Skip to content

Commit

Permalink
Added test that compares symbolic pullback to zygote pullback.
Browse files Browse the repository at this point in the history
Flipped order of functions.

Added _get_contents method for Tuple as argument.

Added another method to deal with Zygote idiosyncracies.

Fixed method for _get_params.
  • Loading branch information
benedict-96 committed Dec 16, 2024
1 parent f495dc8 commit a260ae5
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 10 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
AbstractNeuralNetworks = "0.3, 0.4"
Documenter = "1.8.0"
ForwardDiff = "0.10.38"
GeometricMachineLearning = "0.3.7"
Latexify = "0.16.5"
RuntimeGeneratedFunctions = "0.5"
Symbolics = "5, 6"
Expand All @@ -28,6 +29,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
GeometricMachineLearning = "194d25b2-d3f5-49f0-af24-c124f4aa80cc"

[targets]
test = ["Test", "ForwardDiff", "Random", "Documenter", "Latexify", "SafeTestsets", "Zygote"]
test = ["Test", "ForwardDiff", "Random", "Documenter", "Latexify", "SafeTestsets", "Zygote", "GeometricMachineLearning"]
19 changes: 11 additions & 8 deletions src/derivatives/pullback.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ nn = SymbolicNeuralNetwork(c)
loss = FeedForwardLoss()
pb = SymbolicPullback(nn, loss)
ps = initialparameters(c) |> NeuralNetworkParameters
pv_values = pb(ps, nn.model, (rand(2), rand(1)))[2](1) |> typeof
pb_values = pb(ps, nn.model, (rand(2), rand(1)))[2](1) |> typeof
# output
Expand Down Expand Up @@ -47,19 +47,20 @@ import Random
Random.seed!(123)
c = Chain(Dense(2, 1, tanh))
nn = SymbolicNeuralNetwork(c)
nn = NeuralNetwork(c)
snn = SymbolicNeuralNetwork(nn)
loss = FeedForwardLoss()
pb = SymbolicPullback(nn, loss)
ps = initialparameters(c) |> NeuralNetworkParameters
pb = SymbolicPullback(snn, loss)
input_output = (rand(2), rand(1))
loss_and_pullback = pb(ps, nn.model, input_output)
pv_values = loss_and_pullback[2](1)
loss_and_pullback = pb(nn.params, nn.model, input_output)
# note that we apply the second argument to another input `1`
pb_values = loss_and_pullback[2](1)
@variables soutput[1:SymbolicNeuralNetworks.output_dimension(nn.model)]
symbolic_pullbacks = SymbolicNeuralNetworks.symbolic_pullback(loss(nn.model, nn.params, nn.input, soutput), nn)
pv_values2 = build_nn_function(symbolic_pullbacks, nn.params, nn.input, soutput)(input_output[1], input_output[2], ps)
pb_values2 = build_nn_function(symbolic_pullbacks, nn.params, nn.input, soutput)(input_output[1], input_output[2], ps)
pv_values == (pv_values2 |> SymbolicNeuralNetworks._get_params |> SymbolicNeuralNetworks._get_contents)
pb_values == (pb_values2 |> SymbolicNeuralNetworks._get_params |> SymbolicNeuralNetworks._get_contents)
# output
Expand Down Expand Up @@ -106,6 +107,7 @@ Return the `NamedTuple` that's equivalent to the `NeuralNetworkParameters`.
"""
_get_params(nt::NamedTuple) = nt
_get_params(ps::NeuralNetworkParameters) = ps.params
_get_params(ps::NamedTuple{(:params,), Tuple{NT}}) where {NT<:NamedTuple} = ps.params
_get_params(ps::AbstractArray{<:Union{NamedTuple, NeuralNetworkParameters}}) = [_get_params(nt) for nt in ps]

"""
Expand Down Expand Up @@ -134,6 +136,7 @@ function __get_contents(nt::AbstractArray{<:NamedTuple})
nt
end
_get_contents(nt::AbstractArray{<:NamedTuple}) = __get_contents(nt)
_get_contents(nt::Tuple{<:NamedTuple}) = nt[1]

# (_pullback::SymbolicPullback)(ps, model, input_nt::QPTOAT)::Tuple = Zygote.pullback(ps -> _pullback.loss(model, ps, input_nt), ps)
function (_pullback::SymbolicPullback)(ps, model, input_nt_output_nt::Tuple{<:QPTOAT, <:QPTOAT})::Tuple
Expand Down
38 changes: 38 additions & 0 deletions test/derivatives/pullback.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
using SymbolicNeuralNetworks
using SymbolicNeuralNetworks: _get_params, _get_contents
using AbstractNeuralNetworks
using Symbolics
using GeometricMachineLearning: ZygotePullback
using Test
import Random
Random.seed!(123)

compare_values(arr1::Array, arr2::Array) = @test arr1 arr2
function compare_values(nt1::NamedTuple, nt2::NamedTuple)
@assert keys(nt1) == keys(nt2)
NamedTuple{keys(nt1)}((compare_values(arr1, arr2) for (arr1, arr2) in zip(values(nt1), values(nt2))))
end

function compare_symbolic_pullback_to_zygote_pullback(input_dim::Integer, output_dim::Integer, second_dim::Integer=1)
c = Chain(Dense(input_dim, output_dim, tanh))
nn = NeuralNetwork(c)
snn = SymbolicNeuralNetwork(nn)
loss = FeedForwardLoss()
spb = SymbolicPullback(snn, loss)
input_output = (rand(input_dim, second_dim), rand(output_dim, second_dim))
loss_and_pullback = spb(nn.params, nn.model, input_output)
# note that we apply the second argument to another input `1`
pb_values = loss_and_pullback[2](1)

zpb = ZygotePullback(loss)
loss_and_pullback_zygote = zpb(nn.params, nn.model, input_output)
pb_values_zygote = loss_and_pullback_zygote[2](1) |> _get_contents |> _get_params

compare_values(pb_values, pb_values_zygote)
end

for input_dim (2, 3)
for output_dim (1, 2)
compare_symbolic_pullback_to_zygote_pullback(input_dim, output_dim)
end
end
3 changes: 2 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ using SafeTestsets
@safetestset "Symbolic Params " begin include("symbolic_neuralnet/symbolize.jl") end
@safetestset "Tests associated with 'build_function.jl' " begin include("build_function/build_function.jl") end
@safetestset "Tests associated with 'build_function_double_input.jl' " begin include("build_function/build_function_double_input.jl") end
@safetestset "Tests associated with 'build_function_array.jl " begin include("build_function/build_function_arrays.jl") end
@safetestset "Tests associated with 'build_function_array.jl " begin include("build_function/build_function_arrays.jl") end
@safetestset "Compare Zygote Pullback with Symbolic Pullback " begin include("derivatives/pullback.jl") end

0 comments on commit a260ae5

Please sign in to comment.