diff --git a/src/Core.jl b/src/Core.jl index c442efc7..bc9210ae 100644 --- a/src/Core.jl +++ b/src/Core.jl @@ -33,7 +33,10 @@ using .OperatorsModule: safe_log10, safe_log1p, safe_sqrt, + safe_asin, + safe_acos, safe_acosh, + safe_atanh, neg, greater, cond, diff --git a/src/Operators.jl b/src/Operators.jl index b38ccf97..350718bd 100644 --- a/src/Operators.jl +++ b/src/Operators.jl @@ -5,6 +5,7 @@ using SpecialFunctions: SpecialFunctions using DynamicQuantities: UnionAbstractQuantity using SpecialFunctions: erf, erfc using Base: @deprecate +using DynamicDiff: ForwardDiff using ..ProgramConstantsModule: DATA_TYPE using ...UtilsModule: @ignore #TODO - actually add these operators to the module! @@ -19,15 +20,25 @@ gamma(x) = SpecialFunctions.gamma(x) atanh_clip(x) = atanh(mod(x + oneunit(x), oneunit(x) + oneunit(x)) - oneunit(x)) * one(x) # == atanh((x + 1) % 2 - 1) +const Dual = ForwardDiff.Dual + # Implicitly defined: #binary: mod #unary: exp, abs, log1p, sin, cos, tan, sinh, cosh, tanh, asin, acos, atan, asinh, acosh, atanh, erf, erfc, gamma, relu, round, floor, ceil, round, sign. +const FloatOrDual = Union{AbstractFloat,Dual{<:Any,<:AbstractFloat}} + # Use some fast operators from https://github.com/JuliaLang/julia/blob/81597635c4ad1e8c2e1c5753fda4ec0e7397543f/base/fastmath.jl # Define allowed operators. Any julia operator can also be used. # TODO: Add all of these operators to the precompilation. # TODO: Since simplification is done in DynamicExpressions.jl, are these names correct anymore? -function safe_pow(x::T, y::T)::T where {T<:Union{AbstractFloat,UnionAbstractQuantity}} +function safe_pow( + x::T1, y::T2 +) where { + T1<:Union{FloatOrDual,UnionAbstractQuantity}, + T2<:Union{FloatOrDual,UnionAbstractQuantity}, +} + T = promote_type(T1, T2) if isinteger(y) y < zero(y) && iszero(x) && return T(NaN) else @@ -36,29 +47,32 @@ function safe_pow(x::T, y::T)::T where {T<:Union{AbstractFloat,UnionAbstractQuan end return x^y end -function safe_log(x::T)::T where {T<:AbstractFloat} - x <= zero(x) && return T(NaN) - return log(x) +function safe_log(x::T)::T where {T<:FloatOrDual} + return x > zero(x) ? log(x) : T(NaN) +end +function safe_log2(x::T)::T where {T<:FloatOrDual} + return x > zero(x) ? log2(x) : T(NaN) +end +function safe_log10(x::T)::T where {T<:FloatOrDual} + return x > zero(x) ? log10(x) : T(NaN) +end +function safe_log1p(x::T)::T where {T<:FloatOrDual} + return x > -oneunit(x) ? log1p(x) : T(NaN) end -function safe_log2(x::T)::T where {T<:AbstractFloat} - x <= zero(x) && return T(NaN) - return log2(x) +function safe_asin(x::T)::T where {T<:FloatOrDual} + return -oneunit(x) <= x <= oneunit(x) ? asin(x) : T(NaN) end -function safe_log10(x::T)::T where {T<:AbstractFloat} - x <= zero(x) && return T(NaN) - return log10(x) +function safe_acos(x::T)::T where {T<:FloatOrDual} + return -oneunit(x) <= x <= oneunit(x) ? acos(x) : T(NaN) end -function safe_log1p(x::T)::T where {T<:AbstractFloat} - x <= -oneunit(x) && return T(NaN) - return log1p(x) +function safe_acosh(x::T)::T where {T<:FloatOrDual} + return x >= oneunit(x) ? acosh(x) : T(NaN) end -function safe_acosh(x::T)::T where {T<:AbstractFloat} - x < oneunit(x) && return T(NaN) - return acosh(x) +function safe_atanh(x::T)::T where {T<:FloatOrDual} + return -oneunit(x) <= x <= oneunit(x) ? atanh(x) : T(NaN) end -function safe_sqrt(x::T)::T where {T<:AbstractFloat} - x < zero(x) && return T(NaN) - return sqrt(x) +function safe_sqrt(x::T)::T where {T<:FloatOrDual} + return x >= zero(x) ? sqrt(x) : T(NaN) end # TODO: Should the above be made more generic, for, e.g., compatibility with units? @@ -75,6 +89,9 @@ safe_log(x) = log(x) safe_log2(x) = log2(x) safe_log10(x) = log10(x) safe_log1p(x) = log1p(x) +safe_asin(x) = asin(x) +safe_acos(x) = acos(x) +safe_atanh(x) = atanh(x) safe_acosh(x) = acosh(x) safe_sqrt(x) = sqrt(x) @@ -103,7 +120,10 @@ DE.get_op_name(::typeof(safe_log)) = "log" DE.get_op_name(::typeof(safe_log2)) = "log2" DE.get_op_name(::typeof(safe_log10)) = "log10" DE.get_op_name(::typeof(safe_log1p)) = "log1p" +DE.get_op_name(::typeof(safe_asin)) = "asin" +DE.get_op_name(::typeof(safe_acos)) = "acos" DE.get_op_name(::typeof(safe_acosh)) = "acosh" +DE.get_op_name(::typeof(safe_atanh)) = "atanh" DE.get_op_name(::typeof(safe_sqrt)) = "sqrt" # Expression algebra @@ -112,7 +132,10 @@ DE.declare_operator_alias(::typeof(safe_log), ::Val{1}) = log DE.declare_operator_alias(::typeof(safe_log2), ::Val{1}) = log2 DE.declare_operator_alias(::typeof(safe_log10), ::Val{1}) = log10 DE.declare_operator_alias(::typeof(safe_log1p), ::Val{1}) = log1p +DE.declare_operator_alias(::typeof(safe_asin), ::Val{1}) = asin +DE.declare_operator_alias(::typeof(safe_acos), ::Val{1}) = acos DE.declare_operator_alias(::typeof(safe_acosh), ::Val{1}) = acosh +DE.declare_operator_alias(::typeof(safe_atanh), ::Val{1}) = atanh DE.declare_operator_alias(::typeof(safe_sqrt), ::Val{1}) = sqrt # Deprecated operations: @@ -123,13 +146,17 @@ DE.declare_operator_alias(::typeof(safe_sqrt), ::Val{1}) = sqrt @ignore pow(x, y) = safe_pow(x, y) @ignore pow_abs(x, y) = safe_pow(x, y) +# Actual mappings used for evaluation get_safe_op(op::F) where {F<:Function} = op get_safe_op(::typeof(^)) = safe_pow get_safe_op(::typeof(log)) = safe_log get_safe_op(::typeof(log2)) = safe_log2 get_safe_op(::typeof(log10)) = safe_log10 get_safe_op(::typeof(log1p)) = safe_log1p +get_safe_op(::typeof(asin)) = safe_asin +get_safe_op(::typeof(acos)) = safe_acos get_safe_op(::typeof(sqrt)) = safe_sqrt get_safe_op(::typeof(acosh)) = safe_acosh +get_safe_op(::typeof(atanh)) = safe_atanh end diff --git a/src/Options.jl b/src/Options.jl index d7fc61cf..b6f1dd84 100644 --- a/src/Options.jl +++ b/src/Options.jl @@ -23,8 +23,10 @@ using ..OperatorsModule: safe_log2, safe_log1p, safe_sqrt, + safe_asin, + safe_acos, safe_acosh, - atanh_clip + safe_atanh using ..MutationWeightsModule: AbstractMutationWeights, MutationWeights, mutations import ..OptionsStructModule: Options using ..OptionsStructModule: ComplexityMapping, operator_specialization @@ -139,7 +141,7 @@ end ] end -function binopmap(op::F) where {F} +function binopmap(@nospecialize(op)) if op == plus return + elseif op == mult @@ -155,14 +157,14 @@ function binopmap(op::F) where {F} end return op end -function inverse_binopmap(op::F) where {F} +function inverse_binopmap(@nospecialize(op)) if op == safe_pow return ^ end return op end -function unaopmap(op::F) where {F} +function unaopmap(@nospecialize(op)) if op == log return safe_log elseif op == log10 @@ -173,14 +175,18 @@ function unaopmap(op::F) where {F} return safe_log1p elseif op == sqrt return safe_sqrt + elseif op == asin + return safe_asin + elseif op == acos + return safe_acos elseif op == acosh return safe_acosh elseif op == atanh - return atanh_clip + return safe_atanh end return op end -function inverse_unaopmap(op::F) where {F} +function inverse_unaopmap(@nospecialize(op)) if op == safe_log return log elseif op == safe_log10 @@ -191,9 +197,13 @@ function inverse_unaopmap(op::F) where {F} return log1p elseif op == safe_sqrt return sqrt + elseif op == safe_asin + return asin + elseif op == safe_acos + return acos elseif op == safe_acosh return acosh - elseif op == atanh_clip + elseif op == safe_atanh return atanh end return op diff --git a/src/SymbolicRegression.jl b/src/SymbolicRegression.jl index cc7decf0..fd95b085 100644 --- a/src/SymbolicRegression.jl +++ b/src/SymbolicRegression.jl @@ -67,7 +67,10 @@ export Population, safe_log2, safe_log10, safe_log1p, + safe_asin, + safe_acos, safe_acosh, + safe_atanh, safe_sqrt, neg, greater, @@ -247,7 +250,10 @@ using .CoreModule: safe_log10, safe_log1p, safe_sqrt, + safe_asin, + safe_acos, safe_acosh, + safe_atanh, neg, greater, cond, diff --git a/test/runtests.jl b/test/runtests.jl index d02a1b9b..049ea686 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -11,9 +11,7 @@ end @eval @run_package_tests filter = ti -> !isdisjoint(ti.tags, $tags_to_run) verbose = true # TODO: This is a very slow test -@testitem "Test custom operators and additional types" tags = [:part2] begin - include("test_operators.jl") -end +include("test_operators.jl") @testitem "Test tree construction and scoring" tags = [:part3] begin include("test_tree_construction.jl") diff --git a/test/test_operators.jl b/test/test_operators.jl index 1221ba6c..b7643e70 100644 --- a/test/test_operators.jl +++ b/test/test_operators.jl @@ -1,29 +1,28 @@ -using SymbolicRegression -using SymbolicRegression: - plus, - sub, - mult, - square, - cube, - safe_pow, - safe_log, - safe_log2, - safe_log10, - safe_sqrt, - safe_acosh, - neg, - greater, - cond, - relu, - logical_or, - logical_and, - gamma -using Random: MersenneTwister -using Suppressor: @capture_err -using LoopVectorization -include("test_params.jl") - -@testset "Generic operator tests" begin +@testitem "Generic operator tests" tags = [:part2] begin + using SymbolicRegression + using SymbolicRegression: + plus, + sub, + mult, + square, + cube, + safe_pow, + safe_log, + safe_log2, + safe_log10, + safe_sqrt, + safe_acosh, + safe_atanh, + safe_asin, + safe_acos, + neg, + greater, + cond, + relu, + logical_or, + logical_and, + gamma + types_to_test = [Float16, Float32, Float64, BigFloat] for T in types_to_test val = T(0.5) @@ -37,6 +36,12 @@ include("test_params.jl") @test abs(safe_log1p(val) - log1p(val)) < 1e-6 @test abs(safe_acosh(val2) - acosh(val2)) < 1e-6 @test isnan(safe_acosh(-val2)) + @test abs(safe_asin(val) - asin(val)) < 1e-6 + @test isnan(safe_asin(val2)) + @test abs(safe_acos(val) - acos(val)) < 1e-6 + @test isnan(safe_acos(val2)) + @test abs(safe_atanh(val) - atanh(val)) < 1e-6 + @test isnan(safe_atanh(val2)) @test neg(-val) == val @test safe_sqrt(val) == sqrt(val) @test isnan(safe_sqrt(-val)) @@ -70,7 +75,11 @@ include("test_params.jl") end end -@testset "Test built-in operators pass validation" begin +@testitem "Built-in operators pass validation" tags = [:part3] begin + using SymbolicRegression + using SymbolicRegression: + plus, sub, mult, square, cube, neg, relu, greater, logical_or, logical_and, cond + types_to_test = [Float16, Float32, Float64, BigFloat] options = Options(; binary_operators=[plus, sub, mult, /, ^, greater, logical_or, logical_and, cond], @@ -83,7 +92,10 @@ end end end -@testset "Test built-in operators pass validation for complex numbers" begin +@testitem "Built-in operators pass validation for complex numbers" tags = [:part2] begin + using SymbolicRegression + using SymbolicRegression: plus, sub, mult, square, cube, neg + types_to_test = [ComplexF16, ComplexF32, ComplexF64] options = Options(; binary_operators=[plus, sub, mult, /, ^], @@ -94,7 +106,10 @@ end end end -@testset "Test incompatibilities are caught" begin +@testitem "Incompatibilities are caught" tags = [:part3] begin + using SymbolicRegression + using SymbolicRegression: greater + options = Options(; binary_operators=[greater]) @test_throws ErrorException SymbolicRegression.assert_operators_well_defined( ComplexF64, options @@ -104,7 +119,9 @@ end ) end -@testset "Operators which return the wrong type should fail" begin +@testitem "Operators with wrong type fail" tags = [:part2] begin + using SymbolicRegression + my_bad_op(x) = 1.0f0 options = Options(; binary_operators=[], unary_operators=[my_bad_op]) @test_throws ErrorException SymbolicRegression.assert_operators_well_defined( @@ -116,10 +133,32 @@ end @test_nowarn SymbolicRegression.assert_operators_well_defined(Float32, options) end -@testset "Turbo mode should be the same" begin +@testitem "Turbo mode matches regular mode" tags = [:part3] begin + using SymbolicRegression + using SymbolicRegression: + plus, sub, mult, square, cube, neg, relu, greater, logical_or, logical_and, cond + using Random: MersenneTwister + using Suppressor: @capture_err + using LoopVectorization: LoopVectorization as _ + include("test_params.jl") + binary_operators = [plus, sub, mult, /, ^, greater, logical_or, logical_and, cond] unary_operators = [square, cube, log, log2, log10, log1p, sqrt, atanh, acosh, neg, relu] options = Options(; binary_operators, unary_operators) + + function test_part(tree, Xpart, options) + y, completed = eval_tree_array(tree, Xpart, options) + completed || return nothing + # We capture any warnings about the LoopVectorization not working + local y_turbo + eval_warnings = @capture_err begin + y_turbo, _ = eval_tree_array(tree, Xpart, options; turbo=true) + end + test_info(@test(y[1] ≈ y_turbo[1] && eval_warnings == "")) do + @info T tree X[:, seed] y y_turbo eval_warnings + end + end + for T in (Float32, Float64), index_bin in 1:length(binary_operators), index_una in 1:length(unary_operators) @@ -129,16 +168,62 @@ end X = rand(MersenneTwister(0), T, 2, 20) for seed in 1:20 Xpart = X[:, [seed]] - y, completed = eval_tree_array(tree, Xpart, options) - completed || continue - local y_turbo - # We capture any warnings about the LoopVectorization not working - eval_warnings = @capture_err begin - y_turbo, _ = eval_tree_array(tree, Xpart, options; turbo=true) - end - test_info(@test y[1] ≈ y_turbo[1] && eval_warnings == "") do - @info T tree X[:, seed] y y_turbo eval_warnings - end + test_part(tree, Xpart, options) end end end + +@testitem "Safe operators are compatible with ForwardDiff" tags = [:part2] begin + using SymbolicRegression + using SymbolicRegression: + safe_log, + safe_log2, + safe_log10, + safe_log1p, + safe_sqrt, + safe_asin, + safe_acos, + safe_atanh, + safe_acosh, + safe_pow + using ForwardDiff + + # Test all safe operators + safe_operators = [ + (safe_log, 2.0, -1.0), # (operator, valid_input, invalid_input) + (safe_log2, 2.0, -1.0), + (safe_log10, 2.0, -1.0), + (safe_log1p, 0.5, -2.0), + (safe_sqrt, 2.0, -1.0), + (safe_asin, 0.5, 2.0), + (safe_acos, 0.5, 2.0), + (safe_atanh, 0.5, 2.0), + (safe_acosh, 2.0, 0.5), + ] + + for (op, valid_x, invalid_x) in safe_operators + # Test derivative exists and is correct for valid input + deriv = ForwardDiff.derivative(op, valid_x) + @test !isnan(deriv) + @test !iszero(deriv) # All these operators should have non-zero derivatives at test points + + # Test derivative is 0.0 for invalid input + deriv_invalid = ForwardDiff.derivative(op, invalid_x) + @test iszero(deriv_invalid) + end + + # Test safe_pow separately since it's binary + for x in [0.5, 2.0], y in [2.0, 0.5] + # Test valid derivatives + deriv_x = ForwardDiff.derivative(x -> safe_pow(x, y), x) + deriv_y = ForwardDiff.derivative(y -> safe_pow(x, y), y) + @test !isnan(deriv_x) + @test !isnan(deriv_y) + @test !iszero(deriv_x) # Should be non-zero for our test points + + # Test invalid cases return 0.0 derivatives + @test iszero(ForwardDiff.derivative(x -> safe_pow(x, -1.0), 0.0)) # 0^(-1) + @test iszero(ForwardDiff.derivative(x -> safe_pow(-x, 0.5), 1.0)) # (-x)^0.5 + @test iszero(ForwardDiff.derivative(x -> safe_pow(x, -0.5), 0.0)) # 0^(-0.5) + end +end