Skip to content

Commit

Permalink
Merge pull request #348 from MilesCranmer/rotate-operator
Browse files Browse the repository at this point in the history
Implement tree rotation operator
  • Loading branch information
MilesCranmer authored Oct 6, 2024
2 parents 749cc34 + da12afd commit 4e34473
Show file tree
Hide file tree
Showing 8 changed files with 218 additions and 6 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
test:
name: Julia ${{ matrix.julia-version }}-${{ matrix.os }}-${{ matrix.test }}-${{ github.event_name }}
runs-on: ${{ matrix.os }}
timeout-minutes: 120
timeout-minutes: 240
strategy:
fail-fast: false
matrix:
Expand Down
31 changes: 31 additions & 0 deletions benchmark/benchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,37 @@ function create_utils_benchmark()
s
end

if isdefined(SymbolicRegression.MutationFunctionsModule, :randomly_rotate_tree!)
suite["randomly_rotate_tree_x10"] = @benchmarkable(
foreach(trees) do tree
SymbolicRegression.MutationFunctionsModule.randomly_rotate_tree!(tree)
end,
setup = (
T = Float64;
nfeatures = 3;
trees = [
gen_random_tree_fixed_size(20, $options, nfeatures, T) for
i in 1:($ntrees)
]
)
)
end

suite["insert_random_op_x10"] = @benchmarkable(
foreach(trees) do tree
SymbolicRegression.MutationFunctionsModule.insert_random_op(
tree, $options, nfeatures
)
end,
setup = (
T = Float64;
nfeatures = 3;
trees = [
gen_random_tree_fixed_size(20, $options, nfeatures, T) for i in 1:($ntrees)
]
)
)

ntrees = 10
options = Options(;
unary_operators=[sin, cos],
Expand Down
7 changes: 6 additions & 1 deletion src/Mutate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ using ..MutationFunctionsModule:
delete_random_op!,
crossover_trees,
form_random_connection!,
break_random_connection!
break_random_connection!,
randomly_rotate_tree!
using ..ConstantOptimizationModule: optimize_constants
using ..RecorderModule: @recorder

Expand Down Expand Up @@ -259,6 +260,10 @@ function next_generation(
tree = break_random_connection!(tree)
@recorder tmp_recorder["type"] = "break_connection"
is_success_always_possible = true
elseif mutation_choice == :rotate_tree
tree = randomly_rotate_tree!(tree)
@recorder tmp_recorder["type"] = "rotate_tree"
is_success_always_possible = true
else
error("Unknown mutation choice: $mutation_choice")
end
Expand Down
106 changes: 106 additions & 0 deletions src/MutationFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -438,4 +438,110 @@ function break_random_connection!(tree::AbstractNode, rng::AbstractRNG=default_r
return tree
end

function is_valid_rotation_node(node::AbstractNode)
return (node.degree > 0 && node.l.degree > 0) || (node.degree == 2 && node.r.degree > 0)
end

function randomly_rotate_tree!(ex::AbstractExpression, rng::AbstractRNG=default_rng())
tree = get_contents(ex)
rotated_tree = randomly_rotate_tree!(tree, rng)
return with_contents(ex, rotated_tree)
end
function randomly_rotate_tree!(tree::AbstractNode, rng::AbstractRNG=default_rng())
num_rotation_nodes = count(is_valid_rotation_node, tree)

# Return the tree if no valid nodes are found
if num_rotation_nodes == 0
return tree
end

root_is_valid_rotation_node = is_valid_rotation_node(tree)

# Now, we decide if we want to rotate at the root, or at a random node
rotate_at_root = root_is_valid_rotation_node && rand(rng) < 1.0 / num_rotation_nodes

subtree_parent = if rotate_at_root
tree
else
rand(
rng,
NodeSampler(;
tree,
filter=t -> (
(t.degree > 0 && is_valid_rotation_node(t.l)) ||
(t.degree == 2 && is_valid_rotation_node(t.r))
),
),
)
end

subtree_side = if rotate_at_root
:n
elseif subtree_parent.degree == 1
:l
else
if is_valid_rotation_node(subtree_parent.l) &&
(!is_valid_rotation_node(subtree_parent.r) || rand(rng, Bool))
:l
else
:r
end
end

subtree_root = if rotate_at_root
tree
elseif subtree_side == :l
subtree_parent.l
else
subtree_parent.r
end

# Perform the rotation
# (reference: https://web.archive.org/web/20230326202118/https://upload.wikimedia.org/wikipedia/commons/1/15/Tree_Rotations.gif)
right_rotation_valid = subtree_root.l.degree > 0
left_rotation_valid = subtree_root.degree == 2 && subtree_root.r.degree > 0

right_rotation = right_rotation_valid && (!left_rotation_valid || rand(rng, Bool))
if right_rotation
node_5 = subtree_root
node_3 = leftmost(node_5)
node_4 = rightmost(node_3)

set_leftmost!(node_5, node_4)
set_rightmost!(node_3, node_5)
if rotate_at_root
return node_3 # new root
elseif subtree_side == :l
subtree_parent.l = node_3
else
subtree_parent.r = node_3
end
else # left rotation
node_3 = subtree_root
node_5 = rightmost(node_3)
node_4 = leftmost(node_5)

set_rightmost!(node_3, node_4)
set_leftmost!(node_5, node_3)
if rotate_at_root
return node_5 # new root
elseif subtree_side == :l
subtree_parent.l = node_5
else
subtree_parent.r = node_5
end
end

return tree
end

#! format: off
# These functions provide an easier way to work with unary nodes, by
# simply letting `.r` fall back to `.l` if the node is a unary operator.
leftmost(node::AbstractNode) = node.l
rightmost(node::AbstractNode) = node.degree == 1 ? node.l : node.r
set_leftmost!(node::AbstractNode, l::AbstractNode) = (node.l = l)
set_rightmost!(node::AbstractNode, r::AbstractNode) = node.degree == 1 ? (node.l = r) : (node.r = r)
#! format: on

end
2 changes: 2 additions & 0 deletions src/MutationWeights.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ will be normalized to sum to 1.0 after initialization.
- `mutate_constant::Float64`: How often to mutate a constant.
- `mutate_operator::Float64`: How often to mutate an operator.
- `swap_operands::Float64`: How often to swap the operands of a binary operator.
- `rotate_tree::Float64`: How often to perform a tree rotation at a random node.
- `add_node::Float64`: How often to append a node to the tree.
- `insert_node::Float64`: How often to insert a node into the tree.
- `delete_node::Float64`: How often to delete a node from the tree.
Expand All @@ -31,6 +32,7 @@ Base.@kwdef mutable struct MutationWeights
mutate_constant::Float64 = 0.048
mutate_operator::Float64 = 0.47
swap_operands::Float64 = 0.1
rotate_tree::Float64 = 0.3
add_node::Float64 = 0.79
insert_node::Float64 = 5.1
delete_node::Float64 = 1.7
Expand Down
4 changes: 0 additions & 4 deletions test/LocalPreferences.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
[DynamicExpressions]
instability_check = "error"
instability_check_codegen = "min"

[SymbolicRegression]
instability_check = "error"
instability_check_codegen = "min"
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ end
include("test_crossover.jl")
end

include("test_rotation.jl")

# TODO: This is another very slow test
@testitem "Test NaN detection in evaluator" tags = [:part1] begin
include("test_nan_detection.jl")
Expand Down
70 changes: 70 additions & 0 deletions test/test_rotation.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
@testitem "Basic `randomly_rotate_tree!`" tags = [:part1] begin
using SymbolicRegression
using SymbolicRegression.MutationFunctionsModule: randomly_rotate_tree!

# Create a simple binary tree structure directly
options = Options(; binary_operators=(+, *, -, /), unary_operators=(cos, exp))
x1, x2, x3 = (Node(; feature=1), Node(; feature=2), Node(; feature=3))

# No-op:
@test randomly_rotate_tree!(x1) === x1

expr = 1.5 * x1 + x2

# (+) -> ((*) -> (1.5, x1), x2)
# Should get rotated to
# (*) -> (1.5, (+) -> (x1, x2))

@test randomly_rotate_tree!(copy(expr)) == 1.5 * (x1 + x2)

# The only rotation option on this tree is to rotate back:
@test randomly_rotate_tree!(randomly_rotate_tree!(copy(expr))) == expr
end

@testitem "Complex `randomly_rotate_tree!`" tags = [:part1] begin
using SymbolicRegression
using SymbolicRegression.MutationFunctionsModule: randomly_rotate_tree!

# Create a simple binary tree structure directly
options = Options(; binary_operators=(+, *, -, /), unary_operators=(cos, exp))
x1, x2, x3 = (Node(; feature=1), Node(; feature=2), Node(; feature=3))

expr = (1.5 * x1) + (2.5 / x3)

# Multiple rotations possible:
# (+) -> ((*) -> (1.5, x1), (/) -> (2.5, x3))
# This can either get rotated to
# (*) -> (1.5, (+) -> (x1, (/) -> (2.5, x3)))
# OR
# (/) -> ((+) -> ((*) -> (1.5, x1), 2.5), x3)

outs = Set([randomly_rotate_tree!(copy(expr)) for _ in 1:100])

@test outs == Set([((1.5 * x1) + 2.5) / x3, 1.5 * (x1 + (2.5 / x3))])

# If we have a unary operator in the mix, both of these options are valid (with
# the unary operator moved in). We also have a third option that rotates with
# the unary operator acting as a pivot.

expr = (1.5 * exp(x1)) + (2.5 / x3)
outs = Set([randomly_rotate_tree!(copy(expr)) for _ in 1:300])
@test outs == Set([
((1.5 * exp(x1)) + 2.5) / x3,
1.5 * (exp(x1) + (2.5 / x3)),
exp(1.5 * x1) + (2.5 / x3),
])
# Basically this third option does a rotation on the `*`:
# (*) -> (1.5, (exp) -> (x1,))
# to
# (exp) -> ((*) -> (1.5, x1),)

# Or, if the unary operator is at the top:
expr = exp((1.5 * x1) + (2.5 / x3))
outs = Set([randomly_rotate_tree!(copy(expr)) for _ in 1:300])
@test outs == Set([
exp(((1.5 * x1) + 2.5) / x3),
exp(1.5 * (x1 + (2.5 / x3))),
# Rotate with `exp` as the *root*:
(1.5 * x1) + exp(2.5 / x3),
])
end

0 comments on commit 4e34473

Please sign in to comment.