Skip to content

Commit

Permalink
Moved remaining tests from docstring tests to build_function directory.
Browse files Browse the repository at this point in the history
ps -> nn.params.

Vector can't be used on Tuple of Vectors (apparently).
  • Loading branch information
benedict-96 committed Dec 16, 2024
1 parent 02c67ea commit bbd5d70
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 13 deletions.
18 changes: 7 additions & 11 deletions src/build_function/build_function_arrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,25 +44,21 @@ Return a function that takes an input, (optionally) an output and neural network
```jldoctest
using SymbolicNeuralNetworks: build_nn_function, SymbolicNeuralNetwork
using AbstractNeuralNetworks: Chain, Dense, initialparameters, NeuralNetworkParameters
using AbstractNeuralNetworks: Chain, Dense, NeuralNetwork
import Random
Random.seed!(123)
c = Chain(Dense(2, 1, tanh))
nn = SymbolicNeuralNetwork(c)
eqs = (a = c(nn.input, nn.params), b = c(nn.input, nn.params).^2)
funcs = build_nn_function(eqs, nn.params, nn.input)
nn = NeuralNetwork(c)
snn = SymbolicNeuralNetwork(nn)
eqs = (a = c(snn.input, snn.params), b = c(snn.input, snn.params).^2)
funcs = build_nn_function(eqs, snn.params, snn.input)
input = [1., 2.]
ps = initialparameters(c) |> NeuralNetworkParameters
a = c(input, ps)
b = c(input, ps).^2
funcs_evaluated = funcs(input, ps)
(funcs_evaluated.a, funcs_evaluated.b) .≈ (a, b)
funcs_evaluated = funcs(input, nn.params)
# output
(true, true)
(a = [-0.9999386280616135], b = [0.9998772598897417])
```
# Implementation
Expand Down
38 changes: 37 additions & 1 deletion test/build_function/build_function_arrays.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using SymbolicNeuralNetworks: build_nn_function, SymbolicNeuralNetwork
using SymbolicNeuralNetworks: build_nn_function, SymbolicNeuralNetwork, function_valued_parameters
using AbstractNeuralNetworks: Chain, Dense, NeuralNetwork
using Test
import Random
Expand All @@ -21,8 +21,44 @@ function build_function_for_array_valued_equation(input_dim::Integer=2, output_d
@test funcs_evaluated_as_vector result_of_standard_computation
end

function build_function_for_named_tuple(input_dim::Integer=2, output_dim::Integer=1)
c = Chain(Dense(input_dim, output_dim, tanh))
nn = NeuralNetwork(c)
snn = SymbolicNeuralNetwork(nn)
eqs = (a = c(snn.input, snn.params), b = c(snn.input, snn.params).^2)
funcs = build_nn_function(eqs, snn.params, snn.input)
input = [1., 2.]
a = c(input, nn.params)
b = c(input, nn.params).^2
funcs_evaluated = funcs(input, nn.params)

funcs_evaluated_as_vector = [funcs_evaluated.a, funcs_evaluated.b]
result_of_standard_computation = [a, b]

@test funcs_evaluated_as_vector result_of_standard_computation
end

function function_valued_parameters_for_named_tuple(input_dim::Integer=2, output_dim::Integer=1)
c = Chain(Dense(input_dim, output_dim, tanh))
nn = NeuralNetwork(c)
snn = SymbolicNeuralNetwork(nn)
eqs = (a = c(snn.input, snn.params), b = c(snn.input, snn.params).^2)
funcs = function_valued_parameters(eqs, snn.params, snn.input)
input = [1., 2.]
a = c(input, nn.params)
b = c(input, nn.params).^2

funcs_evaluated_as_vector = [funcs.a(input, nn.params), funcs.b(input, nn.params)]
result_of_standard_computation = [a, b]

@test funcs_evaluated_as_vector result_of_standard_computation
end

# we test in the following order: `function_valued_parameters` → `build_function` (for `NamedTuple`) → `build_function` (for `Array` of `NamedTuple`s) as this is also how the functions are built.
for input_dim (2, 3)
for output_dim (1, 2)
function_valued_parameters_for_named_tuple(input_dim, output_dim)
build_function_for_named_tuple(input_dim, output_dim)
build_function_for_array_valued_equation(input_dim, output_dim)
end
end
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
using SymbolicNeuralNetworks
using SafeTestsets

@safetestset "Docstring tests. " begin include("doctest.jl") end
# @safetestset "Docstring tests. " begin include("doctest.jl") end
@safetestset "Symbolic gradient " begin include("derivatives/symbolic_gradient.jl") end
@safetestset "Symbolic Neural network " begin include("derivatives/jacobian.jl") end
@safetestset "Symbolic Params " begin include("symbolic_neuralnet/symbolize.jl") end
Expand Down

0 comments on commit bbd5d70

Please sign in to comment.