diff --git a/CHANGELOG.md b/CHANGELOG.md index 170c6dfc5..4bcd9da91 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -91,68 +91,61 @@ A `TemplateExpression` is constructed by specifying: For example, you can create a `TemplateExpression` that enforces the constraint: `sin(f(x1, x2)) + g(x3)^2` - where we evolve `f` and `g` simultaneously. -Let's see some code for this. First, we define some base expressions for each input feature: +To do this, we first describe the structure using `TemplateStructure` +that takes a single closure function that maps a named tuple of +`ComposableExpression` expressions and a tuple of features: ```julia using SymbolicRegression -options = Options(; binary_operators=(+, *, /, -), unary_operators=(sin, cos)) -operators = options.operators -variable_names = ["x1", "x2", "x3"] - -# Base expressions: -x1 = Expression(Node{Float64}(; feature=1); operators, variable_names) -x2 = Expression(Node{Float64}(; feature=2); operators, variable_names) -x3 = Expression(Node{Float64}(; feature=3); operators, variable_names) +structure = TemplateStructure{(:f, :g)}( + ((; f, g), (x1, x2, x3)) -> sin(f(x1, x2)) + g(x3)^2 +) ``` -To build a `TemplateExpression`, we specify the structure using -a `TemplateStructure` object. This class has several fields: +This defines how the `TemplateExpression` should be +evaluated numerically on a given input. -- `combine`: Optional function taking a `NamedTuple` of function keys => expressions, - returning a single expression. Fallback method used by `get_tree` - on a `TemplateExpression` to generate a single `Expression`. -- `combine_vectors`: Optional function taking a `NamedTuple` of function keys => vectors, - returning a single vector. Used for evaluating the expression tree. - You may optionally define a method with a second argument `X` for if you wish - to include the data matrix `X` (of shape `[num_features, num_rows]`) in the - computation. -- `combine_strings`: Optional function taking a `NamedTuple` of function keys => strings, - returning a single string. Used for printing the expression tree. -- `variable_constraints`: Optional `NamedTuple` that defines which variables each sub-expression is allowed to access. - For example, requesting `f(x1, x2)` and `g(x3)` would be equivalent to `(; f=[1, 2], g=[3])`. - -Let's see an example: +The number of arguments allowed by each expression object +is inferred using this closure, though it can also +be passed explicitly with the `num_features` kwarg. ```julia - -# Combine f and g them into a single scalar expression: -structure = TemplateStructure(; - combine_strings=e -> "sin(" * e.f * ") + (" * e.g * ")^2", - combine_vectors=e -> map((f, g) -> sin(f) + g * g, e.f, e.g), - variable_constraints = (; f=[1, 2], g=[3]), # We constrain it to f(x1, x2) and g(x3) -) +operators = Options(binary_operators=(+, -, *, /)).operators +variable_names = ["x1", "x2", "x3"] +x1 = ComposableExpression(Node{Float64}(; feature=1); operators, variable_names) +x2 = ComposableExpression(Node{Float64}(; feature=2); operators, variable_names) +x3 = ComposableExpression(Node{Float64}(; feature=3); operators, variable_names) ``` -This defines how the `TemplateExpression` should be evaluated numerically on a given input, -and also how it should be represented as a string: +Note that using `x1` here refers to the +_relative_ argument to the expression. +So the node with feature equal to 1 will reference +the first argument, regardless of what it is. ```julia -julia> f_example = x1 - x2 * x2; # Normal `Expression` object - -julia> g_example = 1.5 * x3; - -julia> # Create TemplateExpression from these sub-expressions: - st_expr = TemplateExpression((; f=f_example, g=g_example); structure, operators, variable_names); +st_expr = TemplateExpression( + (; f=x1 - x2 * x2, g=1.5 * x1); + structure, + operators, + variable_names +) # Prints as: f = #1 - (#2 * #2); g = 1.5 * #1 + +# Evaluation combines evaluation of `f` and `g`, and combines them +# with the structure function: +st_expr([0.0; 1.0; 2.0;;]) +``` -julia> st_expr # Prints using `my_structure`! -sin(x1 - (x2 * x2)) + 1.5 * x3^2 +This also work with hierarchical expressions! For example, -julia> st_expr([0.0; 1.0; 2.0;;]) # Combines evaluation of `f` and `g` via `my_structure`! -1-element Vector{Float64}: - 8.158529015192103 +```julia +structure = TemplateStructure{(:f, :g)}( + ((; f, g), (x1, x2, x3)) -> f(x1, g(x2), x3^2) - g(x3) +) ``` +this is a valid structure! + We can also use this `TemplateExpression` in SymbolicRegression.jl searches!
@@ -168,11 +161,17 @@ This also has our variable mapping, which says we are fitting `f(x1, x2)`, `g1(x3)`, and `g2(x3)`: ```julia -structure = TemplateStructure(; - combine_strings=e -> "( " * e.f * " + " * e.g1 * ", " * e.f * " + " * e.g2 * " )", - combine_vectors=e -> map(i -> (e.f[i] + e.g1[i], e.f[i] + e.g2[i]), eachindex(e.f)), - variable_constraints = (; f=[1, 2], g1=[3], g2=[3]), -) +function my_structure((; f, g1, g2), (x1, x2, x3)) + _f = f(x1, x2) + _g1 = g1(x3) + _g2 = g2(x3) + + # We use `.x` to get the underlying vector + out = map((fi, g1i, g2i) -> (fi + g1i, fi + g2i), _f.x, _g1.x, _g2.x) + # And `.valid` to see whether the evaluations + return ValidVector(out, _f.valid && _g1.valid && _g2.valid) +end +structure = TemplateStructure{(:f, :g1, :g2)}(my_structure) ``` Now, our dataset is a regular 2D array of inputs for `X`. @@ -182,10 +181,7 @@ But our `y` is actually a _vector of 2-tuples_! X = rand(100, 3) .* 10 y = [ - ( - sin(X[i, 1]) + X[i, 3]^2, - sin(X[i, 1]) + X[i, 3] - ) + (sin(X[i, 1]) + X[i, 3]^2, sin(X[i, 1]) + X[i, 3]) for i in eachindex(axes(X, 1)) ] ``` diff --git a/Project.toml b/Project.toml index ba95ee1f4..ce0f2f397 100644 --- a/Project.toml +++ b/Project.toml @@ -47,7 +47,7 @@ Dates = "1" DifferentiationInterface = "0.5, 0.6" DispatchDoctor = "^0.4.17" Distributed = "<0.0.1, 1" -DynamicExpressions = "1.4" +DynamicExpressions = "1.5.0" DynamicQuantities = "1" Enzyme = "0.12" JSON3 = "1" diff --git a/docs/Project.toml b/docs/Project.toml index 6399bf082..f66ed72c1 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -3,7 +3,9 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" DynamicExpressions = "a40a106e-89c9-4ca8-8020-a735e8728b6b" Gumbo = "708ec375-b3d6-5a57-a7ce-8257bf98657a" Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306" +MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] Documenter = "0.27" diff --git a/docs/make.jl b/docs/make.jl index 336f2fcb0..678d08f00 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -106,6 +106,7 @@ makedocs(; "Examples" => [ "Short Examples" => "examples.md", "Template Expressions" => "examples/template_expression.md", + "Parameterized Expressions" => "examples/parameterized_function.md", ], "API" => "api.md", "Losses" => "losses.md", diff --git a/docs/src/types.md b/docs/src/types.md index bf954dfac..baa9e3800 100644 --- a/docs/src/types.md +++ b/docs/src/types.md @@ -63,32 +63,21 @@ These types allow you to define expressions with parameters that can be tuned to ## Template Expressions Template expressions allow you to specify predefined structures and constraints for your expressions. -These use the new `TemplateStructure` type to define how expressions should be combined and evaluated. +These use `ComposableExpressions` as their internal expression type, which makes them +flexible for creating a structure out of a single function. + +These use the `TemplateStructure` type to define how expressions should be combined and evaluated. ```@docs TemplateExpression TemplateStructure ``` -Example usage: - -```julia -# Define a template structure -structure = TemplateStructure( - combine=e -> e.f + e.g, # Create normal `Expression` - combine_vectors=e -> (e.f .+ e.g), # Output vector - combine_strings=e -> "($e.f) + ($e.g)", # Output string - variable_constraints=(; f=[1, 2], g=[3]) # Constrain dependencies -) - -# Use in options -model = SRRegressor(; - expression_type=TemplateExpression, - expression_options=(; structure=structure) -) -``` +Composable expressions allow you to combine multiple expressions together. -The `variable_constraints` field allows you to specify which variables can be used in different parts of the expression. +```@docs +ComposableExpression +``` ## Population diff --git a/examples/parameterized_function.jl b/examples/parameterized_function.jl index 9faefc97c..13d5fa370 100644 --- a/examples/parameterized_function.jl +++ b/examples/parameterized_function.jl @@ -1,21 +1,62 @@ +#literate_begin file="src/examples/parameterized_function.md" +#= +# Learning Parameterized Expressions + +_Note: Parametric expressions are currently considered experimental and may change in the future._ + +Parameterized expressions in SymbolicRegression.jl allow you to discover symbolic expressions that contain +optimizable parameters. This is particularly useful when you have data that follows different patterns +based on some categorical variable, or when you want to learn an expression with constants that should +be optimized during the search. + +In this tutorial, we'll generate synthetic data with class-dependent parameters and use symbolic regression to discover the parameterized expressions. + +## The Problem + +Let's create a synthetic dataset where the underlying function changes based on a class label: + +```math +y = 2\cos(x_2 + 0.1) + x_1^2 - 3.2 \ \ \ \ \text{[class 1]} \\ +\text{OR} \\ +y = 2\cos(x_2 + 1.5) + x_1^2 - 0.5 \ \ \ \ \text{[class 2]} +``` + +We will need to simultaneously learn the symbolic expression and per-class parameters! +=# using SymbolicRegression using Random: MersenneTwister using Zygote using MLJBase: machine, fit!, predict, report using Test -rng = MersenneTwister(0) -X = NamedTuple{(:x1, :x2, :x3, :x4, :x5)}(ntuple(_ -> randn(rng, Float32, 30), Val(5))) -X = (; X..., classes=rand(rng, 1:2, 30)) -p1 = [0.0f0, 3.2f0] -p2 = [1.5f0, 0.5f0] +#= +Now, we generate synthetic data, with these 2 different classes. + +Note that the `class` feature is given special treatment for the [`SRRegressor`](@ref) +as a categorical variable: +=# + +X = let rng = MersenneTwister(0), n = 30 + (; x1=randn(rng, n), x2=randn(rng, n), class=rand(rng, 1:2, n)) +end + +#= +Now, we generate target values using the true model that +has class-dependent parameters: +=# +y = let P1 = [0.1, 1.5], P2 = [3.2, 0.5] + [2 * cos(x2 + P1[class]) + x1^2 - P2[class] for (x1, x2, class) in zip(X.x1, X.x2, X.class)] +end -y = [ - 2 * cos(X.x4[i] + p1[X.classes[i]]) + X.x1[i]^2 - p2[X.classes[i]] for - i in eachindex(X.classes) -] +#= +## Setting up the Search -stop_at = Ref(1e-4) +We'll configure the symbolic regression search to: +- Use parameterized expressions with up to 2 parameters +- Use Zygote.jl for automatic differentiation during parameter optimization (important when using parametric expressions, as it is higher dimensional) +=# + +stop_at = Ref(1e-4) #src model = SRRegressor(; niterations=100, @@ -25,12 +66,38 @@ model = SRRegressor(; expression_type=ParametricExpression, expression_options=(; max_parameters=2), autodiff_backend=:Zygote, - parallelism=:multithreading, - early_stop_condition=(loss, _) -> loss < stop_at[], -) + early_stop_condition=(loss, _) -> loss < stop_at[], #src +); + +#= +Now, let's set up the machine and fit it: +=# mach = machine(model, X, y) +#= +At this point, you would run: + +```julia +fit!(mach) +``` + +You can extract the best expression and parameters with: + +```julia +report(mach).equations[end] +``` + +## Key Takeaways + +1. [`ParametricExpression`](@ref)s allows us to discover symbolic expressions with optimizable parameters +2. The parameters can capture class-dependent variations in the underlying model + +This approach is particularly useful when you suspect your data follows a common +functional form, but with varying parameters across different conditions or class! +=# +#literate_end + fit!(mach) idx1 = lastindex(report(mach).equations) ypred1 = predict(mach, (data=X, idx=idx1)) diff --git a/examples/template_expression.jl b/examples/template_expression.jl index 8c2465b1a..5b1729229 100644 --- a/examples/template_expression.jl +++ b/examples/template_expression.jl @@ -6,33 +6,41 @@ using Test: @test options = Options(; binary_operators=(+, *, /, -), unary_operators=(sin, cos)) operators = options.operators variable_names = (i -> "x$i").(1:3) -x1, x2, x3 = (i -> Expression(Node(Float64; feature=i); operators, variable_names)).(1:3) - -structure = TemplateStructure{(:f, :g1, :g2)}(; - combine_vectors=e -> map((f, g1, g2) -> (f + g1, f + g2), e.f, e.g1, e.g2), - combine_strings=e -> "( $(e.f) + $(e.g1), $(e.f) + $(e.g2) )", - variable_constraints=(; f=[1, 2], g1=[3], g2=[3]), +x1, x2, x3 = + (i -> ComposableExpression(Node(Float64; feature=i); operators, variable_names)).(1:3) + +structure = TemplateStructure{(:f, :g1, :g2)}( + ((; f, g1, g2), (x1, x2, x3)) -> let + _f = f(x1, x2) + _g1 = g1(x3) + _g2 = g2(x3) + _out1 = _f + _g1 + _out2 = _f + _g2 + ValidVector(map(tuple, _out1.x, _out2.x), _out1.valid && _out2.valid) + end, ) st_expr = TemplateExpression((; f=x1, g1=x3, g2=x3); structure, operators, variable_names) -X = rand(100, 3) .* 10 +x1 = rand(100) +x2 = rand(100) +x3 = rand(100) # Our dataset is a vector of 2-tuples -y = [(sin(X[i, 1]) + X[i, 3]^2, sin(X[i, 1]) + X[i, 3]) for i in eachindex(axes(X, 1))] +y = [(sin(x1[i]) + x3[i]^2, sin(x1[i]) + x3[i]) for i in eachindex(x1, x2, x3)] model = SRRegressor(; binary_operators=(+, *), unary_operators=(sin,), - maxsize=15, + maxsize=20, expression_type=TemplateExpression, expression_options=(; structure), # The elementwise needs to operate directly on each row of `y`: elementwise_loss=((x1, x2), (y1, y2)) -> (y1 - x1)^2 + (y2 - x2)^2, - early_stop_condition=(loss, complexity) -> loss < 1e-5 && complexity <= 7, + early_stop_condition=(loss, complexity) -> loss < 1e-6 && complexity <= 7, ) -mach = machine(model, X, y) +mach = machine(model, [x1 x2 x3], y) fit!(mach) # Check the performance of the model @@ -48,6 +56,6 @@ best_f = get_contents(best_expr).f best_g1 = get_contents(best_expr).g1 best_g2 = get_contents(best_expr).g2 -@test best_f(X') ≈ (@. sin(X[:, 1])) -@test best_g1(X') ≈ (@. X[:, 3] * X[:, 3]) -@test best_g2(X') ≈ (@. X[:, 3]) +@test best_f(x1, x2) ≈ @. sin.(x1) +@test best_g1(x3) ≈ (@. x3 * x3) +@test best_g2(x3) ≈ (@. x3) diff --git a/examples/template_expression_complex.jl b/examples/template_expression_complex.jl index b3794e823..8fad30577 100644 --- a/examples/template_expression_complex.jl +++ b/examples/template_expression_complex.jl @@ -7,9 +7,9 @@ Template expressions are a powerful feature in SymbolicRegression.jl that allow on the symbolic regression search. Rather than searching for a completely free-form expression, you can specify a template that combines multiple sub-expressions in a prescribed way. -This is particularly useful when: +This is particularly useful when any of the following are true: - You have domain knowledge about the functional form of your solution -- You want to learn vector-valued expressions (e.g., force fields, velocity fields) +- You want to learn expressions for a vector-valued output - You need to enforce constraints on which variables can appear in different parts of the expression - You want to share sub-expressions between multiple components @@ -29,7 +29,9 @@ the components of a particle's motion under magnetic and drag forces. We'll see Let's get started! =# -using SymbolicRegression, Random +using SymbolicRegression +using SymbolicRegression: ValidVector +using Random using MLJBase: machine, fit!, predict, report #= @@ -168,69 +170,71 @@ variable_names = ["t", "v_x", "v_y", "v_z", "T"] Template expressions require you to define a _structure_ function, which describes how to combine the sub-expressions into a single expression, numerically evaluate them, and print them. +These are evaluated using `ComposableExpression` for the individual +subexpressions (which allow them to be composed into new expressions), +and `ValidVector` for carrying through evaluation results. -First, let's just make a function that prints the expression: +Let's define our structure function. Note that this takes two arguments, +one being a named tuple of our expressions (`::ComposableExpression`), and the other being a tuple +of the input variables (`::ValidVector`). =# -function combine_strings(e) - ## e is a named tuple of strings representing each formula - return " ╭ 𝐁 = [ " * e.B_x * " , " * e.B_y * " , " * e.B_z * " ]\n ╰ 𝐅 = (" * e.F_d_scale * ") * 𝐯" - ## (Note that string interpolation will erase the colors, so use `*` instead) -end - -#= -So, this will just print the separate B and F_d expressions we've learned. - -Then, let's define an expression that takes the numerical values -evaluated in the TemplateExpression, and combines them into the resultant -force vector. Inside this function, we can do whatever we want. -=# -function combine_vectors(e, X) - ## This time, e is a named tuple of *vectors*, representing the batched - ## evaluation of each formula. - - ## First, extract the 3D velocity vectors from the input matrix: - v = [(X[2, i], X[3, i], X[4, i]) for i in eachindex(axes(X, 2))] - - ## Use this to compute the full drag force: - F_d = [F_d_scale_i .* vi for (F_d_scale_i, vi) in zip(e.F_d_scale, v)] - - ## Collect the magnetic field components that we've learned into the vector: - B = [(bx, by, bz) for (bx, by, bz) in zip(e.B_x, e.B_y, e.B_z)] - - ## Using this, we compute the magnetic force with a cross product: +function compute_force((; B_x, B_y, B_z, F_d_scale), (t, v_x, v_y, v_z, T)) + ## First, we evaluate each subexpression on the variables we wish + ## to have each depend on: + _B_x = B_x(t) + _B_y = B_y(t) + _B_z = B_z(t) + _F_d_scale = F_d_scale(T) + ## Note that we can also evaluate an expression multiple times, + ## including in a hierarchy! + + ## Now, let's do the same computation we did above to + ## get the total force vectors. Note that the evaluation + ## output is wrapped in `ValidVector`, so we need + ## to extract the `.x` to get raw vectors: + B = [(bx, by, bz) for (bx, by, bz) in zip(_B_x.x, _B_y.x, _B_z.x)] + v = [(vx, vy, vz) for (vx, vy, vz) in zip(v_x.x, v_y.x, v_z.x)] + + + ## Now, let's compute the drag force using our model: + F_d = [_F_d_scale .* vi for (vi, _F_d_scale) in zip(v, _F_d_scale.x)] + + ## Now, the magnetic force: F_mag = [cross(vi, Bi) for (vi, Bi) in zip(v, B)] ## Finally, we combine the drag and magnetic forces into the total force: - return [Force((fd .+ fm)...) for (fd, fm) in zip(F_d, F_mag)] + F = [Force((fd .+ fm)...) for (fd, fm) in zip(F_d, F_mag)] + + ## The output of this function needs to be another `ValidVector`, + ## which carries through the validity of the evaluation. We compute + ## this below. + ValidVector(F, _B_x.valid && _B_y.valid && _B_z.valid && _F_d_scale.valid) + ## (Note that if you were doing operations that could not handle NaNs, + ## you may need to return early - just be sure to also return the `ValidVector`!) end #= -For the functions we wish to learn, we can constraint what variables -each of them depends on, explicitly. Let's say B only depends on time, -and the drag force scale only depends on temperature (we explicitly -multiply the velocity in). -=# -variable_constraints = (; B_x=[1], B_y=[1], B_z=[1], F_d_scale=[5]) +Note above that we have constrained what variables each subexpression depends on. -#= -Now, we can create our template expression: +We have constrained the magnetic field to only depend on time, +and the drag force scale to only depend on temperature. +The other variables we simply pass through and use in the evaluation. + +Now, we can create our template expression, with the +subexpression symbols we wish to learn: =# -structure = TemplateStructure{(:B_x, :B_y, :B_z, :F_d_scale)}(; - combine_strings=combine_strings, - combine_vectors=combine_vectors, - variable_constraints=variable_constraints, -) +structure = TemplateStructure{(:B_x, :B_y, :B_z, :F_d_scale)}(compute_force) #= -Let's look at an example of how this would be used +First, let's look at an example of how this would be used in a TemplateExpression, for some guess at the form of the solution: =# options = Options(; binary_operators=(+, *, /, -), unary_operators=(sin, cos, sqrt, exp)) ## The inner operators are an `DynamicExpressions.OperatorEnum` which is used by `Expression`: operators = options.operators -t = Expression(Node{Float64}(; feature=1); operators, variable_names) -T = Expression(Node{Float64}(; feature=5); operators, variable_names) +t = ComposableExpression(Node{Float64}(; feature=1); operators, variable_names) +T = ComposableExpression(Node{Float64}(; feature=5); operators, variable_names) B_x = B_y = B_z = 2.1 * cos(t) F_d_scale = 1.0 * sqrt(T) @@ -243,7 +247,7 @@ ex = TemplateExpression( So we can see that it prints the expression as we've defined it. Now, we can create a regressor that builds template expressions -which follow this structure: +which follow this structure! =# model = SRRegressor(; binary_operators=(+, -, *, /), @@ -252,7 +256,7 @@ model = SRRegressor(; maxsize=35, expression_type=TemplateExpression, expression_options=(; structure=structure), - ## The elementwise needs to operate directly on each row of `y`: + ## Note that the elementwise loss needs to operate directly on each row of `y`: elementwise_loss=(F1, F2) -> (F1.x - F2.x)^2 + (F1.y - F2.y)^2 + (F1.z - F2.z)^2, batching=true, batch_size=30, diff --git a/src/AdaptiveParsimony.jl b/src/AdaptiveParsimony.jl index e3fded95c..f45891faa 100644 --- a/src/AdaptiveParsimony.jl +++ b/src/AdaptiveParsimony.jl @@ -1,6 +1,6 @@ module AdaptiveParsimonyModule -using ..CoreModule: AbstractOptions, MAX_DEGREE +using ..CoreModule: AbstractOptions """ RunningSearchStatistics diff --git a/src/Complexity.jl b/src/Complexity.jl index dec8fb63e..54b101898 100644 --- a/src/Complexity.jl +++ b/src/Complexity.jl @@ -20,12 +20,17 @@ if these are defined. function compute_complexity( tree::AbstractExpression, options::AbstractOptions; break_sharing=Val(false) ) - return compute_complexity(get_tree(tree), options; break_sharing) + if options.complexity_mapping isa Function + return options.complexity_mapping(tree)::Int + else + return compute_complexity(get_tree(tree), options; break_sharing) + end end function compute_complexity( tree::AbstractExpressionNode, options::AbstractOptions; break_sharing=Val(false) )::Int - if options.complexity_mapping.use + complexity_mapping = options.complexity_mapping + if complexity_mapping isa ComplexityMapping && complexity_mapping.use raw_complexity = _compute_complexity( tree, options.complexity_mapping; break_sharing ) diff --git a/src/ComposableExpression.jl b/src/ComposableExpression.jl new file mode 100644 index 000000000..866b56e5d --- /dev/null +++ b/src/ComposableExpression.jl @@ -0,0 +1,242 @@ +module ComposableExpressionModule + +using DispatchDoctor: @unstable +using DynamicExpressions: + AbstractExpression, + Expression, + AbstractExpressionNode, + AbstractOperatorEnum, + Metadata, + constructorof, + get_metadata, + eval_tree_array, + set_node!, + get_contents, + with_contents, + DynamicExpressions as DE +using DynamicExpressions.InterfacesModule: + ExpressionInterface, Interfaces, @implements, all_ei_methods_except, Arguments +using DynamicExpressions.ValueInterfaceModule: is_valid_array + +using ..ConstantOptimizationModule: ConstantOptimizationModule as CO +using ..CoreModule: get_safe_op + +abstract type AbstractComposableExpression{T,N} <: AbstractExpression{T,N} end + +""" + ComposableExpression{T,N,D} <: AbstractComposableExpression{T,N} <: AbstractExpression{T,N} + +A symbolic expression representing a mathematical formula as an expression tree (`tree::N`) with associated metadata (`metadata::Metadata{D}`). Used to construct and manipulate expressions in symbolic regression tasks. + +Example: + +Create variables `x1` and `x2`, and build an expression `f = x1 * sin(x2)`: + +```julia +operators = OperatorEnum(; binary_operators=(+, *, /, -), unary_operators=(sin, cos)) +variable_names = ["x1", "x2"] +x1 = ComposableExpression(Node(Float64; feature=1); operators, variable_names) +x2 = ComposableExpression(Node(Float64; feature=2); operators, variable_names) +f = x1 * sin(x2) +# ^This now references the first and second arguments of things passed to it: + +f(x1, x1) # == x1 * sin(x1) +f(randn(5), randn(5)) # == randn(5) .* sin.(randn(5)) + +# You can even pass it to itself: +f(f, f) # == (x1 * sin(x2)) * sin((x1 * sin(x2))) +``` +""" +struct ComposableExpression{ + T, + N<:AbstractExpressionNode{T}, + D<:@NamedTuple{operators::O, variable_names::V} where {O<:AbstractOperatorEnum,V}, +} <: AbstractComposableExpression{T,N} + tree::N + metadata::Metadata{D} +end + +@inline function ComposableExpression( + tree::AbstractExpressionNode{T}; metadata... +) where {T} + d = (; metadata...) + return ComposableExpression(tree, Metadata(d)) +end + +@unstable DE.constructorof(::Type{<:ComposableExpression}) = ComposableExpression + +DE.get_metadata(ex::AbstractComposableExpression) = ex.metadata +DE.get_contents(ex::AbstractComposableExpression) = ex.tree +DE.get_tree(ex::AbstractComposableExpression) = ex.tree + +function DE.get_operators( + ex::AbstractComposableExpression, operators::Union{AbstractOperatorEnum,Nothing}=nothing +) + return @something(operators, DE.get_metadata(ex).operators) +end +function DE.get_variable_names( + ex::AbstractComposableExpression, + variable_names::Union{Nothing,AbstractVector{<:AbstractString}}=nothing, +) + return @something(variable_names, DE.get_metadata(ex).variable_names, Some(nothing)) +end + +function DE.get_scalar_constants(ex::AbstractComposableExpression) + return DE.get_scalar_constants(DE.get_contents(ex)) +end +function DE.set_scalar_constants!(ex::AbstractComposableExpression, constants, refs) + return DE.set_scalar_constants!(DE.get_contents(ex), constants, refs) +end + +function Base.copy(ex::AbstractComposableExpression) + return ComposableExpression(copy(ex.tree), copy(ex.metadata)) +end + +function Base.convert(::Type{E}, ex::AbstractComposableExpression) where {E<:Expression} + return constructorof(E)(get_contents(ex), get_metadata(ex)) +end + +for name in (:combine_operators, :simplify_tree!) + @eval function DE.$name( + ex::AbstractComposableExpression{T,N}, + operators::Union{AbstractOperatorEnum,Nothing}=nothing, + ) where {T,N} + inner_ex = DE.$name(convert(Expression, ex), operators) + return with_contents(ex, inner_ex) + end +end + +function CO.count_constants_for_optimization(ex::AbstractComposableExpression) + return CO.count_constants_for_optimization(convert(Expression, ex)) +end + +@implements( + ExpressionInterface{all_ei_methods_except(())}, ComposableExpression, [Arguments()] +) + +""" + ValidVector{A<:AbstractVector} + +A wrapper for an AbstractVector paired with a validity flag (valid::Bool). +It represents a vector along with a boolean indicating whether the data is valid. +This is useful in computations where certain operations might produce invalid data +(e.g., division by zero), allowing the validity to propagate through calculations. +Operations on `ValidVector` instances automatically handle the valid flag: if all +operands are valid, the result is valid; if any operand is invalid, the result is +marked invalid. + +You will need to work with this to do highly custom operations with +`ComposableExpression` and `TemplateExpression`. + +# Fields: + +- `x::A`: The vector data. +- `valid::Bool`: Indicates if the data is valid. +""" +struct ValidVector{A<:AbstractVector} + x::A + valid::Bool +end +ValidVector(x::Tuple{Vararg{Any,2}}) = ValidVector(x...) + +function (ex::AbstractComposableExpression)(x) + return error("ComposableExpression does not support input of type $(typeof(x))") +end +function (ex::AbstractComposableExpression)( + x::AbstractVector, _xs::Vararg{AbstractVector,N} +) where {N} + __xs = (x, _xs...) + # Wrap it up for the recursive call + xs = map(Base.Fix2(ValidVector, true), __xs) + result = ex(xs...) + # Unwrap it + if _is_valid(result) + return _get_value(result) + else + # TODO: Make this more general. Like checking if the eltype is numeric. + x = _get_value(result) + nan = convert(eltype(x), NaN) + return x .* nan + end +end +function (ex::AbstractComposableExpression)( + x::ValidVector, _xs::Vararg{ValidVector,N} +) where {N} + xs = (x, _xs...) + valid = all(_is_valid, xs) + if !valid + return ValidVector(_get_value(first(xs)), false) + else + X = Matrix(stack(map(_get_value, xs))') + return ValidVector(eval_tree_array(ex, X)) + end +end +function (ex::AbstractComposableExpression)( + x::AbstractComposableExpression, _xs::Vararg{AbstractComposableExpression,N} +) where {N} + xs = (x, _xs...) + # To do this, we basically want to put the tree of x + # into the position of variable 1, and so on! + tree = copy(get_contents(ex)) + xs_trees = map(get_contents, xs) + # TODO: This is a bit dangerous, no? We are assuming + # that `foreach` won't try to go down the copied trees + foreach(tree) do node + if node.degree == 0 && !node.constant + set_node!(node, copy(xs_trees[node.feature])) + end + end + return with_contents(ex, tree) +end + +# Basically we want to vectorize every single operation on ValidVector, +# so that the user can use it easily. + +function apply_operator(op::F, x::Vararg{Any,N}) where {F<:Function,N} + if all(_is_valid, x) + vx = map(_get_value, x) + safe_op = get_safe_op(op) + result = safe_op.(vx...) + return ValidVector(result, is_valid_array(result)) + else + example_vector = + something(map(xi -> xi isa ValidVector ? xi : nothing, x)...)::ValidVector + return ValidVector(_get_value(example_vector), false) + end +end +_is_valid(x::ValidVector) = x.valid +_is_valid(x) = true +_get_value(x::ValidVector) = x.x +_get_value(x) = x + +#! format: off +# First, binary operators: +for op in ( + :*, :/, :+, :-, :^, :÷, :mod, :log, + :atan, :atand, :copysign, :flipsign, + :&, :|, :⊻, ://, :\, +) + @eval begin + Base.$(op)(x::ValidVector, y::ValidVector) = apply_operator(Base.$(op), x, y) + Base.$(op)(x::ValidVector, y::Number) = apply_operator(Base.$(op), x, y) + Base.$(op)(x::Number, y::ValidVector) = apply_operator(Base.$(op), x, y) + end +end + +for op in ( + :sin, :cos, :tan, :sinh, :cosh, :tanh, :asin, :acos, + :asinh, :acosh, :atanh, :sec, :csc, :cot, :asec, :acsc, :acot, :sech, :csch, + :coth, :asech, :acsch, :acoth, :sinc, :cosc, :cosd, :cotd, :cscd, :secd, + :sinpi, :cospi, :sind, :tand, :acosd, :acotd, :acscd, :asecd, :asind, + :log, :log2, :log10, :log1p, :exp, :exp2, :exp10, :expm1, :frexp, :exponent, + :float, :abs, :real, :imag, :conj, :unsigned, + :nextfloat, :prevfloat, :transpose, :significand, + :modf, :rem, :floor, :ceil, :round, :trunc, + :inv, :sqrt, :cbrt, :abs2, :angle, :factorial, + :(!), :-, :+, :sign, :identity, +) + @eval Base.$(op)(x::ValidVector) = apply_operator(Base.$(op), x) +end +#! format: on + +end diff --git a/src/Configure.jl b/src/Configure.jl index d8f029bfa..1151c4cb2 100644 --- a/src/Configure.jl +++ b/src/Configure.jl @@ -120,7 +120,12 @@ function move_functions_to_workers( ) where {T} # All the types of functions we need to move to workers: function_sets = ( - :unaops, :binops, :elementwise_loss, :early_stop_condition, :loss_function + :unaops, + :binops, + :elementwise_loss, + :early_stop_condition, + :loss_function, + :complexity_mapping, ) for function_set in function_sets @@ -152,6 +157,12 @@ function move_functions_to_workers( end ops = (options.loss_function,) example_inputs = (Node(T; val=zero(T)), dataset, options) + elseif function_set == :complexity_mapping + if !(options.complexity_mapping isa Function) + continue + end + ops = (options.complexity_mapping,) + example_inputs = (create_expression(zero(T), options, dataset),) else error("Invalid function set: $function_set") end @@ -171,7 +182,9 @@ function move_functions_to_workers( end end -function copy_definition_to_workers(op, procs, options::AbstractOptions, verbosity) +function copy_definition_to_workers( + @nospecialize(op), procs, @nospecialize(options::AbstractOptions), verbosity +) name = nameof(op) verbosity > 0 && @info "Copying definition of $op to workers..." src_ms = methods(op).ms @@ -195,7 +208,7 @@ function test_function_on_workers(example_inputs, op, procs) end function activate_env_on_workers( - procs, project_path::String, options::AbstractOptions, verbosity + procs, project_path::String, @nospecialize(options::AbstractOptions), verbosity ) verbosity > 0 && @info "Activating environment on workers." @everywhere procs begin @@ -286,7 +299,7 @@ function test_entire_pipeline( population_size=20, nlength=3, options=options, - nfeatures=dataset.nfeatures, + nfeatures=max_features(dataset, options), ) tmp_pop = s_r_cycle( dataset, diff --git a/src/Core.jl b/src/Core.jl index 6000412ce..c442efc73 100644 --- a/src/Core.jl +++ b/src/Core.jl @@ -10,9 +10,8 @@ include("OptionsStruct.jl") include("Operators.jl") include("Options.jl") -using .ProgramConstantsModule: - MAX_DEGREE, BATCH_DIM, FEATURE_DIM, RecordType, DATA_TYPE, LOSS_TYPE -using .DatasetModule: Dataset, is_weighted, has_units +using .ProgramConstantsModule: RecordType, DATA_TYPE, LOSS_TYPE +using .DatasetModule: Dataset, is_weighted, has_units, max_features using .MutationWeightsModule: AbstractMutationWeights, MutationWeights, sample_mutation using .OptionsStructModule: AbstractOptions, @@ -21,6 +20,7 @@ using .OptionsStructModule: specialized_options, operator_specialization using .OperatorsModule: + get_safe_op, plus, sub, mult, diff --git a/src/Dataset.jl b/src/Dataset.jl index 49a452938..2818fd1db 100644 --- a/src/Dataset.jl +++ b/src/Dataset.jl @@ -3,11 +3,9 @@ module DatasetModule using DynamicQuantities: Quantity using ..UtilsModule: subscriptify, get_base_type -using ..ProgramConstantsModule: BATCH_DIM, FEATURE_DIM, DATA_TYPE, LOSS_TYPE +using ..ProgramConstantsModule: DATA_TYPE, LOSS_TYPE using ...InterfaceDynamicQuantitiesModule: get_si_units, get_sym_units -import ...deprecate_varmap - """ Dataset{T<:DATA_TYPE,L<:LOSS_TYPE} @@ -102,13 +100,11 @@ function Dataset( X_units::Union{AbstractVector,Nothing}=nothing, y_units=nothing, # Deprecated: - varMap=nothing, kws..., ) where {T<:DATA_TYPE,L} Base.require_one_based_indexing(X) y !== nothing && Base.require_one_based_indexing(y) # Deprecation warning: - variable_names = deprecate_varmap(variable_names, varMap, :Dataset) if haskey(kws, :loss_type) Base.depwarn( "The `loss_type` keyword argument is deprecated. Pass as an argument instead.", @@ -129,8 +125,8 @@ function Dataset( ) end - n = size(X, BATCH_DIM) - nfeatures = size(X, FEATURE_DIM) + n = size(X, 2) + nfeatures = size(X, 1) variable_names = @something(variable_names, ["x$(i)" for i in 1:nfeatures]) display_variable_names = @something( display_variable_names, ["x$(subscriptify(i))" for i in 1:nfeatures] @@ -239,4 +235,8 @@ _fill!(x::NamedTuple, val) = foreach(v -> _fill!(v, val), values(x)) _fill!(::Nothing, val) = nothing _fill!(x, val) = x +function max_features(dataset::Dataset, _) + return dataset.nfeatures +end + end diff --git a/src/ExpressionBuilder.jl b/src/ExpressionBuilder.jl index d7bc5f5d6..00264be6a 100644 --- a/src/ExpressionBuilder.jl +++ b/src/ExpressionBuilder.jl @@ -7,19 +7,9 @@ module ExpressionBuilderModule using DispatchDoctor: @unstable using Compat: Fix using DynamicExpressions: - AbstractExpressionNode, - AbstractExpression, - Expression, - constructorof, - get_tree, - get_contents, - get_metadata, - with_contents, - with_metadata, - count_scalar_constants, - eval_tree_array + AbstractExpressionNode, AbstractExpression, constructorof, with_metadata using StatsBase: StatsBase -using ..CoreModule: AbstractOptions, Dataset, DATA_TYPE +using ..CoreModule: AbstractOptions, Dataset using ..HallOfFameModule: HallOfFame using ..PopulationModule: Population using ..PopMemberModule: PopMember diff --git a/src/HallOfFame.jl b/src/HallOfFame.jl index d09990ad7..44acf2ea7 100644 --- a/src/HallOfFame.jl +++ b/src/HallOfFame.jl @@ -3,8 +3,7 @@ module HallOfFameModule using StyledStrings: @styled_str using DynamicExpressions: AbstractExpression, string_tree using ..UtilsModule: split_string, AnnotatedIOBuffer, dump_buffer -using ..CoreModule: - MAX_DEGREE, AbstractOptions, Dataset, DATA_TYPE, LOSS_TYPE, relu, create_expression +using ..CoreModule: AbstractOptions, Dataset, DATA_TYPE, LOSS_TYPE, relu, create_expression using ..ComplexityModule: compute_complexity using ..PopMemberModule: PopMember using ..InterfaceDynamicExpressionsModule: format_dimensions, WILDCARD_UNIT_STRING @@ -133,7 +132,7 @@ const HEADER = let end function string_dominating_pareto_curve( - hallOfFame, dataset, options; width::Union{Integer,Nothing}=nothing + hallOfFame, dataset, options; width::Union{Integer,Nothing}=nothing, pretty::Bool=true ) terminal_width = (width === nothing) ? 100 : max(100, width::Integer) _buffer = IOBuffer() @@ -150,7 +149,7 @@ function string_dominating_pareto_curve( display_variable_names=dataset.display_variable_names, X_sym_units=dataset.X_sym_units, y_sym_units=dataset.y_sym_units, - raw=false, + pretty, ) y_prefix = dataset.y_variable_name unit_str = format_dimensions(dataset.y_sym_units) diff --git a/src/InterfaceDynamicExpressions.jl b/src/InterfaceDynamicExpressions.jl index 86f14d3be..28a284f72 100644 --- a/src/InterfaceDynamicExpressions.jl +++ b/src/InterfaceDynamicExpressions.jl @@ -1,6 +1,7 @@ module InterfaceDynamicExpressionsModule using Printf: @sprintf +using DispatchDoctor: @stable using Compat: Fix using DynamicExpressions: DynamicExpressions as DE, @@ -15,8 +16,6 @@ using ..CoreModule: AbstractOptions, Dataset using ..CoreModule.OptionsModule: inverse_binopmap, inverse_unaopmap using ..UtilsModule: subscriptify -import ..deprecate_varmap - """ eval_tree_array(tree::Union{AbstractExpression,AbstractExpressionNode}, X::AbstractArray, options::AbstractOptions; kws...) @@ -50,23 +49,31 @@ which speed up evaluation significantly. or nan was encountered, and a large loss should be assigned to the equation. """ -function DE.eval_tree_array( - tree::Union{AbstractExpressionNode,AbstractExpression}, - X::AbstractMatrix, - options::AbstractOptions; - kws..., -) - A = expected_array_type(X, typeof(tree)) - out, complete = DE.eval_tree_array( - tree, - X, - DE.get_operators(tree, options); - turbo=options.turbo, - bumper=options.bumper, +@stable( + default_mode = "disable", + default_union_limit = 2, + function DE.eval_tree_array( + tree::Union{AbstractExpressionNode,AbstractExpression}, + X::AbstractMatrix, + options::AbstractOptions; kws..., ) - return out::A, complete::Bool -end + A = expected_array_type(X, typeof(tree)) + out, complete = DE.eval_tree_array( + tree, + X, + DE.get_operators(tree, options); + turbo=options.turbo, + bumper=options.bumper, + kws..., + ) + if isnothing(out) + return nothing, false + else + return out::A, complete::Bool + end + end +) """Improve type inference by telling Julia the expected array returned.""" function expected_array_type(X::AbstractArray, ::Type) @@ -180,23 +187,21 @@ Convert an equation to a string. @inline function DE.string_tree( tree::Union{AbstractExpression,AbstractExpressionNode}, options::AbstractOptions; - raw::Bool=true, + pretty::Bool=false, X_sym_units=nothing, y_sym_units=nothing, variable_names=nothing, display_variable_names=variable_names, - varMap=nothing, kws..., ) - variable_names = deprecate_varmap(variable_names, varMap, :string_tree) - - if raw + if !pretty tree = tree isa GraphNode ? convert(Node, tree) : tree return DE.string_tree( tree, DE.get_operators(tree, options); f_variable=string_variable_raw, variable_names, + pretty, ) end @@ -213,6 +218,7 @@ Convert an equation to a string. ) end, variable_names=display_variable_names, + pretty, kws..., ) else @@ -222,6 +228,7 @@ Convert an equation to a string. f_variable=string_variable, f_constant=Fix{2}(Fix{3}(string_constant, ""), options.v_print_precision), variable_names=display_variable_names, + pretty, kws..., ) end diff --git a/src/LossFunctions.jl b/src/LossFunctions.jl index 637bb0fa4..ee9fbd496 100644 --- a/src/LossFunctions.jl +++ b/src/LossFunctions.jl @@ -1,5 +1,6 @@ module LossFunctionsModule +using DispatchDoctor: @stable using StatsBase: StatsBase using DynamicExpressions: AbstractExpression, AbstractExpressionNode, get_tree, eval_tree_array @@ -42,20 +43,38 @@ end end end -function eval_tree_dispatch( - tree::AbstractExpression, dataset::Dataset, options::AbstractOptions, idx -) - A = expected_array_type(dataset.X, typeof(tree)) - out, complete = eval_tree_array(tree, maybe_getindex(dataset.X, :, idx), options) - return out::A, complete::Bool -end -function eval_tree_dispatch( - tree::AbstractExpressionNode, dataset::Dataset, options::AbstractOptions, idx +@stable( + default_mode = "disable", + default_union_limit = 2, + begin + function eval_tree_dispatch( + tree::AbstractExpression, dataset::Dataset, options::AbstractOptions, idx + ) + A = expected_array_type(dataset.X, typeof(tree)) + out, complete = eval_tree_array( + tree, maybe_getindex(dataset.X, :, idx), options + ) + if isnothing(out) + return out, false + else + return out::A, complete::Bool + end + end + function eval_tree_dispatch( + tree::AbstractExpressionNode, dataset::Dataset, options::AbstractOptions, idx + ) + A = expected_array_type(dataset.X, typeof(tree)) + out, complete = eval_tree_array( + tree, maybe_getindex(dataset.X, :, idx), options + ) + if isnothing(out) + return out, false + else + return out::A, complete::Bool + end + end + end ) - A = expected_array_type(dataset.X, typeof(tree)) - out, complete = eval_tree_array(tree, maybe_getindex(dataset.X, :, idx), options) - return out::A, complete::Bool -end # Evaluate the loss of a particular expression on the input dataset. function _eval_loss( @@ -66,7 +85,7 @@ function _eval_loss( idx, )::L where {T<:DATA_TYPE,L<:LOSS_TYPE} (prediction, completion) = eval_tree_dispatch(tree, dataset, options, idx) - if !completion + if !completion || isnothing(prediction) return L(Inf) end diff --git a/src/MLJInterface.jl b/src/MLJInterface.jl index 395837ef2..84ae4d563 100644 --- a/src/MLJInterface.jl +++ b/src/MLJInterface.jl @@ -24,7 +24,8 @@ using DynamicQuantities: dimension using LossFunctions: SupervisedLoss using ..InterfaceDynamicQuantitiesModule: get_dimensions_type -using ..CoreModule: Options, Dataset, AbstractMutationWeights, MutationWeights, LOSS_TYPE +using ..CoreModule: + Options, Dataset, AbstractMutationWeights, MutationWeights, LOSS_TYPE, ComplexityMapping using ..CoreModule.OptionsModule: DEFAULT_OPTIONS, OPTION_DESCRIPTIONS using ..ComplexityModule: compute_complexity using ..HallOfFameModule: HallOfFame, format_hall_of_fame @@ -141,21 +142,21 @@ function MMI.update( options = old_fitresult === nothing ? get_options(m) : old_fitresult.options return _update(m, verbosity, old_fitresult, old_cache, X, y, w, options, nothing) end -function _update(m, verbosity, old_fitresult, old_cache, X, y, w, options, classes) - if isnothing(classes) && MMI.istable(X) && haskey(X, :classes) +function _update(m, verbosity, old_fitresult, old_cache, X, y, w, options, class) + if isnothing(class) && MMI.istable(X) && haskey(X, :class) if !(X isa NamedTuple) error("Classes can only be specified with named tuples.") end - new_X = Base.structdiff(X, (; X.classes)) - new_classes = X.classes + new_X = Base.structdiff(X, (; X.class)) + new_class = X.class return _update( - m, verbosity, old_fitresult, old_cache, new_X, y, w, options, new_classes + m, verbosity, old_fitresult, old_cache, new_X, y, w, options, new_class ) end if !isnothing(old_fitresult) @assert( - old_fitresult.has_classes == !isnothing(classes), - "If the first fit used classes, the second fit must also use classes." + old_fitresult.has_class == !isnothing(class), + "If the first fit used class, the second fit must also use class." ) end # To speed up iterative fits, we cache the types: @@ -209,7 +210,7 @@ function _update(m, verbosity, old_fitresult, old_cache, X, y, w, options, class X_units=X_units_clean, y_units=y_units_clean, verbosity=verbosity, - extra=isnothing(classes) ? (;) : (; classes), + extra=isnothing(class) ? (;) : (; class), # Help out with inference: v_dim_out=isa(m, SRRegressor) ? Val(1) : Val(2), ) @@ -220,7 +221,7 @@ function _update(m, verbosity, old_fitresult, old_cache, X, y, w, options, class variable_names=variable_names, y_variable_names=y_variable_names, y_is_table=MMI.istable(y), - has_classes=!isnothing(classes), + has_class=!isnothing(class), X_units=X_units_clean, y_units=y_units_clean, types=( @@ -376,17 +377,17 @@ end function eval_tree_mlj( tree::AbstractExpression, X_t, - classes, + class, m::AbstractSRRegressor, ::Type{T}, fitresult, i, prototype, ) where {T} - out, completed = if isnothing(classes) + out, completed = if isnothing(class) eval_tree_array(tree, X_t, fitresult.options) else - eval_tree_array(tree, X_t, classes, fitresult.options) + eval_tree_array(tree, X_t, class, fitresult.options) end if completed return wrap_units(out, fitresult.y_units, i) @@ -396,30 +397,29 @@ function eval_tree_mlj( end function MMI.predict( - m::M, fitresult, Xnew; idx=nothing, classes=nothing + m::M, fitresult, Xnew; idx=nothing, class=nothing ) where {M<:AbstractSRRegressor} - return _predict(m, fitresult, Xnew, idx, classes) + return _predict(m, fitresult, Xnew, idx, class) end -function _predict(m::M, fitresult, Xnew, idx, classes) where {M<:AbstractSRRegressor} +function _predict(m::M, fitresult, Xnew, idx, class) where {M<:AbstractSRRegressor} if Xnew isa NamedTuple && (haskey(Xnew, :idx) || haskey(Xnew, :data)) @assert( haskey(Xnew, :idx) && haskey(Xnew, :data) && length(keys(Xnew)) == 2, "If specifying an equation index during prediction, you must use a named tuple with keys `idx` and `data`." ) - return _predict(m, fitresult, Xnew.data, Xnew.idx, classes) + return _predict(m, fitresult, Xnew.data, Xnew.idx, class) end - if isnothing(classes) && MMI.istable(Xnew) && haskey(Xnew, :classes) + if isnothing(class) && MMI.istable(Xnew) && haskey(Xnew, :class) if !(Xnew isa NamedTuple) error("Classes can only be specified with named tuples.") end - Xnew2 = Base.structdiff(Xnew, (; Xnew.classes)) - return _predict(m, fitresult, Xnew2, idx, Xnew.classes) + Xnew2 = Base.structdiff(Xnew, (; Xnew.class)) + return _predict(m, fitresult, Xnew2, idx, Xnew.class) end - if fitresult.has_classes + if fitresult.has_class @assert( - !isnothing(classes), - "Classes must be specified if the model was fit with classes." + !isnothing(class), "Classes must be specified if the model was fit with class." ) end @@ -441,12 +441,12 @@ function _predict(m::M, fitresult, Xnew, idx, classes) where {M<:AbstractSRRegre if M <: SRRegressor return eval_tree_mlj( - params.equations[_idx], Xnew_t, classes, m, T, fitresult, nothing, prototype + params.equations[_idx], Xnew_t, class, m, T, fitresult, nothing, prototype ) elseif M <: MultitargetSRRegressor outs = [ eval_tree_mlj( - params.equations[i][_idx[i]], Xnew_t, classes, m, T, fitresult, i, prototype + params.equations[i][_idx[i]], Xnew_t, class, m, T, fitresult, i, prototype ) for i in eachindex(_idx, params.equations) ] out_matrix = reduce(hcat, outs) @@ -459,11 +459,14 @@ function _predict(m::M, fitresult, Xnew, idx, classes) where {M<:AbstractSRRegre end function get_equation_strings_for(::SRRegressor, trees, options, variable_names) - return (t -> string_tree(t, options; variable_names=variable_names)).(trees) + return ( + t -> string_tree(t, options; variable_names=variable_names, pretty=false) + ).(trees) end function get_equation_strings_for(::MultitargetSRRegressor, trees, options, variable_names) return [ - (t -> string_tree(t, options; variable_names=variable_names)).(ts) for ts in trees + (t -> string_tree(t, options; variable_names=variable_names, pretty=false)).(ts) for + ts in trees ] end diff --git a/src/Mutate.jl b/src/Mutate.jl index 7b828f6f3..e3ab8993f 100644 --- a/src/Mutate.jl +++ b/src/Mutate.jl @@ -2,21 +2,24 @@ module MutateModule using DynamicExpressions: AbstractExpression, - with_contents, get_tree, preserve_sharing, count_scalar_constants, simplify_tree!, combine_operators using ..CoreModule: - AbstractOptions, AbstractMutationWeights, Dataset, RecordType, sample_mutation + AbstractOptions, + AbstractMutationWeights, + Dataset, + RecordType, + sample_mutation, + max_features using ..ComplexityModule: compute_complexity using ..LossFunctionsModule: score_func, score_func_batched using ..CheckConstraintsModule: check_constraints using ..AdaptiveParsimonyModule: RunningSearchStatistics using ..PopMemberModule: PopMember using ..MutationFunctionsModule: - gen_random_tree_fixed_size, mutate_constant, mutate_operator, swap_operands, @@ -173,7 +176,7 @@ function next_generation( member.score, member.loss end - nfeatures = dataset.nfeatures + nfeatures = max_features(dataset, options) weights = copy(options.mutation_weights) diff --git a/src/Operators.jl b/src/Operators.jl index f99cc3bed..b38ccf97f 100644 --- a/src/Operators.jl +++ b/src/Operators.jl @@ -123,4 +123,13 @@ DE.declare_operator_alias(::typeof(safe_sqrt), ::Val{1}) = sqrt @ignore pow(x, y) = safe_pow(x, y) @ignore pow_abs(x, y) = safe_pow(x, y) +get_safe_op(op::F) where {F<:Function} = op +get_safe_op(::typeof(^)) = safe_pow +get_safe_op(::typeof(log)) = safe_log +get_safe_op(::typeof(log2)) = safe_log2 +get_safe_op(::typeof(log10)) = safe_log10 +get_safe_op(::typeof(log1p)) = safe_log1p +get_safe_op(::typeof(sqrt)) = safe_sqrt +get_safe_op(::typeof(acosh)) = safe_acosh + end diff --git a/src/Options.jl b/src/Options.jl index aa247709e..48df70bd9 100644 --- a/src/Options.jl +++ b/src/Options.jl @@ -326,6 +326,10 @@ const OPTION_DESCRIPTIONS = """- `defaults`: What set of defaults to use for `Op - `complexity_of_variables`: What complexity should be assigned to use of a variable, which can also be a vector indicating different per-variable complexity. By default, this is 1. +- `complexity_mapping`: Alternatively, you can pass a function that takes + the expression as input and returns the complexity. Make sure that + this operates on `AbstractExpression` (and unpacks to `AbstractExpressionNode`), + and returns an integer. - `alpha`: The probability of accepting an equation mutation during regularized evolution is given by exp(-delta_loss/(alpha * T)), where T goes from 1 to 0. Thus, alpha=infinite is the same as no annealing. @@ -474,6 +478,7 @@ $(OPTION_DESCRIPTIONS) @nospecialize(complexity_of_operators = nothing), @nospecialize(complexity_of_constants::Union{Nothing,Real} = nothing), @nospecialize(complexity_of_variables::Union{Nothing,Real,AbstractVector} = nothing), + ### complexity_mapping @nospecialize(warmup_maxsize_by::Union{Real,Nothing} = nothing), ### use_frequency ### use_frequency_in_tournament @@ -541,6 +546,7 @@ $(OPTION_DESCRIPTIONS) ## 3. The Objective: dimensionless_constants_only::Bool=false, ## 4. Working with Complexities: + complexity_mapping::Union{Function,ComplexityMapping,Nothing}=nothing, use_frequency::Bool=true, use_frequency_in_tournament::Bool=true, should_simplify::Union{Nothing,Bool}=nothing, @@ -689,6 +695,11 @@ $(OPTION_DESCRIPTIONS) error("You cannot specify both `elementwise_loss` and `loss_function`.") end end + if complexity_mapping !== nothing + @assert complexity_of_operators === nothing && + complexity_of_constants === nothing && + complexity_of_variables === nothing + end ################################# #### Supply defaults ############ @@ -761,12 +772,15 @@ $(OPTION_DESCRIPTIONS) una_constraints, bin_constraints, unary_operators, binary_operators ) - complexity_mapping = ComplexityMapping( - complexity_of_operators, - complexity_of_variables, - complexity_of_constants, - binary_operators, - unary_operators, + complexity_mapping = @something( + complexity_mapping, + ComplexityMapping( + complexity_of_operators, + complexity_of_variables, + complexity_of_constants, + binary_operators, + unary_operators, + ) ) if maxdepth === nothing diff --git a/src/OptionsStruct.jl b/src/OptionsStruct.jl index b39dbf0b5..22871606d 100644 --- a/src/OptionsStruct.jl +++ b/src/OptionsStruct.jl @@ -178,7 +178,7 @@ all properties of `Options` available for internal methods in SymbolicRegression abstract type AbstractOptions end struct Options{ - CM<:ComplexityMapping, + CM<:Union{ComplexityMapping,Function}, OP<:AbstractOperatorEnum, N<:AbstractExpressionNode, E<:AbstractExpression, diff --git a/src/ParametricExpression.jl b/src/ParametricExpression.jl index f98a1de08..a5664fd45 100644 --- a/src/ParametricExpression.jl +++ b/src/ParametricExpression.jl @@ -6,15 +6,12 @@ module ParametricExpressionModule using DynamicExpressions: DynamicExpressions as DE, - AbstractExpression, ParametricExpression, ParametricNode, get_metadata, - with_metadata, get_contents, with_contents, - get_tree, - eval_tree_array + get_tree using StatsBase: StatsBase using Random: default_rng, AbstractRNG @@ -35,7 +32,7 @@ function EB.extra_init_params( ::Val{embed}, ) where {T,embed,E<:ParametricExpression} num_params = options.expression_options.max_parameters - num_classes = length(unique(dataset.extra.classes)) + num_classes = length(unique(dataset.extra.class)) parameter_names = embed ? ["p$i" for i in 1:num_params] : nothing _parameters = if prototype === nothing randn(T, (num_params, num_classes)) @@ -64,7 +61,7 @@ end function DE.eval_tree_array( tree::ParametricExpression, X::AbstractMatrix, - classes::AbstractVector{<:Integer}, + class::AbstractVector{<:Integer}, options::AbstractOptions; kws..., ) @@ -72,7 +69,7 @@ function DE.eval_tree_array( out, complete = DE.eval_tree_array( tree, X, - classes, + class, DE.get_operators(tree, options); turbo=options.turbo, bumper=options.bumper, @@ -87,7 +84,7 @@ function LF.eval_tree_dispatch( out, complete = DE.eval_tree_array( tree, LF.maybe_getindex(dataset.X, :, idx), - LF.maybe_getindex(dataset.extra.classes, idx), + LF.maybe_getindex(dataset.extra.class, idx), options.operators, ) return out::A, complete::Bool diff --git a/src/Population.jl b/src/Population.jl index 6b9173c5c..ce76ee326 100644 --- a/src/Population.jl +++ b/src/Population.jl @@ -205,7 +205,7 @@ function record_population(pop::Population, options::AbstractOptions)::RecordTyp return RecordType( "population" => [ RecordType( - "tree" => string_tree(member.tree, options), + "tree" => string_tree(member.tree, options; pretty=false), "loss" => member.loss, "score" => member.score, "complexity" => compute_complexity(member, options), diff --git a/src/ProgramConstants.jl b/src/ProgramConstants.jl index 607ce08b2..7ae2ccd7b 100644 --- a/src/ProgramConstants.jl +++ b/src/ProgramConstants.jl @@ -1,8 +1,5 @@ module ProgramConstantsModule -const MAX_DEGREE = 2 -const BATCH_DIM = 2 -const FEATURE_DIM = 1 const RecordType = Dict{String,Any} const DATA_TYPE = Number diff --git a/src/ProgressBars.jl b/src/ProgressBars.jl index 1b6bc402d..c32b0c82f 100644 --- a/src/ProgressBars.jl +++ b/src/ProgressBars.jl @@ -1,7 +1,7 @@ module ProgressBarsModule using Compat: Fix -using ProgressMeter: Progress, next! +using ProgressMeter: ProgressMeter, Progress, next! using StyledStrings: @styled_str, annotatedstring using ..UtilsModule: AnnotatedString @@ -26,6 +26,11 @@ function barlen(pbar::WrappedProgressBar)::Int return @something(pbar.bar.barlen, displaysize(stdout)[2]) end +function ProgressMeter.finish!(pbar::WrappedProgressBar) + ProgressMeter.finish!(pbar.bar) + return nothing +end + """Iterate a progress bar.""" function manually_iterate!(pbar::WrappedProgressBar) width = barlen(pbar) diff --git a/src/SearchUtils.jl b/src/SearchUtils.jl index ed433df65..835191250 100644 --- a/src/SearchUtils.jl +++ b/src/SearchUtils.jl @@ -9,11 +9,10 @@ using Distributed: Distributed, @spawnat, Future, procs, addprocs using StatsBase: mean using StyledStrings: @styled_str using DispatchDoctor: @unstable -using Compat: Fix using DynamicExpressions: AbstractExpression, string_tree using ..UtilsModule: subscriptify -using ..CoreModule: Dataset, AbstractOptions, Options, MAX_DEGREE, RecordType +using ..CoreModule: Dataset, AbstractOptions, Options, RecordType, max_features using ..ComplexityModule: compute_complexity using ..PopulationModule: Population using ..PopMemberModule: PopMember @@ -269,7 +268,7 @@ function init_dummy_pops( first(datasets); population_size=1, options=options, - nfeatures=first(datasets).nfeatures, + nfeatures=max_features(first(datasets), options), ) # ^ Due to occasional inference issue, we manually specify the return type return [ @@ -281,7 +280,7 @@ function init_dummy_pops( datasets[j]; population_size=1, options=options, - nfeatures=datasets[j].nfeatures, + nfeatures=max_features(datasets[j], options), ) end for i in 1:npops ] for j in 1:length(datasets) @@ -470,7 +469,7 @@ function print_search_state( 100.0 * cycles_elapsed / total_cycles / nout ) - print("="^twidth * "\n") + print("═"^twidth * "\n") for (j, (hall_of_fame, dataset)) in enumerate(zip(hall_of_fames, datasets)) if nout > 1 @printf("Best equations for output %d\n", j) @@ -479,7 +478,7 @@ function print_search_state( hall_of_fame, dataset, options; width=width ) print(equation_strings * "\n") - print("="^twidth * "\n") + print("═"^twidth * "\n") end return print("Press 'q' and then to stop execution early.\n") end @@ -580,7 +579,7 @@ function save_to_file( complexities[i] = compute_complexity(member, options) losses[i] = member.loss strings[i] = string_tree( - member.tree, options; variable_names=dataset.variable_names + member.tree, options; variable_names=dataset.variable_names, pretty=false ) end diff --git a/src/SymbolicRegression.jl b/src/SymbolicRegression.jl index 75bdbfa19..62cf8626e 100644 --- a/src/SymbolicRegression.jl +++ b/src/SymbolicRegression.jl @@ -14,14 +14,14 @@ export Population, ParametricExpression, TemplateExpression, TemplateStructure, + ValidVector, + ComposableExpression, NodeSampler, AbstractExpression, AbstractExpressionNode, EvalOptions, SRRegressor, MultitargetSRRegressor, - LOSS_TYPE, - DATA_TYPE, #Functions: equation_search, @@ -41,7 +41,6 @@ export Population, set_node!, copy_node, node_to_symbolic, - node_type, symbolic_to_node, simplify_tree!, tree_mapreduce, @@ -88,6 +87,7 @@ using Pkg: Pkg using TOML: parsefile using Random: seed!, shuffle! using Reexport +using ProgressMeter: finish! using DynamicExpressions: Node, GraphNode, @@ -159,16 +159,17 @@ using DynamicExpressions: with_type_parameters LogCoshLoss using Compat: @compat, Fix -@compat public AbstractOptions, -AbstractRuntimeOptions, -RuntimeOptions, -AbstractMutationWeights, -mutate!, -condition_mutation_weights!, -sample_mutation, -MutationResult, -AbstractSearchState, -SearchState +#! format: off +@compat( + public, + ( + AbstractOptions, AbstractRuntimeOptions, RuntimeOptions, + AbstractMutationWeights, mutate!, condition_mutation_weights!, + sample_mutation, MutationResult, AbstractSearchState, SearchState, + LOSS_TYPE, DATA_TYPE, node_type, + ) +) +#! format: on # ^ We can add new functions here based on requests from users. # However, I don't want to add many functions without knowing what # users will actually want to overload. @@ -187,15 +188,6 @@ catch VersionNumber(0, 0, 0) end -function deprecate_varmap(variable_names, varMap, func_name) - if varMap !== nothing - Base.depwarn("`varMap` is deprecated; use `variable_names` instead", func_name) - @assert variable_names === nothing "Cannot pass both `varMap` and `variable_names`" - variable_names = varMap - end - return variable_names -end - using DispatchDoctor: @stable @stable default_mode = "disable" begin @@ -221,22 +213,23 @@ using DispatchDoctor: @stable include("Migration.jl") include("SearchUtils.jl") include("ExpressionBuilder.jl") + include("ComposableExpression.jl") include("TemplateExpression.jl") include("ParametricExpression.jl") end using .CoreModule: - MAX_DEGREE, - BATCH_DIM, - FEATURE_DIM, DATA_TYPE, LOSS_TYPE, RecordType, Dataset, AbstractOptions, Options, + ComplexityMapping, AbstractMutationWeights, MutationWeights, + get_safe_op, + max_features, is_weighted, sample_mutation, plus, @@ -317,6 +310,8 @@ using .SearchUtilsModule: get_cur_maxsize, update_hall_of_fame! using .TemplateExpressionModule: TemplateExpression, TemplateStructure +using .TemplateExpressionModule: TemplateExpression, TemplateStructure, ValidVector +using .ComposableExpressionModule: ComposableExpression using .ExpressionBuilderModule: embed_metadata, strip_metadata @stable default_mode = "disable" begin @@ -444,7 +439,6 @@ function equation_search( v_dim_out::Val{DIM_OUT}=Val(nothing), # Deprecated: multithreaded=nothing, - varMap=nothing, ) where {T<:DATA_TYPE,L,DIM_OUT} if multithreaded !== nothing error( @@ -452,7 +446,6 @@ function equation_search( "Choose one of :multithreaded, :multiprocessing, or :serial.", ) end - variable_names = deprecate_varmap(variable_names, varMap, :equation_search) if weights !== nothing @assert length(weights) == length(y) @@ -532,6 +525,7 @@ end _warmup_search!(state, datasets, ropt, options) _main_search_loop!(state, datasets, ropt, options) _tear_down!(state, ropt, options) + _info_dump(state, datasets, ropt, options) return _format_output(state, datasets, ropt, options) end @@ -715,7 +709,7 @@ function _initialize_search!( population_size=options.population_size, nlength=3, options=options, - nfeatures=datasets[j].nfeatures, + nfeatures=max_features(datasets[j], options), ), HallOfFame(options, datasets[j]), RecordType(), @@ -939,7 +933,7 @@ function _main_search_loop!( options, total_cycles, cycles_remaining=state.cycles_remaining[j] ) move_window!(state.all_running_search_statistics[j]) - if progress_bar !== nothing + if !isnothing(progress_bar) head_node_occupation = estimate_work_fraction(resource_monitor) update_progress_bar!( progress_bar, @@ -1009,6 +1003,9 @@ function _main_search_loop!( end ################################################################ end + if !isnothing(progress_bar) + finish!(progress_bar) + end return nothing end function _tear_down!( @@ -1088,6 +1085,50 @@ end end return (out_pop, best_seen, record, num_evals) end +function _info_dump( + state::AbstractSearchState, + datasets::Vector{D}, + ropt::AbstractRuntimeOptions, + options::AbstractOptions, +) where {D<:Dataset} + ropt.verbosity <= 0 && return nothing + + nout = length(state.halls_of_fame) + if nout > 1 + @info "Final populations:" + else + @info "Final population:" + end + for (j, (hall_of_fame, dataset)) in enumerate(zip(state.halls_of_fame, datasets)) + if nout > 1 + @info "Output $j:" + end + equation_strings = string_dominating_pareto_curve( + hall_of_fame, + dataset, + options; + width=@something( + options.terminal_width, + ropt.progress ? displaysize(stdout)[2] : nothing, + Some(nothing) + ) + ) + println(equation_strings) + end + + if options.save_to_file + output_directory = joinpath( + something(options.output_directory, "outputs"), ropt.run_id + ) + @info "Results saved to:" + for j in 1:nout + filename = nout > 1 ? "hall_of_fame_output$(j).csv" : "hall_of_fame.csv" + output_file = joinpath(output_directory, filename) + println(" - ", output_file) + end + end + return nothing +end include("MLJInterface.jl") using .MLJInterfaceModule: SRRegressor, MultitargetSRRegressor diff --git a/src/TemplateExpression.jl b/src/TemplateExpression.jl index 586589fab..39ceab3d6 100644 --- a/src/TemplateExpression.jl +++ b/src/TemplateExpression.jl @@ -2,8 +2,8 @@ module TemplateExpressionModule using Random: AbstractRNG using Compat: Fix -using DispatchDoctor: @unstable -using StyledStrings: @styled_str +using DispatchDoctor: @unstable, @stable +using StyledStrings: @styled_str, annotatedstring using DynamicExpressions: DynamicExpressions as DE, AbstractStructuredExpression, @@ -11,7 +11,6 @@ using DynamicExpressions: AbstractExpression, AbstractOperatorEnum, OperatorEnum, - Expression, Metadata, get_contents, with_contents, @@ -20,13 +19,12 @@ using DynamicExpressions: get_variable_names, get_tree, node_type, - eval_tree_array, count_nodes using DynamicExpressions.InterfacesModule: ExpressionInterface, Interfaces, @implements, all_ei_methods_except, Arguments using ..CoreModule: - AbstractOptions, Dataset, CoreModule as CM, AbstractMutationWeights, has_units + AbstractOptions, Options, Dataset, CoreModule as CM, AbstractMutationWeights, has_units using ..ConstantOptimizationModule: ConstantOptimizationModule as CO using ..InterfaceDynamicExpressionsModule: InterfaceDynamicExpressionsModule as IDE using ..MutationFunctionsModule: MutationFunctionsModule as MF @@ -37,131 +35,76 @@ using ..ComplexityModule: ComplexityModule using ..LossFunctionsModule: LossFunctionsModule as LF using ..MutateModule: MutateModule as MM using ..PopMemberModule: PopMember +using ..ComposableExpressionModule: ComposableExpression, ValidVector """ - TemplateStructure{K,S,N,E,C} <: Function + TemplateStructure{K,E,NF} <: Function A struct that defines a prescribed structure for a `TemplateExpression`, -including functions that define the result of combining sub-expressions in different contexts. +including functions that define the result in different contexts. The `K` parameter is used to specify the symbols representing the inner expressions. If not declared using the constructor `TemplateStructure{K}(...)`, the keys of the `variable_constraints` `NamedTuple` will be used to infer this. # Fields -- `combine`: Optional function taking a `NamedTuple` of function keys => expressions, - returning a single expression. Fallback method used by `get_tree` - on a `TemplateExpression` to generate a single `Expression`. -- `combine_vectors`: Optional function taking a `NamedTuple` of function keys => vectors, - returning a single vector. Used for evaluating the expression tree. - You may optionally define a method with a second argument `X` for if you wish - to include the data matrix `X` (of shape `[num_features, num_rows]`) in the - computation. -- `combine_strings`: Optional function taking a `NamedTuple` of function keys => strings, - returning a single string. Used for printing the expression tree. -- `variable_constraints`: Optional `NamedTuple` that defines which variables each sub-expression is allowed to access. - For example, requesting `f(x1, x2)` and `g(x3)` would be equivalent to `(; f=[1, 2], g=[3])`. +- `combine`: Required function taking a `NamedTuple` of `ComposableExpression`s (sharing the keys `K`), + and then tuple representing the data of `ValidVector`s. For example, + `((; f, g), (x1, x2, x3)) -> f(x1, x2) + g(x3)` would be a valid `combine` function. You may also + re-use the callable expressions and use different inputs, such as + `((; f, g), (x1, x2)) -> f(x1 + g(x2)) - g(x1)` is another valid choice. +- `num_features`: Optional `NamedTuple` of function keys => integers representing the number of + features used by each expression. If not provided, it will be inferred using the `combine` + function. For example, if `f` takes two arguments, and `g` takes one, then + `num_features = (; f=2, g=1)`. """ -struct TemplateStructure{ - K, - E<:Union{Nothing,Function}, - N<:Union{Nothing,Function}, - S<:Union{Nothing,Function}, - C<:Union{Nothing,NamedTuple{<:Any,<:Tuple{Vararg{Vector{Int}}}}}, -} <: Function +struct TemplateStructure{K,E<:Function,NF<:NamedTuple{K}} <: Function combine::E - combine_vectors::N - combine_strings::S - variable_constraints::C + num_features::NF end -function TemplateStructure{K}(combine::E; kws...) where {K,E<:Function} - return TemplateStructure{K}(; combine, kws...) +function TemplateStructure{K}(combine::E, num_features=nothing) where {K,E<:Function} + num_features = @something(num_features, infer_variable_constraints(Val(K), combine)) + return TemplateStructure{K,E,typeof(num_features)}(combine, num_features) end -function TemplateStructure{K}(; kws...) where {K} - return TemplateStructure(; _function_keys=Val(K), kws...) -end -function TemplateStructure(combine::E; kws...) where {E<:Function} - return TemplateStructure(; combine, kws...) + +@unstable function combine(template::TemplateStructure, args...) + return template.combine(args...) end -function TemplateStructure(; - combine::E=nothing, - combine_vectors::N=nothing, - combine_strings::S=nothing, - variable_constraints::C=nothing, - _function_keys::Val{K}=Val(nothing), -) where { - K, - E<:Union{Nothing,Function}, - N<:Union{Nothing,Function}, - S<:Union{Nothing,Function}, - C<:Union{Nothing,NamedTuple{<:Any,<:Tuple{Vararg{Vector{Int}}}}}, -} - Kout = if K !== nothing && variable_constraints !== nothing - K != keys(variable_constraints) && - throw(ArgumentError("`K` must match the keys of `variable_constraints`.")) - K - elseif K !== nothing - K - elseif variable_constraints !== nothing - keys(variable_constraints) - else - throw( - ArgumentError( - "If `variable_constraints` is not provided, " * - "you must initialize `TemplateStructure` with " * - "`TemplateStructure{K}(...)`, for tuple of symbols `K`.", - ), - ) + +get_function_keys(::TemplateStructure{K}) where {K} = K + +function _record_composable_expression!(variable_constraints, ::Val{k}, args...) where {k} + vc = variable_constraints[k][] + if vc == -1 + variable_constraints[k][] = length(args) + elseif vc != length(args) + throw(ArgumentError("Inconsistent number of arguments passed to $k")) end - return TemplateStructure{Kout,E,N,S,C}( - combine, combine_vectors, combine_strings, variable_constraints + return first(args) +end + +"""Infers number of features used by each subexpression, by passing in test data.""" +function infer_variable_constraints(::Val{K}, combiner::F) where {K,F} + variable_constraints = NamedTuple{K}(map(_ -> Ref(-1), K)) + # Now, we need to evaluate the `combine` function to see how many + # features are used for each function call. If unset, we record it. + # If set, we validate. + inner = Fix{1}(_record_composable_expression!, variable_constraints) + _recorders_of_composable_expressions = NamedTuple{K}(map(k -> Fix{1}(inner, Val(k)), K)) + # We use an evaluation to get the variable constraints + combiner( + _recorders_of_composable_expressions, + Base.Iterators.repeated(ValidVector(ones(Float64, 1), true)), ) -end -# TODO: This interface is ugly. Part of this is due to AbstractStructuredExpression, -# which was not written with this `TemplateStructure` in mind, but just with a -# single callable function. - -function combine(template::TemplateStructure, nt::NamedTuple) - return (template.combine::Function)(nt)::AbstractExpression -end -function combine_vectors( - template::TemplateStructure, nt::NamedTuple, X::Union{AbstractMatrix,Nothing}=nothing -) - combiner = template.combine_vectors::Function - if X !== nothing && hasmethod(combiner, typeof((nt, X))) - # TODO: Refactor this - return combiner(nt, X)::AbstractVector - else - return combiner(nt)::AbstractVector + inferred = NamedTuple{K}(map(x -> x[], values(variable_constraints))) + if any(==(-1), values(inferred)) + failed_keys = filter(k -> inferred[k] == -1, K) + throw(ArgumentError("Failed to infer number of features used by $failed_keys")) end -end -function combine_strings(template::TemplateStructure, nt::NamedTuple) - return (template.combine_strings::Function)(nt)::AbstractString -end - -function (template::TemplateStructure)( - nt::NamedTuple{<:Any,<:Tuple{AbstractExpression,Vararg{AbstractExpression}}} -) - return combine(template, nt) -end -function (template::TemplateStructure)( - nt::NamedTuple{<:Any,<:Tuple{AbstractVector,Vararg{AbstractVector}}}, - X::Union{AbstractMatrix,Nothing}=nothing, -) - return combine_vectors(template, nt, X) -end -function (template::TemplateStructure)( - nt::NamedTuple{<:Any,<:Tuple{AbstractString,Vararg{AbstractString}}} -) - return combine_strings(template, nt) + return inferred end -can_combine(template::TemplateStructure) = template.combine !== nothing -can_combine_vectors(template::TemplateStructure) = template.combine_vectors !== nothing -can_combine_strings(template::TemplateStructure) = template.combine_strings !== nothing -get_function_keys(::TemplateStructure{K}) where {K} = K - """ TemplateExpression{T,F,N,E,TS,D} <: AbstractStructuredExpression{T,F,N,E,D} @@ -199,20 +142,8 @@ x3 = Expression(Node{Float64}(; feature=3); operators, variable_names) example_expr = (; f=x1, g=x3) st_expr = TemplateExpression( example_expr; - structure=TemplateStructure{(:f, :g)}(nt -> sin(nt.f) + nt.g * nt.g), - operators, - variable_names, -) -``` - -We can also define constraints on which variables each sub-expression is allowed to access: - -```julia -variable_constraints = (; f=[1, 2], g=[3]) -st_expr = TemplateExpression( - example_expr; - structure=TemplateStructure( - nt -> sin(nt.f) + nt.g * nt.g; variable_constraints + structure=TemplateStructure{(:f, :g)}( + ((; f, g), (x1, x2, x3)) -> sin(f(x1, x2)) + g(x3)^2 ), operators, variable_names, @@ -228,9 +159,11 @@ struct TemplateExpression{ T, F<:TemplateStructure, N<:AbstractExpressionNode{T}, - E<:Expression{T,N}, # TODO: Generalize this + E<:ComposableExpression{T,N}, TS<:NamedTuple{<:Any,<:NTuple{<:Any,E}}, - D<:@NamedTuple{structure::F, operators::O, variable_names::V} where {O,V}, + D<:@NamedTuple{ + structure::F, operators::O, variable_names::V + } where {O<:AbstractOperatorEnum,V}, } <: AbstractStructuredExpression{T,F,N,E,D} trees::TS metadata::Metadata{D} @@ -268,28 +201,28 @@ end ExpressionInterface{all_ei_methods_except(())}, TemplateExpression, [Arguments()] ) -function combine(ex::TemplateExpression, nt::NamedTuple) - return combine(get_metadata(ex).structure, nt) -end -function combine_vectors( - ex::TemplateExpression, nt::NamedTuple, X::Union{AbstractMatrix,Nothing}=nothing -) - return combine_vectors(get_metadata(ex).structure, nt, X) -end -function combine_strings(ex::TemplateExpression, nt::NamedTuple) - return combine_strings(get_metadata(ex).structure, nt) +@unstable function combine(ex::TemplateExpression, args...) + return combine(get_metadata(ex).structure, args...) end -function can_combine(ex::TemplateExpression) - return can_combine(get_metadata(ex).structure) -end -function can_combine_vectors(ex::TemplateExpression) - return can_combine_vectors(get_metadata(ex).structure) -end -function can_combine_strings(ex::TemplateExpression) - return can_combine_strings(get_metadata(ex).structure) +function DE.get_tree(ex::TemplateExpression{<:Any,<:Any,<:Any,E}) where {E} + raw_contents = get_contents(ex) + total_num_features = max(values(get_metadata(ex).structure.num_features)...) + example_inner_ex = first(values(raw_contents)) + example_tree = get_contents(example_inner_ex)::AbstractExpressionNode + + variable_trees = [ + DE.constructorof(typeof(example_tree))(; feature=i) for i in 1:total_num_features + ] + variable_expressions = [ + with_contents(inner_ex, variable_tree) for + (inner_ex, variable_tree) in zip(values(raw_contents), variable_trees) + ] + + return DE.get_tree( + combine(get_metadata(ex).structure, raw_contents, variable_expressions) + ) end -get_function_keys(ex::TemplateExpression) = get_function_keys(get_metadata(ex).structure) function EB.create_expression( t::AbstractExpressionNode{T}, @@ -305,7 +238,8 @@ function EB.create_expression( operators = options.operators variable_names = embed ? dataset.variable_names : nothing inner_expressions = ntuple( - _ -> Expression(copy(t); operators, variable_names), length(function_keys) + _ -> ComposableExpression(copy(t); operators, variable_names), + Val(length(function_keys)), ) # TODO: Generalize to other inner expression types return DE.constructorof(E)( @@ -338,61 +272,85 @@ function ComplexityModule.compute_complexity( ) end +# Rather than using iterator with repeat, just make a tuple: +function _colors(::Val{n}) where {n} + return ntuple( + (i -> (:magenta, :green, :red, :blue, :yellow, :cyan)[mod1(i, n)]), Val(n) + ) +end + _color_string(s::AbstractString, c::Symbol) = styled"{$c:$s}" function DE.string_tree( - tree::TemplateExpression, operators::Union{AbstractOperatorEnum,Nothing}=nothing; kws... -) - raw_contents = get_contents(tree) - if can_combine_strings(tree) - function_keys = keys(raw_contents) - colors = Base.Iterators.cycle((:magenta, :green, :red, :blue, :yellow, :cyan)) - inner_strings = NamedTuple{function_keys}( - map(ex -> DE.string_tree(ex, operators; kws...), values(raw_contents)) - ) - colored_strings = NamedTuple{function_keys}( - map(_color_string, inner_strings, colors) - ) - return combine_strings(tree, colored_strings) - else - @assert can_combine(tree) - return DE.string_tree(combine(tree, raw_contents), operators; kws...) - end -end -function DE.eval_tree_array( - tree::TemplateExpression{T}, - cX::AbstractMatrix{T}, + tree::TemplateExpression, operators::Union{AbstractOperatorEnum,Nothing}=nothing; + pretty::Bool=false, + variable_names=nothing, kws..., -) where {T} +) raw_contents = get_contents(tree) - if can_combine_vectors(tree) - # Raw numerical results of each inner expression: - outs = map( - ex -> DE.eval_tree_array(ex, cX, operators; kws...), values(raw_contents) + function_keys = keys(raw_contents) + num_features = get_metadata(tree).structure.num_features + total_num_features = max(values(num_features)...) + colors = _colors(Val(length(function_keys))) + variable_names = ["#" * string(i) for i in 1:total_num_features] + inner_strings = NamedTuple{function_keys}( + map( + ex -> DE.string_tree(ex, operators; pretty, variable_names, kws...), + values(raw_contents), + ), + ) + strings = NamedTuple{function_keys}( + map( + (k, s, c) -> let + prefix = if !pretty || length(function_keys) == 1 + "" + elseif k == first(function_keys) + "╭ " + elseif k == last(function_keys) + "╰ " + else + "├ " + end + annotatedstring(prefix * string(k) * " = ", _color_string(s, c)) + end, + function_keys, + values(inner_strings), + colors, + ), + ) + return annotatedstring(join(strings, pretty ? styled"\n" : "; ")) +end +@stable( + default_mode = "disable", + default_union_limit = 2, + begin + function DE.eval_tree_array( + tree::TemplateExpression{T}, + cX::AbstractMatrix{T}, + operators::Union{AbstractOperatorEnum,Nothing}=nothing; + kws..., + ) where {T} + raw_contents = get_contents(tree) + if has_invalid_variables(tree) + return (nothing, false) + end + result = combine( + tree, raw_contents, map(x -> ValidVector(copy(x), true), eachrow(cX)) + ) + return result.x, result.valid + end + function (ex::TemplateExpression)( + X, operators::Union{AbstractOperatorEnum,Nothing}=nothing; kws... ) - # Combine them using the structure function: - results = NamedTuple{keys(raw_contents)}(map(first, outs)) - return combine_vectors(tree, results, cX), all(last, outs) - else - @assert can_combine(tree) - return DE.eval_tree_array(combine(tree, raw_contents), cX, operators; kws...) + result, valid = DE.eval_tree_array(ex, X, operators; kws...) + if valid + return result + else + return nothing + end + end end -end -function (ex::TemplateExpression)( - X, operators::Union{AbstractOperatorEnum,Nothing}=nothing; kws... ) - raw_contents = get_contents(ex) - if can_combine_vectors(ex) - results = NamedTuple{keys(raw_contents)}( - map(ex -> ex(X, operators; kws...), values(raw_contents)) - ) - return combine_vectors(ex, results, X) - else - @assert can_combine(ex) - callable = combine(ex, raw_contents) - return callable(X, operators; kws...) - end -end @unstable IDE.expected_array_type(::AbstractMatrix, ::Type{<:TemplateExpression}) = Any function DA.violates_dimensional_constraints( @@ -422,6 +380,13 @@ function CM.operator_specialization( return O end +function CM.max_features( + dataset::Dataset, options::Options{<:Any,<:Any,<:Any,<:TemplateExpression} +) + num_features = options.expression_options.structure.num_features + return max(values(num_features)...) +end + """ We pick a random subexpression to mutate, and also return the symbol we mutated on so that we can put it back together later. @@ -493,24 +458,17 @@ function CC.check_constraints( maxsize::Int, cursize::Union{Int,Nothing}=nothing, )::Bool - raw_contents = get_contents(ex) - variable_constraints = get_metadata(ex).structure.variable_constraints - # First, we check the variable constraints at the top level: - has_invalid_variables = any(keys(raw_contents)) do key - tree = raw_contents[key] - allowed_variables = variable_constraints[key] - contains_other_features_than(tree, allowed_variables) - end - if has_invalid_variables + if has_invalid_variables(ex) return false end # We also check the combined complexity: - ((cursize === nothing) ? ComplexityModule.compute_complexity(ex, options) : cursize) > - maxsize && return false + @something(cursize, ComplexityModule.compute_complexity(ex, options)) > maxsize && + return false # Then, we check other constraints for inner expressions: + raw_contents = get_contents(ex) for t in values(raw_contents) if !CC.check_constraints(t, options, maxsize, nothing) return false @@ -519,12 +477,21 @@ function CC.check_constraints( return true # TODO: The concept of `cursize` doesn't really make sense here. end -function contains_other_features_than(tree::AbstractExpression, features) - return contains_other_features_than(get_tree(tree), features) +function has_invalid_variables(ex::TemplateExpression) + raw_contents = get_contents(ex) + num_features = get_metadata(ex).structure.num_features + any(keys(raw_contents)) do key + tree = raw_contents[key] + max_feature = num_features[key] + contains_features_greater_than(tree, max_feature) + end +end +function contains_features_greater_than(tree::AbstractExpression, max_feature) + return contains_features_greater_than(get_tree(tree), max_feature) end -function contains_other_features_than(tree::AbstractExpressionNode, features) +function contains_features_greater_than(tree::AbstractExpressionNode, max_feature) any(tree) do node - node.degree == 0 && !node.constant && node.feature ∉ features + node.degree == 0 && !node.constant && node.feature > max_feature end end diff --git a/src/deprecates.jl b/src/deprecates.jl index c8e0b4d57..6b6fb29ac 100644 --- a/src/deprecates.jl +++ b/src/deprecates.jl @@ -4,7 +4,7 @@ import .HallOfFameModule: calculate_pareto_frontier import .MutationFunctionsModule: gen_random_tree, gen_random_tree_fixed_size @deprecate( - calculate_pareto_frontier(X, y, hallOfFame, options; weights=nothing, varMap=nothing), + calculate_pareto_frontier(X, y, hallOfFame, options; weights=nothing), calculate_pareto_frontier(hallOfFame) ) @deprecate( @@ -41,7 +41,6 @@ import .MutationFunctionsModule: gen_random_tree, gen_random_tree_fixed_size loss_type::Type=Nothing, # Deprecated: multithreaded=nothing, - varMap=nothing, ) where {T<:DATA_TYPE}, equation_search( X, @@ -58,7 +57,6 @@ import .MutationFunctionsModule: gen_random_tree, gen_random_tree_fixed_size saved_state, loss_type, multithreaded, - varMap, ) ) diff --git a/test/runtests.jl b/test/runtests.jl index fcc2c5b08..db9676b16 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -77,9 +77,7 @@ end include("test_nested_constraints.jl") end -@testitem "Test complexity evaluation" tags = [:part3] begin - include("test_complexity.jl") -end +include("test_complexity.jl") @testitem "Test options" tags = [:part1] begin include("test_options.jl") @@ -131,7 +129,6 @@ end ENV["SYMBOLIC_REGRESSION_IS_TESTING"] = "true" include("../examples/parameterized_function.jl") end -include("test_template_expression.jl") @testitem "Testing whether the recorder works." tags = [:part3] begin include("test_recorder.jl") @@ -165,6 +162,7 @@ end include("test_pretty_printing.jl") include("test_expression_builder.jl") +include("test_composable_expression.jl") @testitem "Aqua tests" tags = [:part2, :aqua] begin include("test_aqua.jl") diff --git a/test/test_complexity.jl b/test/test_complexity.jl index deaad6813..e3a3efd87 100644 --- a/test/test_complexity.jl +++ b/test/test_complexity.jl @@ -1,54 +1,81 @@ -println("Testing custom complexities.") -using SymbolicRegression, Test +@testitem "Test complexity evaluation" tags = [:part3] begin + using SymbolicRegression -x1, x2, x3 = Node("x1"), Node("x2"), Node("x3") + x1, x2, x3 = Node("x1"), Node("x2"), Node("x3") -# First, test regular complexities: -function make_options(; kw...) - return Options(; binary_operators=(+, -, *, /, ^), unary_operators=(cos, sin), kw...) + # First, test regular complexities: + function make_options(; kw...) + return Options(; + binary_operators=(+, -, *, /, ^), unary_operators=(cos, sin), kw... + ) + end + options = make_options() + @extend_operators options + tree = sin((x1 + x2 + x3)^2.3) + @test compute_complexity(tree, options) == 8 + + options = make_options(; complexity_of_operators=[sin => 3]) + @test compute_complexity(tree, options) == 10 + options = make_options(; complexity_of_operators=[sin => 3, (+) => 2]) + @test compute_complexity(tree, options) == 12 + + # Real numbers: + options = make_options(; complexity_of_operators=[sin => 3, (+) => 2, (^) => 3.2]) + @test compute_complexity(tree, options) == round(Int, 12 + (3.2 - 1)) +end + +@testitem "Test other things about complexity" tags = [:part3] begin + using SymbolicRegression + + x1, x2, x3 = Node("x1"), Node("x2"), Node("x3") + + function make_options(; kw...) + return Options(; + binary_operators=(+, -, *, /, ^), unary_operators=(cos, sin), kw... + ) + end + + options = make_options(; + complexity_of_operators=[sin => 3, (+) => 2], complexity_of_variables=2 + ) + tree = sin((x1 + x2 + x3)^2.3) + @test compute_complexity(tree, options) == 12 + 3 * 1 + options = make_options(; + complexity_of_operators=[sin => 3, (+) => 2], + complexity_of_variables=2, + complexity_of_constants=2, + ) + @test compute_complexity(tree, options) == 12 + 3 * 1 + 1 + options = make_options(; + complexity_of_operators=[sin => 3, (+) => 2], + complexity_of_variables=2, + complexity_of_constants=2.6, + ) + @test compute_complexity(tree, options) == 12 + 3 * 1 + 1 + 1 + + # Custom variables + options = make_options(; + complexity_of_variables=[1, 2, 3], complexity_of_operators=[(+) => 5, (*) => 2] + ) + x1, x2, x3 = [Node{Float64}(; feature=i) for i in 1:3] + tree = x1 + x2 * x3 + @test compute_complexity(tree, options) == 1 + 5 + 2 + 2 + 3 + options = make_options(; + complexity_of_variables=2, complexity_of_operators=[(+) => 5, (*) => 2] + ) + @test compute_complexity(tree, options) == 2 + 5 + 2 + 2 + 2 +end + +@testitem "Custom complexity mapping" tags = [:part3] begin + using SymbolicRegression + + function custom_complexity(tree) + @test tree isa AbstractExpression + return 10 + end + + options = Options(; complexity_mapping=custom_complexity) + variable_names = ["x1"] + x1 = Expression(Node{Float64}(; feature=1); options.operators, variable_names) + @test compute_complexity(x1, options) == 10 end -options = make_options() -@extend_operators options -tree = sin((x1 + x2 + x3)^2.3) -@test compute_complexity(tree, options) == 8 - -options = make_options(; complexity_of_operators=[sin => 3]) -@test compute_complexity(tree, options) == 10 -options = make_options(; complexity_of_operators=[sin => 3, (+) => 2]) -@test compute_complexity(tree, options) == 12 - -# Real numbers: -options = make_options(; complexity_of_operators=[sin => 3, (+) => 2, (^) => 3.2]) -@test compute_complexity(tree, options) == round(Int, 12 + (3.2 - 1)) - -# Now, test other things, like variables and constants: -options = make_options(; - complexity_of_operators=[sin => 3, (+) => 2], complexity_of_variables=2 -) -@test compute_complexity(tree, options) == 12 + 3 * 1 -options = make_options(; - complexity_of_operators=[sin => 3, (+) => 2], - complexity_of_variables=2, - complexity_of_constants=2, -) -@test compute_complexity(tree, options) == 12 + 3 * 1 + 1 -options = make_options(; - complexity_of_operators=[sin => 3, (+) => 2], - complexity_of_variables=2, - complexity_of_constants=2.6, -) -@test compute_complexity(tree, options) == 12 + 3 * 1 + 1 + 1 - -# Custom variables -options = make_options(; - complexity_of_variables=[1, 2, 3], complexity_of_operators=[(+) => 5, (*) => 2] -) -x1, x2, x3 = [Node{Float64}(; feature=i) for i in 1:3] -tree = x1 + x2 * x3 -@test compute_complexity(tree, options) == 1 + 5 + 2 + 2 + 3 -options = make_options(; - complexity_of_variables=2, complexity_of_operators=[(+) => 5, (*) => 2] -) -@test compute_complexity(tree, options) == 2 + 5 + 2 + 2 + 2 - -println("Passed.") diff --git a/test/test_composable_expression.jl b/test/test_composable_expression.jl new file mode 100644 index 000000000..fc17a6b5b --- /dev/null +++ b/test/test_composable_expression.jl @@ -0,0 +1,296 @@ +@testitem "Integration Test with fit! and Performance Check" tags = [:part3] begin + include("../examples/template_expression.jl") +end +@testitem "Test ComposableExpression" tags = [:part2] begin + using SymbolicRegression: ComposableExpression, Node + using DynamicExpressions: OperatorEnum + + operators = OperatorEnum(; binary_operators=(+, *, /, -), unary_operators=(sin, cos)) + variable_names = (i -> "x$i").(1:3) + ex = ComposableExpression(Node(Float64; feature=1); operators, variable_names) + x = randn(32) + y = randn(32) + + @test ex(x, y) == x +end + +@testitem "Test interface for ComposableExpression" tags = [:part2] begin + using SymbolicRegression: ComposableExpression + using DynamicExpressions.InterfacesModule: Interfaces, ExpressionInterface + using DynamicExpressions: OperatorEnum + + operators = OperatorEnum(; binary_operators=(+, *, /, -), unary_operators=(sin, cos)) + variable_names = (i -> "x$i").(1:3) + x1 = ComposableExpression(Node(Float64; feature=1); operators, variable_names) + x2 = ComposableExpression(Node(Float64; feature=2); operators, variable_names) + + f = x1 * sin(x2) + g = f(f, f) + + @test string_tree(f) == "x1 * sin(x2)" + @test string_tree(g) == "(x1 * sin(x2)) * sin(x1 * sin(x2))" + + @test Interfaces.test(ExpressionInterface, ComposableExpression, [f, g]) +end + +@testitem "Test interface for TemplateExpression" tags = [:part2] begin + using SymbolicRegression + using SymbolicRegression: TemplateExpression + using DynamicExpressions.InterfacesModule: Interfaces, ExpressionInterface + using DynamicExpressions: OperatorEnum + + operators = OperatorEnum(; binary_operators=(+, *, /, -), unary_operators=(sin, cos)) + variable_names = (i -> "x$i").(1:3) + x1 = ComposableExpression(Node(Float64; feature=1); operators, variable_names) + x2 = ComposableExpression(Node(Float64; feature=2); operators, variable_names) + + structure = TemplateStructure{(:f, :g)}( + ((; f, g), (x1, x2)) -> f(f(f(x1))) - f(g(x2, x1)) + ) + @test structure.num_features == (; f=1, g=2) + + expr = TemplateExpression((; f=x1, g=x2 * x2); structure, operators, variable_names) + + @test String(string_tree(expr)) == "f = #1; g = #2 * #2" + @test String(string_tree(expr; pretty=true)) == "╭ f = #1\n╰ g = #2 * #2" + @test string_tree(get_tree(expr), operators) == "x1 - (x1 * x1)" + @test Interfaces.test(ExpressionInterface, TemplateExpression, [expr]) +end + +@testitem "Printing and evaluation of TemplateExpression" tags = [:part2] begin + using SymbolicRegression + + structure = TemplateStructure{(:f, :g)}( + ((; f, g), (x1, x2, x3)) -> sin(f(x1, x2)) + g(x3)^2 + ) + operators = Options().operators + variable_names = ["x1", "x2", "x3"] + + x1, x2, x3 = [ + ComposableExpression(Node{Float64}(; feature=i); operators, variable_names) for + i in 1:3 + ] + f = x1 * x2 + g = x1 + expr = TemplateExpression((; f, g); structure, operators, variable_names) + + # Default printing strategy: + @test String(string_tree(expr)) == "f = #1 * #2; g = #1" + + x1_val = randn(5) + x2_val = randn(5) + + # The feature indicates the index passed as argument: + @test x1(x1_val) ≈ x1_val + @test x2(x1_val, x2_val) ≈ x2_val + @test x1(x2_val) ≈ x2_val + + # Composing expressions and then calling: + @test String(string_tree((x1 * x2)(x3, x3))) == "x3 * x3" + + # Can evaluate with `sin` even though it's not in the allowed operators! + X = randn(3, 5) + x1_val = X[1, :] + x2_val = X[2, :] + x3_val = X[3, :] + @test expr(X) ≈ @. sin(x1_val * x2_val) + x3_val^2 + + # This is even though `g` is defined on `x1` only: + @test g(x3_val) ≈ x3_val +end + +@testitem "Test error handling" tags = [:part2] begin + using SymbolicRegression + using SymbolicRegression: ComposableExpression, Node, ValidVector + using DynamicExpressions: OperatorEnum + + operators = OperatorEnum(; binary_operators=(+, *, /, -), unary_operators=(sin, cos)) + variable_names = (i -> "x$i").(1:3) + ex = ComposableExpression(Node{Float64}(; feature=1); operators, variable_names) + + # Test error for unsupported input type with specific message + @test_throws "ComposableExpression does not support input of type String" ex( + "invalid input" + ) + + # Test ValidVector operations with numbers + x = ValidVector([1.0, 2.0, 3.0], true) + + # Test binary operations between ValidVector and Number + @test (x + 2.0).x ≈ [3.0, 4.0, 5.0] + @test (2.0 + x).x ≈ [3.0, 4.0, 5.0] + @test (x * 2.0).x ≈ [2.0, 4.0, 6.0] + @test (2.0 * x).x ≈ [2.0, 4.0, 6.0] + + # Test unary operations on ValidVector + @test sin(x).x ≈ sin.([1.0, 2.0, 3.0]) + @test cos(x).x ≈ cos.([1.0, 2.0, 3.0]) + @test abs(x).x ≈ [1.0, 2.0, 3.0] + @test (-x).x ≈ [-1.0, -2.0, -3.0] + + # Test propagation of invalid flag + invalid_x = ValidVector([1.0, 2.0, 3.0], false) + @test !((invalid_x + 2.0).valid) + @test !((2.0 + invalid_x).valid) + @test !(sin(invalid_x).valid) + + # Test that regular numbers are considered valid + @test (x + 2).valid + @test sin(x).valid +end +@testitem "Test validity propagation with NaN" tags = [:part2] begin + using SymbolicRegression: ComposableExpression, Node, ValidVector + using DynamicExpressions: OperatorEnum + + operators = OperatorEnum(; binary_operators=(+, *, /, -), unary_operators=(sin, cos)) + variable_names = (i -> "x$i").(1:3) + x1 = ComposableExpression(Node{Float64}(; feature=1); operators, variable_names) + x2 = ComposableExpression(Node{Float64}(; feature=2); operators, variable_names) + x3 = ComposableExpression(Node{Float64}(; feature=3); operators, variable_names) + + ex = 1.0 + x2 / x1 + + @test ex([1.0], [2.0]) ≈ [3.0] + + @test ex([1.0, 1.0], [2.0, 2.0]) |> Base.Fix1(count, isnan) == 0 + @test ex([1.0, 0.0], [2.0, 2.0]) |> Base.Fix1(count, isnan) == 2 + + x1_val = ValidVector([1.0, 2.0], false) + x2_val = ValidVector([1.0, 2.0], false) + @test ex(x1_val, x2_val).valid == false +end + +@testitem "Test nothing return and type inference for TemplateExpression" tags = [:part2] begin + using SymbolicRegression + using Test: @inferred + + # Create a template expression that divides by x1 + structure = TemplateStructure{(:f,)}(((; f), (x1, x2)) -> 1.0 + f(x1) / x1) + operators = Options(; binary_operators=(+, -, *, /)).operators + variable_names = ["x1", "x2"] + + x1 = ComposableExpression(Node{Float64}(; feature=1); operators, variable_names) + x2 = ComposableExpression(Node{Float64}(; feature=2); operators, variable_names) + expr = TemplateExpression((; f=x1); structure, operators, variable_names) + + # Test division by zero returns nothing + X = [0.0 1.0]' + @test expr(X) === nothing + + # Test type inference + X_good = [1.0 2.0]' + @test @inferred(Union{Nothing,Vector{Float64}}, expr(X_good)) ≈ [2.0] + + # Test type inference with ValidVector input + x1_val = ValidVector([1.0], true) + x2_val = ValidVector([2.0], true) + @test @inferred(ValidVector{Vector{Float64}}, x1(x1_val, x2_val)).x ≈ [1.0] + + x2_val_false = ValidVector([2.0], false) + @test @inferred(x1(x1_val, x2_val_false)).valid == false +end +@testitem "Test compatibility with power laws" tags = [:part3] begin + using SymbolicRegression + using DynamicExpressions: OperatorEnum + + operators = OperatorEnum(; binary_operators=(+, -, *, /, ^)) + variable_names = ["x1", "x2"] + x1 = ComposableExpression(Node{Float64}(; feature=1); operators, variable_names) + x2 = ComposableExpression(Node{Float64}(; feature=2); operators, variable_names) + + structure = TemplateStructure{(:f,)}(((; f), (x1, x2)) -> f(x1)^f(x2)) + expr = TemplateExpression((; f=x1); structure, operators, variable_names) + + # There shouldn't be an error when we evaluate with invalid + # expressions, even though the source of the NaN comes from the structure + # function itself: + X = -rand(2, 32) + @test expr(X) === nothing +end + +@testitem "Test constraints checking in TemplateExpression" tags = [:part2] begin + using SymbolicRegression + using SymbolicRegression: CheckConstraintsModule as CC + + # Create a template expression with nested exponentials + options = Options(; + binary_operators=(+, -, *, /), + unary_operators=(exp,), + nested_constraints=[exp => [exp => 1]], # Only allow one nested exp + ) + operators = options.operators + variable_names = ["x1", "x2"] + + # Create a valid inner expression + x1 = ComposableExpression(Node{Float64}(; feature=1); operators, variable_names) + valid_expr = exp(x1) # One exp is ok + + # Create an invalid inner expression with too many nested exp + invalid_expr = exp(exp(exp(x1))) + # Three nested exp's violates constraint + + @test CC.check_constraints(valid_expr, options, 20) + @test !CC.check_constraints(invalid_expr, options, 20) +end + +@testitem "Test feature constraints in TemplateExpression" tags = [:part1] begin + using SymbolicRegression + using DynamicExpressions: Node + + operators = Options(; binary_operators=(+, -, *, /)).operators + variable_names = ["x1", "x2", "x3"] + + # Create a structure where f only gets access to x1, x2 + # and g only gets access to x3 + structure = TemplateStructure{(:f, :g)}(((; f, g), (x1, x2, x3)) -> f(x1, x2) + g(x3)) + + x1 = ComposableExpression(Node{Float64}(; feature=1); operators, variable_names) + x2 = ComposableExpression(Node{Float64}(; feature=2); operators, variable_names) + x3 = ComposableExpression(Node{Float64}(; feature=3); operators, variable_names) + + # Test valid case - each function only uses allowed features + valid_f = x1 + x2 + valid_g = x1 + valid_template = TemplateExpression( + (; f=valid_f, g=valid_g); structure, operators, variable_names + ) + @test valid_template([1.0 2.0 3.0]') ≈ [6.0] # (1 + 2) + 3 + + # Test invalid case - f tries to use x3 which it shouldn't have access to + invalid_f = x1 + x3 + invalid_template = TemplateExpression( + (; f=invalid_f, g=valid_g); structure, operators, variable_names + ) + @test invalid_template([1.0 2.0 3.0]') === nothing + + # Test invalid case - g tries to use x2 which it shouldn't have access to + invalid_g = x2 + invalid_template2 = TemplateExpression( + (; f=valid_f, g=invalid_g); structure, operators, variable_names + ) + @test invalid_template2([1.0 2.0 3.0]') === nothing +end +@testitem "Test invalid structure" tags = [:part3] begin + using SymbolicRegression + + operators = Options(; binary_operators=(+, -, *, /)).operators + variable_names = ["x1", "x2", "x3"] + + x1 = ComposableExpression(Node{Float64}(; feature=1); operators, variable_names) + x2 = ComposableExpression(Node{Float64}(; feature=2); operators, variable_names) + x3 = ComposableExpression(Node{Float64}(; feature=3); operators, variable_names) + + @test_throws ArgumentError TemplateStructure{(:f,)}( + ((; f), (x1, x2)) -> f(x1) + f(x1, x2) + ) + @test_throws "Inconsistent number of arguments passed to f" TemplateStructure{(:f,)}( + ((; f), (x1, x2)) -> f(x1) + f(x1, x2) + ) + + @test_throws ArgumentError TemplateStructure{(:f, :g)}(((; f, g), (x1, x2)) -> f(x1)) + @test_throws "Failed to infer number of features used by (:g,)" TemplateStructure{( + :f, :g + )}( + ((; f, g), (x1, x2)) -> f(x1) + ) +end diff --git a/test/test_custom_operators_multiprocessing.jl b/test/test_custom_operators_multiprocessing.jl index 2fca2298e..22c978771 100644 --- a/test/test_custom_operators_multiprocessing.jl +++ b/test/test_custom_operators_multiprocessing.jl @@ -1,4 +1,5 @@ using SymbolicRegression +using Test defs = quote _plus(x, y) = x + y @@ -7,14 +8,18 @@ defs = quote _min(x, y) = x - y _cos(x) = cos(x) _exp(x) = exp(x) - early_stop(loss, c) = ((loss <= 1e-10) && (c <= 10)) + early_stop(loss, c) = ((loss <= 1e-10) && (c <= 6)) my_loss(x, y, w) = abs(x - y)^2 * w + my_complexity(ex) = ceil(Int, length($(get_tree)(ex)) / 2) end # This is needed as workers are initialized in `Core.Main`! if (@__MODULE__) != Core.Main Core.eval(Core.Main, defs) - eval(:(using Main: _plus, _mult, _div, _min, _cos, _exp, early_stop, my_loss)) + eval( + :(using Main: + _plus, _mult, _div, _min, _cos, _exp, early_stop, my_loss, my_complexity), + ) else eval(defs) end @@ -26,8 +31,10 @@ options = SymbolicRegression.Options(; binary_operators=(_plus, _mult, _div, _min), unary_operators=(_cos, _exp), populations=20, + maxsize=15, early_stop_condition=early_stop, elementwise_loss=my_loss, + complexity_mapping=my_complexity, ) hof = equation_search( @@ -41,5 +48,6 @@ hof = equation_search( ) @test any( - early_stop(member.loss, count_nodes(member.tree)) for member in hof.members[hof.exists] + early_stop(member.loss, my_complexity(member.tree)) for + member in hof.members[hof.exists] ) diff --git a/test/test_expression_builder.jl b/test/test_expression_builder.jl index 37b9291f3..50028ff4b 100644 --- a/test/test_expression_builder.jl +++ b/test/test_expression_builder.jl @@ -15,10 +15,10 @@ ) X = ones(1, 1) * 2 y = ones(1) - dataset = Dataset(X, y; extra=(; classes=[1])) + dataset = Dataset(X, y; extra=(; class=[1])) @test ex isa ParametricExpression - @test ex(dataset.X, dataset.extra.classes) ≈ ones(1, 1) * 6 + @test ex(dataset.X, dataset.extra.class) ≈ ones(1, 1) * 6 # Mistake in that we gave the wrong options! @test_throws( diff --git a/test/test_expression_derivatives.jl b/test/test_expression_derivatives.jl index c8cba75ae..359b405bf 100644 --- a/test/test_expression_derivatives.jl +++ b/test/test_expression_derivatives.jl @@ -84,18 +84,18 @@ end true_params = [0.5 2.0] init_params = [0.1 0.2] init_constants = [2.5, -0.5] - classes = rand(rng, 1:2, 32) + class = rand(rng, 1:2, 32) y = [ - X[1, i] * X[1, i] - cos(2.6 * X[2, i] - 0.2) + true_params[1, classes[i]] for + X[1, i] * X[1, i] - cos(2.6 * X[2, i] - 0.2) + true_params[1, class[i]] for i in 1:32 ] - dataset = Dataset(X, y; extra=(; classes)) + dataset = Dataset(X, y; extra=(; class)) (true_val, (true_d_params, true_d_constants)) = value_and_gradient(AutoZygote(), (init_params, init_constants)) do (params, c) pred = [ - X[1, i] * X[1, i] - cos(c[1] * X[2, i] + c[2]) + params[1, classes[i]] for + X[1, i] * X[1, i] - cos(c[1] * X[2, i] + c[2]) + params[1, class[i]] for i in 1:32 ] sum(abs2, pred .- y) / length(y) diff --git a/test/test_mixed_utils.jl b/test/test_mixed_utils.jl index 2ad9e7636..25af052c7 100644 --- a/test/test_mixed_utils.jl +++ b/test/test_mixed_utils.jl @@ -1,6 +1,5 @@ -using SymbolicRegression -using SymbolicRegression: string_tree -using Random, Bumper, LoopVectorization +using SymbolicRegression, Random, Bumper, LoopVectorization +using SymbolicRegression: string_tree, node_type include("test_params.jl") diff --git a/test/test_mlj.jl b/test/test_mlj.jl index a4348fd28..ecc9f5331 100644 --- a/test/test_mlj.jl +++ b/test/test_mlj.jl @@ -144,7 +144,7 @@ end fit!(mach) # Check predictions - @test sum(abs2, predict(mach, X) .- Y) / length(X) < 1e-6 + @test sum(abs2, predict(mach, X) .- Y) / length(X) < 1e-5 # Load the output CSV file for i in 1:3 diff --git a/test/test_template_expression.jl b/test/test_template_expression.jl deleted file mode 100644 index 04836cf15..000000000 --- a/test/test_template_expression.jl +++ /dev/null @@ -1,227 +0,0 @@ -@testitem "Basic utility of the TemplateExpression" tags = [:part3] begin - using SymbolicRegression - using SymbolicRegression: SymbolicRegression as SR - using SymbolicRegression.CheckConstraintsModule: check_constraints - using DynamicExpressions: OperatorEnum - - options = Options(; binary_operators=(+, *, /, -), unary_operators=(sin, cos)) - operators = options.operators - variable_names = ["x1", "x2", "x3"] - x1, x2, x3 = - (i -> Expression(Node(Float64; feature=i); operators, variable_names)).(1:3) - - # For combining expressions to a single expression: - structure = TemplateStructure(; - combine=e -> sin(e.f) + e.g * e.g, - combine_vectors=e -> (@. sin(e.f) + e.g^2), - combine_strings=e -> "sin($(e.f)) + $(e.g)^2", - variable_constraints=(; f=[1, 2], g=[3]), - ) - - @test structure isa TemplateStructure{(:f, :g)} - - st_expr = TemplateExpression((; f=x1, g=cos(x3)); structure, operators, variable_names) - @test string_tree(st_expr) == "sin(x1) + cos(x3)^2" - operators = OperatorEnum(; binary_operators=(+, *, /, -), unary_operators=(cos, sin)) - - # Changing the operators will change how the expression is interpreted for - # parts that are already evaluated: - @test string_tree(st_expr, operators) == "sin(x1) + sin(x3)^2" - - # We can evaluate with this too: - cX = [1.0 2.0; 3.0 4.0; 5.0 6.0] - out = st_expr(cX) - @test out ≈ [sin(1.0) + cos(5.0)^2, sin(2.0) + cos(6.0)^2] - - # And also check the contents: - @test check_constraints(st_expr, options, 100) - - # We can see that violating the constraints will cause a violation: - new_expr = with_contents(st_expr, (; f=x3, g=cos(x3))) - @test !check_constraints(new_expr, options, 100) - new_expr = with_contents(st_expr, (; f=x2, g=cos(x3))) - @test check_constraints(new_expr, options, 100) - new_expr = with_contents(st_expr, (; f=x2, g=cos(x1))) - @test !check_constraints(new_expr, options, 100) - - # Checks the size of each individual expression: - new_expr = with_contents(st_expr, (; f=x2, g=cos(x3))) - - @test compute_complexity(new_expr, options) == 3 - @test check_constraints(new_expr, options, 3) - @test !check_constraints(new_expr, options, 2) -end -@testitem "Expression interface" tags = [:part3] begin - using SymbolicRegression - using DynamicExpressions: OperatorEnum - using DynamicExpressions.InterfacesModule: Interfaces, ExpressionInterface - - operators = OperatorEnum(; binary_operators=(+, *, /, -), unary_operators=(sin, cos)) - variable_names = (i -> "x$i").(1:3) - x1, x2, x3 = - (i -> Expression(Node(Float64; feature=i); operators, variable_names)).(1:3) - - # For combining expressions to a single expression: - structure = TemplateStructure{(:f, :g)}(; - combine=e -> sin(e.f) + e.g * e.g, - combine_strings=e -> "sin($(e.f)) + $(e.g)^2", - combine_vectors=e -> (@. sin(e.f) + e.g^2), - variable_constraints=(; f=[1, 2], g=[3]), - ) - st_expr = TemplateExpression((; f=x1, g=x3); structure, operators, variable_names) - @test Interfaces.test(ExpressionInterface, TemplateExpression, [st_expr]) -end -@testitem "Utilising TemplateExpression to build vector expressions" tags = [:part3] begin - using SymbolicRegression - using Random: rand - - # Define the structure function, which returns a tuple: - structure = TemplateStructure{(:f, :g1, :g2, :g3)}(; - combine_strings=e -> "( $(e.f) + $(e.g1), $(e.f) + $(e.g2), $(e.f) + $(e.g3) )", - combine_vectors=e -> - map((f, g1, g2, g3) -> (f + g1, f + g2, f + g3), e.f, e.g1, e.g2, e.g3), - ) - - # Set up operators and variable names - options = Options(; binary_operators=(+, *, /, -), unary_operators=(sin, cos)) - variable_names = (i -> "x$i").(1:3) - - # Create expressions - x1, x2, x3 = - (i -> Expression(Node(Float64; feature=i); options.operators, variable_names)).(1:3) - - # Test with vector inputs: - nt_vector = NamedTuple{(:f, :g1, :g2, :g3)}((1:3, 4:6, 7:9, 10:12)) - @test structure(nt_vector) == [(5, 8, 11), (7, 10, 13), (9, 12, 15)] - - # And string inputs: - nt_string = NamedTuple{(:f, :g1, :g2, :g3)}(("x1", "x2", "x3", "x2")) - @test structure(nt_string) == "( x1 + x2, x1 + x3, x1 + x2 )" - - # Now, using TemplateExpression: - st_expr = TemplateExpression( - (; f=x1, g1=x2, g2=x3, g3=x2); structure, options.operators, variable_names - ) - @test string_tree(st_expr) == "( x1 + x2, x1 + x3, x1 + x2 )" - - # We can directly call it: - cX = [1.0 2.0; 3.0 4.0; 5.0 6.0] - out = st_expr(cX) - @test out == [(1 + 3, 1 + 5, 1 + 3), (2 + 4, 2 + 6, 2 + 4)] -end -@testitem "TemplateExpression getters" tags = [:part3] begin - using SymbolicRegression - using DynamicExpressions: get_operators, get_variable_names - - operators = - Options(; binary_operators=(+, *, /, -), unary_operators=(sin, cos)).operators - variable_names = (i -> "x$i").(1:3) - x1, x2, x3 = - (i -> Expression(Node(Float64; feature=i); operators, variable_names)).(1:3) - - structure = TemplateStructure(; - combine=e -> e.f, variable_constraints=(; f=[1, 2], g1=[3], g2=[3], g3=[3]) - ) - - st_expr = TemplateExpression( - (; f=x1, g1=x3, g2=x3, g3=x3); structure, operators, variable_names - ) - - @test st_expr isa TemplateExpression - @test get_operators(st_expr) == operators - @test get_variable_names(st_expr) == variable_names - @test get_metadata(st_expr).structure == structure -end -@testitem "Integration Test with fit! and Performance Check" tags = [:part3] begin - include("../examples/template_expression.jl") -end -@testitem "TemplateExpression with only combine function" tags = [:part3] begin - using SymbolicRegression - using SymbolicRegression.TemplateExpressionModule: - can_combine_vectors, can_combine, get_function_keys - using SymbolicRegression.InterfaceDynamicExpressionsModule: expected_array_type - using DynamicExpressions: constructorof - - # Set up basic operators and variables - options = Options(; binary_operators=(+, *, /, -), unary_operators=(sin, cos)) - operators = options.operators - variable_names = ["x1", "x2", "x3"] - x1, x2, x3 = - (i -> Expression(Node(Float64; feature=i); operators, variable_names)).(1:3) - - # Create a TemplateStructure with only combine (no combine_vectors) - structure = TemplateStructure(; - combine=e -> sin(e.f) + e.g * e.g, # Only define combine - variable_constraints=(; f=[1, 2], g=[3]), - ) - - # Create the TemplateExpression - st_expr = TemplateExpression((; f=x1, g=cos(x3)); structure, operators, variable_names) - - @test constructorof(typeof(st_expr)) === TemplateExpression - @test get_function_keys(st_expr) == (:f, :g) - - # Test evaluation - cX = [1.0 2.0; 3.0 4.0; 5.0 6.0] - out = st_expr(cX) - out_2, complete = eval_tree_array(st_expr, cX) - - # The expression should evaluate by first combining to a single expression, - # then evaluating that expression - expected = sin.(cX[1, :]) .+ cos.(cX[3, :]) .^ 2 - @test out ≈ expected - - @test complete - @test out_2 ≈ expected - - # Verify that can_combine_vectors is false but can_combine is true - @test !can_combine_vectors(st_expr) - @test can_combine(st_expr) - - @test expected_array_type(cX, typeof(st_expr)) === Any - - @test string_tree(st_expr) == "sin(x1) + (cos(x3) * cos(x3))" -end -@testitem "TemplateExpression with data in combine_vectors" tags = [:part3] begin - using SymbolicRegression - - options = Options(; binary_operators=(+, *, /, -), unary_operators=(sin, cos, exp)) - operators = options.operators - variable_names = ["x1", "x2", "x3"] - x1, x2, x3 = - (i -> Expression(Node(Float64; feature=i); operators, variable_names)).(1:3) - f = exp(2.5 * x3) - g = x1 - structure = TemplateStructure(; - combine_vectors=(e, X) -> e.f .+ X[2, :], variable_constraints=(; f=[3], g=[1]) - ) - st_expr = TemplateExpression((; f, g); structure, operators, variable_names) - X = randn(3, 100) - @test st_expr(X) ≈ @. exp(2.5 * X[3, :]) + X[2, :] -end -@testitem "TemplateStructure constructors" tags = [:part3] begin - using SymbolicRegression - - operators = Options(; binary_operators=(+, *, /, -)).operators - variable_names = ["x1", "x2"] - - # Create simple expressions with constant values - f = Expression(Node(Float64; val=1.0); operators, variable_names) - g = Expression(Node(Float64; val=2.0); operators, variable_names) - - # Test TemplateStructure{K}(combine; kws...) - st1 = TemplateStructure{(:f, :g)}(e -> e.f + e.g) - @test st1.combine((; f, g)) == f + g - - # Test TemplateStructure(combine; kws...) - st2 = TemplateStructure(e -> e.f + e.g; variable_constraints=(; f=[1], g=[2])) - @test st2.combine((; f, g)) == f + g - - # Test error when no K or variable_constraints provided - @test_throws ArgumentError TemplateStructure(e -> e.f + e.g) - @test_throws ArgumentError( - "If `variable_constraints` is not provided, " * - "you must initialize `TemplateStructure` with " * - "`TemplateStructure{K}(...)`, for tuple of symbols `K`.", - ) TemplateStructure(e -> e.f + e.g) -end diff --git a/test/test_units.jl b/test/test_units.jl index a586f5e3c..da7f45fa3 100644 --- a/test/test_units.jl +++ b/test/test_units.jl @@ -337,15 +337,15 @@ end @test string_tree(tree, options) == "(1.0 * (x1 + ((x2 * x3) * 5.32))) - cos(1.5 * (x1 - 0.5))" - @test string_tree(tree, options; raw=false) == + @test string_tree(tree, options; pretty=true) == "(1 * (x₁ + ((x₂ * x₃) * 5.32))) - cos(1.5 * (x₁ - 0.5))" @test string_tree( - tree, options; raw=false, display_variable_names=dataset.display_variable_names + tree, options; pretty=true, display_variable_names=dataset.display_variable_names ) == "(1 * (x₁ + ((x₂ * x₃) * 5.32))) - cos(1.5 * (x₁ - 0.5))" @test string_tree( tree, options; - raw=false, + pretty=true, display_variable_names=dataset.display_variable_names, X_sym_units=dataset.X_sym_units, y_sym_units=dataset.y_sym_units, @@ -355,7 +355,7 @@ end @test string_tree( x5 * 3.2, options; - raw=false, + pretty=true, display_variable_names=dataset.display_variable_names, X_sym_units=dataset.X_sym_units, y_sym_units=dataset.y_sym_units, @@ -366,7 +366,7 @@ end @test string_tree( x5 * 3.2, options; - raw=false, + pretty=true, display_variable_names=dataset2.display_variable_names, X_sym_units=dataset2.X_sym_units, y_sym_units=dataset2.y_sym_units, @@ -381,7 +381,7 @@ end @test string_tree( x5 * 3.2, options; - raw=false, + pretty=true, display_variable_names=dataset2.display_variable_names, X_sym_units=dataset2.X_sym_units, y_sym_units=dataset2.y_sym_units,