Skip to content

Commit

Permalink
Extract immutable state and contextual data into a RoundContext
Browse files Browse the repository at this point in the history
  • Loading branch information
romac committed Oct 31, 2023
1 parent 26a5573 commit 480d110
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 102 deletions.
19 changes: 9 additions & 10 deletions Code/consensus/src/executor.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::collections::BTreeMap;

use malachite_round::state_machine::RoundContext;
use secrecy::{ExposeSecret, Secret};

use malachite_common::signature::Keypair;
Expand Down Expand Up @@ -93,10 +94,9 @@ where
RoundMessage::NewRound(round) => {
// TODO: check if we are the proposer

self.round_states.insert(
round,
RoundState::new(self.height.clone(), self.address.clone()).new_round(round),
);
// XXX: Check if there is an existing state?
self.round_states
.insert(round, RoundState::default().new_round(round));

None
}
Expand Down Expand Up @@ -160,7 +160,7 @@ where
}

// Check that the proposal is for the current height and round
if round_state.height != proposal.height() || proposal.round() != self.round {
if self.height != proposal.height() || self.round != proposal.round() {
return None;
}

Expand Down Expand Up @@ -232,13 +232,12 @@ where
/// Apply the event, update the state.
fn apply_event(&mut self, round: Round, event: RoundEvent<Ctx>) -> Option<RoundMessage<Ctx>> {
// Get the round state, or create a new one
let round_state = self
.round_states
.remove(&round)
.unwrap_or_else(|| RoundState::new(self.height.clone(), self.address.clone()));
let round_state = self.round_states.remove(&round).unwrap_or_default();

let round_ctx = RoundContext::new(round, &self.height, &self.address);

// Apply the event to the round state machine
let transition = round_state.apply_event(round, event);
let transition = round_state.apply_event(&round_ctx, event);

// Update state
self.round_states.insert(round, transition.next_state);
Expand Down
22 changes: 13 additions & 9 deletions Code/round/src/state.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::events::Event;
use crate::state_machine::RoundContext;
use crate::transition::Transition;

use malachite_common::{Context, Round};
Expand Down Expand Up @@ -32,8 +33,6 @@ pub struct State<Ctx>
where
Ctx: Context,
{
pub address: Ctx::Address,
pub height: Ctx::Height,
pub round: Round,
pub step: Step,
pub proposal: Option<Ctx::Proposal>,
Expand All @@ -47,8 +46,6 @@ where
{
fn clone(&self) -> Self {
Self {
address: self.address.clone(),
height: self.height.clone(),
round: self.round,
step: self.step,
proposal: self.proposal.clone(),
Expand All @@ -62,10 +59,8 @@ impl<Ctx> State<Ctx>
where
Ctx: Context,
{
pub fn new(height: Ctx::Height, address: Ctx::Address) -> Self {
pub fn new() -> Self {
Self {
address,
height,
round: Round::INITIAL,
step: Step::NewRound,
proposal: None,
Expand Down Expand Up @@ -114,7 +109,16 @@ where
}
}

pub fn apply_event(self, round: Round, event: Event<Ctx>) -> Transition<Ctx> {
crate::state_machine::apply_event(self, round, event)
pub fn apply_event(self, ctx: &RoundContext<Ctx>, event: Event<Ctx>) -> Transition<Ctx> {
crate::state_machine::apply_event(self, ctx, event)
}
}

impl<Ctx> Default for State<Ctx>
where
Ctx: Context,
{
fn default() -> Self {
Self::new()
}
}
93 changes: 66 additions & 27 deletions Code/round/src/state_machine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,28 @@ use crate::message::Message;
use crate::state::{State, Step};
use crate::transition::Transition;

pub struct RoundContext<'a, Ctx>
where
Ctx: Context,
{
pub round: Round,
pub height: &'a Ctx::Height,
pub address: &'a Ctx::Address,
}

impl<'a, Ctx> RoundContext<'a, Ctx>
where
Ctx: Context,
{
pub fn new(round: Round, height: &'a Ctx::Height, address: &'a Ctx::Address) -> Self {
Self {
round,
height,
address,
}
}
}

/// Check that a proposal has a valid Proof-Of-Lock round
fn is_valid_pol_round<Ctx>(state: &State<Ctx>, pol_round: Round) -> bool
where
Expand All @@ -21,15 +43,21 @@ where
/// Valid transitions result in at least a change to the state and/or an output message.
///
/// Commented numbers refer to line numbers in the spec paper.
pub fn apply_event<Ctx>(mut state: State<Ctx>, round: Round, event: Event<Ctx>) -> Transition<Ctx>
pub fn apply_event<Ctx>(
mut state: State<Ctx>,
ctx: &RoundContext<Ctx>,
event: Event<Ctx>,
) -> Transition<Ctx>
where
Ctx: Context,
{
let this_round = state.round == round;
let this_round = state.round == ctx.round;

match (state.step, event) {
// From NewRound. Event must be for current round.
(Step::NewRound, Event::NewRoundProposer(value)) if this_round => propose(state, value), // L11/L14
(Step::NewRound, Event::NewRoundProposer(value)) if this_round => {
propose(state, ctx.height, value) // L11/L14
}
(Step::NewRound, Event::NewRound) if this_round => schedule_timeout_propose(state), // L11/L20

// From Propose. Event must be for current round.
Expand All @@ -44,9 +72,9 @@ where
.map_or(true, |locked| &locked.value == proposal.value())
{
state.proposal = Some(proposal.clone());
prevote(state, proposal.round(), proposal.value().id())
prevote(state, ctx.address, proposal.round(), proposal.value().id())
} else {
prevote_nil(state)
prevote_nil(state, ctx.address)
}
}

Expand All @@ -62,33 +90,35 @@ where
if proposal.value().is_valid()
&& (locked.round <= proposal.pol_round() || &locked.value == proposal.value())
{
prevote(state, proposal.round(), proposal.value().id())
prevote(state, ctx.address, proposal.round(), proposal.value().id())
} else {
prevote_nil(state)
prevote_nil(state, ctx.address)
}
}
(Step::Propose, Event::ProposalInvalid) if this_round => prevote_nil(state), // L22/L25, L28/L31
(Step::Propose, Event::TimeoutPropose) if this_round => prevote_nil(state), // L57
(Step::Propose, Event::ProposalInvalid) if this_round => prevote_nil(state, ctx.address), // L22/L25, L28/L31
(Step::Propose, Event::TimeoutPropose) if this_round => prevote_nil(state, ctx.address), // L57

// From Prevote. Event must be for current round.
(Step::Prevote, Event::PolkaAny) if this_round => schedule_timeout_prevote(state), // L34
(Step::Prevote, Event::PolkaNil) if this_round => precommit_nil(state), // L44
(Step::Prevote, Event::PolkaValue(value_id)) if this_round => precommit(state, value_id), // L36/L37 - NOTE: only once?
(Step::Prevote, Event::TimeoutPrevote) if this_round => precommit_nil(state), // L61
(Step::Prevote, Event::PolkaNil) if this_round => precommit_nil(state, ctx), // L44
(Step::Prevote, Event::PolkaValue(value_id)) if this_round => {
precommit(state, ctx.address, value_id) // L36/L37 - NOTE: only once?
}
(Step::Prevote, Event::TimeoutPrevote) if this_round => precommit_nil(state, ctx), // L61

// From Precommit. Event must be for current round.
(Step::Precommit, Event::PolkaValue(value_id)) if this_round => {
set_valid_value(state, value_id)
} // L36/L42 - NOTE: only once?
set_valid_value(state, value_id) // L36/L42 - NOTE: only once?
}

// From Commit. No more state transitions.
(Step::Commit, _) => Transition::invalid(state),

// From all (except Commit). Various round guards.
(_, Event::PrecommitAny) if this_round => schedule_timeout_precommit(state), // L47
(_, Event::TimeoutPrecommit) if this_round => round_skip(state, round.increment()), // L65
(_, Event::RoundSkip) if state.round < round => round_skip(state, round), // L55
(_, Event::PrecommitValue(value_id)) => commit(state, round, value_id), // L49
(_, Event::TimeoutPrecommit) if this_round => round_skip(state, ctx.round.increment()), // L65
(_, Event::RoundSkip) if state.round < ctx.round => round_skip(state, ctx.round), // L55
(_, Event::PrecommitValue(value_id)) => commit(state, ctx.round, value_id), // L49

// Invalid transition.
_ => Transition::invalid(state),
Expand All @@ -103,7 +133,7 @@ where
/// otherwise propose the given value.
///
/// Ref: L11/L14
pub fn propose<Ctx>(state: State<Ctx>, value: Ctx::Value) -> Transition<Ctx>
pub fn propose<Ctx>(state: State<Ctx>, height: &Ctx::Height, value: Ctx::Value) -> Transition<Ctx>
where
Ctx: Context,
{
Expand All @@ -112,7 +142,7 @@ where
None => (value, Round::Nil),
};

let proposal = Message::proposal(state.height.clone(), state.round, value, pol_round);
let proposal = Message::proposal(height.clone(), state.round, value, pol_round);
Transition::to(state.next_step()).with_message(proposal)
}

Expand All @@ -124,7 +154,12 @@ where
/// unless we are locked on something else at a higher round.
///
/// Ref: L22/L28
pub fn prevote<Ctx>(state: State<Ctx>, vr: Round, proposed: ValueId<Ctx>) -> Transition<Ctx>
pub fn prevote<Ctx>(
state: State<Ctx>,
address: &Ctx::Address,
vr: Round,
proposed: ValueId<Ctx>,
) -> Transition<Ctx>
where
Ctx: Context,
{
Expand All @@ -135,18 +170,18 @@ where
None => Some(proposed), // not locked, prevote the value
};

let message = Message::prevote(state.round, value, state.address.clone());
let message = Message::prevote(state.round, value, address.clone());
Transition::to(state.next_step()).with_message(message)
}

/// Received a complete proposal for an empty or invalid value, or timed out; prevote nil.
///
/// Ref: L22/L25, L28/L31, L57
pub fn prevote_nil<Ctx>(state: State<Ctx>) -> Transition<Ctx>
pub fn prevote_nil<Ctx>(state: State<Ctx>, address: &Ctx::Address) -> Transition<Ctx>
where
Ctx: Context,
{
let message = Message::prevote(state.round, None, state.address.clone());
let message = Message::prevote(state.round, None, address.clone());
Transition::to(state.next_step()).with_message(message)
}

Expand All @@ -160,11 +195,15 @@ where
///
/// NOTE: Only one of this and set_valid_value should be called once in a round
/// How do we enforce this?
pub fn precommit<Ctx>(state: State<Ctx>, value_id: ValueId<Ctx>) -> Transition<Ctx>
pub fn precommit<Ctx>(
state: State<Ctx>,
address: &Ctx::Address,
value_id: ValueId<Ctx>,
) -> Transition<Ctx>
where
Ctx: Context,
{
let message = Message::precommit(state.round, Some(value_id), state.address.clone());
let message = Message::precommit(state.round, Some(value_id), address.clone());

let Some(value) = state
.proposal
Expand All @@ -183,11 +222,11 @@ where
/// Received a polka for nil or timed out of prevote; precommit nil.
///
/// Ref: L44, L61
pub fn precommit_nil<Ctx>(state: State<Ctx>) -> Transition<Ctx>
pub fn precommit_nil<Ctx>(state: State<Ctx>, ctx: &RoundContext<Ctx>) -> Transition<Ctx>
where
Ctx: Context,
{
let message = Message::precommit(state.round, None, state.address.clone());
let message = Message::precommit(state.round, None, ctx.address.clone());
Transition::to(state.next_step()).with_message(message)
}

Expand Down
Loading

0 comments on commit 480d110

Please sign in to comment.