Skip to content

Commit

Permalink
feat(driver): Allow the driver to raise errors in some occasions
Browse files Browse the repository at this point in the history
  • Loading branch information
romac committed Nov 13, 2023
1 parent b2d63ed commit 0af13bc
Show file tree
Hide file tree
Showing 6 changed files with 112 additions and 30 deletions.
4 changes: 2 additions & 2 deletions Code/common/src/validator_set.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use core::fmt::Debug;
use core::fmt::{Debug, Display};

use crate::{Context, PublicKey};

Expand All @@ -12,7 +12,7 @@ pub type VotingPower = u64;
/// TODO: Keep this trait or just add the bounds to Consensus::Address?
pub trait Address
where
Self: Clone + Debug + Eq + Ord,
Self: Clone + Debug + Display + Eq + Ord,
{
}

Expand Down
65 changes: 41 additions & 24 deletions Code/driver/src/driver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use malachite_vote::Threshold;
use crate::env::Env as DriverEnv;
use crate::event::Event;
use crate::message::Message;
use crate::Error;
use crate::ProposerSelector;

/// Driver for the state machine of the Malachite consensus engine at a given height.
Expand Down Expand Up @@ -76,58 +77,66 @@ where
self.env.validate_proposal(proposal).await
}

pub async fn execute(&mut self, msg: Event<Ctx>) -> Option<Message<Ctx>> {
let round_msg = match self.apply(msg).await {
pub async fn execute(&mut self, msg: Event<Ctx>) -> Result<Option<Message<Ctx>>, Error<Ctx>> {
let round_msg = match self.apply(msg).await? {
Some(msg) => msg,
None => return None,
None => return Ok(None),
};

match round_msg {
let msg = match round_msg {
RoundMessage::NewRound(round) => {
// XXX: Check if there is an existing state?
assert!(self.round < round);
Some(Message::NewRound(round))
Message::NewRound(round)
}

RoundMessage::Proposal(proposal) => {
// sign the proposal
Some(Message::Propose(proposal))
Message::Propose(proposal)
}

RoundMessage::Vote(vote) => {
let signed_vote = self.ctx.sign_vote(vote);
Some(Message::Vote(signed_vote))
Message::Vote(signed_vote)
}

RoundMessage::ScheduleTimeout(timeout) => Some(Message::ScheduleTimeout(timeout)),
RoundMessage::ScheduleTimeout(timeout) => Message::ScheduleTimeout(timeout),

RoundMessage::Decision(value) => {
// TODO: update the state
Some(Message::Decide(value.round, value.value))
Message::Decide(value.round, value.value)
}
}
};

Ok(Some(msg))
}

async fn apply(&mut self, msg: Event<Ctx>) -> Option<RoundMessage<Ctx>> {
async fn apply(&mut self, msg: Event<Ctx>) -> Result<Option<RoundMessage<Ctx>>, Error<Ctx>> {
match msg {
Event::NewRound(round) => self.apply_new_round(round).await,
Event::Proposal(proposal) => self.apply_proposal(proposal).await,
Event::Proposal(proposal) => Ok(self.apply_proposal(proposal).await),
Event::Vote(signed_vote) => self.apply_vote(signed_vote),
Event::TimeoutElapsed(timeout) => self.apply_timeout(timeout),
Event::TimeoutElapsed(timeout) => Ok(self.apply_timeout(timeout)),
}
}

async fn apply_new_round(&mut self, round: Round) -> Option<RoundMessage<Ctx>> {
async fn apply_new_round(
&mut self,
round: Round,
) -> Result<Option<RoundMessage<Ctx>>, Error<Ctx>> {
let proposer_address = self
.proposer_selector
.select_proposer(round, &self.validator_set);

let proposer = self
.validator_set
.get_by_address(&proposer_address)
.expect("proposer not found"); // FIXME: expect
.ok_or_else(|| Error::ProposerNotFound(proposer_address.clone()))?;

let event = if proposer.address() == &self.address {
// We are the proposer
// TODO: Schedule propose timeout

let value = self.get_value().await;
RoundEvent::NewRoundProposer(value)
} else {
Expand All @@ -139,7 +148,7 @@ where
.insert(round, RoundState::default().new_round(round));
self.round = round;

self.apply_event(round, event)
Ok(self.apply_event(round, event))
}

async fn apply_proposal(&mut self, proposal: Ctx::Proposal) -> Option<RoundMessage<Ctx>> {
Expand Down Expand Up @@ -201,25 +210,33 @@ where
}
}

fn apply_vote(&mut self, signed_vote: SignedVote<Ctx>) -> Option<RoundMessage<Ctx>> {
// TODO: How to handle missing validator?
fn apply_vote(
&mut self,
signed_vote: SignedVote<Ctx>,
) -> Result<Option<RoundMessage<Ctx>>, Error<Ctx>> {
let validator = self
.validator_set
.get_by_address(signed_vote.validator_address())?;
.get_by_address(signed_vote.validator_address())
.ok_or_else(|| Error::ValidatorNotFound(signed_vote.validator_address().clone()))?;

if !self
.ctx
.verify_signed_vote(&signed_vote, validator.public_key())
{
// TODO: How to handle invalid votes?
return None;
return Err(Error::InvalidVoteSignature(
signed_vote.clone(),
validator.clone(),
));
}

let round = signed_vote.vote.round();

let vote_msg = self
let Some(vote_msg) = self
.votes
.apply_vote(signed_vote.vote, validator.voting_power())?;
.apply_vote(signed_vote.vote, validator.voting_power())
else {
return Ok(None);
};

let round_event = match vote_msg {
VoteMessage::PolkaAny => RoundEvent::PolkaAny,
Expand All @@ -230,7 +247,7 @@ where
VoteMessage::SkipRound(r) => RoundEvent::SkipRound(r),
};

self.apply_event(round, round_event)
Ok(self.apply_event(round, round_event))
}

fn apply_timeout(&mut self, timeout: Timeout) -> Option<RoundMessage<Ctx>> {
Expand Down
52 changes: 52 additions & 0 deletions Code/driver/src/error.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
use core::fmt;

use malachite_common::{Context, SignedVote, Validator};

#[derive(Clone, Debug)]
pub enum Error<Ctx>
where
Ctx: Context,
{
/// Proposer not found
ProposerNotFound(Ctx::Address),

/// Validator not found in validator set
ValidatorNotFound(Ctx::Address),

/// Invalid vote signature
InvalidVoteSignature(SignedVote<Ctx>, Ctx::Validator),
}

impl<Ctx> fmt::Display for Error<Ctx>
where
Ctx: Context,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Error::ProposerNotFound(addr) => write!(f, "Proposer not found: {addr}"),
Error::ValidatorNotFound(addr) => write!(f, "Validator not found: {addr}"),
Error::InvalidVoteSignature(vote, validator) => write!(
f,
"Invalid vote signature by {} on vote {vote:?}",
validator.address()
),
}
}
}

impl<Ctx> PartialEq for Error<Ctx>
where
Ctx: Context,
{
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(Error::ProposerNotFound(addr1), Error::ProposerNotFound(addr2)) => addr1 == addr2,
(Error::ValidatorNotFound(addr1), Error::ValidatorNotFound(addr2)) => addr1 == addr2,
(
Error::InvalidVoteSignature(vote1, validator1),
Error::InvalidVoteSignature(vote2, validator2),
) => vote1 == vote2 && validator1 == validator2,
_ => false,
}
}
}
2 changes: 2 additions & 0 deletions Code/driver/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@ extern crate alloc;

mod driver;
mod env;
mod error;
mod event;
mod message;
mod proposer;

pub use driver::Driver;
pub use env::Env;
pub use error::Error;
pub use event::Event;
pub use message::Message;
pub use proposer::ProposerSelector;
Expand Down
11 changes: 11 additions & 0 deletions Code/test/src/validator_set.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use core::fmt;

use malachite_common::VotingPower;

use crate::{signing::PublicKey, TestContext};
Expand All @@ -22,6 +24,15 @@ impl Address {
}
}

impl fmt::Display for Address {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
for byte in self.0.iter() {
write!(f, "{:02x}", byte)?;
}
Ok(())
}
}

impl malachite_common::Address for Address {}

/// A validator is a public key and voting power
Expand Down
8 changes: 4 additions & 4 deletions Code/test/tests/driver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ fn driver_steps_proposer() {
.input_event
.unwrap_or_else(|| previous_message.unwrap());

let output = block_on(driver.execute(execute_message));
let output = block_on(driver.execute(execute_message)).expect("execute succeeded");
assert_eq!(output, step.expected_output, "expected output message");

assert_eq!(driver.round, step.expected_round, "expected round");
Expand Down Expand Up @@ -419,7 +419,7 @@ fn driver_steps_not_proposer_valid() {
.input_event
.unwrap_or_else(|| previous_message.unwrap());

let output = block_on(driver.execute(execute_message));
let output = block_on(driver.execute(execute_message)).expect("execute succeeded");
assert_eq!(output, step.expected_output, "expected output message");

assert_eq!(driver.round, step.expected_round, "expected round");
Expand Down Expand Up @@ -561,7 +561,7 @@ fn driver_steps_not_proposer_invalid() {
.input_event
.unwrap_or_else(|| previous_message.unwrap());

let output = block_on(driver.execute(execute_message));
let output = block_on(driver.execute(execute_message)).expect("execute succeeded");
assert_eq!(output, step.expected_output, "expected output");

assert_eq!(driver.round, step.expected_round, "expected round");
Expand Down Expand Up @@ -766,7 +766,7 @@ fn driver_steps_not_proposer_timeout_multiple_rounds() {
.input_event
.unwrap_or_else(|| previous_message.unwrap());

let output = block_on(driver.execute(execute_message));
let output = block_on(driver.execute(execute_message)).expect("execute succeeded");
assert_eq!(output, step.expected_output, "expected output message");

assert_eq!(driver.round, step.expected_round, "expected round");
Expand Down

0 comments on commit 0af13bc

Please sign in to comment.