diff --git a/shardtree/src/lib.rs b/shardtree/src/lib.rs index 011f37b..fff21b4 100644 --- a/shardtree/src/lib.rs +++ b/shardtree/src/lib.rs @@ -235,17 +235,17 @@ impl< let (append_result, position, checkpoint_id) = if let Some(subtree) = self.store.last_shard().map_err(ShardTreeError::Storage)? { - match subtree.max_position() { - // If the subtree is full, then construct a successor tree. - Some(pos) if pos == subtree.root_addr.max_position() => { - let addr = subtree.root_addr; - if subtree.root_addr.index() < Self::max_subtree_index() { - LocatedTree::empty(addr.next_at_level()).append(value, retention)? - } else { - return Err(InsertionError::TreeFull.into()); - } + if subtree.root().is_full() { + // If the shard is full, then construct a successor tree. + let addr = subtree.root_addr; + if subtree.root_addr.index() < Self::max_subtree_index() { + LocatedTree::empty(addr.next_at_level()).append(value, retention)? + } else { + return Err(InsertionError::TreeFull.into()); } - _ => subtree.append(value, retention)?, + } else { + // Otherwise, just append to the shard. + subtree.append(value, retention)? } } else { let root_addr = Address::from_parts(Self::subtree_level(), 0); @@ -573,13 +573,7 @@ impl< .map_err(ShardTreeError::Storage)?; if let Some(to_clear) = to_clear { - let pre_clearing_max_position = to_clear.max_position(); let cleared = to_clear.clear_flags(positions); - - // Clearing flags should not modify the max position of leaves represented - // in the shard. - assert!(cleared.max_position() == pre_clearing_max_position); - self.store .put_shard(cleared) .map_err(ShardTreeError::Storage)?; diff --git a/shardtree/src/prunable.rs b/shardtree/src/prunable.rs index 9130a7b..28d73ab 100644 --- a/shardtree/src/prunable.rs +++ b/shardtree/src/prunable.rs @@ -362,10 +362,11 @@ impl LocatedPrunableTree { /// If the tree contains any [`Node::Nil`] nodes that are to the left of filled nodes in the /// tree, this will return an error containing the addresses of those nodes. pub fn right_filled_root(&self) -> Result> { - self.root_hash( - self.max_position() - .map_or_else(|| self.root_addr.position_range_start(), |pos| pos + 1), - ) + let truncate_at = self + .max_position() + .map_or_else(|| self.root_addr.position_range_start(), |pos| pos + 1); + + self.root_hash(truncate_at) } /// Returns the positions of marked leaves in the tree. @@ -949,6 +950,33 @@ impl LocatedPrunableTree { root: go(&to_clear, self.root_addr, &self.root), } } + + #[cfg(test)] + pub(crate) fn flag_positions(&self) -> BTreeMap { + fn go( + root: &PrunableTree, + root_addr: Address, + acc: &mut BTreeMap, + ) { + match &root.0 { + Node::Parent { left, right, .. } => { + let (l_addr, r_addr) = root_addr + .children() + .expect("A parent node cannot appear at level 0"); + go(left, l_addr, acc); + go(right, r_addr, acc); + } + Node::Leaf { value } if value.1 != RetentionFlags::EPHEMERAL => { + acc.insert(root_addr.max_position(), value.1); + } + _ => (), + } + } + + let mut result = BTreeMap::new(); + go(&self.root, self.root_addr, &mut result); + result + } } // We need an applicative functor for Result for this function so that we can correctly @@ -971,13 +999,15 @@ fn accumulate_result_with( #[cfg(test)] mod tests { - use std::collections::BTreeSet; + use std::collections::{BTreeMap, BTreeSet}; use incrementalmerkletree::{Address, Level, Position}; + use proptest::proptest; use super::{LocatedPrunableTree, PrunableTree, RetentionFlags}; use crate::{ error::{InsertionError, QueryError}, + testing::{arb_char_str, arb_prunable_tree}, tree::{ tests::{leaf, nil, parent}, LocatedTree, @@ -1197,4 +1227,34 @@ mod tests { )])) ); } + + proptest! { + #[test] + fn clear_flags( + root in arb_prunable_tree(arb_char_str(), 8, 2^6) + ) { + let root_addr = Address::from_parts(Level::from(7), 0); + let tree = LocatedTree::from_parts(root_addr, root); + + let (to_clear, to_retain) = tree.flag_positions().into_iter().enumerate().fold( + (BTreeMap::new(), BTreeMap::new()), + |(mut to_clear, mut to_retain), (i, (pos, flags))| { + if i % 2 == 0 { + to_clear.insert(pos, flags); + } else { + to_retain.insert(pos, flags); + } + (to_clear, to_retain) + } + ); + + let pre_clearing_max_position = tree.max_position(); + let cleared = tree.clear_flags(to_clear); + + // Clearing flags should not modify the max position of leaves represented + // in the shard. + assert!(cleared.max_position() == pre_clearing_max_position); + assert_eq!(to_retain, cleared.flag_positions()); + } + } } diff --git a/shardtree/src/tree.rs b/shardtree/src/tree.rs index e92fd6c..30abb7c 100644 --- a/shardtree/src/tree.rs +++ b/shardtree/src/tree.rs @@ -148,6 +148,20 @@ impl Tree { matches!(&self.0, Node::Leaf { .. }) } + /// Returns `true` if no additional leaves can be appended to this tree. + /// + /// The tree is considered full if no `Nil` node exists along the right-hand + /// path in a depth-first traversal of this tree. In this case, no additional + /// nodes can be added to the right-hand side of this tree without introducing + /// a new `Parent` node having this tree as its left-hand child. + pub fn is_full(&self) -> bool { + match &self.0 { + Node::Nil => false, + Node::Leaf { .. } | Node::Pruned => true, + Node::Parent { right, .. } => right.is_full(), + } + } + /// Returns a vector of the addresses of [`Node::Nil`] and [`Node::Pruned`] subtree roots /// within this tree. /// @@ -260,19 +274,18 @@ impl LocatedTree { /// Note that no actual leaf value may exist at this position, as it may have previously been /// pruned. pub fn max_position(&self) -> Option { - Self::max_position_internal(self.root_addr, &self.root) - } - - pub(crate) fn max_position_internal(addr: Address, root: &Tree) -> Option { - match &root.0 { - Node::Nil => None, - Node::Leaf { .. } | Node::Pruned => Some(addr.position_range_end() - 1), - Node::Parent { left, right, .. } => { - let (l_addr, r_addr) = addr.children().unwrap(); - Self::max_position_internal(r_addr, right.as_ref()) - .or_else(|| Self::max_position_internal(l_addr, left.as_ref())) + fn go(addr: Address, root: &Tree) -> Option { + match &root.0 { + Node::Nil => None, + Node::Leaf { .. } | Node::Pruned => Some(addr.position_range_end() - 1), + Node::Parent { left, right, .. } => { + let (l_addr, r_addr) = addr.children().unwrap(); + go(r_addr, right.as_ref()).or_else(|| go(l_addr, left.as_ref())) + } } } + + go(self.root_addr, &self.root) } /// Returns the value at the specified position, if any.