diff --git a/Cargo.lock b/Cargo.lock index c1b891707..872d8cab0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -17,6 +17,28 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" +[[package]] +name = "aes" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b169f7a6d4742236a0a00c541b845991d0ac43e546831af1249753ab4c3aa3a0" +dependencies = [ + "cfg-if", + "cipher", + "cpufeatures", +] + +[[package]] +name = "aes-prng" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49cccd49cb7034d6ee7db9ac3549bb3fb38ff17179d93b726efb974cc9ddafa9" +dependencies = [ + "aes", + "byteorder", + "rand", +] + [[package]] name = "ahash" version = "0.8.11" @@ -907,6 +929,16 @@ dependencies = [ "half", ] +[[package]] +name = "cipher" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "773f3b9af64447d2ce9850330c473515014aa235e6a783b02db81ff39e4a3dad" +dependencies = [ + "crypto-common", + "inout", +] + [[package]] name = "clap" version = "4.5.16" @@ -2220,6 +2252,15 @@ version = "2.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b248f5224d1d606005e02c97f5aa4e88eeb230488bcc03bc9ca4d7991399f2b5" +[[package]] +name = "inout" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a0c10553d664a4d0bcff9f4215d0aac67a639cc68ef660840afe309b807bc9f5" +dependencies = [ + "generic-array", +] + [[package]] name = "instant" version = "0.1.13" @@ -2318,6 +2359,24 @@ dependencies = [ "zeroize", ] +[[package]] +name = "iris-mpc-cpu" +version = "0.1.0" +dependencies = [ + "aes-prng", + "bytemuck", + "bytes", + "eyre", + "iris-mpc-common", + "num-traits", + "rand", + "rayon", + "serde", + "thiserror", + "tokio", + "tracing", +] + [[package]] name = "iris-mpc-gpu" version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml index e0de04056..2bf1dc80c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,7 @@ [workspace] members = [ "iris-mpc", + "iris-mpc-cpu", "iris-mpc-gpu", "iris-mpc-common", "iris-mpc-upgrade", diff --git a/iris-mpc-cpu/Cargo.toml b/iris-mpc-cpu/Cargo.toml new file mode 100644 index 000000000..e0f4a8061 --- /dev/null +++ b/iris-mpc-cpu/Cargo.toml @@ -0,0 +1,20 @@ +[package] +name = "iris-mpc-cpu" +version = "0.1.0" +edition = "2021" +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + + +[dependencies] +aes-prng = "0.2" +bytes = "1.7" +bytemuck.workspace = true +eyre.workspace = true +iris-mpc-common = { path = "../iris-mpc-common" } +num-traits = "0.2" +rand.workspace = true +rayon.workspace = true +serde.workspace = true +tokio.workspace = true +thiserror = "1.0" +tracing.workspace = true diff --git a/iris-mpc-cpu/src/error.rs b/iris-mpc-cpu/src/error.rs new file mode 100644 index 000000000..63ec7041d --- /dev/null +++ b/iris-mpc-cpu/src/error.rs @@ -0,0 +1,60 @@ +use rayon::ThreadPoolBuildError; +use thiserror::Error; + +/// An Error enum capturing the errors produced by this crate. +#[derive(Error, Debug)] +pub enum Error { + /// Config Error + #[error("Invalid Configuration")] + Config, + /// Type conversion error + #[error("Conversion error")] + Conversion, + /// Error from the color_eyre crate + #[error(transparent)] + Eyre(#[from] eyre::Report), + /// Invalid party id provided + #[error("Invalid Party id {0}")] + Id(usize), + /// Message size is invalid + #[error("Message size is invalid")] + InvalidMessageSize, + /// Size is invalid + #[error("Size is invalid")] + InvalidSize, + /// A IO error has orccured + #[error(transparent)] + IO(#[from] std::io::Error), + /// JMP verify failed + #[error("JMP verify failed")] + JmpVerify, + /// Mask HW is to small + #[error("Mask HW is to small")] + MaskHW, + /// Not enough triples + #[error("Not enough triples")] + NotEnoughTriples, + /// Invalid number of parties + #[error("Invalid number of parties {0}")] + NumParty(usize), + /// Verify failed + #[error("Verify failed")] + Verify, + #[error(transparent)] + ThreadPoolBuildError(#[from] ThreadPoolBuildError), + /// Some other error has occurred. + #[error("Err: {0}")] + Other(String), +} + +impl From for Error { + fn from(mes: String) -> Self { + Self::Other(mes) + } +} + +impl From<&str> for Error { + fn from(mes: &str) -> Self { + Self::Other(mes.to_owned()) + } +} diff --git a/iris-mpc-cpu/src/lib.rs b/iris-mpc-cpu/src/lib.rs new file mode 100644 index 000000000..db9182fc1 --- /dev/null +++ b/iris-mpc-cpu/src/lib.rs @@ -0,0 +1,5 @@ +pub(crate) mod error; +pub(crate) mod networks; +pub(crate) mod protocol; +pub(crate) mod shares; +pub(crate) mod utils; diff --git a/iris-mpc-cpu/src/networks/mod.rs b/iris-mpc-cpu/src/networks/mod.rs new file mode 100644 index 000000000..6274c758e --- /dev/null +++ b/iris-mpc-cpu/src/networks/mod.rs @@ -0,0 +1,2 @@ +pub(crate) mod network_trait; +pub(crate) mod test_network; diff --git a/iris-mpc-cpu/src/networks/network_trait.rs b/iris-mpc-cpu/src/networks/network_trait.rs new file mode 100644 index 000000000..bcd2605a3 --- /dev/null +++ b/iris-mpc-cpu/src/networks/network_trait.rs @@ -0,0 +1,43 @@ +use crate::error::Error; +use bytes::{Bytes, BytesMut}; +use iris_mpc_common::id::PartyID; + +pub type IoError = std::io::Error; + +#[allow(async_fn_in_trait)] +pub trait NetworkTrait: Send + Sync { + fn get_id(&self) -> PartyID; + + async fn shutdown(self) -> Result<(), IoError>; + + async fn send(&mut self, id: PartyID, data: Bytes) -> Result<(), IoError>; + async fn send_next_id(&mut self, data: Bytes) -> Result<(), IoError>; + async fn send_prev_id(&mut self, data: Bytes) -> Result<(), IoError>; + + async fn receive(&mut self, id: PartyID) -> Result; + async fn receive_prev_id(&mut self) -> Result; + async fn receive_next_id(&mut self) -> Result; + + async fn broadcast(&mut self, data: Bytes) -> Result, IoError>; + //======= sync world ========= + fn blocking_send(&mut self, id: PartyID, data: Bytes) -> Result<(), IoError>; + fn blocking_send_next_id(&mut self, data: Bytes) -> Result<(), IoError>; + fn blocking_send_prev_id(&mut self, data: Bytes) -> Result<(), IoError>; + + fn blocking_receive(&mut self, id: PartyID) -> Result; + fn blocking_receive_prev_id(&mut self) -> Result; + fn blocking_receive_next_id(&mut self) -> Result; + + fn blocking_broadcast(&mut self, data: Bytes) -> Result, IoError>; +} + +#[allow(async_fn_in_trait)] +pub trait NetworkEstablisher { + fn get_id(&self) -> PartyID; + fn get_num_parties(&self) -> usize; + async fn open_channel(&mut self) -> Result; + async fn shutdown(self) -> Result<(), Error>; + //======= sync world ========= + fn print_connection_stats(&self, out: &mut impl std::io::Write) -> std::io::Result<()>; + fn get_send_receive(&self, i: usize) -> std::io::Result<(u64, u64)>; +} diff --git a/iris-mpc-cpu/src/networks/test_network.rs b/iris-mpc-cpu/src/networks/test_network.rs new file mode 100644 index 000000000..1c20d68c7 --- /dev/null +++ b/iris-mpc-cpu/src/networks/test_network.rs @@ -0,0 +1,358 @@ +use super::network_trait::{NetworkEstablisher, NetworkTrait}; +use crate::error::Error; +use bytes::{Bytes, BytesMut}; +use iris_mpc_common::id::PartyID; +use std::{ + collections::VecDeque, + io, + io::{Error as IOError, ErrorKind as IOErrorKind}, +}; +use tokio::sync::mpsc::{self, UnboundedReceiver, UnboundedSender}; + +pub struct TestNetwork3p { + p1_p2_sender: UnboundedSender, + p1_p3_sender: UnboundedSender, + p2_p3_sender: UnboundedSender, + p2_p1_sender: UnboundedSender, + p3_p1_sender: UnboundedSender, + p3_p2_sender: UnboundedSender, + p1_p2_receiver: UnboundedReceiver, + p1_p3_receiver: UnboundedReceiver, + p2_p3_receiver: UnboundedReceiver, + p2_p1_receiver: UnboundedReceiver, + p3_p1_receiver: UnboundedReceiver, + p3_p2_receiver: UnboundedReceiver, +} + +impl Default for TestNetwork3p { + fn default() -> Self { + Self::new() + } +} + +impl TestNetwork3p { + pub fn new() -> Self { + // AT Most 1 message is buffered before they are read so this should be fine + let p1_p2 = mpsc::unbounded_channel(); + let p1_p3 = mpsc::unbounded_channel(); + let p2_p3 = mpsc::unbounded_channel(); + let p2_p1 = mpsc::unbounded_channel(); + let p3_p1 = mpsc::unbounded_channel(); + let p3_p2 = mpsc::unbounded_channel(); + + Self { + p1_p2_sender: p1_p2.0, + p1_p3_sender: p1_p3.0, + p2_p1_sender: p2_p1.0, + p2_p3_sender: p2_p3.0, + p3_p1_sender: p3_p1.0, + p3_p2_sender: p3_p2.0, + p1_p2_receiver: p1_p2.1, + p1_p3_receiver: p1_p3.1, + p2_p1_receiver: p2_p1.1, + p2_p3_receiver: p2_p3.1, + p3_p1_receiver: p3_p1.1, + p3_p2_receiver: p3_p2.1, + } + } + + pub fn get_party_networks(self) -> [PartyTestNetwork; 3] { + let party1 = PartyTestNetwork { + id: PartyID::ID0, + send_prev: self.p1_p3_sender, + recv_prev: self.p3_p1_receiver, + send_next: self.p1_p2_sender, + recv_next: self.p2_p1_receiver, + stats: [0; 4], + }; + + let party2 = PartyTestNetwork { + id: PartyID::ID1, + send_prev: self.p2_p1_sender, + recv_prev: self.p1_p2_receiver, + send_next: self.p2_p3_sender, + recv_next: self.p3_p2_receiver, + stats: [0; 4], + }; + + let party3 = PartyTestNetwork { + id: PartyID::ID2, + send_prev: self.p3_p2_sender, + recv_prev: self.p2_p3_receiver, + send_next: self.p3_p1_sender, + recv_next: self.p1_p3_receiver, + stats: [0; 4], + }; + + [party1, party2, party3] + } +} + +pub struct TestNetworkEstablisher { + id: PartyID, + test_network: VecDeque, +} + +pub struct PartyTestNetwork { + id: PartyID, + send_prev: UnboundedSender, + send_next: UnboundedSender, + recv_prev: UnboundedReceiver, + recv_next: UnboundedReceiver, + stats: [usize; 4], // [sent_prev, sent_next, recv_prev, recv_next] +} + +impl From> for TestNetworkEstablisher { + fn from(net: VecDeque) -> Self { + Self { + id: net.front().unwrap().id, + test_network: net, + } + } +} + +impl PartyTestNetwork { + pub const NUM_PARTIES: usize = 3; +} + +impl NetworkEstablisher for TestNetworkEstablisher { + fn get_id(&self) -> PartyID { + self.id + } + + fn get_num_parties(&self) -> usize { + 3 + } + + async fn open_channel(&mut self) -> Result { + self.test_network + .pop_front() + .ok_or(Error::Other("test config error".to_owned())) + } + + async fn shutdown(self) -> Result<(), Error> { + Ok(()) + } + + fn get_send_receive(&self, _: usize) -> std::io::Result<(u64, u64)> { + unreachable!() + } + + fn print_connection_stats(&self, _: &mut impl std::io::Write) -> std::io::Result<()> { + unreachable!() + } +} + +impl NetworkTrait for PartyTestNetwork { + async fn shutdown(self) -> Result<(), IOError> { + Ok(()) + } + + async fn send(&mut self, id: PartyID, data: Bytes) -> std::io::Result<()> { + tracing::trace!("send_id {}->{}: {:?}", self.id, id, data); + let res = if id == self.id.next_id() { + self.stats[1] += data.len(); + self.send_next + .send(data) + .map_err(|_| IOError::new(IOErrorKind::Other, "Send failed")) + } else if id == self.id.prev_id() { + self.stats[0] += data.len(); + self.send_prev + .send(data) + .map_err(|_| IOError::new(IOErrorKind::Other, "Send failed")) + } else { + Err(IOError::new(io::ErrorKind::Other, "Invalid ID")) + }; + + tracing::trace!("send_id {}->{}: done", self.id, id); + res + } + + async fn receive(&mut self, id: PartyID) -> std::io::Result { + tracing::trace!("recv_id {}<-{}: ", self.id, id); + let buf = if id == self.id.prev_id() { + let data = self + .recv_prev + .recv() + .await + .ok_or_else(|| IOError::new(IOErrorKind::Other, "Receive failed"))?; + self.stats[2] += data.len(); + data + } else if id == self.id.next_id() { + let data = self + .recv_next + .recv() + .await + .ok_or_else(|| IOError::new(IOErrorKind::Other, "Receive failed"))?; + self.stats[3] += data.len(); + data + } else { + return Err(io::Error::new(io::ErrorKind::Other, "Invalid ID")); + }; + tracing::trace!("recv_id {}<-{}: done", self.id, id); + + Ok(BytesMut::from(buf.as_ref())) + } + + async fn broadcast(&mut self, data: Bytes) -> Result, io::Error> { + let mut result = Vec::with_capacity(3); + for id in 0..3 { + if id != usize::from(self.id) { + self.send(PartyID::try_from(id).unwrap(), data.clone()) + .await?; + } + } + for id in 0..3 { + if id == usize::from(self.id) { + result.push(BytesMut::from(data.as_ref())); + } else { + result.push(self.receive(PartyID::try_from(id).unwrap()).await?); + } + } + Ok(result) + } + + fn get_id(&self) -> PartyID { + self.id + } + + async fn send_next_id(&mut self, data: Bytes) -> Result<(), IOError> { + tracing::trace!("send {}->{}: {:?}", self.id, self.id.next_id(), data); + self.stats[1] += data.len(); + let res = self + .send_next + .send(data) + .map_err(|_| IOError::new(IOErrorKind::Other, "Send failed")); + tracing::trace!("send {}->{}: done", self.id, self.id.next_id()); + res + } + + async fn send_prev_id(&mut self, data: Bytes) -> Result<(), IOError> { + tracing::trace!("send {}->{}: {:?}", self.id, self.id.prev_id(), data); + self.stats[0] += data.len(); + let res = self + .send_prev + .send(data) + .map_err(|_| IOError::new(IOErrorKind::Other, "Send failed")); + tracing::trace!("send {}->{}: done", self.id, self.id.prev_id()); + res + } + + async fn receive_prev_id(&mut self) -> Result { + tracing::trace!("recv {}<-{}: ", self.id, self.id.prev_id()); + let buf = self + .recv_prev + .recv() + .await + .ok_or_else(|| IOError::new(IOErrorKind::Other, "Receive failed"))?; + self.stats[2] += buf.len(); + + tracing::trace!("recv {}<-{}: done", self.id, self.id.prev_id()); + Ok(BytesMut::from(buf.as_ref())) + } + + async fn receive_next_id(&mut self) -> Result { + tracing::trace!("recv {}<-{}: ", self.id, self.id.next_id()); + let buf = self + .recv_next + .recv() + .await + .ok_or_else(|| IOError::new(IOErrorKind::Other, "Receive failed"))?; + self.stats[3] += buf.len(); + + tracing::trace!("recv {}<-{}: done", self.id, self.id.next_id()); + Ok(BytesMut::from(buf.as_ref())) + } + + fn blocking_send( + &mut self, + id: PartyID, + data: Bytes, + ) -> Result<(), super::network_trait::IoError> { + tracing::trace!("send_id {}->{}: {:?}", self.id, id, data); + let res = if id == self.id.next_id() { + self.blocking_send_next_id(data) + } else { + self.blocking_send_prev_id(data) + }; + tracing::trace!("send_id {}->{}: done", self.id, id); + res + } + + fn blocking_send_next_id(&mut self, data: Bytes) -> Result<(), super::network_trait::IoError> { + tracing::trace!("send {}->{}: {:?}", self.id, self.id.next_id(), data); + self.stats[1] += data.len(); + let res = self + .send_next + .send(data) + .map_err(|_| IOError::new(IOErrorKind::Other, "Send failed")); + tracing::trace!("send {}->{}: done", self.id, self.id.next_id()); + res + } + + fn blocking_send_prev_id(&mut self, data: Bytes) -> Result<(), super::network_trait::IoError> { + tracing::trace!("send {}->{}: {:?}", self.id, self.id.prev_id(), data); + self.stats[0] += data.len(); + let res = self + .send_prev + .send(data) + .map_err(|_| IOError::new(IOErrorKind::Other, "Send failed")); + tracing::trace!("send {}->{}: done", self.id, self.id.prev_id()); + res + } + + fn blocking_receive(&mut self, id: PartyID) -> Result { + tracing::trace!("recv_id {}<-{}: ", self.id, id); + let buf = if id == self.id.next_id() { + self.blocking_receive_next_id() + } else { + self.blocking_receive_prev_id() + }; + tracing::trace!("recv_id {}<-{}: done", self.id, id); + buf + } + + fn blocking_receive_prev_id(&mut self) -> Result { + tracing::trace!("recv {}<-{}: ", self.id, self.id.prev_id()); + let buf = self + .recv_prev + .blocking_recv() + .ok_or_else(|| IOError::new(IOErrorKind::Other, "Receive failed"))?; + self.stats[2] += buf.len(); + + tracing::trace!("recv {}<-{}: done", self.id, self.id.prev_id()); + Ok(BytesMut::from(buf.as_ref())) + } + + fn blocking_receive_next_id(&mut self) -> Result { + tracing::trace!("recv {}<-{}: ", self.id, self.id.next_id()); + let buf = self + .recv_next + .blocking_recv() + .ok_or_else(|| IOError::new(IOErrorKind::Other, "Receive failed"))?; + self.stats[3] += buf.len(); + + tracing::trace!("recv {}<-{}: done", self.id, self.id.next_id()); + Ok(BytesMut::from(buf.as_ref())) + } + + fn blocking_broadcast( + &mut self, + data: Bytes, + ) -> Result, super::network_trait::IoError> { + let mut result = Vec::with_capacity(3); + for id in 0..3 { + if id != usize::from(self.id) { + self.blocking_send(PartyID::try_from(id).unwrap(), data.clone())? + } + } + for id in 0..3 { + if id == usize::from(self.id) { + result.push(BytesMut::from(data.as_ref())); + } else { + result.push(self.blocking_receive(PartyID::try_from(id).unwrap())?); + } + } + Ok(result) + } +} diff --git a/iris-mpc-cpu/src/protocol/binary.rs b/iris-mpc-cpu/src/protocol/binary.rs new file mode 100644 index 000000000..a107ab7b1 --- /dev/null +++ b/iris-mpc-cpu/src/protocol/binary.rs @@ -0,0 +1,310 @@ +use super::iris::WorkerThread; +use crate::{ + error::Error, + networks::network_trait::NetworkTrait, + protocol::iris::{A, A_BITS, B_BITS}, + shares::{ + bit::Bit, + int_ring::IntRing2k, + ring_impl::RingElement, + share::Share, + vecshare::{SliceShare, VecShare}, + }, + utils::Utils, +}; +use iris_mpc_common::id::PartyID; +use num_traits::{One, Zero}; +use rand::{distributions::Standard, prelude::Distribution, Rng}; + +impl WorkerThread { + fn bit_inject_ot_2round_sender( + &mut self, + input: VecShare, + ) -> Result, Error> + where + Standard: Distribution, + { + let len = input.len(); + let mut m0 = Vec::with_capacity(len); + let mut m1 = Vec::with_capacity(len); + let mut shares = VecShare::with_capacity(len); + + for inp in input.into_iter() { + let (a, b) = inp.get_ab(); + // new shares + let (c3, c2) = self.prf.gen_rands::>(); + // mask of the ot + let w0 = self.prf.get_my_prf().gen::>(); + let w1 = self.prf.get_my_prf().gen::>(); + + shares.push(Share::new(c3, c2)); + let c = c3 + c2; + let xor = RingElement(T::from((a ^ b).convert().convert())); + let m0_ = xor - c; + let m1_ = (xor ^ RingElement::one()) - c; + m0.push(m0_ ^ w0); + m1.push(m1_ ^ w1); + } + self.network + .blocking_send_prev_id(Utils::ring_slice_to_bytes(&m0))?; + self.network + .blocking_send_prev_id(Utils::ring_slice_to_bytes(&m1))?; + + Ok(shares) + } + + fn bit_inject_ot_2round_receiver( + &mut self, + input: VecShare, + ) -> Result, Error> + where + Standard: Distribution, + { + let len = input.len(); + let m0_bytes = self.network.blocking_receive_next_id()?; + let m1_bytes = self.network.blocking_receive_next_id()?; + let wc_bytes = self.network.blocking_receive_prev_id()?; + + let m0 = Utils::ring_iter_from_bytes(m0_bytes, len)?; + let m1 = Utils::ring_iter_from_bytes(m1_bytes, len)?; + + let wc = Utils::ring_iter_from_bytes(wc_bytes, len)?; + + let mut shares = VecShare::with_capacity(len); + let mut send = Vec::with_capacity(len); + + for ((inp, wc), (m0, m1)) in input.into_iter().zip(wc).zip(m0.zip(m1)) { + // new share + let c2 = self.prf.get_my_prf().gen::>(); + + let choice = inp.get_b().convert().convert(); + let xor = if choice { wc ^ m1 } else { wc ^ m0 }; + + send.push(xor); + shares.push(Share::new(c2, xor)); + } + + // Reshare to Helper + self.network + .blocking_send_prev_id(Utils::ring_slice_to_bytes(&send))?; + + Ok(shares) + } + + fn bit_inject_ot_2round_helper( + &mut self, + input: VecShare, + ) -> Result, Error> + where + Standard: Distribution, + { + let len = input.len(); + let mut wc = Vec::with_capacity(len); + let mut shares = VecShare::with_capacity(len); + + for inp in input.into_iter() { + // new share + let c3 = self.prf.get_prev_prf().gen::>(); + shares.push(Share::new(RingElement::zero(), c3)); + + // mask of the ot + let w0 = self.prf.get_prev_prf().gen::>(); + let w1 = self.prf.get_prev_prf().gen::>(); + + let choice = inp.get_a().convert().convert(); + if choice { + wc.push(w1); + } else { + wc.push(w0); + } + } + self.network + .blocking_send_next_id(Utils::ring_slice_to_bytes(&wc))?; + + // Receive Reshare + let c1_bytes = self.network.blocking_receive_next_id()?; + let c1 = Utils::ring_iter_from_bytes(c1_bytes, len)?; + + for (s, c1) in shares.iter_mut().zip(c1) { + s.a = c1; + } + Ok(shares) + } + + // TODO this is inbalanced, so a real implementation should actually rotate + // parties around + pub(crate) fn bit_inject_ot_2round( + &mut self, + input: VecShare, + ) -> Result, Error> + where + Standard: Distribution, + { + let res = match self.get_party_id() { + PartyID::ID0 => { + // OT Helper + self.bit_inject_ot_2round_helper(input)? + } + PartyID::ID1 => { + // OT Receiver + self.bit_inject_ot_2round_receiver(input)? + } + PartyID::ID2 => { + // OT Sender + self.bit_inject_ot_2round_sender(input)? + } + }; + Ok(res) + } + + pub(crate) fn mul_lift_2k(vals: SliceShare) -> VecShare + where + u32: From, + { + VecShare::new_vec( + vals.iter() + .map(|val| { + let a = (u32::from(val.a.0)) << K; + let b = (u32::from(val.b.0)) << K; + Share::new(RingElement(a), RingElement(b)) + }) + .collect(), + ) + } + + pub(crate) fn a2b_pre(&self, x: Share) -> (Share, Share, Share) { + let (a, b) = x.get_ab(); + + let mut x1 = Share::zero(); + let mut x2 = Share::zero(); + let mut x3 = Share::zero(); + + match self.network.get_id() { + PartyID::ID0 => { + x1.a = a; + x3.b = b; + } + PartyID::ID1 => { + x2.a = a; + x1.b = b; + } + PartyID::ID2 => { + x3.a = a; + x2.b = b; + } + } + (x1, x2, x3) + } + + // Extracts bit at position K + pub fn extract_msb( + &mut self, + x_: VecShare, + ) -> Result, Error> { + // let truncate_len = x_.len(); + let x = x_.transpose_pack_u64_with_len::(); + + let len = x.len(); + + let mut x1 = Vec::with_capacity(len); + let mut x2 = Vec::with_capacity(len); + let mut x3 = Vec::with_capacity(len); + + for x_ in x.into_iter() { + let len_ = x_.len(); + let mut x1_ = VecShare::with_capacity(len_); + let mut x2_ = VecShare::with_capacity(len_); + let mut x3_ = VecShare::with_capacity(len_); + for x__ in x_.into_iter() { + let (x1__, x2__, x3__) = self.a2b_pre(x__); + x1_.push(x1__); + x2_.push(x2__); + x3_.push(x3__); + } + x1.push(x1_); + x2.push(x2_); + x3.push(x3_); + } + + self.binary_add_3_get_msb(x1, x2, x3) + } + + pub(crate) fn binary_add_3_get_msb( + &mut self, + x1: Vec>, + x2: Vec>, + mut x3: Vec>, + // truncate_len: usize, + ) -> Result, Error> + where + Standard: Distribution, + { + let len = x1.len(); + debug_assert!(len == x2.len() && len == x3.len()); + + // Full adder to get 2 * c and s + let mut x2x3 = x2; + transposed_pack_xor_assign(&mut x2x3, &x3); + let s = transposed_pack_xor(&x1, &x2x3); + let mut x1x3 = x1; + transposed_pack_xor_assign(&mut x1x3, &x3); + // 2 * c + x1x3.pop().expect("Enough elements present"); + x2x3.pop().expect("Enough elements present"); + x3.pop().expect("Enough elements present"); + let mut c = self.transposed_pack_and(x1x3, x2x3)?; + transposed_pack_xor_assign(&mut c, &x3); + + // Add 2c + s via a ripple carry adder + // LSB of c is 0 + // First round: half adder can be skipped due to LSB of c being 0 + let mut a = s; + let mut b = c; + + // First full adder (carry is 0) + let mut c = self.and_many(a[1].as_slice(), b[0].as_slice())?; + + // For last round + let mut a_msb = a.pop().expect("Enough elements present"); + let b_msb = b.pop().expect("Enough elements present"); + + // 2 -> k-1 + for (a_, b_) in a.iter_mut().skip(2).zip(b.iter_mut().skip(1)) { + *a_ ^= c.as_slice(); + *b_ ^= c.as_slice(); + let tmp_c = self.and_many(a_.as_slice(), b_.as_slice())?; + c ^= tmp_c; + } + + a_msb ^= b_msb; + a_msb ^= c; + + // Extract bits for outputs + let res = a_msb; + // let mut res = a_msb.convert_to_bits(); + // res.truncate(truncate_len); + + Ok(res) + } + + // Compute code_dots > a/b * mask_dots + // via MSB(a * mask_dots - b * code_dots) + pub fn compare_threshold_masked_many( + &mut self, + code_dots: VecShare, + mask_dots: VecShare, + ) -> Result, Error> { + debug_assert!(A_BITS as u64 <= B_BITS); + let len = code_dots.len(); + assert_eq!(len, mask_dots.len()); + + let y = Self::mul_lift_2k::<_, B_BITS>(code_dots.as_slice()); + let mut x = self.lift::<{ B_BITS as usize }>(mask_dots)?; + for x_ in x.iter_mut() { + *x_ *= A as u32; + } + + x.sub_assign(y); + self.extract_msb::<{ u16::K + B_BITS as usize }>(x) + } +} diff --git a/iris-mpc-cpu/src/protocol/iris.rs b/iris-mpc-cpu/src/protocol/iris.rs new file mode 100644 index 000000000..65c70e744 --- /dev/null +++ b/iris-mpc-cpu/src/protocol/iris.rs @@ -0,0 +1,52 @@ +use super::prf::{Prf, PrfSeed}; +use crate::{error::Error, networks::network_trait::NetworkTrait}; +use bytes::{Buf, Bytes, BytesMut}; +use iris_mpc_common::id::PartyID; + +pub(crate) const IRIS_CODE_SIZE: usize = + iris_mpc_common::iris_db::iris::IrisCodeArray::IRIS_CODE_SIZE; +pub(crate) const MATCH_THRESHOLD_RATIO: f64 = iris_mpc_common::iris_db::iris::MATCH_THRESHOLD_RATIO; +pub(crate) const B_BITS: u64 = 16; +pub(crate) const B: u64 = 1 << B_BITS; +pub(crate) const A: u64 = ((1. - 2. * MATCH_THRESHOLD_RATIO) * B as f64) as u64; +pub(crate) const A_BITS: u32 = u64::BITS - A.leading_zeros(); + +pub(crate) struct WorkerThread { + pub(crate) id: usize, + pub(crate) network: N, + pub(crate) prf: Prf, +} + +impl WorkerThread { + pub(crate) fn create(id: usize, network: N) -> Self { + Self { + id, + network, + prf: Prf::default(), + } + } + + pub fn get_party_id(&self) -> PartyID { + self.network.get_id() + } + + pub(crate) fn bytes_to_seed(mut bytes: BytesMut) -> Result { + if bytes.len() != std::mem::size_of::() { + Err(Error::InvalidMessageSize) + } else { + let mut their_seed: PrfSeed = PrfSeed::default(); + bytes.copy_to_slice(&mut their_seed); + Ok(their_seed) + } + } + + pub(crate) async fn setup_prf(&mut self) -> Result<(), Error> { + let seed = Prf::gen_seed(); + let data = Bytes::from_iter(seed.into_iter()); + self.network.send_next_id(data).await?; + let response = self.network.receive_prev_id().await?; + let their_seed = Self::bytes_to_seed(response)?; + self.prf = Prf::new(seed, their_seed); + Ok(()) + } +} diff --git a/iris-mpc-cpu/src/protocol/mod.rs b/iris-mpc-cpu/src/protocol/mod.rs new file mode 100644 index 000000000..fd5bfc452 --- /dev/null +++ b/iris-mpc-cpu/src/protocol/mod.rs @@ -0,0 +1,3 @@ +pub(crate) mod binary; +pub(crate) mod iris; +pub(crate) mod prf; diff --git a/iris-mpc-cpu/src/protocol/prf.rs b/iris-mpc-cpu/src/protocol/prf.rs new file mode 100644 index 000000000..2f9d1328a --- /dev/null +++ b/iris-mpc-cpu/src/protocol/prf.rs @@ -0,0 +1,66 @@ +use crate::shares::{int_ring::IntRing2k, ring_impl::RingElement}; +use aes_prng::AesRng; +use rand::{distributions::Standard, prelude::Distribution, Rng, SeedableRng}; + +pub type PrfSeed = ::Seed; + +pub struct Prf { + my_prf: AesRng, + prev_prf: AesRng, +} + +impl Default for Prf { + fn default() -> Self { + Self { + my_prf: AesRng::from_entropy(), + prev_prf: AesRng::from_entropy(), + } + } +} + +impl Prf { + pub fn new(my_key: PrfSeed, next_key: PrfSeed) -> Self { + Self { + my_prf: AesRng::from_seed(my_key), + prev_prf: AesRng::from_seed(next_key), + } + } + + pub fn get_my_prf(&mut self) -> &mut AesRng { + &mut self.my_prf + } + + pub fn get_prev_prf(&mut self) -> &mut AesRng { + &mut self.prev_prf + } + + pub fn gen_seed() -> PrfSeed { + let mut rng = AesRng::from_entropy(); + rng.gen::() + } + + pub fn gen_rands(&mut self) -> (T, T) + where + Standard: Distribution, + { + let a = self.my_prf.gen::(); + let b = self.prev_prf.gen::(); + (a, b) + } + + pub fn gen_zero_share(&mut self) -> RingElement + where + Standard: Distribution, + { + let (a, b) = self.gen_rands::>(); + a - b + } + + pub fn gen_binary_zero_share(&mut self) -> RingElement + where + Standard: Distribution, + { + let (a, b) = self.gen_rands::>(); + a ^ b + } +} diff --git a/iris-mpc-cpu/src/shares/bit.rs b/iris-mpc-cpu/src/shares/bit.rs new file mode 100644 index 000000000..1af8ae6ad --- /dev/null +++ b/iris-mpc-cpu/src/shares/bit.rs @@ -0,0 +1,391 @@ +use crate::error::Error; +use num_traits::{ + AsPrimitive, One, WrappingAdd, WrappingMul, WrappingNeg, WrappingShl, WrappingShr, WrappingSub, + Zero, +}; +use rand::{distributions::Standard, prelude::Distribution, Rng}; +use serde::{Deserialize, Serialize}; +use std::ops::{ + Add, BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Mul, Neg, Not, Rem, Shl, + Shr, Sub, +}; + +/// Bit is a sharable wrapper for a boolean value +#[derive( + Copy, + Clone, + Debug, + Default, + Eq, + PartialEq, + PartialOrd, + Ord, + Serialize, + Deserialize, + bytemuck::NoUninit, + bytemuck::AnyBitPattern, +)] +#[repr(transparent)] +/// This transparent is important due to some typecasts! +pub struct Bit(pub(super) u8); + +impl AsPrimitive for Bit { + fn as_(self) -> Self { + self + } +} + +impl std::fmt::Display for Bit { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self.0 { + 1 => write!(f, "1"), + 0 => write!(f, "0"), + _ => unreachable!(), + } + } +} + +impl Bit { + pub fn new(value: bool) -> Self { + Self(value as u8) + } + + pub fn convert(self) -> bool { + debug_assert!(self.0 == 0 || self.0 == 1); + self.0 == 1 + } +} + +impl TryFrom for Bit { + type Error = Error; + + fn try_from(value: u8) -> Result { + match value { + 0 => Ok(Bit(0)), + 1 => Ok(Bit(1)), + _ => Err(Error::Conversion), + } + } +} + +impl TryFrom for Bit { + type Error = Error; + + fn try_from(value: usize) -> Result { + match value { + 0 => Ok(Bit(0)), + 1 => Ok(Bit(1)), + _ => Err(Error::Conversion), + } + } +} + +impl Add for Bit { + type Output = Self; + + #[allow(clippy::suspicious_arithmetic_impl)] + #[inline(always)] + fn add(self, rhs: Self) -> Self::Output { + self ^ rhs + } +} + +impl Add<&Bit> for Bit { + type Output = Self; + + #[allow(clippy::suspicious_arithmetic_impl)] + #[inline(always)] + fn add(self, rhs: &Self) -> Self::Output { + self ^ rhs + } +} + +impl WrappingAdd for Bit { + #[inline(always)] + fn wrapping_add(&self, rhs: &Self) -> Self::Output { + *self ^ *rhs + } +} + +impl Sub for Bit { + type Output = Self; + + #[allow(clippy::suspicious_arithmetic_impl)] + #[inline(always)] + fn sub(self, rhs: Self) -> Self::Output { + self ^ rhs + } +} + +impl Sub<&Bit> for Bit { + type Output = Self; + + #[allow(clippy::suspicious_arithmetic_impl)] + #[inline(always)] + fn sub(self, rhs: &Self) -> Self::Output { + self ^ rhs + } +} + +impl WrappingSub for Bit { + #[inline(always)] + fn wrapping_sub(&self, rhs: &Self) -> Self::Output { + *self ^ *rhs + } +} + +impl Neg for Bit { + type Output = Self; + #[inline(always)] + fn neg(self) -> Self::Output { + self + } +} + +impl WrappingNeg for Bit { + #[inline(always)] + fn wrapping_neg(&self) -> Self { + -*self + } +} + +impl BitXor for Bit { + type Output = Self; + + #[inline(always)] + fn bitxor(self, rhs: Self) -> Self::Output { + Bit(self.0 ^ rhs.0) + } +} + +impl BitXor<&Bit> for Bit { + type Output = Self; + + #[inline(always)] + fn bitxor(self, rhs: &Self) -> Self::Output { + Bit(self.0 ^ rhs.0) + } +} + +impl BitXorAssign for Bit { + #[inline(always)] + fn bitxor_assign(&mut self, rhs: Self) { + self.0 ^= rhs.0; + } +} + +impl BitXorAssign<&Bit> for Bit { + #[inline(always)] + fn bitxor_assign(&mut self, rhs: &Self) { + self.0 ^= rhs.0; + } +} + +impl BitOr for Bit { + type Output = Self; + + #[inline(always)] + fn bitor(self, rhs: Self) -> Self::Output { + Bit(self.0 | rhs.0) + } +} + +impl BitOr<&Bit> for Bit { + type Output = Self; + + #[inline(always)] + fn bitor(self, rhs: &Self) -> Self::Output { + Bit(self.0 | rhs.0) + } +} + +impl BitOrAssign for Bit { + #[inline(always)] + fn bitor_assign(&mut self, rhs: Self) { + self.0 |= rhs.0; + } +} + +impl BitOrAssign<&Bit> for Bit { + #[inline(always)] + fn bitor_assign(&mut self, rhs: &Self) { + self.0 |= rhs.0; + } +} + +impl Not for Bit { + type Output = Self; + + #[inline(always)] + fn not(self) -> Self { + Self(self.0 ^ 1) + } +} + +impl BitAnd for Bit { + type Output = Self; + + #[inline(always)] + fn bitand(self, rhs: Self) -> Self::Output { + Bit(self.0 & rhs.0) + } +} + +impl BitAnd<&Bit> for Bit { + type Output = Self; + + #[inline(always)] + fn bitand(self, rhs: &Self) -> Self::Output { + Bit(self.0 & rhs.0) + } +} + +impl BitAndAssign for Bit { + #[inline(always)] + fn bitand_assign(&mut self, rhs: Self) { + self.0 &= rhs.0; + } +} + +impl BitAndAssign<&Bit> for Bit { + #[inline(always)] + fn bitand_assign(&mut self, rhs: &Self) { + self.0 &= rhs.0; + } +} + +impl Mul for Bit { + type Output = Self; + + #[allow(clippy::suspicious_arithmetic_impl)] + #[inline(always)] + fn mul(self, rhs: Self) -> Self::Output { + self & rhs + } +} + +impl Mul<&Bit> for Bit { + type Output = Self; + + #[allow(clippy::suspicious_arithmetic_impl)] + #[inline(always)] + fn mul(self, rhs: &Self) -> Self::Output { + self & rhs + } +} + +impl WrappingMul for Bit { + #[inline(always)] + fn wrapping_mul(&self, rhs: &Self) -> Self::Output { + *self & *rhs + } +} + +impl Zero for Bit { + #[inline(always)] + fn zero() -> Self { + Self(0) + } + + #[inline(always)] + fn is_zero(&self) -> bool { + self.0 == 0 + } +} + +impl One for Bit { + #[inline(always)] + fn one() -> Self { + Self(1) + } +} + +impl From for u8 { + #[inline(always)] + fn from(other: Bit) -> Self { + other.0 + } +} + +impl From for Bit { + #[inline(always)] + fn from(other: bool) -> Self { + Bit(other as u8) + } +} + +impl From for bool { + #[inline(always)] + fn from(other: Bit) -> Self { + other.0 == 1 + } +} + +impl Shl for Bit { + type Output = Self; + + fn shl(self, rhs: usize) -> Self { + if rhs == 0 { + self + } else { + Self(0) + } + } +} + +impl WrappingShl for Bit { + #[inline(always)] + fn wrapping_shl(&self, rhs: u32) -> Self { + *self << rhs as usize + } +} + +impl Shr for Bit { + type Output = Self; + + fn shr(self, rhs: usize) -> Self { + if rhs == 0 { + self + } else { + Self(0) + } + } +} + +impl WrappingShr for Bit { + #[inline(always)] + fn wrapping_shr(&self, rhs: u32) -> Self { + *self >> rhs as usize + } +} + +impl Distribution for Standard { + #[inline(always)] + fn sample(&self, rng: &mut R) -> Bit { + Bit(rng.gen::() as u8) + } +} + +impl AsRef for Bit { + fn as_ref(&self) -> &Bit { + self + } +} + +impl From for u128 { + fn from(val: Bit) -> Self { + u128::from(val.0) + } +} + +impl Rem for Bit { + type Output = Self; + + fn rem(self, rhs: Self) -> Self::Output { + match rhs { + Bit(0) => panic!("Division by zero"), + Bit(1) => Bit(0), + _ => unreachable!(), + } + } +} diff --git a/iris-mpc-cpu/src/shares/int_ring.rs b/iris-mpc-cpu/src/shares/int_ring.rs new file mode 100644 index 000000000..5e5da3a37 --- /dev/null +++ b/iris-mpc-cpu/src/shares/int_ring.rs @@ -0,0 +1,120 @@ +use super::bit::Bit; +use num_traits::{ + AsPrimitive, One, WrappingAdd, WrappingMul, WrappingNeg, WrappingShl, WrappingShr, WrappingSub, + Zero, +}; +use serde::{Deserialize, Serialize}; +use std::{ + fmt::Debug, + ops::{BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Neg, Not, Rem}, +}; + +pub trait IntRing2k: + std::fmt::Display + + Serialize + + for<'a> Deserialize<'a> + + Default + + WrappingAdd + + WrappingSub + + WrappingMul + + WrappingNeg + + WrappingShl + + WrappingShr + + Not + + BitXor + + BitAnd + + BitOr + + BitXorAssign + + BitAndAssign + + BitOrAssign + + PartialEq + + From + + Into + + Copy + + Debug + + Zero + + One + + Sized + + Send + + Sync + + Rem + + 'static + + bytemuck::NoUninit + + bytemuck::AnyBitPattern +{ + type Signed: Neg + From + AsPrimitive; + const K: usize; + const BYTES: usize; + + /// a += b + #[inline(always)] + fn wrapping_add_assign(&mut self, rhs: &Self) { + *self = self.wrapping_add(rhs); + } + + /// a -= b + #[inline(always)] + fn wrapping_sub_assign(&mut self, rhs: &Self) { + *self = self.wrapping_sub(rhs); + } + + /// a = -a + #[inline(always)] + fn wrapping_neg_inplace(&mut self) { + *self = self.wrapping_neg(); + } + + /// a*= b + #[inline(always)] + fn wrapping_mul_assign(&mut self, rhs: &Self) { + *self = self.wrapping_mul(rhs); + } + + /// a <<= b + #[inline(always)] + fn wrapping_shl_assign(&mut self, rhs: u32) { + *self = self.wrapping_shl(rhs); + } + + /// a >>= b + #[inline(always)] + fn wrapping_shr_assign(&mut self, rhs: u32) { + *self = self.wrapping_shr(rhs); + } +} + +impl IntRing2k for Bit { + type Signed = Bit; + const K: usize = 1; + const BYTES: usize = 1; +} + +impl IntRing2k for u8 { + type Signed = i8; + const K: usize = Self::BITS as usize; + const BYTES: usize = Self::K / 8; +} + +impl IntRing2k for u16 { + type Signed = i16; + const K: usize = Self::BITS as usize; + const BYTES: usize = Self::K / 8; +} + +impl IntRing2k for u32 { + type Signed = i32; + const K: usize = Self::BITS as usize; + const BYTES: usize = Self::K / 8; +} + +impl IntRing2k for u64 { + type Signed = i64; + const K: usize = Self::BITS as usize; + const BYTES: usize = Self::K / 8; +} + +impl IntRing2k for u128 { + type Signed = i128; + const K: usize = Self::BITS as usize; + const BYTES: usize = Self::K / 8; +} diff --git a/iris-mpc-cpu/src/shares/mod.rs b/iris-mpc-cpu/src/shares/mod.rs new file mode 100644 index 000000000..45a424f60 --- /dev/null +++ b/iris-mpc-cpu/src/shares/mod.rs @@ -0,0 +1,6 @@ +pub(crate) mod bit; +pub(crate) mod int_ring; +pub(crate) mod ring_impl; +pub(crate) mod share; +pub(crate) mod vecshare; +pub(crate) mod vecshare_bittranspose; diff --git a/iris-mpc-cpu/src/shares/ring_impl.rs b/iris-mpc-cpu/src/shares/ring_impl.rs new file mode 100644 index 000000000..16d7501c8 --- /dev/null +++ b/iris-mpc-cpu/src/shares/ring_impl.rs @@ -0,0 +1,452 @@ +use super::{bit::Bit, int_ring::IntRing2k}; +use num_traits::{One, Zero}; +use rand::{ + distributions::{Distribution, Standard}, + Rng, +}; +use serde::{Deserialize, Serialize}; +use std::{ + marker::PhantomData, + mem::ManuallyDrop, + ops::{ + Add, AddAssign, BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Mul, + MulAssign, Neg, Not, Rem, Shl, ShlAssign, Shr, ShrAssign, Sub, SubAssign, + }, +}; + +#[derive(Clone, Copy, Debug, Default, PartialEq, Serialize, Deserialize, PartialOrd, Eq, Ord)] +#[serde(bound = "")] +#[repr(transparent)] +pub struct RingElement(pub T); + +pub struct BitIter<'a, T: IntRing2k> { + bits: &'a RingElement, + index: usize, + _marker: std::marker::PhantomData, +} + +impl Iterator for BitIter<'_, T> { + type Item = Bit; + + fn next(&mut self) -> Option { + if self.index >= T::K { + None + } else { + let bit = ((self.bits.0 >> self.index) & T::one()) == T::one(); + self.index += 1; + Some(Bit(bit as u8)) + } + } + + fn size_hint(&self) -> (usize, Option) { + let len = T::K - self.index; + (len, Some(len)) + } +} + +impl RingElement { + /// Safe because RingElement has repr(transparent) + pub fn convert_slice(vec: &[Self]) -> &[T] { + // SAFETY: RingElement has repr(transparent) + unsafe { &*(vec as *const [Self] as *const [T]) } + } + + /// Safe because RingElement has repr(transparent) + pub fn convert_vec(vec: Vec) -> Vec { + let me = ManuallyDrop::new(vec); + // SAFETY: RingElement has repr(transparent) + unsafe { Vec::from_raw_parts(me.as_ptr() as *mut T, me.len(), me.capacity()) } + } + + /// Safe because RingElement has repr(transparent) + pub fn convert_slice_rev(vec: &[T]) -> &[Self] { + // SAFETY: RingElement has repr(transparent) + unsafe { &*(vec as *const [T] as *const [Self]) } + } + + /// Safe because RingElement has repr(transparent) + pub fn convert_vec_rev(vec: Vec) -> Vec { + let me = ManuallyDrop::new(vec); + // SAFETY: RingElement has repr(transparent) + unsafe { Vec::from_raw_parts(me.as_ptr() as *mut Self, me.len(), me.capacity()) } + } + + pub fn convert(self) -> T { + self.0 + } + + pub(crate) fn bit_iter(&self) -> BitIter<'_, T> { + BitIter { + bits: self, + index: 0, + _marker: PhantomData, + } + } + + pub fn get_bit(&self, index: usize) -> Self { + RingElement((self.0 >> index) & T::one()) + } + + #[cfg(test)] + pub(crate) fn get_bit_as_bit(&self, index: usize) -> RingElement { + let bit = ((self.0 >> index) & T::one()) == T::one(); + RingElement(Bit(bit as u8)) + } + + pub fn upgrade_to_128(self) -> RingElement { + RingElement(self.0.into()) + } +} + +impl std::fmt::Display for RingElement { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + std::fmt::Display::fmt(&self.0, f) + } +} + +impl Add for RingElement { + type Output = Self; + + fn add(self, rhs: Self) -> Self::Output { + Self(self.0.wrapping_add(&rhs.0)) + } +} + +impl Add<&Self> for RingElement { + type Output = Self; + + fn add(self, rhs: &Self) -> Self::Output { + Self(self.0.wrapping_add(&rhs.0)) + } +} + +impl AddAssign for RingElement { + fn add_assign(&mut self, rhs: Self) { + self.0.wrapping_add_assign(&rhs.0) + } +} + +impl AddAssign<&Self> for RingElement { + fn add_assign(&mut self, rhs: &Self) { + self.0.wrapping_add_assign(&rhs.0) + } +} + +impl Sub for RingElement { + type Output = Self; + + fn sub(self, rhs: Self) -> Self::Output { + Self(self.0.wrapping_sub(&rhs.0)) + } +} + +impl Sub<&Self> for RingElement { + type Output = Self; + + fn sub(self, rhs: &Self) -> Self::Output { + Self(self.0.wrapping_sub(&rhs.0)) + } +} + +impl SubAssign for RingElement { + fn sub_assign(&mut self, rhs: Self) { + self.0.wrapping_sub_assign(&rhs.0) + } +} + +impl SubAssign<&Self> for RingElement { + fn sub_assign(&mut self, rhs: &Self) { + self.0.wrapping_sub_assign(&rhs.0) + } +} + +impl Mul for RingElement { + type Output = Self; + + fn mul(self, rhs: T) -> Self::Output { + Self(self.0.wrapping_mul(&rhs)) + } +} + +impl Mul<&T> for RingElement { + type Output = Self; + + fn mul(self, rhs: &T) -> Self::Output { + Self(self.0.wrapping_mul(rhs)) + } +} + +impl Mul for RingElement { + type Output = Self; + + fn mul(self, rhs: Self) -> Self::Output { + Self(self.0.wrapping_mul(&rhs.0)) + } +} + +impl Mul<&Self> for RingElement { + type Output = Self; + + fn mul(self, rhs: &Self) -> Self::Output { + Self(self.0.wrapping_mul(&rhs.0)) + } +} + +impl MulAssign for RingElement { + fn mul_assign(&mut self, rhs: Self) { + self.0.wrapping_mul_assign(&rhs.0) + } +} + +impl MulAssign<&Self> for RingElement { + fn mul_assign(&mut self, rhs: &Self) { + self.0.wrapping_mul_assign(&rhs.0) + } +} + +impl MulAssign for RingElement { + fn mul_assign(&mut self, rhs: T) { + self.0.wrapping_mul_assign(&rhs) + } +} + +impl MulAssign<&T> for RingElement { + fn mul_assign(&mut self, rhs: &T) { + self.0.wrapping_mul_assign(rhs) + } +} + +impl Zero for RingElement { + fn zero() -> Self { + Self(T::zero()) + } + + fn is_zero(&self) -> bool { + self.0.is_zero() + } +} + +impl One for RingElement { + fn one() -> Self { + Self(T::one()) + } +} + +impl Neg for RingElement { + type Output = Self; + + fn neg(self) -> Self::Output { + Self(self.0.wrapping_neg()) + } +} + +impl Distribution> for Standard +where + Standard: Distribution, +{ + #[inline(always)] + fn sample(&self, rng: &mut R) -> RingElement { + RingElement(rng.gen()) + } +} + +impl Not for RingElement { + type Output = Self; + + fn not(self) -> Self { + Self(!self.0) + } +} + +impl BitXor for RingElement { + type Output = Self; + + fn bitxor(self, rhs: Self) -> Self::Output { + RingElement(self.0 ^ rhs.0) + } +} + +impl BitXor<&Self> for RingElement { + type Output = Self; + + fn bitxor(self, rhs: &Self) -> Self::Output { + RingElement(self.0 ^ rhs.0) + } +} + +impl BitXorAssign for RingElement { + fn bitxor_assign(&mut self, rhs: Self) { + self.0 ^= rhs.0; + } +} + +impl BitXorAssign<&Self> for RingElement { + fn bitxor_assign(&mut self, rhs: &Self) { + self.0 ^= rhs.0; + } +} + +impl BitOr for RingElement { + type Output = Self; + + fn bitor(self, rhs: Self) -> Self::Output { + RingElement(self.0 | rhs.0) + } +} + +impl BitOr<&Self> for RingElement { + type Output = Self; + + fn bitor(self, rhs: &Self) -> Self::Output { + RingElement(self.0 | rhs.0) + } +} + +impl BitOrAssign for RingElement { + fn bitor_assign(&mut self, rhs: Self) { + self.0 |= rhs.0; + } +} + +impl BitOrAssign<&Self> for RingElement { + fn bitor_assign(&mut self, rhs: &Self) { + self.0 |= rhs.0; + } +} + +impl BitAnd for RingElement { + type Output = Self; + + fn bitand(self, rhs: Self) -> Self::Output { + RingElement(self.0 & rhs.0) + } +} + +impl BitAnd for RingElement { + type Output = Self; + + fn bitand(self, rhs: T) -> Self::Output { + RingElement(self.0 & rhs) + } +} + +impl BitAnd<&Self> for RingElement { + type Output = Self; + + fn bitand(self, rhs: &Self) -> Self::Output { + RingElement(self.0 & rhs.0) + } +} + +impl BitAndAssign for RingElement { + fn bitand_assign(&mut self, rhs: Self) { + self.0 &= rhs.0; + } +} + +impl BitAndAssign<&Self> for RingElement { + fn bitand_assign(&mut self, rhs: &Self) { + self.0 &= rhs.0; + } +} + +impl Shl for RingElement { + type Output = Self; + + fn shl(self, rhs: u32) -> Self::Output { + RingElement(self.0.wrapping_shl(rhs)) + } +} + +impl ShlAssign for RingElement { + fn shl_assign(&mut self, rhs: u32) { + self.0.wrapping_shl_assign(rhs) + } +} + +impl Shr for RingElement { + type Output = Self; + + fn shr(self, rhs: u32) -> Self::Output { + RingElement(self.0.wrapping_shr(rhs)) + } +} + +impl ShrAssign for RingElement { + fn shr_assign(&mut self, rhs: u32) { + self.0.wrapping_shr_assign(rhs) + } +} + +impl Rem for RingElement { + type Output = Self; + + fn rem(self, rhs: T) -> Self::Output { + RingElement(self.0 % rhs) + } +} + +#[cfg(test)] +mod unsafe_test { + use super::*; + use aes_prng::AesRng; + use rand::{Rng, SeedableRng}; + + const ELEMENTS: usize = 100; + + fn conversion_test() + where + Standard: Distribution, + { + let mut rng = AesRng::from_entropy(); + let t_vec: Vec = (0..ELEMENTS).map(|_| rng.gen()).collect(); + let rt_vec: Vec> = + (0..ELEMENTS).map(|_| rng.gen::>()).collect(); + + // Convert vec to vec> + let t_conv = RingElement::convert_vec_rev(t_vec.to_owned()); + assert_eq!(t_conv.len(), t_vec.len()); + for (a, b) in t_conv.iter().zip(t_vec.iter()) { + assert_eq!(a.0, *b) + } + + // Convert slice vec to vec> + let t_conv = RingElement::convert_slice_rev(&t_vec); + assert_eq!(t_conv.len(), t_vec.len()); + for (a, b) in t_conv.iter().zip(t_vec.iter()) { + assert_eq!(a.0, *b) + } + + // Convert vec> to vec + let rt_conv = RingElement::convert_vec(rt_vec.to_owned()); + assert_eq!(rt_conv.len(), rt_vec.len()); + for (a, b) in rt_conv.iter().zip(rt_vec.iter()) { + assert_eq!(*a, b.0) + } + + // Convert slice vec> to vec + let rt_conv = RingElement::convert_slice(&rt_vec); + assert_eq!(rt_conv.len(), rt_vec.len()); + for (a, b) in rt_conv.iter().zip(rt_vec.iter()) { + assert_eq!(*a, b.0) + } + } + + macro_rules! test_impl { + ($([$ty:ty,$fn:ident]),*) => ($( + #[test] + fn $fn() { + conversion_test::<$ty>(); + } + )*) + } + + test_impl! { + [Bit, bit_test], + [u8, u8_test], + [u16, u16_test], + [u32, u32_test], + [u64, u64_test], + [u128, u128_test] + } +} diff --git a/iris-mpc-cpu/src/shares/share.rs b/iris-mpc-cpu/src/shares/share.rs new file mode 100644 index 000000000..60ec555b9 --- /dev/null +++ b/iris-mpc-cpu/src/shares/share.rs @@ -0,0 +1,317 @@ +use super::{int_ring::IntRing2k, ring_impl::RingElement}; +use iris_mpc_common::id::PartyID; +use num_traits::Zero; +use serde::{Deserialize, Serialize}; +use std::ops::{ + Add, AddAssign, BitAnd, BitXor, BitXorAssign, Mul, MulAssign, Neg, Not, Shl, Shr, Sub, + SubAssign, +}; + +#[derive(Clone, Debug, PartialEq, Default, Eq, PartialOrd, Ord, Serialize, Deserialize)] +#[serde(bound = "")] +pub struct Share { + pub a: RingElement, + pub b: RingElement, +} + +impl Share { + pub fn new(a: RingElement, b: RingElement) -> Self { + Self { a, b } + } + + pub(crate) fn sub_from_const(&self, other: T, id: PartyID) -> Self { + let mut a = -self; + a.add_assign_const(other, id); + a + } + + pub fn add_assign_const(&mut self, other: T, id: PartyID) { + match id { + PartyID::ID0 => self.a += RingElement(other), + PartyID::ID1 => self.b += RingElement(other), + PartyID::ID2 => {} + } + } + + pub fn get_a(self) -> RingElement { + self.a + } + + pub fn get_b(self) -> RingElement { + self.b + } + + pub fn get_ab(self) -> (RingElement, RingElement) { + (self.a, self.b) + } + + pub fn get_ab_ref(&self) -> (RingElement, RingElement) { + (self.a, self.b) + } +} + +impl Add<&Self> for Share { + type Output = Self; + + fn add(self, rhs: &Self) -> Self::Output { + Share { + a: self.a + rhs.a, + b: self.b + rhs.b, + } + } +} + +impl Add for Share { + type Output = Self; + + fn add(self, rhs: Self) -> Self::Output { + Share { + a: self.a + rhs.a, + b: self.b + rhs.b, + } + } +} + +impl Sub for Share { + type Output = Self; + + fn sub(self, rhs: Self) -> Self::Output { + Share { + a: self.a - rhs.a, + b: self.b - rhs.b, + } + } +} + +impl Sub<&Self> for Share { + type Output = Self; + + fn sub(self, rhs: &Self) -> Self::Output { + Share { + a: self.a - rhs.a, + b: self.b - rhs.b, + } + } +} + +impl AddAssign for Share { + fn add_assign(&mut self, rhs: Self) { + self.a += rhs.a; + self.b += rhs.b; + } +} + +impl AddAssign<&Self> for Share { + fn add_assign(&mut self, rhs: &Self) { + self.a += rhs.a; + self.b += rhs.b; + } +} + +impl SubAssign for Share { + fn sub_assign(&mut self, rhs: Self) { + self.a -= rhs.a; + self.b -= rhs.b; + } +} + +impl SubAssign<&Self> for Share { + fn sub_assign(&mut self, rhs: &Self) { + self.a -= rhs.a; + self.b -= rhs.b; + } +} + +impl Mul> for Share { + type Output = Self; + + fn mul(self, rhs: RingElement) -> Self::Output { + Share { + a: self.a * rhs, + b: self.b * rhs, + } + } +} + +impl Mul for Share { + type Output = Self; + + fn mul(self, rhs: T) -> Self::Output { + self * RingElement(rhs) + } +} + +impl Mul for &Share { + type Output = Share; + + fn mul(self, rhs: T) -> Self::Output { + Share { + a: self.a * rhs, + b: self.b * rhs, + } + } +} + +impl MulAssign for Share { + fn mul_assign(&mut self, rhs: T) { + self.a *= rhs; + self.b *= rhs; + } +} + +/// This is only the local part of the multiplication (so without randomness and +/// without communication)! +impl Mul for &Share { + type Output = RingElement; + + fn mul(self, rhs: Self) -> Self::Output { + self.a * rhs.a + self.b * rhs.a + self.a * rhs.b + } +} + +impl BitXor for &Share { + type Output = Share; + + fn bitxor(self, rhs: Self) -> Self::Output { + Share { + a: self.a ^ rhs.a, + b: self.b ^ rhs.b, + } + } +} + +impl BitXor for Share { + type Output = Self; + + fn bitxor(self, rhs: Self) -> Self::Output { + Share { + a: self.a ^ rhs.a, + b: self.b ^ rhs.b, + } + } +} + +impl BitXor<&Self> for Share { + type Output = Self; + + fn bitxor(self, rhs: &Self) -> Self::Output { + Share { + a: self.a ^ rhs.a, + b: self.b ^ rhs.b, + } + } +} + +impl BitXorAssign<&Self> for Share { + fn bitxor_assign(&mut self, rhs: &Self) { + self.a ^= rhs.a; + self.b ^= rhs.b; + } +} + +impl BitXorAssign for Share { + fn bitxor_assign(&mut self, rhs: Self) { + self.a ^= rhs.a; + self.b ^= rhs.b; + } +} + +/// This is only the local part of the AND (so without randomness and without +/// communication)! +impl BitAnd for &Share { + type Output = RingElement; + + fn bitand(self, rhs: Self) -> Self::Output { + (self.a & rhs.a) ^ (self.b & rhs.a) ^ (self.a & rhs.b) + } +} + +impl BitAnd<&RingElement> for &Share { + type Output = Share; + + fn bitand(self, rhs: &RingElement) -> Self::Output { + Share { + a: self.a & rhs, + b: self.b & rhs, + } + } +} + +impl BitAnd for Share { + type Output = Share; + + fn bitand(self, rhs: T) -> Self::Output { + Share { + a: self.a & rhs, + b: self.b & rhs, + } + } +} + +impl Zero for Share { + fn zero() -> Self { + Self { + a: RingElement::zero(), + b: RingElement::zero(), + } + } + + fn is_zero(&self) -> bool { + self.a.is_zero() && self.b.is_zero() + } +} + +impl Neg for Share { + type Output = Self; + + fn neg(self) -> Self::Output { + Self { + a: -self.a, + b: -self.b, + } + } +} + +impl Neg for &Share { + type Output = Share; + + fn neg(self) -> Self::Output { + Share { + a: -self.a, + b: -self.b, + } + } +} + +impl Not for &Share { + type Output = Share; + + fn not(self) -> Self::Output { + Share { + a: !self.a, + b: !self.b, + } + } +} + +impl Shr for &Share { + type Output = Share; + + fn shr(self, rhs: u32) -> Self::Output { + Share { + a: self.a >> rhs, + b: self.b >> rhs, + } + } +} + +impl Shl for Share { + type Output = Self; + + fn shl(self, rhs: u32) -> Self::Output { + Self { + a: self.a << rhs, + b: self.b << rhs, + } + } +} diff --git a/iris-mpc-cpu/src/shares/vecshare.rs b/iris-mpc-cpu/src/shares/vecshare.rs new file mode 100644 index 000000000..f0d593159 --- /dev/null +++ b/iris-mpc-cpu/src/shares/vecshare.rs @@ -0,0 +1,406 @@ +use super::{bit::Bit, int_ring::IntRing2k, ring_impl::RingElement, share::Share}; +use bytes::{Buf, BytesMut}; +use num_traits::Zero; +use serde::{Deserialize, Serialize}; +use std::{ + marker::PhantomData, + ops::{AddAssign, BitXor, BitXorAssign, Deref, DerefMut, Not, SubAssign}, +}; + +#[repr(transparent)] +pub struct RingBytesIter { + bytes: BytesMut, + _marker: std::marker::PhantomData, +} + +impl RingBytesIter { + pub fn new(bytes: BytesMut) -> Self { + Self { + bytes, + _marker: PhantomData, + } + } +} + +impl Iterator for RingBytesIter { + type Item = RingElement; + + fn next(&mut self) -> Option { + if self.bytes.remaining() == 0 { + None + } else { + let res = bytemuck::pod_read_unaligned(&self.bytes.chunk()[..T::BYTES]); + self.bytes.advance(T::BYTES); + Some(RingElement(res)) + } + } + + fn size_hint(&self) -> (usize, Option) { + let len = self.bytes.remaining() / T::BYTES; + (len, Some(len)) + } +} + +#[derive(Clone, Copy, Debug)] +#[repr(transparent)] +pub struct SliceShare<'a, T: IntRing2k> { + shares: &'a [Share], +} + +impl<'a, T: IntRing2k> SliceShare<'a, T> { + pub fn split_at(&self, mid: usize) -> (SliceShare, SliceShare) { + let (a, b) = self.shares.split_at(mid); + (SliceShare { shares: a }, SliceShare { shares: b }) + } + + pub fn chunks(&self, chunk_size: usize) -> impl Iterator> + '_ { + self.shares + .chunks(chunk_size) + .map(|x| SliceShare { shares: x }) + } + + pub fn len(&self) -> usize { + self.shares.len() + } + + pub fn iter(&self) -> std::slice::Iter<'_, Share> { + self.shares.iter() + } + + pub fn is_empty(&self) -> bool { + self.shares.is_empty() + } +} + +#[derive(Debug)] +#[repr(transparent)] +pub struct SliceShareMut<'a, T: IntRing2k> { + shares: &'a mut [Share], +} + +impl<'a, T: IntRing2k> SliceShareMut<'a, T> { + pub fn to_vec(&self) -> VecShare { + VecShare { + shares: self.shares.to_vec(), + } + } + + pub fn to_slice(&self) -> SliceShare { + SliceShare { + shares: self.shares, + } + } +} + +#[derive(Clone, Debug, PartialEq, Default, Eq, PartialOrd, Ord, Serialize, Deserialize)] +#[serde(bound = "")] +#[repr(transparent)] +pub struct VecShare { + pub(crate) shares: Vec>, +} + +impl VecShare { + #[cfg(test)] + pub fn new_share(share: Share) -> Self { + let shares = vec![share]; + Self { shares } + } + + pub fn new_vec(shares: Vec>) -> Self { + Self { shares } + } + + pub fn inner(self) -> Vec> { + self.shares + } + + pub fn with_capacity(capacity: usize) -> Self { + let shares = Vec::with_capacity(capacity); + Self { shares } + } + + pub fn extend(&mut self, items: Self) { + self.shares.extend(items.shares); + } + + pub fn extend_from_slice(&mut self, items: SliceShare) { + self.shares.extend_from_slice(items.shares); + } + + pub fn len(&self) -> usize { + self.shares.len() + } + + pub fn iter(&self) -> std::slice::Iter<'_, Share> { + self.shares.iter() + } + + pub fn iter_mut(&mut self) -> std::slice::IterMut<'_, Share> { + self.shares.iter_mut() + } + + pub fn is_empty(&self) -> bool { + self.shares.is_empty() + } + + pub fn push(&mut self, el: Share) { + self.shares.push(el) + } + + pub fn pop(&mut self) -> Option> { + self.shares.pop() + } + + pub fn sum(&self) -> Share { + self.shares.iter().fold(Share::zero(), |a, b| a + b) + } + + pub fn not_inplace(&mut self) { + for x in self.shares.iter_mut() { + *x = !&*x; + } + } + + pub fn split_at(&self, mid: usize) -> (SliceShare, SliceShare) { + let (a, b) = self.shares.split_at(mid); + (SliceShare { shares: a }, SliceShare { shares: b }) + } + + pub fn split_at_mut(&mut self, mid: usize) -> (SliceShareMut, SliceShareMut) { + let (a, b) = self.shares.split_at_mut(mid); + (SliceShareMut { shares: a }, SliceShareMut { shares: b }) + } + + pub fn get_at(self, index: usize) -> Share { + self.shares[index].to_owned() + } + + pub fn from_avec_biter(a: Vec>, b: RingBytesIter) -> Self { + let shares = a + .into_iter() + .zip(b) + .map(|(a_, b_)| Share::new(a_, b_)) + .collect(); + Self { shares } + } + + pub fn from_ab(a: Vec>, b: Vec>) -> Self { + let shares = a + .into_iter() + .zip(b) + .map(|(a_, b_)| Share::new(a_, b_)) + .collect(); + Self { shares } + } + + pub fn flatten(inp: Vec) -> Self { + Self { + shares: inp.into_iter().flat_map(|x| x.shares).collect(), + } + } + + pub fn convert_to_bits(self) -> VecShare { + let mut res = VecShare::with_capacity(T::K * self.shares.len()); + for share in self.shares.into_iter() { + let (a, b) = share.get_ab(); + for (a, b) in a.bit_iter().zip(b.bit_iter()) { + res.push(Share::new(RingElement(a), RingElement(b))); + } + } + res + } + + pub fn truncate(&mut self, len: usize) { + self.shares.truncate(len); + } + + pub fn as_slice(&self) -> SliceShare { + SliceShare { + shares: &self.shares, + } + } + + pub fn as_slice_mut(&mut self) -> SliceShareMut { + SliceShareMut { + shares: &mut self.shares, + } + } +} + +impl VecShare { + pub fn pack(self) -> VecShare { + let outlen = (self.shares.len() + T::K - 1) / T::K; + let mut out = VecShare::with_capacity(outlen); + + for a_ in self.shares.chunks(T::K) { + let mut share_a = RingElement::::zero(); + let mut share_b = RingElement::::zero(); + for (i, bit) in a_.iter().enumerate() { + let (bit_a, bit_b) = bit.to_owned().get_ab(); + share_a |= RingElement(T::from(bit_a.convert().convert()) << i); + share_b |= RingElement(T::from(bit_b.convert().convert()) << i); + } + let share = Share::new(share_a, share_b); + out.push(share); + } + + out + } + + pub fn from_share(share: Share) -> Self { + let (a, b) = share.get_ab(); + let mut res = VecShare::with_capacity(T::K); + for (a, b) in a.bit_iter().zip(b.bit_iter()) { + res.push(Share::new(RingElement(a), RingElement(b))); + } + res + } +} + +impl IntoIterator for VecShare { + type Item = Share; + type IntoIter = std::vec::IntoIter>; + + fn into_iter(self) -> Self::IntoIter { + self.shares.into_iter() + } +} + +impl Not for SliceShare<'_, T> { + type Output = VecShare; + + fn not(self) -> Self::Output { + let mut v = VecShare::with_capacity(self.shares.len()); + for x in self.shares.iter() { + v.push(!x); + } + v + } +} + +impl BitXor for SliceShare<'_, T> { + type Output = VecShare; + + fn bitxor(self, rhs: Self) -> Self::Output { + debug_assert_eq!(self.shares.len(), rhs.shares.len()); + let mut v = VecShare::with_capacity(self.shares.len()); + for (x1, x2) in self.shares.iter().zip(rhs.shares.iter()) { + v.push(x1 ^ x2); + } + v + } +} + +impl BitXor for VecShare { + type Output = Self; + + fn bitxor(self, rhs: Self) -> Self::Output { + debug_assert_eq!(self.shares.len(), rhs.shares.len()); + let mut v = VecShare::with_capacity(self.shares.len()); + for (x1, x2) in self.shares.into_iter().zip(rhs.shares) { + v.push(x1 ^ x2); + } + v + } +} + +impl BitXor> for VecShare { + type Output = Self; + + fn bitxor(self, rhs: SliceShare<'_, T>) -> Self::Output { + debug_assert_eq!(self.shares.len(), rhs.shares.len()); + let mut v = VecShare::with_capacity(self.shares.len()); + for (x1, x2) in self.shares.into_iter().zip(rhs.shares.iter()) { + v.push(x1 ^ x2); + } + v + } +} + +impl AddAssign> for VecShare { + fn add_assign(&mut self, rhs: SliceShare<'_, T>) { + debug_assert_eq!(self.shares.len(), rhs.shares.len()); + for (x1, x2) in self.shares.iter_mut().zip(rhs.shares.iter()) { + *x1 += x2; + } + } +} + +impl SubAssign> for VecShare { + fn sub_assign(&mut self, rhs: SliceShare<'_, T>) { + debug_assert_eq!(self.shares.len(), rhs.shares.len()); + for (x1, x2) in self.shares.iter_mut().zip(rhs.shares.iter()) { + *x1 -= x2; + } + } +} + +impl SubAssign for VecShare { + fn sub_assign(&mut self, rhs: Self) { + debug_assert_eq!(self.shares.len(), rhs.shares.len()); + for (x1, x2) in self.shares.iter_mut().zip(rhs.shares.into_iter()) { + *x1 -= x2; + } + } +} + +impl BitXorAssign> for VecShare { + fn bitxor_assign(&mut self, rhs: SliceShare<'_, T>) { + debug_assert_eq!(self.shares.len(), rhs.shares.len()); + for (x1, x2) in self.shares.iter_mut().zip(rhs.shares.iter()) { + *x1 ^= x2; + } + } +} + +impl BitXorAssign for VecShare { + fn bitxor_assign(&mut self, rhs: Self) { + debug_assert_eq!(self.shares.len(), rhs.shares.len()); + for (x1, x2) in self.shares.iter_mut().zip(rhs.shares) { + *x1 ^= x2; + } + } +} + +impl<'a, T: IntRing2k> Deref for SliceShare<'a, T> { + type Target = [Share]; + + fn deref(&self) -> &Self::Target { + self.shares + } +} + +impl<'a, T: IntRing2k> Deref for SliceShareMut<'a, T> { + type Target = [Share]; + + fn deref(&self) -> &Self::Target { + self.shares + } +} + +impl<'a, T: IntRing2k> DerefMut for SliceShareMut<'a, T> { + fn deref_mut(&mut self) -> &mut Self::Target { + self.shares + } +} + +pub trait ChunksOwned: Sized { + fn chunks_owned(self, chunk_size: usize) -> Vec; +} + +impl ChunksOwned for VecShare { + fn chunks_owned(self, chunk_size: usize) -> Vec { + self.shares + .chunks(chunk_size) + .map(|x| Self { + shares: x.to_owned(), + }) + .collect() + } +} + +impl ChunksOwned for Vec { + fn chunks_owned(self, chunk_size: usize) -> Vec { + self.chunks(chunk_size).map(|x| x.to_vec()).collect() + } +} diff --git a/iris-mpc-cpu/src/shares/vecshare_bittranspose.rs b/iris-mpc-cpu/src/shares/vecshare_bittranspose.rs new file mode 100644 index 000000000..15b643fb7 --- /dev/null +++ b/iris-mpc-cpu/src/shares/vecshare_bittranspose.rs @@ -0,0 +1,426 @@ +use super::{ring_impl::RingElement, share::Share, vecshare::VecShare}; + +impl VecShare { + fn share64_from_share16s( + a: &Share, + b: &Share, + c: &Share, + d: &Share, + ) -> Share { + let a_ = (a.a.0 as u64) + | ((b.a.0 as u64) << 16) + | ((c.a.0 as u64) << 32) + | ((d.a.0 as u64) << 48); + let b_ = (a.b.0 as u64) + | ((b.b.0 as u64) << 16) + | ((c.b.0 as u64) << 32) + | ((d.b.0 as u64) << 48); + + Share { + a: RingElement(a_), + b: RingElement(b_), + } + } + + #[allow(clippy::too_many_arguments)] + fn share128_from_share16s( + a: &Share, + b: &Share, + c: &Share, + d: &Share, + e: &Share, + f: &Share, + g: &Share, + h: &Share, + ) -> Share { + let a_ = (a.a.0 as u128) + | ((b.a.0 as u128) << 16) + | ((c.a.0 as u128) << 32) + | ((d.a.0 as u128) << 48) + | ((e.a.0 as u128) << 64) + | ((f.a.0 as u128) << 80) + | ((g.a.0 as u128) << 96) + | ((h.a.0 as u128) << 112); + let b_ = (a.b.0 as u128) + | ((b.b.0 as u128) << 16) + | ((c.b.0 as u128) << 32) + | ((d.b.0 as u128) << 48) + | ((e.b.0 as u128) << 64) + | ((f.b.0 as u128) << 80) + | ((g.b.0 as u128) << 96) + | ((h.b.0 as u128) << 112); + + Share { + a: RingElement(a_), + b: RingElement(b_), + } + } + + fn share_transpose16x128(a: &[Share; 128]) -> [Share; 16] { + let mut j: u32; + let mut k: usize; + let mut m: u128; + let mut t: Share; + + let mut res = core::array::from_fn(|_| Share::default()); + + // pack results into Share128 datatypes + for (i, bb) in res.iter_mut().enumerate() { + *bb = Self::share128_from_share16s( + &a[i], + &a[i + 16], + &a[i + 32], + &a[i + 48], + &a[i + 64], + &a[i + 80], + &a[i + 96], + &a[i + 112], + ); + } + + // version of 128x128 transpose that only does the swaps needed for 16 bits + m = 0x00ff00ff00ff00ff00ff00ff00ff00ff; + j = 8; + while j != 0 { + k = 0; + while k < 16 { + t = ((&res[k] >> j) ^ &res[k + j as usize]) & m; + res[k + j as usize] ^= &t; + res[k] ^= t << j; + k = (k + j as usize + 1) & !(j as usize); + } + j >>= 1; + m = m ^ (m << j); + } + + res + } + + fn share_transpose16x64(a: &[Share; 64]) -> [Share; 16] { + let mut j: u32; + let mut k: usize; + let mut m: u64; + let mut t: Share; + + let mut res = core::array::from_fn(|_| Share::default()); + + // pack results into Share64 datatypes + for (i, bb) in res.iter_mut().enumerate() { + *bb = Self::share64_from_share16s(&a[i], &a[16 + i], &a[32 + i], &a[48 + i]); + } + + // version of 64x64 transpose that only does the swaps needed for 16 bits + m = 0x00ff00ff00ff00ff; + j = 8; + while j != 0 { + k = 0; + while k < 16 { + t = ((&res[k] >> j) ^ &res[k + j as usize]) & m; + res[k + j as usize] ^= &t; + res[k] ^= t << j; + k = (k + j as usize + 1) & !(j as usize); + } + j >>= 1; + m = m ^ (m << j); + } + + res + } + + pub fn transpose_pack_u64(self) -> Vec> { + self.transpose_pack_u64_with_len::<{ u16::BITS as usize }>() + } + + pub fn transpose_pack_u64_with_len(mut self) -> Vec> { + // Pad to multiple of 64 + let len = (self.shares.len() + 63) / 64; + self.shares.resize(len * 64, Share::default()); + + let mut res = (0..L) + .map(|_| VecShare::new_vec(vec![Share::default(); len])) + .collect::>(); + + for (j, x) in self.shares.chunks_exact(64).enumerate() { + let trans = Self::share_transpose16x64(x.try_into().unwrap()); + for (src, des) in trans.into_iter().zip(res.iter_mut()) { + des.shares[j] = src; + } + } + debug_assert_eq!(res.len(), L); + res + } + + pub fn transpose_pack_u128(self) -> Vec> { + self.transpose_pack_u128_with_len::<{ u16::BITS as usize }>() + } + + pub fn transpose_pack_u128_with_len(mut self) -> Vec> { + // Pad to multiple of 128 + let len = (self.shares.len() + 127) / 128; + self.shares.resize(len * 128, Share::default()); + + let mut res = (0..L) + .map(|_| VecShare::new_vec(vec![Share::default(); len])) + .collect::>(); + + for (j, x) in self.shares.chunks_exact(128).enumerate() { + let trans = Self::share_transpose16x128(x.try_into().unwrap()); + for (src, des) in trans.into_iter().zip(res.iter_mut()) { + des.shares[j] = src; + } + } + debug_assert_eq!(res.len(), L); + res + } +} + +impl VecShare { + fn share64_from_share32s(a: &Share, b: &Share) -> Share { + let a_ = (a.a.0 as u64) | ((b.a.0 as u64) << 32); + let b_ = (a.b.0 as u64) | ((b.b.0 as u64) << 32); + + Share { + a: RingElement(a_), + b: RingElement(b_), + } + } + + fn share128_from_share32s( + a: &Share, + b: &Share, + c: &Share, + d: &Share, + ) -> Share { + let a_ = (a.a.0 as u128) + | ((b.a.0 as u128) << 32) + | ((c.a.0 as u128) << 64) + | ((d.a.0 as u128) << 96); + let b_ = (a.b.0 as u128) + | ((b.b.0 as u128) << 32) + | ((c.b.0 as u128) << 64) + | ((d.b.0 as u128) << 96); + + Share { + a: RingElement(a_), + b: RingElement(b_), + } + } + + fn share_transpose32x128(a: &[Share; 128]) -> [Share; 32] { + let mut j: u32; + let mut k: usize; + let mut m: u128; + let mut t: Share; + + let mut res = core::array::from_fn(|_| Share::default()); + + // pack results into Share128 datatypes + for (i, bb) in res.iter_mut().enumerate() { + *bb = Self::share128_from_share32s(&a[i], &a[32 + i], &a[64 + i], &a[96 + i]); + } + + // version of 128x128 transpose that only does the swaps needed for 32 bits + m = 0x0000ffff0000ffff0000ffff0000ffff; + j = 16; + while j != 0 { + k = 0; + while k < 32 { + t = ((&res[k] >> j) ^ &res[k + j as usize]) & m; + res[k + j as usize] ^= &t; + res[k] ^= t << j; + k = (k + j as usize + 1) & !(j as usize); + } + j >>= 1; + m = m ^ (m << j); + } + + res + } + + fn share_transpose32x64(a: &[Share; 64]) -> [Share; 32] { + let mut j: u32; + let mut k: usize; + let mut m: u64; + let mut t: Share; + + let mut res = core::array::from_fn(|_| Share::default()); + + // pack results into Share64 datatypes + for (i, bb) in res.iter_mut().enumerate() { + *bb = Self::share64_from_share32s(&a[i], &a[32 + i]); + } + + // version of 64x64 transpose that only does the swaps needed for 32 bits + m = 0x0000ffff0000ffff; + j = 16; + while j != 0 { + k = 0; + while k < 32 { + t = ((&res[k] >> j) ^ &res[k + j as usize]) & m; + res[k + j as usize] ^= &t; + res[k] ^= t << j; + k = (k + j as usize + 1) & !(j as usize); + } + j >>= 1; + m = m ^ (m << j); + } + + res + } + + pub fn transpose_pack_u64(self) -> Vec> { + self.transpose_pack_u64_with_len::<{ u32::BITS as usize }>() + } + + pub fn transpose_pack_u64_with_len(mut self) -> Vec> { + // Pad to multiple of 64 + let len = (self.shares.len() + 63) / 64; + self.shares.resize(len * 64, Share::default()); + + let mut res = (0..L) + .map(|_| VecShare::new_vec(vec![Share::default(); len])) + .collect::>(); + + for (j, x) in self.shares.chunks_exact(64).enumerate() { + let trans = Self::share_transpose32x64(x.try_into().unwrap()); + for (src, des) in trans.into_iter().zip(res.iter_mut()) { + des.shares[j] = src; + } + } + debug_assert_eq!(res.len(), L); + res + } + + pub fn transpose_pack_u128(self) -> Vec> { + self.transpose_pack_u128_with_len::<{ u32::BITS as usize }>() + } + + pub fn transpose_pack_u128_with_len(mut self) -> Vec> { + // Pad to multiple of 128 + let len = (self.shares.len() + 127) / 128; + self.shares.resize(len * 128, Share::default()); + + let mut res = (0..L) + .map(|_| VecShare::new_vec(vec![Share::default(); len])) + .collect::>(); + + for (j, x) in self.shares.chunks_exact(128).enumerate() { + let trans = Self::share_transpose32x128(x.try_into().unwrap()); + for (src, des) in trans.into_iter().zip(res.iter_mut()) { + des.shares[j] = src; + } + } + debug_assert_eq!(res.len(), L); + res + } +} + +impl VecShare { + fn share128_from_share64s(a: &Share, b: &Share) -> Share { + let a_ = (a.a.0 as u128) | ((b.a.0 as u128) << 64); + let b_ = (a.b.0 as u128) | ((b.b.0 as u128) << 64); + + Share { + a: RingElement(a_), + b: RingElement(b_), + } + } + + fn share_transpose64x128(a: &[Share; 128]) -> [Share; 64] { + let mut j: u32; + let mut k: usize; + let mut m: u128; + let mut t: Share; + + let mut res = core::array::from_fn(|_| Share::default()); + + // pack results into Share128 datatypes + for (i, bb) in res.iter_mut().enumerate() { + *bb = Self::share128_from_share64s(&a[i], &a[i + 64]); + } + + // version of 128x128 transpose that only does the swaps needed for 64 bits + m = 0x00000000ffffffff00000000ffffffff; + j = 32; + while j != 0 { + k = 0; + while k < 64 { + t = ((&res[k] >> j) ^ &res[k + j as usize]) & m; + res[k + j as usize] ^= &t; + res[k] ^= t << j; + k = (k + j as usize + 1) & !(j as usize); + } + j >>= 1; + m = m ^ (m << j); + } + + res + } + + fn share_transpose64x64(a: &mut [Share; 64]) { + let mut j: u32; + let mut k: usize; + let mut m: u64; + let mut t: Share; + + m = 0x00000000ffffffff; + j = 32; + while j != 0 { + k = 0; + while k < 64 { + t = ((&a[k] >> j) ^ &a[k + j as usize]) & m; + a[k + j as usize] ^= &t; + a[k] ^= t << j; + k = (k + j as usize + 1) & !(j as usize); + } + j >>= 1; + m = m ^ (m << j); + } + } + + pub fn transpose_pack_u64(self) -> Vec> { + self.transpose_pack_u64_with_len::<{ u64::BITS as usize }>() + } + + pub fn transpose_pack_u64_with_len(mut self) -> Vec> { + // Pad to multiple of 64 + let len = (self.shares.len() + 63) / 64; + self.shares.resize(len * 64, Share::default()); + + let mut res = (0..L) + .map(|_| VecShare::new_vec(vec![Share::default(); len])) + .collect::>(); + + for (j, x) in self.shares.chunks_exact_mut(64).enumerate() { + Self::share_transpose64x64(x.try_into().unwrap()); + for (src, des) in x.iter().cloned().zip(res.iter_mut()) { + des.shares[j] = src; + } + } + debug_assert_eq!(res.len(), L); + res + } + + pub fn transpose_pack_u128(self) -> Vec> { + self.transpose_pack_u128_with_len::<{ u64::BITS as usize }>() + } + + pub fn transpose_pack_u128_with_len(mut self) -> Vec> { + // Pad to multiple of 128 + let len = (self.shares.len() + 127) / 128; + self.shares.resize(len * 128, Share::default()); + + let mut res = (0..L) + .map(|_| VecShare::new_vec(vec![Share::default(); len])) + .collect::>(); + + for (j, x) in self.shares.chunks_exact(128).enumerate() { + let trans = Self::share_transpose64x128(x.try_into().unwrap()); + for (src, des) in trans.into_iter().zip(res.iter_mut()) { + des.shares[j] = src; + } + } + debug_assert_eq!(res.len(), L); + res + } +} diff --git a/iris-mpc-cpu/src/utils.rs b/iris-mpc-cpu/src/utils.rs new file mode 100644 index 000000000..e6ef08cbd --- /dev/null +++ b/iris-mpc-cpu/src/utils.rs @@ -0,0 +1,26 @@ +use crate::{ + error::Error, + shares::{int_ring::IntRing2k, ring_impl::RingElement, vecshare::RingBytesIter}, +}; +use bytes::{Buf, Bytes, BytesMut}; + +pub struct Utils {} + +impl Utils { + pub fn ring_slice_to_bytes(vec: &[RingElement]) -> Bytes { + let slice = RingElement::convert_slice(vec); + let slice_: &[u8] = bytemuck::cast_slice(slice); + Bytes::copy_from_slice(slice_) + } + + pub fn ring_iter_from_bytes( + bytes: BytesMut, + n: usize, + ) -> Result, Error> { + if bytes.remaining() != n * T::BYTES { + return Err(Error::InvalidSize); + } + + Ok(RingBytesIter::new(bytes)) + } +}