Skip to content

Commit

Permalink
Merge pull request #365 from MilesCranmer/composable-expressions
Browse files Browse the repository at this point in the history
Rewrite `TemplateExpression` to enable hierarchical expressions
  • Loading branch information
MilesCranmer authored Nov 7, 2024
2 parents bc9edaf + 113f2c6 commit cf4c0c2
Show file tree
Hide file tree
Showing 41 changed files with 1,294 additions and 818 deletions.
104 changes: 50 additions & 54 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,68 +91,61 @@ A `TemplateExpression` is constructed by specifying:
For example, you can create a `TemplateExpression` that enforces
the constraint: `sin(f(x1, x2)) + g(x3)^2` - where we evolve `f` and `g` simultaneously.

Let's see some code for this. First, we define some base expressions for each input feature:
To do this, we first describe the structure using `TemplateStructure`
that takes a single closure function that maps a named tuple of
`ComposableExpression` expressions and a tuple of features:

```julia
using SymbolicRegression

options = Options(; binary_operators=(+, *, /, -), unary_operators=(sin, cos))
operators = options.operators
variable_names = ["x1", "x2", "x3"]

# Base expressions:
x1 = Expression(Node{Float64}(; feature=1); operators, variable_names)
x2 = Expression(Node{Float64}(; feature=2); operators, variable_names)
x3 = Expression(Node{Float64}(; feature=3); operators, variable_names)
structure = TemplateStructure{(:f, :g)}(
((; f, g), (x1, x2, x3)) -> sin(f(x1, x2)) + g(x3)^2
)
```

To build a `TemplateExpression`, we specify the structure using
a `TemplateStructure` object. This class has several fields:
This defines how the `TemplateExpression` should be
evaluated numerically on a given input.

- `combine`: Optional function taking a `NamedTuple` of function keys => expressions,
returning a single expression. Fallback method used by `get_tree`
on a `TemplateExpression` to generate a single `Expression`.
- `combine_vectors`: Optional function taking a `NamedTuple` of function keys => vectors,
returning a single vector. Used for evaluating the expression tree.
You may optionally define a method with a second argument `X` for if you wish
to include the data matrix `X` (of shape `[num_features, num_rows]`) in the
computation.
- `combine_strings`: Optional function taking a `NamedTuple` of function keys => strings,
returning a single string. Used for printing the expression tree.
- `variable_constraints`: Optional `NamedTuple` that defines which variables each sub-expression is allowed to access.
For example, requesting `f(x1, x2)` and `g(x3)` would be equivalent to `(; f=[1, 2], g=[3])`.

Let's see an example:
The number of arguments allowed by each expression object
is inferred using this closure, though it can also
be passed explicitly with the `num_features` kwarg.

```julia

# Combine f and g them into a single scalar expression:
structure = TemplateStructure(;
combine_strings=e -> "sin(" * e.f * ") + (" * e.g * ")^2",
combine_vectors=e -> map((f, g) -> sin(f) + g * g, e.f, e.g),
variable_constraints = (; f=[1, 2], g=[3]), # We constrain it to f(x1, x2) and g(x3)
)
operators = Options(binary_operators=(+, -, *, /)).operators
variable_names = ["x1", "x2", "x3"]
x1 = ComposableExpression(Node{Float64}(; feature=1); operators, variable_names)
x2 = ComposableExpression(Node{Float64}(; feature=2); operators, variable_names)
x3 = ComposableExpression(Node{Float64}(; feature=3); operators, variable_names)
```

This defines how the `TemplateExpression` should be evaluated numerically on a given input,
and also how it should be represented as a string:
Note that using `x1` here refers to the
_relative_ argument to the expression.
So the node with feature equal to 1 will reference
the first argument, regardless of what it is.

```julia
julia> f_example = x1 - x2 * x2; # Normal `Expression` object

julia> g_example = 1.5 * x3;

julia> # Create TemplateExpression from these sub-expressions:
st_expr = TemplateExpression((; f=f_example, g=g_example); structure, operators, variable_names);
st_expr = TemplateExpression(
(; f=x1 - x2 * x2, g=1.5 * x1);
structure,
operators,
variable_names
) # Prints as: f = #1 - (#2 * #2); g = 1.5 * #1

# Evaluation combines evaluation of `f` and `g`, and combines them
# with the structure function:
st_expr([0.0; 1.0; 2.0;;])
```

julia> st_expr # Prints using `my_structure`!
sin(x1 - (x2 * x2)) + 1.5 * x3^2
This also work with hierarchical expressions! For example,

julia> st_expr([0.0; 1.0; 2.0;;]) # Combines evaluation of `f` and `g` via `my_structure`!
1-element Vector{Float64}:
8.158529015192103
```julia
structure = TemplateStructure{(:f, :g)}(
((; f, g), (x1, x2, x3)) -> f(x1, g(x2), x3^2) - g(x3)
)
```

this is a valid structure!

We can also use this `TemplateExpression` in SymbolicRegression.jl searches!

<details>
Expand All @@ -168,11 +161,17 @@ This also has our variable mapping, which says
we are fitting `f(x1, x2)`, `g1(x3)`, and `g2(x3)`:

```julia
structure = TemplateStructure(;
combine_strings=e -> "( " * e.f * " + " * e.g1 * ", " * e.f * " + " * e.g2 * " )",
combine_vectors=e -> map(i -> (e.f[i] + e.g1[i], e.f[i] + e.g2[i]), eachindex(e.f)),
variable_constraints = (; f=[1, 2], g1=[3], g2=[3]),
)
function my_structure((; f, g1, g2), (x1, x2, x3))
_f = f(x1, x2)
_g1 = g1(x3)
_g2 = g2(x3)

# We use `.x` to get the underlying vector
out = map((fi, g1i, g2i) -> (fi + g1i, fi + g2i), _f.x, _g1.x, _g2.x)
# And `.valid` to see whether the evaluations
return ValidVector(out, _f.valid && _g1.valid && _g2.valid)
end
structure = TemplateStructure{(:f, :g1, :g2)}(my_structure)
```

Now, our dataset is a regular 2D array of inputs for `X`.
Expand All @@ -182,10 +181,7 @@ But our `y` is actually a _vector of 2-tuples_!
X = rand(100, 3) .* 10

y = [
(
sin(X[i, 1]) + X[i, 3]^2,
sin(X[i, 1]) + X[i, 3]
)
(sin(X[i, 1]) + X[i, 3]^2, sin(X[i, 1]) + X[i, 3])
for i in eachindex(axes(X, 1))
]
```
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ Dates = "1"
DifferentiationInterface = "0.5, 0.6"
DispatchDoctor = "^0.4.17"
Distributed = "<0.0.1, 1"
DynamicExpressions = "1.4"
DynamicExpressions = "1.5.0"
DynamicQuantities = "1"
Enzyme = "0.12"
JSON3 = "1"
Expand Down
2 changes: 2 additions & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
DynamicExpressions = "a40a106e-89c9-4ca8-8020-a735e8728b6b"
Gumbo = "708ec375-b3d6-5a57-a7ce-8257bf98657a"
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
Documenter = "0.27"
1 change: 1 addition & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ makedocs(;
"Examples" => [
"Short Examples" => "examples.md",
"Template Expressions" => "examples/template_expression.md",
"Parameterized Expressions" => "examples/parameterized_function.md",
],
"API" => "api.md",
"Losses" => "losses.md",
Expand Down
27 changes: 8 additions & 19 deletions docs/src/types.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,32 +63,21 @@ These types allow you to define expressions with parameters that can be tuned to
## Template Expressions

Template expressions allow you to specify predefined structures and constraints for your expressions.
These use the new `TemplateStructure` type to define how expressions should be combined and evaluated.
These use `ComposableExpressions` as their internal expression type, which makes them
flexible for creating a structure out of a single function.

These use the `TemplateStructure` type to define how expressions should be combined and evaluated.

```@docs
TemplateExpression
TemplateStructure
```

Example usage:

```julia
# Define a template structure
structure = TemplateStructure(
combine=e -> e.f + e.g, # Create normal `Expression`
combine_vectors=e -> (e.f .+ e.g), # Output vector
combine_strings=e -> "($e.f) + ($e.g)", # Output string
variable_constraints=(; f=[1, 2], g=[3]) # Constrain dependencies
)

# Use in options
model = SRRegressor(;
expression_type=TemplateExpression,
expression_options=(; structure=structure)
)
```
Composable expressions allow you to combine multiple expressions together.

The `variable_constraints` field allows you to specify which variables can be used in different parts of the expression.
```@docs
ComposableExpression
```

## Population

Expand Down
93 changes: 80 additions & 13 deletions examples/parameterized_function.jl
Original file line number Diff line number Diff line change
@@ -1,21 +1,62 @@
#literate_begin file="src/examples/parameterized_function.md"
#=
# Learning Parameterized Expressions
_Note: Parametric expressions are currently considered experimental and may change in the future._
Parameterized expressions in SymbolicRegression.jl allow you to discover symbolic expressions that contain
optimizable parameters. This is particularly useful when you have data that follows different patterns
based on some categorical variable, or when you want to learn an expression with constants that should
be optimized during the search.
In this tutorial, we'll generate synthetic data with class-dependent parameters and use symbolic regression to discover the parameterized expressions.
## The Problem
Let's create a synthetic dataset where the underlying function changes based on a class label:
```math
y = 2\cos(x_2 + 0.1) + x_1^2 - 3.2 \ \ \ \ \text{[class 1]} \\
\text{OR} \\
y = 2\cos(x_2 + 1.5) + x_1^2 - 0.5 \ \ \ \ \text{[class 2]}
```
We will need to simultaneously learn the symbolic expression and per-class parameters!
=#
using SymbolicRegression
using Random: MersenneTwister
using Zygote
using MLJBase: machine, fit!, predict, report
using Test

rng = MersenneTwister(0)
X = NamedTuple{(:x1, :x2, :x3, :x4, :x5)}(ntuple(_ -> randn(rng, Float32, 30), Val(5)))
X = (; X..., classes=rand(rng, 1:2, 30))
p1 = [0.0f0, 3.2f0]
p2 = [1.5f0, 0.5f0]
#=
Now, we generate synthetic data, with these 2 different classes.
Note that the `class` feature is given special treatment for the [`SRRegressor`](@ref)
as a categorical variable:
=#

X = let rng = MersenneTwister(0), n = 30
(; x1=randn(rng, n), x2=randn(rng, n), class=rand(rng, 1:2, n))
end

#=
Now, we generate target values using the true model that
has class-dependent parameters:
=#
y = let P1 = [0.1, 1.5], P2 = [3.2, 0.5]
[2 * cos(x2 + P1[class]) + x1^2 - P2[class] for (x1, x2, class) in zip(X.x1, X.x2, X.class)]
end

y = [
2 * cos(X.x4[i] + p1[X.classes[i]]) + X.x1[i]^2 - p2[X.classes[i]] for
i in eachindex(X.classes)
]
#=
## Setting up the Search
stop_at = Ref(1e-4)
We'll configure the symbolic regression search to:
- Use parameterized expressions with up to 2 parameters
- Use Zygote.jl for automatic differentiation during parameter optimization (important when using parametric expressions, as it is higher dimensional)
=#

stop_at = Ref(1e-4) #src

model = SRRegressor(;
niterations=100,
Expand All @@ -25,12 +66,38 @@ model = SRRegressor(;
expression_type=ParametricExpression,
expression_options=(; max_parameters=2),
autodiff_backend=:Zygote,
parallelism=:multithreading,
early_stop_condition=(loss, _) -> loss < stop_at[],
)
early_stop_condition=(loss, _) -> loss < stop_at[], #src
);

#=
Now, let's set up the machine and fit it:
=#

mach = machine(model, X, y)

#=
At this point, you would run:
```julia
fit!(mach)
```
You can extract the best expression and parameters with:
```julia
report(mach).equations[end]
```
## Key Takeaways
1. [`ParametricExpression`](@ref)s allows us to discover symbolic expressions with optimizable parameters
2. The parameters can capture class-dependent variations in the underlying model
This approach is particularly useful when you suspect your data follows a common
functional form, but with varying parameters across different conditions or class!
=#
#literate_end

fit!(mach)
idx1 = lastindex(report(mach).equations)
ypred1 = predict(mach, (data=X, idx=idx1))
Expand Down
36 changes: 22 additions & 14 deletions examples/template_expression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,33 +6,41 @@ using Test: @test
options = Options(; binary_operators=(+, *, /, -), unary_operators=(sin, cos))
operators = options.operators
variable_names = (i -> "x$i").(1:3)
x1, x2, x3 = (i -> Expression(Node(Float64; feature=i); operators, variable_names)).(1:3)

structure = TemplateStructure{(:f, :g1, :g2)}(;
combine_vectors=e -> map((f, g1, g2) -> (f + g1, f + g2), e.f, e.g1, e.g2),
combine_strings=e -> "( $(e.f) + $(e.g1), $(e.f) + $(e.g2) )",
variable_constraints=(; f=[1, 2], g1=[3], g2=[3]),
x1, x2, x3 =
(i -> ComposableExpression(Node(Float64; feature=i); operators, variable_names)).(1:3)

structure = TemplateStructure{(:f, :g1, :g2)}(
((; f, g1, g2), (x1, x2, x3)) -> let
_f = f(x1, x2)
_g1 = g1(x3)
_g2 = g2(x3)
_out1 = _f + _g1
_out2 = _f + _g2
ValidVector(map(tuple, _out1.x, _out2.x), _out1.valid && _out2.valid)
end,
)

st_expr = TemplateExpression((; f=x1, g1=x3, g2=x3); structure, operators, variable_names)

X = rand(100, 3) .* 10
x1 = rand(100)
x2 = rand(100)
x3 = rand(100)

# Our dataset is a vector of 2-tuples
y = [(sin(X[i, 1]) + X[i, 3]^2, sin(X[i, 1]) + X[i, 3]) for i in eachindex(axes(X, 1))]
y = [(sin(x1[i]) + x3[i]^2, sin(x1[i]) + x3[i]) for i in eachindex(x1, x2, x3)]

model = SRRegressor(;
binary_operators=(+, *),
unary_operators=(sin,),
maxsize=15,
maxsize=20,
expression_type=TemplateExpression,
expression_options=(; structure),
# The elementwise needs to operate directly on each row of `y`:
elementwise_loss=((x1, x2), (y1, y2)) -> (y1 - x1)^2 + (y2 - x2)^2,
early_stop_condition=(loss, complexity) -> loss < 1e-5 && complexity <= 7,
early_stop_condition=(loss, complexity) -> loss < 1e-6 && complexity <= 7,
)

mach = machine(model, X, y)
mach = machine(model, [x1 x2 x3], y)
fit!(mach)

# Check the performance of the model
Expand All @@ -48,6 +56,6 @@ best_f = get_contents(best_expr).f
best_g1 = get_contents(best_expr).g1
best_g2 = get_contents(best_expr).g2

@test best_f(X') (@. sin(X[:, 1]))
@test best_g1(X') (@. X[:, 3] * X[:, 3])
@test best_g2(X') (@. X[:, 3])
@test best_f(x1, x2) @. sin.(x1)
@test best_g1(x3) (@. x3 * x3)
@test best_g2(x3) (@. x3)
Loading

0 comments on commit cf4c0c2

Please sign in to comment.