Skip to content

Commit

Permalink
Merge branch 'master' into customize-mlj-interface
Browse files Browse the repository at this point in the history
  • Loading branch information
atharvas authored Jan 4, 2025
2 parents 982f0da + e5afc91 commit fce3b06
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 15 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SymbolicRegression"
uuid = "8254be44-1295-4e6a-a16d-46603ac705cb"
authors = ["MilesCranmer <miles.cranmer@gmail.com>"]
version = "1.5.1"
version = "1.5.2"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down Expand Up @@ -50,7 +50,7 @@ DifferentiationInterface = "0.5, 0.6"
DispatchDoctor = "^0.4.17"
Distributed = "<0.0.1, 1"
DynamicDiff = "0.2"
DynamicExpressions = "~1.9"
DynamicExpressions = "~1.9.2"
DynamicQuantities = "1"
Enzyme = "0.12, 0.13"
JSON3 = "1"
Expand Down
18 changes: 14 additions & 4 deletions src/ComposableExpression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ using DynamicExpressions:
AbstractExpressionNode,
AbstractOperatorEnum,
Metadata,
EvalOptions,
constructorof,
get_metadata,
eval_tree_array,
Expand Down Expand Up @@ -51,16 +52,18 @@ f(f, f) # == (x1 * sin(x2)) * sin((x1 * sin(x2)))
struct ComposableExpression{
T,
N<:AbstractExpressionNode{T},
D<:@NamedTuple{operators::O, variable_names::V} where {O<:AbstractOperatorEnum,V},
D<:@NamedTuple{
operators::O, variable_names::V, eval_options::E
} where {O<:AbstractOperatorEnum,V,E<:Union{Nothing,EvalOptions}},
} <: AbstractComposableExpression{T,N}
tree::N
metadata::Metadata{D}
end

@inline function ComposableExpression(
tree::AbstractExpressionNode{T}; metadata...
tree::AbstractExpressionNode{T}; operators, variable_names=nothing, eval_options=nothing
) where {T}
d = (; metadata...)
d = (; operators, variable_names, eval_options)
return ComposableExpression(tree, Metadata(d))
end

Expand Down Expand Up @@ -152,6 +155,9 @@ struct ValidVector{A<:AbstractVector}
end
ValidVector(x::Tuple{Vararg{Any,2}}) = ValidVector(x...)

function get_eval_options(ex::AbstractComposableExpression)
return @something(get_metadata(ex).eval_options, EvalOptions())
end
function (ex::AbstractComposableExpression)(x)
return error("ComposableExpression does not support input of type $(typeof(x))")
end
Expand Down Expand Up @@ -181,11 +187,15 @@ function (ex::AbstractComposableExpression)(
return ValidVector(_get_value(first(xs)), false)
else
X = Matrix(stack(map(_get_value, xs))')
return ValidVector(eval_tree_array(ex, X))
eval_options = get_eval_options(ex)
return ValidVector(eval_tree_array(ex, X; eval_options))
end
end
function (ex::AbstractComposableExpression{T})() where {T}
X = Matrix{T}(undef, 0, 1) # Value is irrelevant as it won't be used
# TODO: We force avoid the eval_options here,
# to get a faster constant evaluation result...
# but not sure if this is a good idea.
out, complete = eval_tree_array(ex, X) # TODO: The valid is not used; not sure how to incorporate
y = only(out)
return complete ? y::T : nan(y)::T
Expand Down
15 changes: 8 additions & 7 deletions src/InterfaceDynamicExpressions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ using DynamicExpressions:
AbstractExpression,
AbstractExpressionNode,
Node,
GraphNode
GraphNode,
EvalOptions
using DynamicQuantities: dimension, ustrip
using ..CoreModule: AbstractOptions, Dataset
using ..CoreModule.OptionsModule: inverse_binopmap, inverse_unaopmap
Expand Down Expand Up @@ -56,16 +57,16 @@ which speed up evaluation significantly.
tree::Union{AbstractExpressionNode,AbstractExpression},
X::AbstractMatrix,
options::AbstractOptions;
turbo=nothing,
bumper=nothing,
kws...,
)
A = expected_array_type(X, typeof(tree))
eval_options = EvalOptions(;
turbo=something(turbo, options.turbo), bumper=something(bumper, options.bumper)
)
out, complete = DE.eval_tree_array(
tree,
X,
DE.get_operators(tree, options);
turbo=options.turbo,
bumper=options.bumper,
kws...,
tree, X, DE.get_operators(tree, options); eval_options, kws...
)
if isnothing(out)
return nothing, false
Expand Down
4 changes: 3 additions & 1 deletion src/TemplateExpression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ using DynamicExpressions:
AbstractOperatorEnum,
OperatorEnum,
Metadata,
EvalOptions,
get_contents,
with_contents,
get_metadata,
Expand Down Expand Up @@ -251,8 +252,9 @@ function EB.create_expression(
# NOTE: We need to copy over the operators so we can call the structure function
operators = options.operators
variable_names = embed ? dataset.variable_names : nothing
eval_options = EvalOptions(; turbo=options.turbo, bumper=options.bumper)
inner_expressions = ntuple(
_ -> ComposableExpression(copy(t); operators, variable_names),
_ -> ComposableExpression(copy(t); operators, variable_names, eval_options),
Val(length(function_keys)),
)
# TODO: Generalize to other inner expression types
Expand Down
26 changes: 26 additions & 0 deletions test/test_composable_expression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -417,3 +417,29 @@ end
X_complex = [-1.0 - 1.0im]'
@test_throws DimensionMismatch expr(X_complex)
end

@testitem "Test eval_options with turbo mode" tags = [:part3] begin
using SymbolicRegression
using DynamicExpressions: OperatorEnum, EvalOptions
using LoopVectorization: LoopVectorization

operators = OperatorEnum(; binary_operators=(+, *, /, -), unary_operators=(sin, cos))
variable_names = ["x1", "x2"]
eval_options = EvalOptions(; turbo=true)

# Create expressions with turbo mode enabled
x1 = ComposableExpression(
Node{Float64}(; feature=1); operators, variable_names, eval_options
)
f = x1 + x1
g = x1
structure = TemplateStructure{(:f, :g)}(((; f, g), (x1, x2)) -> f(x1) * g(x2)^2)
expr = TemplateExpression((; f=x1 + x1, g=x1); structure, operators, variable_names)

n = 32
X = randn(2, n)
result = expr(X)
@test result @. (X[1, :] + X[1, :]) * (X[2, :] * X[2, :])
# n.b., we can't actually test whether turbo is used here,
# this is basically just a smoke test
end
13 changes: 12 additions & 1 deletion test/test_operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,18 @@ end
@testitem "Turbo mode matches regular mode" tags = [:part3] begin
using SymbolicRegression
using SymbolicRegression:
plus, sub, mult, square, cube, neg, relu, greater, logical_or, logical_and, cond
Node,
plus,
sub,
mult,
square,
cube,
neg,
relu,
greater,
logical_or,
logical_and,
cond
using Random: MersenneTwister
using Suppressor: @capture_err
using LoopVectorization: LoopVectorization as _
Expand Down

0 comments on commit fce3b06

Please sign in to comment.