From f4c0d7c51b83fa52d39be23f5359db93c126bc44 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Wed, 27 Nov 2024 23:12:35 +0000 Subject: [PATCH] feat: allow argument-less TemplateExpression parts --- src/ComposableExpression.jl | 5 +++++ src/TemplateExpression.jl | 2 +- test/test_composable_expression.jl | 24 ++++++++++++++++++++++++ 3 files changed, 30 insertions(+), 1 deletion(-) diff --git a/src/ComposableExpression.jl b/src/ComposableExpression.jl index 866b56e5d..1b4dd44ab 100644 --- a/src/ComposableExpression.jl +++ b/src/ComposableExpression.jl @@ -171,6 +171,11 @@ function (ex::AbstractComposableExpression)( return ValidVector(eval_tree_array(ex, X)) end end +function (ex::AbstractComposableExpression{T})() where {T} + X = Matrix{T}(undef, 0, 1) # Value is irrelevant as it won't be used + out, _ = eval_tree_array(ex, X) # TODO: The valid is not used; not sure how to incorporate + return only(out)::T +end function (ex::AbstractComposableExpression)( x::AbstractComposableExpression, _xs::Vararg{AbstractComposableExpression,N} ) where {N} diff --git a/src/TemplateExpression.jl b/src/TemplateExpression.jl index f5aa90237..45c8750f3 100644 --- a/src/TemplateExpression.jl +++ b/src/TemplateExpression.jl @@ -82,7 +82,7 @@ function _record_composable_expression!(variable_constraints, ::Val{k}, args...) elseif vc != length(args) throw(ArgumentError("Inconsistent number of arguments passed to $k")) end - return first(args) + return isempty(args) ? 0.0 : first(args) end """Infers number of features used by each subexpression, by passing in test data.""" diff --git a/test/test_composable_expression.jl b/test/test_composable_expression.jl index fc17a6b5b..094be557d 100644 --- a/test/test_composable_expression.jl +++ b/test/test_composable_expression.jl @@ -294,3 +294,27 @@ end ((; f, g), (x1, x2)) -> f(x1) ) end + +@testitem "Test argument-less template structure" tags = [:part2] begin + using SymbolicRegression + using DynamicExpressions: OperatorEnum + + operators = OperatorEnum(; binary_operators=(+, *, /, -), unary_operators=(sin, cos)) + variable_names = ["x1", "x2"] + x1 = ComposableExpression(Node{Float64}(; feature=1); operators, variable_names) + x2 = ComposableExpression(Node{Float64}(; feature=2); operators, variable_names) + c1 = ComposableExpression(Node{Float64}(; val=3.0); operators, variable_names) + + # We can evaluate an expression with no arguments: + @test c1() == 3.0 + @test typeof(c1()) === Float64 + + # Create a structure where f takes no arguments and g takes two + structure = TemplateStructure{(:f, :g)}(((; f, g), (x1, x2)) -> f() + g(x1, x2)) + + @test structure.num_features == (; f=0, g=2) + + X = [1.0 2.0]' + expr = TemplateExpression((; f=c1, g=x1 + x2); structure, operators, variable_names) + @test expr(X) ≈ [6.0] # 3 + (1 + 2) +end