diff --git a/packages/storey/src/containers/column.rs b/packages/storey/src/containers/column.rs index df39eed..a9f3a77 100644 --- a/packages/storey/src/containers/column.rs +++ b/packages/storey/src/containers/column.rs @@ -383,10 +383,8 @@ where /// access.set(1, &9001).unwrap(); /// assert_eq!(access.get(1).unwrap(), Some(9001)); /// ``` - pub fn set(&mut self, id: u32, value: &T) -> Result<(), UpdateError> { - self.storage - .get(&encode_id(id)) - .ok_or(UpdateError::NotFound)?; + pub fn set(&mut self, id: u32, value: &T) -> Result<(), SetError> { + self.storage.get(&encode_id(id)).ok_or(SetError::NotFound)?; let bytes = value.encode()?; @@ -395,6 +393,44 @@ where Ok(()) } + /// Update the value associated with the given ID by applying a function to it. + /// + /// The provided function is called with the current value, if it exists, and should return the + /// new value. If the function returns `None`, the value is removed. + /// + /// # Example + /// ``` + /// # use mocks::encoding::TestEncoding; + /// # use mocks::backend::TestStorage; + /// use storey::containers::Column; + /// + /// let mut storage = TestStorage::new(); + /// let column = Column::::new(0); + /// let mut access = column.access(&mut storage); + /// + /// access.push(&1337).unwrap(); + /// assert_eq!(access.get(1).unwrap(), Some(1337)); + /// + /// access.update(1, |value| value.map(|v| v + 1)).unwrap(); + /// assert_eq!(access.get(1).unwrap(), Some(1338)); + /// ``` + pub fn update( + &mut self, + id: u32, + f: F, + ) -> Result<(), UpdateError> + where + F: FnOnce(Option) -> Option, + { + let new_value = f(self.get(id).map_err(UpdateError::Decode)?); + match new_value { + Some(value) => self.set(id, &value).map_err(UpdateError::Set), + None => self + .remove(id) + .map_err(|_| UpdateError::Set(SetError::NotFound)), + } + } + /// Remove the value associated with the given ID. /// /// This operation leaves behind an empty slot in the column. The ID is not reused. @@ -445,19 +481,27 @@ impl From for PushError { } #[derive(Debug, PartialEq, Eq, Clone, Copy, Error)] -pub enum UpdateError { +pub enum SetError { #[error("not found")] NotFound, #[error("{0}")] EncodingError(E), } -impl From for UpdateError { +impl From for SetError { fn from(e: E) -> Self { - UpdateError::EncodingError(e) + SetError::EncodingError(e) } } +#[derive(Debug, PartialEq, Eq, Clone, Copy, Error)] +pub enum UpdateError { + #[error("decode error: {0}")] + Decode(D), + #[error("set error: {0}")] + Set(SetError), +} + #[derive(Debug, PartialEq, Eq, Clone, Copy, Error)] pub enum RemoveError { #[error("inconsistent state")] @@ -497,7 +541,7 @@ mod tests { assert_eq!(access.len().unwrap(), 2); access.remove(1).unwrap(); - assert_eq!(access.set(1, &9001), Err(UpdateError::NotFound)); + assert_eq!(access.set(1, &9001), Err(SetError::NotFound)); access.set(2, &9001).unwrap(); assert_eq!(access.get(1).unwrap(), None); @@ -535,6 +579,28 @@ mod tests { assert_eq!(access.len().unwrap(), 1); } + #[test] + fn update() { + let mut storage = TestStorage::new(); + + let column = Column::::new(0); + let mut access = column.access(&mut storage); + + access.push(&1337).unwrap(); + access.push(&42).unwrap(); + access.push(&9001).unwrap(); + access.remove(2).unwrap(); + + access.update(1, |value| value.map(|v| v + 1)).unwrap(); + assert_eq!(access.get(1).unwrap(), Some(1338)); + + access.update(2, |value| value.map(|v| v + 1)).unwrap(); + assert_eq!(access.get(2).unwrap(), None); + + access.update(3, |value| value.map(|v| v + 1)).unwrap(); + assert_eq!(access.get(3).unwrap(), Some(9002)); + } + #[test] fn iteration() { let mut storage = TestStorage::new();