Skip to content

Commit

Permalink
Merge pull request #321 from MilesCranmer/dispatch-doctor
Browse files Browse the repository at this point in the history
Use DispatchDoctor.jl to wrap entire package with `@stable`
  • Loading branch information
MilesCranmer authored Jun 7, 2024
2 parents 71447ee + 108d317 commit ea03242
Show file tree
Hide file tree
Showing 15 changed files with 73 additions and 43 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ version = "0.24.4"
[deps]
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
DispatchDoctor = "8d63f2c5-f18a-4cf2-ba9d-b3f60fc568c8"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
DynamicExpressions = "a40a106e-89c9-4ca8-8020-a735e8728b6b"
DynamicQuantities = "06fc5a27-2a28-4c7c-a15d-362465fb6821"
Expand Down Expand Up @@ -37,6 +38,7 @@ SymbolicRegressionSymbolicUtilsExt = "SymbolicUtils"
Compat = "^4.2"
Dates = "1"
Distributed = "1"
DispatchDoctor = "0.4"
DynamicExpressions = "0.16"
DynamicQuantities = "0.10, 0.11, 0.12, 0.13"
JSON3 = "1"
Expand Down
3 changes: 2 additions & 1 deletion src/Configure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,8 @@ function import_module_on_workers(procs, filename::String, options::Options, ver
@info "Importing SymbolicRegression on workers as well as extensions $(join(relevant_extensions, ',' * ' '))."
end
@everywhere procs Core.eval(Core.Main, $expr)
return verbosity > 0 && @info "Finished!"
verbosity > 0 && @info "Finished!"
return nothing
end

function test_module_on_workers(procs, options::Options, verbosity)
Expand Down
15 changes: 10 additions & 5 deletions src/InterfaceDynamicExpressions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,14 @@ which speed up evaluation significantly.
function eval_tree_array(
tree::AbstractExpressionNode, X::AbstractArray, options::Options; kws...
)
A = expected_array_type(X)
return eval_tree_array(
tree, X, options.operators; turbo=options.turbo, bumper=options.bumper, kws...
)::Tuple{expected_array_type(tree, X, options.operators),Bool}
)::Tuple{A,Bool}
end

# Improve type inference by telling Julia the expected array returned
function expected_array_type(::AbstractExpressionNode, X::AbstractArray, ::OperatorEnum)
function expected_array_type(X::AbstractArray)
return typeof(similar(X, axes(X, 2)))
end

Expand All @@ -89,7 +90,8 @@ respect to `x1`.
function eval_diff_tree_array(
tree::AbstractExpressionNode, X::AbstractArray, options::Options, direction::Int
)
return eval_diff_tree_array(tree, X, options.operators, direction)
A = expected_array_type(X)
return eval_diff_tree_array(tree, X, options.operators, direction)::Tuple{A,A,Bool}
end

"""
Expand All @@ -116,7 +118,9 @@ to every constant in the expression.
function eval_grad_tree_array(
tree::AbstractExpressionNode, X::AbstractArray, options::Options; kws...
)
return eval_grad_tree_array(tree, X, options.operators; kws...)
A = expected_array_type(X)
M = typeof(X) # TODO: This won't work with StaticArrays!
return eval_grad_tree_array(tree, X, options.operators; kws...)::Tuple{A,M,Bool}
end

"""
Expand All @@ -127,7 +131,8 @@ Evaluate an expression tree in a way that can be auto-differentiated.
function differentiable_eval_tree_array(
tree::AbstractExpressionNode, X::AbstractArray, options::Options; kws...
)
return differentiable_eval_tree_array(tree, X, options.operators; kws...)
A = expected_array_type(X)
return differentiable_eval_tree_array(tree, X, options.operators; kws...)::Tuple{A,Bool}
end

const WILDCARD_UNIT_STRING = "[?]"
Expand Down
3 changes: 2 additions & 1 deletion src/InterfaceDynamicQuantities.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
module InterfaceDynamicQuantitiesModule

using DispatchDoctor: @unstable
using DynamicQuantities:
UnionAbstractQuantity,
AbstractDimensions,
Expand Down Expand Up @@ -70,7 +71,7 @@ end
Recursively finds the dimension type from an array, or,
if no quantity is found, returns the default type.
"""
function get_dimensions_type(A::AbstractArray, default::Type{D}) where {D}
@unstable function get_dimensions_type(A::AbstractArray, default::Type{D}) where {D}
i = findfirst(a -> isa(a, UnionAbstractQuantity), A)
if i === nothing
return D
Expand Down
4 changes: 2 additions & 2 deletions src/Options.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
module OptionsModule

using DispatchDoctor: @unstable
using Optim: Optim
using Dates: Dates
using StatsBase: StatsBase
Expand Down Expand Up @@ -374,8 +375,7 @@ https://github.com/MilesCranmer/PySR/discussions/115.
# Arguments
$(OPTION_DESCRIPTIONS)
"""
function Options end
@save_kwargs DEFAULT_OPTIONS function Options(;
@unstable @save_kwargs DEFAULT_OPTIONS function Options(;
binary_operators=[+, -, /, *],
unary_operators=[],
constraints=nothing,
Expand Down
4 changes: 3 additions & 1 deletion src/PopMember.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
module PopMemberModule

using DispatchDoctor: @unstable

using DynamicExpressions: AbstractExpressionNode, copy_node, count_nodes
using ..CoreModule: Options, Dataset, DATA_TYPE, LOSS_TYPE
import ..ComplexityModule: compute_complexity
Expand All @@ -25,7 +27,7 @@ function Base.setproperty!(member::PopMember, field::Symbol, value)
field == :tree && setfield!(member, :complexity, -1)
return setfield!(member, field, value)
end
function Base.getproperty(member::PopMember, field::Symbol)
@unstable @inline function Base.getproperty(member::PopMember, field::Symbol)
field == :complexity && throw(
error("Don't access `.complexity` directly. Use `compute_complexity` instead.")
)
Expand Down
9 changes: 5 additions & 4 deletions src/Population.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module PopulationModule

using StatsBase: StatsBase
using Random: randperm
using DispatchDoctor: @unstable
using DynamicExpressions: AbstractExpressionNode, Node, string_tree
using ..CoreModule: Options, Dataset, RecordType, DATA_TYPE, LOSS_TYPE
using ..ComplexityModule: compute_complexity
Expand Down Expand Up @@ -67,23 +68,23 @@ end
Create random population and score them on the dataset.
"""
function Population(
@unstable function Population(
X::AbstractMatrix{T},
y::AbstractVector{T};
population_size=nothing,
nlength::Int=3,
options::Options,
nfeatures::Int,
loss_type::Type=Nothing,
loss_type::Type{L}=Nothing,
npop=nothing,
) where {T<:DATA_TYPE}
) where {T<:DATA_TYPE,L}
@assert (population_size !== nothing) (npop !== nothing)
population_size = if npop === nothing
population_size
else
npop
end
dataset = Dataset(X, y; loss_type=loss_type)
dataset = Dataset(X, y, L)
update_baseline_loss!(dataset, options)
return Population(
dataset; population_size=population_size, options=options, nfeatures=nfeatures
Expand Down
5 changes: 4 additions & 1 deletion src/SearchUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ module SearchUtilsModule
using Printf: @printf, @sprintf
using Distributed
using StatsBase: mean
using DispatchDoctor: @unstable

using DynamicExpressions: AbstractExpressionNode, string_tree
using ..UtilsModule: subscriptify
Expand Down Expand Up @@ -37,7 +38,9 @@ Base.@kwdef struct RuntimeOptions{PARALLELISM,DIM_OUT,RETURN_STATE}
verbosity::Int64
progress::Bool
end
function Base.getproperty(roptions::RuntimeOptions{P,D,R}, name::Symbol) where {P,D,R}
@unstable @inline function Base.getproperty(
roptions::RuntimeOptions{P,D,R}, name::Symbol
) where {P,D,R}
if name == :parallelism
return P
elseif name == :dim_out
Expand Down
54 changes: 30 additions & 24 deletions src/SymbolicRegression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -149,27 +149,31 @@ function deprecate_varmap(variable_names, varMap, func_name)
return variable_names
end

include("Utils.jl")
include("InterfaceDynamicQuantities.jl")
include("Core.jl")
include("InterfaceDynamicExpressions.jl")
include("Recorder.jl")
include("Complexity.jl")
include("DimensionalAnalysis.jl")
include("CheckConstraints.jl")
include("AdaptiveParsimony.jl")
include("MutationFunctions.jl")
include("LossFunctions.jl")
include("PopMember.jl")
include("ConstantOptimization.jl")
include("Population.jl")
include("HallOfFame.jl")
include("Mutate.jl")
include("RegularizedEvolution.jl")
include("SingleIteration.jl")
include("ProgressBars.jl")
include("Migration.jl")
include("SearchUtils.jl")
using DispatchDoctor: @stable

@stable default_mode = "disable" begin
include("Utils.jl")
include("InterfaceDynamicQuantities.jl")
include("Core.jl")
include("InterfaceDynamicExpressions.jl")
include("Recorder.jl")
include("Complexity.jl")
include("DimensionalAnalysis.jl")
include("CheckConstraints.jl")
include("AdaptiveParsimony.jl")
include("MutationFunctions.jl")
include("LossFunctions.jl")
include("PopMember.jl")
include("ConstantOptimization.jl")
include("Population.jl")
include("HallOfFame.jl")
include("Mutate.jl")
include("RegularizedEvolution.jl")
include("SingleIteration.jl")
include("ProgressBars.jl")
include("Migration.jl")
include("SearchUtils.jl")
end

using .CoreModule:
MAX_DEGREE,
Expand Down Expand Up @@ -255,8 +259,10 @@ using .SearchUtilsModule:
get_cur_maxsize,
update_hall_of_fame!

include("deprecates.jl")
include("Configure.jl")
@stable default_mode = "disable" begin
include("deprecates.jl")
include("Configure.jl")
end

"""
equation_search(X, y[; kws...])
Expand Down Expand Up @@ -1076,7 +1082,7 @@ function _format_output(state::SearchState, ropt::RuntimeOptions)
end
end

function _dispatch_s_r_cycle(
@stable default_mode = "disable" function _dispatch_s_r_cycle(
in_pop::Population{T,L,N},
dataset::Dataset,
options::Options;
Expand Down
5 changes: 3 additions & 2 deletions src/Utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ const max_ops = 8192
const vals = ntuple(Val, max_ops)

"""Return the bottom k elements of x, and their indices."""
bottomk_fast(x, k) = _bottomk_dispatch(x, vals[k])
bottomk_fast(x::AbstractVector{T}, k) where {T} =
_bottomk_dispatch(x, vals[k])::Tuple{Vector{T},Vector{Int}}

function _bottomk_dispatch(x::AbstractVector{T}, ::Val{k}) where {T,k}
if k == 1
Expand Down Expand Up @@ -179,7 +180,7 @@ function _save_kwargs(log_variable::Symbol, fdef::Expr)
return true
end
return quote
$fdef
$(Base).@__doc__ $fdef
const $log_variable = $kwargs
end
end
Expand Down
3 changes: 3 additions & 0 deletions test/LocalPreferences.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[SymbolicRegression]
instability_check = "error"
instability_check_codegen = "min"
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[deps]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
Bumper = "8ce10254-0962-460f-a3d8-1f77fea1446e"
DispatchDoctor = "8d63f2c5-f18a-4cf2-ba9d-b3f60fc568c8"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
DynamicExpressions = "a40a106e-89c9-4ca8-8020-a735e8728b6b"
DynamicQuantities = "06fc5a27-2a28-4c7c-a15d-362465fb6821"
Expand Down
2 changes: 1 addition & 1 deletion test/test_abstract_numbers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ for T in (ComplexF16, ComplexF32, ComplexF64)
elementwise_loss=(prediction, target) -> abs2(prediction - target),
)

dataset = Dataset(X, y; loss_type=L)
dataset = Dataset(X, y, L)
hof = if T == ComplexF16
equation_search([dataset]; options=options, niterations=1_000_000_000)
else
Expand Down
5 changes: 4 additions & 1 deletion test/test_dataset.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using SymbolicRegression
using Test
using DispatchDoctor: allow_unstable

@testset "Dataset construction" begin
# Promotion of types:
Expand All @@ -9,6 +10,8 @@ using Test
end

@testset "With deprecated kwarg" begin
dataset = Dataset(randn(ComplexF32, 3, 32), randn(ComplexF32, 32); loss_type=Float64)
dataset = allow_unstable() do
Dataset(randn(ComplexF32, 3, 32), randn(ComplexF32, 32); loss_type=Float64)
end
@test dataset isa Dataset{ComplexF32,Float64}
end
1 change: 1 addition & 0 deletions test/test_operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ using SymbolicRegression:
using Test
using Random: MersenneTwister
using Suppressor: @capture_err
using LoopVectorization
include("test_params.jl")

@testset "Generic operator tests" begin
Expand Down

0 comments on commit ea03242

Please sign in to comment.