From 5973611b6bbd00c543c9e508830f84e2d3adb4e9 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sat, 6 Jan 2024 15:16:09 +0000 Subject: [PATCH 01/54] Add numerical logging for TensorBoardLogger integration --- Project.toml | 1 + src/MLJInterface.jl | 3 +++ src/SymbolicRegression.jl | 29 +++++++++++++++++++++++++++++ 3 files changed, 33 insertions(+) diff --git a/Project.toml b/Project.toml index 3d8334072..a59417ee1 100644 --- a/Project.toml +++ b/Project.toml @@ -10,6 +10,7 @@ Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" DynamicExpressions = "a40a106e-89c9-4ca8-8020-a735e8728b6b" DynamicQuantities = "06fc5a27-2a28-4c7c-a15d-362465fb6821" LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255" +Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" LossFunctions = "30fc2ffe-d236-52d8-8643-a9d8f7c094a7" MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" diff --git a/src/MLJInterface.jl b/src/MLJInterface.jl index 434b9ac18..7ab28b5a4 100644 --- a/src/MLJInterface.jl +++ b/src/MLJInterface.jl @@ -1,6 +1,7 @@ module MLJInterfaceModule using Optim: Optim +using Logging: AbstractLogger import MLJModelInterface as MMI import DynamicExpressions: eval_tree_array, string_tree, Node import DynamicQuantities: @@ -40,6 +41,7 @@ function modelexpr(model_name::Symbol) procs::Union{Vector{Int},Nothing} = nothing addprocs_function::Union{Function,Nothing} = nothing heap_size_hint_in_bytes::Union{Integer,Nothing} = nothing + logger::Union{AbstractLogger,Nothing} = nothing runtests::Bool = true loss_type::L = Nothing selection_method::Function = choose_best @@ -175,6 +177,7 @@ function _update(m, verbosity, old_fitresult, old_cache, X, y, w, options) X_units=X_units_clean, y_units=y_units_clean, verbosity=verbosity, + logger=m.logger, # Help out with inference: v_dim_out=isa(m, SRRegressor) ? Val(1) : Val(2), ) diff --git a/src/SymbolicRegression.jl b/src/SymbolicRegression.jl index 38f28f645..f874ea63a 100644 --- a/src/SymbolicRegression.jl +++ b/src/SymbolicRegression.jl @@ -66,6 +66,7 @@ export Population, using Distributed import Printf: @printf, @sprintf import PackageExtensionCompat: @require_extensions +using Logging: AbstractLogger, with_logger using Pkg: Pkg import TOML: parsefile import Random: seed!, shuffle! @@ -351,6 +352,7 @@ function equation_search( return_state::Union{Bool,Nothing}=nothing, loss_type::Type{L}=Nothing, verbosity::Union{Integer,Nothing}=nothing, + logger::Union{AbstractLogger,Nothing}=nothing, progress::Union{Bool,Nothing}=nothing, X_units::Union{AbstractVector,Nothing}=nothing, y_units=nothing, @@ -397,6 +399,7 @@ function equation_search( saved_state=saved_state, return_state=return_state, verbosity=verbosity, + logger=logger, progress=progress, v_dim_out=Val(DIM_OUT), ) @@ -434,6 +437,7 @@ function equation_search( saved_state=nothing, return_state::Union{Bool,Nothing}=nothing, verbosity::Union{Int,Nothing}=nothing, + logger::Union{AbstractLogger,Nothing}=nothing, progress::Union{Bool,Nothing}=nothing, v_dim_out::Val{DIM_OUT}=Val(nothing), ) where {DIM_OUT,T<:DATA_TYPE,L<:LOSS_TYPE,D<:Dataset{T,L}} @@ -555,6 +559,7 @@ function equation_search( runtests, saved_state, _verbosity, + logger, _progress, Val(_return_state), ) @@ -573,6 +578,7 @@ function _equation_search( runtests::Bool, saved_state, verbosity, + logger, progress, ::Val{RETURN_STATE}, ) where {T<:DATA_TYPE,L<:LOSS_TYPE,D<:Dataset{T,L},PARALLELISM,RETURN_STATE,DIM_OUT} @@ -951,6 +957,29 @@ function _equation_search( PARALLELISM, ) end + if logger !== nothing + with_logger(logger) do + if nout == 1 + dominating = calculate_pareto_frontier(only(hallOfFame)) + best_loss = length(dominating) > 0 ? dominating[end].loss : L(Inf) + @info( + "search_state", + num_evals = sum(sum, num_evals), + pareto_front = dominating, + best_loss = best_loss, + ) + else + dominating = calculate_pareto_frontier.(hallOfFame) + best_loss = (d -> length(d) > 0 ? d[end].loss : L(Inf)).(dominating) + @info( + "search_state", + num_evals = sum(sum, num_evals), + pareto_front = dominating, + best_loss = best_loss, + ) + end + end + end end sleep(1e-6) From 05d8a6219a130f4bc3330693178f5903f306f7dc Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sat, 6 Jan 2024 16:41:07 +0000 Subject: [PATCH 02/54] Create `logging_callback` for user to define --- src/SearchUtils.jl | 28 +++++++++++++++++++++++ src/SymbolicRegression.jl | 47 ++++++++++++++++++--------------------- 2 files changed, 50 insertions(+), 25 deletions(-) diff --git a/src/SearchUtils.jl b/src/SearchUtils.jl index a94044103..6a22fb7c2 100644 --- a/src/SearchUtils.jl +++ b/src/SearchUtils.jl @@ -5,6 +5,8 @@ module SearchUtilsModule import Printf: @printf, @sprintf using Distributed +using Logging: with_logger +using DynamicExpressions: string_tree import StatsBase: mean import ..UtilsModule: subscriptify @@ -402,4 +404,30 @@ function update_hall_of_fame!( end end +function default_logging_callback( + logger; options, num_evals, hall_of_fame, datasets::Vector{D}, kws... +) where {T,L,D<:Dataset{T,L}} + with_logger(logger) do + @info("search_state", num_evals = sum(sum, num_evals)) + for (i_hof, (hof, dataset)) in enumerate(zip(hall_of_fame, datasets)) + dominating = calculate_pareto_frontier(hof) + best_loss = length(dominating) > 0 ? dominating[end].loss : L(Inf) + losses = L[member.loss for member in dominating] + complexities = Int[compute_complexity(member, options) for member in dominating] + equations = String[ + string_tree(member.tree, options; variable_names=dataset.variable_names) for + member in dominating + ] + @info( + "search_state_$(i_hof)", + best_loss = best_loss, + equations = equations, + losses = losses, + complexities = complexities, + log_step_increment = 0, + ) + end + end +end + end diff --git a/src/SymbolicRegression.jl b/src/SymbolicRegression.jl index f874ea63a..4d2919006 100644 --- a/src/SymbolicRegression.jl +++ b/src/SymbolicRegression.jl @@ -235,7 +235,8 @@ import .SearchUtilsModule: load_saved_population, construct_datasets, get_cur_maxsize, - update_hall_of_fame! + update_hall_of_fame!, + default_logging_callback include("deprecates.jl") include("Configure.jl") @@ -353,6 +354,7 @@ function equation_search( loss_type::Type{L}=Nothing, verbosity::Union{Integer,Nothing}=nothing, logger::Union{AbstractLogger,Nothing}=nothing, + logging_callback::Union{Function,Nothing}=nothing, progress::Union{Bool,Nothing}=nothing, X_units::Union{AbstractVector,Nothing}=nothing, y_units=nothing, @@ -400,6 +402,7 @@ function equation_search( return_state=return_state, verbosity=verbosity, logger=logger, + logging_callback=logging_callback, progress=progress, v_dim_out=Val(DIM_OUT), ) @@ -438,6 +441,7 @@ function equation_search( return_state::Union{Bool,Nothing}=nothing, verbosity::Union{Int,Nothing}=nothing, logger::Union{AbstractLogger,Nothing}=nothing, + logging_callback::Union{Function,Nothing}=nothing, progress::Union{Bool,Nothing}=nothing, v_dim_out::Val{DIM_OUT}=Val(nothing), ) where {DIM_OUT,T<:DATA_TYPE,L<:LOSS_TYPE,D<:Dataset{T,L}} @@ -544,6 +548,11 @@ function equation_search( else `` end + _logging_callback = if logging_callback === nothing && logger !== nothing + (; kws...) -> default_logging_callback(logger; kws...) + else + logging_callback + end # Underscores here mean that we have mutated the variable return _equation_search( @@ -559,7 +568,7 @@ function equation_search( runtests, saved_state, _verbosity, - logger, + _logging_callback, _progress, Val(_return_state), ) @@ -578,7 +587,7 @@ function _equation_search( runtests::Bool, saved_state, verbosity, - logger, + logging_callback, progress, ::Val{RETURN_STATE}, ) where {T<:DATA_TYPE,L<:LOSS_TYPE,D<:Dataset{T,L},PARALLELISM,RETURN_STATE,DIM_OUT} @@ -957,28 +966,16 @@ function _equation_search( PARALLELISM, ) end - if logger !== nothing - with_logger(logger) do - if nout == 1 - dominating = calculate_pareto_frontier(only(hallOfFame)) - best_loss = length(dominating) > 0 ? dominating[end].loss : L(Inf) - @info( - "search_state", - num_evals = sum(sum, num_evals), - pareto_front = dominating, - best_loss = best_loss, - ) - else - dominating = calculate_pareto_frontier.(hallOfFame) - best_loss = (d -> length(d) > 0 ? d[end].loss : L(Inf)).(dominating) - @info( - "search_state", - num_evals = sum(sum, num_evals), - pareto_front = dominating, - best_loss = best_loss, - ) - end - end + if logging_callback !== nothing + logging_callback(; + options, + num_evals, + hall_of_fame=hallOfFame, + worker_assignment, + cycles_remaining, + populations=returnPops, + datasets=datasets, + ) end end sleep(1e-6) From 90799ca04fe3a1b6925de99855d4bf791902ae6a Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sat, 6 Jan 2024 16:54:39 +0000 Subject: [PATCH 03/54] Give error message for both `logger` and `logging_callback` passed. --- src/SymbolicRegression.jl | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/SymbolicRegression.jl b/src/SymbolicRegression.jl index 4d2919006..2102ef789 100644 --- a/src/SymbolicRegression.jl +++ b/src/SymbolicRegression.jl @@ -468,6 +468,11 @@ function equation_search( error( "`numprocs` should not be set when using `parallelism=$(parallelism)`. Please use `:multiprocessing`.", ) + logging_callback !== nothing && + logger !== nothing && + error( + "You cannot set both `logging_callback` and `logger`. Instead, simply use your logger within the `logging_callback`.", + ) # TODO: Still not type stable. Should be able to pass `Val{return_state}`. _return_state = if options.return_state === nothing From 2d851e2bbb350bc3ce851ad36352a283facdffa3 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sat, 6 Jan 2024 16:57:59 +0000 Subject: [PATCH 04/54] Clean up default logger --- src/SearchUtils.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/SearchUtils.jl b/src/SearchUtils.jl index 6a22fb7c2..4dbecf3c9 100644 --- a/src/SearchUtils.jl +++ b/src/SearchUtils.jl @@ -409,7 +409,7 @@ function default_logging_callback( ) where {T,L,D<:Dataset{T,L}} with_logger(logger) do @info("search_state", num_evals = sum(sum, num_evals)) - for (i_hof, (hof, dataset)) in enumerate(zip(hall_of_fame, datasets)) + for (i, (hof, dataset)) in enumerate(zip(hall_of_fame, datasets)) dominating = calculate_pareto_frontier(hof) best_loss = length(dominating) > 0 ? dominating[end].loss : L(Inf) losses = L[member.loss for member in dominating] @@ -419,7 +419,7 @@ function default_logging_callback( member in dominating ] @info( - "search_state_$(i_hof)", + "search_state_$(i)", best_loss = best_loss, equations = equations, losses = losses, From de69b29e1a8d14e975beafbc886cf82e4cd56a5c Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sat, 6 Jan 2024 17:01:49 +0000 Subject: [PATCH 05/54] Fix unbound parameter --- src/SearchUtils.jl | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/SearchUtils.jl b/src/SearchUtils.jl index 4dbecf3c9..42b4f04f2 100644 --- a/src/SearchUtils.jl +++ b/src/SearchUtils.jl @@ -404,9 +404,8 @@ function update_hall_of_fame!( end end -function default_logging_callback( - logger; options, num_evals, hall_of_fame, datasets::Vector{D}, kws... -) where {T,L,D<:Dataset{T,L}} +function default_logging_callback(logger; options, num_evals, hall_of_fame, datasets, _...) + L = typeof(first(datasets).baseline_loss) with_logger(logger) do @info("search_state", num_evals = sum(sum, num_evals)) for (i, (hof, dataset)) in enumerate(zip(hall_of_fame, datasets)) From 62fd0ef0132644603ad554e39750e2cc782fbdc8 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sat, 6 Jan 2024 17:05:00 +0000 Subject: [PATCH 06/54] Remove unused import --- src/SymbolicRegression.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/SymbolicRegression.jl b/src/SymbolicRegression.jl index 2102ef789..bb45a5eb1 100644 --- a/src/SymbolicRegression.jl +++ b/src/SymbolicRegression.jl @@ -66,7 +66,7 @@ export Population, using Distributed import Printf: @printf, @sprintf import PackageExtensionCompat: @require_extensions -using Logging: AbstractLogger, with_logger +using Logging: AbstractLogger using Pkg: Pkg import TOML: parsefile import Random: seed!, shuffle! From bb2e16b34c3ccb7fe95710ebf4d2381a7b2329f6 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sat, 6 Jan 2024 23:27:42 +0000 Subject: [PATCH 07/54] Log all complexities over time --- src/MLJInterface.jl | 4 ++++ src/SearchUtils.jl | 20 +++++++++++--------- src/SymbolicRegression.jl | 9 ++++++++- 3 files changed, 23 insertions(+), 10 deletions(-) diff --git a/src/MLJInterface.jl b/src/MLJInterface.jl index 594db3bf4..8d964aea3 100644 --- a/src/MLJInterface.jl +++ b/src/MLJInterface.jl @@ -41,6 +41,8 @@ function modelexpr(model_name::Symbol) addprocs_function::Union{Function,Nothing} = nothing heap_size_hint_in_bytes::Union{Integer,Nothing} = nothing logger::Union{AbstractLogger,Nothing} = nothing + logging_callback::Union{Function,Nothing} = nothing + log_every_n::Int = 1 runtests::Bool = true loss_type::L = Nothing selection_method::Function = choose_best @@ -177,6 +179,8 @@ function _update(m, verbosity, old_fitresult, old_cache, X, y, w, options) y_units=y_units_clean, verbosity=verbosity, logger=m.logger, + logging_callback=m.logging_callback, + log_every_n=m.log_every_n, # Help out with inference: v_dim_out=isa(m, SRRegressor) ? Val(1) : Val(2), ) diff --git a/src/SearchUtils.jl b/src/SearchUtils.jl index b92041384..ddab034fd 100644 --- a/src/SearchUtils.jl +++ b/src/SearchUtils.jl @@ -407,7 +407,7 @@ end function default_logging_callback(logger; options, num_evals, hall_of_fame, datasets, _...) L = typeof(first(datasets).baseline_loss) with_logger(logger) do - @info("search_state", num_evals = sum(sum, num_evals)) + d = Dict() for (i, (hof, dataset)) in enumerate(zip(hall_of_fame, datasets)) dominating = calculate_pareto_frontier(hof) best_loss = length(dominating) > 0 ? dominating[end].loss : L(Inf) @@ -417,15 +417,17 @@ function default_logging_callback(logger; options, num_evals, hall_of_fame, data string_tree(member.tree, options; variable_names=dataset.variable_names) for member in dominating ] - @info( - "search_state_$(i)", - best_loss = best_loss, - equations = equations, - losses = losses, - complexities = complexities, - log_step_increment = 0, - ) + d[string(i)] = Dict() + d[string(i)]["best_loss"] = best_loss + d[string(i)]["equations"] = Dict() + for (complexity, loss, equation) in zip(complexities, losses, equations) + d[string(i)]["equations"][string(complexity)] = Dict( + "loss" => loss, "equation" => equation + ) + end end + d["num_evals"] = sum(sum, num_evals) + @info("search_state", data = d) end end diff --git a/src/SymbolicRegression.jl b/src/SymbolicRegression.jl index eb6688279..582c2bc42 100644 --- a/src/SymbolicRegression.jl +++ b/src/SymbolicRegression.jl @@ -355,6 +355,7 @@ function equation_search( verbosity::Union{Integer,Nothing}=nothing, logger::Union{AbstractLogger,Nothing}=nothing, logging_callback::Union{Function,Nothing}=nothing, + log_every_n::Int=1, progress::Union{Bool,Nothing}=nothing, X_units::Union{AbstractVector,Nothing}=nothing, y_units=nothing, @@ -403,6 +404,7 @@ function equation_search( verbosity=verbosity, logger=logger, logging_callback=logging_callback, + log_every_n=log_every_n, progress=progress, v_dim_out=Val(DIM_OUT), ) @@ -442,6 +444,7 @@ function equation_search( verbosity::Union{Int,Nothing}=nothing, logger::Union{AbstractLogger,Nothing}=nothing, logging_callback::Union{Function,Nothing}=nothing, + log_every_n::Int=1, progress::Union{Bool,Nothing}=nothing, v_dim_out::Val{DIM_OUT}=Val(nothing), ) where {DIM_OUT,T<:DATA_TYPE,L<:LOSS_TYPE,D<:Dataset{T,L}} @@ -574,6 +577,7 @@ function equation_search( saved_state, _verbosity, _logging_callback, + log_every_n, _progress, Val(_return_state), ) @@ -593,6 +597,7 @@ function _equation_search( saved_state, verbosity, logging_callback, + log_every_n, progress, ::Val{RETURN_STATE}, ) where {T<:DATA_TYPE,L<:LOSS_TYPE,D<:Dataset{T,L},PARALLELISM,RETURN_STATE,DIM_OUT} @@ -800,6 +805,7 @@ function _equation_search( ) end + log_step = 0 last_print_time = time() last_speed_recording_time = time() num_evals_last = sum(sum, num_evals) @@ -971,7 +977,7 @@ function _equation_search( PARALLELISM, ) end - if logging_callback !== nothing + if logging_callback !== nothing && log_step % log_every_n == 0 logging_callback(; options, num_evals, @@ -982,6 +988,7 @@ function _equation_search( datasets=datasets, ) end + log_step += 1 end sleep(1e-6) From 7f0b1134cd8f20aaeab52a731a3fdca9f49bf9c8 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 7 Jan 2024 00:16:02 +0000 Subject: [PATCH 08/54] Add Pareto curve plotting to logger --- Project.toml | 4 ++++ ext/SymbolicRegressionPlotsExt.jl | 39 +++++++++++++++++++++++++++++++ src/SearchUtils.jl | 8 +++++++ src/SymbolicRegression.jl | 3 ++- 4 files changed, 53 insertions(+), 1 deletion(-) create mode 100644 ext/SymbolicRegressionPlotsExt.jl diff --git a/Project.toml b/Project.toml index a59417ee1..b659cba31 100644 --- a/Project.toml +++ b/Project.toml @@ -29,10 +29,12 @@ Tricks = "410a4b4d-49e4-4fbc-ab6d-cb71b17b3775" [weakdeps] JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" +Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b" [extensions] SymbolicRegressionJSON3Ext = "JSON3" +SymbolicRegressionPlotsExt = "Plots" SymbolicRegressionSymbolicUtilsExt = "SymbolicUtils" [compat] @@ -48,6 +50,7 @@ MacroTools = "0.4, 0.5" Optim = "0.19, 1.1 - 1.7.6" PackageExtensionCompat = "1" Pkg = "1" +Plots = "1" PrecompileTools = "1" ProgressBars = "1.4" Reexport = "1" @@ -64,6 +67,7 @@ JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" MLJTestInterface = "72560011-54dd-4dc2-94f3-c5de45b75ecd" +Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb" SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b" diff --git a/ext/SymbolicRegressionPlotsExt.jl b/ext/SymbolicRegressionPlotsExt.jl new file mode 100644 index 000000000..f6bd629b2 --- /dev/null +++ b/ext/SymbolicRegressionPlotsExt.jl @@ -0,0 +1,39 @@ +module SymbolicRegressionPlotsExt + +import Plots: plot +using DynamicExpressions: Node +using SymbolicRegression: HallOfFame, Options, string_tree, sr_plot +using SymbolicRegression.MLJInterfaceModule: AbstractSRRegressor +using SymbolicRegression.HallOfFameModule: format_hall_of_fame + +function plot(hall_of_fame::HallOfFame, options::Options; variable_names=nothing, kws...) + return sr_plot(hall_of_fame, options; variable_names, kws...) +end + +function sr_plot(hall_of_fame::HallOfFame, options::Options; variable_names=nothing, kws...) + (; trees, losses, complexities) = format_hall_of_fame(hall_of_fame, options) + return sr_plot(trees, losses, complexities, options; variable_names, kws...) +end + +function sr_plot( + trees::Vector{N}, + losses::Vector{L}, + complexities::Vector{Int}, + options::Options; + variable_names=nothing, + kws..., +) where {T,L,N<:Node{T}} + tree_strings = [string_tree(tree, options; variable_names) for tree in trees] + return plot( + complexities, + losses; + label=nothing, + xlabel="Complexity", + ylabel="Loss", + title="Hall of Fame", + xlims=(0, options.maxsize), + kws..., + ) +end + +end diff --git a/src/SearchUtils.jl b/src/SearchUtils.jl index ddab034fd..3c620cafa 100644 --- a/src/SearchUtils.jl +++ b/src/SearchUtils.jl @@ -404,6 +404,10 @@ function update_hall_of_fame!( end end +function sr_plot(args...; kws...) + return nothing +end + function default_logging_callback(logger; options, num_evals, hall_of_fame, datasets, _...) L = typeof(first(datasets).baseline_loss) with_logger(logger) do @@ -411,6 +415,7 @@ function default_logging_callback(logger; options, num_evals, hall_of_fame, data for (i, (hof, dataset)) in enumerate(zip(hall_of_fame, datasets)) dominating = calculate_pareto_frontier(hof) best_loss = length(dominating) > 0 ? dominating[end].loss : L(Inf) + trees = [member.tree for member in dominating] losses = L[member.loss for member in dominating] complexities = Int[compute_complexity(member, options) for member in dominating] equations = String[ @@ -420,6 +425,9 @@ function default_logging_callback(logger; options, num_evals, hall_of_fame, data d[string(i)] = Dict() d[string(i)]["best_loss"] = best_loss d[string(i)]["equations"] = Dict() + d[string(i)]["plot"] = sr_plot( + trees, losses, complexities, options; variable_names=dataset.variable_names + ) for (complexity, loss, equation) in zip(complexities, losses, equations) d[string(i)]["equations"][string(complexity)] = Dict( "loss" => loss, "equation" => equation diff --git a/src/SymbolicRegression.jl b/src/SymbolicRegression.jl index 582c2bc42..42b90c4b7 100644 --- a/src/SymbolicRegression.jl +++ b/src/SymbolicRegression.jl @@ -236,7 +236,8 @@ using .SearchUtilsModule: construct_datasets, get_cur_maxsize, update_hall_of_fame!, - default_logging_callback + default_logging_callback, + sr_plot include("deprecates.jl") include("Configure.jl") From b4484bae9f10b6fc7acea0decb1b260d565ad305 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 7 Jan 2024 00:25:29 +0000 Subject: [PATCH 09/54] More descriptive name for plotting utils --- ext/SymbolicRegressionPlotsExt.jl | 12 ++++++------ src/SearchUtils.jl | 7 ++++--- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/ext/SymbolicRegressionPlotsExt.jl b/ext/SymbolicRegressionPlotsExt.jl index f6bd629b2..35c177942 100644 --- a/ext/SymbolicRegressionPlotsExt.jl +++ b/ext/SymbolicRegressionPlotsExt.jl @@ -2,20 +2,19 @@ module SymbolicRegressionPlotsExt import Plots: plot using DynamicExpressions: Node -using SymbolicRegression: HallOfFame, Options, string_tree, sr_plot -using SymbolicRegression.MLJInterfaceModule: AbstractSRRegressor +using SymbolicRegression: HallOfFame, Options, string_tree, default_sr_plot using SymbolicRegression.HallOfFameModule: format_hall_of_fame function plot(hall_of_fame::HallOfFame, options::Options; variable_names=nothing, kws...) - return sr_plot(hall_of_fame, options; variable_names, kws...) + return default_sr_plot(hall_of_fame, options; variable_names, kws...) end -function sr_plot(hall_of_fame::HallOfFame, options::Options; variable_names=nothing, kws...) +function default_sr_plot(hall_of_fame::HallOfFame, options::Options; variable_names=nothing, kws...) (; trees, losses, complexities) = format_hall_of_fame(hall_of_fame, options) - return sr_plot(trees, losses, complexities, options; variable_names, kws...) + return default_sr_plot(trees, losses, complexities, options; variable_names, kws...) end -function sr_plot( +function default_sr_plot( trees::Vector{N}, losses::Vector{L}, complexities::Vector{Int}, @@ -32,6 +31,7 @@ function sr_plot( ylabel="Loss", title="Hall of Fame", xlims=(0, options.maxsize), + yscale=:log10, kws..., ) end diff --git a/src/SearchUtils.jl b/src/SearchUtils.jl index 3c620cafa..104952ccb 100644 --- a/src/SearchUtils.jl +++ b/src/SearchUtils.jl @@ -404,8 +404,9 @@ function update_hall_of_fame!( end end -function sr_plot(args...; kws...) - return nothing +# Defined by Plots extension +function default_sr_plot(args...; kws...) + return "Load the Plots package to use this function." end function default_logging_callback(logger; options, num_evals, hall_of_fame, datasets, _...) @@ -425,7 +426,7 @@ function default_logging_callback(logger; options, num_evals, hall_of_fame, data d[string(i)] = Dict() d[string(i)]["best_loss"] = best_loss d[string(i)]["equations"] = Dict() - d[string(i)]["plot"] = sr_plot( + d[string(i)]["plot"] = default_sr_plot( trees, losses, complexities, options; variable_names=dataset.variable_names ) for (complexity, loss, equation) in zip(complexities, losses, equations) From 910e15a13ff7212f08f6dca3772315961c80a397 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 7 Jan 2024 00:26:06 +0000 Subject: [PATCH 10/54] Fix `using` -> `import` --- ext/SymbolicRegressionPlotsExt.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ext/SymbolicRegressionPlotsExt.jl b/ext/SymbolicRegressionPlotsExt.jl index 35c177942..2d1bdb3cb 100644 --- a/ext/SymbolicRegressionPlotsExt.jl +++ b/ext/SymbolicRegressionPlotsExt.jl @@ -1,8 +1,10 @@ module SymbolicRegressionPlotsExt import Plots: plot +import SymbolicRegression: default_sr_plot + using DynamicExpressions: Node -using SymbolicRegression: HallOfFame, Options, string_tree, default_sr_plot +using SymbolicRegression: HallOfFame, Options, string_tree using SymbolicRegression.HallOfFameModule: format_hall_of_fame function plot(hall_of_fame::HallOfFame, options::Options; variable_names=nothing, kws...) From 36015dec0ae1193615a98c92370278451dc23b92 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 7 Jan 2024 03:25:12 +0000 Subject: [PATCH 11/54] Fix import --- ext/SymbolicRegressionPlotsExt.jl | 4 +++- src/SymbolicRegression.jl | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/ext/SymbolicRegressionPlotsExt.jl b/ext/SymbolicRegressionPlotsExt.jl index 2d1bdb3cb..7e82a3ee6 100644 --- a/ext/SymbolicRegressionPlotsExt.jl +++ b/ext/SymbolicRegressionPlotsExt.jl @@ -11,7 +11,9 @@ function plot(hall_of_fame::HallOfFame, options::Options; variable_names=nothing return default_sr_plot(hall_of_fame, options; variable_names, kws...) end -function default_sr_plot(hall_of_fame::HallOfFame, options::Options; variable_names=nothing, kws...) +function default_sr_plot( + hall_of_fame::HallOfFame, options::Options; variable_names=nothing, kws... +) (; trees, losses, complexities) = format_hall_of_fame(hall_of_fame, options) return default_sr_plot(trees, losses, complexities, options; variable_names, kws...) end diff --git a/src/SymbolicRegression.jl b/src/SymbolicRegression.jl index 42b90c4b7..354ec3280 100644 --- a/src/SymbolicRegression.jl +++ b/src/SymbolicRegression.jl @@ -237,7 +237,7 @@ using .SearchUtilsModule: get_cur_maxsize, update_hall_of_fame!, default_logging_callback, - sr_plot + default_sr_plot include("deprecates.jl") include("Configure.jl") From afcc6495c758a5478f6ef8ce677925f60ed9f2aa Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Mon, 8 Jan 2024 00:18:39 +0000 Subject: [PATCH 12/54] Move to using RecipesBase.jl instead --- Project.toml | 5 +---- .../PlotRecipes.jl | 13 +++++-------- src/SymbolicRegression.jl | 1 + 3 files changed, 7 insertions(+), 12 deletions(-) rename ext/SymbolicRegressionPlotsExt.jl => src/PlotRecipes.jl (79%) diff --git a/Project.toml b/Project.toml index b659cba31..f47495fd9 100644 --- a/Project.toml +++ b/Project.toml @@ -21,6 +21,7 @@ PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" ProgressBars = "49802e3a-d2f1-5c88-81d8-b72133a6f568" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" @@ -29,12 +30,10 @@ Tricks = "410a4b4d-49e4-4fbc-ab6d-cb71b17b3775" [weakdeps] JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" -Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b" [extensions] SymbolicRegressionJSON3Ext = "JSON3" -SymbolicRegressionPlotsExt = "Plots" SymbolicRegressionSymbolicUtilsExt = "SymbolicUtils" [compat] @@ -50,7 +49,6 @@ MacroTools = "0.4, 0.5" Optim = "0.19, 1.1 - 1.7.6" PackageExtensionCompat = "1" Pkg = "1" -Plots = "1" PrecompileTools = "1" ProgressBars = "1.4" Reexport = "1" @@ -67,7 +65,6 @@ JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" MLJTestInterface = "72560011-54dd-4dc2-94f3-c5de45b75ecd" -Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb" SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b" diff --git a/ext/SymbolicRegressionPlotsExt.jl b/src/PlotRecipes.jl similarity index 79% rename from ext/SymbolicRegressionPlotsExt.jl rename to src/PlotRecipes.jl index 7e82a3ee6..048dc6264 100644 --- a/ext/SymbolicRegressionPlotsExt.jl +++ b/src/PlotRecipes.jl @@ -1,11 +1,8 @@ -module SymbolicRegressionPlotsExt +module PlotRecipesModule -import Plots: plot -import SymbolicRegression: default_sr_plot - -using DynamicExpressions: Node -using SymbolicRegression: HallOfFame, Options, string_tree -using SymbolicRegression.HallOfFameModule: format_hall_of_fame +using DynamicExpressions: Node, string_tree +using ..CoreModule: Options +using ..HallOfFameModule: HallOfFame, format_hall_of_fame function plot(hall_of_fame::HallOfFame, options::Options; variable_names=nothing, kws...) return default_sr_plot(hall_of_fame, options; variable_names, kws...) @@ -40,4 +37,4 @@ function default_sr_plot( ) end -end +end \ No newline at end of file diff --git a/src/SymbolicRegression.jl b/src/SymbolicRegression.jl index 354ec3280..2d8f26724 100644 --- a/src/SymbolicRegression.jl +++ b/src/SymbolicRegression.jl @@ -159,6 +159,7 @@ include("SingleIteration.jl") include("ProgressBars.jl") include("Migration.jl") include("SearchUtils.jl") +include("PlotRecipes.jl") using .CoreModule: MAX_DEGREE, From bf7297ae9d26a19a8f1377e57cf3324b48853c19 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Mon, 8 Jan 2024 01:14:15 +0000 Subject: [PATCH 13/54] Apply RecipesBase compat --- Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/Project.toml b/Project.toml index f47495fd9..28a3a5b4b 100644 --- a/Project.toml +++ b/Project.toml @@ -51,6 +51,7 @@ PackageExtensionCompat = "1" Pkg = "1" PrecompileTools = "1" ProgressBars = "1.4" +RecipesBase = "1" Reexport = "1" SpecialFunctions = "0.10.1, 1, 2" StatsBase = "0.33, 0.34" From cd3697ed2dedb7607926b0f683b83fb5b9ce020e Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Mon, 8 Jan 2024 01:14:54 +0000 Subject: [PATCH 14/54] Proper types for MLJ interface --- src/MLJInterface.jl | 124 ++++++++++++++++++++++++++++++++++++-------- 1 file changed, 101 insertions(+), 23 deletions(-) diff --git a/src/MLJInterface.jl b/src/MLJInterface.jl index 8d964aea3..3e032bd2e 100644 --- a/src/MLJInterface.jl +++ b/src/MLJInterface.jl @@ -89,9 +89,84 @@ function get_options(::AbstractSRRegressor) end eval(modelexpr(:SRRegressor)) eval(modelexpr(:MultitargetSRRegressor)) +""" + SRFitResultTypes + +A struct referencing types in the `SRFitResult` struct, +to be used in type inference during MLJ.update to speed up iterative fits. +""" +Base.@kwdef struct SRFitResultTypes{ + _T,_X_t,_y_t,_w_t,_state,_X_units,_y_units,_X_units_clean,_y_units_clean +} + T::Type{_T} = Any + X_t::Type{_X_t} = Any + y_t::Type{_y_t} = Any + w_t::Type{_w_t} = Any + state::Type{_state} = Any + X_units::Type{_X_units} = Any + y_units::Type{_y_units} = Any + X_units_clean::Type{_X_units_clean} = Any + y_units_clean::Type{_y_units_clean} = Any +end + +""" + SRFitResult + +A struct containing the result of a fit of an `SRRegressor` or `MultitargetSRRegressor`. +""" +Base.@kwdef struct SRFitResult{ + M<:AbstractSRRegressor, + S, + O<:Options, + XD<:Union{Vector{<:AbstractDimensions},Nothing}, + YD<:Union{AbstractDimensions,Vector{<:AbstractDimensions},Nothing}, + TYPES<:SRFitResultTypes, +} + model::M + state::S + num_targets::Int + options::O + variable_names::Vector{String} + y_variable_names::Union{Vector{String},Nothing} + y_is_table::Bool + X_units::XD + y_units::YD + types::TYPES +end + +# Define a simpler print method for this: +function Base.show(io::IO, ::MIME"text/plain", fitresult::SRFitResult) + print(io, "SRFitResult for $(fitresult.model):") + print(io, "\n") + print(io, " state:\n") + print(io, " [1]: $(typeof(fitresult.state[1])) with ") + print(io, "$(length(fitresult.state[1])) × $(length(fitresult.state[1][1])) ") + print(io, "populations of $(fitresult.state[1][1][1].n) members\n") + print(io, " [2]: $(typeof(fitresult.state[2])) ") + if fitresult.model isa SRRegressor + print(io, "with $(sum(fitresult.state[2].exists)) saved expressions") + else + print(io, "with $(map(s -> sum(s.exists), fitresult.state[2])) saved expressions") + end + print(io, "\n") + print(io, " num_targets: $(fitresult.num_targets)") + print(io, "\n") + print(io, " variable_names: $(fitresult.variable_names)") + print(io, "\n") + print(io, " y_variable_names: $(fitresult.y_variable_names)") + print(io, "\n") + print(io, " X_units: $(fitresult.X_units)") + print(io, "\n") + print(io, " y_units: $(fitresult.y_units)") + print(io, "\n") + return nothing +end + # Cleaning already taken care of by `Options` and `equation_search` function full_report( - m::AbstractSRRegressor, fitresult; v_with_strings::Val{with_strings}=Val(true) + m::AbstractSRRegressor, + fitresult::SRFitResult; + v_with_strings::Val{with_strings}=Val(true), ) where {with_strings} _, hof = fitresult.state # TODO: Adjust baseline loss @@ -123,25 +198,23 @@ function MMI.fit(m::AbstractSRRegressor, verbosity, X, y, w=nothing) return MMI.update(m, verbosity, nothing, nothing, X, y, w) end function MMI.update( - m::AbstractSRRegressor, verbosity, old_fitresult, old_cache, X, y, w=nothing + m::AbstractSRRegressor, + verbosity, + old_fitresult::Union{SRFitResult,Nothing}, + old_cache, + X, + y, + w=nothing, ) options = old_fitresult === nothing ? get_options(m) : old_fitresult.options return _update(m, verbosity, old_fitresult, old_cache, X, y, w, options) end -function _update(m, verbosity, old_fitresult, old_cache, X, y, w, options) +function _update( + m, verbosity, old_fitresult::Union{SRFitResult,Nothing}, old_cache, X, y, w, options +) # To speed up iterative fits, we cache the types: types = if old_fitresult === nothing - (; - T=Any, - X_t=Any, - y_t=Any, - w_t=Any, - state=Any, - X_units=Any, - y_units=Any, - X_units_clean=Any, - y_units_clean=Any, - ) + SRFitResultTypes() else old_fitresult.types end @@ -184,7 +257,8 @@ function _update(m, verbosity, old_fitresult, old_cache, X, y, w, options) # Help out with inference: v_dim_out=isa(m, SRRegressor) ? Val(1) : Val(2), ) - fitresult = (; + fitresult = SRFitResult(; + model=m, state=search_state, num_targets=isa(m, SRRegressor) ? 1 : size(y_t, 1), options=options, @@ -193,7 +267,7 @@ function _update(m, verbosity, old_fitresult, old_cache, X, y, w, options) y_is_table=MMI.istable(y), X_units=X_units_clean, y_units=y_units_clean, - types=( + types=SRFitResultTypes(; T=hof_eltype(search_state[2]), X_t=typeof(X_t), y_t=typeof(y_t), @@ -204,7 +278,7 @@ function _update(m, verbosity, old_fitresult, old_cache, X, y, w, options) X_units_clean=typeof(X_units_clean), y_units_clean=typeof(y_units_clean), ), - )::(old_fitresult === nothing ? Any : typeof(old_fitresult)) + )::(old_fitresult === nothing ? SRFitResult : typeof(old_fitresult)) return (fitresult, nothing, full_report(m, fitresult)) end hof_eltype(::Type{H}) where {T,H<:HallOfFame{T}} = T @@ -253,7 +327,7 @@ function format_input_for(::MultitargetSRRegressor, y, ::Type{D}) where {D} ) return get_matrix_and_info(y, D) end -function validate_variable_names(variable_names, fitresult) +function validate_variable_names(variable_names, fitresult::SRFitResult) @assert( variable_names == fitresult.variable_names, "Variable names do not match fitted regressor." @@ -294,13 +368,15 @@ end end end -function prediction_fallback(::Type{T}, m::SRRegressor, Xnew_t, fitresult) where {T} +function prediction_fallback( + ::Type{T}, m::SRRegressor, Xnew_t, fitresult::SRFitResult +) where {T} prediction_warn() out = fill!(similar(Xnew_t, T, axes(Xnew_t, 2)), zero(T)) return wrap_units(out, fitresult.y_units, nothing) end function prediction_fallback( - ::Type{T}, ::MultitargetSRRegressor, Xnew_t, fitresult, prototype + ::Type{T}, ::MultitargetSRRegressor, Xnew_t, fitresult::SRFitResult, prototype ) where {T} prediction_warn() out_cols = [ @@ -340,7 +416,7 @@ function unwrap_units_single(v::AbstractVector, ::Type{D}) where {D} return compat_ustrip(v)::AbstractVector, dims end -function MMI.fitted_params(m::AbstractSRRegressor, fitresult) +function MMI.fitted_params(m::AbstractSRRegressor, fitresult::SRFitResult) report = full_report(m, fitresult) return (; best_idx=report.best_idx, @@ -349,7 +425,7 @@ function MMI.fitted_params(m::AbstractSRRegressor, fitresult) ) end -function MMI.predict(m::SRRegressor, fitresult, Xnew) +function MMI.predict(m::SRRegressor, fitresult::SRFitResult{<:SRRegressor}, Xnew) params = full_report(m, fitresult; v_with_strings=Val(false)) Xnew_t, variable_names, X_units = get_matrix_and_info(Xnew, m.dimensions_type) T = promote_type(eltype(Xnew_t), fitresult.types.T) @@ -367,7 +443,9 @@ function MMI.predict(m::SRRegressor, fitresult, Xnew) return wrap_units(out, fitresult.y_units, nothing) end end -function MMI.predict(m::MultitargetSRRegressor, fitresult, Xnew) +function MMI.predict( + m::MultitargetSRRegressor, fitresult::SRFitResult{<:MultitargetSRRegressor}, Xnew +) params = full_report(m, fitresult; v_with_strings=Val(false)) prototype = MMI.istable(Xnew) ? Xnew : nothing Xnew_t, variable_names, X_units = get_matrix_and_info(Xnew, m.dimensions_type) From c13d49b3a7ed53381873a0bb6b46a05d6ae04102 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Mon, 8 Jan 2024 01:20:18 +0000 Subject: [PATCH 15/54] Move all printing methods to single file --- src/MLJInterface.jl | 28 ------------------- src/OptionsStruct.jl | 24 ---------------- src/Printing.jl | 58 +++++++++++++++++++++++++++++++++++++++ src/SymbolicRegression.jl | 1 + 4 files changed, 59 insertions(+), 52 deletions(-) create mode 100644 src/Printing.jl diff --git a/src/MLJInterface.jl b/src/MLJInterface.jl index 3e032bd2e..07dcfd966 100644 --- a/src/MLJInterface.jl +++ b/src/MLJInterface.jl @@ -134,34 +134,6 @@ Base.@kwdef struct SRFitResult{ types::TYPES end -# Define a simpler print method for this: -function Base.show(io::IO, ::MIME"text/plain", fitresult::SRFitResult) - print(io, "SRFitResult for $(fitresult.model):") - print(io, "\n") - print(io, " state:\n") - print(io, " [1]: $(typeof(fitresult.state[1])) with ") - print(io, "$(length(fitresult.state[1])) × $(length(fitresult.state[1][1])) ") - print(io, "populations of $(fitresult.state[1][1][1].n) members\n") - print(io, " [2]: $(typeof(fitresult.state[2])) ") - if fitresult.model isa SRRegressor - print(io, "with $(sum(fitresult.state[2].exists)) saved expressions") - else - print(io, "with $(map(s -> sum(s.exists), fitresult.state[2])) saved expressions") - end - print(io, "\n") - print(io, " num_targets: $(fitresult.num_targets)") - print(io, "\n") - print(io, " variable_names: $(fitresult.variable_names)") - print(io, "\n") - print(io, " y_variable_names: $(fitresult.y_variable_names)") - print(io, "\n") - print(io, " X_units: $(fitresult.X_units)") - print(io, "\n") - print(io, " y_units: $(fitresult.y_units)") - print(io, "\n") - return nothing -end - # Cleaning already taken care of by `Options` and `equation_search` function full_report( m::AbstractSRRegressor, diff --git a/src/OptionsStruct.jl b/src/OptionsStruct.jl index dad2c46ca..de5d59af2 100644 --- a/src/OptionsStruct.jl +++ b/src/OptionsStruct.jl @@ -197,28 +197,4 @@ struct Options{CT,OP<:AbstractOperatorEnum,use_recorder,OPT<:Optim.Options,W} define_helper_functions::Bool end -function Base.print(io::IO, options::Options) - return print( - io, - "Options(" * - "binops=$(options.operators.binops), " * - "unaops=$(options.operators.unaops), " - # Fill in remaining fields automatically: - * - join( - [ - if fieldname in (:optimizer_options, :mutation_weights) - "$(fieldname)=..." - else - "$(fieldname)=$(getfield(options, fieldname))" - end for - fieldname in fieldnames(Options) if fieldname ∉ [:operators, :nuna, :nbin] - ], - ", ", - ) * - ")", - ) -end -Base.show(io::IO, ::MIME"text/plain", options::Options) = Base.print(io, options) - end diff --git a/src/Printing.jl b/src/Printing.jl new file mode 100644 index 000000000..07f3ecb9b --- /dev/null +++ b/src/Printing.jl @@ -0,0 +1,58 @@ +"""Defines printing methods of exported types (aside from expressions themselves)""" +module PrintingModule + +using ..CoreModule: Options +using ..MLJModelInterface: SRRegressor, SRFitResult + +function Base.print(io::IO, options::Options) + return print( + io, + "Options(" * + "binops=$(options.operators.binops), " * + "unaops=$(options.operators.unaops), " + # Fill in remaining fields automatically: + * + join( + [ + if fieldname in (:optimizer_options, :mutation_weights) + "$(fieldname)=..." + else + "$(fieldname)=$(getfield(options, fieldname))" + end for + fieldname in fieldnames(Options) if fieldname ∉ [:operators, :nuna, :nbin] + ], + ", ", + ) * + ")", + ) +end +Base.show(io::IO, ::MIME"text/plain", options::Options) = Base.print(io, options) + +function Base.show(io::IO, ::MIME"text/plain", fitresult::SRFitResult) + print(io, "SRFitResult for $(fitresult.model):") + print(io, "\n") + print(io, " state:\n") + print(io, " [1]: $(typeof(fitresult.state[1])) with ") + print(io, "$(length(fitresult.state[1])) × $(length(fitresult.state[1][1])) ") + print(io, "populations of $(fitresult.state[1][1][1].n) members\n") + print(io, " [2]: $(typeof(fitresult.state[2])) ") + if fitresult.model isa SRRegressor + print(io, "with $(sum(fitresult.state[2].exists)) saved expressions") + else + print(io, "with $(map(s -> sum(s.exists), fitresult.state[2])) saved expressions") + end + print(io, "\n") + print(io, " num_targets: $(fitresult.num_targets)") + print(io, "\n") + print(io, " variable_names: $(fitresult.variable_names)") + print(io, "\n") + print(io, " y_variable_names: $(fitresult.y_variable_names)") + print(io, "\n") + print(io, " X_units: $(fitresult.X_units)") + print(io, "\n") + print(io, " y_units: $(fitresult.y_units)") + print(io, "\n") + return nothing +end + +end diff --git a/src/SymbolicRegression.jl b/src/SymbolicRegression.jl index 2d8f26724..03376af71 100644 --- a/src/SymbolicRegression.jl +++ b/src/SymbolicRegression.jl @@ -159,6 +159,7 @@ include("SingleIteration.jl") include("ProgressBars.jl") include("Migration.jl") include("SearchUtils.jl") +include("Printing.jl") include("PlotRecipes.jl") using .CoreModule: From 59a373e436dc2a1c3f5513d55da3873770a3bcb8 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Mon, 8 Jan 2024 01:50:32 +0000 Subject: [PATCH 16/54] Idiomatic plotting recipes for Pareto curve --- src/PlotRecipes.jl | 48 ++++++++++++++++++--------------------- src/Printing.jl | 2 +- src/SymbolicRegression.jl | 5 ++-- 3 files changed, 26 insertions(+), 29 deletions(-) diff --git a/src/PlotRecipes.jl b/src/PlotRecipes.jl index 048dc6264..06bcdc878 100644 --- a/src/PlotRecipes.jl +++ b/src/PlotRecipes.jl @@ -1,40 +1,36 @@ module PlotRecipesModule +using RecipesBase: @recipe using DynamicExpressions: Node, string_tree using ..CoreModule: Options using ..HallOfFameModule: HallOfFame, format_hall_of_fame +using ..MLJInterfaceModule: SRFitResult, SRRegressor -function plot(hall_of_fame::HallOfFame, options::Options; variable_names=nothing, kws...) - return default_sr_plot(hall_of_fame, options; variable_names, kws...) +@recipe function default_sr_plot(fitresult::SRFitResult{<:SRRegressor}) + return fitresult.state[2], fitresult.options end -function default_sr_plot( - hall_of_fame::HallOfFame, options::Options; variable_names=nothing, kws... -) +# TODO: Add variable names +@recipe function default_sr_plot(hall_of_fame::HallOfFame, options::Options) (; trees, losses, complexities) = format_hall_of_fame(hall_of_fame, options) - return default_sr_plot(trees, losses, complexities, options; variable_names, kws...) + return (trees, losses, complexities, options) end -function default_sr_plot( - trees::Vector{N}, - losses::Vector{L}, - complexities::Vector{Int}, - options::Options; - variable_names=nothing, - kws..., +@recipe function default_sr_plot( + trees::Vector{N}, losses::Vector{L}, complexities::Vector{Int}, options::Options ) where {T,L,N<:Node{T}} - tree_strings = [string_tree(tree, options; variable_names) for tree in trees] - return plot( - complexities, - losses; - label=nothing, - xlabel="Complexity", - ylabel="Loss", - title="Hall of Fame", - xlims=(0, options.maxsize), - yscale=:log10, - kws..., - ) + tree_strings = [string_tree(tree, options) for tree in trees] + + xlabel --> "Complexity" + ylabel --> "Loss" + + xlims --> (0.5, options.maxsize + 1) + + xscale --> :log10 + yscale --> :log10 + + # Data for plotting: + return complexities, losses end -end \ No newline at end of file +end diff --git a/src/Printing.jl b/src/Printing.jl index 07f3ecb9b..6d7c35eeb 100644 --- a/src/Printing.jl +++ b/src/Printing.jl @@ -2,7 +2,7 @@ module PrintingModule using ..CoreModule: Options -using ..MLJModelInterface: SRRegressor, SRFitResult +using ..MLJInterfaceModule: SRRegressor, SRFitResult function Base.print(io::IO, options::Options) return print( diff --git a/src/SymbolicRegression.jl b/src/SymbolicRegression.jl index 03376af71..e810ac46d 100644 --- a/src/SymbolicRegression.jl +++ b/src/SymbolicRegression.jl @@ -159,8 +159,6 @@ include("SingleIteration.jl") include("ProgressBars.jl") include("Migration.jl") include("SearchUtils.jl") -include("Printing.jl") -include("PlotRecipes.jl") using .CoreModule: MAX_DEGREE, @@ -1120,6 +1118,9 @@ end include("MLJInterface.jl") using .MLJInterfaceModule: SRRegressor, MultitargetSRRegressor +include("Printing.jl") +include("PlotRecipes.jl") + function __init__() @require_extensions end From 506412e27a73399d586df76a73c8f037bda682a8 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Mon, 8 Jan 2024 02:42:12 +0000 Subject: [PATCH 17/54] Plot convex hull --- src/PlotRecipes.jl | 60 +++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 57 insertions(+), 3 deletions(-) diff --git a/src/PlotRecipes.jl b/src/PlotRecipes.jl index 06bcdc878..2e2429321 100644 --- a/src/PlotRecipes.jl +++ b/src/PlotRecipes.jl @@ -1,6 +1,6 @@ module PlotRecipesModule -using RecipesBase: @recipe +using RecipesBase: @recipe, @series using DynamicExpressions: Node, string_tree using ..CoreModule: Options using ..HallOfFameModule: HallOfFame, format_hall_of_fame @@ -20,6 +20,22 @@ end trees::Vector{N}, losses::Vector{L}, complexities::Vector{Int}, options::Options ) where {T,L,N<:Node{T}} tree_strings = [string_tree(tree, options) for tree in trees] + log_losses = @. log10(losses + eps(L)) + log_complexities = @. log10(complexities) + # Add an upper right corner to this for the convex hull calculation: + push!(log_losses, maximum(log_losses)) + push!(log_complexities, maximum(log_complexities)) + + xy = cat(log_complexities, log_losses; dims=2) + log_hull = convex_hull(xy) + + # Add the first point again to close the hull: + push!(log_hull, log_hull[1]) + + # Then remove the first two points for visualization + log_hull = log_hull[3:end] + + hull = [10 .^ row for row in log_hull] xlabel --> "Complexity" ylabel --> "Loss" @@ -29,8 +45,46 @@ end xscale --> :log10 yscale --> :log10 - # Data for plotting: - return complexities, losses + # Main complexity/loss plot: + @series begin + label --> "Pareto Front" + + complexities, losses + end + + # Add on a convex hull: + @series begin + label --> "Convex Hull" + color --> :lightgray + + first.(hull), last.(hull) + end +end + +"""Uses gift wrapping algorithm to create a convex hull.""" +function convex_hull(xy) + cur_point = xy[sortperm(xy[:, 1])[1], :] + hull = typeof(cur_point)[] + while true + push!(hull, cur_point) + end_point = xy[1, :] + for candidate_point in eachrow(xy) + if end_point == cur_point || isleftof(candidate_point, (cur_point, end_point)) + end_point = candidate_point + end + end + cur_point = end_point + if end_point == hull[1] + break + end + end + return hull +end + +function isleftof(point, line) + (start_point, end_point) = line + return (end_point[1] - start_point[1]) * (point[2] - start_point[2]) - + (end_point[2] - start_point[2]) * (point[1] - start_point[1]) > 0 end end From 4b6a3cd44f30bfa92e2fe1e53718ca85c61df088 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Mon, 8 Jan 2024 02:49:50 +0000 Subject: [PATCH 18/54] Fix for Julia 1.6 --- src/PlotRecipes.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/PlotRecipes.jl b/src/PlotRecipes.jl index 2e2429321..42599ca3c 100644 --- a/src/PlotRecipes.jl +++ b/src/PlotRecipes.jl @@ -12,8 +12,8 @@ end # TODO: Add variable names @recipe function default_sr_plot(hall_of_fame::HallOfFame, options::Options) - (; trees, losses, complexities) = format_hall_of_fame(hall_of_fame, options) - return (trees, losses, complexities, options) + out = format_hall_of_fame(hall_of_fame, options) + return (out.trees, out.losses, out.complexities, options) end @recipe function default_sr_plot( From c1ebb67b132c4bc820bbb9a5a1b9505987709b6b Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Wed, 20 Mar 2024 21:06:06 +0000 Subject: [PATCH 19/54] Move logging utilities to `src/Logigng.jl` --- src/Logging.jl | 58 +++++++++++++++++++++++++++++++++++++++ src/MLJInterface.jl | 8 ++++-- src/SearchUtils.jl | 36 ------------------------ src/SymbolicRegression.jl | 16 +++-------- 4 files changed, 67 insertions(+), 51 deletions(-) create mode 100644 src/Logging.jl diff --git a/src/Logging.jl b/src/Logging.jl new file mode 100644 index 000000000..152260037 --- /dev/null +++ b/src/Logging.jl @@ -0,0 +1,58 @@ +module LoggingModule + +using Base: AbstractLogger + +using ..CoreModule: Options, Dataset +using ..SearchUtilsModule: SearchState, RuntimeOptions + +# Defined by Plots extension +function default_sr_plot(args...; kws...) + return "Load the Plots package to use this function." +end + +function default_logging_callback( + logger::AbstractLogger; + log_step::Integer, + state::SearchState, + datasets::AbstractVector{<:Dataset{T,L}}, + ropt::RuntimeOptions, + options::Options, +) where {T,L} + data = let d = Dict() + for i in eachindex(datasets, state.halls_of_fame) + dominating = calculate_pareto_frontier(state.halls_of_fame[i]) + best_loss = length(dominating) > 0 ? dominating[end].loss : L(Inf) + trees = [member.tree for member in dominating] + losses = L[member.loss for member in dominating] + complexities = Int[compute_complexity(member, options) for member in dominating] + equations = String[ + string_tree( + member.tree, options; variable_names=datasets[i].variable_names + ) for member in dominating + ] + is = string(i) + d[is] = Dict() + d[is]["best_loss"] = best_loss + d[is]["equations"] = Dict() + d[is]["plot"] = default_sr_plot( + trees, + losses, + complexities, + options; + variable_names=datasets[i].variable_names, + ) + for i_eqn in eachindex(complexities, losses, equations) + d[is]["equations"][string(complexities[i_eqn])] = Dict( + "loss" => losses[i_eqn], "equation" => equations[i_eqn] + ) + end + end + d["num_evals"] = sum(sum, state.num_evals) + d + end + with_logger(logger) do + @info("search_state", data = data) + end +end + +end diff --git a/src/MLJInterface.jl b/src/MLJInterface.jl index 8ac7e6205..cb1671328 100644 --- a/src/MLJInterface.jl +++ b/src/MLJInterface.jl @@ -331,8 +331,8 @@ wrap_units(v, y_units, i::Integer) = (yi -> Quantity(yi, y_units[i])).(v) wrap_units(v, y_units, ::Nothing) = (yi -> Quantity(yi, y_units)).(v) function prediction_fallback( - ::Type{T}, ::SRRegressor, Xnew_t, fitresult::SRFitResult -, _) where {T} + ::Type{T}, ::SRRegressor, Xnew_t, fitresult::SRFitResult, _ +) where {T} prediction_warn() out = fill!(similar(Xnew_t, T, axes(Xnew_t, 2)), zero(T)) return wrap_units(out, fitresult.y_units, nothing) @@ -416,7 +416,9 @@ function MMI.predict(m::SRRegressor, fitresult::SRFitResult{<:SRRegressor}, Xnew return prediction_fallback(T, m, X_t, fitresult, prototype) end end -function MMI.predict(m::MultitargetSRRegressor, fitresult::SRFitResult{<:MultitargetSRRegressor}, Xnew) +function MMI.predict( + m::MultitargetSRRegressor, fitresult::SRFitResult{<:MultitargetSRRegressor}, Xnew +) params = full_report(m, fitresult; v_with_strings=Val(false)) prototype = MMI.istable(Xnew) ? Xnew : nothing Xnew_t, variable_names, X_units = get_matrix_and_info(Xnew, m.dimensions_type) diff --git a/src/SearchUtils.jl b/src/SearchUtils.jl index ef85bff56..b11dfb367 100644 --- a/src/SearchUtils.jl +++ b/src/SearchUtils.jl @@ -488,40 +488,4 @@ function update_hall_of_fame!( end end -# Defined by Plots extension -function default_sr_plot(args...; kws...) - return "Load the Plots package to use this function." -end - -function default_logging_callback(logger; options, num_evals, hall_of_fame, datasets, _...) - L = typeof(first(datasets).baseline_loss) - with_logger(logger) do - d = Dict() - for (i, (hof, dataset)) in enumerate(zip(hall_of_fame, datasets)) - dominating = calculate_pareto_frontier(hof) - best_loss = length(dominating) > 0 ? dominating[end].loss : L(Inf) - trees = [member.tree for member in dominating] - losses = L[member.loss for member in dominating] - complexities = Int[compute_complexity(member, options) for member in dominating] - equations = String[ - string_tree(member.tree, options; variable_names=dataset.variable_names) for - member in dominating - ] - d[string(i)] = Dict() - d[string(i)]["best_loss"] = best_loss - d[string(i)]["equations"] = Dict() - d[string(i)]["plot"] = default_sr_plot( - trees, losses, complexities, options; variable_names=dataset.variable_names - ) - for (complexity, loss, equation) in zip(complexities, losses, equations) - d[string(i)]["equations"][string(complexity)] = Dict( - "loss" => loss, "equation" => equation - ) - end - end - d["num_evals"] = sum(sum, num_evals) - @info("search_state", data = d) - end -end - end diff --git a/src/SymbolicRegression.jl b/src/SymbolicRegression.jl index fc6aa9ada..72fece425 100644 --- a/src/SymbolicRegression.jl +++ b/src/SymbolicRegression.jl @@ -171,6 +171,7 @@ include("SingleIteration.jl") include("ProgressBars.jl") include("Migration.jl") include("SearchUtils.jl") +include("Logging.jl") using .CoreModule: MAX_DEGREE, @@ -253,9 +254,8 @@ using .SearchUtilsModule: load_saved_population, construct_datasets, get_cur_maxsize, - update_hall_of_fame!, - default_logging_callback, - default_sr_plot + update_hall_of_fame! +using .LoggingModule: default_logging_callback, default_sr_plot include("deprecates.jl") include("Configure.jl") @@ -1036,15 +1036,7 @@ function _main_search_loop!( ) end if ropt.logging_callback !== nothing && log_step % ropt.log_every_n == 0 - ropt.logging_callback(; - options, - state.num_evals, - hall_of_fame=state.halls_of_fame, - state.worker_assignment, - state.cycles_remaining, - populations=state.last_pops, - datasets=datasets, - ) + ropt.logging_callback(; log_step, state, datasets, ropt, options) end log_step += 1 end From bfbd3eaf3382a1579f7e95e9a2774a0fed31b6c7 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Wed, 20 Mar 2024 21:16:21 +0000 Subject: [PATCH 20/54] Fix merge errors --- src/Logging.jl | 2 + src/MLJInterface.jl | 125 ++++++++++---------------------------------- src/SearchUtils.jl | 2 - 3 files changed, 31 insertions(+), 98 deletions(-) diff --git a/src/Logging.jl b/src/Logging.jl index 152260037..4d985972e 100644 --- a/src/Logging.jl +++ b/src/Logging.jl @@ -1,6 +1,8 @@ module LoggingModule using Base: AbstractLogger +using Logging: with_logger +using DynamicExpressions: string_tree using ..CoreModule: Options, Dataset using ..SearchUtilsModule: SearchState, RuntimeOptions diff --git a/src/MLJInterface.jl b/src/MLJInterface.jl index cb1671328..49b72f9ac 100644 --- a/src/MLJInterface.jl +++ b/src/MLJInterface.jl @@ -1,7 +1,6 @@ module MLJInterfaceModule using Optim: Optim -using Logging: AbstractLogger using LineSearches: LineSearches using MLJModelInterface: MLJModelInterface as MMI using DynamicExpressions: eval_tree_array, string_tree, AbstractExpressionNode, Node @@ -40,9 +39,6 @@ function modelexpr(model_name::Symbol) procs::Union{Vector{Int},Nothing} = nothing addprocs_function::Union{Function,Nothing} = nothing heap_size_hint_in_bytes::Union{Integer,Nothing} = nothing - logger::Union{AbstractLogger,Nothing} = nothing - logging_callback::Union{Function,Nothing} = nothing - log_every_n::Int = 1 runtests::Bool = true loss_type::L = Nothing selection_method::Function = choose_best @@ -89,56 +85,9 @@ function get_options(::AbstractSRRegressor) end eval(modelexpr(:SRRegressor)) eval(modelexpr(:MultitargetSRRegressor)) -""" - SRFitResultTypes - -A struct referencing types in the `SRFitResult` struct, -to be used in type inference during MLJ.update to speed up iterative fits. -""" -Base.@kwdef struct SRFitResultTypes{ - _T,_X_t,_y_t,_w_t,_state,_X_units,_y_units,_X_units_clean,_y_units_clean -} - T::Type{_T} = Any - X_t::Type{_X_t} = Any - y_t::Type{_y_t} = Any - w_t::Type{_w_t} = Any - state::Type{_state} = Any - X_units::Type{_X_units} = Any - y_units::Type{_y_units} = Any - X_units_clean::Type{_X_units_clean} = Any - y_units_clean::Type{_y_units_clean} = Any -end - -""" - SRFitResult - -A struct containing the result of a fit of an `SRRegressor` or `MultitargetSRRegressor`. -""" -Base.@kwdef struct SRFitResult{ - M<:AbstractSRRegressor, - S, - O<:Options, - XD<:Union{Vector{<:AbstractDimensions},Nothing}, - YD<:Union{AbstractDimensions,Vector{<:AbstractDimensions},Nothing}, - TYPES<:SRFitResultTypes, -} - model::M - state::S - num_targets::Int - options::O - variable_names::Vector{String} - y_variable_names::Union{Vector{String},Nothing} - y_is_table::Bool - X_units::XD - y_units::YD - types::TYPES -end - # Cleaning already taken care of by `Options` and `equation_search` function full_report( - m::AbstractSRRegressor, - fitresult::SRFitResult; - v_with_strings::Val{with_strings}=Val(true), + m::AbstractSRRegressor, fitresult; v_with_strings::Val{with_strings}=Val(true) ) where {with_strings} _, hof = fitresult.state # TODO: Adjust baseline loss @@ -170,23 +119,25 @@ function MMI.fit(m::AbstractSRRegressor, verbosity, X, y, w=nothing) return MMI.update(m, verbosity, nothing, nothing, X, y, w) end function MMI.update( - m::AbstractSRRegressor, - verbosity, - old_fitresult::Union{SRFitResult,Nothing}, - old_cache, - X, - y, - w=nothing, + m::AbstractSRRegressor, verbosity, old_fitresult, old_cache, X, y, w=nothing ) options = old_fitresult === nothing ? get_options(m) : old_fitresult.options return _update(m, verbosity, old_fitresult, old_cache, X, y, w, options) end -function _update( - m, verbosity, old_fitresult::Union{SRFitResult,Nothing}, old_cache, X, y, w, options -) +function _update(m, verbosity, old_fitresult, old_cache, X, y, w, options) # To speed up iterative fits, we cache the types: types = if old_fitresult === nothing - SRFitResultTypes() + (; + T=Any, + X_t=Any, + y_t=Any, + w_t=Any, + state=Any, + X_units=Any, + y_units=Any, + X_units_clean=Any, + y_units_clean=Any, + ) else old_fitresult.types end @@ -223,14 +174,10 @@ function _update( X_units=X_units_clean, y_units=y_units_clean, verbosity=verbosity, - logger=m.logger, - logging_callback=m.logging_callback, - log_every_n=m.log_every_n, # Help out with inference: v_dim_out=isa(m, SRRegressor) ? Val(1) : Val(2), ) - fitresult = SRFitResult(; - model=m, + fitresult = (; state=search_state, num_targets=isa(m, SRRegressor) ? 1 : size(y_t, 1), options=options, @@ -239,7 +186,7 @@ function _update( y_is_table=MMI.istable(y), X_units=X_units_clean, y_units=y_units_clean, - types=SRFitResultTypes(; + types=( T=hof_eltype(search_state[2]), X_t=typeof(X_t), y_t=typeof(y_t), @@ -250,7 +197,7 @@ function _update( X_units_clean=typeof(X_units_clean), y_units_clean=typeof(y_units_clean), ), - )::(old_fitresult === nothing ? SRFitResult : typeof(old_fitresult)) + )::(old_fitresult === nothing ? Any : typeof(old_fitresult)) return (fitresult, nothing, full_report(m, fitresult)) end hof_eltype(::Type{H}) where {T,H<:HallOfFame{T}} = T @@ -299,7 +246,7 @@ function format_input_for(::MultitargetSRRegressor, y, ::Type{D}) where {D} ) return get_matrix_and_info(y, D) end -function validate_variable_names(variable_names, fitresult::SRFitResult) +function validate_variable_names(variable_names, fitresult) @assert( variable_names == fitresult.variable_names, "Variable names do not match fitted regressor." @@ -330,15 +277,13 @@ wrap_units(v, ::Nothing, ::Nothing) = v wrap_units(v, y_units, i::Integer) = (yi -> Quantity(yi, y_units[i])).(v) wrap_units(v, y_units, ::Nothing) = (yi -> Quantity(yi, y_units)).(v) -function prediction_fallback( - ::Type{T}, ::SRRegressor, Xnew_t, fitresult::SRFitResult, _ -) where {T} +function prediction_fallback(::Type{T}, ::SRRegressor, Xnew_t, fitresult, _) where {T} prediction_warn() out = fill!(similar(Xnew_t, T, axes(Xnew_t, 2)), zero(T)) return wrap_units(out, fitresult.y_units, nothing) end function prediction_fallback( - ::Type{T}, ::MultitargetSRRegressor, Xnew_t, fitresult::SRFitResult, prototype + ::Type{T}, ::MultitargetSRRegressor, Xnew_t, fitresult, prototype ) where {T} prediction_warn() out_cols = [ @@ -378,7 +323,7 @@ function unwrap_units_single(v::AbstractVector, ::Type{D}) where {D} return compat_ustrip(v)::AbstractVector, dims end -function MMI.fitted_params(m::AbstractSRRegressor, fitresult::SRFitResult) +function MMI.fitted_params(m::AbstractSRRegressor, fitresult) report = full_report(m, fitresult) return (; best_idx=report.best_idx, @@ -398,27 +343,15 @@ function eval_tree_mlj( end end -function MMI.predict(m::SRRegressor, fitresult::SRFitResult{<:SRRegressor}, Xnew) - params = full_report(m, fitresult; v_with_strings=Val(false)) - Xnew_t, variable_names, X_units = get_matrix_and_info(Xnew, m.dimensions_type) - T = promote_type(eltype(Xnew_t), fitresult.types.T) - if length(params.equations) == 0 - return prediction_fallback(T, m, Xnew_t, fitresult) - end - X_units_clean = clean_units(X_units) - validate_variable_names(variable_names, fitresult) - validate_units(X_units_clean, fitresult.X_units) - eq = params.equations[params.best_idx] - out, completed = eval_tree_array(eq, Xnew_t, fitresult.options) - if !completed - return prediction_fallback(T, m, Xnew_t, fitresult) - else - return prediction_fallback(T, m, X_t, fitresult, prototype) +function MMI.predict(m::M, fitresult, Xnew; idx=nothing) where {M<:AbstractSRRegressor} + if Xnew isa NamedTuple && (haskey(Xnew, :idx) || haskey(Xnew, :data)) + @assert( + haskey(Xnew, :idx) && haskey(Xnew, :data) && length(keys(Xnew)) == 2, + "If specifying an equation index during prediction, you must use a named tuple with keys `idx` and `data`." + ) + return MMI.predict(m, fitresult, Xnew.data; idx=Xnew.idx) end -end -function MMI.predict( - m::MultitargetSRRegressor, fitresult::SRFitResult{<:MultitargetSRRegressor}, Xnew -) + params = full_report(m, fitresult; v_with_strings=Val(false)) prototype = MMI.istable(Xnew) ? Xnew : nothing Xnew_t, variable_names, X_units = get_matrix_and_info(Xnew, m.dimensions_type) diff --git a/src/SearchUtils.jl b/src/SearchUtils.jl index b11dfb367..fa9923878 100644 --- a/src/SearchUtils.jl +++ b/src/SearchUtils.jl @@ -5,8 +5,6 @@ module SearchUtilsModule using Printf: @printf, @sprintf using Distributed -using Logging: with_logger -using DynamicExpressions: string_tree using StatsBase: mean using DynamicExpressions: AbstractExpressionNode From 5c7d4c325b39a19933c87afe4224421634c63384 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Wed, 20 Mar 2024 21:33:22 +0000 Subject: [PATCH 21/54] Fix merge errors --- src/MLJInterface.jl | 119 ++++++++++++++++++++++++++++++++++---------- 1 file changed, 92 insertions(+), 27 deletions(-) diff --git a/src/MLJInterface.jl b/src/MLJInterface.jl index 49b72f9ac..b8e87840b 100644 --- a/src/MLJInterface.jl +++ b/src/MLJInterface.jl @@ -2,6 +2,7 @@ module MLJInterfaceModule using Optim: Optim using LineSearches: LineSearches +using Logging: AbstractLogger using MLJModelInterface: MLJModelInterface as MMI using DynamicExpressions: eval_tree_array, string_tree, AbstractExpressionNode, Node using DynamicQuantities: @@ -39,6 +40,9 @@ function modelexpr(model_name::Symbol) procs::Union{Vector{Int},Nothing} = nothing addprocs_function::Union{Function,Nothing} = nothing heap_size_hint_in_bytes::Union{Integer,Nothing} = nothing + logger::Union{AbstractLogger,Nothing} = nothing + logging_callback::Union{Function,Nothing} = nothing + log_every_n::Int = 1 runtests::Bool = true loss_type::L = Nothing selection_method::Function = choose_best @@ -85,9 +89,56 @@ function get_options(::AbstractSRRegressor) end eval(modelexpr(:SRRegressor)) eval(modelexpr(:MultitargetSRRegressor)) +""" + SRFitResultTypes + +A struct referencing types in the `SRFitResult` struct, +to be used in type inference during MLJ.update to speed up iterative fits. +""" +Base.@kwdef struct SRFitResultTypes{ + _T,_X_t,_y_t,_w_t,_state,_X_units,_y_units,_X_units_clean,_y_units_clean +} + T::Type{_T} = Any + X_t::Type{_X_t} = Any + y_t::Type{_y_t} = Any + w_t::Type{_w_t} = Any + state::Type{_state} = Any + X_units::Type{_X_units} = Any + y_units::Type{_y_units} = Any + X_units_clean::Type{_X_units_clean} = Any + y_units_clean::Type{_y_units_clean} = Any +end + +""" + SRFitResult + +A struct containing the result of a fit of an `SRRegressor` or `MultitargetSRRegressor`. +""" +Base.@kwdef struct SRFitResult{ + M<:AbstractSRRegressor, + S, + O<:Options, + XD<:Union{Vector{<:AbstractDimensions},Nothing}, + YD<:Union{AbstractDimensions,Vector{<:AbstractDimensions},Nothing}, + TYPES<:SRFitResultTypes, +} + model::M + state::S + num_targets::Int + options::O + variable_names::Vector{String} + y_variable_names::Union{Vector{String},Nothing} + y_is_table::Bool + X_units::XD + y_units::YD + types::TYPES +end + # Cleaning already taken care of by `Options` and `equation_search` function full_report( - m::AbstractSRRegressor, fitresult; v_with_strings::Val{with_strings}=Val(true) + m::AbstractSRRegressor, + fitresult::SRFitResult; + v_with_strings::Val{with_strings}=Val(true), ) where {with_strings} _, hof = fitresult.state # TODO: Adjust baseline loss @@ -119,25 +170,23 @@ function MMI.fit(m::AbstractSRRegressor, verbosity, X, y, w=nothing) return MMI.update(m, verbosity, nothing, nothing, X, y, w) end function MMI.update( - m::AbstractSRRegressor, verbosity, old_fitresult, old_cache, X, y, w=nothing + m::AbstractSRRegressor, + verbosity, + old_fitresult::Union{SRFitResult,Nothing}, + old_cache, + X, + y, + w=nothing, ) options = old_fitresult === nothing ? get_options(m) : old_fitresult.options return _update(m, verbosity, old_fitresult, old_cache, X, y, w, options) end -function _update(m, verbosity, old_fitresult, old_cache, X, y, w, options) +function _update( + m, verbosity, old_fitresult::Union{SRFitResult,Nothing}, old_cache, X, y, w, options +) # To speed up iterative fits, we cache the types: types = if old_fitresult === nothing - (; - T=Any, - X_t=Any, - y_t=Any, - w_t=Any, - state=Any, - X_units=Any, - y_units=Any, - X_units_clean=Any, - y_units_clean=Any, - ) + SRFitResultTypes() else old_fitresult.types end @@ -174,10 +223,14 @@ function _update(m, verbosity, old_fitresult, old_cache, X, y, w, options) X_units=X_units_clean, y_units=y_units_clean, verbosity=verbosity, + logger=m.logger, + logging_callback=m.logging_callback, + log_every_n=m.log_every_n, # Help out with inference: v_dim_out=isa(m, SRRegressor) ? Val(1) : Val(2), ) - fitresult = (; + fitresult = SRFitResult(; + model=m, state=search_state, num_targets=isa(m, SRRegressor) ? 1 : size(y_t, 1), options=options, @@ -186,7 +239,7 @@ function _update(m, verbosity, old_fitresult, old_cache, X, y, w, options) y_is_table=MMI.istable(y), X_units=X_units_clean, y_units=y_units_clean, - types=( + types=SRFitResultTypes(; T=hof_eltype(search_state[2]), X_t=typeof(X_t), y_t=typeof(y_t), @@ -197,7 +250,7 @@ function _update(m, verbosity, old_fitresult, old_cache, X, y, w, options) X_units_clean=typeof(X_units_clean), y_units_clean=typeof(y_units_clean), ), - )::(old_fitresult === nothing ? Any : typeof(old_fitresult)) + )::(old_fitresult === nothing ? SRFitResult : typeof(old_fitresult)) return (fitresult, nothing, full_report(m, fitresult)) end hof_eltype(::Type{H}) where {T,H<:HallOfFame{T}} = T @@ -246,7 +299,7 @@ function format_input_for(::MultitargetSRRegressor, y, ::Type{D}) where {D} ) return get_matrix_and_info(y, D) end -function validate_variable_names(variable_names, fitresult) +function validate_variable_names(variable_names, fitresult::SRFitResult) @assert( variable_names == fitresult.variable_names, "Variable names do not match fitted regressor." @@ -272,18 +325,30 @@ function prediction_warn() @warn "Evaluation failed either due to NaNs detected or due to unfinished search. Using 0s for prediction." end -wrap_units(v, ::Nothing, ::Integer) = v -wrap_units(v, ::Nothing, ::Nothing) = v -wrap_units(v, y_units, i::Integer) = (yi -> Quantity(yi, y_units[i])).(v) -wrap_units(v, y_units, ::Nothing) = (yi -> Quantity(yi, y_units)).(v) +@inline function wrap_units(v, y_units, i::Integer) + if y_units === nothing + return v + else + return (yi -> Quantity(yi, y_units[i])).(v) + end +end +@inline function wrap_units(v, y_units, ::Nothing) + if y_units === nothing + return v + else + return (yi -> Quantity(yi, y_units)).(v) + end +end -function prediction_fallback(::Type{T}, ::SRRegressor, Xnew_t, fitresult, _) where {T} +function prediction_fallback( + ::Type{T}, m::SRRegressor, Xnew_t, fitresult::SRFitResult, _ +) where {T} prediction_warn() out = fill!(similar(Xnew_t, T, axes(Xnew_t, 2)), zero(T)) return wrap_units(out, fitresult.y_units, nothing) end function prediction_fallback( - ::Type{T}, ::MultitargetSRRegressor, Xnew_t, fitresult, prototype + ::Type{T}, ::MultitargetSRRegressor, Xnew_t, fitresult::SRFitResult, prototype ) where {T} prediction_warn() out_cols = [ @@ -291,11 +356,11 @@ function prediction_fallback( fill!(similar(Xnew_t, T, axes(Xnew_t, 2)), zero(T)), fitresult.y_units, i ) for i in 1:(fitresult.num_targets) ] - out_matrix = hcat(out_cols...) + out_matrix = reduce(hcat, out_cols) if !fitresult.y_is_table return out_matrix else - return MMI.table(out_matrix; names=fitresult.y_variable_names, prototype) + return MMI.table(out_matrix; names=fitresult.y_variable_names, prototype=prototype) end end @@ -323,7 +388,7 @@ function unwrap_units_single(v::AbstractVector, ::Type{D}) where {D} return compat_ustrip(v)::AbstractVector, dims end -function MMI.fitted_params(m::AbstractSRRegressor, fitresult) +function MMI.fitted_params(m::AbstractSRRegressor, fitresult::SRFitResult) report = full_report(m, fitresult) return (; best_idx=report.best_idx, From 98bf25b15a9f124d509f548f0d1dae129b70464f Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Wed, 20 Mar 2024 23:01:01 +0000 Subject: [PATCH 22/54] Allow user to specify different logging rate for plots --- Project.toml | 4 +++- src/Logging.jl | 29 ++++++++++++++++------------- src/MLJInterface.jl | 2 +- src/SearchUtils.jl | 4 ++-- src/SymbolicRegression.jl | 15 ++++++++++----- test/test_logging.jl | 36 ++++++++++++++++++++++++++++++++++++ test/unittest.jl | 4 ++++ 7 files changed, 72 insertions(+), 22 deletions(-) create mode 100644 test/test_logging.jl diff --git a/Project.toml b/Project.toml index a9083a57f..72c322d22 100644 --- a/Project.toml +++ b/Project.toml @@ -70,11 +70,13 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890" MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" MLJTestInterface = "72560011-54dd-4dc2-94f3-c5de45b75ecd" +Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb" SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b" +TensorBoardLogger = "899adc3e-224a-11e9-021f-63837185c80f" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Test", "SafeTestsets", "Aqua", "Bumper", "ForwardDiff", "LinearAlgebra", "LoopVectorization", "JSON3", "MLJBase", "MLJTestInterface", "Suppressor", "SymbolicUtils", "Zygote"] +test = ["Test", "SafeTestsets", "Aqua", "Bumper", "ForwardDiff", "JSON3", "LinearAlgebra", "LoopVectorization", "MLJBase", "MLJTestInterface", "Plots", "Suppressor", "SymbolicUtils", "TensorBoardLogger", "Zygote"] diff --git a/src/Logging.jl b/src/Logging.jl index 4d985972e..46062321b 100644 --- a/src/Logging.jl +++ b/src/Logging.jl @@ -5,6 +5,8 @@ using Logging: with_logger using DynamicExpressions: string_tree using ..CoreModule: Options, Dataset +using ..ComplexityModule: compute_complexity +using ..HallOfFameModule: calculate_pareto_frontier using ..SearchUtilsModule: SearchState, RuntimeOptions # Defined by Plots extension @@ -20,7 +22,7 @@ function default_logging_callback( ropt::RuntimeOptions, options::Options, ) where {T,L} - data = let d = Dict() + data = let d = Dict{String,Union{Dict{String,Any},Float64}}() for i in eachindex(datasets, state.halls_of_fame) dominating = calculate_pareto_frontier(state.halls_of_fame[i]) best_loss = length(dominating) > 0 ? dominating[end].loss : L(Inf) @@ -33,19 +35,20 @@ function default_logging_callback( ) for member in dominating ] is = string(i) - d[is] = Dict() + d[is] = Dict{String,Any}() d[is]["best_loss"] = best_loss - d[is]["equations"] = Dict() - d[is]["plot"] = default_sr_plot( - trees, - losses, - complexities, - options; - variable_names=datasets[i].variable_names, - ) - for i_eqn in eachindex(complexities, losses, equations) - d[is]["equations"][string(complexities[i_eqn])] = Dict( - "loss" => losses[i_eqn], "equation" => equations[i_eqn] + d[is]["equations"] = Dict([ + string(complexities[i_eqn]) => + Dict("loss" => losses[i_eqn], "equation" => equations[i_eqn]) for + i_eqn in eachindex(complexities, losses, equations) + ]) + if ropt.log_every_n.plots > 0 && log_step % ropt.log_every_n.plots == 0 + d[is]["plot"] = default_sr_plot( + trees, + losses, + complexities, + options; + variable_names=datasets[i].variable_names, ) end end diff --git a/src/MLJInterface.jl b/src/MLJInterface.jl index b8e87840b..7408945ec 100644 --- a/src/MLJInterface.jl +++ b/src/MLJInterface.jl @@ -42,7 +42,7 @@ function modelexpr(model_name::Symbol) heap_size_hint_in_bytes::Union{Integer,Nothing} = nothing logger::Union{AbstractLogger,Nothing} = nothing logging_callback::Union{Function,Nothing} = nothing - log_every_n::Int = 1 + log_every_n::Union{Integer,NamedTuple} = 1 runtests::Bool = true loss_type::L = Nothing selection_method::Function = choose_best diff --git a/src/SearchUtils.jl b/src/SearchUtils.jl index fa9923878..45b63d024 100644 --- a/src/SearchUtils.jl +++ b/src/SearchUtils.jl @@ -26,7 +26,7 @@ rather than set within `Options`. This is to differentiate between parameters that relate to processing and the duration of the search, and parameters dealing with the search hyperparameters itself. """ -Base.@kwdef struct RuntimeOptions{PARALLELISM,DIM_OUT,RETURN_STATE} +Base.@kwdef struct RuntimeOptions{PARALLELISM,DIM_OUT,RETURN_STATE,NT<:NamedTuple} niterations::Int64 total_cycles::Int64 numprocs::Int64 @@ -37,7 +37,7 @@ Base.@kwdef struct RuntimeOptions{PARALLELISM,DIM_OUT,RETURN_STATE} verbosity::Int64 progress::Bool logging_callback::Union{Function,Nothing} - log_every_n::Int64 + log_every_n::NT end function Base.getproperty(roptions::RuntimeOptions{P,D,R}, name::Symbol) where {P,D,R} if name == :parallelism diff --git a/src/SymbolicRegression.jl b/src/SymbolicRegression.jl index 72fece425..7a53cc284 100644 --- a/src/SymbolicRegression.jl +++ b/src/SymbolicRegression.jl @@ -374,7 +374,7 @@ function equation_search( verbosity::Union{Integer,Nothing}=nothing, logger::Union{AbstractLogger,Nothing}=nothing, logging_callback::Union{Function,Nothing}=nothing, - log_every_n::Int=1, + log_every_n::Union{Integer,NamedTuple}=1, progress::Union{Bool,Nothing}=nothing, X_units::Union{AbstractVector,Nothing}=nothing, y_units=nothing, @@ -463,7 +463,7 @@ function equation_search( verbosity::Union{Int,Nothing}=nothing, logger::Union{AbstractLogger,Nothing}=nothing, logging_callback::Union{Function,Nothing}=nothing, - log_every_n::Int=1, + log_every_n::Union{Integer,NamedTuple}=1, progress::Union{Bool,Nothing}=nothing, v_dim_out::Val{DIM_OUT}=Val(nothing), ) where {DIM_OUT,T<:DATA_TYPE,L<:LOSS_TYPE,D<:Dataset{T,L}} @@ -581,11 +581,16 @@ function equation_search( else logging_callback end + _log_every_n = if log_every_n isa Integer + (; scalars=log_every_n, plots=0) + else + log_every_n + end # Underscores here mean that we have mutated the variable return _equation_search( datasets, - RuntimeOptions{concurrency,dim_out,_return_state}(; + RuntimeOptions{concurrency,dim_out,_return_state,typeof(_log_every_n)}(; niterations=niterations, total_cycles=options.populations * niterations, numprocs=_numprocs, @@ -596,7 +601,7 @@ function equation_search( verbosity=_verbosity, progress=_progress, logging_callback=_logging_callback, - log_every_n=log_every_n, + log_every_n=_log_every_n, ), options, saved_state, @@ -1035,7 +1040,7 @@ function _main_search_loop!( ropt.parallelism, ) end - if ropt.logging_callback !== nothing && log_step % ropt.log_every_n == 0 + if ropt.logging_callback !== nothing && log_step % ropt.log_every_n.scalars == 0 ropt.logging_callback(; log_step, state, datasets, ropt, options) end log_step += 1 diff --git a/test/test_logging.jl b/test/test_logging.jl new file mode 100644 index 000000000..dd43f5195 --- /dev/null +++ b/test/test_logging.jl @@ -0,0 +1,36 @@ +using Test +using SymbolicRegression +using TensorBoardLogger +using Logging +using MLJBase +using Plots + +mktempdir() do dir + logger = TBLogger(dir, tb_overwrite; min_level=Logging.Info) + + niterations = 4 + populations = 36 + log_every_n = (; scalars=2, plots=10) + model = SRRegressor(; + binary_operators=[+, -, *, mod], + unary_operators=[], + maxsize=40, + niterations, + populations, + log_every_n, + logger, + ) + + X = (a=rand(500), b=rand(500)) + y = @. 2 * cos(X.a * 23.5) - X.b^2 + mach = machine(model, X, y) + + fit!(mach) + + b = TensorBoardLogger.steps(logger) + @test length(b) == (niterations * populations//log_every_n.scalars) + 1 + + files_and_dirs = readdir(dir) + @test length(files_and_dirs) == 1 + @test occursin(r"events\.out\.tfevents", only(files_and_dirs)) +end diff --git a/test/unittest.jl b/test/unittest.jl index f4cd99cba..45ad1721d 100644 --- a/test/unittest.jl +++ b/test/unittest.jl @@ -102,3 +102,7 @@ end @safetestset "Dataset" begin include("test_dataset.jl") end + +@safetestset "Logging" begin + include("test_logging.jl") +end From 8134157fb4b328abcb6c7fb67e13cde7291bafef Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Wed, 20 Mar 2024 23:48:14 +0000 Subject: [PATCH 23/54] Add infiltrator for testing --- Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/Project.toml b/Project.toml index 72c322d22..cb7e582ae 100644 --- a/Project.toml +++ b/Project.toml @@ -9,6 +9,7 @@ Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" DynamicExpressions = "a40a106e-89c9-4ca8-8020-a735e8728b6b" DynamicQuantities = "06fc5a27-2a28-4c7c-a15d-362465fb6821" +Infiltrator = "5903a43b-9cc3-4c30-8d17-598619ec4e9b" LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255" Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" LossFunctions = "30fc2ffe-d236-52d8-8643-a9d8f7c094a7" From 7e62ea73fa3bfaa45483ca0c3d38eec7eb21c96f Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Thu, 21 Mar 2024 00:14:17 +0000 Subject: [PATCH 24/54] Include convex hull area as logged metric --- src/Logging.jl | 125 +++++++++++++++++++++++++++++++++++---------- src/PlotRecipes.jl | 27 +--------- 2 files changed, 100 insertions(+), 52 deletions(-) diff --git a/src/Logging.jl b/src/Logging.jl index 46062321b..a724e7264 100644 --- a/src/Logging.jl +++ b/src/Logging.jl @@ -3,6 +3,7 @@ module LoggingModule using Base: AbstractLogger using Logging: with_logger using DynamicExpressions: string_tree +using Infiltrator: @infiltrate using ..CoreModule: Options, Dataset using ..ComplexityModule: compute_complexity @@ -22,42 +23,114 @@ function default_logging_callback( ropt::RuntimeOptions, options::Options, ) where {T,L} - data = let d = Dict{String,Union{Dict{String,Any},Float64}}() + nout = length(datasets) + data = let d = Ref(Dict{String,Any}()) for i in eachindex(datasets, state.halls_of_fame) + cur_out = Dict{String,Any}() + + #### Summaries dominating = calculate_pareto_frontier(state.halls_of_fame[i]) - best_loss = length(dominating) > 0 ? dominating[end].loss : L(Inf) trees = [member.tree for member in dominating] losses = L[member.loss for member in dominating] complexities = Int[compute_complexity(member, options) for member in dominating] - equations = String[ - string_tree( - member.tree, options; variable_names=datasets[i].variable_names - ) for member in dominating - ] - is = string(i) - d[is] = Dict{String,Any}() - d[is]["best_loss"] = best_loss - d[is]["equations"] = Dict([ - string(complexities[i_eqn]) => - Dict("loss" => losses[i_eqn], "equation" => equations[i_eqn]) for - i_eqn in eachindex(complexities, losses, equations) - ]) - if ropt.log_every_n.plots > 0 && log_step % ropt.log_every_n.plots == 0 - d[is]["plot"] = default_sr_plot( - trees, - losses, - complexities, - options; - variable_names=datasets[i].variable_names, - ) + + cur_out["min_loss"] = length(dominating) > 0 ? dominating[end].loss : L(Inf) + # @infiltrate + cur_out["pareto_volume"] = if length(dominating) > 1 + log_losses = @. log10(losses + eps(L)) + log_complexities = @. log10(complexities) + + # Add a point equal to the best loss and largest possible complexity, + 1 + push!(log_losses, minimum(log_losses)) + push!(log_complexities, log10(options.maxsize + 1)) + + # Add a point to connect things: + push!(log_losses, maximum(log_losses)) + push!(log_complexities, maximum(log_complexities)) + + xy = cat(log_complexities, log_losses; dims=2) + hull = convex_hull(xy) + convex_hull_area(hull) + else + 0.0 + end + + #### Full Pareto front + cur_out["equations"] = let + equations = String[ + string_tree( + member.tree, + options; + variable_names=datasets[i].variable_names, + ) for member in dominating + ] + Dict([ + "complexity=" * string(complexities[i_eqn]) => Dict( + "loss" => losses[i_eqn], "equation" => equations[i_eqn] + ) for i_eqn in eachindex(complexities, losses, equations) + ]) + end + cur_out["plot"] = + if ropt.log_every_n.plots > 0 && log_step % ropt.log_every_n.plots == 0 + default_sr_plot( + trees, + losses, + complexities, + options; + variable_names=datasets[i].variable_names, + ) + else + nothing + end + + if nout == 1 + d[] = cur_out + else + d[]["out_$(i)"] = cur_out end end - d["num_evals"] = sum(sum, state.num_evals) - d + d[]["num_evals"] = sum(sum, state.num_evals) + d[] end with_logger(logger) do - @info("search_state", data = data) + @info("search", data = data) + end +end + +"""Uses gift wrapping algorithm to create a convex hull.""" +function convex_hull(xy) + cur_point = xy[sortperm(xy[:, 1])[1], :] + hull = typeof(cur_point)[] + while true + push!(hull, cur_point) + end_point = xy[1, :] + for candidate_point in eachrow(xy) + if end_point == cur_point || isleftof(candidate_point, (cur_point, end_point)) + end_point = candidate_point + end + end + cur_point = end_point + if end_point == hull[1] + break + end + end + return hull +end + +function isleftof(point, line) + (start_point, end_point) = line + return (end_point[1] - start_point[1]) * (point[2] - start_point[2]) - + (end_point[2] - start_point[2]) * (point[1] - start_point[1]) > 0 +end + +"""Computes the area within a convex hull.""" +function convex_hull_area(hull) + area = 0.0 + for i in eachindex(hull) + j = i == lastindex(hull) ? firstindex(hull) : nextind(hull, i) + area += (hull[i][1] * hull[j][2] - hull[j][1] * hull[i][2]) end + return abs(area) / 2 end end diff --git a/src/PlotRecipes.jl b/src/PlotRecipes.jl index 42599ca3c..ec29f171d 100644 --- a/src/PlotRecipes.jl +++ b/src/PlotRecipes.jl @@ -5,6 +5,7 @@ using DynamicExpressions: Node, string_tree using ..CoreModule: Options using ..HallOfFameModule: HallOfFame, format_hall_of_fame using ..MLJInterfaceModule: SRFitResult, SRRegressor +using ..LoggingModule: convex_hull @recipe function default_sr_plot(fitresult::SRFitResult{<:SRRegressor}) return fitresult.state[2], fitresult.options @@ -61,30 +62,4 @@ end end end -"""Uses gift wrapping algorithm to create a convex hull.""" -function convex_hull(xy) - cur_point = xy[sortperm(xy[:, 1])[1], :] - hull = typeof(cur_point)[] - while true - push!(hull, cur_point) - end_point = xy[1, :] - for candidate_point in eachrow(xy) - if end_point == cur_point || isleftof(candidate_point, (cur_point, end_point)) - end_point = candidate_point - end - end - cur_point = end_point - if end_point == hull[1] - break - end - end - return hull -end - -function isleftof(point, line) - (start_point, end_point) = line - return (end_point[1] - start_point[1]) * (point[2] - start_point[2]) - - (end_point[2] - start_point[2]) * (point[1] - start_point[1]) > 0 -end - end From ee55eeebcf9e8ff5a1c7d89ae8f90f24eeca7cd8 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Thu, 21 Mar 2024 00:14:32 +0000 Subject: [PATCH 25/54] Log the distribution of complexities --- src/Logging.jl | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/Logging.jl b/src/Logging.jl index a724e7264..a336353a8 100644 --- a/src/Logging.jl +++ b/src/Logging.jl @@ -28,6 +28,17 @@ function default_logging_callback( for i in eachindex(datasets, state.halls_of_fame) cur_out = Dict{String,Any}() + #### Population diagnostics + cur_out["population"] = Dict([ + "complexities" => let + complexities = Int[] + for pop in state.last_pops[i], member in pop.members + push!(complexities, compute_complexity(member, options)) + end + [count(==(c), complexities) for c in 1:(options.maxsize)] + end, + ]) + #### Summaries dominating = calculate_pareto_frontier(state.halls_of_fame[i]) trees = [member.tree for member in dominating] From 332adc40dad5df376e59c22ea6c67c6e89ba5068 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Thu, 21 Mar 2024 00:39:20 +0000 Subject: [PATCH 26/54] Fix plotting part of logging --- src/Logging.jl | 10 +++------- src/PlotRecipes.jl | 2 ++ src/SymbolicRegression.jl | 2 +- 3 files changed, 6 insertions(+), 8 deletions(-) diff --git a/src/Logging.jl b/src/Logging.jl index a336353a8..464afb68b 100644 --- a/src/Logging.jl +++ b/src/Logging.jl @@ -4,17 +4,13 @@ using Base: AbstractLogger using Logging: with_logger using DynamicExpressions: string_tree using Infiltrator: @infiltrate +using RecipesBase: plot using ..CoreModule: Options, Dataset using ..ComplexityModule: compute_complexity using ..HallOfFameModule: calculate_pareto_frontier using ..SearchUtilsModule: SearchState, RuntimeOptions -# Defined by Plots extension -function default_sr_plot(args...; kws...) - return "Load the Plots package to use this function." -end - function default_logging_callback( logger::AbstractLogger; log_step::Integer, @@ -35,7 +31,7 @@ function default_logging_callback( for pop in state.last_pops[i], member in pop.members push!(complexities, compute_complexity(member, options)) end - [count(==(c), complexities) for c in 1:(options.maxsize)] + complexities end, ]) @@ -83,7 +79,7 @@ function default_logging_callback( end cur_out["plot"] = if ropt.log_every_n.plots > 0 && log_step % ropt.log_every_n.plots == 0 - default_sr_plot( + plot( trees, losses, complexities, diff --git a/src/PlotRecipes.jl b/src/PlotRecipes.jl index ec29f171d..cfadd937b 100644 --- a/src/PlotRecipes.jl +++ b/src/PlotRecipes.jl @@ -7,6 +7,8 @@ using ..HallOfFameModule: HallOfFame, format_hall_of_fame using ..MLJInterfaceModule: SRFitResult, SRRegressor using ..LoggingModule: convex_hull +function default_sr_plot end + @recipe function default_sr_plot(fitresult::SRFitResult{<:SRRegressor}) return fitresult.state[2], fitresult.options end diff --git a/src/SymbolicRegression.jl b/src/SymbolicRegression.jl index 7a53cc284..b2786925a 100644 --- a/src/SymbolicRegression.jl +++ b/src/SymbolicRegression.jl @@ -255,7 +255,7 @@ using .SearchUtilsModule: construct_datasets, get_cur_maxsize, update_hall_of_fame! -using .LoggingModule: default_logging_callback, default_sr_plot +using .LoggingModule: default_logging_callback include("deprecates.jl") include("Configure.jl") From 935a8eb22dff2d4f13185498614932ebc853c7fd Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Thu, 21 Mar 2024 00:46:36 +0000 Subject: [PATCH 27/54] Remove infiltrator --- Project.toml | 1 - src/Logging.jl | 1 - 2 files changed, 2 deletions(-) diff --git a/Project.toml b/Project.toml index cb7e582ae..72c322d22 100644 --- a/Project.toml +++ b/Project.toml @@ -9,7 +9,6 @@ Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" DynamicExpressions = "a40a106e-89c9-4ca8-8020-a735e8728b6b" DynamicQuantities = "06fc5a27-2a28-4c7c-a15d-362465fb6821" -Infiltrator = "5903a43b-9cc3-4c30-8d17-598619ec4e9b" LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255" Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" LossFunctions = "30fc2ffe-d236-52d8-8643-a9d8f7c094a7" diff --git a/src/Logging.jl b/src/Logging.jl index 464afb68b..1acfd1c85 100644 --- a/src/Logging.jl +++ b/src/Logging.jl @@ -3,7 +3,6 @@ module LoggingModule using Base: AbstractLogger using Logging: with_logger using DynamicExpressions: string_tree -using Infiltrator: @infiltrate using RecipesBase: plot using ..CoreModule: Options, Dataset From fe053c3cc376fd87a9a6f234d9faddeb0d4f430b Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 16 Jun 2024 02:37:25 +0100 Subject: [PATCH 28/54] build: add missing Logging compat --- Project.toml | 1 + test/Project.toml | 1 + 2 files changed, 2 insertions(+) diff --git a/Project.toml b/Project.toml index 3d9e26ae1..9b1f31638 100644 --- a/Project.toml +++ b/Project.toml @@ -45,6 +45,7 @@ DynamicExpressions = "0.16" DynamicQuantities = "0.10, 0.11, 0.12, 0.13, 0.14" JSON3 = "1" LineSearches = "7" +Logging = "1" LossFunctions = "0.10, 0.11" MLJModelInterface = "~1.5, ~1.6, ~1.7, ~1.8, ~1.9, ~1.10" MacroTools = "0.4, 0.5" diff --git a/test/Project.toml b/test/Project.toml index 46b408ce7..e389cb600 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -10,6 +10,7 @@ JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890" MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea" From 78d08d2965a41213ec9612ff615a65ab54cace34 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 16 Jun 2024 02:39:15 +0100 Subject: [PATCH 29/54] test: fix logging test --- test/runtests.jl | 8 ++++---- test/test_logging.jl | 1 + 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index cb8b90a96..35abdd28b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -107,10 +107,6 @@ end include("test_utils.jl") end -@testitem "Test logging" tags = [:unit] begin - include("test_logging.jl") -end - @testitem "Test units" tags = [:integration] begin include("test_units.jl") end @@ -167,6 +163,10 @@ end include("test_abstract_numbers.jl") end +@testitem "Test logging" tags = [:integration] begin + include("test_logging.jl") +end + @testitem "Aqua tests" tags = [:integration, :aqua] begin include("test_aqua.jl") end diff --git a/test/test_logging.jl b/test/test_logging.jl index dd43f5195..e0512e414 100644 --- a/test/test_logging.jl +++ b/test/test_logging.jl @@ -4,6 +4,7 @@ using TensorBoardLogger using Logging using MLJBase using Plots +include("test_params.jl") mktempdir() do dir logger = TBLogger(dir, tb_overwrite; min_level=Logging.Info) From afc771188783fc1078caf765817e588964e67ce2 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 16 Jun 2024 02:50:48 +0100 Subject: [PATCH 30/54] refactor: move RecipesBase to extension --- Project.toml | 4 ++- .../SymbolicRegressionRecipesBaseExt.jl | 27 ++++++++++++++---- src/Logging.jl | 28 +++++++++++-------- 3 files changed, 40 insertions(+), 19 deletions(-) rename src/PlotRecipes.jl => ext/SymbolicRegressionRecipesBaseExt.jl (68%) diff --git a/Project.toml b/Project.toml index 9b1f31638..0149ba7b7 100644 --- a/Project.toml +++ b/Project.toml @@ -22,7 +22,6 @@ PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" ProgressBars = "49802e3a-d2f1-5c88-81d8-b72133a6f568" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" @@ -30,10 +29,12 @@ TOML = "fa267f1f-6049-4f14-aa54-33bafae1ed76" [weakdeps] JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" +RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b" [extensions] SymbolicRegressionJSON3Ext = "JSON3" +SymbolicRegressionRecipesBaseExt = "RecipesBase" SymbolicRegressionSymbolicUtilsExt = "SymbolicUtils" [compat] @@ -66,4 +67,5 @@ julia = "1.6" [extras] JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" +RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b" diff --git a/src/PlotRecipes.jl b/ext/SymbolicRegressionRecipesBaseExt.jl similarity index 68% rename from src/PlotRecipes.jl rename to ext/SymbolicRegressionRecipesBaseExt.jl index cfadd937b..4a73dda86 100644 --- a/src/PlotRecipes.jl +++ b/ext/SymbolicRegressionRecipesBaseExt.jl @@ -1,11 +1,26 @@ -module PlotRecipesModule +module SymbolicRegressionRecipesBaseExt -using RecipesBase: @recipe, @series +using RecipesBase: @recipe, @series, plot using DynamicExpressions: Node, string_tree -using ..CoreModule: Options -using ..HallOfFameModule: HallOfFame, format_hall_of_fame -using ..MLJInterfaceModule: SRFitResult, SRRegressor -using ..LoggingModule: convex_hull +using SymbolicRegression.CoreModule: Options +using SymbolicRegression.HallOfFameModule: HallOfFame, format_hall_of_fame +using SymbolicRegression.MLJInterfaceModule: SRFitResult, SRRegressor +using SymbolicRegression.LoggingModule: convex_hull + +import SymbolicRegression.LoggingModule: add_plot_to_log! + +function add_plot_to_log!( + log; trees, losses, complexities, options, variable_names, log_step, ropt +) + if ropt.log_every_n.plots > 0 && log_step % ropt.log_every_n.plots == 0 + log["plot"] = plot( + trees, losses, complexities, options; variable_names=variable_names + ) + else + nothing + end + return nothing +end function default_sr_plot end diff --git a/src/Logging.jl b/src/Logging.jl index 1acfd1c85..7b3de9906 100644 --- a/src/Logging.jl +++ b/src/Logging.jl @@ -10,6 +10,10 @@ using ..ComplexityModule: compute_complexity using ..HallOfFameModule: calculate_pareto_frontier using ..SearchUtilsModule: SearchState, RuntimeOptions +function add_plot_to_log!(log; kws...) + return nothing +end + function default_logging_callback( logger::AbstractLogger; log_step::Integer, @@ -76,18 +80,18 @@ function default_logging_callback( ) for i_eqn in eachindex(complexities, losses, equations) ]) end - cur_out["plot"] = - if ropt.log_every_n.plots > 0 && log_step % ropt.log_every_n.plots == 0 - plot( - trees, - losses, - complexities, - options; - variable_names=datasets[i].variable_names, - ) - else - nothing - end + + # Will get method created by RecipesBase extension + add_plot_to_log!( + cur_out; + trees, + losses, + complexities, + options, + datasets[i].variable_names, + log_step, + ropt, + ) if nout == 1 d[] = cur_out From a491a64c650e7bb0689708b0e174e1afff070e6a Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 16 Jun 2024 02:57:19 +0100 Subject: [PATCH 31/54] test: add stabilization to MLJInterface --- src/Logging.jl | 1 - src/SymbolicRegression.jl | 10 +++++----- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/src/Logging.jl b/src/Logging.jl index 7b3de9906..84b1947ce 100644 --- a/src/Logging.jl +++ b/src/Logging.jl @@ -3,7 +3,6 @@ module LoggingModule using Base: AbstractLogger using Logging: with_logger using DynamicExpressions: string_tree -using RecipesBase: plot using ..CoreModule: Options, Dataset using ..ComplexityModule: compute_complexity diff --git a/src/SymbolicRegression.jl b/src/SymbolicRegression.jl index 0df4d3bb1..8ec8ca7e2 100644 --- a/src/SymbolicRegression.jl +++ b/src/SymbolicRegression.jl @@ -1163,11 +1163,11 @@ end return (out_pop, best_seen, record, num_evals) end -include("MLJInterface.jl") -using .MLJInterfaceModule: SRRegressor, MultitargetSRRegressor - -include("Printing.jl") -include("PlotRecipes.jl") +@stable default_mode = "disable" begin + include("MLJInterface.jl") + using .MLJInterfaceModule: SRRegressor, MultitargetSRRegressor + include("Printing.jl") +end function __init__() @require_extensions From f4c9111c52172ffae2b4d1128d2f30be2cfb049d Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 16 Jun 2024 03:00:42 +0100 Subject: [PATCH 32/54] test: reset variable names when needed --- test/test_params.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/test_params.jl b/test/test_params.jl index 7aafd7f48..c75105365 100644 --- a/test/test_params.jl +++ b/test/test_params.jl @@ -1,4 +1,5 @@ using SymbolicRegression: L2DistLoss, MutationWeights +using DynamicExpressions: set_default_variable_names! using Optim: Optim using LineSearches: LineSearches using Test: Test @@ -69,3 +70,5 @@ const default_params = ( test_info(_, x) = error("Test failed: $x") test_info(_, ::Test.Pass) = nothing test_info(f::F, ::Test.Fail) where {F} = f() + +set_default_variable_names!(["x$(i)" for i in 1:100]) From 5a759d5ff5f02989b100248f8a984dc65cfd9d60 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 16 Jun 2024 03:18:48 +0100 Subject: [PATCH 33/54] refactor: prevent method overwrite --- ext/SymbolicRegressionRecipesBaseExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/SymbolicRegressionRecipesBaseExt.jl b/ext/SymbolicRegressionRecipesBaseExt.jl index 4a73dda86..ca6d3bb5a 100644 --- a/ext/SymbolicRegressionRecipesBaseExt.jl +++ b/ext/SymbolicRegressionRecipesBaseExt.jl @@ -10,7 +10,7 @@ using SymbolicRegression.LoggingModule: convex_hull import SymbolicRegression.LoggingModule: add_plot_to_log! function add_plot_to_log!( - log; trees, losses, complexities, options, variable_names, log_step, ropt + log::Dict; trees, losses, complexities, options, variable_names, log_step, ropt ) if ropt.log_every_n.plots > 0 && log_step % ropt.log_every_n.plots == 0 log["plot"] = plot( From 6b778e8894c29b436667737f5cb8b1fc212a9406 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Fri, 8 Nov 2024 17:37:25 -0500 Subject: [PATCH 34/54] refactor: despecialize --- ext/SymbolicRegressionRecipesBaseExt.jl | 16 +++++++++++++--- src/Logging.jl | 6 +++--- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/ext/SymbolicRegressionRecipesBaseExt.jl b/ext/SymbolicRegressionRecipesBaseExt.jl index 303d50600..fefef7b1e 100644 --- a/ext/SymbolicRegressionRecipesBaseExt.jl +++ b/ext/SymbolicRegressionRecipesBaseExt.jl @@ -10,7 +10,14 @@ using SymbolicRegression.LoggingModule: convex_hull import SymbolicRegression.LoggingModule: add_plot_to_log! function add_plot_to_log!( - log::Dict; trees, losses, complexities, options, variable_names, log_step, ropt + log::Dict; + trees, + losses, + complexities, + @nospecialize(options), + variable_names, + log_step, + ropt, ) if ropt.log_every_n.plots > 0 && log_step % ropt.log_every_n.plots == 0 log["plot"] = plot( @@ -29,13 +36,16 @@ function default_sr_plot end end # TODO: Add variable names -@recipe function default_sr_plot(hall_of_fame::HallOfFame, options::Options) +@recipe function default_sr_plot(hall_of_fame::HallOfFame, @nospecialize(options::Options)) out = format_hall_of_fame(hall_of_fame, options) return (out.trees, out.losses, out.complexities, options) end @recipe function default_sr_plot( - trees::Vector{N}, losses::Vector{L}, complexities::Vector{Int}, options::Options + trees::Vector{N}, + losses::Vector{L}, + complexities::Vector{Int}, + @nospecialize(options::Options) ) where {T,L,N<:AbstractExpression{T}} tree_strings = [string_tree(tree, options) for tree in trees] log_losses = @. log10(losses + eps(L)) diff --git a/src/Logging.jl b/src/Logging.jl index e98661a31..27eb4234d 100644 --- a/src/Logging.jl +++ b/src/Logging.jl @@ -16,10 +16,10 @@ end function SU.default_logging_callback( logger::AbstractLogger; log_step::Integer, - state::SearchState, + @nospecialize(state::SearchState), datasets::AbstractVector{<:Dataset{T,L}}, - ropt::RuntimeOptions, - options::Options, + @nospecialize(ropt::RuntimeOptions), + @nospecialize(options::Options), ) where {T,L} nout = length(datasets) data = let d = Ref(Dict{String,Any}()) From bce1df50b572f971af5c936e7058daca26e2598e Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Fri, 8 Nov 2024 18:41:22 -0500 Subject: [PATCH 35/54] refactor: clean up logging interface --- ext/SymbolicRegressionRecipesBaseExt.jl | 21 +- src/Logging.jl | 258 ++++++++++++++++-------- src/SearchUtils.jl | 29 +-- src/SymbolicRegression.jl | 18 +- test/runtests.jl | 5 +- test/test_logging.jl | 61 +++--- 6 files changed, 221 insertions(+), 171 deletions(-) diff --git a/ext/SymbolicRegressionRecipesBaseExt.jl b/ext/SymbolicRegressionRecipesBaseExt.jl index fefef7b1e..bdd26e839 100644 --- a/ext/SymbolicRegressionRecipesBaseExt.jl +++ b/ext/SymbolicRegressionRecipesBaseExt.jl @@ -9,24 +9,11 @@ using SymbolicRegression.LoggingModule: convex_hull import SymbolicRegression.LoggingModule: add_plot_to_log! -function add_plot_to_log!( - log::Dict; - trees, - losses, - complexities, - @nospecialize(options), - variable_names, - log_step, - ropt, +function add_plot_to_log!(; + trees, losses, complexities, @nospecialize(options), variable_names ) - if ropt.log_every_n.plots > 0 && log_step % ropt.log_every_n.plots == 0 - log["plot"] = plot( - trees, losses, complexities, options; variable_names=variable_names - ) - else - nothing - end - return nothing + plot_result = plot(trees, losses, complexities, options; variable_names=variable_names) + return Dict{String,Any}("plot" => plot_result) end function default_sr_plot end diff --git a/src/Logging.jl b/src/Logging.jl index 27eb4234d..2edee6e2c 100644 --- a/src/Logging.jl +++ b/src/Logging.jl @@ -1,109 +1,193 @@ module LoggingModule using Base: AbstractLogger -using Logging: with_logger +using Logging: Logging as LG using DynamicExpressions: string_tree -using ..CoreModule: Options, Dataset +using ..UtilsModule: @ignore +using ..CoreModule: AbstractOptions, Dataset +using ..PopulationModule: Population +using ..HallOfFameModule: HallOfFame using ..ComplexityModule: compute_complexity using ..HallOfFameModule: calculate_pareto_frontier -using ..SearchUtilsModule: SearchState, RuntimeOptions, SearchUtilsModule as SU +using ..SearchUtilsModule: AbstractSearchState, AbstractRuntimeOptions -function add_plot_to_log!(log; kws...) - return nothing +import ..SearchUtilsModule: logging_callback! + +""" + AbstractSRLogger <: AbstractLogger + +Abstract type for symbolic regression loggers. Subtypes must implement: + +- `get_logger(logger)`: Return the underlying logger +- `logging_callback!(logger; kws...)`: Callback function for logging. + Called with the current state, datasets, runtime options, and options. If you wish to + reduce the logging frequency, you can increment and monitor a counter within this + function. +""" +abstract type AbstractSRLogger <: LG.AbstractLogger end + +function get_logger end + +""" + SRLogger(logger::AbstractLogger; log_every_n::Integer=1) + +A logger for symbolic regression that wraps another logger. + +# Arguments +- `logger`: The base logger to wrap +- `log_interval_scalars`: Number of steps between logging events for scalars. Default is 1 (log every step). +- `log_interval_plots`: Number of steps between logging events for plots. Default is 0 (never log plots). +""" +Base.@kwdef struct SRLogger{L<:AbstractLogger} <: AbstractSRLogger + logger::L + log_interval_scalars::Int = 1 + log_interval_plots::Int = 0 + _log_step::Base.RefValue{Int} = Base.RefValue(0) end +SRLogger(logger::AbstractLogger; kws...) = SRLogger(; logger, kws...) -function SU.default_logging_callback( - logger::AbstractLogger; - log_step::Integer, - @nospecialize(state::SearchState), +get_logger(logger::SRLogger) = logger.logger +function should_log(logger::SRLogger) + return should_log(logger, Val(:scalars)) || should_log(logger, Val(:plots)) +end +function should_log(logger::SRLogger, ::Val{:scalars}) + return logger.log_interval_scalars > 0 && + logger._log_step[] % logger.log_interval_scalars == 0 +end +function should_log(logger::SRLogger, ::Val{:plots}) + return logger.log_interval_plots > 0 && + logger._log_step[] % logger.log_interval_plots == 0 +end + +#! format: off +LG.with_logger(f::Function, logger::AbstractSRLogger) = LG.with_logger(f, get_logger(logger)) +#! format: on + +# Will get method created by RecipesBase extension +function add_plot_to_log! end +@ignore add_plot_to_log!(; kws...) = nothing + +""" + logging_callback!(logger::AbstractSRLogger; kws...) + +Default logging callback for SymbolicRegression. Logs the current state of the search, +and adds a plot of the current Pareto front to the logger. + +To override the default logging behavior, create a new type `MyLogger <: AbstractSRLogger` +and define a method for `SymbolicRegression.logging_callback`. +""" +function logging_callback!( + logger::AbstractSRLogger; + @nospecialize(state::AbstractSearchState), datasets::AbstractVector{<:Dataset{T,L}}, - @nospecialize(ropt::RuntimeOptions), - @nospecialize(options::Options), + @nospecialize(ropt::AbstractRuntimeOptions), + @nospecialize(options::AbstractOptions), ) where {T,L} - nout = length(datasets) - data = let d = Ref(Dict{String,Any}()) - for i in eachindex(datasets, state.halls_of_fame) - cur_out = Dict{String,Any}() - - #### Population diagnostics - cur_out["population"] = Dict([ - "complexities" => let - complexities = Int[] - for pop in state.last_pops[i], member in pop.members - push!(complexities, compute_complexity(member, options)) - end - complexities - end, - ]) - - #### Summaries - dominating = calculate_pareto_frontier(state.halls_of_fame[i]) - trees = [member.tree for member in dominating] - losses = L[member.loss for member in dominating] - complexities = Int[compute_complexity(member, options) for member in dominating] - - cur_out["min_loss"] = length(dominating) > 0 ? dominating[end].loss : L(Inf) - # @infiltrate - cur_out["pareto_volume"] = if length(dominating) > 1 - log_losses = @. log10(losses + eps(L)) - log_complexities = @. log10(complexities) - - # Add a point equal to the best loss and largest possible complexity, + 1 - push!(log_losses, minimum(log_losses)) - push!(log_complexities, log10(options.maxsize + 1)) - - # Add a point to connect things: - push!(log_losses, maximum(log_losses)) - push!(log_complexities, maximum(log_complexities)) - - xy = cat(log_complexities, log_losses; dims=2) - hull = convex_hull(xy) - convex_hull_area(hull) - else - 0.0 - end - - #### Full Pareto front - cur_out["equations"] = let - equations = String[ - string_tree( - member.tree, - options; - variable_names=datasets[i].variable_names, - ) for member in dominating - ] - Dict([ - "complexity=" * string(complexities[i_eqn]) => Dict( - "loss" => losses[i_eqn], "equation" => equations[i_eqn] - ) for i_eqn in eachindex(complexities, losses, equations) - ]) - end + log_step = logger._log_step[] + if should_log(logger) + data = log_payload(logger, state, datasets, options) + LG.with_logger(logger) do + @info("search", data = data) + end + end + logger._log_step[] += 1 + return nothing +end - # Will get method created by RecipesBase extension - add_plot_to_log!( - cur_out; - trees, - losses, - complexities, - options, - datasets[i].variable_names, - log_step, - ropt, +function log_payload( + logger::AbstractSRLogger, + @nospecialize(state::AbstractSearchState), + datasets::AbstractVector{<:Dataset{T,L}}, + @nospecialize(options::AbstractOptions), +) where {T,L} + d = Ref(Dict{String,Any}()) + should_log_scalars = should_log(logger, Val(:scalars)) + should_log_plots = should_log(logger, Val(:plots)) + for i in eachindex(datasets, state.halls_of_fame) + out = Dict{String,Any}() + if should_log_scalars + out = merge( + out, + _log_scalars( + state.last_pops[i], state.halls_of_fame[i], datasets[i], options + ), ) + end + if should_log_plots + out = merge( + out, + add_plot_to_log!(; trees, losses, complexities, datasets[i].variable_names), + ) + end + if length(datasets) == 1 + d[] = out + else + d[]["output$(i)"] = out + end + end + d[]["num_evals"] = sum(sum, state.num_evals) + return d[] +end - if nout == 1 - d[] = cur_out - else - d[]["out_$(i)"] = cur_out +function _log_scalars( + @nospecialize(pops::AbstractVector{<:Population}), + @nospecialize(hall_of_fame::HallOfFame{T,L}), + dataset::Dataset{T,L}, + @nospecialize(options::AbstractOptions), +) where {T,L} + out = Dict{String,Any}() + + #### Population diagnostics + out["population"] = Dict([ + "complexities" => let + complexities = Int[] + for pop in pops, member in pop.members + push!(complexities, compute_complexity(member, options)) end + complexities end - d[]["num_evals"] = sum(sum, state.num_evals) - d[] + ]) + + #### Summaries + dominating = calculate_pareto_frontier(hall_of_fame) + trees = [member.tree for member in dominating] + losses = L[member.loss for member in dominating] + complexities = Int[compute_complexity(member, options) for member in dominating] + + out["min_loss"] = length(dominating) > 0 ? dominating[end].loss : L(Inf) + out["pareto_volume"] = if length(dominating) > 1 + log_losses = @. log10(losses + eps(L)) + log_complexities = @. log10(complexities) + + # Add a point equal to the best loss and largest possible complexity, + 1 + push!(log_losses, minimum(log_losses)) + push!(log_complexities, log10(options.maxsize + 1)) + + # Add a point to connect things: + push!(log_losses, maximum(log_losses)) + push!(log_complexities, maximum(log_complexities)) + + xy = cat(log_complexities, log_losses; dims=2) + hull = convex_hull(xy) + convex_hull_area(hull) + else + 0.0 end - with_logger(logger) do - @info("search", data = data) + + #### Full Pareto front + out["equations"] = let + equations = String[ + string_tree(member.tree, options; variable_names=dataset.variable_names) for + member in dominating + ] + Dict([ + "complexity=" * string(complexities[i_eqn]) => + Dict("loss" => losses[i_eqn], "equation" => equations[i_eqn]) for + i_eqn in eachindex(complexities, losses, equations) + ]) end + return out end """Uses gift wrapping algorithm to create a convex hull.""" diff --git a/src/SearchUtils.jl b/src/SearchUtils.jl index b9c823756..abcb971e2 100644 --- a/src/SearchUtils.jl +++ b/src/SearchUtils.jl @@ -21,7 +21,7 @@ using ..HallOfFameModule: HallOfFame, string_dominating_pareto_curve using ..ProgressBarsModule: WrappedProgressBar, manually_iterate!, barlen using ..AdaptiveParsimonyModule: RunningSearchStatistics -function default_logging_callback end +function logging_callback! end """ AbstractRuntimeOptions @@ -42,15 +42,14 @@ can customize runtime behaviors by passing it to `equation_search`. abstract type AbstractRuntimeOptions end """ - RuntimeOptions{N,PARALLELISM,DIM_OUT,RETURN_STATE,NT} <: AbstractRuntimeOptions + RuntimeOptions{PARALLELISM,DIM_OUT,RETURN_STATE,LOGGER} <: AbstractRuntimeOptions Parameters for a search that are passed to `equation_search` directly, rather than set within `Options`. This is to differentiate between parameters that relate to processing and the duration of the search, and parameters dealing with the search hyperparameters itself. """ -struct RuntimeOptions{PARALLELISM,DIM_OUT,RETURN_STATE,NT<:NamedTuple} <: - AbstractRuntimeOptions +struct RuntimeOptions{PARALLELISM,DIM_OUT,RETURN_STATE,LOGGER} <: AbstractRuntimeOptions niterations::Int64 numprocs::Int64 init_procs::Union{Vector{Int},Nothing} @@ -59,8 +58,7 @@ struct RuntimeOptions{PARALLELISM,DIM_OUT,RETURN_STATE,NT<:NamedTuple} <: runtests::Bool verbosity::Int64 progress::Bool - logging_callback::Union{Function,Nothing} - log_every_n::NT + logger::Union{AbstractLogger,Nothing} parallelism::Val{PARALLELISM} dim_out::Val{DIM_OUT} return_state::Val{RETURN_STATE} @@ -97,9 +95,7 @@ end verbosity::Union{Int,Nothing}=nothing, progress::Union{Bool,Nothing}=nothing, v_dim_out::Val{DIM_OUT}=Val(nothing), - logger::Union{AbstractLogger,Nothing}=nothing, - logging_callback::Union{Function,Nothing}=nothing, - log_every_n::Union{Integer,NamedTuple}=1, + logger=nothing, # Defined from options options_return_state::Val{ORS}=Val(nothing), options_verbosity::Union{Integer,Nothing}=nothing, @@ -158,16 +154,6 @@ end numprocs end end - _logging_callback = if logging_callback === nothing && logger !== nothing - (; kws...) -> default_logging_callback(logger; kws...) - else - logging_callback - end - _log_every_n = if log_every_n isa Integer - (; scalars=log_every_n, plots=0) - else - log_every_n - end _return_state = VRS <: Val ? first(VRS.parameters) : something(ORS, return_state, false) dim_out = something(DIM_OUT, nout > 1 ? 2 : 1) @@ -190,7 +176,7 @@ end `` end - return RuntimeOptions{concurrency,dim_out,_return_state,typeof(_log_every_n)}( + return RuntimeOptions{concurrency,dim_out,_return_state,typeof(logger)}( niterations, _numprocs, procs, @@ -199,8 +185,7 @@ end runtests, _verbosity, _progress, - _logging_callback, - _log_every_n, + logger, Val(concurrency), Val(dim_out), Val(_return_state), diff --git a/src/SymbolicRegression.jl b/src/SymbolicRegression.jl index 97140a914..685233ee5 100644 --- a/src/SymbolicRegression.jl +++ b/src/SymbolicRegression.jl @@ -22,6 +22,7 @@ export Population, EvalOptions, SRRegressor, MultitargetSRRegressor, + SRLogger, #Functions: equation_search, @@ -83,7 +84,6 @@ export Population, using Distributed using Printf: @printf, @sprintf -using Logging: AbstractLogger using Pkg: Pkg using TOML: parsefile using Random: seed!, shuffle! @@ -310,7 +310,9 @@ using .SearchUtilsModule: construct_datasets, save_to_file, get_cur_maxsize, - update_hall_of_fame! + update_hall_of_fame!, + logging_callback! +using .LoggingModule: AbstractSRLogger, SRLogger, get_logger using .TemplateExpressionModule: TemplateExpression, TemplateStructure using .TemplateExpressionModule: TemplateExpression, TemplateStructure, ValidVector using .ComposableExpressionModule: ComposableExpression @@ -434,9 +436,7 @@ function equation_search( run_id::Union{String,Nothing}=nothing, loss_type::Type{L}=Nothing, verbosity::Union{Integer,Nothing}=nothing, - logger::Union{AbstractLogger,Nothing}=nothing, - logging_callback::Union{Function,Nothing}=nothing, - log_every_n::Union{Integer,NamedTuple}=1, + logger::Union{AbstractSRLogger,Nothing}=nothing, progress::Union{Bool,Nothing}=nothing, X_units::Union{AbstractVector,Nothing}=nothing, y_units=nothing, @@ -485,8 +485,6 @@ function equation_search( run_id=run_id, verbosity=verbosity, logger=logger, - logging_callback=logging_callback, - log_every_n=log_every_n, progress=progress, v_dim_out=Val(DIM_OUT), ) @@ -796,7 +794,6 @@ function _main_search_loop!( nothing end - log_step = 0 last_print_time = time() last_speed_recording_time = time() num_evals_last = sum(sum, state.num_evals) @@ -955,10 +952,9 @@ function _main_search_loop!( ropt.parallelism, ) end - if ropt.logging_callback !== nothing && log_step % ropt.log_every_n.scalars == 0 - ropt.logging_callback(; log_step, state, datasets, ropt, options) + if ropt.logger !== nothing + logging_callback!(ropt.logger; state, datasets, ropt, options) end - log_step += 1 end yield() diff --git a/test/runtests.jl b/test/runtests.jl index b897a0a02..d02a1b9be 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -160,10 +160,7 @@ end include("test_abstract_numbers.jl") end -@testitem "Test logging" tags = [:part1, :integration] begin - include("test_logging.jl") -end - +include("test_logging.jl") include("test_pretty_printing.jl") include("test_expression_builder.jl") include("test_composable_expression.jl") diff --git a/test/test_logging.jl b/test/test_logging.jl index e0512e414..da0d232fe 100644 --- a/test/test_logging.jl +++ b/test/test_logging.jl @@ -1,37 +1,38 @@ -using Test -using SymbolicRegression -using TensorBoardLogger -using Logging -using MLJBase -using Plots -include("test_params.jl") +@testitem "Test logging" tags = [:part1, :integration] begin + using SymbolicRegression, TensorBoardLogger, Logging, MLJBase, Plots -mktempdir() do dir - logger = TBLogger(dir, tb_overwrite; min_level=Logging.Info) + include("test_params.jl") - niterations = 4 - populations = 36 - log_every_n = (; scalars=2, plots=10) - model = SRRegressor(; - binary_operators=[+, -, *, mod], - unary_operators=[], - maxsize=40, - niterations, - populations, - log_every_n, - logger, - ) + mktempdir() do dir + logger = SRLogger(; + logger=TBLogger(dir, tb_overwrite; min_level=Logging.Info), + log_interval_scalars=2, + log_interval_plots=10, + ) - X = (a=rand(500), b=rand(500)) - y = @. 2 * cos(X.a * 23.5) - X.b^2 - mach = machine(model, X, y) + niterations = 4 + populations = 36 + model = SRRegressor(; + binary_operators=[+, -, *, mod], + unary_operators=[], + maxsize=40, + niterations, + populations, + log_every_n, + logger, + ) - fit!(mach) + X = (a=rand(500), b=rand(500)) + y = @. 2 * cos(X.a * 23.5) - X.b^2 + mach = machine(model, X, y) - b = TensorBoardLogger.steps(logger) - @test length(b) == (niterations * populations//log_every_n.scalars) + 1 + fit!(mach) - files_and_dirs = readdir(dir) - @test length(files_and_dirs) == 1 - @test occursin(r"events\.out\.tfevents", only(files_and_dirs)) + b = TensorBoardLogger.steps(logger) + @test length(b) == (niterations * populations//log_every_n.scalars) + 1 + + files_and_dirs = readdir(dir) + @test length(files_and_dirs) == 1 + @test occursin(r"events\.out\.tfevents", only(files_and_dirs)) + end end From 87cdffe6a5c5c717144f41666fd2a92f6ba7fb55 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Fri, 8 Nov 2024 18:48:51 -0500 Subject: [PATCH 36/54] fix: MLJ with logger --- src/MLJInterface.jl | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/MLJInterface.jl b/src/MLJInterface.jl index ba75f579b..9c5f87cb5 100644 --- a/src/MLJInterface.jl +++ b/src/MLJInterface.jl @@ -31,6 +31,7 @@ using ..CoreModule.OptionsModule: DEFAULT_OPTIONS, OPTION_DESCRIPTIONS using ..ComplexityModule: compute_complexity using ..HallOfFameModule: HallOfFame, format_hall_of_fame using ..UtilsModule: subscriptify, @ignore +using ..LoggingModule: AbstractSRLogger import ..equation_search @@ -57,9 +58,7 @@ function modelexpr(model_name::Symbol) procs::Union{Vector{Int},Nothing} = nothing addprocs_function::Union{Function,Nothing} = nothing heap_size_hint_in_bytes::Union{Integer,Nothing} = nothing - logger::Union{AbstractLogger,Nothing} = nothing - logging_callback::Union{Function,Nothing} = nothing - log_every_n::Union{Integer,NamedTuple} = 1 + logger::Union{AbstractSRLogger,Nothing} = nothing runtests::Bool = true run_id::Union{String,Nothing} = nothing loss_type::L = Nothing @@ -270,8 +269,6 @@ function _update( verbosity=verbosity, extra=isnothing(class) ? (;) : (; class), logger=m.logger, - logging_callback=m.logging_callback, - log_every_n=m.log_every_n, # Help out with inference: v_dim_out=isa(m, SRRegressor) ? Val(1) : Val(2), ) From fbe7fbdccc8641388acfb18ecb1e6c3ee5c7bd52 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Fri, 8 Nov 2024 19:00:11 -0500 Subject: [PATCH 37/54] test: fix logging test --- ext/SymbolicRegressionRecipesBaseExt.jl | 4 ++-- src/Logging.jl | 15 +++++++++++---- test/test_logging.jl | 5 ++--- 3 files changed, 15 insertions(+), 9 deletions(-) diff --git a/ext/SymbolicRegressionRecipesBaseExt.jl b/ext/SymbolicRegressionRecipesBaseExt.jl index bdd26e839..c0d83f489 100644 --- a/ext/SymbolicRegressionRecipesBaseExt.jl +++ b/ext/SymbolicRegressionRecipesBaseExt.jl @@ -10,9 +10,9 @@ using SymbolicRegression.LoggingModule: convex_hull import SymbolicRegression.LoggingModule: add_plot_to_log! function add_plot_to_log!(; - trees, losses, complexities, @nospecialize(options), variable_names + hall_of_fame::HallOfFame, @nospecialize(options::Options), variable_names ) - plot_result = plot(trees, losses, complexities, options; variable_names=variable_names) + plot_result = plot(hall_of_fame, options; variable_names=variable_names) return Dict{String,Any}("plot" => plot_result) end diff --git a/src/Logging.jl b/src/Logging.jl index 2edee6e2c..71faf6bcf 100644 --- a/src/Logging.jl +++ b/src/Logging.jl @@ -109,15 +109,22 @@ function log_payload( if should_log_scalars out = merge( out, - _log_scalars( - state.last_pops[i], state.halls_of_fame[i], datasets[i], options + _log_scalars(; + pops=state.last_pops[i], + hall_of_fame=state.halls_of_fame[i], + dataset=datasets[i], + options, ), ) end if should_log_plots out = merge( out, - add_plot_to_log!(; trees, losses, complexities, datasets[i].variable_names), + add_plot_to_log!(; + hall_of_fame=state.halls_of_fame[i], + options, + variable_names=datasets[i].variable_names, + ), ) end if length(datasets) == 1 @@ -130,7 +137,7 @@ function log_payload( return d[] end -function _log_scalars( +function _log_scalars(; @nospecialize(pops::AbstractVector{<:Population}), @nospecialize(hall_of_fame::HallOfFame{T,L}), dataset::Dataset{T,L}, diff --git a/test/test_logging.jl b/test/test_logging.jl index da0d232fe..c5bb29fad 100644 --- a/test/test_logging.jl +++ b/test/test_logging.jl @@ -18,7 +18,6 @@ maxsize=40, niterations, populations, - log_every_n, logger, ) @@ -28,8 +27,8 @@ fit!(mach) - b = TensorBoardLogger.steps(logger) - @test length(b) == (niterations * populations//log_every_n.scalars) + 1 + b = TensorBoardLogger.steps(logger.logger) + @test length(b) == (niterations * populations//2) + 1 files_and_dirs = readdir(dir) @test length(files_and_dirs) == 1 From 94ec4631dfbd28964c7fde0f816bf88bbca6dc13 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Fri, 8 Nov 2024 19:38:11 -0500 Subject: [PATCH 38/54] test: for convex hull accuracy --- src/Logging.jl | 7 ++++--- test/test_logging.jl | 20 ++++++++++++++++++++ 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/src/Logging.jl b/src/Logging.jl index 71faf6bcf..fefdda915 100644 --- a/src/Logging.jl +++ b/src/Logging.jl @@ -60,9 +60,9 @@ function should_log(logger::SRLogger, ::Val{:plots}) logger._log_step[] % logger.log_interval_plots == 0 end -#! format: off -LG.with_logger(f::Function, logger::AbstractSRLogger) = LG.with_logger(f, get_logger(logger)) -#! format: on +function LG.with_logger(f::Function, logger::AbstractSRLogger) + return LG.with_logger(f, get_logger(logger)) +end # Will get method created by RecipesBase extension function add_plot_to_log! end @@ -199,6 +199,7 @@ end """Uses gift wrapping algorithm to create a convex hull.""" function convex_hull(xy) + @assert size(xy, 2) == 2 cur_point = xy[sortperm(xy[:, 1])[1], :] hull = typeof(cur_point)[] while true diff --git a/test/test_logging.jl b/test/test_logging.jl index c5bb29fad..3bf9b751e 100644 --- a/test/test_logging.jl +++ b/test/test_logging.jl @@ -35,3 +35,23 @@ @test occursin(r"events\.out\.tfevents", only(files_and_dirs)) end end +@testitem "Test convex hull calculation" tags = [:part1] begin + using SymbolicRegression.LoggingModule: convex_hull, convex_hull_area + + # Create a Pareto front with an interior point that should be ignored + # Points: (0,0), (0,2), (2,0), and (1,1) which is inside the triangle + points = [ + 0.0 0.0 # vertex 1 + 0.0 2.0 # vertex 2 + 2.0 0.0 # vertex 3 + 1.0 1.0 # interior point that should be ignored + ] + hull = convex_hull(points) + + @test length(hull) == 3 + @test hull == [[0.0, 0.0], [0.0, 2.0], [2.0, 0.0]] + + # Expected area = 1/2 * base * height = 1/2 * 2 * 2 = 2 + area = convex_hull_area(hull) + @test isapprox(area, 2.0, atol=1e-10) +end From 22c46578ab2ff67c61c15d4fb842e397fa0c1da4 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Fri, 8 Nov 2024 19:46:48 -0500 Subject: [PATCH 39/54] refactor: fix jet error --- ext/SymbolicRegressionRecipesBaseExt.jl | 4 ++-- src/Logging.jl | 7 ++++--- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/ext/SymbolicRegressionRecipesBaseExt.jl b/ext/SymbolicRegressionRecipesBaseExt.jl index c0d83f489..d7f074599 100644 --- a/ext/SymbolicRegressionRecipesBaseExt.jl +++ b/ext/SymbolicRegressionRecipesBaseExt.jl @@ -7,9 +7,9 @@ using SymbolicRegression.HallOfFameModule: HallOfFame, format_hall_of_fame using SymbolicRegression.MLJInterfaceModule: SRFitResult, SRRegressor using SymbolicRegression.LoggingModule: convex_hull -import SymbolicRegression.LoggingModule: add_plot_to_log! +import SymbolicRegression.LoggingModule: make_plot -function add_plot_to_log!(; +function make_plot(; hall_of_fame::HallOfFame, @nospecialize(options::Options), variable_names ) plot_result = plot(hall_of_fame, options; variable_names=variable_names) diff --git a/src/Logging.jl b/src/Logging.jl index fefdda915..dee268235 100644 --- a/src/Logging.jl +++ b/src/Logging.jl @@ -65,8 +65,9 @@ function LG.with_logger(f::Function, logger::AbstractSRLogger) end # Will get method created by RecipesBase extension -function add_plot_to_log! end -@ignore add_plot_to_log!(; kws...) = nothing +function make_plot(; kws...) + return error("Please load `Plots` or another plotting package.") +end """ logging_callback!(logger::AbstractSRLogger; kws...) @@ -120,7 +121,7 @@ function log_payload( if should_log_plots out = merge( out, - add_plot_to_log!(; + make_plot(; hall_of_fame=state.halls_of_fame[i], options, variable_names=datasets[i].variable_names, From 5ed8de33edfc4dbc345ea3486e750d998bb64bc8 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Fri, 8 Nov 2024 20:16:40 -0500 Subject: [PATCH 40/54] test: harder test of convex hull --- ext/SymbolicRegressionRecipesBaseExt.jl | 2 +- src/Logging.jl | 9 ++----- test/test_logging.jl | 34 ++++++++++++++++--------- 3 files changed, 25 insertions(+), 20 deletions(-) diff --git a/ext/SymbolicRegressionRecipesBaseExt.jl b/ext/SymbolicRegressionRecipesBaseExt.jl index d7f074599..d8473ebe2 100644 --- a/ext/SymbolicRegressionRecipesBaseExt.jl +++ b/ext/SymbolicRegressionRecipesBaseExt.jl @@ -9,7 +9,7 @@ using SymbolicRegression.LoggingModule: convex_hull import SymbolicRegression.LoggingModule: make_plot -function make_plot(; +function make_plot( hall_of_fame::HallOfFame, @nospecialize(options::Options), variable_names ) plot_result = plot(hall_of_fame, options; variable_names=variable_names) diff --git a/src/Logging.jl b/src/Logging.jl index dee268235..d5d42bbfd 100644 --- a/src/Logging.jl +++ b/src/Logging.jl @@ -65,7 +65,7 @@ function LG.with_logger(f::Function, logger::AbstractSRLogger) end # Will get method created by RecipesBase extension -function make_plot(; kws...) +function make_plot(args...) return error("Please load `Plots` or another plotting package.") end @@ -120,12 +120,7 @@ function log_payload( end if should_log_plots out = merge( - out, - make_plot(; - hall_of_fame=state.halls_of_fame[i], - options, - variable_names=datasets[i].variable_names, - ), + out, make_plot(state.halls_of_fame[i], options, datasets[i].variable_names) ) end if length(datasets) == 1 diff --git a/test/test_logging.jl b/test/test_logging.jl index 3bf9b751e..2d76f08b4 100644 --- a/test/test_logging.jl +++ b/test/test_logging.jl @@ -39,19 +39,29 @@ end using SymbolicRegression.LoggingModule: convex_hull, convex_hull_area # Create a Pareto front with an interior point that should be ignored - # Points: (0,0), (0,2), (2,0), and (1,1) which is inside the triangle - points = [ - 0.0 0.0 # vertex 1 - 0.0 2.0 # vertex 2 - 2.0 0.0 # vertex 3 - 1.0 1.0 # interior point that should be ignored - ] - hull = convex_hull(points) - - @test length(hull) == 3 - @test hull == [[0.0, 0.0], [0.0, 2.0], [2.0, 0.0]] + log_complexities = [1.0, 2.0, 3.0, 4.0] + log_losses = [4.0, 3.0, 3.0, 2.5] + + # Add a point to connect things at lower right corner + push!(log_complexities, 5.0) + push!(log_losses, 2.5) + + # Add a point to connect things at upper right corner + push!(log_losses, 4.0) + push!(log_complexities, 5.0) + + xy = cat(log_complexities, log_losses; dims=2) + hull = convex_hull(xy) + @test length(hull) == 5 + @test hull == [[1.0, 4.0], [5.0, 4.0], [5.0, 2.5], [4.0, 2.5], [2.0, 3.0]] # Expected area = 1/2 * base * height = 1/2 * 2 * 2 = 2 area = convex_hull_area(hull) - @test isapprox(area, 2.0, atol=1e-10) + true_area = ( + 1 * (4.0 - 2.5) # lower right rectangle + + 2.0 * (4.0 - 3.0) # block to the slight left and update + + 1.0 * (4.0 - 3.0) / 2 # top left triangle + + 2.0 * (3.0 - 2.5) / 2 # bottom triangle + ) + @test isapprox(area, true_area, atol=1e-10) end From 0f5d1bfb469e3997df0395157ca80a158c2d39dc Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Fri, 8 Nov 2024 20:36:56 -0500 Subject: [PATCH 41/54] feat: make prefix better for template expressions --- src/HallOfFame.jl | 17 ++++++++++------- src/TemplateExpression.jl | 7 ++++++- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/src/HallOfFame.jl b/src/HallOfFame.jl index 44acf2ea7..7d870f78b 100644 --- a/src/HallOfFame.jl +++ b/src/HallOfFame.jl @@ -151,13 +151,7 @@ function string_dominating_pareto_curve( y_sym_units=dataset.y_sym_units, pretty, ) - y_prefix = dataset.y_variable_name - unit_str = format_dimensions(dataset.y_sym_units) - y_prefix *= unit_str - if dataset.y_sym_units === nothing && dataset.X_sym_units !== nothing - y_prefix *= WILDCARD_UNIT_STRING - end - prefix = y_prefix * " = " + prefix = make_prefix(tree, options, dataset) eqn_string = prefix * eqn_string stats_columns_string = @sprintf("%-10d %-8.3e %-8.3e ", complexity, loss, score) left_cols_width = length(stats_columns_string) @@ -172,6 +166,15 @@ function string_dominating_pareto_curve( print(buffer, '─'^(terminal_width - 1)) return dump_buffer(buffer) end +function make_prefix(::AbstractExpression, ::AbstractOptions, dataset::Dataset) + y_prefix = dataset.y_variable_name + unit_str = format_dimensions(dataset.y_sym_units) + y_prefix *= unit_str + if dataset.y_sym_units === nothing && dataset.X_sym_units !== nothing + y_prefix *= WILDCARD_UNIT_STRING + end + return y_prefix * " = " +end function wrap_equation_string(eqn_string, left_cols_width, terminal_width) dots = "..." diff --git a/src/TemplateExpression.jl b/src/TemplateExpression.jl index 39ceab3d6..b8ebe0e6c 100644 --- a/src/TemplateExpression.jl +++ b/src/TemplateExpression.jl @@ -29,6 +29,7 @@ using ..ConstantOptimizationModule: ConstantOptimizationModule as CO using ..InterfaceDynamicExpressionsModule: InterfaceDynamicExpressionsModule as IDE using ..MutationFunctionsModule: MutationFunctionsModule as MF using ..ExpressionBuilderModule: ExpressionBuilderModule as EB +using ..HallOfFameModule: HallOfFameModule as HOF using ..DimensionalAnalysisModule: DimensionalAnalysisModule as DA using ..CheckConstraintsModule: CheckConstraintsModule as CC using ..ComplexityModule: ComplexityModule @@ -305,7 +306,7 @@ function DE.string_tree( prefix = if !pretty || length(function_keys) == 1 "" elseif k == first(function_keys) - "╭ " + "┬ " elseif k == last(function_keys) "╰ " else @@ -320,6 +321,10 @@ function DE.string_tree( ) return annotatedstring(join(strings, pretty ? styled"\n" : "; ")) end +function HOF.make_prefix(::TemplateExpression, ::AbstractOptions, ::Dataset) + return "" +end + @stable( default_mode = "disable", default_union_limit = 2, From bf05304d7d8fa324fd02384403473b7077672c7a Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Fri, 8 Nov 2024 20:38:15 -0500 Subject: [PATCH 42/54] deps: remove `[extras]` again --- Project.toml | 5 ----- 1 file changed, 5 deletions(-) diff --git a/Project.toml b/Project.toml index 7c692471d..4d3dd9b63 100644 --- a/Project.toml +++ b/Project.toml @@ -73,8 +73,3 @@ StyledStrings = "1" SymbolicUtils = "0.19, ^1.0.5, 2, 3" TOML = "<0.0.1, 1" julia = "1.10" - -[extras] -JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" -RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" -SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b" From ef09e782b5a77f9a03f6101eb02d4e66e9113c2b Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Fri, 8 Nov 2024 20:46:46 -0500 Subject: [PATCH 43/54] refactor: remove unused printing --- ext/SymbolicRegressionRecipesBaseExt.jl | 4 -- src/OptionsStruct.jl | 26 +++++++++++ src/Printing.jl | 60 ------------------------- src/SymbolicRegression.jl | 1 - 4 files changed, 26 insertions(+), 65 deletions(-) delete mode 100644 src/Printing.jl diff --git a/ext/SymbolicRegressionRecipesBaseExt.jl b/ext/SymbolicRegressionRecipesBaseExt.jl index d8473ebe2..43083a942 100644 --- a/ext/SymbolicRegressionRecipesBaseExt.jl +++ b/ext/SymbolicRegressionRecipesBaseExt.jl @@ -18,10 +18,6 @@ end function default_sr_plot end -@recipe function default_sr_plot(fitresult::SRFitResult{<:SRRegressor}) - return fitresult.state[2], fitresult.options -end - # TODO: Add variable names @recipe function default_sr_plot(hall_of_fame::HallOfFame, @nospecialize(options::Options)) out = format_hall_of_fame(hall_of_fame, options) diff --git a/src/OptionsStruct.jl b/src/OptionsStruct.jl index ffd368531..22871606d 100644 --- a/src/OptionsStruct.jl +++ b/src/OptionsStruct.jl @@ -257,6 +257,32 @@ struct Options{ use_recorder::Bool end +function Base.print(io::IO, @nospecialize(options::Options)) + return print( + io, + "Options(" * + "binops=$(options.operators.binops), " * + "unaops=$(options.operators.unaops), " + # Fill in remaining fields automatically: + * + join( + [ + if fieldname in (:optimizer_options, :mutation_weights) + "$(fieldname)=..." + else + "$(fieldname)=$(getfield(options, fieldname))" + end for + fieldname in fieldnames(Options) if fieldname ∉ [:operators, :nuna, :nbin] + ], + ", ", + ) * + ")", + ) +end +function Base.show(io::IO, ::MIME"text/plain", @nospecialize(options::Options)) + return Base.print(io, options) +end + specialized_options(options::AbstractOptions) = options @unstable function specialized_options(options::Options) return _specialized_options(options, options.operators) diff --git a/src/Printing.jl b/src/Printing.jl deleted file mode 100644 index c10a0be7a..000000000 --- a/src/Printing.jl +++ /dev/null @@ -1,60 +0,0 @@ -"""Defines printing methods of exported types (aside from expressions themselves)""" -module PrintingModule - -using ..CoreModule: Options -using ..MLJInterfaceModule: SRRegressor, SRFitResult - -function Base.print(io::IO, @nospecialize(options::Options)) - return print( - io, - "Options(" * - "binops=$(options.operators.binops), " * - "unaops=$(options.operators.unaops), " - # Fill in remaining fields automatically: - * - join( - [ - if fieldname in (:optimizer_options, :mutation_weights) - "$(fieldname)=..." - else - "$(fieldname)=$(getfield(options, fieldname))" - end for - fieldname in fieldnames(Options) if fieldname ∉ [:operators, :nuna, :nbin] - ], - ", ", - ) * - ")", - ) -end -function Base.show(io::IO, ::MIME"text/plain", @nospecialize(options::Options)) - return Base.print(io, options) -end - -function Base.show(io::IO, ::MIME"text/plain", @nospecialize(fitresult::SRFitResult)) - print(io, "SRFitResult for $(fitresult.model):") - print(io, "\n") - print(io, " state:\n") - print(io, " [1]: $(typeof(fitresult.state[1])) with ") - print(io, "$(length(fitresult.state[1])) × $(length(fitresult.state[1][1])) ") - print(io, "populations of $(fitresult.state[1][1][1].n) members\n") - print(io, " [2]: $(typeof(fitresult.state[2])) ") - if fitresult.model isa SRRegressor - print(io, "with $(sum(fitresult.state[2].exists)) saved expressions") - else - print(io, "with $(map(s -> sum(s.exists), fitresult.state[2])) saved expressions") - end - print(io, "\n") - print(io, " num_targets: $(fitresult.num_targets)") - print(io, "\n") - print(io, " variable_names: $(fitresult.variable_names)") - print(io, "\n") - print(io, " y_variable_names: $(fitresult.y_variable_names)") - print(io, "\n") - print(io, " X_units: $(fitresult.X_units)") - print(io, "\n") - print(io, " y_units: $(fitresult.y_units)") - print(io, "\n") - return nothing -end - -end diff --git a/src/SymbolicRegression.jl b/src/SymbolicRegression.jl index 685233ee5..30607d94b 100644 --- a/src/SymbolicRegression.jl +++ b/src/SymbolicRegression.jl @@ -1142,7 +1142,6 @@ end include("MLJInterface.jl") using .MLJInterfaceModule: SRRegressor, MultitargetSRRegressor -include("Printing.jl") # Hack to get static analysis to work from within tests: @ignore include("../test/runtests.jl") From 4f51f6052079894e66d5a0f0612f0ea9290ca9d2 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Fri, 8 Nov 2024 21:07:24 -0500 Subject: [PATCH 44/54] style: tweak corner output --- src/TemplateExpression.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/TemplateExpression.jl b/src/TemplateExpression.jl index b8ebe0e6c..75ebba2bc 100644 --- a/src/TemplateExpression.jl +++ b/src/TemplateExpression.jl @@ -306,7 +306,7 @@ function DE.string_tree( prefix = if !pretty || length(function_keys) == 1 "" elseif k == first(function_keys) - "┬ " + "╭ " elseif k == last(function_keys) "╰ " else From a31d13a3214e93c5b38cd5da641e6c95594948ab Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Fri, 8 Nov 2024 21:25:53 -0500 Subject: [PATCH 45/54] docs: document logging interface --- docs/src/api.md | 26 ++++++++++++++++++++++++++ docs/src/examples.md | 39 ++++++++++++++++++++++++++++++++++++++- src/SymbolicRegression.jl | 5 ++++- 3 files changed, 68 insertions(+), 2 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index b7e0f7e28..3abae7b8d 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -60,3 +60,29 @@ Note that use of this function requires `SymbolicUtils.jl` to be installed and l ```@docs calculate_pareto_frontier ``` + +## Logging + +```@docs +SRLogger +``` + +The `SRLogger` allows you to track the progress of symbolic regression searches. +It can wrap any `AbstractLogger` that implements the Julia logging interface, +such as from TensorBoardLogger.jl or Wandb.jl. + +```julia +using TensorBoardLogger + +logger = SRLogger( + TBLogger("logs/run1", tb_overwrite), # Base logger to use + log_interval_scalars=2, # Log scalar metrics every 2 steps + # log_interval_plots=0 # Log plots steps (requires Plots.jl to be loaded) + # ^ Set to greater than 0 to enable logging of plots +) + +model = SRRegressor(; + logger=logger, + kws... +) +``` diff --git a/docs/src/examples.md b/docs/src/examples.md index 10daa5a5d..50ea9d1d5 100644 --- a/docs/src/examples.md +++ b/docs/src/examples.md @@ -384,7 +384,44 @@ You can even output custom structs - see the more detailed [Template Expression Be sure to also check out the [Parametric Expression example](examples/parametric_expression.md). -## 9. Additional features +## 9. Logging with TensorBoard + +You can track the progress of symbolic regression searches using TensorBoard or other logging backends. Here's an example using `TensorBoardLogger` and wrapping it with [`SRLogger`](@ref): + +```julia +using SymbolicRegression +using TensorBoardLogger +using MLJ + +logger = SRLogger(TBLogger("logs/sr_run")) + +# Create and fit model with logger +model = SRRegressor( + binary_operators=[+, -, *], + maxsize=40, + niterations=100, + logger=logger +) + +X = (a=rand(500), b=rand(500)) +y = @. 2 * cos(X.a * 23.5) - X.b^2 + +mach = machine(model, X, y) +fit!(mach) +``` + +You can then view the logs with: + +```bash +tensorboard --logdir logs +``` + +The TensorBoard interface will show +the loss curves over time (at each complexity), as well +as the Pareto frontier volume which can be used as an overall metric +of the search performance. + +## 10. Additional features For the many other features available in SymbolicRegression.jl, check out the API page for `Options`. You might also find it useful diff --git a/src/SymbolicRegression.jl b/src/SymbolicRegression.jl index 30607d94b..b843f7d25 100644 --- a/src/SymbolicRegression.jl +++ b/src/SymbolicRegression.jl @@ -370,7 +370,7 @@ which is useful for debugging and profiling. a distributed run manually with `procs = addprocs()` and `@everywhere`, pass the `procs` to this keyword argument. - `addprocs_function::Union{Function, Nothing}=nothing`: If using multiprocessing - (`parallelism=:multithreading`), and are not passing `procs` manually, + (`parallelism=:multiprocessing`), and are not passing `procs` manually, then they will be allocated dynamically using `addprocs`. However, you may also pass a custom function to use instead of `addprocs`. This function should take a single positional argument, @@ -399,6 +399,9 @@ which is useful for debugging and profiling. Note that if you pass complex data `::Complex{L}`, then the loss type will automatically be set to `L`. - `verbosity`: Whether to print debugging statements or not. +- `logger::Union{AbstractSRLogger,Nothing}=nothing`: An optional logger to record + the progress of the search. You can use an `SRLogger` to wrap a custom logger, + or pass `nothing` to disable logging. - `progress`: Whether to use a progress bar output. Only available for single target output. - `X_units::Union{AbstractVector,Nothing}=nothing`: The units of the dataset, From ea91a613a3671448cdd8d6968ec7aa14499b1ada Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Fri, 8 Nov 2024 23:20:06 -0500 Subject: [PATCH 46/54] fix: MLJ warm start --- examples/parameterized_function.jl | 9 +++++---- src/MLJInterface.jl | 7 ++++++- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/examples/parameterized_function.jl b/examples/parameterized_function.jl index 13d5fa370..21f94fdef 100644 --- a/examples/parameterized_function.jl +++ b/examples/parameterized_function.jl @@ -101,14 +101,15 @@ functional form, but with varying parameters across different conditions or clas fit!(mach) idx1 = lastindex(report(mach).equations) ypred1 = predict(mach, (data=X, idx=idx1)) -loss1 = sum(i -> abs2(ypred1[i] - y[i]), eachindex(y)) +loss1 = sum(i -> abs2(ypred1[i] - y[i]), eachindex(y)) / length(y) # Should keep all parameters -stop_at[] = 1e-5 +stop_at[] = loss1 * 0.999 +mach.model.niterations *= 2 fit!(mach) idx2 = lastindex(report(mach).equations) ypred2 = predict(mach, (data=X, idx=idx2)) -loss2 = sum(i -> abs2(ypred2[i] - y[i]), eachindex(y)) +loss2 = sum(i -> abs2(ypred2[i] - y[i]), eachindex(y)) / length(y) # Should get better: -@test loss1 >= loss2 +@test loss1 > loss2 diff --git a/src/MLJInterface.jl b/src/MLJInterface.jl index 9c5f87cb5..3606fd318 100644 --- a/src/MLJInterface.jl +++ b/src/MLJInterface.jl @@ -141,6 +141,7 @@ Base.@kwdef struct SRFitResult{ } model::M state::S + niterations::Int num_targets::Int options::O variable_names::Vector{String} @@ -246,10 +247,13 @@ function _update( else w end + niterations = + m.niterations - (old_fitresult === nothing ? 0 : old_fitresult.niterations) + @assert niterations >= 0 search_state::types.state = equation_search( X_t, y_t; - niterations=m.niterations, + niterations=niterations, weights=w_t, variable_names=variable_names, display_variable_names=display_variable_names, @@ -275,6 +279,7 @@ function _update( fitresult = SRFitResult(; model=m, state=search_state, + niterations=niterations, num_targets=isa(m, SRRegressor) ? 1 : size(y_t, 1), options=options, variable_names=variable_names, From 68e3f403f5ed5bb688f3cde615430021ecb0eff0 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Fri, 8 Nov 2024 23:25:28 -0500 Subject: [PATCH 47/54] docs: describe MLJ change --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 22fa6f814..c38f4e38c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -405,6 +405,7 @@ A custom run ID can be specified via the new `run_id` parameter passed to `equat - Option to force dimensionless constants when fitting with dimensional constraints, via the `dimensionless_constants_only` option. - Default `maxsize` increased from 20 to 30. - Default `niterations` increased from 10 to 50, as many users seem to be unaware that this is small (and meant for testing), even in publications. I think this 50 is still low, but it should be a more accurate default for those who don't tune. +- `MLJ.fit!(mach)` now records the number of iterations used, and, should `mach.model.niterations` be changed after the fit, the number of iterations passed to `equation_search` will be reduced accordingly. ### Update Guide From 29e361592b60cdbf324d4930583a50006a80232c Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Fri, 8 Nov 2024 23:46:57 -0500 Subject: [PATCH 48/54] fix: update MLJ niterations --- src/MLJInterface.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/MLJInterface.jl b/src/MLJInterface.jl index 3606fd318..801b0f95d 100644 --- a/src/MLJInterface.jl +++ b/src/MLJInterface.jl @@ -279,7 +279,8 @@ function _update( fitresult = SRFitResult(; model=m, state=search_state, - niterations=niterations, + niterations=niterations + + (old_fitresult === nothing ? 0 : old_fitresult.niterations), num_targets=isa(m, SRRegressor) ? 1 : size(y_t, 1), options=options, variable_names=variable_names, From ff62fae4652f235c0deec0d81fcc9cd90c304071 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Fri, 8 Nov 2024 23:53:39 -0500 Subject: [PATCH 49/54] test: fix test_units.jl --- test/test_units.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/test_units.jl b/test/test_units.jl index da7f45fa3..9a9d0338b 100644 --- a/test/test_units.jl +++ b/test/test_units.jl @@ -298,8 +298,7 @@ end @test minimum(report.losses[2]) < 1e-7 # Repeat with second run: - mach.model.niterations = 0 - MLJ.fit!(mach) + MLJ.fit!(mach) # (Will run with 0 iterations) report = MLJ.report(mach) @test minimum(report.losses[1]) < 1e-7 @test minimum(report.losses[2]) < 1e-7 From e40cc01feab37fe606e6c0d4959dc826616533eb Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sat, 9 Nov 2024 11:00:21 -0500 Subject: [PATCH 50/54] feat!: remove plotting utility --- Project.toml | 3 - ext/SymbolicRegressionRecipesBaseExt.jl | 75 ------------------------- src/Logging.jl | 58 ++++++------------- test/Project.toml | 1 - test/test_logging.jl | 6 +- 5 files changed, 19 insertions(+), 124 deletions(-) delete mode 100644 ext/SymbolicRegressionRecipesBaseExt.jl diff --git a/Project.toml b/Project.toml index 4d3dd9b63..805cd6472 100644 --- a/Project.toml +++ b/Project.toml @@ -33,13 +33,11 @@ TOML = "fa267f1f-6049-4f14-aa54-33bafae1ed76" [weakdeps] Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" -RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b" [extensions] SymbolicRegressionEnzymeExt = "Enzyme" SymbolicRegressionJSON3Ext = "JSON3" -SymbolicRegressionRecipesBaseExt = "RecipesBase" SymbolicRegressionSymbolicUtilsExt = "SymbolicUtils" [compat] @@ -65,7 +63,6 @@ PrecompileTools = "1" Printf = "<0.0.1, 1" ProgressMeter = "1.10" Random = "<0.0.1, 1" -RecipesBase = "1" Reexport = "1" SpecialFunctions = "0.10.1, 1, 2" StatsBase = "0.33, 0.34" diff --git a/ext/SymbolicRegressionRecipesBaseExt.jl b/ext/SymbolicRegressionRecipesBaseExt.jl deleted file mode 100644 index 43083a942..000000000 --- a/ext/SymbolicRegressionRecipesBaseExt.jl +++ /dev/null @@ -1,75 +0,0 @@ -module SymbolicRegressionRecipesBaseExt - -using RecipesBase: @recipe, @series, plot -using DynamicExpressions: AbstractExpression, string_tree -using SymbolicRegression.CoreModule: Options -using SymbolicRegression.HallOfFameModule: HallOfFame, format_hall_of_fame -using SymbolicRegression.MLJInterfaceModule: SRFitResult, SRRegressor -using SymbolicRegression.LoggingModule: convex_hull - -import SymbolicRegression.LoggingModule: make_plot - -function make_plot( - hall_of_fame::HallOfFame, @nospecialize(options::Options), variable_names -) - plot_result = plot(hall_of_fame, options; variable_names=variable_names) - return Dict{String,Any}("plot" => plot_result) -end - -function default_sr_plot end - -# TODO: Add variable names -@recipe function default_sr_plot(hall_of_fame::HallOfFame, @nospecialize(options::Options)) - out = format_hall_of_fame(hall_of_fame, options) - return (out.trees, out.losses, out.complexities, options) -end - -@recipe function default_sr_plot( - trees::Vector{N}, - losses::Vector{L}, - complexities::Vector{Int}, - @nospecialize(options::Options) -) where {T,L,N<:AbstractExpression{T}} - tree_strings = [string_tree(tree, options) for tree in trees] - log_losses = @. log10(losses + eps(L)) - log_complexities = @. log10(complexities) - # Add an upper right corner to this for the convex hull calculation: - push!(log_losses, maximum(log_losses)) - push!(log_complexities, maximum(log_complexities)) - - xy = cat(log_complexities, log_losses; dims=2) - log_hull = convex_hull(xy) - - # Add the first point again to close the hull: - push!(log_hull, log_hull[1]) - - # Then remove the first two points for visualization - log_hull = log_hull[3:end] - - hull = [10 .^ row for row in log_hull] - - xlabel --> "Complexity" - ylabel --> "Loss" - - xlims --> (0.5, options.maxsize + 1) - - xscale --> :log10 - yscale --> :log10 - - # Main complexity/loss plot: - @series begin - label --> "Pareto Front" - - complexities, losses - end - - # Add on a convex hull: - @series begin - label --> "Convex Hull" - color --> :lightgray - - first.(hull), last.(hull) - end -end - -end diff --git a/src/Logging.jl b/src/Logging.jl index d5d42bbfd..98acbd531 100644 --- a/src/Logging.jl +++ b/src/Logging.jl @@ -36,44 +36,34 @@ A logger for symbolic regression that wraps another logger. # Arguments - `logger`: The base logger to wrap -- `log_interval_scalars`: Number of steps between logging events for scalars. Default is 1 (log every step). -- `log_interval_plots`: Number of steps between logging events for plots. Default is 0 (never log plots). +- `log_interval`: Number of steps between logging events. Default is 1 (log every step). """ Base.@kwdef struct SRLogger{L<:AbstractLogger} <: AbstractSRLogger logger::L - log_interval_scalars::Int = 1 - log_interval_plots::Int = 0 + log_interval::Int = 1 _log_step::Base.RefValue{Int} = Base.RefValue(0) end SRLogger(logger::AbstractLogger; kws...) = SRLogger(; logger, kws...) -get_logger(logger::SRLogger) = logger.logger -function should_log(logger::SRLogger) - return should_log(logger, Val(:scalars)) || should_log(logger, Val(:plots)) +function get_logger(logger::SRLogger) + return logger.logger end -function should_log(logger::SRLogger, ::Val{:scalars}) - return logger.log_interval_scalars > 0 && - logger._log_step[] % logger.log_interval_scalars == 0 +function should_log(logger::SRLogger) + return logger.log_interval > 0 && logger._log_step[] % logger.log_interval == 0 end -function should_log(logger::SRLogger, ::Val{:plots}) - return logger.log_interval_plots > 0 && - logger._log_step[] % logger.log_interval_plots == 0 +function increment_log_step!(logger::SRLogger) + logger._log_step[] += 1 + return nothing end function LG.with_logger(f::Function, logger::AbstractSRLogger) return LG.with_logger(f, get_logger(logger)) end -# Will get method created by RecipesBase extension -function make_plot(args...) - return error("Please load `Plots` or another plotting package.") -end - """ logging_callback!(logger::AbstractSRLogger; kws...) -Default logging callback for SymbolicRegression. Logs the current state of the search, -and adds a plot of the current Pareto front to the logger. +Default logging callback for SymbolicRegression. To override the default logging behavior, create a new type `MyLogger <: AbstractSRLogger` and define a method for `SymbolicRegression.logging_callback`. @@ -85,14 +75,13 @@ function logging_callback!( @nospecialize(ropt::AbstractRuntimeOptions), @nospecialize(options::AbstractOptions), ) where {T,L} - log_step = logger._log_step[] if should_log(logger) data = log_payload(logger, state, datasets, options) LG.with_logger(logger) do @info("search", data = data) end end - logger._log_step[] += 1 + increment_log_step!(logger) return nothing end @@ -103,26 +92,13 @@ function log_payload( @nospecialize(options::AbstractOptions), ) where {T,L} d = Ref(Dict{String,Any}()) - should_log_scalars = should_log(logger, Val(:scalars)) - should_log_plots = should_log(logger, Val(:plots)) for i in eachindex(datasets, state.halls_of_fame) - out = Dict{String,Any}() - if should_log_scalars - out = merge( - out, - _log_scalars(; - pops=state.last_pops[i], - hall_of_fame=state.halls_of_fame[i], - dataset=datasets[i], - options, - ), - ) - end - if should_log_plots - out = merge( - out, make_plot(state.halls_of_fame[i], options, datasets[i].variable_names) - ) - end + out = _log_scalars(; + pops=state.last_pops[i], + hall_of_fame=state.halls_of_fame[i], + dataset=datasets[i], + options, + ) if length(datasets) == 1 d[] = out else diff --git a/test/Project.toml b/test/Project.toml index af65d6514..8775bad76 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -19,7 +19,6 @@ MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea" MLJTestInterface = "72560011-54dd-4dc2-94f3-c5de45b75ecd" Optim = "429524aa-4258-5aef-a3af-852621145aeb" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" -Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb" diff --git a/test/test_logging.jl b/test/test_logging.jl index 2d76f08b4..1ce360ef8 100644 --- a/test/test_logging.jl +++ b/test/test_logging.jl @@ -1,13 +1,11 @@ @testitem "Test logging" tags = [:part1, :integration] begin - using SymbolicRegression, TensorBoardLogger, Logging, MLJBase, Plots + using SymbolicRegression, TensorBoardLogger, Logging, MLJBase include("test_params.jl") mktempdir() do dir logger = SRLogger(; - logger=TBLogger(dir, tb_overwrite; min_level=Logging.Info), - log_interval_scalars=2, - log_interval_plots=10, + logger=TBLogger(dir, tb_overwrite; min_level=Logging.Info), log_interval=2 ) niterations = 4 From 729f98c562dded1a6f6a98c3c0076bf121cb91aa Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sat, 9 Nov 2024 14:40:05 -0500 Subject: [PATCH 51/54] refactor: redundant signature --- src/Logging.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/Logging.jl b/src/Logging.jl index 98acbd531..707eb31b6 100644 --- a/src/Logging.jl +++ b/src/Logging.jl @@ -27,8 +27,6 @@ Abstract type for symbolic regression loggers. Subtypes must implement: """ abstract type AbstractSRLogger <: LG.AbstractLogger end -function get_logger end - """ SRLogger(logger::AbstractLogger; log_every_n::Integer=1) From 7e466ca3914312f1030a2d081899262c578ecff4 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sat, 9 Nov 2024 14:47:52 -0500 Subject: [PATCH 52/54] docs: describe new logger option --- CHANGELOG.md | 32 +++++++++++++++++++++++++++++++- docs/src/api.md | 7 +------ 2 files changed, 32 insertions(+), 7 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c38f4e38c..6063b567f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,7 @@ Summary of major recent changes, described in more detail below: - `AbstractSearchState`, for holding custom metadata during searches. - `AbstractOptions` and `AbstractRuntimeOptions`, for customizing pretty much everything else in the library via multiple dispatch. Please make an issue/PR if you would like any particular internal functions be declared `public` to enable stability across versions for your tool. - Many of these were motivated to modularize the implementation of [LaSR](https://github.com/trishullab/LibraryAugmentedSymbolicRegression.jl), an LLM-guided version of SymbolicRegression.jl, so it can sit as a modular layer on top of SymbolicRegression.jl. +- [Added TensorBoardLogger.jl and other logging integrations via `SRLogger`](#added-tensorboardloggerjl-and-other-logging-integrations-via-srlogger) - Fundamental improvements to the underlying evolutionary algorithm - New mutation operators introduced, `swap_operands` and `rotate_tree` – both of which seem to help kick the evolution out of local optima. - New hyperparameter defaults created, based on a Pareto front volume calculation, rather than simply accuracy of the best expression. @@ -366,7 +367,36 @@ Base.propertynames(options::MyOptions) = (NEW_OPTIONS_KEYS..., fieldnames(Symbol These new abstractions provide users with greater flexibility in defining the structure and behavior of expressions, nodes, and the search process itself. These are also of course used as the basis for alternate behavior such as `ParametricExpression` and `TemplateExpression`. -### Fundamental improvements to the underlying evolutionary algorithm +### Added TensorBoardLogger.jl and other logging integrations via `SRLogger` + +You can now track the progress of symbolic regression searches using `TensorBoardLogger.jl`, `Wandb.jl`, or other logging backends. + +This is done by wrapping any `AbstractLogger` with the new `SRLogger` type, and passing it to the `logger` option in `SRRegressor` +or `equation_search`: + +```julia +using SymbolicRegression +using TensorBoardLogger + +logger = SRLogger( + TBLogger("logs/run"), + log_interval=2, # Log every 2 steps +) + +model = SRRegressor(; + binary_operators=[+, -, *], + logger=logger, +) +``` + +The logger will track: + +- Loss curves over time at each complexity level +- Population statistics (distribution of complexities) +- Pareto frontier volume (can be used as an overall metric of search performance) +- Full equations at each complexity level + +This works with any logger that implements the Julia logging interface. ### Support for Zygote.jl and Enzyme.jl within the constant optimizer, specified using the `autodiff_backend` option diff --git a/docs/src/api.md b/docs/src/api.md index 3abae7b8d..6c0de1a54 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -74,12 +74,7 @@ such as from TensorBoardLogger.jl or Wandb.jl. ```julia using TensorBoardLogger -logger = SRLogger( - TBLogger("logs/run1", tb_overwrite), # Base logger to use - log_interval_scalars=2, # Log scalar metrics every 2 steps - # log_interval_plots=0 # Log plots steps (requires Plots.jl to be loaded) - # ^ Set to greater than 0 to enable logging of plots -) +logger = SRLogger(TBLogger("logs/run"), log_interval=2) model = SRRegressor(; logger=logger, From 66dac138c6a84c0b42b33033de286a4d76d6098d Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sat, 9 Nov 2024 14:50:08 -0500 Subject: [PATCH 53/54] refactor: fix subtyping --- src/Logging.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Logging.jl b/src/Logging.jl index 707eb31b6..885755980 100644 --- a/src/Logging.jl +++ b/src/Logging.jl @@ -25,7 +25,7 @@ Abstract type for symbolic regression loggers. Subtypes must implement: reduce the logging frequency, you can increment and monitor a counter within this function. """ -abstract type AbstractSRLogger <: LG.AbstractLogger end +abstract type AbstractSRLogger <: AbstractLogger end """ SRLogger(logger::AbstractLogger; log_every_n::Integer=1) From 8c37c2695315108b3b689fb01b3b6bb1f0fcce95 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sat, 9 Nov 2024 15:06:15 -0500 Subject: [PATCH 54/54] test: logging to both TBLogger and SimpleLogger --- test/Project.toml | 1 + test/test_logging.jl | 22 ++++++++++++++++------ 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/test/Project.toml b/test/Project.toml index 8775bad76..298d4e8db 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -13,6 +13,7 @@ JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" +LoggingExtras = "e6f89c97-d47a-5376-807f-9c37f3926c36" LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890" MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea" diff --git a/test/test_logging.jl b/test/test_logging.jl index 1ce360ef8..9bfa5e7f5 100644 --- a/test/test_logging.jl +++ b/test/test_logging.jl @@ -1,12 +1,18 @@ @testitem "Test logging" tags = [:part1, :integration] begin - using SymbolicRegression, TensorBoardLogger, Logging, MLJBase + using SymbolicRegression + using TensorBoardLogger: TensorBoardLogger, TBLogger + using Logging: Logging, SimpleLogger + using LoggingExtras: TeeLogger + using MLJBase: machine, fit! include("test_params.jl") mktempdir() do dir - logger = SRLogger(; - logger=TBLogger(dir, tb_overwrite; min_level=Logging.Info), log_interval=2 - ) + buf = IOBuffer() + simple_logger = SimpleLogger(buf) + tb_logger = TBLogger(dir, TensorBoardLogger.tb_overwrite) + + logger = SRLogger(; logger=TeeLogger(simple_logger, tb_logger), log_interval=2) niterations = 4 populations = 36 @@ -25,12 +31,16 @@ fit!(mach) - b = TensorBoardLogger.steps(logger.logger) + # Check TensorBoardLogger + b = TensorBoardLogger.steps(tb_logger) @test length(b) == (niterations * populations//2) + 1 - files_and_dirs = readdir(dir) @test length(files_and_dirs) == 1 @test occursin(r"events\.out\.tfevents", only(files_and_dirs)) + + # Check SimpleLogger + s = String(take!(buf)) + @test occursin(r"search\s*\n\s*│\s*data\s*=\s*", s) end end @testitem "Test convex hull calculation" tags = [:part1] begin