diff --git a/Project.toml b/Project.toml index e268f51a0..24556292c 100644 --- a/Project.toml +++ b/Project.toml @@ -6,6 +6,7 @@ version = "0.24.4" [deps] Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" +DispatchDoctor = "8d63f2c5-f18a-4cf2-ba9d-b3f60fc568c8" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" DynamicExpressions = "a40a106e-89c9-4ca8-8020-a735e8728b6b" DynamicQuantities = "06fc5a27-2a28-4c7c-a15d-362465fb6821" @@ -37,6 +38,7 @@ SymbolicRegressionSymbolicUtilsExt = "SymbolicUtils" Compat = "^4.2" Dates = "1" Distributed = "1" +DispatchDoctor = "0.4" DynamicExpressions = "0.16" DynamicQuantities = "0.10, 0.11, 0.12, 0.13" JSON3 = "1" diff --git a/src/Configure.jl b/src/Configure.jl index f2dc10ed2..5ccc08100 100644 --- a/src/Configure.jl +++ b/src/Configure.jl @@ -247,7 +247,8 @@ function import_module_on_workers(procs, filename::String, options::Options, ver @info "Importing SymbolicRegression on workers as well as extensions $(join(relevant_extensions, ',' * ' '))." end @everywhere procs Core.eval(Core.Main, $expr) - return verbosity > 0 && @info "Finished!" + verbosity > 0 && @info "Finished!" + return nothing end function test_module_on_workers(procs, options::Options, verbosity) diff --git a/src/InterfaceDynamicExpressions.jl b/src/InterfaceDynamicExpressions.jl index 20e62f488..2645f1b66 100644 --- a/src/InterfaceDynamicExpressions.jl +++ b/src/InterfaceDynamicExpressions.jl @@ -56,13 +56,14 @@ which speed up evaluation significantly. function eval_tree_array( tree::AbstractExpressionNode, X::AbstractArray, options::Options; kws... ) + A = expected_array_type(X) return eval_tree_array( tree, X, options.operators; turbo=options.turbo, bumper=options.bumper, kws... - )::Tuple{expected_array_type(tree, X, options.operators),Bool} + )::Tuple{A,Bool} end # Improve type inference by telling Julia the expected array returned -function expected_array_type(::AbstractExpressionNode, X::AbstractArray, ::OperatorEnum) +function expected_array_type(X::AbstractArray) return typeof(similar(X, axes(X, 2))) end @@ -89,7 +90,8 @@ respect to `x1`. function eval_diff_tree_array( tree::AbstractExpressionNode, X::AbstractArray, options::Options, direction::Int ) - return eval_diff_tree_array(tree, X, options.operators, direction) + A = expected_array_type(X) + return eval_diff_tree_array(tree, X, options.operators, direction)::Tuple{A,A,Bool} end """ @@ -116,7 +118,9 @@ to every constant in the expression. function eval_grad_tree_array( tree::AbstractExpressionNode, X::AbstractArray, options::Options; kws... ) - return eval_grad_tree_array(tree, X, options.operators; kws...) + A = expected_array_type(X) + M = typeof(X) # TODO: This won't work with StaticArrays! + return eval_grad_tree_array(tree, X, options.operators; kws...)::Tuple{A,M,Bool} end """ @@ -127,7 +131,8 @@ Evaluate an expression tree in a way that can be auto-differentiated. function differentiable_eval_tree_array( tree::AbstractExpressionNode, X::AbstractArray, options::Options; kws... ) - return differentiable_eval_tree_array(tree, X, options.operators; kws...) + A = expected_array_type(X) + return differentiable_eval_tree_array(tree, X, options.operators; kws...)::Tuple{A,Bool} end const WILDCARD_UNIT_STRING = "[?]" diff --git a/src/InterfaceDynamicQuantities.jl b/src/InterfaceDynamicQuantities.jl index 8bce2be30..34580a3cc 100644 --- a/src/InterfaceDynamicQuantities.jl +++ b/src/InterfaceDynamicQuantities.jl @@ -1,5 +1,6 @@ module InterfaceDynamicQuantitiesModule +using DispatchDoctor: @unstable using DynamicQuantities: UnionAbstractQuantity, AbstractDimensions, @@ -70,7 +71,7 @@ end Recursively finds the dimension type from an array, or, if no quantity is found, returns the default type. """ -function get_dimensions_type(A::AbstractArray, default::Type{D}) where {D} +@unstable function get_dimensions_type(A::AbstractArray, default::Type{D}) where {D} i = findfirst(a -> isa(a, UnionAbstractQuantity), A) if i === nothing return D diff --git a/src/Options.jl b/src/Options.jl index 239b32ad1..7ab86f454 100644 --- a/src/Options.jl +++ b/src/Options.jl @@ -1,5 +1,6 @@ module OptionsModule +using DispatchDoctor: @unstable using Optim: Optim using Dates: Dates using StatsBase: StatsBase @@ -374,8 +375,7 @@ https://github.com/MilesCranmer/PySR/discussions/115. # Arguments $(OPTION_DESCRIPTIONS) """ -function Options end -@save_kwargs DEFAULT_OPTIONS function Options(; +@unstable @save_kwargs DEFAULT_OPTIONS function Options(; binary_operators=[+, -, /, *], unary_operators=[], constraints=nothing, diff --git a/src/PopMember.jl b/src/PopMember.jl index 1215333b7..d47042c20 100644 --- a/src/PopMember.jl +++ b/src/PopMember.jl @@ -1,5 +1,7 @@ module PopMemberModule +using DispatchDoctor: @unstable + using DynamicExpressions: AbstractExpressionNode, copy_node, count_nodes using ..CoreModule: Options, Dataset, DATA_TYPE, LOSS_TYPE import ..ComplexityModule: compute_complexity @@ -25,7 +27,7 @@ function Base.setproperty!(member::PopMember, field::Symbol, value) field == :tree && setfield!(member, :complexity, -1) return setfield!(member, field, value) end -function Base.getproperty(member::PopMember, field::Symbol) +@unstable @inline function Base.getproperty(member::PopMember, field::Symbol) field == :complexity && throw( error("Don't access `.complexity` directly. Use `compute_complexity` instead.") ) diff --git a/src/Population.jl b/src/Population.jl index 4b776e826..c6d71729b 100644 --- a/src/Population.jl +++ b/src/Population.jl @@ -2,6 +2,7 @@ module PopulationModule using StatsBase: StatsBase using Random: randperm +using DispatchDoctor: @unstable using DynamicExpressions: AbstractExpressionNode, Node, string_tree using ..CoreModule: Options, Dataset, RecordType, DATA_TYPE, LOSS_TYPE using ..ComplexityModule: compute_complexity @@ -67,23 +68,23 @@ end Create random population and score them on the dataset. """ -function Population( +@unstable function Population( X::AbstractMatrix{T}, y::AbstractVector{T}; population_size=nothing, nlength::Int=3, options::Options, nfeatures::Int, - loss_type::Type=Nothing, + loss_type::Type{L}=Nothing, npop=nothing, -) where {T<:DATA_TYPE} +) where {T<:DATA_TYPE,L} @assert (population_size !== nothing) ⊻ (npop !== nothing) population_size = if npop === nothing population_size else npop end - dataset = Dataset(X, y; loss_type=loss_type) + dataset = Dataset(X, y, L) update_baseline_loss!(dataset, options) return Population( dataset; population_size=population_size, options=options, nfeatures=nfeatures diff --git a/src/SearchUtils.jl b/src/SearchUtils.jl index b6a8c08bf..4079a977d 100644 --- a/src/SearchUtils.jl +++ b/src/SearchUtils.jl @@ -6,6 +6,7 @@ module SearchUtilsModule using Printf: @printf, @sprintf using Distributed using StatsBase: mean +using DispatchDoctor: @unstable using DynamicExpressions: AbstractExpressionNode, string_tree using ..UtilsModule: subscriptify @@ -37,7 +38,9 @@ Base.@kwdef struct RuntimeOptions{PARALLELISM,DIM_OUT,RETURN_STATE} verbosity::Int64 progress::Bool end -function Base.getproperty(roptions::RuntimeOptions{P,D,R}, name::Symbol) where {P,D,R} +@unstable @inline function Base.getproperty( + roptions::RuntimeOptions{P,D,R}, name::Symbol +) where {P,D,R} if name == :parallelism return P elseif name == :dim_out diff --git a/src/SymbolicRegression.jl b/src/SymbolicRegression.jl index 84ab4063f..3e782e150 100644 --- a/src/SymbolicRegression.jl +++ b/src/SymbolicRegression.jl @@ -149,27 +149,31 @@ function deprecate_varmap(variable_names, varMap, func_name) return variable_names end -include("Utils.jl") -include("InterfaceDynamicQuantities.jl") -include("Core.jl") -include("InterfaceDynamicExpressions.jl") -include("Recorder.jl") -include("Complexity.jl") -include("DimensionalAnalysis.jl") -include("CheckConstraints.jl") -include("AdaptiveParsimony.jl") -include("MutationFunctions.jl") -include("LossFunctions.jl") -include("PopMember.jl") -include("ConstantOptimization.jl") -include("Population.jl") -include("HallOfFame.jl") -include("Mutate.jl") -include("RegularizedEvolution.jl") -include("SingleIteration.jl") -include("ProgressBars.jl") -include("Migration.jl") -include("SearchUtils.jl") +using DispatchDoctor: @stable + +@stable default_mode = "disable" begin + include("Utils.jl") + include("InterfaceDynamicQuantities.jl") + include("Core.jl") + include("InterfaceDynamicExpressions.jl") + include("Recorder.jl") + include("Complexity.jl") + include("DimensionalAnalysis.jl") + include("CheckConstraints.jl") + include("AdaptiveParsimony.jl") + include("MutationFunctions.jl") + include("LossFunctions.jl") + include("PopMember.jl") + include("ConstantOptimization.jl") + include("Population.jl") + include("HallOfFame.jl") + include("Mutate.jl") + include("RegularizedEvolution.jl") + include("SingleIteration.jl") + include("ProgressBars.jl") + include("Migration.jl") + include("SearchUtils.jl") +end using .CoreModule: MAX_DEGREE, @@ -255,8 +259,10 @@ using .SearchUtilsModule: get_cur_maxsize, update_hall_of_fame! -include("deprecates.jl") -include("Configure.jl") +@stable default_mode = "disable" begin + include("deprecates.jl") + include("Configure.jl") +end """ equation_search(X, y[; kws...]) @@ -1076,7 +1082,7 @@ function _format_output(state::SearchState, ropt::RuntimeOptions) end end -function _dispatch_s_r_cycle( +@stable default_mode = "disable" function _dispatch_s_r_cycle( in_pop::Population{T,L,N}, dataset::Dataset, options::Options; diff --git a/src/Utils.jl b/src/Utils.jl index 3b0292e04..08cccceba 100644 --- a/src/Utils.jl +++ b/src/Utils.jl @@ -93,7 +93,8 @@ const max_ops = 8192 const vals = ntuple(Val, max_ops) """Return the bottom k elements of x, and their indices.""" -bottomk_fast(x, k) = _bottomk_dispatch(x, vals[k]) +bottomk_fast(x::AbstractVector{T}, k) where {T} = + _bottomk_dispatch(x, vals[k])::Tuple{Vector{T},Vector{Int}} function _bottomk_dispatch(x::AbstractVector{T}, ::Val{k}) where {T,k} if k == 1 @@ -179,7 +180,7 @@ function _save_kwargs(log_variable::Symbol, fdef::Expr) return true end return quote - $fdef + $(Base).@__doc__ $fdef const $log_variable = $kwargs end end diff --git a/test/LocalPreferences.toml b/test/LocalPreferences.toml new file mode 100644 index 000000000..6eb148262 --- /dev/null +++ b/test/LocalPreferences.toml @@ -0,0 +1,3 @@ +[SymbolicRegression] +instability_check = "error" +instability_check_codegen = "min" diff --git a/test/Project.toml b/test/Project.toml index 6cbb1770e..07b5a0bab 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,6 +1,7 @@ [deps] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" Bumper = "8ce10254-0962-460f-a3d8-1f77fea1446e" +DispatchDoctor = "8d63f2c5-f18a-4cf2-ba9d-b3f60fc568c8" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" DynamicExpressions = "a40a106e-89c9-4ca8-8020-a735e8728b6b" DynamicQuantities = "06fc5a27-2a28-4c7c-a15d-362465fb6821" diff --git a/test/test_abstract_numbers.jl b/test/test_abstract_numbers.jl index 4bc495a5f..60155be2c 100644 --- a/test/test_abstract_numbers.jl +++ b/test/test_abstract_numbers.jl @@ -20,7 +20,7 @@ for T in (ComplexF16, ComplexF32, ComplexF64) elementwise_loss=(prediction, target) -> abs2(prediction - target), ) - dataset = Dataset(X, y; loss_type=L) + dataset = Dataset(X, y, L) hof = if T == ComplexF16 equation_search([dataset]; options=options, niterations=1_000_000_000) else diff --git a/test/test_dataset.jl b/test/test_dataset.jl index 389409691..fbff37af3 100644 --- a/test/test_dataset.jl +++ b/test/test_dataset.jl @@ -1,5 +1,6 @@ using SymbolicRegression using Test +using DispatchDoctor: allow_unstable @testset "Dataset construction" begin # Promotion of types: @@ -9,6 +10,8 @@ using Test end @testset "With deprecated kwarg" begin - dataset = Dataset(randn(ComplexF32, 3, 32), randn(ComplexF32, 32); loss_type=Float64) + dataset = allow_unstable() do + Dataset(randn(ComplexF32, 3, 32), randn(ComplexF32, 32); loss_type=Float64) + end @test dataset isa Dataset{ComplexF32,Float64} end diff --git a/test/test_operators.jl b/test/test_operators.jl index 1357f098b..21d519dba 100644 --- a/test/test_operators.jl +++ b/test/test_operators.jl @@ -21,6 +21,7 @@ using SymbolicRegression: using Test using Random: MersenneTwister using Suppressor: @capture_err +using LoopVectorization include("test_params.jl") @testset "Generic operator tests" begin