Skip to content

Commit

Permalink
More tweaks to support custom root nodes.
Browse files Browse the repository at this point in the history
  • Loading branch information
MarkNahabedian committed Nov 26, 2024
1 parent ecaa63d commit 204f4bd
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 24 deletions.
4 changes: 2 additions & 2 deletions src/aggregation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ function askc(a::Aggregator, source)
value(a)
end

function askc(a::Aggregator, kb::ReteRootNode, t::Type)
function askc(a::Aggregator, kb::AbstractReteRootNode, t::Type)
askc(x -> aggregate(a, x), kb, t)
value(a)
end
Expand Down Expand Up @@ -62,7 +62,7 @@ function aggregate(a::Collector, thing)
end

# We can infer the query type from the Collector:
function askc(a::Collector{T}, kb::ReteRootNode) where T
function askc(a::Collector{T}, kb::AbstractReteRootNode) where T
askc(x -> aggregate(a, x), kb, T)
value(a)
end
Expand Down
14 changes: 8 additions & 6 deletions src/memory_nodes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ is_memory_for_type(node::IsaMemoryNode, typ::Type)::Bool =
typ == typeof(node).parameters[1]


function receive(node::IsaMemoryNode, fact)
function receive(node::AbstractMemoryNode, fact)
# Ignore facts not relevant to this memory node.
end

Expand All @@ -73,14 +73,14 @@ function askc(continuation::Function, node::IsaMemoryNode)
end

"""
askc(continuation::Function, root::ReteRootNode, t::Type)
askc(continuation::Function, root::AbstractReteRootNode, t::Type)
calls `continuation` on every fact of the specified type (or its
subtypes) that are stored in the network rooted at `root`.
Assumes all memory nodes are direct outputs of `root`.
"""
function askc(continuation::Function, root::ReteRootNode, t::Type)
function askc(continuation::Function, root::AbstractReteRootNode, t::Type)
for o in root.outputs
if o isa IsaMemoryNode
if length(typeof(o).parameters) == 1
Expand All @@ -100,8 +100,9 @@ If there's a memory node in the Rete represented by `root` that stores
objects of the specified type then return it. Otherwise return
nothing.
"""
function find_memory_for_type(root::ReteRootNode,
function find_memory_for_type(root::AbstractReteRootNode,
typ::Type)::Union{Nothing, AbstractMemoryNode}

for o in root.outputs
if is_memory_for_type(o, typ)
return o
Expand All @@ -112,7 +113,7 @@ end


"""
ensure_memory_node(root::ReteRootNode, typ::Type)::IsaTypeNode
ensure_memory_node(root::AbstractReteRootNode, typ::Type)::IsaTypeNode
Find a memory node for the specified type, or make one and add it
to the network.
Expand All @@ -121,7 +122,8 @@ The default is to make an IsaMemoryNode. Specialize this function for
a `Type` to control what type of memory node should be used for that
type.
"""
function ensure_memory_node(root::ReteRootNode, typ::Type)::AbstractMemoryNode
function ensure_memory_node(root::AbstractReteRootNode,
typ::Type)::AbstractMemoryNode
n = find_memory_for_type(root, typ)
if n !== nothing
return n
Expand Down
25 changes: 17 additions & 8 deletions src/root_nodes.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# The root node of a Rete

export CanInstallRulesTrait, ReteRootNode
export CanInstallRulesTrait, AbstractReteRootNode, ReteRootNode


"""
Expand All @@ -26,26 +26,35 @@ Installs the rule or rule group into the Rete rooted at `root`.
install(root::T, rule::Type) where T =
install(CanInstallRulesTrait(T), root, rule)

function install(::CanInstallRulesTrait, root, rule_group::Type)
function install(trait::CanInstallRulesTrait, root, rule_group::Type)
if isconcretetype(rule_group)
install(root, rule_group())
install(trait, root, rule_group())
else
for r in subtypes(rule_group)
install(root, r)
install(trait, root, r)
end
end
end


"""
AbstractReteRootNode
AbstractReteRootNode is the abstract supertype for all root nodes of a
Rete.
"""
abstract type AbstractReteRootNode <: AbstractReteNode end


"""
ReteRootNode
ReteRootNode serves as the root node of a Rete network.
If you need a specialized root node for your application, see
[`CanInstallRulesTrait`](@ref).
[`AbstractReteRootNode`](@ref) and [`CanInstallRulesTrait`](@ref).
"""
struct ReteRootNode <: AbstractReteNode
struct ReteRootNode <: AbstractReteRootNode
inputs::Set{AbstractReteNode}
outputs::Set{AbstractReteNode}
label::String
Expand All @@ -57,15 +66,15 @@ struct ReteRootNode <: AbstractReteNode
end
end

CanInstallRulesTrait(::Type{<:ReteRootNode}) = CanInstallRulesTrait()
CanInstallRulesTrait(::Type{ReteRootNode}) = CanInstallRulesTrait()

inputs(node::ReteRootNode) = node.inputs

outputs(node::ReteRootNode) = node.outputs

label(node::ReteRootNode) = node.label

function receive(node::ReteRootNode, fact)
function receive(node::AbstractReteRootNode, fact)
emit(node, fact)
end

3 changes: 2 additions & 1 deletion src/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,8 @@ macro rule(call, body)
install_method = []
if !rule_decls.custom_install
push!(install_method,
:(function Rete.install(root::ReteRootNode, ::$rule_name)
:(function Rete.install(::CanInstallRulesTrait,
root, ::$rule_name)
join = JoinNode($rule_name_str,
$(length(input_exprs)),
$rule_name())
Expand Down
8 changes: 4 additions & 4 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@ export kb_counts, kb_stats, copy_facts


"""
kb_counts(root::ReteRootNode)
kb_counts(root::AbstractReteRootNode)
Returns a `Dict{Type, Int}` of the number of facts of each type.
"""
function kb_counts(root::ReteRootNode)
function kb_counts(root::AbstractReteRootNode)
result = Dict{Type, Int}()
function walk(node)
if node isa IsaMemoryNode
Expand Down Expand Up @@ -57,14 +57,14 @@ kb_stats(node) = kb_stats(stdout, node)


"""
copy_facts(from_kb::ReteRootNode, to_kb::ReteRootNode, fact_types)
copy_facts(from_kb::AbstractReteRootNode, to_kb::AbstractReteRootNode, fact_types)
Copues facts if the specified `fact_type` from `from_kb` to `to_kb`.
for multiple fact types, one can broadcast over a collection of fact
types.
"""
function copy_facts(from_kb::ReteRootNode, to_kb::ReteRootNode,
function copy_facts(from_kb::AbstractReteRootNode, to_kb::AbstractReteRootNode,
fact_type)
askc(from_kb, fact_type) do fact
receive(to_kb, fact)
Expand Down
6 changes: 3 additions & 3 deletions test/test_rule_decls.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@ end
got_with = false
got_without = false
for m in methods(install)
if hasproperty(m.sig, :parameters)
if m.sig.parameters[3] == RuleWithoutCustomInstall
if hasproperty(m.sig, :parameters) && length(m.sig.parameters) >= 4
if m.sig.parameters[4] == RuleWithoutCustomInstall
got_without = true
elseif m.sig.parameters[3] == RuleWithCustomInstall
elseif m.sig.parameters[4] == RuleWithCustomInstall
got_with = true
end
end
Expand Down

0 comments on commit 204f4bd

Please sign in to comment.