Skip to content

Commit

Permalink
only implement Message for iterators
Browse files Browse the repository at this point in the history
Signed-off-by: Andrew Whitehead <cywolf@gmail.com>
  • Loading branch information
andrewwhitehead authored and str4d committed Jul 21, 2024
1 parent 16f8d99 commit 39f36f1
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 67 deletions.
70 changes: 10 additions & 60 deletions src/hash_to_curve/expand_msg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,71 +102,21 @@ pub trait Message {
///
/// The parameters to successive calls to `f` are treated as a
/// single concatenated octet string.
fn consume(self, f: impl FnMut(&[u8]));
fn input_message(self, f: impl FnMut(&[u8]));
}

impl Message for &[u8] {
#[inline]
fn consume(self, mut f: impl FnMut(&[u8])) {
f(self)
}
}

impl<const N: usize> Message for &[u8; N] {
#[inline]
fn consume(self, mut f: impl FnMut(&[u8])) {
f(self)
}
}

impl Message for &str {
#[inline]
fn consume(self, mut f: impl FnMut(&[u8])) {
f(self.as_bytes())
}
}

impl Message for &[&[u8]] {
#[inline]
fn consume(self, mut f: impl FnMut(&[u8])) {
impl<M, I> Message for I
where
M: AsRef<[u8]>,
I: IntoIterator<Item = M>,
{
fn input_message(self, mut f: impl FnMut(&[u8])) {
for msg in self {
f(msg);
f(msg.as_ref())
}
}
}

#[cfg(feature = "alloc")]
impl Message for Vec<u8> {
#[inline]
fn consume(self, mut f: impl FnMut(&[u8])) {
f(self.as_slice())
}
}

#[cfg(feature = "alloc")]
impl Message for &Vec<u8> {
#[inline]
fn consume(self, mut f: impl FnMut(&[u8])) {
f(self.as_slice())
}
}

#[cfg(feature = "alloc")]
impl Message for alloc::string::String {
#[inline]
fn consume(self, mut f: impl FnMut(&[u8])) {
f(self.as_bytes())
}
}

#[cfg(feature = "alloc")]
impl Message for &alloc::string::String {
#[inline]
fn consume(self, mut f: impl FnMut(&[u8])) {
f(self.as_bytes())
}
}

/// A trait for message expansion methods supported by hash-to-curve.
pub trait ExpandMessage {
/// Initializes a message expander.
Expand Down Expand Up @@ -230,7 +180,7 @@ where

let dst = ExpandMsgDst::for_xof::<H, L>(dst);
let mut hash = H::default();
message.consume(|m| hash.update(m));
message.input_message(|m| hash.update(m));
let reader = hash
.chain((len_in_bytes as u16).to_be_bytes())
.chain(dst.data())
Expand Down Expand Up @@ -294,7 +244,7 @@ where
let dst = ExpandMsgDst::for_xmd::<H>(dst);
let mut hash_b_0 =
H::default().chain(GenericArray::<u8, <H as BlockInput>::BlockSize>::default());
message.consume(|m| hash_b_0.update(m));
message.input_message(|m| hash_b_0.update(m));
let b_0 = hash_b_0
.chain((len_in_bytes as u16).to_be_bytes())
.chain([0u8])
Expand Down
18 changes: 17 additions & 1 deletion tests/expand_msg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,22 @@ use hex_literal::hex;
use sha2::{Sha256, Sha512};
use sha3::{Shake128, Shake256};

#[test]
fn test_expand_message_parts() {
const EXPAND_LEN: usize = 16;
let mut b1 = [0u8; EXPAND_LEN];
let mut b2 = [0u8; EXPAND_LEN];
<ExpandMsgXmd<Sha256> as ExpandMessage>::init_expand::<_, U32>(
[b"sig" as &[u8], b"nature"],
&[],
EXPAND_LEN,
)
.read_into(&mut b1);
<ExpandMsgXmd<Sha256> as ExpandMessage>::init_expand::<_, U32>([b"signature"], &[], EXPAND_LEN)
.read_into(&mut b2);
assert_eq!(b1, b2);
}

struct TestCase {
msg: &'static [u8],
dst: &'static [u8],
Expand All @@ -16,7 +32,7 @@ impl TestCase {
pub fn run<E: ExpandMessage>(self) {
let mut buf = [0u8; 128];
let output = &mut buf[..self.len_in_bytes];
E::init_expand::<_, U32>(self.msg, self.dst, self.len_in_bytes).read_into(output);
E::init_expand::<_, U32>([self.msg], self.dst, self.len_in_bytes).read_into(output);
if output != self.uniform_bytes {
panic!(
"Failed: expand_message.\n\
Expand Down
9 changes: 6 additions & 3 deletions tests/hash_to_curve_g1.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,10 @@ fn hash_to_curve_works_for_draft16_testvectors_g1_sha256_ro() {
];

for case in cases {
let g =
<G1Projective as HashToCurve<ExpandMsgXmd<Sha256>>>::hash_to_curve(case.msg, case.dst);
let g = <G1Projective as HashToCurve<ExpandMsgXmd<Sha256>>>::hash_to_curve(
[case.msg],
case.dst,
);
let aff = G1Affine::from(g);
let g_uncompressed = aff.to_uncompressed();
case.check_output(&g_uncompressed);
Expand Down Expand Up @@ -175,7 +177,8 @@ fn encode_to_curve_works_for_draft16_testvectors_g1_sha256_nu() {

for case in cases {
let g = <G1Projective as HashToCurve<ExpandMsgXmd<Sha256>>>::encode_to_curve(
case.msg, case.dst,
[case.msg],
case.dst,
);
let aff = G1Affine::from(g);
let g_uncompressed = aff.to_uncompressed();
Expand Down
9 changes: 6 additions & 3 deletions tests/hash_to_curve_g2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,10 @@ fn hash_to_curve_works_for_draft16_testvectors_g2_sha256_ro() {
];

for case in cases {
let g =
<G2Projective as HashToCurve<ExpandMsgXmd<Sha256>>>::hash_to_curve(case.msg, case.dst);
let g = <G2Projective as HashToCurve<ExpandMsgXmd<Sha256>>>::hash_to_curve(
[case.msg],
case.dst,
);
let aff = G2Affine::from(g);
let g_uncompressed = aff.to_uncompressed();
case.check_output(&g_uncompressed);
Expand Down Expand Up @@ -215,7 +217,8 @@ fn encode_to_curve_works_for_draft16_testvectors_g2_sha256_nu() {

for case in cases {
let g = <G2Projective as HashToCurve<ExpandMsgXmd<Sha256>>>::encode_to_curve(
case.msg, case.dst,
[case.msg],
case.dst,
);
let aff = G2Affine::from(g);
let g_uncompressed = aff.to_uncompressed();
Expand Down

0 comments on commit 39f36f1

Please sign in to comment.