diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 713a0b168..cb1c9ea86 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -26,8 +26,9 @@ jobs: fail-fast: false matrix: test: - - "unit" - - "integration" + - "part1" + - "part2" + - "part3" julia-version: - "1.6" - "1.8" @@ -37,22 +38,31 @@ jobs: include: - os: windows-latest julia-version: "1" - test: "unit" + test: "part1" - os: windows-latest julia-version: "1" - test: "integration" + test: "part2" + - os: windows-latest + julia-version: "1" + test: "part3" + - os: macOS-latest + julia-version: "1" + test: "part1" - os: macOS-latest julia-version: "1" - test: "unit" + test: "part2" - os: macOS-latest julia-version: "1" - test: "integration" + test: "part3" - os: ubuntu-latest julia-version: "~1.11.0-0" - test: "unit" + test: "part1" - os: ubuntu-latest julia-version: "~1.11.0-0" - test: "integration" + test: "part2" + - os: ubuntu-latest + julia-version: "~1.11.0-0" + test: "part3" steps: - uses: actions/checkout@v4 @@ -62,6 +72,8 @@ jobs: version: ${{ matrix.julia-version }} - name: "Cache dependencies" uses: julia-actions/cache@v2 + with: + cache-name: julia-cache;workflow=${{ github.workflow }};job=${{ github.job }};os=${{ matrix.os }};julia=${{ matrix.julia-version }};project=${{ hashFiles('**/Project.toml') }} - name: "Build package" uses: julia-actions/julia-buildpkg@v1 - name: "Run tests" diff --git a/.gitignore b/.gitignore index d362e9fe5..2cb9c5d85 100644 --- a/.gitignore +++ b/.gitignore @@ -20,3 +20,4 @@ docs/src/index.md *.code-workspace .vscode **/*.json +LocalPreferences.toml diff --git a/Project.toml b/Project.toml index 4927ffae0..1cc170cc8 100644 --- a/Project.toml +++ b/Project.toml @@ -1,11 +1,14 @@ name = "SymbolicRegression" uuid = "8254be44-1295-4e6a-a16d-46603ac705cb" authors = ["MilesCranmer "] -version = "0.24.5" +version = "0.25.0" [deps] +ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" +ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" +DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" DispatchDoctor = "8d63f2c5-f18a-4cf2-ba9d-b3f60fc568c8" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" DynamicExpressions = "a40a106e-89c9-4ca8-8020-a735e8728b6b" @@ -27,20 +30,28 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" TOML = "fa267f1f-6049-4f14-aa54-33bafae1ed76" [weakdeps] +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b" [extensions] +SymbolicRegressionEnzymeExt = "Enzyme" SymbolicRegressionJSON3Ext = "JSON3" SymbolicRegressionSymbolicUtilsExt = "SymbolicUtils" [compat] +ADTypes = "^1.4.0" Compat = "^4.2" +ConstructionBase = "<1.5.7" Dates = "1" +DifferentiationInterface = "0.5" DispatchDoctor = "0.4" -Distributed = "1" -DynamicExpressions = "0.16" -DynamicQuantities = "0.10, 0.11, 0.12, 0.13, 0.14" +# Note that the <0.0.1 bound is required for old julia compat (which does +# not have stdlib packages available in [compat]) +Distributed = "<0.0.1, 1" +DynamicExpressions = "1" +DynamicQuantities = "1" +Enzyme = "0.12" JSON3 = "1" LineSearches = "7" LossFunctions = "0.10, 0.11" @@ -48,18 +59,19 @@ MLJModelInterface = "~1.5, ~1.6, ~1.7, ~1.8, ~1.9, ~1.10, ~1.11" MacroTools = "0.4, 0.5" Optim = "~1.8, ~1.9" PackageExtensionCompat = "1" -Pkg = "1" +Pkg = "<0.0.1, 1" PrecompileTools = "1" -Printf = "1" +Printf = "<0.0.1, 1" ProgressBars = "~1.4, ~1.5" -Random = "1" +Random = "<0.0.1, 1" Reexport = "1" SpecialFunctions = "0.10.1, 1, 2" StatsBase = "0.33, 0.34" -SymbolicUtils = "0.19, ^1.0.5" -TOML = "1" +SymbolicUtils = "0.19, ^1.0.5, 2, 3" +TOML = "<0.0.1, 1" julia = "1.6" [extras] +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b" diff --git a/benchmark/Project.toml b/benchmark/Project.toml index d6b434665..4adbd499a 100644 --- a/benchmark/Project.toml +++ b/benchmark/Project.toml @@ -3,5 +3,12 @@ BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" Bumper = "8ce10254-0962-460f-a3d8-1f77fea1446e" CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" +DynamicExpressions = "a40a106e-89c9-4ca8-8020-a735e8728b6b" LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" + +[preferences.DynamicExpressions] +instability_check = "disable" + +[preferences.SymbolicRegression] +instability_check = "disable" diff --git a/docs/src/api.md b/docs/src/api.md index 07d9b12bb..d9ac1fa97 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -10,35 +10,27 @@ MultitargetSRRegressor ## equation_search ```@docs -equation_search(X::AbstractMatrix{T}, y::AbstractMatrix{T}; - niterations::Int=10, - weights::Union{AbstractVector{T}, Nothing}=nothing, - variable_names::Union{Array{String, 1}, Nothing}=nothing, - options::Options=Options(), - numprocs::Union{Int, Nothing}=nothing, - procs::Union{Array{Int, 1}, Nothing}=nothing, - runtests::Bool=true, - loss_type::Type=Nothing, -) where {T<:DATA_TYPE} +equation_search ``` ## Options ```@docs Options -MutationWeights(;) +MutationWeights ``` ## Printing ```@docs -string_tree(tree::Node, options::Options; kws...) +string_tree ``` ## Evaluation ```@docs -eval_tree_array(tree::Node, X::AbstractMatrix, options::Options; kws...) +eval_tree_array +EvalOptions ``` ## Derivatives @@ -51,16 +43,14 @@ all variables (or, all constants). Both use forward-mode automatic, but use `Zygote.jl` to compute derivatives of each operator, so this is very efficient. ```@docs -eval_diff_tree_array(tree::Node, X::AbstractMatrix, options::Options, direction::Int) -eval_grad_tree_array(tree::Node, X::AbstractMatrix, options::Options; kws...) +eval_diff_tree_array +eval_grad_tree_array ``` ## SymbolicUtils.jl interface ```@docs -node_to_symbolic(tree::Node, options::Options; - variable_names::Union{Array{String, 1}, Nothing}=nothing, - index_functions::Bool=false) +node_to_symbolic ``` Note that use of this function requires `SymbolicUtils.jl` to be installed and loaded. @@ -68,5 +58,5 @@ Note that use of this function requires `SymbolicUtils.jl` to be installed and l ## Pareto frontier ```@docs -calculate_pareto_frontier(hallOfFame::HallOfFame{T,L}) where {T<:DATA_TYPE,L<:LOSS_TYPE} +calculate_pareto_frontier ``` diff --git a/docs/src/examples.md b/docs/src/examples.md index d6f785f0b..4a179cda0 100644 --- a/docs/src/examples.md +++ b/docs/src/examples.md @@ -131,7 +131,7 @@ we can see that the output types are `Float32`: r = report(mach) best = r.equations[r.best_idx] println(typeof(best)) -# Node{Float32} +# Expression{Float32} ``` We can also use `Complex` numbers (ignore the warning @@ -228,7 +228,93 @@ a constant `"2.6353e-22[m s⁻²]"`. Note that you can also search for dimensionless units by settings `dimensionless_constants_only` to `true`. -## 7. Additional features +## 7. Working with Expressions + +Expressions in `SymbolicRegression.jl` are represented using the `Expression` type, which combines the raw `Node` type with an `OperatorEnum`. This allows for more flexible and powerful expression manipulation and evaluation. + +Here's an example: + +```julia +using SymbolicRegression + +# Define options with operators +options = Options(; binary_operators=[+, -, *], unary_operators=[cos]) + +# Create expression nodes +operators = options.operators +variable_names = ["x1", "x2"] +x1 = Expression(Node{Float64}(feature=1); operators, variable_names) +x2 = Expression(Node{Float64}(feature=2); operators, variable_names) + +# Construct an expression using the operators from options +expr = x1 * cos(x2 - 3.2) + +# Evaluate the expression directly +X = rand(Float64, 2, 100) +output = expr(X) +``` + +This `Expression` type, contains both the structure +and the operators used in the expression. These are what +are returned by the search. The raw `Node` type (which is +what used to be output directly) is accessible with + +```julia +get_contents(expr) +``` + +## 8. Parametric Expressions + +Parametric expressions allow the algorithm to optimize parameters within the expressions during the search process. This is useful for finding expressions that not only fit the data but also have tunable parameters. + +To use this, the data needs to have information on which class +each row belongs to --- this class information will be used to +select the parameters when evaluating each expression. + +For example: + +```julia +using SymbolicRegression +using MLJ + +# Define the dataset +X = NamedTuple{(:x1, :x2)}(ntuple(_ -> randn(Float32, 30), Val(2))) +X = (; X..., classes=rand(1:2, 30)) +p1 = [0.0f0, 3.2f0] +p2 = [1.5f0, 0.5f0] + +y = [ + 2 * cos(X.x1[i] + p1[X.classes[i]]) + X.x2[i]^2 - p2[X.classes[i]] for + i in eachindex(X.classes) +] + +# Define the model with parametric expressions +model = SRRegressor( + niterations=100, + binary_operators=[+, *, /, -], + unary_operators=[cos], + expression_type=ParametricExpression, + expression_options=(; max_parameters=2), + parallelism=:multithreading +) + +# Train the model +mach = machine(model, X, y) +fit!(mach) + +# View the best expression +report(mach) +``` + +The final equations will contain parameters that were optimized during training: + +```julia +typeof(report(mach).equations[end]) +``` + +This example demonstrates how to set up a symbolic regression model that searches for expressions with parameters, optimizing both the structure and the parameters of the expressions based on the provided class information. + +## 9. Additional features For the many other features available in SymbolicRegression.jl, check out the API page for `Options`. You might also find it useful diff --git a/docs/src/types.md b/docs/src/types.md index 2ce365271..5c7277c57 100644 --- a/docs/src/types.md +++ b/docs/src/types.md @@ -3,19 +3,10 @@ ## Equations Equations are specified as binary trees with the `Node` type, defined -as follows: +as follows. ```@docs -Node{T<:DATA_TYPE} -``` - -There are a variety of constructors for `Node` objects, including: - -```@docs -Node(; val::DATA_TYPE=nothing, feature::Integer=nothing) -Node(op::Int, l::Node) -Node(op::Int, l::Node, r::Node) -Node(var_string::String) +Node ``` When you create an `Options` object, the operators @@ -39,7 +30,7 @@ convert(::Type{Node{T1}}, tree::Node{T2}) where {T1, T2} You can set a `tree` (in-place) with `set_node!`: ```@docs -set_node!(tree::Node{T}, new_tree::Node{T}) where {T} +set_node! ``` You can create a copy of a node with `copy_node`: @@ -48,6 +39,45 @@ You can create a copy of a node with `copy_node`: copy_node(tree::Node) ``` +## Expressions + +Expressions are represented using the `Expression` type, which combines the raw `Node` type with an `OperatorEnum`. + +```@docs +Expression +``` + +These types allow you to define and manipulate expressions with a clear separation between the structure and the operators used. + +## Parametric Expressions + +Parametric expressions are a type of expression that includes parameters which can be optimized during the search. + +```@docs +ParametricExpression +ParametricNode +``` + +These types allow you to define expressions with parameters that can be tuned to fit the data better. You can specify the maximum number of parameters using the `expression_options` argument in `SRRegressor`. + +## Custom Expressions + +You can create your own expression types by defining a new type that extends `AbstractExpression`. + +```@docs +AbstractExpression +``` + +The interface is fairly flexible, and permits you define specific functional forms, +extra parameters, etc. See the documentation of DynamicExpressions.jl for more details on what +methods you need to implement. Then, for SymbolicRegression.jl, you would +pass `expression_type` to the `Options` constructor, as well as any +`expression_options` you need (as a `NamedTuple`). + +If needed, you may need to overload `SymbolicRegression.ExpressionBuilder.extra_init_params` in +case your expression needs additional parameters. See the method for `ParametricExpression` +as an example. + ## Population Groups of equations are given as a population, which is @@ -68,21 +98,11 @@ PopMember ```@docs HallOfFame -HallOfFame(options::Options, ::Type{T}, ::Type{L}) where {T<:DATA_TYPE,L<:LOSS_TYPE} ``` ## Dataset ```@docs Dataset -Dataset(X::AbstractMatrix{T}, y::Union{AbstractVector{T},Nothing}=nothing; - weights::Union{AbstractVector{T}, Nothing}=nothing, - variable_names::Union{Array{String, 1}, Nothing}=nothing, - y_variable_name::Union{String,Nothing}=nothing, - extra::NamedTuple=NamedTuple(), - loss_type::Type=Nothing, - X_units::Union{AbstractVector, Nothing}=nothing, - y_units=nothing, -) where {T<:DATA_TYPE} update_baseline_loss! ``` diff --git a/examples/parameterized_function.jl b/examples/parameterized_function.jl new file mode 100644 index 000000000..9faefc97c --- /dev/null +++ b/examples/parameterized_function.jl @@ -0,0 +1,47 @@ +using SymbolicRegression +using Random: MersenneTwister +using Zygote +using MLJBase: machine, fit!, predict, report +using Test + +rng = MersenneTwister(0) +X = NamedTuple{(:x1, :x2, :x3, :x4, :x5)}(ntuple(_ -> randn(rng, Float32, 30), Val(5))) +X = (; X..., classes=rand(rng, 1:2, 30)) +p1 = [0.0f0, 3.2f0] +p2 = [1.5f0, 0.5f0] + +y = [ + 2 * cos(X.x4[i] + p1[X.classes[i]]) + X.x1[i]^2 - p2[X.classes[i]] for + i in eachindex(X.classes) +] + +stop_at = Ref(1e-4) + +model = SRRegressor(; + niterations=100, + binary_operators=[+, *, /, -], + unary_operators=[cos, exp], + populations=30, + expression_type=ParametricExpression, + expression_options=(; max_parameters=2), + autodiff_backend=:Zygote, + parallelism=:multithreading, + early_stop_condition=(loss, _) -> loss < stop_at[], +) + +mach = machine(model, X, y) + +fit!(mach) +idx1 = lastindex(report(mach).equations) +ypred1 = predict(mach, (data=X, idx=idx1)) +loss1 = sum(i -> abs2(ypred1[i] - y[i]), eachindex(y)) + +# Should keep all parameters +stop_at[] = 1e-5 +fit!(mach) +idx2 = lastindex(report(mach).equations) +ypred2 = predict(mach, (data=X, idx=idx2)) +loss2 = sum(i -> abs2(ypred2[i] - y[i]), eachindex(y)) + +# Should get better: +@test loss1 >= loss2 diff --git a/ext/SymbolicRegressionEnzymeExt.jl b/ext/SymbolicRegressionEnzymeExt.jl new file mode 100644 index 000000000..b8b8be60b --- /dev/null +++ b/ext/SymbolicRegressionEnzymeExt.jl @@ -0,0 +1,61 @@ +module SymbolicRegressionEnzymeExt + +using SymbolicRegression.LossFunctionsModule: eval_loss +using DynamicExpressions: + AbstractExpression, + AbstractExpressionNode, + get_scalar_constants, + set_scalar_constants!, + extract_gradient, + with_contents, + get_contents +using ADTypes: AutoEnzyme +using Enzyme: autodiff, Reverse, Active, Const, Duplicated + +import SymbolicRegression.ConstantOptimizationModule: GradEvaluator + +# We prepare a copy of the tree and all arrays +function GradEvaluator(f::F, backend::AE) where {F,AE<:AutoEnzyme} + storage_tree = copy(f.tree) + _, storage_refs = get_scalar_constants(storage_tree) + storage_dataset = deepcopy(f.dataset) + # TODO: It is super inefficient to deepcopy; how can we skip this + return GradEvaluator(f, backend, (; storage_tree, storage_refs, storage_dataset)) +end + +function evaluator(tree, dataset, options, idx, output) + output[] = eval_loss(tree, dataset, options; regularization=false, idx=idx) + return nothing +end + +with_stacksize(f::F, n) where {F} = fetch(schedule(Task(f, n))) + +function (g::GradEvaluator{<:Any,<:AutoEnzyme})(_, G, x::AbstractVector{T}) where {T} + set_scalar_constants!(g.f.tree, x, g.f.refs) + set_scalar_constants!(g.extra.storage_tree, zero(x), g.extra.storage_refs) + fill!(g.extra.storage_dataset, 0) + + output = [zero(T)] + doutput = [one(T)] + + with_stacksize(32 * 1024 * 1024) do + autodiff( + Reverse, + evaluator, + Duplicated(g.f.tree, g.extra.storage_tree), + Duplicated(g.f.dataset, g.extra.storage_dataset), + Const(g.f.options), + Const(g.f.idx), + Duplicated(output, doutput), + ) + end + + if G !== nothing + # TODO: This is redundant since we already have the references. + # Should just be able to extract from the references directly. + G .= first(get_scalar_constants(g.extra.storage_tree)) + end + return output[] +end + +end diff --git a/ext/SymbolicRegressionSymbolicUtilsExt.jl b/ext/SymbolicRegressionSymbolicUtilsExt.jl index 2b0096ee0..f94fa463a 100644 --- a/ext/SymbolicRegressionSymbolicUtilsExt.jl +++ b/ext/SymbolicRegressionSymbolicUtilsExt.jl @@ -1,8 +1,9 @@ module SymbolicRegressionSymbolicUtilsExt using SymbolicUtils: Symbolic -using SymbolicRegression: AbstractExpressionNode, Node, Options +using SymbolicRegression: AbstractExpressionNode, AbstractExpression, Node, Options using SymbolicRegression.MLJInterfaceModule: AbstractSRRegressor, get_options +using DynamicExpressions: get_tree, get_operators import SymbolicRegression: node_to_symbolic, symbolic_to_node @@ -11,10 +12,14 @@ import SymbolicRegression: node_to_symbolic, symbolic_to_node Convert an expression to SymbolicUtils.jl form. """ -function node_to_symbolic(tree::AbstractExpressionNode, options::Options; kws...) - return node_to_symbolic(tree, options.operators; kws...) +function node_to_symbolic( + tree::Union{AbstractExpressionNode,AbstractExpression}, options::Options; kws... +) + return node_to_symbolic(get_tree(tree), get_operators(tree, options); kws...) end -function node_to_symbolic(tree::AbstractExpressionNode, m::AbstractSRRegressor; kws...) +function node_to_symbolic( + tree::Union{AbstractExpressionNode,AbstractExpression}, m::AbstractSRRegressor; kws... +) return node_to_symbolic(tree, get_options(m); kws...) end @@ -31,24 +36,30 @@ function symbolic_to_node(eqn::Symbolic, m::AbstractSRRegressor; kws...) end function Base.convert( - ::Type{Symbolic}, tree::AbstractExpressionNode, options::Options; kws... + ::Type{Symbolic}, + tree::Union{AbstractExpressionNode,AbstractExpression}, + options::Union{Options,Nothing}=nothing; + kws..., ) - return convert(Symbolic, tree, options.operators; kws...) + return convert(Symbolic, get_tree(tree), get_operators(tree, options); kws...) end function Base.convert( - ::Type{Symbolic}, tree::AbstractExpressionNode, m::AbstractSRRegressor; kws... + ::Type{Symbolic}, + tree::Union{AbstractExpressionNode,AbstractExpression}, + m::AbstractSRRegressor; + kws..., ) return convert(Symbolic, tree, get_options(m); kws...) end function Base.convert( ::Type{N}, x::Union{Number,Symbolic}, options::Options; kws... -) where {N<:AbstractExpressionNode} +) where {N<:Union{AbstractExpressionNode,AbstractExpression}} return convert(N, x, options.operators; kws...) end function Base.convert( ::Type{N}, x::Union{Number,Symbolic}, m::AbstractSRRegressor; kws... -) where {N<:AbstractExpressionNode} +) where {N<:Union{AbstractExpressionNode,AbstractExpression}} return convert(N, x, get_options(m); kws...) end diff --git a/src/CheckConstraints.jl b/src/CheckConstraints.jl index 6b60ff840..7f6093631 100644 --- a/src/CheckConstraints.jl +++ b/src/CheckConstraints.jl @@ -1,7 +1,7 @@ module CheckConstraintsModule -using DynamicExpressions: AbstractExpressionNode, count_depth, tree_mapreduce -using ..UtilsModule: vals +using DynamicExpressions: + AbstractExpressionNode, AbstractExpression, get_tree, count_depth, tree_mapreduce using ..CoreModule: Options using ..ComplexityModule: compute_complexity, past_complexity_limit @@ -70,6 +70,15 @@ function flag_illegal_nests(tree::AbstractExpressionNode, options::Options)::Boo end """Check if user-passed constraints are violated or not""" +function check_constraints( + ex::AbstractExpression, + options::Options, + maxsize::Int, + cursize::Union{Int,Nothing}=nothing, +)::Bool + tree = get_tree(ex) + return check_constraints(tree, options, maxsize, cursize) +end function check_constraints( tree::AbstractExpressionNode, options::Options, @@ -93,7 +102,8 @@ function check_constraints( return true end -check_constraints(tree::AbstractExpressionNode, options::Options)::Bool = - check_constraints(tree, options, options.maxsize) +check_constraints( + ex::Union{AbstractExpression,AbstractExpressionNode}, options::Options +)::Bool = check_constraints(ex, options, options.maxsize) end diff --git a/src/Complexity.jl b/src/Complexity.jl index f3b0bba91..dccb05bd3 100644 --- a/src/Complexity.jl +++ b/src/Complexity.jl @@ -1,9 +1,12 @@ module ComplexityModule -using DynamicExpressions: AbstractExpressionNode, count_nodes, tree_mapreduce +using DynamicExpressions: + AbstractExpression, AbstractExpressionNode, get_tree, count_nodes, tree_mapreduce using ..CoreModule: Options, ComplexityMapping -function past_complexity_limit(tree::AbstractExpressionNode, options::Options, limit)::Bool +function past_complexity_limit( + tree::Union{AbstractExpression,AbstractExpressionNode}, options::Options, limit +)::Bool return compute_complexity(tree, options) > limit end @@ -14,6 +17,11 @@ By default, this is the number of nodes in a tree. However, it could use the custom settings in options.complexity_mapping if these are defined. """ +function compute_complexity( + tree::AbstractExpression, options::Options; break_sharing=Val(false) +) + return compute_complexity(get_tree(tree), options; break_sharing) +end function compute_complexity( tree::AbstractExpressionNode, options::Options; break_sharing=Val(false) )::Int diff --git a/src/ConstantOptimization.jl b/src/ConstantOptimization.jl index 7dc8471a3..fe66b4f5d 100644 --- a/src/ConstantOptimization.jl +++ b/src/ConstantOptimization.jl @@ -2,8 +2,16 @@ module ConstantOptimizationModule using LineSearches: LineSearches using Optim: Optim -using DynamicExpressions: Node, count_constants -using ..CoreModule: Options, Dataset, DATA_TYPE, LOSS_TYPE +using ADTypes: AbstractADType, AutoEnzyme +using DifferentiationInterface: value_and_gradient +using DynamicExpressions: + AbstractExpression, + Expression, + count_scalar_constants, + get_scalar_constants, + set_scalar_constants!, + extract_gradient +using ..CoreModule: Options, Dataset, DATA_TYPE, LOSS_TYPE, specialized_options using ..UtilsModule: get_birth_order using ..LossFunctionsModule: eval_loss, loss_to_score, batch_sample using ..PopMemberModule: PopMember @@ -22,45 +30,58 @@ end function dispatch_optimize_constants( dataset::Dataset{T,L}, member::P, options::Options, idx ) where {T<:DATA_TYPE,L<:LOSS_TYPE,P<:PopMember{T,L}} - nconst = count_constants(member.tree) + nconst = count_constants_for_optimization(member.tree) nconst == 0 && return (member, 0.0) if nconst == 1 && !(T <: Complex) algorithm = Optim.Newton(; linesearch=LineSearches.BackTracking()) return _optimize_constants( - dataset, member, options, algorithm, options.optimizer_options, idx + dataset, + member, + specialized_options(options), + algorithm, + options.optimizer_options, + idx, ) end return _optimize_constants( dataset, member, - options, + specialized_options(options), + # We use specialized options here due to Enzyme being + # more particular about dynamic dispatch options.optimizer_algorithm, options.optimizer_options, idx, ) end +"""How many constants will be optimized.""" +count_constants_for_optimization(ex::Expression) = count_scalar_constants(ex) + function _optimize_constants( dataset, member::P, options, algorithm, optimizer_options, idx )::Tuple{P,Float64} where {T,L,P<:PopMember{T,L}} tree = member.tree eval_fraction = options.batching ? (options.batch_size / dataset.n) : 1.0 - f(t) = eval_loss(t, dataset, options; regularization=false, idx=idx)::L - baseline = f(tree) - result = Optim.optimize(f, tree, algorithm, optimizer_options) + x0, refs = get_scalar_constants(tree) + @assert count_constants_for_optimization(tree) == length(x0) + f = Evaluator(tree, refs, dataset, options, idx) + fg! = GradEvaluator(f, options.autodiff_backend) + obj = if algorithm isa Optim.Newton || options.autodiff_backend === nothing + f + else + Optim.only_fg!(fg!) + end + baseline = f(x0) + result = Optim.optimize(obj, x0, algorithm, optimizer_options) num_evals = result.f_calls * eval_fraction # Try other initial conditions: for _ in 1:(options.optimizer_nrestarts) - tmptree = copy(tree) - foreach(tmptree) do node - if node.degree == 0 && node.constant - node.val = (node.val) * (T(1) + T(1//2) * randn(T)) - end - end - tmpresult = Optim.optimize( - f, tmptree, algorithm, optimizer_options; make_copy=false - ) + eps = randn(T, size(x0)...) + xt = @. x0 * (T(1) + T(1//2) * eps) + tmpresult = Optim.optimize(obj, xt, algorithm, optimizer_options) num_evals += tmpresult.f_calls * eval_fraction + # TODO: Does this need to take into account h_calls? if tmpresult.minimum < result.minimum result = tmpresult @@ -68,16 +89,48 @@ function _optimize_constants( end if result.minimum < baseline - member.tree = result.minimizer - member.loss = eval_loss(member.tree, dataset, options; regularization=true, idx=idx) + member.tree = tree + member.loss = f(result.minimizer; regularization=true) member.score = loss_to_score( member.loss, dataset.use_baseline, dataset.baseline_loss, member, options ) member.birth = get_birth_order(; deterministic=options.deterministic) num_evals += eval_fraction + else + set_scalar_constants!(member.tree, x0, refs) end return member, num_evals end +struct Evaluator{N<:AbstractExpression,R,D<:Dataset,O<:Options,I} <: Function + tree::N + refs::R + dataset::D + options::O + idx::I +end +function (e::Evaluator)(x::AbstractVector; regularization=false) + set_scalar_constants!(e.tree, x, e.refs) + return eval_loss(e.tree, e.dataset, e.options; regularization, e.idx) +end +struct GradEvaluator{F<:Evaluator,AD<:Union{Nothing,AbstractADType},EX} <: Function + f::F + backend::AD + extra::EX +end +GradEvaluator(f::F, backend::AD) where {F,AD} = GradEvaluator(f, backend, nothing) + +function (g::GradEvaluator{<:Any,AD})(_, G, x::AbstractVector) where {AD} + AD isa AutoEnzyme && error("Please load the `Enzyme.jl` package.") + set_scalar_constants!(g.f.tree, x, g.f.refs) + (val, grad) = value_and_gradient(g.backend, g.f.tree) do tree + eval_loss(tree, g.f.dataset, g.f.options; regularization=false, idx=g.f.idx) + end + if G !== nothing && grad !== nothing + G .= extract_gradient(grad, g.f.tree) + end + return val +end + end diff --git a/src/Core.jl b/src/Core.jl index 7a04b975f..8f917c906 100644 --- a/src/Core.jl +++ b/src/Core.jl @@ -1,5 +1,7 @@ module CoreModule +function create_expression end + include("Utils.jl") include("ProgramConstants.jl") include("Dataset.jl") @@ -12,7 +14,7 @@ using .ProgramConstantsModule: MAX_DEGREE, BATCH_DIM, FEATURE_DIM, RecordType, DATA_TYPE, LOSS_TYPE using .DatasetModule: Dataset using .MutationWeightsModule: MutationWeights, sample_mutation -using .OptionsStructModule: Options, ComplexityMapping +using .OptionsStructModule: Options, ComplexityMapping, specialized_options using .OptionsModule: Options using .OperatorsModule: plus, diff --git a/src/Dataset.jl b/src/Dataset.jl index b2a7b42d8..99c31ee3d 100644 --- a/src/Dataset.jl +++ b/src/Dataset.jl @@ -1,13 +1,6 @@ module DatasetModule -using DynamicQuantities: - AbstractDimensions, - Dimensions, - SymbolicDimensions, - Quantity, - uparse, - sym_uparse, - DEFAULT_DIM_BASE_TYPE +using DynamicQuantities: Quantity using ..UtilsModule: subscriptify, get_base_type, @constfield using ..ProgramConstantsModule: BATCH_DIM, FEATURE_DIM, DATA_TYPE, LOSS_TYPE @@ -22,6 +15,8 @@ import ...deprecate_varmap - `X::AbstractMatrix{T}`: The input features, with shape `(nfeatures, n)`. - `y::AbstractVector{T}`: The desired output values, with shape `(n,)`. +- `index::Int`: The index of the output feature corresponding to this + dataset, if any. - `n::Int`: The number of samples. - `nfeatures::Int`: The number of features. - `weighted::Bool`: Whether the dataset is non-uniformly weighted. @@ -64,6 +59,7 @@ mutable struct Dataset{ } @constfield X::AX @constfield y::AY + @constfield index::Int @constfield n::Int @constfield nfeatures::Int @constfield weighted::Bool @@ -99,6 +95,7 @@ function Dataset( X::AbstractMatrix{T}, y::Union{AbstractVector{T},Nothing}=nothing, loss_type::Type{L}=Nothing; + index::Int=1, weights::Union{AbstractVector{T},Nothing}=nothing, variable_names::Union{Array{String,1},Nothing}=nothing, display_variable_names=variable_names, @@ -123,6 +120,7 @@ function Dataset( X, y, kws[:loss_type]; + index, weights, variable_names, display_variable_names, @@ -206,6 +204,7 @@ function Dataset( }( X, y, + index, n, nfeatures, weighted, @@ -260,4 +259,17 @@ function has_units(dataset::Dataset) return dataset.X_units !== nothing || dataset.y_units !== nothing end +# Used for Enzyme +function Base.fill!(d::Dataset, val) + _fill!(d.X, val) + _fill!(d.y, val) + _fill!(d.weights, val) + _fill!(d.extra, val) + return d +end +_fill!(x::AbstractArray, val) = fill!(x, val) +_fill!(x::NamedTuple, val) = foreach(v -> _fill!(v, val), values(x)) +_fill!(::Nothing, val) = nothing +_fill!(x, val) = x + end diff --git a/src/DimensionalAnalysis.jl b/src/DimensionalAnalysis.jl index d0e975fbf..cc9440db1 100644 --- a/src/DimensionalAnalysis.jl +++ b/src/DimensionalAnalysis.jl @@ -1,7 +1,7 @@ module DimensionalAnalysisModule -using DynamicExpressions: AbstractExpressionNode -using DynamicQuantities: Quantity, DimensionError, AbstractQuantity, uparse, constructorof +using DynamicExpressions: AbstractExpression, AbstractExpressionNode, get_tree +using DynamicQuantities: Quantity, DimensionError, AbstractQuantity, constructorof using ..CoreModule: Options, Dataset using ..UtilsModule: safe_call @@ -192,6 +192,11 @@ function violates_dimensional_constraints( tree, dataset.X_units, dataset.y_units, (@view X[:, 1]), options ) end +function violates_dimensional_constraints( + tree::AbstractExpression, dataset::Dataset, options::Options +) + return violates_dimensional_constraints(get_tree(tree), dataset, options) +end function violates_dimensional_constraints( tree::AbstractExpressionNode{T}, X_units::AbstractVector{<:Quantity}, diff --git a/src/ExpressionBuilder.jl b/src/ExpressionBuilder.jl new file mode 100644 index 000000000..a54d97bdd --- /dev/null +++ b/src/ExpressionBuilder.jl @@ -0,0 +1,290 @@ +module ExpressionBuilderModule + +using DispatchDoctor: @unstable +using DynamicExpressions: + AbstractExpressionNode, + AbstractExpression, + Expression, + ParametricExpression, + ParametricNode, + constructorof, + get_tree, + get_contents, + get_metadata, + with_contents, + with_metadata, + count_scalar_constants, + eval_tree_array +using Random: default_rng, AbstractRNG +using StatsBase: StatsBase +using ..CoreModule: Options, Dataset, DATA_TYPE +using ..HallOfFameModule: HallOfFame +using ..LossFunctionsModule: maybe_getindex +using ..InterfaceDynamicExpressionsModule: expected_array_type +using ..PopulationModule: Population +using ..PopMemberModule: PopMember + +import DynamicExpressions: get_operators +import ..CoreModule: create_expression +import ..MutationFunctionsModule: + make_random_leaf, crossover_trees, mutate_constant, mutate_factor +import ..LossFunctionsModule: eval_tree_dispatch +import ..ConstantOptimizationModule: count_constants_for_optimization + +@unstable function create_expression( + t::T, options::Options, dataset::Dataset{T,L}, ::Val{embed}=Val(false) +) where {T,L,embed} + return create_expression( + constructorof(options.node_type)(; val=t), options, dataset, Val(embed) + ) +end +@unstable function create_expression( + t::AbstractExpressionNode{T}, + options::Options, + dataset::Dataset{T,L}, + ::Val{embed}=Val(false), +) where {T,L,embed} + return constructorof(options.expression_type)( + t; init_params(options, dataset, nothing, Val(embed))... + ) +end +function create_expression( + ex::AbstractExpression{T}, ::Options, ::Dataset{T,L}, ::Val{embed}=Val(false) +) where {T,L,embed} + return ex +end +@unstable function init_params( + options::Options, + dataset::Dataset{T,L}, + prototype::Union{Nothing,AbstractExpression}, + ::Val{embed}, +) where {T,L,embed} + consistency_checks(options, prototype) + return (; + operators=embed ? options.operators : nothing, + variable_names=embed ? dataset.variable_names : nothing, + extra_init_params( + options.expression_type, prototype, options, dataset, Val(embed) + )..., + ) +end +function extra_init_params( + ::Type{E}, + prototype::Union{Nothing,AbstractExpression}, + options::Options, + dataset::Dataset{T}, + ::Val{embed}, +) where {T,embed,E<:AbstractExpression} + return (; options.expression_options...) +end +function extra_init_params( + ::Type{E}, + prototype::Union{Nothing,ParametricExpression}, + options::Options, + dataset::Dataset{T}, + ::Val{embed}, +) where {T,embed,E<:ParametricExpression} + num_params = options.expression_options.max_parameters + num_classes = length(unique(dataset.extra.classes)) + parameter_names = embed ? ["p$i" for i in 1:num_params] : nothing + _parameters = if prototype === nothing + randn(T, (num_params, num_classes)) + else + copy(get_metadata(prototype).parameters) + end + return (; parameters=_parameters, parameter_names) +end + +consistency_checks(::Options, prototype::Nothing) = nothing +function consistency_checks(options::Options, prototype) + if prototype === nothing + return nothing + end + @assert( + prototype isa options.expression_type, + "Need prototype to be of type $(options.expression_type), but got $(prototype)::$(typeof(prototype))" + ) + if prototype isa ParametricExpression + if prototype.metadata.parameter_names !== nothing + @assert( + length(prototype.metadata.parameter_names) == + options.expression_options.max_parameters, + "Mismatch between options.expression_options.max_parameters=$(options.expression_options.max_parameters) and prototype.metadata.parameter_names=$(prototype.metadata.parameter_names)" + ) + end + @assert size(prototype.metadata.parameters, 1) == + options.expression_options.max_parameters + end + return nothing +end + +@unstable begin + function embed_metadata( + ex::AbstractExpression, options::Options, dataset::Dataset{T,L} + ) where {T,L} + return with_metadata(ex; init_params(options, dataset, ex, Val(true))...) + end + function embed_metadata( + member::PopMember, options::Options, dataset::Dataset{T,L} + ) where {T,L} + return PopMember( + embed_metadata(member.tree, options, dataset), + member.score, + member.loss, + nothing; + member.ref, + member.parent, + deterministic=options.deterministic, + ) + end + function embed_metadata( + pop::Population, options::Options, dataset::Dataset{T,L} + ) where {T,L} + return Population( + map(member -> embed_metadata(member, options, dataset), pop.members) + ) + end + function embed_metadata( + hof::HallOfFame, options::Options, dataset::Dataset{T,L} + ) where {T,L} + return HallOfFame( + map(member -> embed_metadata(member, options, dataset), hof.members), hof.exists + ) + end + function embed_metadata( + vec::Vector{H}, options::Options, dataset::Dataset{T,L} + ) where {T,L,H<:Union{HallOfFame,Population,PopMember}} + return map(elem -> embed_metadata(elem, options, dataset), vec) + end +end + +"""Strips all metadata except for top-level information""" +function strip_metadata(ex::Expression, options::Options, dataset::Dataset{T,L}) where {T,L} + return with_metadata(ex; init_params(options, dataset, ex, Val(false))...) +end +function strip_metadata( + ex::ParametricExpression, options::Options, dataset::Dataset{T,L} +) where {T,L} + return with_metadata(ex; init_params(options, dataset, ex, Val(false))...) +end +function strip_metadata( + member::PopMember, options::Options, dataset::Dataset{T,L} +) where {T,L} + return PopMember( + strip_metadata(member.tree, options, dataset), + member.score, + member.loss, + nothing; + member.ref, + member.parent, + deterministic=options.deterministic, + ) +end +function strip_metadata( + pop::Population, options::Options, dataset::Dataset{T,L} +) where {T,L} + return Population(map(member -> strip_metadata(member, options, dataset), pop.members)) +end +function strip_metadata( + hof::HallOfFame, options::Options, dataset::Dataset{T,L} +) where {T,L} + return HallOfFame( + map(member -> strip_metadata(member, options, dataset), hof.members), hof.exists + ) +end + +function eval_tree_dispatch( + tree::ParametricExpression{T}, dataset::Dataset{T}, options::Options, idx +) where {T<:DATA_TYPE} + A = expected_array_type(dataset.X) + return eval_tree_array( + tree, + maybe_getindex(dataset.X, :, idx), + maybe_getindex(dataset.extra.classes, idx), + options.operators, + )::Tuple{A,Bool} +end + +function make_random_leaf( + nfeatures::Int, + ::Type{T}, + ::Type{N}, + rng::AbstractRNG=default_rng(), + options::Union{Options,Nothing}=nothing, +) where {T<:DATA_TYPE,N<:ParametricNode} + choice = rand(rng, 1:3) + if choice == 1 + return ParametricNode(; val=randn(rng, T)) + elseif choice == 2 + return ParametricNode(T; feature=rand(rng, 1:nfeatures)) + else + tree = ParametricNode{T}() + tree.val = zero(T) + tree.degree = 0 + tree.feature = 0 + tree.constant = false + tree.is_parameter = true + tree.parameter = rand( + rng, UInt16(1):UInt16(options.expression_options.max_parameters) + ) + return tree + end +end + +function crossover_trees( + ex1::ParametricExpression{T}, ex2::AbstractExpression{T}, rng::AbstractRNG=default_rng() +) where {T} + tree1 = get_contents(ex1) + tree2 = get_contents(ex2) + out1, out2 = crossover_trees(tree1, tree2, rng) + ex1 = with_contents(ex1, out1) + ex2 = with_contents(ex2, out2) + + # We also randomly share parameters + nparams1 = size(ex1.metadata.parameters, 1) + nparams2 = size(ex2.metadata.parameters, 1) + num_params_switch = min(nparams1, nparams2) + idx_to_switch = StatsBase.sample( + rng, 1:num_params_switch, num_params_switch; replace=false + ) + for param_idx in idx_to_switch + ex2_params = ex2.metadata.parameters[param_idx, :] + ex2.metadata.parameters[param_idx, :] .= ex1.metadata.parameters[param_idx, :] + ex1.metadata.parameters[param_idx, :] .= ex2_params + end + + return ex1, ex2 +end + +function count_constants_for_optimization(ex::ParametricExpression) + return count_scalar_constants(get_tree(ex)) + length(ex.metadata.parameters) +end + +function mutate_constant( + ex::ParametricExpression{T}, + temperature, + options::Options, + rng::AbstractRNG=default_rng(), +) where {T<:DATA_TYPE} + if rand(rng, Bool) + # Normal mutation of inner constant + tree = get_contents(ex) + return with_contents(ex, mutate_constant(tree, temperature, options, rng)) + else + # Mutate parameters + parameter_index = rand(rng, 1:(options.expression_options.max_parameters)) + # We mutate all the parameters at once + factor = mutate_factor(T, temperature, options, rng) + ex.metadata.parameters[parameter_index, :] .*= factor + return ex + end +end + +@unstable function get_operators(ex::AbstractExpression, options::Options) + return get_operators(ex, options.operators) +end +@unstable function get_operators(ex::AbstractExpressionNode, options::Options) + return get_operators(ex, options.operators) +end + +end diff --git a/src/HallOfFame.jl b/src/HallOfFame.jl index 8224fe480..19c52f933 100644 --- a/src/HallOfFame.jl +++ b/src/HallOfFame.jl @@ -1,12 +1,11 @@ module HallOfFameModule -using DynamicExpressions: AbstractExpressionNode, Node, constructorof, string_tree -using DynamicExpressions.EquationModule: with_type_parameters +using DynamicExpressions: AbstractExpression, string_tree using ..UtilsModule: split_string -using ..CoreModule: MAX_DEGREE, Options, Dataset, DATA_TYPE, LOSS_TYPE, relu +using ..CoreModule: + MAX_DEGREE, Options, Dataset, DATA_TYPE, LOSS_TYPE, relu, create_expression using ..ComplexityModule: compute_complexity using ..PopMemberModule: PopMember -using ..LossFunctionsModule: eval_loss using ..InterfaceDynamicExpressionsModule: format_dimensions, WILDCARD_UNIT_STRING using Printf: @sprintf @@ -23,13 +22,33 @@ have been set, you can run `.members[exists]`. These are ordered by complexity, with `.members[1]` the member with complexity 1. - `exists::Array{Bool,1}`: Whether the member at the given complexity has been set. """ -struct HallOfFame{T<:DATA_TYPE,L<:LOSS_TYPE,N<:AbstractExpressionNode{T}} +struct HallOfFame{T<:DATA_TYPE,L<:LOSS_TYPE,N<:AbstractExpression{T}} members::Array{PopMember{T,L,N},1} exists::Array{Bool,1} #Whether it has been set end +function Base.show(io::IO, mime::MIME"text/plain", hof::HallOfFame{T,L,N}) where {T,L,N} + println(io, "HallOfFame{...}:") + for i in eachindex(hof.members, hof.exists) + s_member, s_exists = if hof.exists[i] + sprint((io, m) -> show(io, mime, m), hof.members[i]), "true" + else + "undef", "false" + end + println(io, " "^4 * ".exists[$i] = $s_exists") + print(io, " "^4 * ".members[$i] =") + splitted = split(strip(s_member), '\n') + if length(splitted) == 1 + println(io, " " * s_member) + else + println(io) + foreach(line -> println(io, " "^8 * line), splitted) + end + end + return nothing +end """ - HallOfFame(options::Options, ::Type{T}, ::Type{L}) where {T<:DATA_TYPE,L<:LOSS_TYPE,N<:AbstractExpressionNode} + HallOfFame(options::Options, dataset::Dataset{T,L}) where {T<:DATA_TYPE,L<:LOSS_TYPE} Create empty HallOfFame. The HallOfFame stores a list of `PopMember` objects in `.members`, which is enumerated @@ -39,18 +58,18 @@ has been instantiated or not. Arguments: - `options`: Options containing specification about deterministic. -- `T`: Type of Nodes to use in the population. e.g., `Float64`. -- `L`: Type of loss to use in the population. e.g., `Float64`. +- `dataset`: Dataset containing the input data. """ function HallOfFame( - options::Options, ::Type{T}, ::Type{L} + options::Options, dataset::Dataset{T,L} ) where {T<:DATA_TYPE,L<:LOSS_TYPE} actualMaxsize = options.maxsize + MAX_DEGREE - NT = with_type_parameters(options.node_type, T) - return HallOfFame{T,L,NT}( + base_tree = create_expression(zero(T), options, dataset) + + return HallOfFame{T,L,typeof(base_tree)}( [ PopMember( - constructorof(options.node_type)(T; val=convert(T, 1)), + copy(base_tree), L(0), L(Inf), options; diff --git a/src/InterfaceDynamicExpressions.jl b/src/InterfaceDynamicExpressions.jl index ebd88c7f2..d5cf52300 100644 --- a/src/InterfaceDynamicExpressions.jl +++ b/src/InterfaceDynamicExpressions.jl @@ -1,27 +1,24 @@ module InterfaceDynamicExpressionsModule using Printf: @sprintf -using DynamicExpressions: DynamicExpressions using DynamicExpressions: - OperatorEnum, GenericOperatorEnum, AbstractExpressionNode, Node, GraphNode -using DynamicExpressions.StringsModule: needs_brackets + DynamicExpressions as DE, + OperatorEnum, + GenericOperatorEnum, + AbstractExpression, + AbstractExpressionNode, + ParametricExpression, + Node, + GraphNode using DynamicQuantities: dimension, ustrip using ..CoreModule: Options using ..CoreModule.OptionsModule: inverse_binopmap, inverse_unaopmap using ..UtilsModule: subscriptify -import DynamicExpressions: - eval_tree_array, - eval_diff_tree_array, - eval_grad_tree_array, - print_tree, - string_tree, - differentiable_eval_tree_array - import ..deprecate_varmap """ - eval_tree_array(tree::AbstractExpressionNode, X::AbstractArray, options::Options; kws...) + eval_tree_array(tree::Union{AbstractExpression,AbstractExpressionNode}, X::AbstractArray, options::Options; kws...) Evaluate a binary tree (equation) over a given input data matrix. The operators contain all of the operators used. This function fuses doublets @@ -42,7 +39,7 @@ The bulk of the code is for optimizations and pre-emptive NaN/Inf checks, which speed up evaluation significantly. # Arguments -- `tree::AbstractExpressionNode`: The root node of the tree to evaluate. +- `tree::Union{AbstractExpression,AbstractExpressionNode}`: The root node of the tree to evaluate. - `X::AbstractArray`: The input data to evaluate the tree on. - `options::Options`: Options used to define the operators used in the tree. @@ -53,12 +50,38 @@ which speed up evaluation significantly. or nan was encountered, and a large loss should be assigned to the equation. """ -function eval_tree_array( - tree::AbstractExpressionNode, X::AbstractArray, options::Options; kws... +function DE.eval_tree_array( + tree::Union{AbstractExpressionNode,AbstractExpression}, + X::AbstractMatrix, + options::Options; + kws..., ) A = expected_array_type(X) - return eval_tree_array( - tree, X, options.operators; turbo=options.turbo, bumper=options.bumper, kws... + return DE.eval_tree_array( + tree, + X, + DE.get_operators(tree, options); + turbo=options.turbo, + bumper=options.bumper, + kws..., + )::Tuple{A,Bool} +end +function DE.eval_tree_array( + tree::ParametricExpression, + X::AbstractMatrix, + classes::AbstractVector{<:Integer}, + options::Options; + kws..., +) + A = expected_array_type(X) + return DE.eval_tree_array( + tree, + X, + classes, + DE.get_operators(tree, options); + turbo=options.turbo, + bumper=options.bumper, + kws..., )::Tuple{A,Bool} end @@ -68,7 +91,7 @@ function expected_array_type(X::AbstractArray) end """ - eval_diff_tree_array(tree::AbstractExpressionNode, X::AbstractArray, options::Options, direction::Int) + eval_diff_tree_array(tree::Union{AbstractExpression,AbstractExpressionNode}, X::AbstractArray, options::Options, direction::Int) Compute the forward derivative of an expression, using a similar structure and optimization to eval_tree_array. `direction` is the index of a particular @@ -77,7 +100,7 @@ respect to `x1`. # Arguments -- `tree::AbstractExpressionNode`: The expression tree to evaluate. +- `tree::Union{AbstractExpression,AbstractExpressionNode}`: The expression tree to evaluate. - `X::AbstractArray`: The data matrix, with each column being a data point. - `options::Options`: The options containing the operators used to create the `tree`. - `direction::Int`: The index of the variable to take the derivative with respect to. @@ -87,15 +110,21 @@ respect to `x1`. - `(evaluation, derivative, complete)::Tuple{AbstractVector, AbstractVector, Bool}`: the normal evaluation, the derivative, and whether the evaluation completed as normal (or encountered a nan or inf). """ -function eval_diff_tree_array( - tree::AbstractExpressionNode, X::AbstractArray, options::Options, direction::Int +function DE.eval_diff_tree_array( + tree::Union{AbstractExpression,AbstractExpressionNode}, + X::AbstractArray, + options::Options, + direction::Int, ) A = expected_array_type(X) - return eval_diff_tree_array(tree, X, options.operators, direction)::Tuple{A,A,Bool} + # TODO: Add `AbstractExpression` implementation in `Expression.jl` + return DE.eval_diff_tree_array( + DE.get_tree(tree), X, DE.get_operators(tree, options), direction + )::Tuple{A,A,Bool} end """ - eval_grad_tree_array(tree::AbstractExpressionNode, X::AbstractArray, options::Options; variable::Bool=false) + eval_grad_tree_array(tree::Union{AbstractExpression,AbstractExpressionNode}, X::AbstractArray, options::Options; variable::Bool=false) Compute the forward-mode derivative of an expression, using a similar structure and optimization to eval_tree_array. `variable` specifies whether @@ -104,7 +133,7 @@ to every constant in the expression. # Arguments -- `tree::AbstractExpressionNode`: The expression tree to evaluate. +- `tree::Union{AbstractExpression,AbstractExpressionNode}`: The expression tree to evaluate. - `X::AbstractArray`: The data matrix, with each column being a data point. - `options::Options`: The options containing the operators used to create the `tree`. - `variable::Bool`: Whether to take derivatives with respect to features (i.e., `X` - with `variable=true`), @@ -115,12 +144,17 @@ to every constant in the expression. - `(evaluation, gradient, complete)::Tuple{AbstractVector, AbstractArray, Bool}`: the normal evaluation, the gradient, and whether the evaluation completed as normal (or encountered a nan or inf). """ -function eval_grad_tree_array( - tree::AbstractExpressionNode, X::AbstractArray, options::Options; kws... +function DE.eval_grad_tree_array( + tree::Union{AbstractExpression,AbstractExpressionNode}, + X::AbstractArray, + options::Options; + kws..., ) A = expected_array_type(X) M = typeof(X) # TODO: This won't work with StaticArrays! - return eval_grad_tree_array(tree, X, options.operators; kws...)::Tuple{A,M,Bool} + return DE.eval_grad_tree_array( + tree, X, DE.get_operators(tree, options); kws... + )::Tuple{A,M,Bool} end """ @@ -128,11 +162,16 @@ end Evaluate an expression tree in a way that can be auto-differentiated. """ -function differentiable_eval_tree_array( - tree::AbstractExpressionNode, X::AbstractArray, options::Options +function DE.differentiable_eval_tree_array( + tree::Union{AbstractExpression,AbstractExpressionNode}, + X::AbstractArray, + options::Options, ) A = expected_array_type(X) - return differentiable_eval_tree_array(tree, X, options.operators)::Tuple{A,Bool} + # TODO: Add `AbstractExpression` implementation in `Expression.jl` + return DE.differentiable_eval_tree_array( + DE.get_tree(tree), X, DE.get_operators(tree, options) + )::Tuple{A,Bool} end const WILDCARD_UNIT_STRING = "[?]" @@ -149,8 +188,8 @@ Convert an equation to a string. - `variable_names::Union{Array{String, 1}, Nothing}=nothing`: what variables to print for each feature. """ -@inline function string_tree( - tree::AbstractExpressionNode, +@inline function DE.string_tree( + tree::Union{AbstractExpression,AbstractExpressionNode}, options::Options; raw::Bool=true, X_sym_units=nothing, @@ -164,16 +203,19 @@ Convert an equation to a string. if raw tree = tree isa GraphNode ? convert(Node, tree) : tree - return string_tree( - tree, options.operators; f_variable=string_variable_raw, variable_names + return DE.string_tree( + tree, + DE.get_operators(tree, options); + f_variable=string_variable_raw, + variable_names, ) end vprecision = vals[options.print_precision] if X_sym_units !== nothing || y_sym_units !== nothing - return string_tree( + return DE.string_tree( tree, - options.operators; + DE.get_operators(tree, options); f_variable=(feature, vname) -> string_variable(feature, vname, X_sym_units), f_constant=let unit_placeholder = @@ -184,9 +226,9 @@ Convert an equation to a string. kws..., ) else - return string_tree( + return DE.string_tree( tree, - options.operators; + DE.get_operators(tree, options); f_variable=string_variable, f_constant=(val,) -> string_constant(val, vprecision, ""), variable_names=display_variable_names, @@ -252,22 +294,15 @@ Print an equation - `variable_names::Union{Array{String, 1}, Nothing}=nothing`: what variables to print for each feature. """ -function print_tree(tree::AbstractExpressionNode, options::Options; kws...) - return print_tree(tree, options.operators; kws...) -end -function print_tree(io::IO, tree::AbstractExpressionNode, options::Options; kws...) - return print_tree(io, tree, options.operators; kws...) +function DE.print_tree( + tree::Union{AbstractExpression,AbstractExpressionNode}, options::Options; kws... +) + return DE.print_tree(tree, DE.get_operators(tree, options); kws...) end - -""" - convert(::Type{<:AbstractExpressionNode{T}}, tree::AbstractExpressionNode, options::Options; kws...) where {T} - -Convert an equation to a different base type `T`. -""" -function Base.convert( - ::Type{N}, tree::AbstractExpressionNode, options::Options -) where {T,N<:AbstractExpressionNode{T}} - return convert(N, tree, options.operators) +function DE.print_tree( + io::IO, tree::Union{AbstractExpression,AbstractExpressionNode}, options::Options; kws... +) + return DE.print_tree(io, tree, DE.get_operators(tree, options); kws...) end """ @@ -283,13 +318,13 @@ defined. macro extend_operators(options) operators = :($(options).operators) type_requirements = Options - @gensym alias_operators + alias_operators = gensym("alias_operators") return quote if !isa($(options), $type_requirements) error("You must pass an options type to `@extend_operators`.") end $alias_operators = $define_alias_operators($operators) - $(DynamicExpressions).@extend_operators $alias_operators + $(DE).@extend_operators $alias_operators end |> esc end function define_alias_operators(operators) @@ -304,14 +339,22 @@ function define_alias_operators(operators) ) end -function (tree::AbstractExpressionNode)(X, options::Options; kws...) - return tree(X, options.operators; turbo=options.turbo, bumper=options.bumper, kws...) +function (tree::Union{AbstractExpression,AbstractExpressionNode})( + X, options::Options; kws... +) + return tree( + X, + DE.get_operators(tree, options); + turbo=options.turbo, + bumper=options.bumper, + kws..., + ) end -function DynamicExpressions.EvaluationHelpersModule._grad_evaluator( - tree::AbstractExpressionNode, X, options::Options; kws... +function DE.EvaluationHelpersModule._grad_evaluator( + tree::Union{AbstractExpression,AbstractExpressionNode}, X, options::Options; kws... ) - return DynamicExpressions.EvaluationHelpersModule._grad_evaluator( - tree, X, options.operators; turbo=options.turbo, kws... + return DE.EvaluationHelpersModule._grad_evaluator( + tree, X, DE.get_operators(tree, options); turbo=options.turbo, kws... ) end diff --git a/src/InterfaceDynamicQuantities.jl b/src/InterfaceDynamicQuantities.jl index 34580a3cc..91fbdd7ee 100644 --- a/src/InterfaceDynamicQuantities.jl +++ b/src/InterfaceDynamicQuantities.jl @@ -29,7 +29,7 @@ end function get_units(_, _, ::Nothing, ::Function) return nothing end -function get_units(::Type{T}, ::Type{D}, x::AbstractString, f::Function) where {T,D} +function get_units(::Type{T}, ::Type{D}, x::AbstractString, f::F) where {T,D,F<:Function} isempty(x) && return one(Quantity{T,D}) return convert(Quantity{T,D}, f(x)) end @@ -42,7 +42,7 @@ end function get_units(::Type{T}, ::Type{D}, x::Real, ::Function) where {T,D} return Quantity(convert(T, x), D)::Quantity{T,D} end -function get_units(::Type{T}, ::Type{D}, x::AbstractVector, f::Function) where {T,D} +function get_units(::Type{T}, ::Type{D}, x::AbstractVector, f::F) where {T,D,F<:Function} return Quantity{T,D}[get_units(T, D, xi, f) for xi in x] end # TODO: Allow for AbstractQuantity output here diff --git a/src/LossFunctions.jl b/src/LossFunctions.jl index ac5493234..a84218879 100644 --- a/src/LossFunctions.jl +++ b/src/LossFunctions.jl @@ -1,12 +1,12 @@ module LossFunctionsModule -using Random: MersenneTwister using StatsBase: StatsBase -using DynamicExpressions: AbstractExpressionNode, Node, constructorof +using DynamicExpressions: + AbstractExpression, AbstractExpressionNode, get_tree, eval_tree_array using LossFunctions: LossFunctions using LossFunctions: SupervisedLoss -using ..InterfaceDynamicExpressionsModule: eval_tree_array -using ..CoreModule: Options, Dataset, DATA_TYPE, LOSS_TYPE +using ..InterfaceDynamicExpressionsModule: expected_array_type +using ..CoreModule: Options, Dataset, create_expression, DATA_TYPE, LOSS_TYPE using ..ComplexityModule: compute_complexity using ..DimensionalAnalysisModule: violates_dimensional_constraints @@ -25,7 +25,7 @@ function _weighted_loss( x::AbstractArray{T}, y::AbstractArray{T}, w::AbstractArray{T}, loss::LT ) where {T<:DATA_TYPE,LT<:Union{Function,SupervisedLoss}} if loss isa SupervisedLoss - return LossFunctions.sum(loss, x, y, w; normalize=true) + return sum(loss, x, y, w; normalize=true) else l(i) = loss(x[i], y[i], w[i]) return sum(l, eachindex(x)) / sum(w) @@ -41,17 +41,25 @@ end end end +function eval_tree_dispatch( + tree::Union{AbstractExpression{T},AbstractExpressionNode{T}}, + dataset::Dataset{T}, + options::Options, + idx, +) where {T<:DATA_TYPE} + A = expected_array_type(dataset.X) + return eval_tree_array(tree, maybe_getindex(dataset.X, :, idx), options)::Tuple{A,Bool} +end + # Evaluate the loss of a particular expression on the input dataset. function _eval_loss( - tree::AbstractExpressionNode{T}, + tree::Union{AbstractExpression{T},AbstractExpressionNode{T}}, dataset::Dataset{T,L}, options::Options, regularization::Bool, idx, )::L where {T<:DATA_TYPE,L<:LOSS_TYPE} - (prediction, completion) = eval_tree_array( - tree, maybe_getindex(dataset.X, :, idx), options - ) + (prediction, completion) = eval_tree_dispatch(tree, dataset, options, idx) if !completion return L(Inf) end @@ -95,7 +103,7 @@ end # Evaluate the loss of a particular expression on the input dataset. function eval_loss( - tree::AbstractExpressionNode{T}, + tree::Union{AbstractExpression{T},AbstractExpressionNode{T}}, dataset::Dataset{T,L}, options::Options; regularization::Bool=true, @@ -105,14 +113,14 @@ function eval_loss( _eval_loss(tree, dataset, options, regularization, idx) else f = options.loss_function::Function - evaluator(f, tree, dataset, options, idx) + evaluator(f, get_tree(tree), dataset, options, idx) end return loss_val end function eval_loss_batched( - tree::AbstractExpressionNode{T}, + tree::Union{AbstractExpression{T},AbstractExpressionNode{T}}, dataset::Dataset{T,L}, options::Options; regularization::Bool=true, @@ -127,8 +135,8 @@ function batch_sample(dataset, options) end # Just so we can pass either PopMember or Node here: -get_tree(t::AbstractExpressionNode) = t -get_tree(m) = m.tree +get_tree_from_member(t::Union{AbstractExpression,AbstractExpressionNode}) = t +get_tree_from_member(m) = m.tree # Beware: this is a circular dependency situation... # PopMember is using losses, but then we also want # losses to use the PopMember's cached complexity for trees. @@ -161,7 +169,7 @@ end function score_func( dataset::Dataset{T,L}, member, options::Options; complexity::Union{Int,Nothing}=nothing )::Tuple{L,L} where {T<:DATA_TYPE,L<:LOSS_TYPE} - result_loss = eval_loss(get_tree(member), dataset, options) + result_loss = eval_loss(get_tree_from_member(member), dataset, options) score = loss_to_score( result_loss, dataset.use_baseline, @@ -181,7 +189,7 @@ function score_func_batched( complexity::Union{Int,Nothing}=nothing, idx=nothing, )::Tuple{L,L} where {T<:DATA_TYPE,L<:LOSS_TYPE} - result_loss = eval_loss_batched(get_tree(member), dataset, options; idx=idx) + result_loss = eval_loss_batched(get_tree_from_member(member), dataset, options; idx=idx) score = loss_to_score( result_loss, dataset.use_baseline, @@ -201,7 +209,8 @@ Update the baseline loss of the dataset using the loss function specified in `op function update_baseline_loss!( dataset::Dataset{T,L}, options::Options ) where {T<:DATA_TYPE,L<:LOSS_TYPE} - example_tree = constructorof(options.node_type)(T; val=dataset.avg_y) + example_tree = create_expression(zero(T), options, dataset) + # constructorof(options.node_type)(T; val=dataset.avg_y) # TODO: It could be that the loss function is not defined for this example type? baseline_loss = eval_loss(example_tree, dataset, options) if isfinite(baseline_loss) @@ -215,7 +224,9 @@ function update_baseline_loss!( end function dimensional_regularization( - tree::AbstractExpressionNode{T}, dataset::Dataset{T,L}, options::Options + tree::Union{AbstractExpression{T},AbstractExpressionNode{T}}, + dataset::Dataset{T,L}, + options::Options, ) where {T<:DATA_TYPE,L<:LOSS_TYPE} if !violates_dimensional_constraints(tree, dataset, options) return zero(L) diff --git a/src/MLJInterface.jl b/src/MLJInterface.jl index f0f9943b1..2bbce47f1 100644 --- a/src/MLJInterface.jl +++ b/src/MLJInterface.jl @@ -3,7 +3,16 @@ module MLJInterfaceModule using Optim: Optim using LineSearches: LineSearches using MLJModelInterface: MLJModelInterface as MMI -using DynamicExpressions: eval_tree_array, string_tree, AbstractExpressionNode, Node +using ADTypes: AbstractADType +using DynamicExpressions: + eval_tree_array, + string_tree, + AbstractExpressionNode, + AbstractExpression, + Node, + Expression, + default_node_type, + get_tree using DynamicQuantities: QuantityArray, UnionAbstractQuantity, @@ -20,12 +29,20 @@ using ..CoreModule: Options, Dataset, MutationWeights, LOSS_TYPE using ..CoreModule.OptionsModule: DEFAULT_OPTIONS, OPTION_DESCRIPTIONS using ..ComplexityModule: compute_complexity using ..HallOfFameModule: HallOfFame, format_hall_of_fame -using ..UtilsModule: subscriptify +using ..UtilsModule: subscriptify, @ignore import ..equation_search abstract type AbstractSRRegressor <: MMI.Deterministic end +# For static analysis tools: +@ignore mutable struct SRRegressor <: AbstractSRRegressor + selection_method::Function +end +@ignore mutable struct MultitargetSRRegressor <: AbstractSRRegressor + selection_method::Function +end + # TODO: To reduce code re-use, we could forward these defaults from # `equation_search`, similar to what we do for `Options`. @@ -122,11 +139,27 @@ function MMI.update( m::AbstractSRRegressor, verbosity, old_fitresult, old_cache, X, y, w=nothing ) options = old_fitresult === nothing ? get_options(m) : old_fitresult.options - return _update(m, verbosity, old_fitresult, old_cache, X, y, w, options) + return _update(m, verbosity, old_fitresult, old_cache, X, y, w, options, nothing) end -function _update(m, verbosity, old_fitresult, old_cache, X, y, w, options) +function _update(m, verbosity, old_fitresult, old_cache, X, y, w, options, classes) + if isnothing(classes) && MMI.istable(X) && haskey(X, :classes) + if !(X isa NamedTuple) + error("Classes can only be specified with named tuples.") + end + new_X = Base.structdiff(X, (; X.classes)) + new_classes = X.classes + return _update( + m, verbosity, old_fitresult, old_cache, new_X, y, w, options, new_classes + ) + end + if !isnothing(old_fitresult) + @assert( + old_fitresult.has_classes == !isnothing(classes), + "If the first fit used classes, the second fit must also use classes." + ) + end # To speed up iterative fits, we cache the types: - types = if old_fitresult === nothing + types = if isnothing(old_fitresult) (; T=Any, X_t=Any, @@ -174,6 +207,7 @@ function _update(m, verbosity, old_fitresult, old_cache, X, y, w, options) X_units=X_units_clean, y_units=y_units_clean, verbosity=verbosity, + extra=isnothing(classes) ? (;) : (; classes), # Help out with inference: v_dim_out=isa(m, SRRegressor) ? Val(1) : Val(2), ) @@ -184,6 +218,7 @@ function _update(m, verbosity, old_fitresult, old_cache, X, y, w, options) variable_names=variable_names, y_variable_names=y_variable_names, y_is_table=MMI.istable(y), + has_classes=!isnothing(classes), X_units=X_units_clean, y_units=y_units_clean, types=( @@ -333,9 +368,20 @@ function MMI.fitted_params(m::AbstractSRRegressor, fitresult) end function eval_tree_mlj( - tree::Node, X_t, m::AbstractSRRegressor, ::Type{T}, fitresult, i, prototype + tree::AbstractExpression, + X_t, + classes, + m::AbstractSRRegressor, + ::Type{T}, + fitresult, + i, + prototype, ) where {T} - out, completed = eval_tree_array(tree, X_t, fitresult.options) + out, completed = if isnothing(classes) + eval_tree_array(tree, X_t, fitresult.options) + else + eval_tree_array(tree, X_t, classes, fitresult.options) + end if completed return wrap_units(out, fitresult.y_units, i) else @@ -343,13 +389,32 @@ function eval_tree_mlj( end end -function MMI.predict(m::M, fitresult, Xnew; idx=nothing) where {M<:AbstractSRRegressor} +function MMI.predict( + m::M, fitresult, Xnew; idx=nothing, classes=nothing +) where {M<:AbstractSRRegressor} + return _predict(m, fitresult, Xnew, idx, classes) +end +function _predict(m::M, fitresult, Xnew, idx, classes) where {M<:AbstractSRRegressor} if Xnew isa NamedTuple && (haskey(Xnew, :idx) || haskey(Xnew, :data)) @assert( haskey(Xnew, :idx) && haskey(Xnew, :data) && length(keys(Xnew)) == 2, "If specifying an equation index during prediction, you must use a named tuple with keys `idx` and `data`." ) - return MMI.predict(m, fitresult, Xnew.data; idx=Xnew.idx) + return _predict(m, fitresult, Xnew.data, Xnew.idx, classes) + end + if isnothing(classes) && MMI.istable(Xnew) && haskey(Xnew, :classes) + if !(Xnew isa NamedTuple) + error("Classes can only be specified with named tuples.") + end + Xnew2 = Base.structdiff(Xnew, (; Xnew.classes)) + return _predict(m, fitresult, Xnew2, idx, Xnew.classes) + end + + if fitresult.has_classes + @assert( + !isnothing(classes), + "Classes must be specified if the model was fit with classes." + ) end params = full_report(m, fitresult; v_with_strings=Val(false)) @@ -370,12 +435,12 @@ function MMI.predict(m::M, fitresult, Xnew; idx=nothing) where {M<:AbstractSRReg if M <: SRRegressor return eval_tree_mlj( - params.equations[idx], Xnew_t, m, T, fitresult, nothing, prototype + params.equations[idx], Xnew_t, classes, m, T, fitresult, nothing, prototype ) elseif M <: MultitargetSRRegressor outs = [ eval_tree_mlj( - params.equations[i][idx[i]], Xnew_t, m, T, fitresult, i, prototype + params.equations[i][idx[i]], Xnew_t, classes, m, T, fitresult, i, prototype ) for i in eachindex(idx, params.equations) ] out_matrix = reduce(hcat, outs) @@ -434,10 +499,16 @@ MMI.metadata_pkg( is_wrapper=false, ) +const input_scitype = Union{ + MMI.Table(MMI.Continuous), + AbstractMatrix{<:MMI.Continuous}, + MMI.Table(MMI.Continuous, MMI.Count), +} + # TODO: Allow for Count data, and coerce it into Continuous as needed. MMI.metadata_model( SRRegressor; - input_scitype=Union{MMI.Table(MMI.Continuous),AbstractMatrix{<:MMI.Continuous}}, + input_scitype, target_scitype=AbstractVector{<:MMI.Continuous}, supports_weights=true, reports_feature_importances=false, @@ -446,7 +517,7 @@ MMI.metadata_model( ) MMI.metadata_model( MultitargetSRRegressor; - input_scitype=Union{MMI.Table(MMI.Continuous),AbstractMatrix{<:MMI.Continuous}}, + input_scitype, target_scitype=Union{MMI.Table(MMI.Continuous),AbstractMatrix{<:MMI.Continuous}}, supports_weights=true, reports_feature_importances=false, @@ -518,11 +589,11 @@ function tag_with_docstring(model_name::Symbol, description::String, bottom_matt # Operations - `predict(mach, Xnew)`: Return predictions of the target given features `Xnew`, which - should have same scitype as `X` above. The expression used for prediction is defined - by the `selection_method` function, which can be seen by viewing `report(mach).best_idx`. + should have same scitype as `X` above. The expression used for prediction is defined + by the `selection_method` function, which can be seen by viewing `report(mach).best_idx`. - `predict(mach, (data=Xnew, idx=i))`: Return predictions of the target given features - `Xnew`, which should have same scitype as `X` above. By passing a named tuple with keys - `data` and `idx`, you are able to specify the equation you wish to evaluate in `idx`. + `Xnew`, which should have same scitype as `X` above. By passing a named tuple with keys + `data` and `idx`, you are able to specify the equation you wish to evaluate in `idx`. $(bottom_matter) """ diff --git a/src/Migration.jl b/src/Migration.jl index fb6e99d02..daab9255f 100644 --- a/src/Migration.jl +++ b/src/Migration.jl @@ -1,7 +1,7 @@ module MigrationModule using StatsBase: StatsBase -using ..CoreModule: Options, DATA_TYPE, LOSS_TYPE +using ..CoreModule: Options using ..PopulationModule: Population using ..PopMemberModule: PopMember, reset_birth! using ..UtilsModule: poisson_sample diff --git a/src/Mutate.jl b/src/Mutate.jl index 8820359fc..8b4c6acf0 100644 --- a/src/Mutate.jl +++ b/src/Mutate.jl @@ -1,16 +1,16 @@ module MutateModule using DynamicExpressions: - AbstractExpressionNode, - Node, + AbstractExpression, + ParametricExpression, + with_contents, + get_tree, preserve_sharing, copy_node, - count_nodes, - count_constants, + count_scalar_constants, simplify_tree!, combine_operators -using ..CoreModule: - Options, MutationWeights, Dataset, RecordType, sample_mutation, DATA_TYPE, LOSS_TYPE +using ..CoreModule: Options, MutationWeights, Dataset, RecordType, sample_mutation using ..ComplexityModule: compute_complexity using ..LossFunctionsModule: score_func, score_func_batched using ..CheckConstraintsModule: check_constraints @@ -34,32 +34,32 @@ using ..RecorderModule: @recorder function condition_mutation_weights!( weights::MutationWeights, member::PopMember, options::Options, curmaxsize::Int ) + tree = get_tree(member.tree) if !preserve_sharing(typeof(member.tree)) weights.form_connection = 0.0 weights.break_connection = 0.0 end - if member.tree.degree == 0 + if tree.degree == 0 # If equation is too small, don't delete operators # or simplify weights.mutate_operator = 0.0 weights.swap_operands = 0.0 weights.delete_node = 0.0 weights.simplify = 0.0 - if !member.tree.constant + if !tree.constant weights.optimize = 0.0 weights.mutate_constant = 0.0 end return nothing end - if !any(node -> node.degree == 2, member.tree) + if !any(node -> node.degree == 2, tree) # swap is implemented only for binary ops weights.swap_operands = 0.0 end - #More constants => more likely to do constant mutation - n_constants = count_constants(member.tree) - weights.mutate_constant *= min(8, n_constants) / 8.0 + condition_mutate_constant!(typeof(member.tree), weights, member, options, curmaxsize) + complexity = compute_complexity(member, options) if complexity >= curmaxsize @@ -75,6 +75,33 @@ function condition_mutation_weights!( return nothing end +""" +Use this to modify how `mutate_constant` changes for an expression type. +""" +function condition_mutate_constant!( + ::Type{<:AbstractExpression}, + weights::MutationWeights, + member::PopMember, + options::Options, + curmaxsize::Int, +) + n_constants = count_scalar_constants(member.tree) + weights.mutate_constant *= min(8, n_constants) / 8.0 + + return nothing +end +function condition_mutate_constant!( + ::Type{<:ParametricExpression}, + weights::MutationWeights, + member::PopMember, + options::Options, + curmaxsize::Int, +) + # Avoid modifying the mutate_constant weight, since + # otherwise we would be mutating constants all the time! + return nothing +end + # Go through one simulated options.annealing mutation cycle # exp(-delta/T) defines probability of accepting a change function next_generation( @@ -87,7 +114,7 @@ function next_generation( tmp_recorder::RecordType, )::Tuple{ P,Bool,Float64 -} where {T,L,D<:Dataset{T,L},N<:AbstractExpressionNode{T},P<:PopMember{T,L,N}} +} where {T,L,D<:Dataset{T,L},N<:AbstractExpression{T},P<:PopMember{T,L,N}} parent_ref = member.ref mutation_accepted = false num_evals = 0.0 @@ -158,11 +185,10 @@ function next_generation( elseif mutation_choice == :simplify @assert options.should_simplify simplify_tree!(tree, options.operators) - if tree isa Node - tree = combine_operators(tree, options.operators) - end + tree = combine_operators(tree, options.operators) @recorder tmp_recorder["type"] = "partial_simplify" mutation_accepted = true + is_success_always_possible = true return ( PopMember( tree, @@ -175,16 +201,16 @@ function next_generation( mutation_accepted, num_evals, ) - - is_success_always_possible = true # Simplification shouldn't hurt complexity; unless some non-symmetric constraint # to commutative operator... - elseif mutation_choice == :randomize # We select a random size, though the generated tree # may have fewer nodes than we request. tree_size_to_generate = rand(1:curmaxsize) - tree = gen_random_tree_fixed_size(tree_size_to_generate, options, nfeatures, T) + tree = with_contents( + tree, + gen_random_tree_fixed_size(tree_size_to_generate, options, nfeatures, T), + ) @recorder tmp_recorder["type"] = "regenerate" is_success_always_possible = true @@ -202,9 +228,8 @@ function next_generation( num_evals += new_num_evals @recorder tmp_recorder["type"] = "optimize" mutation_accepted = true - return (cur_member, mutation_accepted, num_evals) - is_success_always_possible = true + return (cur_member, mutation_accepted, num_evals) elseif mutation_choice == :do_nothing @recorder begin tmp_recorder["type"] = "identity" @@ -212,6 +237,7 @@ function next_generation( tmp_recorder["reason"] = "identity" end mutation_accepted = true + is_success_always_possible = true return ( PopMember( tree, @@ -243,6 +269,7 @@ function next_generation( attempts += 1 end ############################################# + tree::AbstractExpression if !successful_mutation @recorder begin @@ -360,7 +387,7 @@ end """Generate a generation via crossover of two members.""" function crossover_generation( member1::P, member2::P, dataset::D, curmaxsize::Int, options::Options -)::Tuple{P,P,Bool,Float64} where {T,L,D<:Dataset{T,L},P<:PopMember{T,L}} +)::Tuple{P,P,Bool,Float64} where {T,L,D<:Dataset{T,L},N,P<:PopMember{T,L,N}} tree1 = member1.tree tree2 = member2.tree crossover_accepted = false @@ -413,7 +440,7 @@ function crossover_generation( afterSize1; parent=member1.ref, deterministic=options.deterministic, - ) + )::P baby2 = PopMember( child_tree2, afterScore2, @@ -422,7 +449,7 @@ function crossover_generation( afterSize2; parent=member2.ref, deterministic=options.deterministic, - ) + )::P crossover_accepted = true return baby1, baby2, crossover_accepted, num_evals diff --git a/src/MutationFunctions.jl b/src/MutationFunctions.jl index a45a32cf7..805a03b7d 100644 --- a/src/MutationFunctions.jl +++ b/src/MutationFunctions.jl @@ -3,8 +3,11 @@ module MutationFunctionsModule using Random: default_rng, AbstractRNG using DynamicExpressions: AbstractExpressionNode, + AbstractExpression, AbstractNode, NodeSampler, + get_contents, + with_contents, constructorof, copy_node, set_node!, @@ -31,6 +34,11 @@ function random_node( end """Swap operands in binary operator for ops like pow and divide""" +function swap_operands(ex::AbstractExpression, rng::AbstractRNG=default_rng()) + tree = get_contents(ex) + ex = with_contents(ex, swap_operands(tree, rng)) + return ex +end function swap_operands(tree::AbstractNode, rng::AbstractRNG=default_rng()) if !any(node -> node.degree == 2, tree) return tree @@ -41,6 +49,13 @@ function swap_operands(tree::AbstractNode, rng::AbstractRNG=default_rng()) end """Randomly convert an operator into another one (binary->binary; unary->unary)""" +function mutate_operator( + ex::AbstractExpression{T}, options::Options, rng::AbstractRNG=default_rng() +) where {T<:DATA_TYPE} + tree = get_contents(ex) + ex = with_contents(ex, mutate_operator(tree, options, rng)) + return ex +end function mutate_operator( tree::AbstractExpressionNode{T}, options::Options, rng::AbstractRNG=default_rng() ) where {T} @@ -57,6 +72,13 @@ function mutate_operator( end """Randomly perturb a constant""" +function mutate_constant( + ex::AbstractExpression{T}, temperature, options::Options, rng::AbstractRNG=default_rng() +) where {T<:DATA_TYPE} + tree = get_contents(ex) + ex = with_contents(ex, mutate_constant(tree, temperature, options, rng)) + return ex +end function mutate_constant( tree::AbstractExpressionNode{T}, temperature, @@ -70,25 +92,39 @@ function mutate_constant( end node = rand(rng, NodeSampler(; tree, filter=t -> (t.degree == 0 && t.constant))) + node.val *= mutate_factor(T, temperature, options, rng) + + return tree +end + +function mutate_factor(::Type{T}, temperature, options, rng) where {T<:DATA_TYPE} bottom = 1//10 maxChange = options.perturbation_factor * temperature + 1 + bottom factor = T(maxChange^rand(rng, T)) makeConstBigger = rand(rng, Bool) - if makeConstBigger - node.val *= factor - else - node.val /= factor - end + factor = makeConstBigger ? factor : 1 / factor if rand(rng) > options.probability_negate_constant - node.val *= -1 + factor *= -1 end - - return tree + return factor end +# TODO: Shouldn't we add a mutate_feature here? + """Add a random unary/binary operation to the end of a tree""" +function append_random_op( + ex::AbstractExpression{T}, + options::Options, + nfeatures::Int, + rng::AbstractRNG=default_rng(); + makeNewBinOp::Union{Bool,Nothing}=nothing, +) where {T<:DATA_TYPE} + tree = get_contents(ex) + ex = with_contents(ex, append_random_op(tree, options, nfeatures, rng; makeNewBinOp)) + return ex +end function append_random_op( tree::AbstractExpressionNode{T}, options::Options, @@ -104,14 +140,15 @@ function append_random_op( end if makeNewBinOp - newnode = constructorof(typeof(tree))( - rand(rng, 1:(options.nbin)), - make_random_leaf(nfeatures, T, typeof(tree), rng), - make_random_leaf(nfeatures, T, typeof(tree), rng), + newnode = constructorof(typeof(tree))(; + op=rand(rng, 1:(options.nbin)), + l=make_random_leaf(nfeatures, T, typeof(tree), rng, options), + r=make_random_leaf(nfeatures, T, typeof(tree), rng, options), ) else - newnode = constructorof(typeof(tree))( - rand(rng, 1:(options.nuna)), make_random_leaf(nfeatures, T, typeof(tree), rng) + newnode = constructorof(typeof(tree))(; + op=rand(rng, 1:(options.nuna)), + l=make_random_leaf(nfeatures, T, typeof(tree), rng, options), ) end @@ -121,6 +158,16 @@ function append_random_op( end """Insert random node""" +function insert_random_op( + ex::AbstractExpression{T}, + options::Options, + nfeatures::Int, + rng::AbstractRNG=default_rng(), +) where {T<:DATA_TYPE} + tree = get_contents(ex) + ex = with_contents(ex, insert_random_op(tree, options, nfeatures, rng)) + return ex +end function insert_random_op( tree::AbstractExpressionNode{T}, options::Options, @@ -133,16 +180,28 @@ function insert_random_op( left = copy_node(node) if makeNewBinOp - right = make_random_leaf(nfeatures, T, typeof(tree), rng) - newnode = constructorof(typeof(tree))(rand(rng, 1:(options.nbin)), left, right) + right = make_random_leaf(nfeatures, T, typeof(tree), rng, options) + newnode = constructorof(typeof(tree))(; + op=rand(rng, 1:(options.nbin)), l=left, r=right + ) else - newnode = constructorof(typeof(tree))(rand(rng, 1:(options.nuna)), left) + newnode = constructorof(typeof(tree))(; op=rand(rng, 1:(options.nuna)), l=left) end set_node!(node, newnode) return tree end """Add random node to the top of a tree""" +function prepend_random_op( + ex::AbstractExpression{T}, + options::Options, + nfeatures::Int, + rng::AbstractRNG=default_rng(), +) where {T<:DATA_TYPE} + tree = get_contents(ex) + ex = with_contents(ex, prepend_random_op(tree, options, nfeatures, rng)) + return ex +end function prepend_random_op( tree::AbstractExpressionNode{T}, options::Options, @@ -155,20 +214,26 @@ function prepend_random_op( left = copy_node(tree) if makeNewBinOp - right = make_random_leaf(nfeatures, T, typeof(tree), rng) - newnode = constructorof(typeof(tree))(rand(rng, 1:(options.nbin)), left, right) + right = make_random_leaf(nfeatures, T, typeof(tree), rng, options) + newnode = constructorof(typeof(tree))(; + op=rand(rng, 1:(options.nbin)), l=left, r=right + ) else - newnode = constructorof(typeof(tree))(rand(rng, 1:(options.nuna)), left) + newnode = constructorof(typeof(tree))(; op=rand(rng, 1:(options.nuna)), l=left) end set_node!(node, newnode) return node end function make_random_leaf( - nfeatures::Int, ::Type{T}, ::Type{N}, rng::AbstractRNG=default_rng() + nfeatures::Int, + ::Type{T}, + ::Type{N}, + rng::AbstractRNG=default_rng(), + ::Union{Options,Nothing}=nothing, ) where {T<:DATA_TYPE,N<:AbstractExpressionNode} if rand(rng, Bool) - return constructorof(N)(; val=randn(rng, T)) + return constructorof(N)(T; val=randn(rng, T)) else return constructorof(N)(T; feature=rand(rng, 1:nfeatures)) end @@ -188,6 +253,16 @@ function random_node_and_parent(tree::AbstractNode, rng::AbstractRNG=default_rng end """Select a random node, and splice it out of the tree.""" +function delete_random_op!( + ex::AbstractExpression{T}, + options::Options, + nfeatures::Int, + rng::AbstractRNG=default_rng(), +) where {T<:DATA_TYPE} + tree = get_contents(ex) + ex = with_contents(ex, delete_random_op!(tree, options, nfeatures, rng)) + return ex +end function delete_random_op!( tree::AbstractExpressionNode{T}, options::Options, @@ -199,7 +274,7 @@ function delete_random_op!( if node.degree == 0 # Replace with new constant - newnode = make_random_leaf(nfeatures, T, typeof(tree), rng) + newnode = make_random_leaf(nfeatures, T, typeof(tree), rng, options) set_node!(node, newnode) elseif node.degree == 1 # Join one of the children with the parent @@ -253,7 +328,7 @@ function gen_random_tree_fixed_size( ::Type{T}, rng::AbstractRNG=default_rng(), ) where {T<:DATA_TYPE} - tree = make_random_leaf(nfeatures, T, options.node_type, rng) + tree = make_random_leaf(nfeatures, T, options.node_type, rng, options) cur_size = count_nodes(tree) while cur_size < node_count if cur_size == node_count - 1 # only unary operator allowed. @@ -267,12 +342,21 @@ function gen_random_tree_fixed_size( return tree end +function crossover_trees( + ex1::E, ex2::E, rng::AbstractRNG=default_rng() +) where {T,E<:AbstractExpression{T}} + tree1 = get_contents(ex1) + tree2 = get_contents(ex2) + out1, out2 = crossover_trees(tree1, tree2, rng) + ex1 = with_contents(ex1, out1) + ex2 = with_contents(ex2, out2) + return ex1, ex2 +end + """Crossover between two expressions""" function crossover_trees( - tree1::AbstractExpressionNode{T}, - tree2::AbstractExpressionNode{T}, - rng::AbstractRNG=default_rng(), -) where {T} + tree1::N, tree2::N, rng::AbstractRNG=default_rng() +) where {T,N<:AbstractExpressionNode{T}} tree1 = copy_node(tree1) tree2 = copy_node(tree2) @@ -315,6 +399,10 @@ function get_two_nodes_without_loop(tree::AbstractNode, rng::AbstractRNG; max_at return (tree, tree, true) end +function form_random_connection!(ex::AbstractExpression, rng::AbstractRNG=default_rng()) + tree = get_contents(ex) + return with_contents(ex, form_random_connection!(tree, rng)) +end function form_random_connection!(tree::AbstractNode, rng::AbstractRNG=default_rng()) if length(tree) < 5 return tree @@ -334,6 +422,11 @@ function form_random_connection!(tree::AbstractNode, rng::AbstractRNG=default_rn end return tree end + +function break_random_connection!(ex::AbstractExpression, rng::AbstractRNG=default_rng()) + tree = get_contents(ex) + return with_contents(ex, break_random_connection!(tree, rng)) +end function break_random_connection!(tree::AbstractNode, rng::AbstractRNG=default_rng()) tree.degree == 0 && return tree parent = rand(rng, NodeSampler(; tree, filter=t -> t.degree != 0)) diff --git a/src/MutationWeights.jl b/src/MutationWeights.jl index 8549dd810..1f3f7369f 100644 --- a/src/MutationWeights.jl +++ b/src/MutationWeights.jl @@ -18,14 +18,14 @@ will be normalized to sum to 1.0 after initialization. - `randomize::Float64`: How often to create a random tree. - `do_nothing::Float64`: How often to do nothing. - `optimize::Float64`: How often to optimize the constants in the tree, as a mutation. - Note that this is different from `optimizer_probability`, which is - performed at the end of an iteration for all individuals. + Note that this is different from `optimizer_probability`, which is + performed at the end of an iteration for all individuals. - `form_connection::Float64`: **Only used for `GraphNode`, not regular `Node`**. - Otherwise, this will automatically be set to 0.0. How often to form a - connection between two nodes. + Otherwise, this will automatically be set to 0.0. How often to form a + connection between two nodes. - `break_connection::Float64`: **Only used for `GraphNode`, not regular `Node`**. - Otherwise, this will automatically be set to 0.0. How often to break a - connection between two nodes. + Otherwise, this will automatically be set to 0.0. How often to break a + connection between two nodes. """ Base.@kwdef mutable struct MutationWeights mutate_constant::Float64 = 0.048 diff --git a/src/Operators.jl b/src/Operators.jl index cc756f0d6..e7b99ea10 100644 --- a/src/Operators.jl +++ b/src/Operators.jl @@ -1,10 +1,12 @@ module OperatorsModule +using DynamicExpressions: DynamicExpressions as DE using SpecialFunctions: SpecialFunctions using DynamicQuantities: UnionAbstractQuantity using SpecialFunctions: erf, erfc using Base: @deprecate using ..ProgramConstantsModule: DATA_TYPE +using ...UtilsModule: @ignore #TODO - actually add these operators to the module! # TODO: Should this be limited to AbstractFloat instead? @@ -95,8 +97,21 @@ function logical_and(x, y) return ((x > zero(x)) & (y > zero(y))) * one(x) end +# Strings +DE.get_op_name(::typeof(safe_pow)) = "^" +DE.get_op_name(::typeof(safe_log)) = "log" +DE.get_op_name(::typeof(safe_log2)) = "log2" +DE.get_op_name(::typeof(safe_log10)) = "log10" +DE.get_op_name(::typeof(safe_log1p)) = "log1p" +DE.get_op_name(::typeof(safe_acosh)) = "acosh" +DE.get_op_name(::typeof(safe_sqrt)) = "sqrt" + # Deprecated operations: @deprecate pow(x, y) safe_pow(x, y) @deprecate pow_abs(x, y) safe_pow(x, y) +# For static analysis tools: +@ignore pow(x, y) = safe_pow(x, y) +@ignore pow_abs(x, y) = safe_pow(x, y) + end diff --git a/src/Options.jl b/src/Options.jl index 4131a40bf..0815bf744 100644 --- a/src/Options.jl +++ b/src/Options.jl @@ -4,8 +4,8 @@ using DispatchDoctor: @unstable using Optim: Optim using Dates: Dates using StatsBase: StatsBase -using DynamicExpressions: OperatorEnum, Node -using Distributed: nworkers +using DynamicExpressions: OperatorEnum, Expression, default_node_type +using ADTypes: AbstractADType, ADTypes using LossFunctions: L2DistLoss, SupervisedLoss using Optim: Optim using LineSearches: LineSearches @@ -28,65 +28,115 @@ using ..OperatorsModule: using ..MutationWeightsModule: MutationWeights, mutations import ..OptionsStructModule: Options using ..OptionsStructModule: ComplexityMapping, operator_specialization -using ..UtilsModule: max_ops, @save_kwargs - -""" - build_constraints(una_constraints, bin_constraints, - unary_operators, binary_operators) - -Build constraints on operator-level complexity from a user-passed dict. -""" -function build_constraints( - una_constraints, bin_constraints, unary_operators, binary_operators, nuna, nbin -)::Tuple{Array{Int,1},Array{Tuple{Int,Int},1}} +using ..UtilsModule: max_ops, @save_kwargs, @ignore + +"""Build constraints on operator-level complexity from a user-passed dict.""" +@unstable function build_constraints(; + una_constraints, + bin_constraints, + @nospecialize(unary_operators), + @nospecialize(binary_operators) +)::Tuple{Vector{Int},Vector{Tuple{Int,Int}}} # Expect format ((*)=>(-1, 3)), etc. # TODO: Need to disable simplification if (*, -, +, /) are constrained? # Or, just quit simplification is constraints violated. - is_bin_constraints_already_done = typeof(bin_constraints) <: Array{Tuple{Int,Int},1} - is_una_constraints_already_done = typeof(una_constraints) <: Array{Int,1} - - if typeof(bin_constraints) <: Array && !is_bin_constraints_already_done - bin_constraints = Dict(bin_constraints) + is_una_constraints_already_done = una_constraints isa Vector{Int} + _una_constraints1 = if una_constraints isa Array && !is_una_constraints_already_done + Dict(una_constraints) + else + una_constraints end - if typeof(una_constraints) <: Array && !is_una_constraints_already_done - una_constraints = Dict(una_constraints) + _una_constraints2 = if _una_constraints1 === nothing + fill(-1, length(unary_operators)) + elseif !is_una_constraints_already_done + [ + haskey(_una_constraints1, op) ? _una_constraints1[op]::Int : -1 for + op in unary_operators + ] + else + _una_constraints1 end - if una_constraints === nothing - una_constraints = [-1 for i in 1:nuna] - elseif !is_una_constraints_already_done - una_constraints::Dict - _una_constraints = Int[] - for (i, op) in enumerate(unary_operators) - did_user_declare_constraints = haskey(una_constraints, op) - if did_user_declare_constraints - constraint::Int = una_constraints[op] - push!(_una_constraints, constraint) - else - push!(_una_constraints, -1) - end - end - una_constraints = _una_constraints + is_bin_constraints_already_done = bin_constraints isa Vector{Tuple{Int,Int}} + _bin_constraints1 = if bin_constraints isa Array && !is_bin_constraints_already_done + Dict(bin_constraints) + else + bin_constraints end - if bin_constraints === nothing - bin_constraints = [(-1, -1) for i in 1:nbin] + _bin_constraints2 = if _bin_constraints1 === nothing + fill((-1, -1), length(binary_operators)) elseif !is_bin_constraints_already_done - bin_constraints::Dict - _bin_constraints = Tuple{Int,Int}[] - for (i, op) in enumerate(binary_operators) - did_user_declare_constraints = haskey(bin_constraints, op) - if did_user_declare_constraints - constraint::Tuple{Int,Int} = bin_constraints[op] - push!(_bin_constraints, constraint) + [ + if haskey(_bin_constraints1, op) + _bin_constraints1[op]::Tuple{Int,Int} else - push!(_bin_constraints, (-1, -1)) + (-1, -1) + end for op in binary_operators + ] + else + _bin_constraints1 + end + + return _una_constraints2, _bin_constraints2 +end + +@unstable function build_nested_constraints(; + @nospecialize(binary_operators), @nospecialize(unary_operators), nested_constraints +) + nested_constraints === nothing && return nested_constraints + # Check that intersection of binary operators and unary operators is empty: + for op in binary_operators + if op ∈ unary_operators + error( + "Operator $(op) is both a binary and unary operator. " * + "You can't use nested constraints.", + ) + end + end + + # Convert to dict: + _nested_constraints = if nested_constraints isa Dict + nested_constraints + else + # Convert to dict: + nested_constraints = Dict( + [cons[1] => Dict(cons[2]...) for cons in nested_constraints]... + ) + end + for (op, nested_constraint) in _nested_constraints + if !(op ∈ binary_operators || op ∈ unary_operators) + error("Operator $(op) is not in the operator set.") + end + for (nested_op, max_nesting) in nested_constraint + if !(nested_op ∈ binary_operators || nested_op ∈ unary_operators) + error("Operator $(nested_op) is not in the operator set.") end + @assert nested_op ∈ binary_operators || nested_op ∈ unary_operators + @assert max_nesting >= -1 && typeof(max_nesting) <: Int end - bin_constraints = _bin_constraints end - return una_constraints, bin_constraints + # Lastly, we clean it up into a dict of (degree,op_idx) => max_nesting. + return [ + let (degree, idx) = if op ∈ binary_operators + 2, findfirst(isequal(op), binary_operators)::Int + else + 1, findfirst(isequal(op), unary_operators)::Int + end, + new_max_nesting_dict = [ + let (nested_degree, nested_idx) = if nested_op ∈ binary_operators + 2, findfirst(isequal(nested_op), binary_operators)::Int + else + 1, findfirst(isequal(nested_op), unary_operators)::Int + end + (nested_degree, nested_idx, max_nesting) + end for (nested_op, max_nesting) in nested_constraint + ] + + (degree, idx, new_max_nesting_dict) + end for (op, nested_constraint) in _nested_constraints + ] end function binopmap(op::F) where {F} @@ -177,6 +227,9 @@ const deprecated_options_mapping = Base.ImmutableDict( :loss => :elementwise_loss, ) +# For static analysis tools: +@ignore const DEFAULT_OPTIONS = () + const OPTION_DESCRIPTIONS = """- `binary_operators`: Vector of binary operators (functions) to use. Each operator should be defined for two input scalars, and one output scalar. All operators @@ -248,8 +301,10 @@ const OPTION_DESCRIPTIONS = """- `binary_operators`: Vector of binary operators return sum((prediction .- dataset.y) .^ 2) / dataset.n end -- `node_type::Type{N}=Node`: The type of node to use for the search. - For example, `Node` or `GraphNode`. +- `expression_type::Type{E}=Expression`: The type of expression to use. + For example, `Expression`. +- `node_type::Type{N}=default_node_type(Expression)`: The type of node to use for the search. + For example, `Node` or `GraphNode`. The default is computed by `default_node_type(expression_type)`. - `populations`: How many populations of equations to use. - `population_size`: How many equations in each population. - `ncycles_per_iteration`: How many generations to consider per iteration. @@ -312,14 +367,21 @@ const OPTION_DESCRIPTIONS = """- `binary_operators`: Vector of binary operators - `optimizer_probability`: Probability of performing optimization of constants at the end of a given iteration. - `optimizer_iterations`: How many optimization iterations to perform. This gets - passed to `Optim.Options` as `iterations`. The default is 8. + passed to `Optim.Options` as `iterations`. The default is 8. - `optimizer_f_calls_limit`: How many function calls to allow during optimization. This gets passed to `Optim.Options` as `f_calls_limit`. The default is - `0` which means no limit. + `10_000`. - `optimizer_options`: General options for the constant optimization. For details we refer to the documentation on `Optim.Options` from the `Optim.jl` package. Options can be provided here as `NamedTuple`, e.g. `(iterations=16,)`, as a `Dict`, e.g. Dict(:x_tol => 1.0e-32,), or as an `Optim.Options` instance. +- `autodiff_backend`: The backend to use for differentiation, which should be + an instance of `AbstractADType` (see `DifferentiationInterface.jl`). + Default is `nothing`, which means `Optim.jl` will estimate gradients (likely + with finite differences). You can also pass a symbolic version of the backend + type, such as `:Zygote` for Zygote, `:Enzyme`, etc. Most backends will not + work, and many will never work due to incompatibilities, though support for some + is gradually being added. - `output_file`: What file to store equations to, as a backup. - `perturbation_factor`: When mutating a constant, either multiply or divide by (1+perturbation_factor)^(rand()+1). @@ -401,7 +463,9 @@ $(OPTION_DESCRIPTIONS) should_simplify::Union{Nothing,Bool}=nothing, should_optimize_constants::Bool=true, output_file::Union{Nothing,AbstractString}=nothing, - node_type::Type=Node, + expression_type::Type=Expression, + node_type::Type=default_node_type(expression_type), + expression_options::NamedTuple=NamedTuple(), populations::Integer=15, perturbation_factor::Real=0.076, annealing::Bool=false, @@ -434,6 +498,7 @@ $(OPTION_DESCRIPTIONS) optimizer_iterations::Union{Nothing,Integer}=nothing, optimizer_f_calls_limit::Union{Nothing,Integer}=nothing, optimizer_options::Union{Dict,NamedTuple,Optim.Options,Nothing}=nothing, + autodiff_backend::Union{AbstractADType,Symbol,Nothing}=nothing, use_recorder::Bool=false, recorder_file::AbstractString="pysr_recorder.json", early_stop_condition::Union{Function,Real,Nothing}=nothing, @@ -561,69 +626,15 @@ $(OPTION_DESCRIPTIONS) end end - nuna = length(unary_operators) - nbin = length(binary_operators) @assert maxsize > 3 @assert warmup_maxsize_by >= 0.0f0 - @assert nuna <= max_ops && nbin <= max_ops + @assert length(unary_operators) <= max_ops + @assert length(binary_operators) <= max_ops # Make sure nested_constraints contains functions within our operator set: - if nested_constraints !== nothing - # Check that intersection of binary operators and unary operators is empty: - for op in binary_operators - if op ∈ unary_operators - error( - "Operator $(op) is both a binary and unary operator. " * - "You can't use nested constraints.", - ) - end - end - - # Convert to dict: - if !(typeof(nested_constraints) <: Dict) - # Convert to dict: - nested_constraints = Dict( - [cons[1] => Dict(cons[2]...) for cons in nested_constraints]... - ) - end - for (op, nested_constraint) in nested_constraints - if !(op ∈ binary_operators || op ∈ unary_operators) - error("Operator $(op) is not in the operator set.") - end - for (nested_op, max_nesting) in nested_constraint - if !(nested_op ∈ binary_operators || nested_op ∈ unary_operators) - error("Operator $(nested_op) is not in the operator set.") - end - @assert nested_op ∈ binary_operators || nested_op ∈ unary_operators - @assert max_nesting >= -1 && typeof(max_nesting) <: Int - end - end - - # Lastly, we clean it up into a dict of (degree,op_idx) => max_nesting. - new_nested_constraints = [] - # Dict() - for (op, nested_constraint) in nested_constraints - (degree, idx) = if op ∈ binary_operators - 2, findfirst(isequal(op), binary_operators) - else - 1, findfirst(isequal(op), unary_operators) - end - new_max_nesting_dict = [] - # Dict() - for (nested_op, max_nesting) in nested_constraint - (nested_degree, nested_idx) = if nested_op ∈ binary_operators - 2, findfirst(isequal(nested_op), binary_operators) - else - 1, findfirst(isequal(nested_op), unary_operators) - end - # new_max_nesting_dict[(nested_degree, nested_idx)] = max_nesting - push!(new_max_nesting_dict, (nested_degree, nested_idx, max_nesting)) - end - # new_nested_constraints[(degree, idx)] = new_max_nesting_dict - push!(new_nested_constraints, (degree, idx, new_max_nesting_dict)) - end - nested_constraints = new_nested_constraints - end + _nested_constraints = build_nested_constraints(; + binary_operators, unary_operators, nested_constraints + ) if typeof(constraints) <: Tuple constraints = collect(constraints) @@ -642,8 +653,8 @@ $(OPTION_DESCRIPTIONS) una_constraints = constraints end - una_constraints, bin_constraints = build_constraints( - una_constraints, bin_constraints, unary_operators, binary_operators, nuna, nbin + _una_constraints, _bin_constraints = build_constraints(; + una_constraints, bin_constraints, unary_operators, binary_operators ) complexity_mapping = ComplexityMapping( @@ -692,7 +703,7 @@ $(OPTION_DESCRIPTIONS) if !isa(optimizer_options, Optim.Options) optimizer_iterations = isnothing(optimizer_iterations) ? 8 : optimizer_iterations optimizer_f_calls_limit = if isnothing(optimizer_f_calls_limit) - 0 + 10_000 else optimizer_f_calls_limit end @@ -710,35 +721,33 @@ $(OPTION_DESCRIPTIONS) @warn "Optimizer warnings are turned on. This might result in a lot of warnings being printed from NaNs, as these are common during symbolic regression" end - ## Create tournament weights: - tournament_selection_weights = - let n = tournament_selection_n, p = tournament_selection_p - k = collect(0:(n - 1)) - prob_each = p * ((1 - p) .^ k) - - StatsBase.Weights(prob_each, sum(prob_each)) - end - set_mutation_weights = create_mutation_weights(mutation_weights) @assert print_precision > 0 + _autodiff_backend = if autodiff_backend isa Union{Nothing,AbstractADType} + autodiff_backend + else + ADTypes.Auto(autodiff_backend) + end + options = Options{ typeof(complexity_mapping), operator_specialization(typeof(operators)), node_type, + expression_type, + typeof(expression_options), turbo, bumper, deprecated_return_state, - typeof(tournament_selection_weights), + typeof(_autodiff_backend), }( operators, - bin_constraints, - una_constraints, + _bin_constraints, + _una_constraints, complexity_mapping, tournament_selection_n, tournament_selection_p, - tournament_selection_weights, parsimony, dimensional_constraint_penalty, dimensionless_constants_only, @@ -772,18 +781,21 @@ $(OPTION_DESCRIPTIONS) print_precision, save_to_file, probability_negate_constant, - nuna, - nbin, + length(unary_operators), + length(binary_operators), seed, elementwise_loss, loss_function, node_type, + expression_type, + expression_options, progress, terminal_width, optimizer_algorithm, optimizer_probability, optimizer_nrestarts, optimizer_options, + _autodiff_backend, recorder_file, tournament_selection_p, early_stop_condition, @@ -791,7 +803,7 @@ $(OPTION_DESCRIPTIONS) timeout_in_seconds, max_evals, skip_mutation_failures, - nested_constraints, + _nested_constraints, deterministic, define_helper_functions, use_recorder, diff --git a/src/OptionsStruct.jl b/src/OptionsStruct.jl index f89a284d3..2417b75ba 100644 --- a/src/OptionsStruct.jl +++ b/src/OptionsStruct.jl @@ -1,8 +1,9 @@ module OptionsStructModule +using DispatchDoctor: @unstable using Optim: Optim using DynamicExpressions: - AbstractOperatorEnum, AbstractExpressionNode, OperatorEnum, GenericOperatorEnum + AbstractOperatorEnum, AbstractExpressionNode, AbstractExpression, OperatorEnum using LossFunctions: SupervisedLoss import ..MutationWeightsModule: MutationWeights @@ -124,10 +125,12 @@ struct Options{ CM<:ComplexityMapping, OP<:AbstractOperatorEnum, N<:AbstractExpressionNode, + E<:AbstractExpression, + EO<:NamedTuple, _turbo, _bumper, _return_state, - W, + AD, } operators::OP bin_constraints::Vector{Tuple{Int,Int}} @@ -135,7 +138,6 @@ struct Options{ complexity_mapping::CM tournament_selection_n::Int tournament_selection_p::Float32 - tournament_selection_weights::W parsimony::Float32 dimensional_constraint_penalty::Union{Float32,Nothing} dimensionless_constants_only::Bool @@ -175,12 +177,15 @@ struct Options{ elementwise_loss::Union{SupervisedLoss,Function} loss_function::Union{Nothing,Function} node_type::Type{N} + expression_type::Type{E} + expression_options::EO progress::Union{Bool,Nothing} terminal_width::Union{Int,Nothing} optimizer_algorithm::Optim.AbstractOptimizer optimizer_probability::Float32 optimizer_nrestarts::Int optimizer_options::Optim.Options + autodiff_backend::AD recorder_file::String prob_pick_first::Float32 early_stop_condition::Union{Function,Nothing} @@ -218,4 +223,19 @@ function Base.print(io::IO, options::Options) end Base.show(io::IO, ::MIME"text/plain", options::Options) = Base.print(io, options) +@unstable function specialized_options(options::Options) + return _specialized_options(options) +end +@generated function _specialized_options(options::O) where {O<:Options} + # Return an options struct with concrete operators + type_parameters = O.parameters + fields = Any[:(getfield(options, $(QuoteNode(k)))) for k in fieldnames(O)] + quote + operators = getfield(options, :operators) + Options{$(type_parameters[1]),typeof(operators),$(type_parameters[3:end]...)}( + $(fields...) + ) + end +end + end diff --git a/src/PopMember.jl b/src/PopMember.jl index d47042c20..84f29f451 100644 --- a/src/PopMember.jl +++ b/src/PopMember.jl @@ -1,15 +1,14 @@ module PopMemberModule using DispatchDoctor: @unstable - -using DynamicExpressions: AbstractExpressionNode, copy_node, count_nodes -using ..CoreModule: Options, Dataset, DATA_TYPE, LOSS_TYPE +using DynamicExpressions: AbstractExpression, AbstractExpressionNode, string_tree +using ..CoreModule: Options, Dataset, DATA_TYPE, LOSS_TYPE, create_expression import ..ComplexityModule: compute_complexity using ..UtilsModule: get_birth_order using ..LossFunctionsModule: score_func # Define a member of population by equation, score, and age -mutable struct PopMember{T<:DATA_TYPE,L<:LOSS_TYPE,N<:AbstractExpressionNode{T}} +mutable struct PopMember{T<:DATA_TYPE,L<:LOSS_TYPE,N<:AbstractExpression{T}} tree::N score::L # Inludes complexity penalty, normalization loss::L # Raw loss @@ -33,11 +32,20 @@ end ) return getfield(member, field) end +function Base.show(io::IO, p::PopMember{T,L,N}) where {T,L,N} + shower(x) = sprint(show, x) + print(io, "PopMember(") + print(io, "tree = (", string_tree(p.tree), "), ") + print(io, "loss = ", shower(p.loss), ", ") + print(io, "score = ", shower(p.score)) + print(io, ")") + return nothing +end generate_reference() = abs(rand(Int)) """ - PopMember(t::AbstractExpressionNode{T}, score::L, loss::L) + PopMember(t::AbstractExpression{T}, score::L, loss::L) Create a population member with a birth date at the current time. The type of the `Node` may be different from the type of the score @@ -45,23 +53,30 @@ and loss. # Arguments -- `t::AbstractExpressionNode{T}`: The tree for the population member. +- `t::AbstractExpression{T}`: The tree for the population member. - `score::L`: The score (normalized to a baseline, and offset by a complexity penalty) - `loss::L`: The raw loss to assign. """ function PopMember( - t::AbstractExpressionNode{T}, + t::AbstractExpression{T}, score::L, loss::L, - options::Options, + options::Union{Options,Nothing}=nothing, complexity::Union{Int,Nothing}=nothing; ref::Int=-1, parent::Int=-1, - deterministic=false, + deterministic=nothing, ) where {T<:DATA_TYPE,L<:LOSS_TYPE} if ref == -1 ref = generate_reference() end + if !(deterministic isa Bool) + throw( + ArgumentError( + "You must declare `deterministic` as `true` or `false`, it cannot be left undefined.", + ), + ) + end complexity = complexity === nothing ? -1 : complexity return PopMember{T,L,typeof(t)}( t, @@ -75,8 +90,11 @@ function PopMember( end """ - PopMember(dataset::Dataset{T,L}, - t::AbstractExpressionNode{T}, options::Options) + PopMember( + dataset::Dataset{T,L}, + t::AbstractExpression{T}, + options::Options + ) Create a population member with a birth date at the current time. Automatically compute the score for this tree. @@ -84,23 +102,24 @@ Automatically compute the score for this tree. # Arguments - `dataset::Dataset{T,L}`: The dataset to evaluate the tree on. -- `t::AbstractExpressionNode{T}`: The tree for the population member. +- `t::AbstractExpression{T}`: The tree for the population member. - `options::Options`: What options to use. """ function PopMember( dataset::Dataset{T,L}, - t::AbstractExpressionNode{T}, + tree::Union{AbstractExpressionNode{T},AbstractExpression{T}}, options::Options, complexity::Union{Int,Nothing}=nothing; ref::Int=-1, parent::Int=-1, deterministic=nothing, ) where {T<:DATA_TYPE,L<:LOSS_TYPE} - set_complexity = complexity === nothing ? compute_complexity(t, options) : complexity + ex = create_expression(tree, options, dataset) + set_complexity = complexity === nothing ? compute_complexity(ex, options) : complexity @assert set_complexity != -1 - score, loss = score_func(dataset, t, options; complexity=set_complexity) + score, loss = score_func(dataset, ex, options; complexity=set_complexity) return PopMember( - t, + ex, score, loss, options, diff --git a/src/Population.jl b/src/Population.jl index 67e90b3a9..9bc1df309 100644 --- a/src/Population.jl +++ b/src/Population.jl @@ -1,19 +1,18 @@ module PopulationModule using StatsBase: StatsBase -using Random: randperm using DispatchDoctor: @unstable -using DynamicExpressions: AbstractExpressionNode, Node, string_tree +using DynamicExpressions: AbstractExpression, string_tree using ..CoreModule: Options, Dataset, RecordType, DATA_TYPE, LOSS_TYPE using ..ComplexityModule: compute_complexity using ..LossFunctionsModule: score_func, update_baseline_loss! using ..AdaptiveParsimonyModule: RunningSearchStatistics using ..MutationFunctionsModule: gen_random_tree using ..PopMemberModule: PopMember -using ..UtilsModule: bottomk_fast, argmin_fast +using ..UtilsModule: bottomk_fast, argmin_fast, PerThreadCache # A list of members of the population, with easy constructors, # which allow for random generation of new populations -struct Population{T<:DATA_TYPE,L<:LOSS_TYPE,N<:AbstractExpressionNode{T}} +struct Population{T<:DATA_TYPE,L<:LOSS_TYPE,N<:AbstractExpression{T}} members::Array{PopMember{T,L,N},1} n::Int end @@ -147,7 +146,7 @@ function _best_of_sample( argmin_fast(scores) else # First, decide what place we take (usually 1st place wins): - tournament_winner = StatsBase.sample(options.tournament_selection_weights) + tournament_winner = StatsBase.sample(get_tournament_selection_weights(options)) # Then, find the member that won that place, given # their fitness: if tournament_winner == 1 @@ -159,6 +158,26 @@ function _best_of_sample( return members[chosen_idx] end +const CACHED_WEIGHTS = + let init_k = collect(0:5), + init_prob_each = 0.5f0 * (1 - 0.5f0) .^ init_k, + test_weights = StatsBase.Weights(init_prob_each, sum(init_prob_each)) + + PerThreadCache{Dict{Tuple{Int,Float32},typeof(test_weights)}}() + end + +@unstable function get_tournament_selection_weights(@nospecialize(options::Options)) + n = options.tournament_selection_n + p = options.tournament_selection_p + # Computing the weights for the tournament becomes quite expensive, + return get!(CACHED_WEIGHTS, (n, p)) do + k = collect(0:(n - 1)) + prob_each = p * ((1 - p) .^ k) + + return StatsBase.Weights(prob_each, sum(prob_each)) + end +end + function finalize_scores( dataset::Dataset{T,L}, pop::P, options::Options )::Tuple{P,Float64} where {T,L,P<:Population{T,L}} diff --git a/src/Recorder.jl b/src/Recorder.jl index d7bf1f668..a25ac0e78 100644 --- a/src/Recorder.jl +++ b/src/Recorder.jl @@ -1,6 +1,6 @@ module RecorderModule -using ..CoreModule: RecordType, Options +using ..CoreModule: RecordType "Assumes that `options` holds the user options::Options" macro recorder(ex) diff --git a/src/RegularizedEvolution.jl b/src/RegularizedEvolution.jl index f00de16b9..141e85888 100644 --- a/src/RegularizedEvolution.jl +++ b/src/RegularizedEvolution.jl @@ -2,7 +2,6 @@ module RegularizedEvolutionModule using DynamicExpressions: string_tree using ..CoreModule: Options, Dataset, RecordType, DATA_TYPE, LOSS_TYPE -using ..PopMemberModule: PopMember using ..PopulationModule: Population, best_of_sample using ..AdaptiveParsimonyModule: RunningSearchStatistics using ..MutateModule: next_generation, crossover_generation diff --git a/src/SearchUtils.jl b/src/SearchUtils.jl index fb2da4876..a540e35f6 100644 --- a/src/SearchUtils.jl +++ b/src/SearchUtils.jl @@ -4,20 +4,19 @@ This includes: process management, stdin reading, checking for early stops.""" module SearchUtilsModule using Printf: @printf, @sprintf -using Distributed +using Distributed: Distributed, @spawnat, Future, procs using StatsBase: mean using DispatchDoctor: @unstable -using DynamicExpressions: AbstractExpressionNode, string_tree +using DynamicExpressions: AbstractExpression, string_tree using ..UtilsModule: subscriptify using ..CoreModule: Dataset, Options, MAX_DEGREE, RecordType using ..ComplexityModule: compute_complexity using ..PopulationModule: Population using ..PopMemberModule: PopMember -using ..HallOfFameModule: - HallOfFame, calculate_pareto_frontier, string_dominating_pareto_curve +using ..HallOfFameModule: HallOfFame, string_dominating_pareto_curve using ..ProgressBarsModule: WrappedProgressBar, set_multiline_postfix!, manually_iterate! -using ..AdaptiveParsimonyModule: update_frequencies!, RunningSearchStatistics +using ..AdaptiveParsimonyModule: RunningSearchStatistics """ RuntimeOptions{N,PARALLELISM,DIM_OUT,RETURN_STATE} @@ -129,11 +128,26 @@ end function init_dummy_pops( npops::Int, datasets::Vector{D}, options::Options ) where {T,L,D<:Dataset{T,L}} + prototype = Population( + first(datasets); + population_size=1, + options=options, + nfeatures=first(datasets).nfeatures, + ) + # ^ Due to occasional inference issue, we manually specify the return type return [ - [ - Population(d; population_size=1, options=options, nfeatures=d.nfeatures) for - _ in 1:npops - ] for d in datasets + typeof(prototype)[ + if (i == 1 && j == 1) + prototype + else + Population( + datasets[j]; + population_size=1, + options=options, + nfeatures=datasets[j].nfeatures, + ) + end for i in 1:npops + ] for j in 1:length(datasets) ] end @@ -211,76 +225,52 @@ function check_max_evals(num_evals, options::Options)::Bool return options.max_evals !== nothing && options.max_evals::Int <= sum(sum, num_evals) end -const TIME_TYPE = Float64 +""" +This struct is used to monitor resources. -"""This struct is used to monitor resources.""" +Whenever we check a channel, we record if it was empty or not. +This gives us a measure for how much of a bottleneck there is +at the head worker. +""" Base.@kwdef mutable struct ResourceMonitor - """The time the search started.""" - absolute_start_time::TIME_TYPE = time() - """The time the head worker started doing work.""" - start_work::TIME_TYPE = Inf - """The time the head worker finished doing work.""" - stop_work::TIME_TYPE = Inf - - num_starts::UInt = 0 - num_stops::UInt = 0 - work_intervals::Vector{TIME_TYPE} = TIME_TYPE[] - rest_intervals::Vector{TIME_TYPE} = TIME_TYPE[] - - """Number of intervals to store.""" - num_intervals_to_store::Int -end - -function start_work_monitor!(monitor::ResourceMonitor) - monitor.start_work = time() - monitor.num_starts += 1 - if monitor.num_stops > 0 - push!(monitor.rest_intervals, monitor.start_work - monitor.stop_work) - if length(monitor.rest_intervals) > monitor.num_intervals_to_store - popfirst!(monitor.rest_intervals) - end - end - return nothing + population_ready::Vector{Bool} = Bool[] + max_recordings::Int + start_reporting_at::Int + window_size::Int end -function stop_work_monitor!(monitor::ResourceMonitor) - monitor.stop_work = time() - push!(monitor.work_intervals, monitor.stop_work - monitor.start_work) - monitor.num_stops += 1 - @assert monitor.num_stops == monitor.num_starts - if length(monitor.work_intervals) > monitor.num_intervals_to_store - popfirst!(monitor.work_intervals) +function record_channel_state!(monitor::ResourceMonitor, state) + push!(monitor.population_ready, state) + if length(monitor.population_ready) > monitor.max_recordings + popfirst!(monitor.population_ready) end return nothing end function estimate_work_fraction(monitor::ResourceMonitor)::Float64 - if monitor.num_stops <= 1 + if length(monitor.population_ready) <= monitor.start_reporting_at return 0.0 # Can't estimate from only one interval, due to JIT. end - work_intervals = monitor.work_intervals - rest_intervals = monitor.rest_intervals - # Trim 1st, in case we are still in the first interval. - if monitor.num_stops <= monitor.num_intervals_to_store + 1 - work_intervals = work_intervals[2:end] - rest_intervals = rest_intervals[2:end] - end - return mean(work_intervals) / (mean(work_intervals) + mean(rest_intervals)) + return mean(monitor.population_ready[(end - (monitor.window_size - 1)):end]) end function get_load_string(; head_node_occupation::Float64, parallelism=:serial) - parallelism == :serial && return "" - out = @sprintf("Head worker occupation: %.1f%%", head_node_occupation * 100) - - raise_usage_warning = head_node_occupation > 0.4 - if raise_usage_warning - out *= "." - out *= " This is high, and will prevent efficient resource usage." - out *= " Increase `ncycles_per_iteration` to reduce load on head worker." + if parallelism == :serial || head_node_occupation == 0.0 + return "" end + return "" + ## TODO: Debug why populations are always ready + # out = @sprintf("Head worker occupation: %.1f%%", head_node_occupation * 100) + + # raise_usage_warning = head_node_occupation > 0.4 + # if raise_usage_warning + # out *= "." + # out *= " This is high, and will prevent efficient resource usage." + # out *= " Increase `ncycles_per_iteration` to reduce load on head worker." + # end - out *= "\n" - return out + # out *= "\n" + # return out end function update_progress_bar!( @@ -386,9 +376,7 @@ The state of a search, including the populations, worker outputs, tasks, and channels. This is used to manage the search and keep track of runtime variables in a single struct. """ -Base.@kwdef struct SearchState{ - T,L,N<:AbstractExpressionNode{T},WorkerOutputType,ChannelType -} +Base.@kwdef struct SearchState{T,L,N<:AbstractExpression{T},WorkerOutputType,ChannelType} procs::Vector{Int} we_created_procs::Bool worker_output::Vector{Vector{WorkerOutputType}} @@ -478,6 +466,7 @@ function construct_datasets( y_variable_names, X_units, y_units, + extra, ::Type{L}, ) where {L} nout = size(y, 1) @@ -486,6 +475,7 @@ function construct_datasets( X, y[j, :], L; + index=j, weights=(weights === nothing ? weights : weights[j, :]), variable_names=variable_names, display_variable_names=display_variable_names, @@ -506,6 +496,7 @@ function construct_datasets( end, X_units=X_units, y_units=isa(y_units, AbstractVector) ? y_units[j] : y_units, + extra=extra, ) for j in 1:nout ] end diff --git a/src/SingleIteration.jl b/src/SingleIteration.jl index 2ae3428ee..582f25b15 100644 --- a/src/SingleIteration.jl +++ b/src/SingleIteration.jl @@ -1,17 +1,12 @@ module SingleIterationModule -using DynamicExpressions: - AbstractExpressionNode, - Node, - constructorof, - string_tree, - simplify_tree!, - combine_operators +using ADTypes: AutoEnzyme +using DynamicExpressions: AbstractExpression, string_tree, simplify_tree!, combine_operators using ..UtilsModule: @threads_if -using ..CoreModule: Options, Dataset, RecordType, DATA_TYPE, LOSS_TYPE +using ..CoreModule: Options, Dataset, RecordType, create_expression using ..ComplexityModule: compute_complexity -using ..PopMemberModule: PopMember, generate_reference -using ..PopulationModule: Population, finalize_scores, best_sub_pop +using ..PopMemberModule: generate_reference +using ..PopulationModule: Population, finalize_scores using ..HallOfFameModule: HallOfFame using ..AdaptiveParsimonyModule: RunningSearchStatistics using ..RegularizedEvolutionModule: reg_evol_cycle @@ -32,22 +27,20 @@ function s_r_cycle( record::RecordType, )::Tuple{ P,HallOfFame{T,L,N},Float64 -} where {T,L,D<:Dataset{T,L},N<:AbstractExpressionNode{T},P<:Population{T,L,N}} +} where {T,L,D<:Dataset{T,L},N<:AbstractExpression{T},P<:Population{T,L,N}} max_temp = 1.0 min_temp = 0.0 if !options.annealing min_temp = max_temp end all_temperatures = LinRange(max_temp, min_temp, ncycles) - best_examples_seen = HallOfFame(options, T, L) + best_examples_seen = HallOfFame(options, dataset) num_evals = 0.0 # For evaluating on a fixed batch (for batching) idx = options.batching ? batch_sample(dataset, options) : Int[] - loss_cache = [ - (oid=constructorof(typeof(member.tree))(T; val=zero(T)), score=zero(L)) for - member in pop.members - ] + example_tree = create_expression(zero(T), options, dataset) + loss_cache = [(oid=example_tree, score=zero(L)) for member in pop.members] first_loop = true for temperature in all_temperatures @@ -109,13 +102,14 @@ function optimize_and_simplify_population( )::Tuple{P,Float64} where {T,L,D<:Dataset{T,L},P<:Population{T,L}} array_num_evals = zeros(Float64, pop.n) do_optimization = rand(pop.n) .< options.optimizer_probability - @threads_if !(options.deterministic) for j in 1:(pop.n) + # Note: we have to turn off this threading loop due to Enzyme, since we need + # to manually allocate a new task with a larger stack for Enzyme. + should_thread = !(options.deterministic) && !(isa(options.autodiff_backend, AutoEnzyme)) + @threads_if should_thread for j in 1:(pop.n) if options.should_simplify tree = pop.members[j].tree tree = simplify_tree!(tree, options.operators) - if tree isa Node - tree = combine_operators(tree, options.operators) - end + tree = combine_operators(tree, options.operators) pop.members[j].tree = tree end if options.should_optimize_constants && do_optimization[j] diff --git a/src/SymbolicRegression.jl b/src/SymbolicRegression.jl index f051c25da..8dfb53837 100644 --- a/src/SymbolicRegression.jl +++ b/src/SymbolicRegression.jl @@ -9,8 +9,14 @@ export Population, MutationWeights, Node, GraphNode, + ParametricNode, + Expression, + ParametricExpression, + StructuredExpression, NodeSampler, + AbstractExpression, AbstractExpressionNode, + EvalOptions, SRRegressor, MultitargetSRRegressor, LOSS_TYPE, @@ -22,6 +28,8 @@ export Population, calculate_pareto_frontier, count_nodes, compute_complexity, + @parse_expression, + parse_expression, print_tree, string_tree, eval_tree_array, @@ -31,6 +39,7 @@ export Population, set_node!, copy_node, node_to_symbolic, + node_type, symbolic_to_node, simplify_tree!, tree_mapreduce, @@ -38,6 +47,9 @@ export Population, gen_random_tree, gen_random_tree_fixed_size, @extend_operators, + get_tree, + get_contents, + get_metadata, #Operators plus, @@ -76,18 +88,28 @@ using Reexport using DynamicExpressions: Node, GraphNode, + ParametricNode, + Expression, + ParametricExpression, + StructuredExpression, NodeSampler, + AbstractExpression, AbstractExpressionNode, + @parse_expression, + parse_expression, copy_node, set_node!, string_tree, print_tree, count_nodes, get_constants, - set_constants, + get_scalar_constants, + set_constants!, + set_scalar_constants!, index_constants, NodeIndex, eval_tree_array, + EvalOptions, differentiable_eval_tree_array, eval_diff_tree_array, eval_grad_tree_array, @@ -96,8 +118,12 @@ using DynamicExpressions: combine_operators, simplify_tree!, tree_mapreduce, - set_default_variable_names! -using DynamicExpressions.EquationModule: with_type_parameters + set_default_variable_names!, + node_type, + get_tree, + get_contents, + get_metadata +using DynamicExpressions: with_type_parameters @reexport using LossFunctions: MarginLoss, DistanceLoss, @@ -173,6 +199,7 @@ using DispatchDoctor: @stable include("ProgressBars.jl") include("Migration.jl") include("SearchUtils.jl") + include("ExpressionBuilder.jl") end using .CoreModule: @@ -207,8 +234,9 @@ using .CoreModule: gamma, erf, erfc, - atanh_clip -using .UtilsModule: is_anonymous_function, recursive_merge, json3_write + atanh_clip, + create_expression +using .UtilsModule: is_anonymous_function, recursive_merge, json3_write, @ignore using .ComplexityModule: compute_complexity using .CheckConstraintsModule: check_constraints using .AdaptiveParsimonyModule: @@ -246,8 +274,7 @@ using .SearchUtilsModule: check_for_timeout, check_max_evals, ResourceMonitor, - start_work_monitor!, - stop_work_monitor!, + record_channel_state!, estimate_work_fraction, update_progress_bar!, print_search_state, @@ -258,6 +285,7 @@ using .SearchUtilsModule: save_to_file, get_cur_maxsize, update_hall_of_fame! +using .ExpressionBuilderModule: embed_metadata, strip_metadata @stable default_mode = "disable" begin include("deprecates.jl") @@ -379,6 +407,7 @@ function equation_search( progress::Union{Bool,Nothing}=nothing, X_units::Union{AbstractVector,Nothing}=nothing, y_units=nothing, + extra::NamedTuple=NamedTuple(), v_dim_out::Val{DIM_OUT}=Val(nothing), # Deprecated: multithreaded=nothing, @@ -406,6 +435,7 @@ function equation_search( y_variable_names, X_units, y_units, + extra, L, ) @@ -589,7 +619,7 @@ function equation_search( ) end -@stable default_mode = "disable" @noinline function _equation_search( +@noinline function _equation_search( datasets::Vector{D}, ropt::RuntimeOptions, options::Options, saved_state ) where {D<:Dataset} _validate_options(datasets, ropt, options) @@ -598,7 +628,7 @@ end _warmup_search!(state, datasets, ropt, options) _main_search_loop!(state, datasets, ropt, options) _tear_down!(state, ropt, options) - return _format_output(state, ropt) + return _format_output(state, datasets, ropt, options) end function _validate_options( @@ -641,7 +671,8 @@ end nout = length(datasets) example_dataset = first(datasets) - NT = with_type_parameters(options.node_type, T) + example_ex = create_expression(zero(T), options, example_dataset) + NT = typeof(example_ex) PopType = Population{T,L,NT} HallOfFameType = HallOfFame{T,L,NT} WorkerOutputType = get_worker_output_type( @@ -698,9 +729,7 @@ end for j in 1:nout ] - return SearchState{ - T,L,with_type_parameters(options.node_type, T),WorkerOutputType,ChannelType - }(; + return SearchState{T,L,typeof(example_ex),WorkerOutputType,ChannelType}(; procs=procs, we_created_procs=we_created_procs, worker_output=worker_output, @@ -727,13 +756,13 @@ function _initialize_search!( init_hall_of_fame = load_saved_hall_of_fame(saved_state) if init_hall_of_fame === nothing for j in 1:nout - state.halls_of_fame[j] = HallOfFame(options, T, L) + state.halls_of_fame[j] = HallOfFame(options, datasets[j]) end else # Recompute losses for the hall of fame, in # case the dataset changed: for j in eachindex(init_hall_of_fame, datasets, state.halls_of_fame) - hof = init_hall_of_fame[j] + hof = strip_metadata(init_hall_of_fame[j], options, datasets[j]) for member in hof.members[hof.exists] score, result_loss = score_func(datasets[j], member, options) member.score = score @@ -750,17 +779,17 @@ function _initialize_search!( saved_pop = load_saved_population(saved_state; out=j, pop=i) new_pop = if saved_pop !== nothing && length(saved_pop.members) == options.population_size - saved_pop::Population{T,L,N} + _saved_pop = strip_metadata(saved_pop, options, datasets[j]) ## Update losses: - for member in saved_pop.members + for member in _saved_pop.members score, result_loss = score_func(datasets[j], member, options) member.score = score member.loss = result_loss end - copy_pop = copy(saved_pop) + copy_pop = copy(_saved_pop) @sr_spawner( begin - (copy_pop, HallOfFame(options, T, L), RecordType(), 0.0) + (copy_pop, HallOfFame(options, datasets[j]), RecordType(), 0.0) end, parallelism = ropt.parallelism, worker_idx = worker_idx @@ -779,7 +808,7 @@ function _initialize_search!( options=options, nfeatures=datasets[j].nfeatures, ), - HallOfFame(options, T, L), + HallOfFame(options, datasets[j]), RecordType(), Float64(options.population_size), ) @@ -863,10 +892,11 @@ function _main_search_loop!( end kappa = 0 resource_monitor = ResourceMonitor(; - absolute_start_time=time(), # Storing n times as many monitoring intervals as populations seems like it will # help get accurate resource estimates: - num_intervals_to_store=options.populations * 100 * nout, + max_recordings=options.populations * 100 * nout, + start_reporting_at=options.populations * 3 * nout, + window_size=options.populations * 2 * nout, ) while sum(state.cycles_remaining) > 0 kappa += 1 @@ -890,11 +920,12 @@ function _main_search_loop!( else true end + record_channel_state!(resource_monitor, population_ready) + # Don't start more if this output has finished its cycles: # TODO - this might skip extra cycles? population_ready &= (state.cycles_remaining[j] > 0) if population_ready - start_work_monitor!(resource_monitor) # Take the fetch operation from the channel since its ready (cur_pop, best_seen, cur_record, cur_num_evals) = if ropt.parallelism in ( @@ -989,7 +1020,6 @@ function _main_search_loop!( state.cur_maxsizes[j] = get_cur_maxsize(; options, ropt.total_cycles, cycles_remaining=state.cycles_remaining[j] ) - stop_work_monitor!(resource_monitor) move_window!(state.all_running_search_statistics[j]) if ropt.progress head_node_occupation = estimate_work_fraction(resource_monitor) @@ -1076,10 +1106,20 @@ function _tear_down!(state::SearchState, ropt::RuntimeOptions, options::Options) @recorder json3_write(state.record[], options.recorder_file) return nothing end -function _format_output(state::SearchState, ropt::RuntimeOptions) - out_hof = (ropt.dim_out == 1 ? only(state.halls_of_fame) : state.halls_of_fame) +function _format_output( + state::SearchState, datasets, ropt::RuntimeOptions, options::Options +) + nout = length(datasets) + out_hof = if ropt.dim_out == 1 + embed_metadata(only(state.halls_of_fame), options, only(datasets)) + else + map(j -> embed_metadata(state.halls_of_fame[j], options, datasets[j]), 1:nout) + end if ropt.return_state - return (state.last_pops, out_hof) + return ( + map(j -> embed_metadata(state.last_pops[j], options, datasets[j]), 1:nout), + out_hof, + ) else return out_hof end @@ -1135,10 +1175,12 @@ function __init__() @require_extensions end -macro ignore(args...) end # Hack to get static analysis to work from within tests: @ignore include("../test/runtests.jl") +# TODO: Hack to force ConstructionBase version +using ConstructionBase: ConstructionBase as _ + include("precompile.jl") redirect_stdout(devnull) do redirect_stderr(devnull) do diff --git a/src/Utils.jl b/src/Utils.jl index 9636252e6..a667b6987 100644 --- a/src/Utils.jl +++ b/src/Utils.jl @@ -2,7 +2,9 @@ module UtilsModule using Printf: @printf -using MacroTools: splitdef, combinedef +using MacroTools: splitdef + +macro ignore(args...) end const pseudo_time = Ref(0) @@ -190,6 +192,49 @@ macro constfield(ex) return esc(VERSION < v"1.8.0" ? ex : Expr(:const, ex)) end +json3_write(args...) = error("Please load the JSON3.jl package.") + +""" + PerThreadCache{T} + +A cache that is efficient for multithreaded code, and works +by having a separate cache for each thread. This allows +us to avoid repeated locking. We only need to lock the cache +when resizing to the number of threads. +""" +struct PerThreadCache{T} + x::Vector{T} + num_threads::Ref{Int} + lock::Threads.SpinLock + + PerThreadCache{T}() where {T} = new(Vector{T}(undef, 1), Ref(1), Threads.SpinLock()) +end + +function _get_thread_cache(cache::PerThreadCache{T}) where {T} + if cache.num_threads[] < Threads.nthreads() + Base.@lock cache.lock begin + # The reason we have this extra `.len[]` parameter is to avoid + # a race condition between a thread resizing the array concurrent + # to the check above. Basically we want to make sure the array is + # always big enough by the time we get to using it. Since `.len[]` + # is set last, we can safely use the array. + if cache.num_threads[] < Threads.nthreads() + resize!(cache.x, Threads.nthreads()) + cache.num_threads[] = Threads.nthreads() + end + end + end + threadid = Threads.threadid() + if !isassigned(cache.x, threadid) + cache.x[threadid] = eltype(cache.x)() + end + return cache.x[threadid] +end +function Base.get!(f::F, cache::PerThreadCache, key) where {F<:Function} + thread_cache = _get_thread_cache(cache) + return get!(f, thread_cache, key) +end + # https://discourse.julialang.org/t/performance-of-hasmethod-vs-try-catch-on-methoderror/99827/14 # Faster way to catch method errors: @enum IsGood::Int8 begin @@ -197,29 +242,26 @@ end Bad Undefined end -const SafeFunctions = Dict{Type,IsGood}() -const SafeFunctionsLock = Threads.SpinLock() +const SafeFunctions = PerThreadCache{Dict{Type,IsGood}}() function safe_call(f::F, x::T, default::D) where {F,T<:Tuple,D} - status = get(SafeFunctions, Tuple{F,T}, Undefined) + thread_cache = _get_thread_cache(SafeFunctions) + status = get(thread_cache, Tuple{F,T}, Undefined) status == Good && return (f(x...)::D, true) status == Bad && return (default, false) - return lock(SafeFunctionsLock) do - output = try - (f(x...)::D, true) - catch e - !isa(e, MethodError) && rethrow(e) - (default, false) - end - if output[2] - SafeFunctions[Tuple{F,T}] = Good - else - SafeFunctions[Tuple{F,T}] = Bad - end - return output + + output = try + (f(x...)::D, true) + catch e + !isa(e, MethodError) && rethrow(e) + (default, false) + end + if output[2] + thread_cache[Tuple{F,T}] = Good + else + thread_cache[Tuple{F,T}] = Bad end + return output end -json3_write(args...) = error("Please load the JSON3.jl package.") - end diff --git a/test/Project.toml b/test/Project.toml index b6f96c2d9..e525609a7 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,10 +1,12 @@ [deps] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" Bumper = "8ce10254-0962-460f-a3d8-1f77fea1446e" +DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" DispatchDoctor = "8d63f2c5-f18a-4cf2-ba9d-b3f60fc568c8" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" DynamicExpressions = "a40a106e-89c9-4ca8-8020-a735e8728b6b" DynamicQuantities = "06fc5a27-2a28-4c7c-a15d-362465fb6821" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" @@ -25,3 +27,6 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" TestItemRunner = "f8b46487-2199-4994-9208-9a1283c18c0a" TestItems = "1c621080-faea-4a02-84b6-bbd5e436b8fe" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" + +[preferences.SymbolicRegression] +instability_check = "error" diff --git a/test/runtests.jl b/test/runtests.jl index dd2209874..362894174 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,167 +2,176 @@ using TestItems: @testitem using TestItemRunner: @run_package_tests ENV["SYMBOLIC_REGRESSION_TEST"] = "true" -tags_to_run = let t = get(ENV, "SYMBOLIC_REGRESSION_TEST_SUITE", "unit,integration") +tags_to_run = let t = get(ENV, "SYMBOLIC_REGRESSION_TEST_SUITE", "part1,part2,part3") t = split(t, ",") t = map(Symbol, t) t end -@eval @run_package_tests filter = ti -> !isdisjoint(ti.tags, $tags_to_run) +@eval @run_package_tests filter = ti -> !isdisjoint(ti.tags, $tags_to_run) verbose = true -@testitem "JET tests" tags = [:integration, :jet] begin - test_jet_file = joinpath((@__DIR__), "test_jet.jl") - run(`$(Base.julia_cmd()) --startup-file=no $test_jet_file`) -end - -@testitem "Test custom operators and additional types" tags = [:unit] begin +# TODO: This is a very slow test +@testitem "Test custom operators and additional types" tags = [:part2] begin include("test_operators.jl") end -@testitem "Test tree construction and scoring" tags = [:unit] begin +@testitem "Test tree construction and scoring" tags = [:part3] begin include("test_tree_construction.jl") end -@testitem "Test SymbolicUtils interface" tags = [:unit] begin +include("test_graph_nodes.jl") + +@testitem "Test SymbolicUtils interface" tags = [:part1] begin include("test_symbolic_utils.jl") end -@testitem "Test constraints interface" tags = [:unit] begin +@testitem "Test constraints interface" tags = [:part2] begin include("test_constraints.jl") end -@testitem "Test custom losses" tags = [:unit] begin +@testitem "Test custom losses" tags = [:part1] begin include("test_losses.jl") end -@testitem "Test derivatives" tags = [:unit] begin +@testitem "Test derivatives" tags = [:part2] begin include("test_derivatives.jl") end +include("test_expression_derivatives.jl") -@testitem "Test simplification" tags = [:unit] begin +@testitem "Test simplification" tags = [:part3] begin include("test_simplification.jl") end -@testitem "Test printing" tags = [:unit] begin +@testitem "Test printing" tags = [:part1] begin include("test_print.jl") end -@testitem "Test validity of expression evaluation" tags = [:unit] begin +@testitem "Test validity of expression evaluation" tags = [:part2] begin include("test_evaluation.jl") end -@testitem "Test turbo mode with NaN" tags = [:unit] begin +@testitem "Test turbo mode with NaN" tags = [:part3] begin include("test_turbo_nan.jl") end -@testitem "Test validity of integer expression evaluation" tags = [:unit] begin +@testitem "Test validity of integer expression evaluation" tags = [:part1] begin include("test_integer_evaluation.jl") end -@testitem "Test tournament selection" tags = [:unit] begin +@testitem "Test tournament selection" tags = [:part2] begin include("test_prob_pick_first.jl") end -@testitem "Test crossover mutation" tags = [:unit] begin +@testitem "Test crossover mutation" tags = [:part3] begin include("test_crossover.jl") end -@testitem "Test NaN detection in evaluator" tags = [:unit] begin +# TODO: This is another very slow test +@testitem "Test NaN detection in evaluator" tags = [:part1] begin include("test_nan_detection.jl") end -@testitem "Test nested constraint checking" tags = [:unit] begin +@testitem "Test nested constraint checking" tags = [:part2] begin include("test_nested_constraints.jl") end -@testitem "Test complexity evaluation" tags = [:unit] begin +@testitem "Test complexity evaluation" tags = [:part3] begin include("test_complexity.jl") end -@testitem "Test options" tags = [:unit] begin +@testitem "Test options" tags = [:part1] begin include("test_options.jl") end -@testitem "Test hash of tree" tags = [:unit] begin +@testitem "Test hash of tree" tags = [:part2] begin include("test_hash.jl") end -@testitem "Test migration" tags = [:unit] begin +@testitem "Test migration" tags = [:part3] begin include("test_migration.jl") end -@testitem "Test deprecated options" tags = [:unit] begin +@testitem "Test deprecated options" tags = [:part1] begin include("test_deprecation.jl") end -@testitem "Test optimization mutation" tags = [:unit] begin +@testitem "Test optimization mutation" tags = [:part2] begin include("test_optimizer_mutation.jl") end -@testitem "Test RunningSearchStatistics" tags = [:unit] begin +@testitem "Test RunningSearchStatistics" tags = [:part3] begin include("test_search_statistics.jl") end -@testitem "Test utils" tags = [:unit] begin +@testitem "Test utils" tags = [:part1] begin include("test_utils.jl") end -@testitem "Test units" tags = [:integration] begin - include("test_units.jl") -end +include("test_units.jl") -@testitem "Dataset" tags = [:unit] begin +@testitem "Dataset" tags = [:part3] begin include("test_dataset.jl") end -@testitem "Test mixed settings." tags = [:integration] begin - include("test_mixed.jl") -end +include("test_mixed.jl") -@testitem "Testing fast-cycle and custom variable names" tags = [:integration] begin +@testitem "Testing fast-cycle and custom variable names" tags = [:part2] begin include("test_fast_cycle.jl") end -@testitem "Testing whether we can stop based on clock time." tags = [:integration] begin +@testitem "Testing whether we can stop based on clock time." tags = [:part3] begin include("test_stop_on_clock.jl") end -@testitem "Running README example." tags = [:integration] begin +@testitem "Running README example." tags = [:part1] begin + ENV["SYMBOLIC_REGRESSION_IS_TESTING"] = "true" include("../example.jl") end -@testitem "Testing whether the recorder works." tags = [:integration] begin +# TODO: This is the slowest test. +@testitem "Running parameterized function example." tags = [:part2] begin + ENV["SYMBOLIC_REGRESSION_IS_TESTING"] = "true" + include("../examples/parameterized_function.jl") +end + +@testitem "Testing whether the recorder works." tags = [:part3] begin include("test_recorder.jl") end -@testitem "Testing whether deterministic mode works." tags = [:integration] begin +@testitem "Testing whether deterministic mode works." tags = [:part1] begin include("test_deterministic.jl") end -@testitem "Testing whether early stop criteria works." tags = [:integration] begin +@testitem "Testing whether early stop criteria works." tags = [:part2] begin include("test_early_stop.jl") end -@testitem "Test MLJ integration" tags = [:integration] begin - include("test_mlj.jl") -end +include("test_mlj.jl") -@testitem "Testing whether we can move operators to workers." tags = [:integration] begin +@testitem "Testing whether we can move operators to workers." tags = [:part1] begin include("test_custom_operators_multiprocessing.jl") end -@testitem "Test whether the precompilation script works." tags = [:integration] begin +@testitem "Test whether the precompilation script works." tags = [:part2] begin include("test_precompilation.jl") end -@testitem "Test whether custom objectives work." tags = [:integration] begin +@testitem "Test whether custom objectives work." tags = [:part3] begin include("test_custom_objectives.jl") end -@testitem "Test abstract numbers" tags = [:integration] begin +@testitem "Test abstract numbers" tags = [:part1] begin include("test_abstract_numbers.jl") end -@testitem "Aqua tests" tags = [:integration, :aqua] begin +include("test_pretty_printing.jl") +include("test_expression_builder.jl") + +@testitem "Aqua tests" tags = [:part2, :aqua] begin include("test_aqua.jl") end + +@testitem "JET tests" tags = [:part1, :jet] begin + test_jet_file = joinpath((@__DIR__), "test_jet.jl") + run(`$(Base.julia_cmd()) --startup-file=no $test_jet_file`) +end diff --git a/test/test_derivatives.jl b/test/test_derivatives.jl index ee2d4ddfd..9615a9277 100644 --- a/test/test_derivatives.jl +++ b/test/test_derivatives.jl @@ -122,7 +122,7 @@ end println("Testing NodeIndex.") -using SymbolicRegression: get_constants, NodeIndex, index_constants +using SymbolicRegression: get_scalar_constants, NodeIndex, index_constants options = Options(; binary_operators=(+, *, -, /, pow_abs2), unary_operators=(custom_cos, exp, sin) @@ -144,6 +144,6 @@ function check_tree( end end -@test check_tree(tree, index_constants(tree), get_constants(tree)) +@test check_tree(tree, index_constants(tree), first(get_scalar_constants(tree))) println("Done.") diff --git a/test/test_deterministic.jl b/test/test_deterministic.jl index d541eea70..26e80c08d 100644 --- a/test/test_deterministic.jl +++ b/test/test_deterministic.jl @@ -1,53 +1,33 @@ using SymbolicRegression using Random -macro maybe_inferred(ex) - # Only get stable inference on Julia 1.10+ - return if VERSION >= v"1.10.0-DEV.0" - quote - @inferred $ex - end - else - quote - try - # Still want to test for any bugs (JuliaLang/julia#53761) - @inferred $ex - catch - $ex - end - end - end |> esc -end +X = 2 .* randn(MersenneTwister(0), Float32, 2, 1000) +y = 3 * cos.(X[2, :]) + X[1, :] .^ 2 .- 2 -begin - X = 2 .* randn(MersenneTwister(0), Float32, 2, 1000) - y = 3 * cos.(X[2, :]) + X[1, :] .^ 2 .- 2 +options = SymbolicRegression.Options(; + binary_operators=(+, *, /, -), + unary_operators=(cos,), + crossover_probability=0.0, # required for recording, as not set up to track crossovers. + max_evals=10000, + deterministic=true, + seed=0, + verbosity=0, + progress=false, +) - options = SymbolicRegression.Options(; - binary_operators=(+, *, /, -), - unary_operators=(cos,), - crossover_probability=0.0, # required for recording, as not set up to track crossovers. - max_evals=10000, - deterministic=true, - seed=0, - verbosity=0, - progress=false, +all_outputs = [] +for i in 1:2 + hall_of_fame = equation_search( + X, + y; + niterations=5, + options=options, + parallelism=:serial, + v_dim_out=Val(1), + return_state=Val(false), ) - - all_outputs = [] - for i in 1:2 - hall_of_fame = @maybe_inferred equation_search( - X, - y; - niterations=5, - options=options, - parallelism=:serial, - v_dim_out=Val(1), - return_state=Val(false), - ) - dominating = calculate_pareto_frontier(hall_of_fame) - push!(all_outputs, dominating[end].tree) - end - - @test string(all_outputs[1]) == string(all_outputs[2]) + dominating = calculate_pareto_frontier(hall_of_fame) + push!(all_outputs, dominating[end].tree) end + +@test string(all_outputs[1]) == string(all_outputs[2]) diff --git a/test/test_expression_builder.jl b/test/test_expression_builder.jl new file mode 100644 index 000000000..37b9291f3 --- /dev/null +++ b/test/test_expression_builder.jl @@ -0,0 +1,66 @@ +# This file tests particular functionality of ExpressionBuilderModule +@testitem "ParametricExpression" tags = [:part3] begin + using SymbolicRegression + using SymbolicRegression.ExpressionBuilderModule: + strip_metadata, embed_metadata, init_params + + options = Options() + ex = parse_expression( + :(x1 * p1); + expression_type=ParametricExpression, + operators=options.operators, + parameters=ones(2, 1) * 3, + parameter_names=["p1", "p2"], + variable_names=["x1"], + ) + X = ones(1, 1) * 2 + y = ones(1) + dataset = Dataset(X, y; extra=(; classes=[1])) + + @test ex isa ParametricExpression + @test ex(dataset.X, dataset.extra.classes) ≈ ones(1, 1) * 6 + + # Mistake in that we gave the wrong options! + @test_throws( + AssertionError( + "Need prototype to be of type $(options.expression_type), but got $(ex)::$(typeof(ex))", + ), + init_params(options, dataset, ex, Val(true)) + ) + + options = Options(; + expression_type=ParametricExpression, expression_options=(; max_parameters=2) + ) + + # Mistake in that we also gave the wrong number of parameter names! + pop!(ex.metadata.parameter_names) + @test_throws( + AssertionError( + "Mismatch between options.expression_options.max_parameters=$(options.expression_options.max_parameters) and prototype.metadata.parameter_names=$(ex.metadata.parameter_names)", + ), + init_params(options, dataset, ex, Val(true)) + ) + # So, we fix it: + push!(ex.metadata.parameter_names, "p2") + + @test ex.metadata.parameter_names == ["p1", "p2"] + @test keys(init_params(options, dataset, ex, Val(true))) == + (:operators, :variable_names, :parameters, :parameter_names) + + @test sprint(show, ex) == "x1 * p1" + stripped_ex = strip_metadata(ex, options, dataset) + # Stripping the metadata means that operations like `show` + # do not know what binary operator to use: + @test sprint(show, stripped_ex) == "binary_operator[4](x1, p1)" + + # However, it's important that parametric expressions are still parametric: + @test stripped_ex isa ParametricExpression + # And, that they still have the right parameters: + @test haskey(getfield(stripped_ex.metadata, :_data), :parameters) + @test stripped_ex.metadata.parameters ≈ ones(2, 1) * 3 + + # Now, test that we can embed metadata back in: + embedded_ex = embed_metadata(stripped_ex, options, dataset) + @test embedded_ex isa ParametricExpression + @test ex == embedded_ex +end diff --git a/test/test_expression_derivatives.jl b/test/test_expression_derivatives.jl new file mode 100644 index 000000000..c8cba75ae --- /dev/null +++ b/test/test_expression_derivatives.jl @@ -0,0 +1,142 @@ +@testitem "Test derivatives" tags = [:part1] begin + using SymbolicRegression + using Zygote: Zygote + using Random: MersenneTwister + + ex = @parse_expression( + x * x - cos(2.5 * y), + unary_operators = [cos], + binary_operators = [*, -, +], + variable_names = [:x, :y] + ) + + rng = MersenneTwister(0) + X = rand(rng, 2, 32) + + (δy,) = Zygote.gradient(X) do X + x = @view X[1, :] + y = @view X[2, :] + + sum(i -> x[i] * x[i] - cos(2.5 * y[i]), eachindex(x)) + end + δy_hat = ex'(X) + + @test δy ≈ δy_hat + + options2 = Options(; unary_operators=[sin], binary_operators=[+, *, -]) + (δy2,) = Zygote.gradient(X) do X + x = @view X[1, :] + y = @view X[2, :] + + sum(i -> (x[i] + x[i]) * sin(2.5 + y[i]), eachindex(x)) + end + δy2_hat = ex'(X, options2) + + @test δy2 ≈ δy2_hat +end + +@testitem "Test derivatives during optimization" tags = [:part1] begin + using SymbolicRegression + using SymbolicRegression.ConstantOptimizationModule: Evaluator, GradEvaluator + using DynamicExpressions + using Zygote: Zygote + using Random: MersenneTwister + using DifferentiationInterface: value_and_gradient + + rng = MersenneTwister(0) + X = rand(rng, 2, 32) + y = @. X[1, :] * X[1, :] - cos(2.6 * X[2, :]) + dataset = Dataset(X, y) + + options = Options(; + unary_operators=[cos], binary_operators=[+, *, -], autodiff_backend=:Zygote + ) + + ex = @parse_expression( + x * x - cos(2.5 * y), operators = options.operators, variable_names = [:x, :y] + ) + f = Evaluator(ex, last(get_scalar_constants(ex)), dataset, options, nothing) + fg! = GradEvaluator(f, options.autodiff_backend) + + @test f(first(get_scalar_constants(ex))) isa Float64 + + x = first(get_scalar_constants(ex)) + G = zero(x) + fg!(nothing, G, x) + @test G[] != 0 +end + +@testitem "Test derivatives of parametric expression during optimization" tags = [:part3] begin + using SymbolicRegression + using SymbolicRegression.ConstantOptimizationModule: + Evaluator, GradEvaluator, optimize_constants, specialized_options + using DynamicExpressions + using Zygote: Zygote + using Random: MersenneTwister + using DifferentiationInterface: value_and_gradient, AutoZygote, AutoEnzyme + enzyme_compatible = VERSION >= v"1.10.0" && VERSION < v"1.11.0-DEV.0" + @static if enzyme_compatible + using Enzyme: Enzyme + end + + rng = MersenneTwister(0) + X = rand(rng, 2, 32) + true_params = [0.5 2.0] + init_params = [0.1 0.2] + init_constants = [2.5, -0.5] + classes = rand(rng, 1:2, 32) + y = [ + X[1, i] * X[1, i] - cos(2.6 * X[2, i] - 0.2) + true_params[1, classes[i]] for + i in 1:32 + ] + + dataset = Dataset(X, y; extra=(; classes)) + + (true_val, (true_d_params, true_d_constants)) = + value_and_gradient(AutoZygote(), (init_params, init_constants)) do (params, c) + pred = [ + X[1, i] * X[1, i] - cos(c[1] * X[2, i] + c[2]) + params[1, classes[i]] for + i in 1:32 + ] + sum(abs2, pred .- y) / length(y) + end + + options = Options(; + unary_operators=[cos], binary_operators=[+, *, -], autodiff_backend=:Zygote + ) + + ex = @parse_expression( + x * x - cos(2.5 * y + -0.5) + p1, + operators = options.operators, + expression_type = ParametricExpression, + variable_names = ["x", "y"], + extra_metadata = (parameter_names=["p1"], parameters=init_params) + ) + + function test_backend(ex, @nospecialize(backend); allow_failure=false) + x0, refs = get_scalar_constants(ex) + G = zero(x0) + + f = Evaluator(ex, refs, dataset, specialized_options(options), nothing) + fg! = GradEvaluator(f, backend) + + @test f(x0) ≈ true_val + + try + val = fg!(nothing, G, x0) + @test val ≈ true_val + @test G ≈ vcat(true_d_constants[:], true_d_params[:]) + catch e + if allow_failure + @warn "Expected failure" e + else + rethrow(e) + end + end + end + + test_backend(ex, AutoZygote(); allow_failure=false) + @static if enzyme_compatible + test_backend(ex, AutoEnzyme(); allow_failure=true) + end +end diff --git a/test/test_graph_nodes.jl b/test/test_graph_nodes.jl index 639546dc0..ba4fbb6d3 100644 --- a/test/test_graph_nodes.jl +++ b/test/test_graph_nodes.jl @@ -1,26 +1,133 @@ -using SymbolicRegression +@testitem "GraphNode evaluation" tags = [:part1] begin + using SymbolicRegression -options = Options(; binary_operators=[+, -, *, /], unary_operators=[cos, sin], maxsize=30) + options = Options(; + binary_operators=[+, -, *, /], unary_operators=[cos, sin], maxsize=30 + ) -x1, x2, x3 = [GraphNode(Float64; feature=i) for i in 1:3] + x1, x2, x3 = [GraphNode(Float64; feature=i) for i in 1:3] -base_tree = cos(x1 - 3.2) * x2 - x3 * copy(x3) -tree = sin(base_tree) + base_tree + base_tree = cos(x1 - 3.2) * x2 - x3 * copy(x3) + tree = sin(base_tree) + base_tree -X = randn(3, 50) -z = @. cos(X[1, :] - 3.2) * X[2, :] - X[3, :] * X[3, :] -y = @. sin(z) + z -dataset = Dataset(X, y) + X = randn(3, 50) + z = @. cos(X[1, :] - 3.2) * X[2, :] - X[3, :] * X[3, :] + y = @. sin(z) + z + dataset = Dataset(X, y) -tree(dataset.X, options) + tree(dataset.X, options) -eval_tree_array(tree, dataset.X, options) + eval_tree_array(tree, dataset.X, options) +end -@test compute_complexity(tree, options) == 12 -@test compute_complexity(tree, options; break_sharing=Val(true)) == 22 +@testitem "GraphNode complexity" tags = [:part1] begin + using SymbolicRegression -pop = Population( - dataset, GraphNode{Float64}; nlength=3, options, nfeatures=3, population_size=100 -) + options = Options(; + binary_operators=[+, -, *, /], unary_operators=[cos, sin], maxsize=30 + ) + x1, x2, x3 = [GraphNode(Float64; feature=i) for i in 1:3] -equation_search([dataset], GraphNode; niterations=10000, options) + base_tree = cos(x1 - 3.2) * x2 - x3 * copy(x3) + tree = sin(base_tree) + base_tree + + @test compute_complexity(tree, options) == 12 + @test compute_complexity(tree, options; break_sharing=Val(true)) == 22 +end + +@testitem "GraphNode population" tags = [:part1] begin + using SymbolicRegression + + options = Options(; + binary_operators=[+, -, *, /], + unary_operators=[cos, sin], + maxsize=30, + node_type=GraphNode, + ) + + X = randn(3, 50) + z = @. cos(X[1, :] - 3.2) * X[2, :] - X[3, :] * X[3, :] + y = @. sin(z) + z + dataset = Dataset(X, y) + + pop = Population(dataset; options, nlength=3, nfeatures=3, population_size=100) + @test pop isa Population{T,T,<:Expression{T,<:GraphNode{T}}} where {T} + + # Seems to not work yet: + # equation_search([dataset]; niterations=10, options) +end + +@testitem "GraphNode break connection mutation" tags = [:part1] begin + using SymbolicRegression + using SymbolicRegression.MutationFunctionsModule: break_random_connection! + using Random: MersenneTwister + + options = Options(; + binary_operators=[+, -, *, /], + unary_operators=[cos, sin], + maxsize=30, + node_type=GraphNode, + ) + + x1, x2, x3 = [GraphNode(Float64; feature=i) for i in 1:3] + base_tree = cos(x1 - 3.2) * x2 + tree = sin(base_tree) + base_tree + + ex = Expression(tree; operators=options.operators, variable_names=["x1", "x2", "x3"]) + + s = strip(sprint(print_tree, ex)) + @test s == "sin(cos(x1 - 3.2) * x2) + {(cos(x1 - 3.2) * x2)}" + + rng = MersenneTwister(0) + expressions = [copy(ex) for _ in 1:1000] + expressions = [break_random_connection!(ex, rng) for ex in expressions] + strings = [strip(sprint(print_tree, ex)) for ex in expressions] + strings = unique(strings) + @test Set(strings) == Set([ + "sin(cos(x1 - 3.2) * x2) + {(cos(x1 - 3.2) * x2)}", + "sin(cos(x1 - 3.2) * x2) + (cos(x1 - 3.2) * x2)", + ]) + # Either it breaks the connection or not +end + +@testitem "GraphNode form connection mutation" tags = [:part1] begin + using SymbolicRegression + using SymbolicRegression.MutationFunctionsModule: form_random_connection! + using Random: MersenneTwister + + options = Options(; + binary_operators=[+, -, *, /], + unary_operators=[cos, sin], + maxsize=30, + node_type=GraphNode, + ) + + x1, x2 = [GraphNode{Float64}(; feature=i) for i in 1:2] + + tree = cos(x1 * x2 + 1.5) + ex = Expression(tree; operators=options.operators, variable_names=["x1", "x2"]) + rng = MersenneTwister(0) + expressions = [copy(ex) for _ in 1:3_000] + expressions = [form_random_connection!(ex, rng) for ex in expressions] + strings = [strip(sprint(print_tree, ex)) for ex in expressions] + strings = sort(unique(strings); by=length) + + # All possible connections that can be made + @test Set(strings) == Set([ + "cos(x1)", + "cos(x2)", + "cos(1.5)", + "cos(x1 * x2)", + "cos(x2 + 1.5)", + "cos(x1 + 1.5)", + "cos(1.5 + {1.5})", + "cos((x1 * x2) + 1.5)", + "cos((x1 * x2) + {x2})", + "cos((x1 * x2) + {x1})", + "cos((x2 * {x2}) + 1.5)", + "cos((x1 * {x1}) + 1.5)", + "cos((x1 * 1.5) + {1.5})", + "cos((1.5 * x2) + {1.5})", + "cos((x1 * x2) + {(x1 * x2)})", + ]) +end diff --git a/test/test_jet.jl b/test/test_jet.jl index 28476febb..1521e9ba9 100644 --- a/test/test_jet.jl +++ b/test/test_jet.jl @@ -1,6 +1,7 @@ -if VERSION < v"1.10.0" +if !(VERSION >= v"1.10.0" && VERSION < v"1.11.0-DEV.0") exit(0) end +# TODO: Check why is breaking on 1.11.0 dir = mktempdir() @@ -10,12 +11,17 @@ using Pkg @info "Creating environment..." Pkg.activate(dir; io=devnull) Pkg.develop(; path=dirname(@__DIR__), io=devnull) -Pkg.add(["JET", "Preferences"]; io=devnull) +Pkg.add(["JET", "Preferences", "DynamicExpressions"]; io=devnull) @info "Done!" using Preferences cd(dir) -Preferences.set_preferences!("SymbolicRegression", "instability_check" => "disable") +Preferences.set_preferences!( + "SymbolicRegression", "instability_check" => "disable"; force=true +) +Preferences.set_preferences!( + "DynamicExpressions", "instability_check" => "disable"; force=true +) using SymbolicRegression using JET diff --git a/test/test_migration.jl b/test/test_migration.jl index 13c69753e..b650e7e6f 100644 --- a/test/test_migration.jl +++ b/test/test_migration.jl @@ -1,4 +1,7 @@ using SymbolicRegression +using SymbolicRegression: strip_metadata +using DynamicExpressions: get_tree +using Test using Random: seed! seed!(0) @@ -10,14 +13,20 @@ options = Options(); population1 = Population( X, y; population_size=100, options=options, nfeatures=5, nlength=10 ) +dataset = Dataset(X, y) tree = Node(1, Node(; val=1.0), Node(; feature=2) * 3.2) @test !(hash(tree) in [hash(p.tree) for p in population1.members]) +ex = @parse_expression($tree, operators = options.operators, variable_names = [:x1, :x2],) +ex = strip_metadata(ex, options, dataset) + SymbolicRegression.MigrationModule.migrate!( - [PopMember(tree, 0.0, Inf, options)] => population1, options; frac=0.5 + [PopMember(ex, 0.0, Inf, options; deterministic=false)] => population1, + options; + frac=0.5, ) # Now we see that the tree is in the population: -@test tree in [p.tree for p in population1.members] +@test tree in [get_tree(p.tree) for p in population1.members] diff --git a/test/test_mixed.jl b/test/test_mixed.jl index 1da2b1a3e..a2d047166 100644 --- a/test/test_mixed.jl +++ b/test/test_mixed.jl @@ -1,150 +1,39 @@ -using SymbolicRegression -using SymbolicRegression: string_tree -using Random, Bumper, LoopVectorization -include("test_params.jl") +@testitem "Search with batching & weighted & serial & progress bar & warmup & BFGS" tags = [ + :part1 +] begin + include("test_mixed_utils.jl") + test_mixed(0, true, true, :serial) +end -for i in 0:5 - local options, X, y, tree - batching = i in [0, 1] - weighted = i in [0, 2] +@testitem "Search with multiprocessing & batching & multi-output & use_frequency & string-specified parallelism" tags = [ + :part2 +] begin + include("test_mixed_utils.jl") + test_mixed(1, true, false, :multiprocessing) +end - numprocs = 2 - progress = false - warmup_maxsize_by = 0.0f0 - optimizer_algorithm = "NelderMead" - multi = false - tournament_selection_p = 1.0 - parallelism = :multiprocessing - crossover_probability = 0.0f0 - skip_mutation_failures = false - use_frequency = false - use_frequency_in_tournament = false - turbo = false - bumper = false - T = Float32 - print("Testing with batching=$(batching) and weighted=$(weighted), ") - if i == 0 - println("with serial & progress bar & warmup & BFGS") - numprocs = nothing #Try serial computation here. - parallelism = :serial - progress = true #Also try the progress bar. - warmup_maxsize_by = 0.5f0 #Smaller maxsize at first, build up slowly - optimizer_algorithm = "BFGS" - tournament_selection_p = 0.8 - elseif i == 1 - println("with multi-output and use_frequency and string-specified parallelism.") - multi = true - use_frequency = true - parallelism = "multiprocessing" - elseif i == 3 - println( - "with multi-threading and crossover and use_frequency_in_tournament and bumper=true", - ) - parallelism = :multithreading - numprocs = nothing - crossover_probability = 0.02f0 - use_frequency_in_tournament = true - bumper = true - elseif i == 4 - println( - "with crossover and skip mutation failures and both frequencies options, and Float16 type", - ) - crossover_probability = 0.02f0 - skip_mutation_failures = true - use_frequency = true - use_frequency_in_tournament = true - T = Float16 - elseif i == 5 - println("with default hyperparameters, Float64 type, and turbo=true") - T = Float64 - turbo = true - end - if i == 5 - options = SymbolicRegression.Options(; - unary_operators=(cos,), - batching=batching, - parsimony=0.0f0, # Required for scoring - ) - else - options = SymbolicRegression.Options(; - default_params..., - binary_operators=(+, *), - unary_operators=(cos,), - populations=4, - batching=batching, - crossover_probability=crossover_probability, - skip_mutation_failures=skip_mutation_failures, - seed=0, - progress=progress, - warmup_maxsize_by=warmup_maxsize_by, - optimizer_algorithm=optimizer_algorithm, - tournament_selection_p=tournament_selection_p, - parsimony=0.0f0, - use_frequency=use_frequency, - use_frequency_in_tournament=use_frequency_in_tournament, - turbo=turbo, - bumper=bumper, - ) - end +@testitem "Search with multi-threading & default settings" tags = [:part3] begin + include("test_mixed_utils.jl") + test_mixed(2, false, true, :multithreading) +end - X = randn(MersenneTwister(0), T, 5, 100) - if weighted - mask = rand(100) .> 0.5 - weights = map(x -> convert(T, x), mask) - # Completely different function superimposed - need - # to use correct weights to figure it out! - y = (2 .* cos.(X[4, :])) .* weights .+ (1 .- weights) .* (5 .* X[2, :]) - hallOfFame = equation_search( - X, - y; - weights=weights, - niterations=2, - options=options, - parallelism=parallelism, - numprocs=numprocs, - ) - dominating = [calculate_pareto_frontier(hallOfFame)] - else - y = 2 * cos.(X[4, :]) - niterations = 2 - if multi - # Copy the same output twice; make sure we can find it twice - y = repeat(y, 1, 2) - y = transpose(y) - niterations = 20 - end - hallOfFame = equation_search( - X, - y; - niterations=niterations, - options=options, - parallelism=parallelism, - numprocs=numprocs, - ) - dominating = if multi - [calculate_pareto_frontier(hallOfFame[j]) for j in 1:2] - else - [calculate_pareto_frontier(hallOfFame)] - end - end +@testitem "Search with multi-threading & weighted & crossover & use_frequency_in_tournament & bumper" tags = [ + :part1 +] begin + include("test_mixed_utils.jl") + test_mixed(3, false, false, :multithreading) +end - # For brevity, always assume multi-output in this test: - for dom in dominating - @test length(dom) > 0 - best = dom[end] - # Assert we created the correct type of trees: - @test typeof(best.tree) == Node{T} +@testitem "Search with multi-threading & crossover & skip mutation failures & both frequencies options & Float16 type" tags = [ + :part2 +] begin + include("test_mixed_utils.jl") + test_mixed(4, false, false, :multithreading) +end - # Test the score - @test best.loss < maximum_residual - # Test the actual equation found: - testX = randn(MersenneTwister(1), T, 5, 100) - true_y = 2 * cos.(testX[4, :]) - predicted_y, flag = eval_tree_array(best.tree, testX, options) - @test flag - @test sum(abs, true_y .- predicted_y) < maximum_residual - # eval evaluates inside global - end - - println("Passed.") -end # for i=1... +@testitem "Search with multiprocessing & default hyperparameters & Float64 type & turbo" tags = [ + :part3 +] begin + include("test_mixed_utils.jl") + test_mixed(5, false, false, :multiprocessing) +end diff --git a/test/test_mixed_utils.jl b/test/test_mixed_utils.jl new file mode 100644 index 000000000..2ad9e7636 --- /dev/null +++ b/test/test_mixed_utils.jl @@ -0,0 +1,149 @@ +using SymbolicRegression +using SymbolicRegression: string_tree +using Random, Bumper, LoopVectorization + +include("test_params.jl") + +function test_mixed(i, batching::Bool, weighted::Bool, parallelism) + progress = false + warmup_maxsize_by = 0.0f0 + optimizer_algorithm = "NelderMead" + multi = false + tournament_selection_p = 1.0 + crossover_probability = 0.0f0 + skip_mutation_failures = false + use_frequency = false + use_frequency_in_tournament = false + turbo = false + bumper = false + T = Float32 + Random.seed!(0) + + if i == 0 + progress = true #Also try the progress bar. + warmup_maxsize_by = 0.5f0 #Smaller maxsize at first, build up slowly + optimizer_algorithm = "BFGS" + tournament_selection_p = 0.8 + elseif i == 1 + multi = true + use_frequency = true + elseif i == 3 + crossover_probability = 0.02f0 + use_frequency_in_tournament = true + bumper = true + elseif i == 4 + crossover_probability = 0.02f0 + skip_mutation_failures = true + use_frequency = true + use_frequency_in_tournament = true + T = Float16 + elseif i == 5 + T = Float64 + turbo = true + end + + numprocs = parallelism == :multiprocessing ? 2 : nothing + + options = if i == 5 + SymbolicRegression.Options(; + unary_operators=(cos,), + batching=batching, + parsimony=0.0f0, # Required for scoring + early_stop_condition=1e-6, + ) + else + SymbolicRegression.Options(; + default_params..., + binary_operators=(+, *), + unary_operators=(cos,), + populations=4, + batching=batching, + crossover_probability=crossover_probability, + skip_mutation_failures=skip_mutation_failures, + seed=0, + progress=progress, + warmup_maxsize_by=warmup_maxsize_by, + optimizer_algorithm=optimizer_algorithm, + tournament_selection_p=tournament_selection_p, + parsimony=0.0f0, + use_frequency=use_frequency, + use_frequency_in_tournament=use_frequency_in_tournament, + turbo=turbo, + bumper=bumper, + early_stop_condition=1e-6, + ) + end + + X = randn(MersenneTwister(0), T, 5, 100) + + (y, hallOfFame, dominating) = if weighted + mask = rand(100) .> 0.5 + weights = map(x -> convert(T, x), mask) + # Completely different function superimposed - need + # to use correct weights to figure it out! + y = (2 .* cos.(X[4, :])) .* weights .+ (1 .- weights) .* (5 .* X[2, :]) + hallOfFame = equation_search( + X, + y; + weights=weights, + niterations=2, + options=options, + parallelism=parallelism, + numprocs=numprocs, + ) + dominating = [calculate_pareto_frontier(hallOfFame)] + + (y, hallOfFame, dominating) + else + y = 2 * cos.(X[4, :]) + niterations = 2 + if multi + # Copy the same output twice; make sure we can find it twice + y = repeat(y, 1, 2) + y = transpose(y) + niterations = 20 + end + hallOfFame = equation_search( + X, + y; + niterations=niterations, + options=options, + parallelism=parallelism, + numprocs=numprocs, + ) + dominating = if multi + [calculate_pareto_frontier(hallOfFame[j]) for j in 1:2] + else + [calculate_pareto_frontier(hallOfFame)] + end + + (y, hallOfFame, dominating) + end + + # For brevity, always assume multi-output in this test: + for dom in dominating + @test length(dom) > 0 + best = dom[end] + # Assert we created the correct type of trees: + @test node_type(typeof(best.tree)) == Node{T} + + # Test the score + @test best.loss < maximum_residual + # Test the actual equation found: + testX = randn(MersenneTwister(1), T, 5, 100) + true_y = 2 * cos.(testX[4, :]) + predicted_y, flag = eval_tree_array(best.tree, testX, options) + + @test flag + if parallelism == :multiprocessing && turbo + # TODO: For some reason this test does a bit worse + @test sum(abs, true_y .- predicted_y) < maximum_residual * 50 + else + @test sum(abs, true_y .- predicted_y) < maximum_residual + end + + # eval evaluates inside global + end + + return println("Passed.") +end diff --git a/test/test_mlj.jl b/test/test_mlj.jl index ca0416209..3ef8d7b93 100644 --- a/test/test_mlj.jl +++ b/test/test_mlj.jl @@ -1,22 +1,8 @@ -using SymbolicRegression: SymbolicRegression -using SymbolicRegression: - Node, SRRegressor, MultitargetSRRegressor, node_to_symbolic, symbolic_to_node -using MLJTestInterface: MLJTestInterface as MTI -using MLJBase: machine, fit!, report, predict -using SymbolicUtils: SymbolicUtils -using Suppressor: @capture_err - -macro quiet(ex) - return quote - redirect_stderr(devnull) do - $ex - end - end |> esc -end - -stop_kws = (; early_stop_condition=(loss, complexity) -> loss < 1e-7) +@testitem "Generic interface tests" tags = [:part1] begin + using SymbolicRegression + using MLJTestInterface: MLJTestInterface as MTI + include("test_params.jl") -@testset "Generic interface tests" begin failures, summary = MTI.test( [SRRegressor], MTI.make_regression()...; mod=@__MODULE__, verbosity=0, throw=true ) @@ -32,87 +18,116 @@ stop_kws = (; early_stop_condition=(loss, complexity) -> loss < 1e-7) @test isempty(failures) end -@testset "Variable names" begin - @testset "Single outputs" begin - X = (a=rand(32), b=rand(32)) - y = X.a .^ 2.1 - # We also make sure the deprecated npop and npopulations still work: - model = SRRegressor(; niterations=10, npop=33, npopulations=15, stop_kws...) - mach = machine(model, X, y) - fit!(mach) - rep = report(mach) - @test occursin("a", rep.equation_strings[rep.best_idx]) - ypred_good = predict(mach, X) - @test sum(abs2, predict(mach, X) .- y) / length(y) < 1e-5 - - @testset "Check that we can choose the equation" begin - ypred_same = predict(mach, (data=X, idx=rep.best_idx)) - @test ypred_good == ypred_same - - ypred_bad = predict(mach, (data=X, idx=1)) - @test ypred_good != ypred_bad - end - - @testset "Smoke test SymbolicUtils" begin - eqn = node_to_symbolic(rep.equations[rep.best_idx], model) - n = symbolic_to_node(eqn, model) - eqn2 = convert(SymbolicUtils.Symbolic, n, model) - n2 = convert(Node, eqn2, model) - end - end +@testitem "Variable names - single outputs" tags = [:part3] begin + using SymbolicRegression + using SymbolicRegression: Node + using MLJBase + using SymbolicUtils + using Random: MersenneTwister - @testset "Multiple outputs" begin - X = (a=rand(32), b=rand(32)) - y = X.a .^ 2.1 - model = MultitargetSRRegressor(; niterations=10, stop_kws...) - mach = machine(model, X, reduce(hcat, [reshape(y, :, 1) for i in 1:3])) - fit!(mach) - rep = report(mach) - @test all( - eq -> occursin("a", eq), [rep.equation_strings[i][rep.best_idx[i]] for i in 1:3] - ) - ypred_good = predict(mach, X) - - @testset "Test that we can choose the equation" begin - ypred_same = predict(mach, (data=X, idx=rep.best_idx)) - @test ypred_good == ypred_same - - ypred_bad = predict(mach, (data=X, idx=[1, 1, 1])) - @test ypred_good != ypred_bad - - ypred_mixed = predict(mach, (data=X, idx=[rep.best_idx[1], 1, rep.best_idx[3]])) - @test ypred_mixed == hcat(ypred_good[:, 1], ypred_bad[:, 2], ypred_good[:, 3]) - - @test_throws AssertionError predict(mach, (data=X,)) - VERSION >= v"1.8" && - @test_throws "If specifying an equation index during" predict( - mach, (data=X,) - ) - VERSION >= v"1.8" && - @test_throws "If specifying an equation index during" predict( - mach, (X=X, idx=1) - ) - end - end + include("test_params.jl") - @testset "Named outputs" begin - X = (b1=randn(32), b2=randn(32)) - Y = (c1=X.b1 .* X.b2, c2=X.b1 .+ X.b2) - w = ones(32) - model = MultitargetSRRegressor(; niterations=10, stop_kws...) - mach = machine(model, X, Y, w) - fit!(mach) - test_outs = predict(mach, X) - @test isempty(setdiff((:c1, :c2), keys(test_outs))) - @test_throws AssertionError predict(mach, (a1=randn(32), b2=randn(32))) - VERSION >= v"1.8" && @test_throws "Variable names do not match fitted" predict( - mach, (b1=randn(32), a2=randn(32)) - ) - end + stop_kws = (; early_stop_condition=(loss, complexity) -> loss < 1e-5) + + rng = MersenneTwister(0) + X = (a=rand(rng, 32), b=rand(rng, 32)) + y = X.a .^ 2.1 + # We also make sure the deprecated npop and npopulations still work: + model = SRRegressor(; niterations=10, npop=1000, npopulations=15, stop_kws...) + mach = machine(model, X, y) + fit!(mach) + rep = report(mach) + @test occursin("a", rep.equation_strings[rep.best_idx]) + ypred_good = predict(mach, X) + @test sum(abs2, predict(mach, X) .- y) / length(y) < 1e-5 + + # Check that we can choose the equation + ypred_same = predict(mach, (data=X, idx=rep.best_idx)) + @test ypred_good == ypred_same + + ypred_bad = predict(mach, (data=X, idx=1)) + @test ypred_good != ypred_bad + + # Smoke test SymbolicUtils + eqn = node_to_symbolic(rep.equations[rep.best_idx], model) + n = symbolic_to_node(eqn, model) + eqn2 = convert(SymbolicUtils.Symbolic, n, model) + n2 = convert(Node, eqn2, model) end -@testset "Good predictions" begin - X = randn(100, 3) +@testitem "Variable names - multiple outputs" tags = [:part1] begin + using SymbolicRegression + using MLJBase + using Random: MersenneTwister + + include("test_params.jl") + + stop_kws = (; early_stop_condition=(loss, complexity) -> loss < 1e-5) + + rng = MersenneTwister(0) + X = (a=rand(rng, 32), b=rand(rng, 32)) + y = X.a .^ 2.1 + model = MultitargetSRRegressor(; niterations=10, stop_kws...) + mach = machine(model, X, reduce(hcat, [reshape(y, :, 1) for i in 1:3])) + fit!(mach) + rep = report(mach) + @test all( + eq -> occursin("a", eq), [rep.equation_strings[i][rep.best_idx[i]] for i in 1:3] + ) + ypred_good = predict(mach, X) + + # Test that we can choose the equation + ypred_same = predict(mach, (data=X, idx=rep.best_idx)) + @test ypred_good == ypred_same + + ypred_bad = predict(mach, (data=X, idx=[1, 1, 1])) + @test ypred_good != ypred_bad + + ypred_mixed = predict(mach, (data=X, idx=[rep.best_idx[1], 1, rep.best_idx[3]])) + @test ypred_mixed == hcat(ypred_good[:, 1], ypred_bad[:, 2], ypred_good[:, 3]) + + @test_throws AssertionError predict(mach, (data=X,)) + VERSION >= v"1.8" && + @test_throws "If specifying an equation index during" predict(mach, (data=X,)) + VERSION >= v"1.8" && + @test_throws "If specifying an equation index during" predict(mach, (X=X, idx=1)) +end + +@testitem "Variable names - named outputs" tags = [:part1] begin + using SymbolicRegression + using MLJBase + using Random: MersenneTwister + + include("test_params.jl") + + stop_kws = (; early_stop_condition=(loss, complexity) -> loss < 1e-5) + + rng = MersenneTwister(0) + X = (b1=randn(rng, 32), b2=randn(rng, 32)) + Y = (c1=X.b1 .* X.b2, c2=X.b1 .+ X.b2) + w = ones(32) + model = MultitargetSRRegressor(; niterations=10, stop_kws...) + mach = machine(model, X, Y, w) + fit!(mach) + test_outs = predict(mach, X) + @test isempty(setdiff((:c1, :c2), keys(test_outs))) + @test_throws AssertionError predict(mach, (a1=randn(32), b2=randn(32))) + VERSION >= v"1.8" && @test_throws "Variable names do not match fitted" predict( + mach, (b1=randn(32), a2=randn(32)) + ) +end + +@testitem "Good predictions" tags = [:part1] begin + using SymbolicRegression + using MLJBase + using Random: MersenneTwister + + include("test_params.jl") + + stop_kws = (; early_stop_condition=(loss, complexity) -> loss < 1e-5) + + rng = MersenneTwister(0) + X = randn(rng, 100, 3) Y = X model = MultitargetSRRegressor(; niterations=10, stop_kws...) mach = machine(model, X, Y) @@ -120,27 +135,42 @@ end @test sum(abs2, predict(mach, X) .- Y) / length(X) < 1e-6 end -@testset "Helpful errors" begin +@testitem "Helpful errors" tags = [:part3] begin + using SymbolicRegression + using MLJBase + using Random: MersenneTwister + + include("test_params.jl") + model = MultitargetSRRegressor() - mach = machine(model, randn(32, 3), randn(32); scitype_check_level=0) + rng = MersenneTwister(0) + mach = machine(model, randn(rng, 32, 3), randn(rng, 32); scitype_check_level=0) @test_throws AssertionError @quiet(fit!(mach)) VERSION >= v"1.8" && @test_throws "For single-output regression, please" @quiet(fit!(mach)) model = SRRegressor() - mach = machine(model, randn(32, 3), randn(32, 2); scitype_check_level=0) + rng = MersenneTwister(0) + mach = machine(model, randn(rng, 32, 3), randn(rng, 32, 2); scitype_check_level=0) @test_throws AssertionError @quiet(fit!(mach)) VERSION >= v"1.8" && @test_throws "For multi-output regression, please" @quiet(fit!(mach)) model = SRRegressor(; verbosity=0) - mach = machine(model, randn(32, 3), randn(32)) + rng = MersenneTwister(0) + mach = machine(model, randn(rng, 32, 3), randn(rng, 32)) @test_throws ErrorException @quiet(fit!(mach; verbosity=0)) end -@testset "Unfinished search" begin +@testitem "Unfinished search" tags = [:part3] begin + using SymbolicRegression + using MLJBase + using Suppressor + using Random: MersenneTwister + model = SRRegressor(; timeout_in_seconds=1e-10) - mach = machine(model, randn(32, 3), randn(32)) + rng = MersenneTwister(0) + mach = machine(model, randn(rng, 32, 3), randn(rng, 32)) fit!(mach) # Ensure that the hall of fame is empty: _, hof = mach.fitresult.state @@ -157,7 +187,8 @@ end @test occursin("Evaluation failed either due to", msg) model = MultitargetSRRegressor(; timeout_in_seconds=1e-10) - mach = machine(model, randn(32, 3), randn(32, 3)) + rng = MersenneTwister(0) + mach = machine(model, randn(rng, 32, 3), randn(rng, 32, 3)) fit!(mach) # Ensure that the hall of fame is empty: _, hofs = mach.fitresult.state diff --git a/test/test_nan_detection.jl b/test/test_nan_detection.jl index fb2a21d1f..5626a16c8 100644 --- a/test/test_nan_detection.jl +++ b/test/test_nan_detection.jl @@ -1,5 +1,6 @@ println("Testing NaN detection.") using SymbolicRegression +using LoopVectorization for T in [Float16, Float32, Float64], turbo in [true, false] T == Float16 && turbo && continue diff --git a/test/test_optimizer_mutation.jl b/test/test_optimizer_mutation.jl index 41f3bd67d..48560190a 100644 --- a/test/test_optimizer_mutation.jl +++ b/test/test_optimizer_mutation.jl @@ -3,7 +3,7 @@ using SymbolicRegression: SymbolicRegression using SymbolicRegression: Dataset, RunningSearchStatistics, RecordType using Optim: Optim using SymbolicRegression.MutateModule: next_generation -using DynamicExpressions: get_constants +using DynamicExpressions: get_scalar_constants mutation_weights = (; optimize=1e30) # We also test whether a named tuple works. options = Options(; @@ -35,7 +35,7 @@ new_member, _, _ = next_generation( tmp_recorder=RecordType(), ) -resultant_constants = get_constants(new_member.tree) +resultant_constants, refs = get_scalar_constants(new_member.tree) for k in [0.0, 0.2, 0.5, 1.0] @test sin(resultant_constants[1] * k + resultant_constants[2]) ≈ sin(2.1 * k + 0.8) atol = 1e-3 diff --git a/test/test_params.jl b/test/test_params.jl index 7aafd7f48..b74b58013 100644 --- a/test/test_params.jl +++ b/test/test_params.jl @@ -1,11 +1,15 @@ using SymbolicRegression: L2DistLoss, MutationWeights +using DynamicExpressions.OperatorEnumConstructionModule: empty_all_globals! using Optim: Optim using LineSearches: LineSearches using Test: Test ENV["SYMBOLIC_REGRESSION_IS_TESTING"] = "true" +empty_all_globals!() + const maximum_residual = 1e-2 + if !@isdefined(custom_cos) || !hasmethod(custom_cos, (String,)) @eval custom_cos(x) = cos(x) end @@ -69,3 +73,11 @@ const default_params = ( test_info(_, x) = error("Test failed: $x") test_info(_, ::Test.Pass) = nothing test_info(f::F, ::Test.Fail) where {F} = f() + +macro quiet(ex) + return quote + redirect_stderr(devnull) do + $ex + end + end |> esc +end diff --git a/test/test_pretty_printing.jl b/test/test_pretty_printing.jl new file mode 100644 index 000000000..42a28d14e --- /dev/null +++ b/test/test_pretty_printing.jl @@ -0,0 +1,111 @@ +@testitem "pretty print member" tags = [:part3] begin + using SymbolicRegression + + options = Options(; binary_operators=[+, ^]) + + ex = @parse_expression(x^2.0 + 1.5, binary_operators = [+, ^], variable_names = [:x]) + shower(x) = sprint((io, e) -> show(io, MIME"text/plain"(), e), x) + s = shower(ex) + @test s == "(x ^ 2.0) + 1.5" + + X = [1.0 2.0 3.0] + y = [2.0, 3.0, 4.0] + dataset = Dataset(X, y) + member = PopMember(dataset, ex, options; deterministic=false) + member.score = 1.0 + @test member isa PopMember{Float64,Float64,<:Expression{Float64,Node{Float64}}} + s_member = shower(member) + @test s_member == "PopMember(tree = ((x ^ 2.0) + 1.5), loss = 16.25, score = 1.0)" + + # New options shouldn't change this + options = Options(; binary_operators=[-, /]) + s_member = shower(member) + @test s_member == "PopMember(tree = ((x ^ 2.0) + 1.5), loss = 16.25, score = 1.0)" +end + +@testitem "pretty print hall of fame" tags = [:part1] begin + using SymbolicRegression + using SymbolicRegression: embed_metadata + using SymbolicRegression.CoreModule: safe_pow + + options = Options(; binary_operators=[+, safe_pow], maxsize=7) + + ex = @parse_expression( + $safe_pow(x, 2.0) + 1.5, binary_operators = [+, safe_pow], variable_names = [:x] + ) + shower(x) = sprint((io, e) -> show(io, MIME"text/plain"(), e), x) + s = shower(ex) + @test s == "(x ^ 2.0) + 1.5" + + X = [1.0 2.0 3.0] + y = [2.0, 3.0, 4.0] + dataset = Dataset(X, y) + member = PopMember(dataset, ex, options; deterministic=false) + member.score = 1.0 + @test member isa PopMember{Float64,Float64,<:Expression{Float64,Node{Float64}}} + + hof = HallOfFame(options, dataset) + hof = embed_metadata(hof, options, dataset) + hof.members[5] = member + hof.exists[5] = true + s_hof = strip(shower(hof)) + true_s = "HallOfFame{...}: + .exists[1] = false + .members[1] = undef + .exists[2] = false + .members[2] = undef + .exists[3] = false + .members[3] = undef + .exists[4] = false + .members[4] = undef + .exists[5] = true + .members[5] = PopMember(tree = ((x ^ 2.0) + 1.5), loss = 16.25, score = 1.0) + .exists[6] = false + .members[6] = undef + .exists[7] = false + .members[7] = undef + .exists[8] = false + .members[8] = undef + .exists[9] = false + .members[9] = undef" + + @test s_hof == true_s +end + +@testitem "pretty print expression" tags = [:part2] begin + using SymbolicRegression + using Suppressor: @capture_out + + options = Options(; binary_operators=[+, -, *, /], unary_operators=[cos]) + ex = @parse_expression( + cos(x) + y * y, operators = options.operators, variable_names = [:x, :y] + ) + + s = sprint((io, ex) -> print_tree(io, ex, options), ex) + @test strip(s) == "cos(x) + (y * y)" + + s = @capture_out begin + print_tree(ex, options) + end + @test strip(s) == "cos(x) + (y * y)" + + # Works with the tree itself too + s = @capture_out begin + print_tree(get_tree(ex), options) + end + @test strip(s) == "cos(x1) + (x2 * x2)" + s = sprint((io, ex) -> print_tree(io, ex, options), get_tree(ex)) + @test strip(s) == "cos(x1) + (x2 * x2)" + + # Updating options won't change printout, UNLESS + # we pass the options. + options = Options(; binary_operators=[/, *, -, +], unary_operators=[sin]) + + s = @capture_out begin + print_tree(ex) + end + @test strip(s) == "cos(x) + (y * y)" + + s = sprint((io, ex) -> print_tree(io, ex, options), ex) + @test strip(s) == "sin(x) / (y - y)" +end diff --git a/test/test_prob_pick_first.jl b/test/test_prob_pick_first.jl index 8967f0e04..b2f73704e 100644 --- a/test/test_prob_pick_first.jl +++ b/test/test_prob_pick_first.jl @@ -1,6 +1,7 @@ println("Testing whether tournament_selection_p works.") using SymbolicRegression -using DynamicExpressions.EquationModule: with_type_parameters +using DynamicExpressions: with_type_parameters, @parse_expression +using Test include("test_params.jl") n = 10 @@ -15,19 +16,21 @@ options = Options(; for reverse in [false, true] T = Float32 - NT = with_type_parameters(options.node_type, T) - members = PopMember{T,T,NT}[] # Generate members with scores from 0 to 1: - for i in 1:n - tree = Node("x1") * 3.2f0 - score = Float32(i - 1) / (n - 1) - if reverse - score = 1 - score - end - test_loss = 1.0f0 # (arbitrary for this test) - push!(members, PopMember(tree, score, test_loss, options)) - end + members = [ + let + ex = @parse_expression( + x1 * 3.2, operators = options.operators, variable_names = [:x1], + ) + score = Float32(i - 1) / (n - 1) + if reverse + score = 1 - score + end + test_loss = 1.0f0 # (arbitrary for this test) + PopMember(ex, score, test_loss, options; deterministic=false) + end for i in 1:n + ] pop = Population(members) diff --git a/test/test_simplification.jl b/test/test_simplification.jl index ad6c0a562..510ed5ab7 100644 --- a/test/test_simplification.jl +++ b/test/test_simplification.jl @@ -1,8 +1,11 @@ include("test_params.jl") using SymbolicRegression, Test using SymbolicUtils: simplify, Symbolic -using Random: MersenneTwister -using Base: ≈ +using DynamicExpressions.OperatorEnumConstructionModule: empty_all_globals! +#! format: off +using Base: ≈; using Random: MersenneTwister +#! format: on +# ^ Can't end line with ≈ due to JuliaSyntax.jl bug function Base.:≈(a::String, b::String) a = replace(a, r"\s+" => "") @@ -10,6 +13,8 @@ function Base.:≈(a::String, b::String) return a == b end +empty_all_globals!() + binary_operators = (+, -, /, *) index_of_mult = [i for (i, op) in enumerate(binary_operators) if op == *][1] diff --git a/test/test_stop_on_clock.jl b/test/test_stop_on_clock.jl index 295586e57..a7f925a20 100644 --- a/test/test_stop_on_clock.jl +++ b/test/test_stop_on_clock.jl @@ -1,13 +1,25 @@ using SymbolicRegression using Random +using Distributed: rmprocs include("test_params.jl") X = randn(MersenneTwister(0), Float32, 5, 100) y = 2 * cos.(X[4, :]) -options = Options(; default_params..., timeout_in_seconds=1) +# Ensure is precompiled: +options = Options(; + default_params..., + population_size=10, + ncycles_per_iteration=100, + maxsize=15, + timeout_in_seconds=1, +) +equation_search(X, y; niterations=1, options=options, parallelism=:serial) + +# Ensure nothing might prevent slow checking of the clock: +rmprocs() +GC.gc(true) # full=true start_time = time() -# With multithreading: -equation_search(X, y; niterations=10000000, options=options, parallelism=:multithreading) +equation_search(X, y; niterations=10000000, options=options, parallelism=:serial) end_time = time() @test end_time - start_time < 100 diff --git a/test/test_turbo_nan.jl b/test/test_turbo_nan.jl index 2447b8253..09d66c11a 100644 --- a/test/test_turbo_nan.jl +++ b/test/test_turbo_nan.jl @@ -1,4 +1,5 @@ using SymbolicRegression +using LoopVectorization bad_op(x::T) where {T} = (x >= 0) ? x : T(0) diff --git a/test/test_units.jl b/test/test_units.jl index 0e58173e1..e35e0a185 100644 --- a/test/test_units.jl +++ b/test/test_units.jl @@ -1,52 +1,22 @@ -using SymbolicRegression -using SymbolicRegression: - square, - cube, - plus, - sub, - mult, - greater, - cond, - relu, - logical_or, - logical_and, - safe_pow, - atanh_clip -using SymbolicRegression.InterfaceDynamicQuantitiesModule: get_units, get_dimensions_type -using SymbolicRegression.MLJInterfaceModule: unwrap_units_single -using SymbolicRegression.DimensionalAnalysisModule: - violates_dimensional_constraints, @maybe_return_call, WildcardQuantity -using DynamicQuantities: - DEFAULT_DIM_BASE_TYPE, - RealQuantity, - Quantity, - QuantityArray, - SymbolicDimensions, - Dimensions, - DimensionError, - @u_str, - @us_str, - uparse, - sym_uparse, - ustrip, - dimension -using MLJBase: MLJBase as MLJ -using MLJModelInterface: MLJModelInterface as MMI -include("utils.jl") - -custom_op(x, y) = x + y - -options = Options(; - binary_operators=[-, *, /, custom_op, ^], unary_operators=[cos, cbrt, sqrt, abs, inv] -) -@extend_operators options - -(x1, x2, x3) = (i -> Node(Float64; feature=i)).(1:3) - -@testset "Dimensional analysis" begin +@testitem "Dimensional analysis" tags = [:part3] begin + using SymbolicRegression + using SymbolicRegression.InterfaceDynamicQuantitiesModule: get_units + using SymbolicRegression.DimensionalAnalysisModule: violates_dimensional_constraints + using DynamicQuantities + using DynamicQuantities: DEFAULT_DIM_BASE_TYPE + X = randn(3, 100) y = @. cos(X[3, :] * 2.1 - 0.2) + 0.5 + custom_op(x, y) = x + y + options = Options(; + binary_operators=[-, *, /, custom_op, ^], + unary_operators=[cos, cbrt, sqrt, abs, inv], + ) + @extend_operators options + + (x1, x2, x3) = (i -> Node(Float64; feature=i)).(1:3) + D = Dimensions{DEFAULT_DIM_BASE_TYPE} SD = SymbolicDimensions{DEFAULT_DIM_BASE_TYPE} @@ -130,15 +100,24 @@ options = Options(; end end -options = Options(; binary_operators=[-, *, /, custom_op], unary_operators=[cos]) -@extend_operators options +@testitem "Search with dimensional constraints" tags = [:part3] begin + using SymbolicRegression + using SymbolicRegression.DimensionalAnalysisModule: violates_dimensional_constraints + using Random: MersenneTwister -@testset "Search with dimensional constraints" begin - X = rand(1, 128) .* 10 + rng = MersenneTwister(0) + X = rand(rng, 1, 128) .* 20 y = @. cos(X[1, :]) + X[1, :] dataset = Dataset(X, y; X_units=["kg"], y_units="1") + custom_op(x, y) = x + y + options = Options(; + binary_operators=[-, *, /, custom_op], + unary_operators=[cos], + early_stop_condition=(loss, complexity) -> (loss < 1e-7 && complexity <= 8), + ) + @extend_operators options - hof = EquationSearch(dataset; options) + hof = equation_search(dataset; niterations=1000, options) # Solutions should be like cos([cons] * X[1]) + [cons]*X[1] dominating = calculate_pareto_frontier(hof) @@ -151,26 +130,27 @@ options = Options(; binary_operators=[-, *, /, custom_op], unary_operators=[cos] # Check that every cos(...) which contains x1 also has complexity has_cos(tree) = - any(tree) do t + any(get_tree(tree)) do t t.degree == 1 && options.operators.unaops[t.op] == cos end valid_trees = [ - !has_cos(member.tree) || any(member.tree) do t - if ( + !has_cos(member.tree) || any( + t -> t.degree == 1 && - options.operators.unaops[t.op] == cos && - Node(Float64; feature=1) in t - ) - return compute_complexity(t, options) > 1 - end - return false - end for member in dominating + options.operators.unaops[t.op] == cos && + Node(Float64; feature=1) in t && + compute_complexity(t, options) > 1, + get_tree(member.tree), + ) for member in dominating ] @test all(valid_trees) @test length(valid_trees) > 0 end -@testset "Operator compatibility" begin +@testitem "Operator compatibility" tags = [:part3] begin + using SymbolicRegression + using DynamicQuantities + ## square cube plus sub mult greater cond relu logical_or logical_and safe_pow atanh_clip # Want to ensure these operators perform correctly in the context of units @test square(1.0u"m") == 1.0u"m^2" @@ -210,11 +190,24 @@ end @test_throws DimensionError atanh_clip(1.0u"m") end -options = Options(; binary_operators=[-, *, /, custom_op], unary_operators=[cos]) -@extend_operators options +@testitem "Search with dimensional constraints on output" tags = [:part3] begin + using SymbolicRegression + using MLJBase: MLJBase as MLJ + using DynamicQuantities + using Random: MersenneTwister + + include("utils.jl") -@testset "Search with dimensional constraints on output" begin - X = randn(2, 128) + custom_op(x, y) = x + y + options = Options(; + binary_operators=[-, *, /, custom_op], + unary_operators=[cos], + early_stop_condition=(loss, complexity) -> (loss < 1e-7 && complexity == 3), + ) + @extend_operators options + + rng = MersenneTwister(0) + X = randn(rng, 2, 128) X[2, :] .= X[1, :] y = X[1, :] .^ 2 @@ -224,7 +217,7 @@ options = Options(; binary_operators=[-, *, /, custom_op], unary_operators=[cos] # Solution should be x2 * x2 dominating = calculate_pareto_frontier(hof) - best = first(filter(m::PopMember -> m.loss < 1e-7, dominating)).tree + best = get_tree(first(filter(m::PopMember -> m.loss < 1e-7, dominating)).tree) x2 = Node(Float64; feature=2) @@ -236,18 +229,23 @@ options = Options(; binary_operators=[-, *, /, custom_op], unary_operators=[cos] @warn "Complexity of best solution is not 3; search with units might have failed" end - X = randn(2, 128) + rng = MersenneTwister(0) + X = randn(rng, 2, 128) y = @. cbrt(X[1, :]) .+ sqrt(abs(X[2, :])) - options2 = Options(; binary_operators=[+, *], unary_operators=[sqrt, cbrt, abs]) + options2 = Options(; + binary_operators=[+, *], + unary_operators=[sqrt, cbrt, abs], + early_stop_condition=(loss, complexity) -> (loss < 1e-7 && complexity == 6), + ) hof = EquationSearch(X, y; options=options2, X_units=["kg^3", "kg^2"], y_units="kg") dominating = calculate_pareto_frontier(hof) best = first(filter(m::PopMember -> m.loss < 1e-7, dominating)).tree @test compute_complexity(best, options2) == 6 - @test any(best) do t + @test any(get_tree(best)) do t t.degree == 1 && options2.operators.unaops[t.op] == cbrt end - @test any(best) do t + @test any(get_tree(best)) do t t.degree == 1 && options2.operators.unaops[t.op] == safe_sqrt end @@ -269,10 +267,10 @@ options = Options(; binary_operators=[-, *, /, custom_op], unary_operators=[cos] report = MLJ.report(mach) best_idx = findfirst(report.losses .< 1e-7)::Int @test report.complexities[best_idx] <= 6 - @test any(report.equations[best_idx]) do t + @test any(get_tree(report.equations[best_idx])) do t t.degree == 1 && t.op == 2 # cbrt end - @test any(report.equations[best_idx]) do t + @test any(get_tree(report.equations[best_idx])) do t t.degree == 1 && t.op == 1 # safe_sqrt end @@ -318,14 +316,20 @@ options = Options(; binary_operators=[-, *, /, custom_op], unary_operators=[cos] end end -@testset "Should error on mismatched units" begin +@testitem "Should error on mismatched units" tags = [:part3] begin + using SymbolicRegression + using DynamicQuantities + X = randn(11, 50) y = randn(50) VERSION >= v"1.8.0" && @test_throws("Number of features", Dataset(X, y; X_units=["m", "1"], y_units="kg")) end -@testset "Should print units" begin +@testitem "Should print units" tags = [:part3] begin + using SymbolicRegression + using DynamicQuantities + X = randn(5, 64) y = randn(64) dataset = Dataset(X, y; X_units=["m^3", "km/s", "kg", "1", "1"], y_units="kg") @@ -386,7 +390,13 @@ end ) == "x₅[5.0 m] * 3.2" end -@testset "Dimensionless constants" begin +@testitem "Dimensionless constants" tags = [:part3] begin + using SymbolicRegression + using SymbolicRegression.DimensionalAnalysisModule: violates_dimensional_constraints + using DynamicQuantities + + include("utils.jl") + options = Options(; binary_operators=[+, -, *, /, square, cube], unary_operators=[cos, sin], @@ -422,7 +432,14 @@ end end end -@testset "Miscellaneous" begin +@testitem "Miscellaneous tests of unit interface" tags = [:part3] begin + using SymbolicRegression + using DynamicQuantities + using SymbolicRegression.DimensionalAnalysisModule: @maybe_return_call, WildcardQuantity + using SymbolicRegression.MLJInterfaceModule: unwrap_units_single + using SymbolicRegression.InterfaceDynamicQuantitiesModule: get_dimensions_type + using MLJModelInterface: MLJModelInterface as MMI + function test_return_call(op::Function, w...) @maybe_return_call(typeof(first(w)), op, w) return nothing