Skip to content

Commit

Permalink
Merge pull request #83 from CosmWasm/75-update-method
Browse files Browse the repository at this point in the history
Add `update` method to `ItemAccess`
  • Loading branch information
uint authored Nov 13, 2024
2 parents b47c699 + f19496f commit a78a034
Show file tree
Hide file tree
Showing 5 changed files with 161 additions and 28 deletions.
21 changes: 15 additions & 6 deletions packages/mocks/src/encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,18 @@ use storey_encoding::{Cover, DecodableWithImpl, EncodableWithImpl, Encoding};

pub struct TestEncoding;

#[derive(Debug, PartialEq)]
pub struct MockError;

impl std::fmt::Display for MockError {
fn fmt(&self, _: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Ok(())
}
}

impl Encoding for TestEncoding {
type DecodeError = ();
type EncodeError = ();
type DecodeError = MockError;
type EncodeError = MockError;
}

// This is how we would implement `EncodableWith` and `DecodableWith` for
Expand Down Expand Up @@ -39,16 +48,16 @@ where
// Imagine `MyTestEncoding` is a third-party trait that we don't control.

trait MyTestEncoding: Sized {
fn my_encode(&self) -> Result<Vec<u8>, ()>;
fn my_decode(data: &[u8]) -> Result<Self, ()>;
fn my_encode(&self) -> Result<Vec<u8>, MockError>;
fn my_decode(data: &[u8]) -> Result<Self, MockError>;
}

impl MyTestEncoding for u64 {
fn my_encode(&self) -> Result<Vec<u8>, ()> {
fn my_encode(&self) -> Result<Vec<u8>, MockError> {
Ok(self.to_le_bytes().to_vec())
}

fn my_decode(data: &[u8]) -> Result<Self, ()> {
fn my_decode(data: &[u8]) -> Result<Self, MockError> {
let mut bytes = [0u8; 8];
bytes.copy_from_slice(data);
Ok(u64::from_le_bytes(bytes))
Expand Down
4 changes: 2 additions & 2 deletions packages/storey-encoding/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
pub trait Encoding {
/// The error type returned when encoding fails.
type EncodeError;
type EncodeError: std::fmt::Display;

/// The error type returned when decoding fails.
type DecodeError;
type DecodeError: std::fmt::Display;
}

pub trait EncodableWith<E: Encoding>: sealed::SealedE<E> {
Expand Down
88 changes: 77 additions & 11 deletions packages/storey/src/containers/column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ where
Ok(id)
}

/// Update the value associated with the given ID.
/// Set the value associated with the given ID.
///
/// # Example
/// ```
Expand All @@ -380,13 +380,11 @@ where
/// access.push(&1337).unwrap();
/// assert_eq!(access.get(1).unwrap(), Some(1337));
///
/// access.update(1, &9001).unwrap();
/// access.set(1, &9001).unwrap();
/// assert_eq!(access.get(1).unwrap(), Some(9001));
/// ```
pub fn update(&mut self, id: u32, value: &T) -> Result<(), UpdateError<E::EncodeError>> {
self.storage
.get(&encode_id(id))
.ok_or(UpdateError::NotFound)?;
pub fn set(&mut self, id: u32, value: &T) -> Result<(), SetError<E::EncodeError>> {
self.storage.get(&encode_id(id)).ok_or(SetError::NotFound)?;

let bytes = value.encode()?;

Expand All @@ -395,6 +393,44 @@ where
Ok(())
}

/// Update the value associated with the given ID by applying a function to it.
///
/// The provided function is called with the current value, if it exists, and should return the
/// new value. If the function returns `None`, the value is removed.
///
/// # Example
/// ```
/// # use mocks::encoding::TestEncoding;
/// # use mocks::backend::TestStorage;
/// use storey::containers::Column;
///
/// let mut storage = TestStorage::new();
/// let column = Column::<u64, TestEncoding>::new(0);
/// let mut access = column.access(&mut storage);
///
/// access.push(&1337).unwrap();
/// assert_eq!(access.get(1).unwrap(), Some(1337));
///
/// access.update(1, |value| value.map(|v| v + 1)).unwrap();
/// assert_eq!(access.get(1).unwrap(), Some(1338));
/// ```
pub fn update<F>(
&mut self,
id: u32,
f: F,
) -> Result<(), UpdateError<E::DecodeError, E::EncodeError>>
where
F: FnOnce(Option<T>) -> Option<T>,
{
let new_value = f(self.get(id).map_err(UpdateError::Decode)?);
match new_value {
Some(value) => self.set(id, &value).map_err(UpdateError::Set),
None => self
.remove(id)
.map_err(|_| UpdateError::Set(SetError::NotFound)),
}
}

/// Remove the value associated with the given ID.
///
/// This operation leaves behind an empty slot in the column. The ID is not reused.
Expand Down Expand Up @@ -445,19 +481,27 @@ impl<E> From<E> for PushError<E> {
}

#[derive(Debug, PartialEq, Eq, Clone, Copy, Error)]
pub enum UpdateError<E> {
pub enum SetError<E> {
#[error("not found")]
NotFound,
#[error("{0}")]
EncodingError(E),
}

impl<E> From<E> for UpdateError<E> {
impl<E> From<E> for SetError<E> {
fn from(e: E) -> Self {
UpdateError::EncodingError(e)
SetError::EncodingError(e)
}
}

#[derive(Debug, PartialEq, Eq, Clone, Copy, Error)]
pub enum UpdateError<D, E> {
#[error("decode error: {0}")]
Decode(D),
#[error("set error: {0}")]
Set(SetError<E>),
}

#[derive(Debug, PartialEq, Eq, Clone, Copy, Error)]
pub enum RemoveError {
#[error("inconsistent state")]
Expand Down Expand Up @@ -497,8 +541,8 @@ mod tests {
assert_eq!(access.len().unwrap(), 2);

access.remove(1).unwrap();
assert_eq!(access.update(1, &9001), Err(UpdateError::NotFound));
access.update(2, &9001).unwrap();
assert_eq!(access.set(1, &9001), Err(SetError::NotFound));
access.set(2, &9001).unwrap();

assert_eq!(access.get(1).unwrap(), None);
assert_eq!(access.get(2).unwrap(), Some(9001));
Expand Down Expand Up @@ -535,6 +579,28 @@ mod tests {
assert_eq!(access.len().unwrap(), 1);
}

#[test]
fn update() {
let mut storage = TestStorage::new();

let column = Column::<u64, TestEncoding>::new(0);
let mut access = column.access(&mut storage);

access.push(&1337).unwrap();
access.push(&42).unwrap();
access.push(&9001).unwrap();
access.remove(2).unwrap();

access.update(1, |value| value.map(|v| v + 1)).unwrap();
assert_eq!(access.get(1).unwrap(), Some(1338));

access.update(2, |value| value.map(|v| v + 1)).unwrap();
assert_eq!(access.get(2).unwrap(), None);

access.update(3, |value| value.map(|v| v + 1)).unwrap();
assert_eq!(access.get(3).unwrap(), Some(9002));
}

#[test]
fn iteration() {
let mut storage = TestStorage::new();
Expand Down
59 changes: 58 additions & 1 deletion packages/storey/src/containers/item.rs
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ impl<E, T, S> ItemAccess<E, T, S>
where
E: Encoding,
T: EncodableWith<E> + DecodableWith<E>,
S: StorageMut,
S: Storage + StorageMut,
{
/// Set the value of the item.
///
Expand All @@ -234,6 +234,39 @@ where
Ok(())
}

/// Update the value of the item.
///
/// The function `f` is called with the current value of the item, if it exists.
/// If the function returns `Some`, the item is set to the new value.
/// If the function returns `None`, the item is removed.
///
/// # Example
/// ```
/// # use mocks::encoding::TestEncoding;
/// # use mocks::backend::TestStorage;
/// use storey::containers::Item;
///
/// let mut storage = TestStorage::new();
/// let item = Item::<u64, TestEncoding>::new(0);
///
/// item.access(&mut storage).set(&42).unwrap();
/// item.access(&mut storage).update(|value| value.map(|v| v + 1)).unwrap();
/// assert_eq!(item.access(&storage).get().unwrap(), Some(43));
/// ```
pub fn update<F>(&mut self, f: F) -> Result<(), UpdateError<E::DecodeError, E::EncodeError>>
where
F: FnOnce(Option<T>) -> Option<T>,
{
let new_value = f(self.get().map_err(UpdateError::Decode)?);
match new_value {
Some(value) => self.set(&value).map_err(UpdateError::Encode),
None => {
self.remove();
Ok(())
}
}
}

/// Remove the value of the item.
///
/// # Example
Expand All @@ -254,6 +287,14 @@ where
}
}

#[derive(Debug, PartialEq, Eq, Clone, Copy, thiserror::Error)]
pub enum UpdateError<D, E> {
#[error("decode error: {0}")]
Decode(D),
#[error("encode error: {0}")]
Encode(E),
}

#[cfg(test)]
mod tests {
use super::*;
Expand All @@ -276,4 +317,20 @@ mod tests {
assert_eq!(access1.get().unwrap(), None);
assert_eq!(storage.get(&[1]), None);
}

#[test]
fn update() {
let mut storage = TestStorage::new();

let item = Item::<u64, TestEncoding>::new(0);
item.access(&mut storage).set(&42).unwrap();

item.access(&mut storage)
.update(|value| value.map(|v| v + 1))
.unwrap();
assert_eq!(item.access(&storage).get().unwrap(), Some(43));

item.access(&mut storage).update(|_| None).unwrap();
assert_eq!(item.access(&storage).get().unwrap(), None);
}
}
17 changes: 9 additions & 8 deletions packages/storey/src/encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,15 @@
//! struct DisplayEncoding;
//!
//! impl Encoding for DisplayEncoding {
//! type DecodeError = ();
//! type EncodeError = ();
//! type DecodeError = String;
//! type EncodeError = String;
//! }
//!
//! impl<T> EncodableWithImpl<DisplayEncoding> for Cover<&T,>
//! where
//! T: std::fmt::Display,
//! {
//! fn encode_impl(self) -> Result<Vec<u8>, ()> {
//! fn encode_impl(self) -> Result<Vec<u8>, String> {
//! Ok(format!("{}", self.0).into_bytes())
//! }
//! }
Expand All @@ -67,17 +67,18 @@
//! struct DisplayEncoding;
//!
//! impl Encoding for DisplayEncoding {
//! type DecodeError = ();
//! type EncodeError = ();
//! type DecodeError = String;
//! type EncodeError = String;
//! }
//!
//! impl<T> DecodableWithImpl<DisplayEncoding> for Cover<T>
//! where
//! T: std::str::FromStr,
//! {
//! fn decode_impl(data: &[u8]) -> Result<Self, ()> {
//! let string = String::from_utf8(data.to_vec()).map_err(|_| ())?;
//! let value = string.parse().map_err(|_| ())?;
//! fn decode_impl(data: &[u8]) -> Result<Self, String> {
//! let string =
//! String::from_utf8(data.to_vec()).map_err(|_| "string isn't UTF-8".to_string())?;
//! let value = string.parse().map_err(|_| "parsing failed".to_string())?;
//! Ok(Cover(value))
//! }
//! }
Expand Down

0 comments on commit a78a034

Please sign in to comment.