Skip to content

Commit

Permalink
Merge pull request #325 from MilesCranmer/testitem
Browse files Browse the repository at this point in the history
Refactor tests to use TestItems.jl
  • Loading branch information
MilesCranmer authored Jun 16, 2024
2 parents 7451580 + 0bd0081 commit 0fc573b
Show file tree
Hide file tree
Showing 52 changed files with 278 additions and 254 deletions.
15 changes: 11 additions & 4 deletions src/DimensionalAnalysis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -194,14 +194,11 @@ function violates_dimensional_constraints(
end
function violates_dimensional_constraints(
tree::AbstractExpressionNode{T},
X_units::Union{AbstractVector{<:Quantity},Nothing},
X_units::AbstractVector{<:Quantity},
y_units::Union{Quantity,Nothing},
x::AbstractVector{T},
options::Options,
) where {T}
if X_units === nothing && y_units === nothing
return false
end
allow_wildcards = !(options.dimensionless_constants_only)
dimensional_output = violates_dimensional_constraints_dispatch(
tree, X_units, x, options.operators, allow_wildcards
Expand All @@ -217,5 +214,15 @@ function violates_dimensional_constraints(
end
return violates
end
function violates_dimensional_constraints(
::AbstractExpressionNode{T}, ::Nothing, ::Quantity, ::AbstractVector{T}, ::Options
) where {T}
return error("This should never happen. Please submit a bug report.")
end
function violates_dimensional_constraints(
::AbstractExpressionNode{T}, ::Nothing, ::Nothing, ::AbstractVector{T}, ::Options
) where {T}
return false
end

end
4 changes: 2 additions & 2 deletions src/InterfaceDynamicExpressions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,10 +129,10 @@ end
Evaluate an expression tree in a way that can be auto-differentiated.
"""
function differentiable_eval_tree_array(
tree::AbstractExpressionNode, X::AbstractArray, options::Options; kws...
tree::AbstractExpressionNode, X::AbstractArray, options::Options
)
A = expected_array_type(X)
return differentiable_eval_tree_array(tree, X, options.operators; kws...)::Tuple{A,Bool}
return differentiable_eval_tree_array(tree, X, options.operators)::Tuple{A,Bool}
end

const WILDCARD_UNIT_STRING = "[?]"
Expand Down
4 changes: 2 additions & 2 deletions src/LossFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ using ..DimensionalAnalysisModule: violates_dimensional_constraints
function _loss(
x::AbstractArray{T}, y::AbstractArray{T}, loss::LT
) where {T<:DATA_TYPE,LT<:Union{Function,SupervisedLoss}}
if LT <: SupervisedLoss
if loss isa SupervisedLoss
return LossFunctions.mean(loss, x, y)
else
l(i) = loss(x[i], y[i])
Expand All @@ -24,7 +24,7 @@ end
function _weighted_loss(
x::AbstractArray{T}, y::AbstractArray{T}, w::AbstractArray{T}, loss::LT
) where {T<:DATA_TYPE,LT<:Union{Function,SupervisedLoss}}
if LT <: SupervisedLoss
if loss isa SupervisedLoss
return LossFunctions.sum(loss, x, y, w; normalize=true)
else
l(i) = loss(x[i], y[i], w[i])
Expand Down
4 changes: 2 additions & 2 deletions src/Operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ function logical_and(x, y)
end

# Deprecated operations:
@deprecate pow safe_pow
@deprecate pow_abs safe_pow
@deprecate pow(x, y) safe_pow(x, y)
@deprecate pow_abs(x, y) safe_pow(x, y)

end
6 changes: 6 additions & 0 deletions src/Options.jl
Original file line number Diff line number Diff line change
Expand Up @@ -549,10 +549,16 @@ $(OPTION_DESCRIPTIONS)
)
end

is_testing = parse(Bool, get(ENV, "SYMBOLIC_REGRESSION_IS_TESTING", "false"))

if output_file === nothing
# "%Y-%m-%d_%H%M%S.%f"
date_time_str = Dates.format(Dates.now(), "yyyy-mm-dd_HHMMSS.sss")
output_file = "hall_of_fame_" * date_time_str * ".csv"
if is_testing
tmpdir = mktempdir()
output_file = joinpath(tmpdir, output_file)
end
end

nuna = length(unary_operators)
Expand Down
7 changes: 5 additions & 2 deletions src/SearchUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ Base.@kwdef struct RuntimeOptions{PARALLELISM,DIM_OUT,RETURN_STATE}
runtests::Bool
verbosity::Int64
progress::Bool
parallelism::Val{PARALLELISM}
dim_out::Val{DIM_OUT}
return_state::Val{RETURN_STATE}
end
@unstable @inline function Base.getproperty(
roptions::RuntimeOptions{P,D,R}, name::Symbol
Expand Down Expand Up @@ -108,8 +111,8 @@ macro sr_spawner(expr, kws...)
@assert all(ex -> ex.head == :(=), kws)
@assert any(ex -> ex.args[1] == :parallelism, kws)
@assert any(ex -> ex.args[1] == :worker_idx, kws)
parallelism = kws[findfirst(ex -> ex.args[1] == :parallelism, kws)].args[2]
worker_idx = kws[findfirst(ex -> ex.args[1] == :worker_idx, kws)].args[2]
parallelism = kws[findfirst(ex -> ex.args[1] == :parallelism, kws)::Int].args[2]
worker_idx = kws[findfirst(ex -> ex.args[1] == :worker_idx, kws)::Int].args[2]
return quote
if $(parallelism) == :serial
$(expr)
Expand Down
5 changes: 4 additions & 1 deletion src/SymbolicRegression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,7 @@ function equation_search(
# Underscores here mean that we have mutated the variable
return _equation_search(
datasets,
RuntimeOptions{concurrency,dim_out,_return_state}(;
RuntimeOptions(;
niterations=niterations,
total_cycles=options.populations * niterations,
numprocs=_numprocs,
Expand All @@ -580,6 +580,9 @@ function equation_search(
runtests=runtests,
verbosity=_verbosity,
progress=_progress,
parallelism=Val(concurrency),
dim_out=Val(dim_out),
return_state=Val(_return_state),
),
options,
saved_state,
Expand Down
2 changes: 1 addition & 1 deletion src/Utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ using MacroTools: splitdef, combinedef

const pseudo_time = Ref(0)

function get_birth_order(; deterministic=false)::Int
function get_birth_order(; deterministic::Bool=false)::Int
"""deterministic gives a birth time with perfect resolution, but is not thread safe."""
if deterministic
global pseudo_time
Expand Down
4 changes: 2 additions & 2 deletions src/deprecates.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@ import .MutationFunctionsModule: gen_random_tree, gen_random_tree_fixed_size

@deprecate(
gen_random_tree(length::Int, options::Options, nfeatures::Int, t::Type),
gen_random_tree(length, options, nfeatures, t, Node)
gen_random_tree(length, options, nfeatures, t)
)
@deprecate(
gen_random_tree_fixed_size(node_count::Int, options::Options, nfeatures::Int, t::Type),
gen_random_tree_fixed_size(node_count, options, nfeatures, t, Node)
gen_random_tree_fixed_size(node_count, options, nfeatures, t)
)

@deprecate(
Expand Down
2 changes: 1 addition & 1 deletion src/precompile.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ function do_precompilation(::Val{mode}) where {mode}
return_state=false,
verbosity=0,
)
nout == 1 && calculate_pareto_frontier(hof)
nout == 1 && calculate_pareto_frontier(hof::HallOfFame)
end
end
end
Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
DynamicExpressions = "a40a106e-89c9-4ca8-8020-a735e8728b6b"
DynamicQuantities = "06fc5a27-2a28-4c7c-a15d-362465fb6821"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down
53 changes: 0 additions & 53 deletions test/full.jl

This file was deleted.

Loading

0 comments on commit 0fc573b

Please sign in to comment.