diff --git a/Code/Cargo.toml b/Code/Cargo.toml index b96fe3227..24c12d303 100644 --- a/Code/Cargo.toml +++ b/Code/Cargo.toml @@ -17,6 +17,8 @@ license = "Apache-2.0" publish = false [workspace.dependencies] +async-trait = "0.1" +futures = "0.3" ed25519-consensus = "2.1.0" rand = { version = "0.8.5", features = ["std_rng"] } secrecy = "0.8.0" diff --git a/Code/driver/Cargo.toml b/Code/driver/Cargo.toml index bf2b34458..b81fc6273 100644 --- a/Code/driver/Cargo.toml +++ b/Code/driver/Cargo.toml @@ -13,4 +13,5 @@ malachite-common = { version = "0.1.0", path = "../common" } malachite-round = { version = "0.1.0", path = "../round" } malachite-vote = { version = "0.1.0", path = "../vote" } -secrecy.workspace = true +async-trait.workspace = true +secrecy.workspace = true diff --git a/Code/driver/src/client.rs b/Code/driver/src/client.rs index 1e240d730..3b3f3f69c 100644 --- a/Code/driver/src/client.rs +++ b/Code/driver/src/client.rs @@ -1,14 +1,17 @@ +use async_trait::async_trait; + use malachite_common::Context; /// Client for use by the [`Driver`](crate::Driver) to ask /// for a value to propose and validate proposals. +#[async_trait] pub trait Client where Ctx: Context, { /// Get the value to propose. - fn get_value(&self) -> Ctx::Value; + async fn get_value(&self) -> Ctx::Value; /// Validate a proposal. - fn validate_proposal(&self, proposal: &Ctx::Proposal) -> bool; + async fn validate_proposal(&self, proposal: &Ctx::Proposal) -> bool; } diff --git a/Code/driver/src/driver.rs b/Code/driver/src/driver.rs index 818a98196..4c0b66e27 100644 --- a/Code/driver/src/driver.rs +++ b/Code/driver/src/driver.rs @@ -20,7 +20,7 @@ use crate::event::Event; use crate::message::Message; use crate::ProposerSelector; -/// Driver for the state machine of the Malachite consensus engine. +/// Driver for the state machine of the Malachite consensus engine at a given height. #[derive(Clone, Debug)] pub struct Driver where @@ -73,16 +73,16 @@ where } } - fn get_value(&self) -> Ctx::Value { - self.client.get_value() + async fn get_value(&self) -> Ctx::Value { + self.client.get_value().await } - fn validate_proposal(&self, proposal: &Ctx::Proposal) -> bool { - self.client.validate_proposal(proposal) + async fn validate_proposal(&self, proposal: &Ctx::Proposal) -> bool { + self.client.validate_proposal(proposal).await } - pub fn execute(&mut self, msg: Event) -> Option> { - let round_msg = match self.apply(msg) { + pub async fn execute(&mut self, msg: Event) -> Option> { + let round_msg = match self.apply(msg).await { Some(msg) => msg, None => return None, }; @@ -113,16 +113,16 @@ where } } - fn apply(&mut self, msg: Event) -> Option> { + async fn apply(&mut self, msg: Event) -> Option> { match msg { - Event::NewRound(round) => self.apply_new_round(round), - Event::Proposal(proposal) => self.apply_proposal(proposal), + Event::NewRound(round) => self.apply_new_round(round).await, + Event::Proposal(proposal) => self.apply_proposal(proposal).await, Event::Vote(signed_vote) => self.apply_vote(signed_vote), Event::TimeoutElapsed(timeout) => self.apply_timeout(timeout), } } - fn apply_new_round(&mut self, round: Round) -> Option> { + async fn apply_new_round(&mut self, round: Round) -> Option> { let proposer_address = self .proposer_selector .select_proposer(round, &self.validator_set); @@ -134,7 +134,7 @@ where // TODO: Write this check differently, maybe just based on the address let event = if proposer.public_key() == &self.private_key.expose_secret().verifying_key() { - let value = self.get_value(); + let value = self.get_value().await; RoundEvent::NewRoundProposer(value) } else { RoundEvent::NewRound @@ -148,7 +148,7 @@ where self.apply_event(round, event) } - fn apply_proposal(&mut self, proposal: Ctx::Proposal) -> Option> { + async fn apply_proposal(&mut self, proposal: Ctx::Proposal) -> Option> { // Check that there is an ongoing round let Some(round_state) = self.round_states.get(&self.round) else { // TODO: Add logging @@ -172,7 +172,7 @@ where // TODO: Verify proposal signature (make some of these checks part of message validation) - let is_valid = self.validate_proposal(&proposal); + let is_valid = self.validate_proposal(&proposal).await; match proposal.pol_round() { Round::Nil => { diff --git a/Code/driver/src/lib.rs b/Code/driver/src/lib.rs index 23d4987a6..98be64ff0 100644 --- a/Code/driver/src/lib.rs +++ b/Code/driver/src/lib.rs @@ -22,3 +22,6 @@ pub use driver::Driver; pub use event::Event; pub use message::Message; pub use proposer::ProposerSelector; + +// Re-export `#[async_trait]` macro for convenience. +pub use async_trait::async_trait; diff --git a/Code/test/Cargo.toml b/Code/test/Cargo.toml index 6b536bce3..188d95fd8 100644 --- a/Code/test/Cargo.toml +++ b/Code/test/Cargo.toml @@ -14,6 +14,9 @@ malachite-driver = { version = "0.1.0", path = "../driver" } malachite-round = { version = "0.1.0", path = "../round" } malachite-vote = { version = "0.1.0", path = "../vote" } +futures = { workspace = true, features = ["executor"] } + +async-trait.workspace = true ed25519-consensus.workspace = true signature.workspace = true rand.workspace = true diff --git a/Code/test/src/client.rs b/Code/test/src/client.rs index cc66f2381..3973f766e 100644 --- a/Code/test/src/client.rs +++ b/Code/test/src/client.rs @@ -1,3 +1,5 @@ +use async_trait::async_trait; + use malachite_driver::Client; use crate::{Proposal, TestContext, Value}; @@ -13,12 +15,13 @@ impl TestClient { } } +#[async_trait] impl Client for TestClient { - fn get_value(&self) -> Value { + async fn get_value(&self) -> Value { self.value.clone() } - fn validate_proposal(&self, proposal: &Proposal) -> bool { + async fn validate_proposal(&self, proposal: &Proposal) -> bool { (self.is_valid)(proposal) } } diff --git a/Code/test/tests/driver.rs b/Code/test/tests/driver.rs index 1113ae4f5..ca3300a85 100644 --- a/Code/test/tests/driver.rs +++ b/Code/test/tests/driver.rs @@ -1,3 +1,4 @@ +use futures::executor::block_on; use malachite_common::{Context, Round, Timeout}; use malachite_driver::{Driver, Event, Message, ProposerSelector}; use malachite_round::state::{RoundValue, State, Step}; @@ -222,7 +223,7 @@ fn driver_steps_proposer() { .input_event .unwrap_or_else(|| previous_message.unwrap()); - let output = driver.execute(execute_message); + let output = block_on(driver.execute(execute_message)); assert_eq!(output, step.expected_output, "expected output message"); assert_eq!(driver.round, step.expected_round, "expected round"); @@ -418,7 +419,7 @@ fn driver_steps_not_proposer_valid() { .input_event .unwrap_or_else(|| previous_message.unwrap()); - let output = driver.execute(execute_message); + let output = block_on(driver.execute(execute_message)); assert_eq!(output, step.expected_output, "expected output message"); assert_eq!(driver.round, step.expected_round, "expected round"); @@ -560,7 +561,7 @@ fn driver_steps_not_proposer_invalid() { .input_event .unwrap_or_else(|| previous_message.unwrap()); - let output = driver.execute(execute_message); + let output = block_on(driver.execute(execute_message)); assert_eq!(output, step.expected_output, "expected output"); assert_eq!(driver.round, step.expected_round, "expected round"); @@ -765,7 +766,7 @@ fn driver_steps_not_proposer_timeout_multiple_rounds() { .input_event .unwrap_or_else(|| previous_message.unwrap()); - let output = driver.execute(execute_message); + let output = block_on(driver.execute(execute_message)); assert_eq!(output, step.expected_output, "expected output message"); assert_eq!(driver.round, step.expected_round, "expected round");