Skip to content

Commit

Permalink
Merge branch 'master' into testitem
Browse files Browse the repository at this point in the history
  • Loading branch information
MilesCranmer committed Jun 16, 2024
2 parents 3d09913 + 7451580 commit 0bd0081
Show file tree
Hide file tree
Showing 7 changed files with 147 additions and 97 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ SymbolicRegressionSymbolicUtilsExt = "SymbolicUtils"
[compat]
Compat = "^4.2"
Dates = "1"
Distributed = "1"
DispatchDoctor = "0.4"
Distributed = "1"
DynamicExpressions = "0.16"
DynamicQuantities = "0.10, 0.11, 0.12, 0.13"
JSON3 = "1"
Expand Down
31 changes: 17 additions & 14 deletions src/Complexity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@ module ComplexityModule
using DynamicExpressions: AbstractExpressionNode, count_nodes, tree_mapreduce
using ..CoreModule: Options, ComplexityMapping

function past_complexity_limit(
tree::AbstractExpressionNode, options::Options{CT}, limit
)::Bool where {CT}
function past_complexity_limit(tree::AbstractExpressionNode, options::Options, limit)::Bool
return compute_complexity(tree, options) > limit
end

Expand All @@ -17,27 +15,32 @@ However, it could use the custom settings in options.complexity_mapping
if these are defined.
"""
function compute_complexity(
tree::AbstractExpressionNode, options::Options{CT}; break_sharing=Val(false)
)::Int where {CT}
tree::AbstractExpressionNode, options::Options; break_sharing=Val(false)
)::Int
if options.complexity_mapping.use
raw_complexity = _compute_complexity(tree, options; break_sharing)
raw_complexity = _compute_complexity(
tree, options.complexity_mapping; break_sharing
)
return round(Int, raw_complexity)
else
return count_nodes(tree; break_sharing)
end
end

function _compute_complexity(
tree::AbstractExpressionNode, options::Options{CT}; break_sharing=Val(false)
tree::AbstractExpressionNode, cmap::ComplexityMapping{CT}; break_sharing=Val(false)
)::CT where {CT}
cmap = options.complexity_mapping
constant_complexity = cmap.constant_complexity
variable_complexity = cmap.variable_complexity
unaop_complexities = cmap.unaop_complexities
binop_complexities = cmap.binop_complexities
return tree_mapreduce(
t -> t.constant ? constant_complexity : variable_complexity,
t -> t.degree == 1 ? unaop_complexities[t.op] : binop_complexities[t.op],
let vc = cmap.variable_complexity, cc = cmap.constant_complexity
if vc isa AbstractVector
t -> t.constant ? cc : @inbounds(vc[t.feature])
else
t -> t.constant ? cc : vc
end
end,
let uc = cmap.unaop_complexities, bc = cmap.binop_complexities
t -> t.degree == 1 ? @inbounds(uc[t.op]) : @inbounds(bc[t.op])
end,
+,
tree,
CT;
Expand Down
72 changes: 12 additions & 60 deletions src/Options.jl
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,8 @@ const OPTION_DESCRIPTIONS = """- `binary_operators`: Vector of binary operators
Input this in the form of, e.g., [(^) => 3, sin => 2].
- `complexity_of_constants`: What complexity should be assigned to use of a constant.
By default, this is 1.
- `complexity_of_variables`: What complexity should be assigned to each variable.
- `complexity_of_variables`: What complexity should be assigned to use of a variable,
which can also be a vector indicating different per-variable complexity.
By default, this is 1.
- `alpha`: The probability of accepting an equation mutation
during regularized evolution is given by exp(-delta_loss/(alpha * T)),
Expand Down Expand Up @@ -376,8 +377,8 @@ https://github.com/MilesCranmer/PySR/discussions/115.
$(OPTION_DESCRIPTIONS)
"""
@unstable @save_kwargs DEFAULT_OPTIONS function Options(;
binary_operators=[+, -, /, *],
unary_operators=[],
binary_operators=Function[+, -, /, *],
unary_operators=Function[],
constraints=nothing,
elementwise_loss::Union{Function,SupervisedLoss,Nothing}=nothing,
loss_function::Union{Function,Nothing}=nothing,
Expand All @@ -386,7 +387,7 @@ $(OPTION_DESCRIPTIONS)
topn::Integer=12, #samples to return per population
complexity_of_operators=nothing,
complexity_of_constants::Union{Nothing,Real}=nothing,
complexity_of_variables::Union{Nothing,Real}=nothing,
complexity_of_variables::Union{Nothing,Real,AbstractVector}=nothing,
parsimony::Real=0.0032,
dimensional_constraint_penalty::Union{Nothing,Real}=nothing,
dimensionless_constants_only::Bool=false,
Expand Down Expand Up @@ -645,62 +646,13 @@ $(OPTION_DESCRIPTIONS)
una_constraints, bin_constraints, unary_operators, binary_operators, nuna, nbin
)

# Define the complexities of everything.
use_complexity_mapping = (
complexity_of_constants !== nothing ||
complexity_of_variables !== nothing ||
complexity_of_operators !== nothing
complexity_mapping = ComplexityMapping(
complexity_of_operators,
complexity_of_variables,
complexity_of_constants,
binary_operators,
unary_operators,
)
complexity_mapping = if use_complexity_mapping
if complexity_of_operators === nothing
complexity_of_operators = Dict()
else
# Convert to dict:
complexity_of_operators = Dict(complexity_of_operators)
end

# Get consistent type:
promoted_type = promote_type(
if (complexity_of_variables !== nothing)
typeof(complexity_of_variables)
else
Int
end,
if (complexity_of_constants !== nothing)
typeof(complexity_of_constants)
else
Int
end,
(x -> typeof(x)).(values(complexity_of_operators))...,
)

# If not in dict, then just set it to 1.
binop_complexities = promoted_type[
(haskey(complexity_of_operators, op) ? complexity_of_operators[op] : 1) #
for op in binary_operators
]
unaop_complexities = promoted_type[
(haskey(complexity_of_operators, op) ? complexity_of_operators[op] : 1) #
for op in unary_operators
]

variable_complexity = (
(complexity_of_variables !== nothing) ? complexity_of_variables : 1
)
constant_complexity = (
(complexity_of_constants !== nothing) ? complexity_of_constants : 1
)

ComplexityMapping(;
binop_complexities=binop_complexities,
unaop_complexities=unaop_complexities,
variable_complexity=variable_complexity,
constant_complexity=constant_complexity,
)
else
ComplexityMapping(false)
end
# Finish defining complexities

if maxdepth === nothing
maxdepth = maxsize
Expand Down Expand Up @@ -772,7 +724,7 @@ $(OPTION_DESCRIPTIONS)
@assert print_precision > 0

options = Options{
eltype(complexity_mapping),
typeof(complexity_mapping),
operator_specialization(typeof(operators)),
node_type,
turbo,
Expand Down
119 changes: 99 additions & 20 deletions src/OptionsStruct.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,35 +7,108 @@ using LossFunctions: SupervisedLoss

import ..MutationWeightsModule: MutationWeights

"""This struct defines how complexity is calculated."""
struct ComplexityMapping{T<:Real}
use::Bool # Whether we use custom complexity, or just use 1 for everythign.
binop_complexities::Vector{T} # Complexity of each binary operator.
unaop_complexities::Vector{T} # Complexity of each unary operator.
variable_complexity::T # Complexity of using a variable.
constant_complexity::T # Complexity of using a constant.
"""
This struct defines how complexity is calculated.
# Fields
- `use`: Shortcut indicating whether we use custom complexities,
or just use 1 for everything.
- `binop_complexities`: Complexity of each binary operator.
- `unaop_complexities`: Complexity of each unary operator.
- `variable_complexity`: Complexity of using a variable.
- `constant_complexity`: Complexity of using a constant.
"""
struct ComplexityMapping{T<:Real,VC<:Union{T,AbstractVector{T}}}
use::Bool
binop_complexities::Vector{T}
unaop_complexities::Vector{T}
variable_complexity::VC
constant_complexity::T
end

Base.eltype(::ComplexityMapping{T}) where {T} = T

function ComplexityMapping(use::Bool)
return ComplexityMapping{Int}(use, zeros(Int, 0), zeros(Int, 0), 1, 1)
end

"""Promote type when defining complexity mapping."""
function ComplexityMapping(;
binop_complexities::Vector{T1},
unaop_complexities::Vector{T2},
variable_complexity::T3,
variable_complexity::Union{T3,AbstractVector{T3}},
constant_complexity::T4,
) where {T1<:Real,T2<:Real,T3<:Real,T4<:Real}
promoted_T = promote_type(T1, T2, T3, T4)
return ComplexityMapping{promoted_T}(
T = promote_type(T1, T2, T3, T4)
vc = map(T, variable_complexity)
return ComplexityMapping{T,typeof(vc)}(
true,
binop_complexities,
unaop_complexities,
variable_complexity,
constant_complexity,
map(T, binop_complexities),
map(T, unaop_complexities),
vc,
T(constant_complexity),
)
end

function ComplexityMapping(
::Nothing, ::Nothing, ::Nothing, binary_operators, unary_operators
)
# If no customization provided, then we simply
# turn off the complexity mapping
use = false
return ComplexityMapping{Int,Int}(use, zeros(Int, 0), zeros(Int, 0), 0, 0)
end
function ComplexityMapping(
complexity_of_operators,
complexity_of_variables,
complexity_of_constants,
binary_operators,
unary_operators,
)
_complexity_of_operators = if complexity_of_operators === nothing
Dict{Function,Int64}()
else
# Convert to dict:
Dict(complexity_of_operators)
end

VAR_T = if (complexity_of_variables !== nothing)
if complexity_of_variables isa AbstractVector
eltype(complexity_of_variables)
else
typeof(complexity_of_variables)
end
else
Int
end
CONST_T = if (complexity_of_constants !== nothing)
typeof(complexity_of_constants)
else
Int
end
OP_T = eltype(_complexity_of_operators).parameters[2]

T = promote_type(VAR_T, CONST_T, OP_T)

# If not in dict, then just set it to 1.
binop_complexities = T[
(haskey(_complexity_of_operators, op) ? _complexity_of_operators[op] : one(T)) #
for op in binary_operators
]
unaop_complexities = T[
(haskey(_complexity_of_operators, op) ? _complexity_of_operators[op] : one(T)) #
for op in unary_operators
]

variable_complexity = if complexity_of_variables !== nothing
map(T, complexity_of_variables)
else
one(T)
end
constant_complexity = if complexity_of_constants !== nothing
map(T, complexity_of_constants)
else
one(T)
end

return ComplexityMapping(;
binop_complexities, unaop_complexities, variable_complexity, constant_complexity
)
end

Expand All @@ -48,12 +121,18 @@ else
end

struct Options{
CT,OP<:AbstractOperatorEnum,N<:AbstractExpressionNode,_turbo,_bumper,_return_state,W
CM<:ComplexityMapping,
OP<:AbstractOperatorEnum,
N<:AbstractExpressionNode,
_turbo,
_bumper,
_return_state,
W,
}
operators::OP
bin_constraints::Vector{Tuple{Int,Int}}
una_constraints::Vector{Int}
complexity_mapping::ComplexityMapping{CT}
complexity_mapping::CM
tournament_selection_n::Int
tournament_selection_p::Float32
tournament_selection_weights::W
Expand Down
4 changes: 2 additions & 2 deletions src/SymbolicRegression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -589,7 +589,7 @@ function equation_search(
)
end

@noinline function _equation_search(
@stable default_mode = "disable" @noinline function _equation_search(
datasets::Vector{D}, ropt::RuntimeOptions, options::Options, saved_state
) where {D<:Dataset}
_validate_options(datasets, ropt, options)
Expand Down Expand Up @@ -631,7 +631,7 @@ function _validate_options(
end
return nothing
end
function _create_workers(
@stable default_mode = "disable" function _create_workers(
datasets::Vector{D}, ropt::RuntimeOptions, options::Options
) where {T,L,D<:Dataset{T,L}}
stdin_reader = watch_stream(stdin)
Expand Down
4 changes: 4 additions & 0 deletions test/LocalPreferences.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
[DynamicExpressions]
instability_check = "error"
instability_check_codegen = "min"

[SymbolicRegression]
instability_check = "error"
instability_check_codegen = "min"
12 changes: 12 additions & 0 deletions test/test_complexity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,16 @@ options = make_options(;
)
@test compute_complexity(tree, options) == 12 + 3 * 1 + 1 + 1

# Custom variables
options = make_options(;
complexity_of_variables=[1, 2, 3], complexity_of_operators=[(+) => 5, (*) => 2]
)
x1, x2, x3 = [Node{Float64}(; feature=i) for i in 1:3]
tree = x1 + x2 * x3
@test compute_complexity(tree, options) == 1 + 5 + 2 + 2 + 3
options = make_options(;
complexity_of_variables=2, complexity_of_operators=[(+) => 5, (*) => 2]
)
@test compute_complexity(tree, options) == 2 + 5 + 2 + 2 + 2

println("Passed.")

0 comments on commit 0bd0081

Please sign in to comment.