Skip to content

Commit

Permalink
Merge pull request #386 from MilesCranmer/dynamic-autodiff
Browse files Browse the repository at this point in the history
Add differential operator
  • Loading branch information
MilesCranmer authored Dec 13, 2024
2 parents 701889f + 5fd91a6 commit 9b6d4fc
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 7 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
DispatchDoctor = "8d63f2c5-f18a-4cf2-ba9d-b3f60fc568c8"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
DynamicDiff = "7317a516-7a03-4707-b902-c6dba1468ba0"
DynamicExpressions = "a40a106e-89c9-4ca8-8020-a735e8728b6b"
DynamicQuantities = "06fc5a27-2a28-4c7c-a15d-362465fb6821"
LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255"
Expand Down Expand Up @@ -48,6 +49,7 @@ Dates = "1"
DifferentiationInterface = "0.5, 0.6"
DispatchDoctor = "^0.4.17"
Distributed = "<0.0.1, 1"
DynamicDiff = "0.2"
DynamicExpressions = "~1.8"
DynamicQuantities = "1"
Enzyme = "0.12, 0.13"
Expand Down
3 changes: 1 addition & 2 deletions src/SymbolicRegression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,6 @@ using DynamicExpressions:
with_contents,
with_metadata
using DynamicExpressions: with_type_parameters
# TODO: Reexport D once DynamicAutodiff is registered
# @reexport using DynamicAutodiff: D
@reexport using LossFunctions:
MarginLoss,
DistanceLoss,
Expand Down Expand Up @@ -160,6 +158,7 @@ using DynamicExpressions: with_type_parameters
LogitDistLoss,
QuantileLoss,
LogCoshLoss
using DynamicDiff: D
using Compat: @compat, Fix

#! format: off
Expand Down
7 changes: 2 additions & 5 deletions src/TemplateExpression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module TemplateExpressionModule

using Random: AbstractRNG
using Compat: Fix
using DynamicDiff: DynamicDiff
using DispatchDoctor: @unstable, @stable
using StyledStrings: @styled_str, annotatedstring
using DynamicExpressions:
Expand Down Expand Up @@ -39,9 +40,6 @@ using ..MutateModule: MutateModule as MM
using ..PopMemberModule: PopMember
using ..ComposableExpressionModule: ComposableExpression, ValidVector

# TODO: Modify `D` once DynamicAutodiff is registered
# import DynamicAutodiff: D

"""
TemplateStructure{K,E,NF} <: Function
Expand Down Expand Up @@ -94,10 +92,9 @@ struct ArgumentRecorder{F} <: Function
end
(f::ArgumentRecorder)(args...) = f.f(args...)

# TODO: Modify `D` once DynamicAutodiff is registered
# We pass through the derivative operators, since
# we just want to record the number of arguments.
# DA.D(f::ArgumentRecorder, _::Integer) = f
DynamicDiff.D(f::ArgumentRecorder, ::Integer) = f

"""Infers number of features used by each subexpression, by passing in test data."""
function infer_variable_constraints(::Val{K}, combiner::F) where {K,F}
Expand Down
22 changes: 22 additions & 0 deletions test/test_composable_expression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -318,3 +318,25 @@ end
expr = TemplateExpression((; f=c1, g=x1 + x2); structure, operators, variable_names)
@test expr(X) [6.0] # 3 + (1 + 2)
end

@testitem "Test TemplateExpression with differential operator" tags = [:part3] begin
using SymbolicRegression
using SymbolicRegression: D
using DynamicExpressions: OperatorEnum

operators = OperatorEnum(; binary_operators=(+, -, *, /), unary_operators=(sin, cos))
variable_names = ["x1", "x2"]
x1 = ComposableExpression(Node{Float64}(; feature=1); operators, variable_names)
x2 = ComposableExpression(Node{Float64}(; feature=2); operators, variable_names)
x3 = ComposableExpression(Node{Float64}(; feature=3); operators, variable_names)

structure = TemplateStructure{(:f, :g)}(
((; f, g), (x1, x2, x3)) -> f(x1) + D(g, 1)(x2, x3)
)
expr = TemplateExpression(
(; f=x1, g=cos(x1 - x2) + 2.5 * x1); structure, operators, variable_names
)
# Truth: x1 - sin(x2 - x3) + 2.5
X = stack(([1.0, 2.0], [3.0, 4.0], [5.0, 6.0]); dims=1)
@test expr(X) [1.0, 2.0] .- sin.([3.0, 4.0] .- [5.0, 6.0]) .+ 2.5
end

0 comments on commit 9b6d4fc

Please sign in to comment.