From 495efdcb14d85cedd8a44d7b628331bd6ca3e091 Mon Sep 17 00:00:00 2001 From: Andrew Whitehead Date: Mon, 8 Aug 2022 16:23:32 -0700 Subject: [PATCH] only implement Message for iterators Signed-off-by: Andrew Whitehead --- src/hash_to_curve/expand_msg.rs | 70 +++++---------------------------- tests/expand_msg.rs | 18 ++++++++- tests/hash_to_curve_g1.rs | 8 ++-- tests/hash_to_curve_g2.rs | 9 +++-- 4 files changed, 38 insertions(+), 67 deletions(-) diff --git a/src/hash_to_curve/expand_msg.rs b/src/hash_to_curve/expand_msg.rs index 4e816fd5..33a3d9f9 100644 --- a/src/hash_to_curve/expand_msg.rs +++ b/src/hash_to_curve/expand_msg.rs @@ -102,71 +102,21 @@ pub trait Message { /// /// The parameters to successive calls to `f` are treated as a /// single concatenated octet string. - fn consume(self, f: impl FnMut(&[u8])); + fn input_message(self, f: impl FnMut(&[u8])); } -impl Message for &[u8] { - #[inline] - fn consume(self, mut f: impl FnMut(&[u8])) { - f(self) - } -} - -impl Message for &[u8; N] { - #[inline] - fn consume(self, mut f: impl FnMut(&[u8])) { - f(self) - } -} - -impl Message for &str { - #[inline] - fn consume(self, mut f: impl FnMut(&[u8])) { - f(self.as_bytes()) - } -} - -impl Message for &[&[u8]] { - #[inline] - fn consume(self, mut f: impl FnMut(&[u8])) { +impl Message for I +where + M: AsRef<[u8]>, + I: IntoIterator, +{ + fn input_message(self, mut f: impl FnMut(&[u8])) { for msg in self { - f(msg); + f(msg.as_ref()) } } } -#[cfg(feature = "alloc")] -impl Message for Vec { - #[inline] - fn consume(self, mut f: impl FnMut(&[u8])) { - f(self.as_slice()) - } -} - -#[cfg(feature = "alloc")] -impl Message for &Vec { - #[inline] - fn consume(self, mut f: impl FnMut(&[u8])) { - f(self.as_slice()) - } -} - -#[cfg(feature = "alloc")] -impl Message for alloc::string::String { - #[inline] - fn consume(self, mut f: impl FnMut(&[u8])) { - f(self.as_bytes()) - } -} - -#[cfg(feature = "alloc")] -impl Message for &alloc::string::String { - #[inline] - fn consume(self, mut f: impl FnMut(&[u8])) { - f(self.as_bytes()) - } -} - /// A trait for message expansion methods supported by hash-to-curve. pub trait ExpandMessage { /// Initializes a message expander. @@ -230,7 +180,7 @@ where let dst = ExpandMsgDst::for_xof::(dst); let mut hash = H::default(); - message.consume(|m| hash.update(m)); + message.input_message(|m| hash.update(m)); let reader = hash .chain((len_in_bytes as u16).to_be_bytes()) .chain(dst.data()) @@ -294,7 +244,7 @@ where let dst = ExpandMsgDst::for_xmd::(dst); let mut hash_b_0 = H::default().chain(GenericArray::::BlockSize>::default()); - message.consume(|m| hash_b_0.update(m)); + message.input_message(|m| hash_b_0.update(m)); let b_0 = hash_b_0 .chain((len_in_bytes as u16).to_be_bytes()) .chain([0u8]) diff --git a/tests/expand_msg.rs b/tests/expand_msg.rs index ff8a5528..dd1892d2 100644 --- a/tests/expand_msg.rs +++ b/tests/expand_msg.rs @@ -4,6 +4,22 @@ use hex_literal::hex; use sha2::{Sha256, Sha512}; use sha3::{Shake128, Shake256}; +#[test] +fn test_expand_message_parts() { + const EXPAND_LEN: usize = 16; + let mut b1 = [0u8; EXPAND_LEN]; + let mut b2 = [0u8; EXPAND_LEN]; + as ExpandMessage>::init_expand::<_, U32>( + [b"sig" as &[u8], b"nature"], + &[], + EXPAND_LEN, + ) + .read_into(&mut b1); + as ExpandMessage>::init_expand::<_, U32>([b"signature"], &[], EXPAND_LEN) + .read_into(&mut b2); + assert_eq!(b1, b2); +} + struct TestCase { msg: &'static [u8], dst: &'static [u8], @@ -16,7 +32,7 @@ impl TestCase { pub fn run(self) { let mut buf = [0u8; 128]; let output = &mut buf[..self.len_in_bytes]; - E::init_expand::<_, U32>(self.msg, self.dst, self.len_in_bytes).read_into(output); + E::init_expand::<_, U32>([self.msg], self.dst, self.len_in_bytes).read_into(output); if output != self.uniform_bytes { panic!( "Failed: expand_message.\n\ diff --git a/tests/hash_to_curve_g1.rs b/tests/hash_to_curve_g1.rs index 85108939..410888ee 100644 --- a/tests/hash_to_curve_g1.rs +++ b/tests/hash_to_curve_g1.rs @@ -96,8 +96,10 @@ fn hash_to_curve_works_for_draft16_testvectors_g1_sha256_ro() { ]; for case in cases { - let g = - >>::hash_to_curve(case.msg, case.dst); + let g = >>::hash_to_curve( + [case.msg], + case.dst, + ); let aff = G1Affine::from(g); let g_uncompressed = aff.to_uncompressed(); case.check_output(&g_uncompressed); @@ -175,7 +177,7 @@ fn encode_to_curve_works_for_draft16_testvectors_g1_sha256_nu() { for case in cases { let g = >>::encode_to_curve( - case.msg, case.dst, + [case.msg], case.dst, ); let aff = G1Affine::from(g); let g_uncompressed = aff.to_uncompressed(); diff --git a/tests/hash_to_curve_g2.rs b/tests/hash_to_curve_g2.rs index 46e40bf6..95c97d1d 100644 --- a/tests/hash_to_curve_g2.rs +++ b/tests/hash_to_curve_g2.rs @@ -116,8 +116,10 @@ fn hash_to_curve_works_for_draft16_testvectors_g2_sha256_ro() { ]; for case in cases { - let g = - >>::hash_to_curve(case.msg, case.dst); + let g = >>::hash_to_curve( + [case.msg], + case.dst, + ); let aff = G2Affine::from(g); let g_uncompressed = aff.to_uncompressed(); case.check_output(&g_uncompressed); @@ -215,7 +217,8 @@ fn encode_to_curve_works_for_draft16_testvectors_g2_sha256_nu() { for case in cases { let g = >>::encode_to_curve( - case.msg, case.dst, + [case.msg], + case.dst, ); let aff = G2Affine::from(g); let g_uncompressed = aff.to_uncompressed();