From 9c6efe75e762b26143c10fce25d8a3a06b2c53d5 Mon Sep 17 00:00:00 2001 From: Romain Ruetschi Date: Tue, 20 Feb 2024 08:38:04 +0100 Subject: [PATCH 1/4] chore(code): Replace manual impls of common traits with `derive-where` (#165) --- code/Cargo.toml | 1 + code/common/Cargo.toml | 3 +- code/common/src/signed_vote.rs | 36 +---------- code/driver/Cargo.toml | 2 + code/driver/src/error.rs | 18 +----- code/driver/src/output.rs | 67 +------------------ code/round/Cargo.toml | 2 + code/round/src/input.rs | 115 +-------------------------------- code/round/src/output.rs | 69 +------------------- code/round/src/state.rs | 58 +---------------- code/vote/Cargo.toml | 2 + code/vote/src/keeper.rs | 62 ++---------------- code/vote/src/round_votes.rs | 2 +- code/vote/src/round_weights.rs | 2 +- 14 files changed, 28 insertions(+), 411 deletions(-) diff --git a/code/Cargo.toml b/code/Cargo.toml index a76d53a86..f1dff359e 100644 --- a/code/Cargo.toml +++ b/code/Cargo.toml @@ -18,6 +18,7 @@ license = "Apache-2.0" publish = false [workspace.dependencies] +derive-where = "1.2.7" ed25519-consensus = "2.1.0" futures = "0.3" glob = "0.3.0" diff --git a/code/common/Cargo.toml b/code/common/Cargo.toml index b39479cba..90510329e 100644 --- a/code/common/Cargo.toml +++ b/code/common/Cargo.toml @@ -9,4 +9,5 @@ license.workspace = true publish.workspace = true [dependencies] -signature.workspace = true +derive-where.workspace = true +signature.workspace = true diff --git a/code/common/src/signed_vote.rs b/code/common/src/signed_vote.rs index 24226fd84..551a401fa 100644 --- a/code/common/src/signed_vote.rs +++ b/code/common/src/signed_vote.rs @@ -1,8 +1,9 @@ -use core::fmt; +use derive_where::derive_where; use crate::{Context, Signature, Vote}; /// A signed vote, ie. a vote emitted by a validator and signed by its private key. +#[derive_where(Clone, Debug, PartialEq, Eq)] pub struct SignedVote where Ctx: Context, @@ -28,36 +29,3 @@ where self.vote.validator_address() } } - -// NOTE: We have to derive these instances manually, otherwise -// the compiler would infer a Clone/Debug/PartialEq/Eq bound on `Ctx`, -// which may not hold for all contexts. - -impl Clone for SignedVote { - #[cfg_attr(coverage_nightly, coverage(off))] - fn clone(&self) -> Self { - Self { - vote: self.vote.clone(), - signature: self.signature.clone(), - } - } -} - -impl fmt::Debug for SignedVote { - #[cfg_attr(coverage_nightly, coverage(off))] - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("SignedVote") - .field("vote", &self.vote) - .field("signature", &self.signature) - .finish() - } -} - -impl PartialEq for SignedVote { - #[cfg_attr(coverage_nightly, coverage(off))] - fn eq(&self, other: &Self) -> bool { - self.vote == other.vote && self.signature == other.signature - } -} - -impl Eq for SignedVote {} diff --git a/code/driver/Cargo.toml b/code/driver/Cargo.toml index e72a6303f..8c0aa8c39 100644 --- a/code/driver/Cargo.toml +++ b/code/driver/Cargo.toml @@ -12,3 +12,5 @@ publish.workspace = true malachite-common = { version = "0.1.0", path = "../common" } malachite-round = { version = "0.1.0", path = "../round" } malachite-vote = { version = "0.1.0", path = "../vote" } + +derive-where.workspace = true diff --git a/code/driver/src/error.rs b/code/driver/src/error.rs index dd2ad9959..41f8e203b 100644 --- a/code/driver/src/error.rs +++ b/code/driver/src/error.rs @@ -1,9 +1,11 @@ use core::fmt; +use derive_where::derive_where; + use malachite_common::Context; /// The type of errors that can be yielded by the `Driver`. -#[derive(Clone, Debug)] +#[derive_where(Clone, Debug, PartialEq, Eq)] pub enum Error where Ctx: Context, @@ -27,17 +29,3 @@ where } } } - -impl PartialEq for Error -where - Ctx: Context, -{ - #[cfg_attr(coverage_nightly, coverage(off))] - fn eq(&self, other: &Self) -> bool { - match (self, other) { - (Error::ProposerNotFound(addr1), Error::ProposerNotFound(addr2)) => addr1 == addr2, - (Error::ValidatorNotFound(addr1), Error::ValidatorNotFound(addr2)) => addr1 == addr2, - _ => false, - } - } -} diff --git a/code/driver/src/output.rs b/code/driver/src/output.rs index ce4c222fc..3fc85f281 100644 --- a/code/driver/src/output.rs +++ b/code/driver/src/output.rs @@ -1,8 +1,9 @@ -use core::fmt; +use derive_where::derive_where; use malachite_common::{Context, Round, Timeout}; /// Messages emitted by the [`Driver`](crate::Driver) +#[derive_where(Clone, Debug, PartialEq, Eq)] pub enum Output where Ctx: Context, @@ -26,67 +27,3 @@ where /// The timeout tells the proposal builder how long it has to build a value. GetValue(Ctx::Height, Round, Timeout), } - -// NOTE: We have to derive these instances manually, otherwise -// the compiler would infer a Clone/Debug/PartialEq/Eq bound on `Ctx`, -// which may not hold for all contexts. - -impl Clone for Output { - #[cfg_attr(coverage_nightly, coverage(off))] - fn clone(&self) -> Self { - match self { - Output::NewRound(height, round) => Output::NewRound(*height, *round), - Output::Propose(proposal) => Output::Propose(proposal.clone()), - Output::Vote(signed_vote) => Output::Vote(signed_vote.clone()), - Output::Decide(round, value) => Output::Decide(*round, value.clone()), - Output::ScheduleTimeout(timeout) => Output::ScheduleTimeout(*timeout), - Output::GetValue(height, round, timeout) => Output::GetValue(*height, *round, *timeout), - } - } -} - -impl fmt::Debug for Output { - #[cfg_attr(coverage_nightly, coverage(off))] - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Output::NewRound(height, round) => write!(f, "NewRound({:?}, {:?})", height, round), - Output::Propose(proposal) => write!(f, "Propose({:?})", proposal), - Output::Vote(signed_vote) => write!(f, "Vote({:?})", signed_vote), - Output::Decide(round, value) => write!(f, "Decide({:?}, {:?})", round, value), - Output::ScheduleTimeout(timeout) => write!(f, "ScheduleTimeout({:?})", timeout), - Output::GetValue(height, round, timeout) => { - write!(f, "GetValue({:?}, {:?}, {:?})", height, round, timeout) - } - } - } -} - -impl PartialEq for Output { - #[cfg_attr(coverage_nightly, coverage(off))] - fn eq(&self, other: &Self) -> bool { - match (self, other) { - (Output::NewRound(height, round), Output::NewRound(other_height, other_round)) => { - height == other_height && round == other_round - } - (Output::Propose(proposal), Output::Propose(other_proposal)) => { - proposal == other_proposal - } - (Output::Vote(signed_vote), Output::Vote(other_signed_vote)) => { - signed_vote == other_signed_vote - } - (Output::Decide(round, value), Output::Decide(other_round, other_value)) => { - round == other_round && value == other_value - } - (Output::ScheduleTimeout(timeout), Output::ScheduleTimeout(other_timeout)) => { - timeout == other_timeout - } - ( - Output::GetValue(height, round, timeout), - Output::GetValue(other_height, other_round, other_timeout), - ) => height == other_height && round == other_round && timeout == other_timeout, - _ => false, - } - } -} - -impl Eq for Output {} diff --git a/code/round/Cargo.toml b/code/round/Cargo.toml index d672a7b42..7868fecd5 100644 --- a/code/round/Cargo.toml +++ b/code/round/Cargo.toml @@ -10,3 +10,5 @@ publish.workspace = true [dependencies] malachite-common = { version = "0.1.0", path = "../common" } + +derive-where.workspace = true diff --git a/code/round/src/input.rs b/code/round/src/input.rs index eb3d4c81b..5b5fc0701 100644 --- a/code/round/src/input.rs +++ b/code/round/src/input.rs @@ -1,10 +1,11 @@ //! Inputs to the round state machine. -use core::fmt; +use derive_where::derive_where; use malachite_common::{Context, Round, ValueId}; /// Input to the round state machine. +#[derive_where(Clone, Debug, PartialEq, Eq)] pub enum Input where Ctx: Context, @@ -73,115 +74,3 @@ where /// L65 TimeoutPrecommit, } - -impl Clone for Input { - #[cfg_attr(coverage_nightly, coverage(off))] - fn clone(&self) -> Self { - match self { - Input::NewRound(round) => Input::NewRound(*round), - Input::ProposeValue(value) => Input::ProposeValue(value.clone()), - Input::Proposal(proposal) => Input::Proposal(proposal.clone()), - Input::InvalidProposal => Input::InvalidProposal, - Input::ProposalAndPolkaPrevious(proposal) => { - Input::ProposalAndPolkaPrevious(proposal.clone()) - } - Input::InvalidProposalAndPolkaPrevious(proposal) => { - Input::InvalidProposalAndPolkaPrevious(proposal.clone()) - } - Input::PolkaAny => Input::PolkaAny, - Input::PolkaNil => Input::PolkaNil, - Input::ProposalAndPolkaCurrent(proposal) => { - Input::ProposalAndPolkaCurrent(proposal.clone()) - } - Input::PrecommitAny => Input::PrecommitAny, - Input::ProposalAndPrecommitValue(proposal) => { - Input::ProposalAndPrecommitValue(proposal.clone()) - } - Input::PrecommitValue(value_id) => Input::PrecommitValue(value_id.clone()), - Input::SkipRound(round) => Input::SkipRound(*round), - Input::TimeoutPropose => Input::TimeoutPropose, - Input::TimeoutPrevote => Input::TimeoutPrevote, - Input::TimeoutPrecommit => Input::TimeoutPrecommit, - } - } -} - -impl PartialEq for Input { - #[cfg_attr(coverage_nightly, coverage(off))] - fn eq(&self, other: &Self) -> bool { - match (self, other) { - (Input::NewRound(round), Input::NewRound(other_round)) => round == other_round, - (Input::ProposeValue(value), Input::ProposeValue(other_value)) => value == other_value, - (Input::Proposal(proposal), Input::Proposal(other_proposal)) => { - proposal == other_proposal - } - (Input::InvalidProposal, Input::InvalidProposal) => true, - ( - Input::ProposalAndPolkaPrevious(proposal), - Input::ProposalAndPolkaPrevious(other_proposal), - ) => proposal == other_proposal, - ( - Input::InvalidProposalAndPolkaPrevious(proposal), - Input::InvalidProposalAndPolkaPrevious(other_proposal), - ) => proposal == other_proposal, - (Input::PolkaAny, Input::PolkaAny) => true, - (Input::PolkaNil, Input::PolkaNil) => true, - ( - Input::ProposalAndPolkaCurrent(proposal), - Input::ProposalAndPolkaCurrent(other_proposal), - ) => proposal == other_proposal, - (Input::PrecommitAny, Input::PrecommitAny) => true, - ( - Input::ProposalAndPrecommitValue(proposal), - Input::ProposalAndPrecommitValue(other_proposal), - ) => proposal == other_proposal, - (Input::PrecommitValue(value_id), Input::PrecommitValue(other_value_id)) => { - value_id == other_value_id - } - (Input::SkipRound(round), Input::SkipRound(other_round)) => round == other_round, - (Input::TimeoutPropose, Input::TimeoutPropose) => true, - (Input::TimeoutPrevote, Input::TimeoutPrevote) => true, - (Input::TimeoutPrecommit, Input::TimeoutPrecommit) => true, - _ => false, - } - } -} - -impl Eq for Input {} - -impl fmt::Debug for Input -where - Ctx: Context, - Ctx::Value: fmt::Debug, - Ctx::Proposal: fmt::Debug, -{ - #[cfg_attr(coverage_nightly, coverage(off))] - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Input::NewRound(round) => write!(f, "NewRound({:?})", round), - Input::ProposeValue(value) => write!(f, "ProposeValue({:?})", value), - Input::Proposal(proposal) => write!(f, "Proposal({:?})", proposal), - Input::InvalidProposal => write!(f, "InvalidProposal"), - Input::ProposalAndPolkaPrevious(proposal) => { - write!(f, "ProposalAndPolkaPrevious({:?})", proposal) - } - Input::InvalidProposalAndPolkaPrevious(proposal) => { - write!(f, "InvalidProposalAndPolkaPrevious({:?})", proposal) - } - Input::PolkaAny => write!(f, "PolkaAny"), - Input::PolkaNil => write!(f, "PolkaNil"), - Input::ProposalAndPolkaCurrent(proposal) => { - write!(f, "ProposalAndPolkaCurrent({:?})", proposal) - } - Input::PrecommitAny => write!(f, "PrecommitAny"), - Input::ProposalAndPrecommitValue(proposal) => { - write!(f, "ProposalAndPrecommitValue({:?})", proposal) - } - Input::PrecommitValue(value_id) => write!(f, "PrecommitValue({:?})", value_id), - Input::SkipRound(round) => write!(f, "SkipRound({:?})", round), - Input::TimeoutPropose => write!(f, "TimeoutPropose"), - Input::TimeoutPrevote => write!(f, "TimeoutPrevote"), - Input::TimeoutPrecommit => write!(f, "TimeoutPrecommit"), - } - } -} diff --git a/code/round/src/output.rs b/code/round/src/output.rs index 921d1f95f..46ff981ed 100644 --- a/code/round/src/output.rs +++ b/code/round/src/output.rs @@ -1,12 +1,13 @@ //! Outputs of the round state machine. -use core::fmt; +use derive_where::derive_where; use malachite_common::{Context, NilOrVal, Round, Timeout, TimeoutStep, ValueId}; use crate::state::RoundValue; /// Output of the round state machine. +#[derive_where(Clone, Debug, PartialEq, Eq)] pub enum Output where Ctx: Context, @@ -81,69 +82,3 @@ impl Output { Output::Decision(RoundValue { round, value }) } } - -// NOTE: We have to derive these instances manually, otherwise -// the compiler would infer a Clone/Debug/PartialEq/Eq bound on `Ctx`, -// which may not hold for all contexts. - -impl Clone for Output { - #[cfg_attr(coverage_nightly, coverage(off))] - fn clone(&self) -> Self { - match self { - Output::NewRound(round) => Output::NewRound(*round), - Output::Proposal(proposal) => Output::Proposal(proposal.clone()), - Output::Vote(vote) => Output::Vote(vote.clone()), - Output::ScheduleTimeout(timeout) => Output::ScheduleTimeout(*timeout), - Output::GetValueAndScheduleTimeout(height, round, timeout) => { - Output::GetValueAndScheduleTimeout(*height, *round, *timeout) - } - Output::Decision(round_value) => Output::Decision(round_value.clone()), - } - } -} - -impl fmt::Debug for Output { - #[cfg_attr(coverage_nightly, coverage(off))] - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Output::NewRound(round) => write!(f, "NewRound({:?})", round), - Output::Proposal(proposal) => write!(f, "Proposal({:?})", proposal), - Output::Vote(vote) => write!(f, "Vote({:?})", vote), - Output::ScheduleTimeout(timeout) => write!(f, "ScheduleTimeout({:?})", timeout), - Output::GetValueAndScheduleTimeout(height, round, timeout) => { - write!( - f, - "GetValueAndScheduleTimeout({:?}, {:?}, {:?})", - height, round, timeout - ) - } - Output::Decision(round_value) => write!(f, "Decision({:?})", round_value), - } - } -} - -impl PartialEq for Output { - #[cfg_attr(coverage_nightly, coverage(off))] - fn eq(&self, other: &Self) -> bool { - match (self, other) { - (Output::NewRound(round), Output::NewRound(other_round)) => round == other_round, - (Output::Proposal(proposal), Output::Proposal(other_proposal)) => { - proposal == other_proposal - } - (Output::Vote(vote), Output::Vote(other_vote)) => vote == other_vote, - (Output::ScheduleTimeout(timeout), Output::ScheduleTimeout(other_timeout)) => { - timeout == other_timeout - } - ( - Output::GetValueAndScheduleTimeout(height, round, timeout), - Output::GetValueAndScheduleTimeout(other_height, other_round, other_timeout), - ) => height == other_height && round == other_round && timeout == other_timeout, - (Output::Decision(round_value), Output::Decision(other_round_value)) => { - round_value == other_round_value - } - _ => false, - } - } -} - -impl Eq for Output {} diff --git a/code/round/src/state.rs b/code/round/src/state.rs index c6b1bf0e8..1ec7612de 100644 --- a/code/round/src/state.rs +++ b/code/round/src/state.rs @@ -1,6 +1,6 @@ //! The state maintained by the round state machine -use core::fmt; +use derive_where::derive_where; use crate::input::Input; use crate::state_machine::Info; @@ -45,6 +45,7 @@ pub enum Step { } /// The state of the consensus state machine +#[derive_where(Clone, Debug, PartialEq, Eq)] pub struct State where Ctx: Context, @@ -124,10 +125,6 @@ where } } -// NOTE: We have to derive these instances manually, otherwise -// the compiler would infer a Clone/Debug/PartialEq/Eq bound on `Ctx`, -// which may not hold for all contexts. - impl Default for State where Ctx: Context, @@ -136,54 +133,3 @@ where Self::new(Ctx::Height::default(), Round::Nil) } } - -impl Clone for State -where - Ctx: Context, -{ - #[cfg_attr(coverage_nightly, coverage(off))] - fn clone(&self) -> Self { - Self { - height: self.height, - round: self.round, - step: self.step, - locked: self.locked.clone(), - valid: self.valid.clone(), - decision: self.decision.clone(), - } - } -} - -impl fmt::Debug for State -where - Ctx: Context, -{ - #[cfg_attr(coverage_nightly, coverage(off))] - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("State") - .field("height", &self.height) - .field("round", &self.round) - .field("step", &self.step) - .field("locked", &self.locked) - .field("valid", &self.valid) - .field("decision", &self.decision) - .finish() - } -} - -impl PartialEq for State -where - Ctx: Context, -{ - #[cfg_attr(coverage_nightly, coverage(off))] - fn eq(&self, other: &Self) -> bool { - self.height == other.height - && self.round == other.round - && self.step == other.step - && self.locked == other.locked - && self.valid == other.valid - && self.decision == other.decision - } -} - -impl Eq for State where Ctx: Context {} diff --git a/code/vote/Cargo.toml b/code/vote/Cargo.toml index 184172208..d81713663 100644 --- a/code/vote/Cargo.toml +++ b/code/vote/Cargo.toml @@ -10,3 +10,5 @@ publish.workspace = true [dependencies] malachite-common = { version = "0.1.0", path = "../common" } + +derive-where.workspace = true diff --git a/code/vote/src/keeper.rs b/code/vote/src/keeper.rs index bab3747d9..b55f7f61d 100644 --- a/code/vote/src/keeper.rs +++ b/code/vote/src/keeper.rs @@ -1,6 +1,6 @@ //! For tallying votes and emitting messages when certain thresholds are reached. -use core::fmt; +use derive_where::derive_where; use alloc::collections::{BTreeMap, BTreeSet}; @@ -33,6 +33,7 @@ pub enum Output { } /// Keeps track of votes and emitted outputs for a given round. +#[derive_where(Clone, Debug, PartialEq, Eq, Default)] pub struct PerRound where Ctx: Context, @@ -50,7 +51,7 @@ where Ctx: Context, { /// Create a new `PerRound` instance. - fn new() -> Self { + pub fn new() -> Self { Self { votes: RoundVotes::new(), addresses_weights: RoundWeights::new(), @@ -74,35 +75,8 @@ where } } -impl Clone for PerRound -where - Ctx: Context, -{ - #[cfg_attr(coverage_nightly, coverage(off))] - fn clone(&self) -> Self { - Self { - votes: self.votes.clone(), - addresses_weights: self.addresses_weights.clone(), - emitted_outputs: self.emitted_outputs.clone(), - } - } -} - -impl fmt::Debug for PerRound -where - Ctx: Context, -{ - #[cfg_attr(coverage_nightly, coverage(off))] - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("PerRound") - .field("votes", &self.votes) - .field("addresses_weights", &self.addresses_weights) - .field("emitted_outputs", &self.emitted_outputs) - .finish() - } -} - /// Keeps track of votes and emits messages when thresholds are reached. +#[derive_where(Clone, Debug)] pub struct VoteKeeper where Ctx: Context, @@ -261,31 +235,3 @@ fn threshold_to_output(typ: VoteType, threshold: Threshold) -> Opt (VoteType::Precommit, Threshold::Value(v)) => Some(Output::PrecommitValue(v)), } } - -impl Clone for VoteKeeper -where - Ctx: Context, -{ - #[cfg_attr(coverage_nightly, coverage(off))] - fn clone(&self) -> Self { - Self { - total_weight: self.total_weight, - threshold_params: self.threshold_params, - per_round: self.per_round.clone(), - } - } -} - -impl fmt::Debug for VoteKeeper -where - Ctx: Context, -{ - #[cfg_attr(coverage_nightly, coverage(off))] - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("VoteKeeper") - .field("total_weight", &self.total_weight) - .field("threshold_params", &self.threshold_params) - .field("per_round", &self.per_round) - .finish() - } -} diff --git a/code/vote/src/round_votes.rs b/code/vote/src/round_votes.rs index 2a1f29ccb..c4053bcda 100644 --- a/code/vote/src/round_votes.rs +++ b/code/vote/src/round_votes.rs @@ -6,7 +6,7 @@ use crate::count::VoteCount; use crate::{Threshold, ThresholdParam, Weight}; /// Tracks all the votes for a single round -#[derive(Clone, Debug)] +#[derive(Clone, Debug, PartialEq, Eq)] pub struct RoundVotes { /// The prevotes for this round. prevotes: VoteCount, diff --git a/code/vote/src/round_weights.rs b/code/vote/src/round_weights.rs index 026bf7ed1..2b82e1233 100644 --- a/code/vote/src/round_weights.rs +++ b/code/vote/src/round_weights.rs @@ -5,7 +5,7 @@ use alloc::collections::BTreeMap; use crate::Weight; /// Keeps track of the weight (ie. voting power) of each validator. -#[derive(Clone, Debug)] +#[derive(Clone, Debug, PartialEq, Eq)] pub struct RoundWeights
{ map: BTreeMap, } From ec43b9ac08fe23d820f2a2002b20315edd9ac592 Mon Sep 17 00:00:00 2001 From: Romain Ruetschi Date: Tue, 20 Feb 2024 08:44:32 +0100 Subject: [PATCH 2/4] Remove manual impls of common trait and use `derive-where` instead --- code/node/Cargo.toml | 1 + code/node/src/network/msg.rs | 83 ++++++++++++++++++------------------ 2 files changed, 43 insertions(+), 41 deletions(-) diff --git a/code/node/Cargo.toml b/code/node/Cargo.toml index 98933b090..51dbe53b8 100644 --- a/code/node/Cargo.toml +++ b/code/node/Cargo.toml @@ -14,6 +14,7 @@ malachite-driver = { version = "0.1.0", path = "../driver" } malachite-round = { version = "0.1.0", path = "../round" } malachite-vote = { version = "0.1.0", path = "../vote" } +derive-where = { workspace = true } futures = { workspace = true } tokio = { workspace = true, features = ["full"] } tokio-stream = { workspace = true } diff --git a/code/node/src/network/msg.rs b/code/node/src/network/msg.rs index 1ddad0a14..f14e56024 100644 --- a/code/node/src/network/msg.rs +++ b/code/node/src/network/msg.rs @@ -1,7 +1,8 @@ -use core::fmt; +use derive_where::derive_where; use malachite_common::Context; +#[derive_where(Clone, Debug, PartialEq, Eq)] pub enum Msg { Vote(Ctx::Vote), Proposal(Ctx::Proposal), @@ -30,43 +31,43 @@ impl Msg { } } } - -impl Clone for Msg { - fn clone(&self) -> Self { - match self { - Msg::Vote(vote) => Msg::Vote(vote.clone()), - Msg::Proposal(proposal) => Msg::Proposal(proposal.clone()), - - #[cfg(test)] - Msg::Dummy(n) => Msg::Dummy(*n), - } - } -} - -impl fmt::Debug for Msg { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Msg::Vote(vote) => write!(f, "Vote({vote:?})"), - Msg::Proposal(proposal) => write!(f, "Proposal({proposal:?})"), - - #[cfg(test)] - Msg::Dummy(n) => write!(f, "Dummy({n:?})"), - } - } -} - -impl PartialEq for Msg { - fn eq(&self, other: &Self) -> bool { - match (self, other) { - (Msg::Vote(vote), Msg::Vote(other_vote)) => vote == other_vote, - (Msg::Proposal(proposal), Msg::Proposal(other_proposal)) => proposal == other_proposal, - - #[cfg(test)] - (Msg::Dummy(n1), Msg::Dummy(n2)) => n1 == n2, - - _ => false, - } - } -} - -impl Eq for Msg {} +// +// impl Clone for Msg { +// fn clone(&self) -> Self { +// match self { +// Msg::Vote(vote) => Msg::Vote(vote.clone()), +// Msg::Proposal(proposal) => Msg::Proposal(proposal.clone()), +// +// #[cfg(test)] +// Msg::Dummy(n) => Msg::Dummy(*n), +// } +// } +// } +// +// impl fmt::Debug for Msg { +// fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { +// match self { +// Msg::Vote(vote) => write!(f, "Vote({vote:?})"), +// Msg::Proposal(proposal) => write!(f, "Proposal({proposal:?})"), +// +// #[cfg(test)] +// Msg::Dummy(n) => write!(f, "Dummy({n:?})"), +// } +// } +// } +// +// impl PartialEq for Msg { +// fn eq(&self, other: &Self) -> bool { +// match (self, other) { +// (Msg::Vote(vote), Msg::Vote(other_vote)) => vote == other_vote, +// (Msg::Proposal(proposal), Msg::Proposal(other_proposal)) => proposal == other_proposal, +// +// #[cfg(test)] +// (Msg::Dummy(n1), Msg::Dummy(n2)) => n1 == n2, +// +// _ => false, +// } +// } +// } +// +// impl Eq for Msg {} From 30e0765f4217101b44ed33a5c0f38f0842d55fa4 Mon Sep 17 00:00:00 2001 From: Romain Ruetschi Date: Tue, 20 Feb 2024 08:57:13 +0100 Subject: [PATCH 3/4] First draft of Protobuf definitions for Malachite types --- code/Cargo.toml | 16 +++- code/common/Cargo.toml | 5 +- code/common/src/height.rs | 3 + code/common/src/lib.rs | 2 + code/common/src/proposal.rs | 3 + code/common/src/round.rs | 17 ++++ code/common/src/validator_set.rs | 3 + code/common/src/value.rs | 31 ++++++++ code/common/src/vote.rs | 24 ++++++ code/node/Cargo.toml | 17 ++-- code/node/src/network/broadcast.rs | 7 +- code/node/src/network/msg.rs | 124 +++++++++++++++++------------ code/proto/Cargo.toml | 15 ++++ code/proto/build.rs | 9 +++ code/proto/src/lib.rs | 55 +++++++++++++ code/proto/src/malachite.proto | 43 ++++++++++ code/test/Cargo.toml | 12 +-- code/test/src/height.rs | 16 ++++ code/test/src/proposal.rs | 24 ++++++ code/test/src/validator_set.rs | 26 ++++++ code/test/src/value.rs | 45 +++++++++++ code/test/src/vote.rs | 35 ++++++++ 22 files changed, 463 insertions(+), 69 deletions(-) create mode 100644 code/proto/Cargo.toml create mode 100644 code/proto/build.rs create mode 100644 code/proto/src/lib.rs create mode 100644 code/proto/src/malachite.proto diff --git a/code/Cargo.toml b/code/Cargo.toml index 280b9c5af..136214354 100644 --- a/code/Cargo.toml +++ b/code/Cargo.toml @@ -5,7 +5,8 @@ members = [ "common", "driver", "itf", - "node", + "node", + "proto", "round", "test", "vote", @@ -19,6 +20,15 @@ license = "Apache-2.0" publish = false [workspace.dependencies] +malachite-common = { version = "0.1.0", path = "common" } +malachite-driver = { version = "0.1.0", path = "driver" } +malachite-itf = { version = "0.1.0", path = "itf" } +malachite-node = { version = "0.1.0", path = "node" } +malachite-proto = { version = "0.1.0", path = "proto" } +malachite-round = { version = "0.1.0", path = "round" } +malachite-test = { version = "0.1.0", path = "test" } +malachite-vote = { version = "0.1.0", path = "vote" } + derive-where = "1.2.7" ed25519-consensus = "2.1.0" futures = "0.3" @@ -27,11 +37,15 @@ itf = "0.2.2" num-bigint = "0.4.4" num-traits = "0.2.17" pretty_assertions = "1.4" +prost = "0.12.3" +prost-types = "0.12.3" +prost-build = "0.12.3" rand = { version = "0.8.5", features = ["std_rng"] } serde = "1.0" serde_json = "1.0" serde_with = "3.4" sha2 = "0.10.8" signature = "2.1.0" +thiserror = "1.0" tokio = "1.35.1" tokio-stream = "0.1" diff --git a/code/common/Cargo.toml b/code/common/Cargo.toml index 90510329e..dd356d70c 100644 --- a/code/common/Cargo.toml +++ b/code/common/Cargo.toml @@ -9,5 +9,6 @@ license.workspace = true publish.workspace = true [dependencies] -derive-where.workspace = true -signature.workspace = true +malachite-proto.workspace = true +derive-where.workspace = true +signature.workspace = true diff --git a/code/common/src/height.rs b/code/common/src/height.rs index 5f458c123..2293bf2c2 100644 --- a/code/common/src/height.rs +++ b/code/common/src/height.rs @@ -1,5 +1,7 @@ use core::fmt::Debug; +use malachite_proto::Protobuf; + /// Defines the requirements for a height type. /// /// A height denotes the number of blocks (values) created since the chain began. @@ -8,5 +10,6 @@ use core::fmt::Debug; pub trait Height where Self: Default + Copy + Clone + Debug + PartialEq + Eq + PartialOrd + Ord, + Self: Protobuf, { } diff --git a/code/common/src/lib.rs b/code/common/src/lib.rs index 3abfccdcd..3c7f702a6 100644 --- a/code/common/src/lib.rs +++ b/code/common/src/lib.rs @@ -12,6 +12,8 @@ #![cfg_attr(not(test), deny(clippy::unwrap_used, clippy::panic))] #![cfg_attr(coverage_nightly, feature(coverage_attribute))] +extern crate alloc; + mod context; mod height; mod proposal; diff --git a/code/common/src/proposal.rs b/code/common/src/proposal.rs index d63168493..c8c2a21db 100644 --- a/code/common/src/proposal.rs +++ b/code/common/src/proposal.rs @@ -1,11 +1,14 @@ use core::fmt::Debug; +use malachite_proto::Protobuf; + use crate::{Context, Round}; /// Defines the requirements for a proposal type. pub trait Proposal where Self: Clone + Debug + Eq + Send + Sync + 'static, + Self: Protobuf, Ctx: Context, { /// The height for which the proposal is for. diff --git a/code/common/src/round.rs b/code/common/src/round.rs index 713c8954d..57a5c2f34 100644 --- a/code/common/src/round.rs +++ b/code/common/src/round.rs @@ -1,4 +1,5 @@ use core::cmp; +use core::convert::Infallible; /// A round number. /// @@ -72,6 +73,22 @@ impl Ord for Round { } } +impl TryFrom for Round { + type Error = Infallible; + + fn try_from(proto: malachite_proto::Round) -> Result { + Ok(Self::new(proto.round)) + } +} + +impl From for malachite_proto::Round { + fn from(round: Round) -> malachite_proto::Round { + malachite_proto::Round { + round: round.as_i64(), + } + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/code/common/src/validator_set.rs b/code/common/src/validator_set.rs index ccc4dd1f6..1b9758948 100644 --- a/code/common/src/validator_set.rs +++ b/code/common/src/validator_set.rs @@ -1,5 +1,7 @@ use core::fmt::{Debug, Display}; +use malachite_proto::Protobuf; + use crate::{Context, PublicKey}; /// Voting power held by a validator. @@ -11,6 +13,7 @@ pub type VotingPower = u64; pub trait Address where Self: Clone + Debug + Display + Eq + Ord, + Self: Protobuf, { } diff --git a/code/common/src/value.rs b/code/common/src/value.rs index a16ac246e..60b47c69b 100644 --- a/code/common/src/value.rs +++ b/code/common/src/value.rs @@ -1,5 +1,7 @@ use core::fmt::Debug; +use malachite_proto::Protobuf; + /// Represents either `Nil` or a value of type `Value`. /// /// This type is isomorphic to `Option` but is more explicit about its intent. @@ -53,10 +55,39 @@ impl NilOrVal { } } +// impl TryFrom for NilOrVal +// where +// Value: From, // FIXME +// { +// type Error = String; +// +// fn try_from(proto: malachite_proto::Value) -> Result { +// match proto.value { +// Some(value) => Ok(NilOrVal::Val(Value::from(value))), // FIXME +// None => Ok(NilOrVal::Nil), +// } +// } +// } +// +// impl TryFrom for NilOrVal +// where +// Value: TryFrom>, // FIXME +// { +// type Error = String; +// +// fn try_from(proto: malachite_proto::ValueId) -> Result { +// match proto.value { +// Some(value) => Ok(NilOrVal::Val(Value::from(value))), // FIXME +// None => Ok(NilOrVal::Nil), +// } +// } +// } + /// Defines the requirements for the type of value to decide on. pub trait Value where Self: Clone + Debug + PartialEq + Eq + PartialOrd + Ord, + Self: Protobuf, { /// The type of the ID of the value. /// Typically a representation of the value with a lower memory footprint. diff --git a/code/common/src/vote.rs b/code/common/src/vote.rs index 82f9c77a9..6f1e43b3e 100644 --- a/code/common/src/vote.rs +++ b/code/common/src/vote.rs @@ -1,5 +1,8 @@ +use core::convert::Infallible; use core::fmt::Debug; +use malachite_proto::Protobuf; + use crate::{Context, NilOrVal, Round, Value}; /// A type of vote. @@ -12,6 +15,26 @@ pub enum VoteType { Precommit, } +impl TryFrom for VoteType { + type Error = Infallible; + + fn try_from(vote_type: malachite_proto::VoteType) -> Result { + match vote_type { + malachite_proto::VoteType::Prevote => Ok(VoteType::Prevote), + malachite_proto::VoteType::Precommit => Ok(VoteType::Precommit), + } + } +} + +impl From for malachite_proto::VoteType { + fn from(vote_type: VoteType) -> malachite_proto::VoteType { + match vote_type { + VoteType::Prevote => malachite_proto::VoteType::Prevote, + VoteType::Precommit => malachite_proto::VoteType::Precommit, + } + } +} + /// Defines the requirements for a vote. /// /// Votes are signed messages from validators for a particular value which @@ -19,6 +42,7 @@ pub enum VoteType { pub trait Vote where Self: Clone + Debug + Eq + Send + Sync + 'static, + Self: Protobuf, Ctx: Context, { /// The height for which the vote is for. diff --git a/code/node/Cargo.toml b/code/node/Cargo.toml index 51dbe53b8..a6ccbd83b 100644 --- a/code/node/Cargo.toml +++ b/code/node/Cargo.toml @@ -9,15 +9,18 @@ license.workspace = true publish.workspace = true [dependencies] -malachite-common = { version = "0.1.0", path = "../common" } -malachite-driver = { version = "0.1.0", path = "../driver" } -malachite-round = { version = "0.1.0", path = "../round" } -malachite-vote = { version = "0.1.0", path = "../vote" } +malachite-common.workspace = true +malachite-driver.workspace = true +malachite-round.workspace = true +malachite-vote.workspace = true +malachite-proto.workspace = true derive-where = { workspace = true } -futures = { workspace = true } -tokio = { workspace = true, features = ["full"] } +futures = { workspace = true } +tokio = { workspace = true, features = ["full"] } tokio-stream = { workspace = true } +prost = { workspace = true } +prost-types = { workspace = true } [dev-dependencies] -malachite-test = { version = "0.1.0", path = "../test" } +malachite-test.workspace = true diff --git a/code/node/src/network/broadcast.rs b/code/node/src/network/broadcast.rs index f09d90e38..c551b4541 100644 --- a/code/node/src/network/broadcast.rs +++ b/code/node/src/network/broadcast.rs @@ -3,11 +3,12 @@ use std::fmt::Debug; use std::net::SocketAddr; use futures::channel::oneshot; -use malachite_common::Context; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::{TcpListener, TcpStream}; use tokio::sync::{broadcast, mpsc}; +use malachite_common::Context; + use super::{Msg, Network, PeerId}; pub enum PeerEvent { @@ -115,7 +116,7 @@ async fn connect_to_peer( println!("[{id}] Sending message to {peer_info}: {msg:?}"); - let bytes = msg.as_bytes(); + let bytes = msg.to_network_bytes().unwrap(); stream.write_u32(bytes.len() as u32).await.unwrap(); stream.write_all(&bytes).await.unwrap(); stream.flush().await.unwrap(); @@ -151,7 +152,7 @@ async fn listen( let len = socket.read_u32().await.unwrap(); let mut buf = vec![0; len as usize]; socket.read_exact(&mut buf).await.unwrap(); - let msg: Msg = Msg::from_bytes(&buf); + let msg: Msg = Msg::from_network_bytes(&buf).unwrap(); println!( "[{id}] Received message from {peer_id} ({addr}): {msg:?}", diff --git a/code/node/src/network/msg.rs b/code/node/src/network/msg.rs index f14e56024..841a8668a 100644 --- a/code/node/src/network/msg.rs +++ b/code/node/src/network/msg.rs @@ -1,6 +1,11 @@ use derive_where::derive_where; +use prost::Message; +use prost_types::Any; + use malachite_common::Context; +use malachite_proto::Error as ProtoError; +use malachite_proto::Protobuf; #[derive_where(Clone, Debug, PartialEq, Eq)] pub enum Msg { @@ -8,66 +13,85 @@ pub enum Msg { Proposal(Ctx::Proposal), #[cfg(test)] - Dummy(u32), + Dummy(u64), } impl Msg { - pub fn as_bytes(&self) -> Vec { - match self { - Msg::Vote(_vote) => todo!(), - Msg::Proposal(_proposal) => todo!(), + pub fn from_network_bytes(bytes: &[u8]) -> Result { + Protobuf::::from_bytes(bytes) + } + + pub fn to_network_bytes(&self) -> Result, ProtoError> { + Protobuf::::to_bytes(self) + } +} + +impl Protobuf for Msg { + fn from_bytes(bytes: &[u8]) -> Result + where + Self: Sized, + { + use prost::Name; + let any = Any::decode(bytes)?; + + if any.type_url == malachite_proto::Vote::type_url() { + let vote = Ctx::Vote::from_bytes(&any.value)?; + Ok(Msg::Vote(vote)) + } else if any.type_url == malachite_proto::Proposal::type_url() { + let proposal = Ctx::Proposal::from_bytes(&any.value)?; + Ok(Msg::Proposal(proposal)) + } else if any.type_url == "malachite.proto.Dummy" { #[cfg(test)] - Msg::Dummy(n) => [&[0x42], n.to_be_bytes().as_slice()].concat(), + { + let value = u64::from_be_bytes(any.value.try_into().unwrap()); + Ok(Msg::Dummy(value)) + } + + #[cfg(not(test))] + { + Err(malachite_proto::Error::Other( + "unknown message type: malachite.proto.Dummy".to_string(), + )) + } + } else { + Err(malachite_proto::Error::Other(format!( + "unknown message type: {}", + any.type_url + ))) } } - pub fn from_bytes(bytes: &[u8]) -> Self { - match bytes { + fn into_bytes(self) -> Result, malachite_proto::Error> { + use prost::Name; + + match self { + Msg::Vote(vote) => { + let any = Any { + type_url: malachite_proto::Vote::type_url(), + value: vote.into_bytes()?, + }; + + Ok(any.encode_to_vec()) + } + Msg::Proposal(proposal) => { + let any = Any { + type_url: malachite_proto::Proposal::type_url(), + value: proposal.into_bytes()?, + }; + + Ok(any.encode_to_vec()) + } + #[cfg(test)] - [0x42, a, b, c, d] => Msg::Dummy(u32::from_be_bytes([*a, *b, *c, *d])), + Msg::Dummy(value) => { + let any = Any { + type_url: "malachite.proto.Dummy".to_string(), + value: value.to_be_bytes().to_vec(), + }; - _ => todo!(), + Ok(any.encode_to_vec()) + } } } } -// -// impl Clone for Msg { -// fn clone(&self) -> Self { -// match self { -// Msg::Vote(vote) => Msg::Vote(vote.clone()), -// Msg::Proposal(proposal) => Msg::Proposal(proposal.clone()), -// -// #[cfg(test)] -// Msg::Dummy(n) => Msg::Dummy(*n), -// } -// } -// } -// -// impl fmt::Debug for Msg { -// fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { -// match self { -// Msg::Vote(vote) => write!(f, "Vote({vote:?})"), -// Msg::Proposal(proposal) => write!(f, "Proposal({proposal:?})"), -// -// #[cfg(test)] -// Msg::Dummy(n) => write!(f, "Dummy({n:?})"), -// } -// } -// } -// -// impl PartialEq for Msg { -// fn eq(&self, other: &Self) -> bool { -// match (self, other) { -// (Msg::Vote(vote), Msg::Vote(other_vote)) => vote == other_vote, -// (Msg::Proposal(proposal), Msg::Proposal(other_proposal)) => proposal == other_proposal, -// -// #[cfg(test)] -// (Msg::Dummy(n1), Msg::Dummy(n2)) => n1 == n2, -// -// _ => false, -// } -// } -// } -// -// impl Eq for Msg {} diff --git a/code/proto/Cargo.toml b/code/proto/Cargo.toml new file mode 100644 index 000000000..7d09ef8d1 --- /dev/null +++ b/code/proto/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "malachite-proto" +version.workspace = true +edition.workspace = true +repository.workspace = true +license.workspace = true +publish.workspace = true + +[dependencies] +prost.workspace = true +prost-types.workspace = true +thiserror.workspace = true + +[build-dependencies] +prost-build.workspace = true diff --git a/code/proto/build.rs b/code/proto/build.rs new file mode 100644 index 000000000..c02bf2f34 --- /dev/null +++ b/code/proto/build.rs @@ -0,0 +1,9 @@ +use std::io::Result; + +fn main() -> Result<()> { + let mut config = prost_build::Config::new(); + config.enable_type_names(); + config.compile_protos(&["src/malachite.proto"], &["src/"])?; + + Ok(()) +} diff --git a/code/proto/src/lib.rs b/code/proto/src/lib.rs new file mode 100644 index 000000000..e29435191 --- /dev/null +++ b/code/proto/src/lib.rs @@ -0,0 +1,55 @@ +use std::fmt::Display; + +use thiserror::Error; + +use prost::{DecodeError, EncodeError, Message}; + +include!(concat!(env!("OUT_DIR"), "/malachite.rs")); + +#[derive(Debug, Error)] +pub enum Error { + #[error("Failed to decode Protobuf message")] + Decode(#[from] DecodeError), + + #[error("Failed to encode Protobuf message")] + Encode(#[from] EncodeError), + + #[error("{0}")] + Other(String), +} + +pub trait Protobuf { + fn from_bytes(bytes: &[u8]) -> Result + where + Self: Sized; + fn into_bytes(self) -> Result, Error>; + + fn to_bytes(&self) -> Result, Error> + where + Self: Clone, + { + self.clone().into_bytes() + } +} + +impl Protobuf for T +where + T: TryFrom, + T::Error: Display, + Proto: Message + From + Default, +{ + fn from_bytes(bytes: &[u8]) -> Result + where + Self: Sized, + { + let proto = Proto::decode(bytes)?; + Self::try_from(proto).map_err(|e| Error::Other(e.to_string())) + } + + fn into_bytes(self) -> Result, Error> { + let proto = Proto::from(self); + let mut bytes = Vec::with_capacity(proto.encoded_len()); + proto.encode(&mut bytes)?; + Ok(bytes) + } +} diff --git a/code/proto/src/malachite.proto b/code/proto/src/malachite.proto new file mode 100644 index 000000000..7d145e58b --- /dev/null +++ b/code/proto/src/malachite.proto @@ -0,0 +1,43 @@ +syntax = "proto3"; + +package malachite; + +message Height { + uint64 value = 1; +} + +message Address { + bytes value = 1; +} + +message Value { + optional bytes value = 2; +} + +message ValueId { + optional bytes value = 1; +} + +message Round { + int64 round = 1; +} + +message Vote { + VoteType vote_type = 1; + Height height = 2; + Round round = 3; + ValueId value = 4; + Address validator_address = 5; +} + +message Proposal { + Height height = 1; + Round round = 2; + Value value = 3; + Round pol_round = 4; +} + +enum VoteType { + PREVOTE = 0; + PRECOMMIT = 1; +} diff --git a/code/test/Cargo.toml b/code/test/Cargo.toml index e2463c636..77d01986f 100644 --- a/code/test/Cargo.toml +++ b/code/test/Cargo.toml @@ -9,13 +9,13 @@ repository.workspace = true license.workspace = true [dependencies] -malachite-common = { version = "0.1.0", path = "../common" } -malachite-driver = { version = "0.1.0", path = "../driver" } -malachite-round = { version = "0.1.0", path = "../round" } -malachite-vote = { version = "0.1.0", path = "../vote" } - -futures = { workspace = true, features = ["executor"] } +malachite-common.workspace = true +malachite-driver.workspace = true +malachite-round.workspace = true +malachite-vote.workspace = true +malachite-proto.workspace = true +futures = { workspace = true, features = ["executor"] } ed25519-consensus.workspace = true signature.workspace = true rand.workspace = true diff --git a/code/test/src/height.rs b/code/test/src/height.rs index 3f4a32c3c..e2737ca76 100644 --- a/code/test/src/height.rs +++ b/code/test/src/height.rs @@ -1,3 +1,5 @@ +use std::convert::Infallible; + /// A blockchain height #[derive(Copy, Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord)] pub struct Height(u64); @@ -13,3 +15,17 @@ impl Height { } impl malachite_common::Height for Height {} + +impl TryFrom for Height { + type Error = Infallible; + + fn try_from(height: malachite_proto::Height) -> Result { + Ok(Self(height.value)) + } +} + +impl From for malachite_proto::Height { + fn from(height: Height) -> malachite_proto::Height { + malachite_proto::Height { value: height.0 } + } +} diff --git a/code/test/src/proposal.rs b/code/test/src/proposal.rs index 407b4be02..012084ee8 100644 --- a/code/test/src/proposal.rs +++ b/code/test/src/proposal.rs @@ -39,3 +39,27 @@ impl malachite_common::Proposal for Proposal { self.pol_round } } + +impl TryFrom for Proposal { + type Error = String; + + fn try_from(proposal: malachite_proto::Proposal) -> Result { + Ok(Self { + height: proposal.height.unwrap().try_into().unwrap(), // infallible + round: proposal.round.unwrap().try_into().unwrap(), // infallible + value: proposal.value.unwrap().try_into().unwrap(), // FIXME + pol_round: proposal.pol_round.unwrap().try_into().unwrap(), // infallible + }) + } +} + +impl From for malachite_proto::Proposal { + fn from(proposal: Proposal) -> malachite_proto::Proposal { + malachite_proto::Proposal { + height: Some(proposal.height.into()), + round: Some(proposal.round.into()), + value: Some(proposal.value.into()), + pol_round: Some(proposal.pol_round.into()), + } + } +} diff --git a/code/test/src/validator_set.rs b/code/test/src/validator_set.rs index 91b903496..fd08eea27 100644 --- a/code/test/src/validator_set.rs +++ b/code/test/src/validator_set.rs @@ -36,6 +36,32 @@ impl fmt::Display for Address { impl malachite_common::Address for Address {} +impl TryFrom for Address { + type Error = String; + + fn try_from(proto: malachite_proto::Address) -> Result { + if proto.value.len() != Self::LENGTH { + return Err(format!( + "Invalid address length: expected {}, got {}", + Self::LENGTH, + proto.value.len() + )); + } + + let mut address = [0; Self::LENGTH]; + address.copy_from_slice(&proto.value); + Ok(Self(address)) + } +} + +impl From
for malachite_proto::Address { + fn from(address: Address) -> Self { + Self { + value: address.0.to_vec(), + } + } +} + /// A validator is a public key and voting power #[derive(Clone, Debug, PartialEq, Eq)] pub struct Validator { diff --git a/code/test/src/value.rs b/code/test/src/value.rs index bed49dbf6..185c9952a 100644 --- a/code/test/src/value.rs +++ b/code/test/src/value.rs @@ -17,6 +17,28 @@ impl From for ValueId { } } +impl TryFrom for ValueId { + type Error = String; + + fn try_from(proto: malachite_proto::ValueId) -> Result { + match proto.value { + Some(bytes) => { + let bytes = <[u8; 8]>::try_from(bytes).unwrap(); // FIXME + Ok(ValueId::new(u64::from_be_bytes(bytes))) + } + None => Err("ValueId not present".to_string()), + } + } +} + +impl From for malachite_proto::ValueId { + fn from(value: ValueId) -> malachite_proto::ValueId { + malachite_proto::ValueId { + value: Some(value.0.to_be_bytes().to_vec()), + } + } +} + /// The value to decide on #[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] pub struct Value(u64); @@ -42,3 +64,26 @@ impl malachite_common::Value for Value { self.id() } } + +impl TryFrom for Value { + type Error = String; + + fn try_from(proto: malachite_proto::Value) -> Result { + match proto.value { + Some(bytes) => { + let bytes = <[u8; 8]>::try_from(bytes).unwrap(); // FIXME + let value = u64::from_be_bytes(bytes); + Ok(Value::new(value)) + } + None => Err("Value not present".to_string()), + } + } +} + +impl From for malachite_proto::Value { + fn from(value: Value) -> malachite_proto::Value { + malachite_proto::Value { + value: Some(value.0.to_be_bytes().to_vec()), + } + } +} diff --git a/code/test/src/vote.rs b/code/test/src/vote.rs index 31b9a250d..0acebd12e 100644 --- a/code/test/src/vote.rs +++ b/code/test/src/vote.rs @@ -98,3 +98,38 @@ impl malachite_common::Vote for Vote { &self.validator_address } } + +impl TryFrom for Vote { + type Error = String; + + fn try_from(vote: malachite_proto::Vote) -> Result { + Ok(Self { + typ: malachite_proto::VoteType::try_from(vote.vote_type) + .unwrap() + .try_into() + .unwrap(), // infallible + height: vote.height.unwrap().try_into().unwrap(), // infallible + round: vote.round.unwrap().try_into().unwrap(), // infallible + value: match vote.value { + Some(value) => NilOrVal::Val(value.try_into().unwrap()), // FIXME + None => NilOrVal::Nil, + }, + validator_address: vote.validator_address.unwrap().try_into().unwrap(), // FIXME + }) + } +} + +impl From for malachite_proto::Vote { + fn from(vote: Vote) -> malachite_proto::Vote { + malachite_proto::Vote { + vote_type: i32::from(malachite_proto::VoteType::from(vote.typ)), + height: Some(vote.height.into()), + round: Some(vote.round.into()), + value: match vote.value { + NilOrVal::Nil => None, + NilOrVal::Val(v) => Some(v.into()), + }, + validator_address: Some(vote.validator_address.into()), + } + } +} From 57c0094fc4837a1ec046d32ad3f67d7e322e0e85 Mon Sep 17 00:00:00 2001 From: Romain Ruetschi Date: Tue, 20 Feb 2024 09:10:47 +0100 Subject: [PATCH 4/4] Cleanup --- code/node/src/network/broadcast.rs | 110 +++++++++++++++++++++-------- 1 file changed, 81 insertions(+), 29 deletions(-) diff --git a/code/node/src/network/broadcast.rs b/code/node/src/network/broadcast.rs index c551b4541..81f70c622 100644 --- a/code/node/src/network/broadcast.rs +++ b/code/node/src/network/broadcast.rs @@ -105,7 +105,10 @@ async fn connect_to_peer( let mut per_peer_rx = per_peer_tx.subscribe(); - send_peer_id(&mut stream, id.clone()).await; + Frame::::PeerId(id.clone()) + .write(&mut stream) + .await + .unwrap(); tokio::spawn(async move { loop { @@ -115,11 +118,7 @@ async fn connect_to_peer( } println!("[{id}] Sending message to {peer_info}: {msg:?}"); - - let bytes = msg.to_network_bytes().unwrap(); - stream.write_u32(bytes.len() as u32).await.unwrap(); - stream.write_all(&bytes).await.unwrap(); - stream.flush().await.unwrap(); + Frame::Msg(msg).write(&mut stream).await.unwrap(); } }); } @@ -143,40 +142,87 @@ async fn listen( peer = socket.peer_addr().unwrap() ); - let peer_id = read_peer_id(&mut socket).await; + let Frame::PeerId(peer_id) = Frame::::read(&mut socket).await.unwrap() else { + eprintln!("[{id}] Peer did not send its ID"); + continue; + }; let id = id.clone(); let tx_received = tx_received.clone(); tokio::spawn(async move { - let len = socket.read_u32().await.unwrap(); - let mut buf = vec![0; len as usize]; - socket.read_exact(&mut buf).await.unwrap(); - let msg: Msg = Msg::from_network_bytes(&buf).unwrap(); - - println!( - "[{id}] Received message from {peer_id} ({addr}): {msg:?}", - addr = socket.peer_addr().unwrap(), - ); - - tx_received.send((peer_id.clone(), msg)).await.unwrap(); // FIXME + loop { + let Frame::Msg(msg) = Frame::::read(&mut socket).await.unwrap() else { + eprintln!("[{id}] Peer did not send a message"); + return; + }; + + println!( + "[{id}] Received message from {peer_id} ({addr}): {msg:?}", + addr = socket.peer_addr().unwrap(), + ); + + tx_received.send((peer_id.clone(), msg)).await.unwrap(); // FIXME + } }); } } -async fn send_peer_id(socket: &mut TcpStream, id: PeerId) { - let bytes = id.0.as_bytes(); - socket.write_u32(bytes.len() as u32).await.unwrap(); - socket.write_all(bytes).await.unwrap(); - socket.flush().await.unwrap(); +pub enum Frame { + PeerId(PeerId), + Msg(Msg), } -async fn read_peer_id(socket: &mut TcpStream) -> PeerId { - let len = socket.read_u32().await.unwrap(); - let mut buf = vec![0; len as usize]; - socket.read_exact(&mut buf).await.unwrap(); - let id = String::from_utf8(buf).unwrap(); - PeerId(id) +impl Frame { + /// Write a frame to the given writer, prefixing it with its discriminant. + pub async fn write( + &self, + writer: &mut W, + ) -> Result<(), std::io::Error> { + match self { + Frame::PeerId(id) => { + writer.write_u8(0x40).await?; + let bytes = id.0.as_bytes(); + writer.write_u32(bytes.len() as u32).await?; + writer.write_all(bytes).await?; + writer.flush().await?; + } + Frame::Msg(msg) => { + writer.write_u8(0x41).await?; + let bytes = msg + .to_network_bytes() + .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; + writer.write_u32(bytes.len() as u32).await?; + writer.write_all(&bytes).await?; + writer.flush().await?; + } + } + + Ok(()) + } + + pub async fn read(reader: &mut R) -> Result { + let discriminant = reader.read_u8().await?; + + match discriminant { + 0x40 => { + let len = reader.read_u32().await?; + let mut buf = vec![0; len as usize]; + reader.read_exact(&mut buf).await?; + Ok(Frame::PeerId(PeerId(String::from_utf8(buf).unwrap()))) + } + 0x41 => { + let len = reader.read_u32().await?; + let mut buf = vec![0; len as usize]; + reader.read_exact(&mut buf).await?; + Ok(Frame::Msg(Msg::from_network_bytes(&buf).unwrap())) + } + _ => Err(std::io::Error::new( + std::io::ErrorKind::InvalidInput, + format!("Invalid frame discriminant: {discriminant}"), + )), + } + } } pub struct Handle { @@ -274,6 +320,7 @@ mod tests { handle3.connect_to_peer(peer2_info.clone()).await; handle1.broadcast(Msg::Dummy(1)).await; + handle1.broadcast(Msg::Dummy(2)).await; let deadline = Duration::from_millis(100); @@ -281,5 +328,10 @@ mod tests { dbg!(&msg2); let msg3 = timeout(deadline, handle3.recv()).await.unwrap(); dbg!(&msg3); + + let msg4 = timeout(deadline, handle2.recv()).await.unwrap(); + dbg!(&msg4); + let msg5 = timeout(deadline, handle3.recv()).await.unwrap(); + dbg!(&msg5); } }