Skip to content

Commit

Permalink
Tree: various speedups (#887)
Browse files Browse the repository at this point in the history
* Tree: various speedups

- make dataclass non-frozen
- use mutate() for cases where a Map is modified multiple times
- remove asserts for cases that would fail immediately anyway

* make frozen depend on __debug__, restore an assert

* Improve depth()

Co-authored-by: Alexandru Fikl <alexfikl@gmail.com>

* opt ancestors

---------

Co-authored-by: Alexandru Fikl <alexfikl@gmail.com>
  • Loading branch information
matthiasdiener and alexfikl authored Dec 2, 2024
1 parent f113be0 commit 2b41e84
Showing 1 changed file with 19 additions and 31 deletions.
50 changes: 19 additions & 31 deletions loopy/schedule/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@
NodeT = TypeVar("NodeT", bound=Hashable)


@dataclass(frozen=True)
# Not frozen when optimizations are enabled because it is slower.
# Tree objects are immutable, and offer no way to mutate the tree.
@dataclass(frozen=__debug__) # type: ignore[literal-required]
class Tree(Generic[NodeT]):
"""
An immutable tree containing nodes of type :class:`NodeT`.
Expand Down Expand Up @@ -95,57 +97,42 @@ def ancestors(self, node: NodeT) -> tuple[NodeT, ...]:
"""
Returns a :class:`tuple` of nodes that are ancestors of *node*.
"""
assert node in self

if self.is_root(node):
parent = self.parent(node)
if parent is None:
# => root
return ()

parent = self._child_to_parent[node]
assert parent is not None

return (parent, *self.ancestors(parent))

def parent(self, node: NodeT) -> NodeT | None:
"""
Returns the parent of *node*.
"""
assert node in self

return self._child_to_parent[node]

def children(self, node: NodeT) -> tuple[NodeT, ...]:
"""
Returns the children of *node*.
"""
assert node in self

return self._parent_to_children[node]

@memoize_method
def depth(self, node: NodeT) -> int:
"""
Returns the depth of *node*, with the root having depth 0.
"""
assert node in self

if self.is_root(node):
# => None
return 0

parent_of_node = self.parent(node)
assert parent_of_node is not None
if parent_of_node is None:
return 0

return 1 + self.depth(parent_of_node)

def is_root(self, node: NodeT) -> bool:
assert node in self

"""Return *True* if *node* is the root of the tree."""
return self.parent(node) is None

def is_leaf(self, node: NodeT) -> bool:
assert node in self

"""Return *True* if *node* has no children."""
return len(self.children(node)) == 0

def __contains__(self, node: NodeT) -> bool:
Expand All @@ -162,9 +149,11 @@ def add_node(self, node: NodeT, parent: NodeT) -> Tree[NodeT]:

siblings = self._parent_to_children[parent]

return Tree((self._parent_to_children
.set(parent, (*siblings, node))
.set(node, ())),
_parent_to_children_mut = self._parent_to_children.mutate()
_parent_to_children_mut[parent] = (*siblings, node)
_parent_to_children_mut[node] = ()

return Tree(_parent_to_children_mut.finish(),
self._child_to_parent.set(node, parent))

def replace_node(self, node: NodeT, new_node: NodeT) -> Tree[NodeT]:
Expand Down Expand Up @@ -234,13 +223,12 @@ def move_node(self, node: NodeT, new_parent: NodeT | None) -> Tree[NodeT]:
parents_new_children = tuple(frozenset(siblings) - frozenset([node]))
new_parents_children = (*self.children(new_parent), node)

new_child_to_parent = self._child_to_parent.set(node, new_parent)
new_parent_to_children = (self._parent_to_children
.set(parent, parents_new_children)
.set(new_parent, new_parents_children))
_parent_to_children_mut = self._parent_to_children.mutate()
_parent_to_children_mut[parent] = parents_new_children
_parent_to_children_mut[new_parent] = new_parents_children

return Tree(new_parent_to_children,
new_child_to_parent)
return Tree(_parent_to_children_mut.finish(),
self._child_to_parent.set(node, new_parent))

def __str__(self) -> str:
"""
Expand Down

0 comments on commit 2b41e84

Please sign in to comment.