diff --git a/Project.toml b/Project.toml index 3988b027..1b658198 100644 --- a/Project.toml +++ b/Project.toml @@ -48,7 +48,7 @@ Dates = "1" DifferentiationInterface = "0.5, 0.6" DispatchDoctor = "^0.4.17" Distributed = "<0.0.1, 1" -DynamicExpressions = "1.6.0" +DynamicExpressions = "~1.8" DynamicQuantities = "1" Enzyme = "0.12, 0.13" JSON3 = "1" diff --git a/src/ComposableExpression.jl b/src/ComposableExpression.jl index f47ab9a5..0247cf10 100644 --- a/src/ComposableExpression.jl +++ b/src/ComposableExpression.jl @@ -112,6 +112,15 @@ end function CO.count_constants_for_optimization(ex::AbstractComposableExpression) return CO.count_constants_for_optimization(convert(Expression, ex)) end +function DE.allocate_container( + prototype::ComposableExpression, n::Union{Nothing,Integer}=nothing +) + return (; tree=DE.allocate_container(get_contents(prototype), n)) +end +function DE.copy_into!(dest::NamedTuple, src::ComposableExpression) + new_tree = DE.copy_into!(dest.tree, get_contents(src)) + return DE.with_contents(src, new_tree) +end @implements( ExpressionInterface{all_ei_methods_except(())}, ComposableExpression, [Arguments()] diff --git a/src/Mutate.jl b/src/Mutate.jl index e3ab8993..cc343351 100644 --- a/src/Mutate.jl +++ b/src/Mutate.jl @@ -2,11 +2,13 @@ module MutateModule using DynamicExpressions: AbstractExpression, + copy_into!, get_tree, preserve_sharing, count_scalar_constants, simplify_tree!, - combine_operators + combine_operators, + allocate_container using ..CoreModule: AbstractOptions, AbstractMutationWeights, @@ -187,13 +189,14 @@ function next_generation( successful_mutation = false attempts = 0 max_attempts = 10 + node_storage = allocate_container(member.tree) ############################################# # Mutations ############################################# local tree while (!successful_mutation) && attempts < max_attempts - tree = copy(member.tree) + tree = copy_into!(node_storage, member.tree) mutation_result = _dispatch_mutations!( tree, @@ -238,7 +241,7 @@ function next_generation( mutation_accepted = false return ( PopMember( - copy(member.tree), + copy_into!(node_storage, member.tree), beforeScore, beforeLoss, options, @@ -267,7 +270,7 @@ function next_generation( mutation_accepted = false return ( PopMember( - copy(member.tree), + copy_into!(node_storage, member.tree), beforeScore, beforeLoss, options, @@ -310,7 +313,7 @@ function next_generation( mutation_accepted = false return ( PopMember( - copy(member.tree), + copy_into!(node_storage, member.tree), beforeScore, beforeLoss, options,