Skip to content

Commit

Permalink
feat: add Column::update
Browse files Browse the repository at this point in the history
  • Loading branch information
uint committed Nov 13, 2024
1 parent ea76d6d commit f19496f
Showing 1 changed file with 74 additions and 8 deletions.
82 changes: 74 additions & 8 deletions packages/storey/src/containers/column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -383,10 +383,8 @@ where
/// access.set(1, &9001).unwrap();
/// assert_eq!(access.get(1).unwrap(), Some(9001));
/// ```
pub fn set(&mut self, id: u32, value: &T) -> Result<(), UpdateError<E::EncodeError>> {
self.storage
.get(&encode_id(id))
.ok_or(UpdateError::NotFound)?;
pub fn set(&mut self, id: u32, value: &T) -> Result<(), SetError<E::EncodeError>> {
self.storage.get(&encode_id(id)).ok_or(SetError::NotFound)?;

let bytes = value.encode()?;

Expand All @@ -395,6 +393,44 @@ where
Ok(())
}

/// Update the value associated with the given ID by applying a function to it.
///
/// The provided function is called with the current value, if it exists, and should return the
/// new value. If the function returns `None`, the value is removed.
///
/// # Example
/// ```
/// # use mocks::encoding::TestEncoding;
/// # use mocks::backend::TestStorage;
/// use storey::containers::Column;
///
/// let mut storage = TestStorage::new();
/// let column = Column::<u64, TestEncoding>::new(0);
/// let mut access = column.access(&mut storage);
///
/// access.push(&1337).unwrap();
/// assert_eq!(access.get(1).unwrap(), Some(1337));
///
/// access.update(1, |value| value.map(|v| v + 1)).unwrap();
/// assert_eq!(access.get(1).unwrap(), Some(1338));
/// ```
pub fn update<F>(
&mut self,
id: u32,
f: F,
) -> Result<(), UpdateError<E::DecodeError, E::EncodeError>>
where
F: FnOnce(Option<T>) -> Option<T>,
{
let new_value = f(self.get(id).map_err(UpdateError::Decode)?);
match new_value {
Some(value) => self.set(id, &value).map_err(UpdateError::Set),
None => self
.remove(id)
.map_err(|_| UpdateError::Set(SetError::NotFound)),
}
}

/// Remove the value associated with the given ID.
///
/// This operation leaves behind an empty slot in the column. The ID is not reused.
Expand Down Expand Up @@ -445,19 +481,27 @@ impl<E> From<E> for PushError<E> {
}

#[derive(Debug, PartialEq, Eq, Clone, Copy, Error)]
pub enum UpdateError<E> {
pub enum SetError<E> {
#[error("not found")]
NotFound,
#[error("{0}")]
EncodingError(E),
}

impl<E> From<E> for UpdateError<E> {
impl<E> From<E> for SetError<E> {
fn from(e: E) -> Self {
UpdateError::EncodingError(e)
SetError::EncodingError(e)
}
}

#[derive(Debug, PartialEq, Eq, Clone, Copy, Error)]
pub enum UpdateError<D, E> {
#[error("decode error: {0}")]
Decode(D),
#[error("set error: {0}")]
Set(SetError<E>),
}

#[derive(Debug, PartialEq, Eq, Clone, Copy, Error)]
pub enum RemoveError {
#[error("inconsistent state")]
Expand Down Expand Up @@ -497,7 +541,7 @@ mod tests {
assert_eq!(access.len().unwrap(), 2);

access.remove(1).unwrap();
assert_eq!(access.set(1, &9001), Err(UpdateError::NotFound));
assert_eq!(access.set(1, &9001), Err(SetError::NotFound));
access.set(2, &9001).unwrap();

assert_eq!(access.get(1).unwrap(), None);
Expand Down Expand Up @@ -535,6 +579,28 @@ mod tests {
assert_eq!(access.len().unwrap(), 1);
}

#[test]
fn update() {
let mut storage = TestStorage::new();

let column = Column::<u64, TestEncoding>::new(0);
let mut access = column.access(&mut storage);

access.push(&1337).unwrap();
access.push(&42).unwrap();
access.push(&9001).unwrap();
access.remove(2).unwrap();

access.update(1, |value| value.map(|v| v + 1)).unwrap();
assert_eq!(access.get(1).unwrap(), Some(1338));

access.update(2, |value| value.map(|v| v + 1)).unwrap();
assert_eq!(access.get(2).unwrap(), None);

access.update(3, |value| value.map(|v| v + 1)).unwrap();
assert_eq!(access.get(3).unwrap(), Some(9002));
}

#[test]
fn iteration() {
let mut storage = TestStorage::new();
Expand Down

0 comments on commit f19496f

Please sign in to comment.