diff --git a/core/ibft.go b/core/ibft.go index 6806a31..78bc980 100644 --- a/core/ibft.go +++ b/core/ibft.go @@ -46,8 +46,9 @@ type Messages interface { } const ( - round0Timeout = 10 * time.Second - roundFactorBase = float64(2) + // DefaultBaseRoundTimeout is the default base round (round 0) timeout + DefaultBaseRoundTimeout = 10 * time.Second + roundFactorBase = float64(2) ) var ( @@ -129,7 +130,7 @@ func NewIBFT( roundStarted: false, name: newRound, }, - baseRoundTimeout: round0Timeout, + baseRoundTimeout: DefaultBaseRoundTimeout, validatorManager: NewValidatorManager(backend, log), } } @@ -1033,10 +1034,10 @@ func (i *IBFT) buildProposal(ctx context.Context, view *proto.View) *proto.Messa } proposal := messages.ExtractProposal(latestPC.ProposalMessage) - round := proposal.Round + preparedCertificateRound := proposal.Round // skip if message's round is equals to/less than maxRound - if previousProposal != nil && round <= maxRound { + if previousProposal != nil && preparedCertificateRound <= maxRound { continue } @@ -1046,7 +1047,7 @@ func (i *IBFT) buildProposal(ctx context.Context, view *proto.View) *proto.Messa } previousProposal = lastPB.RawProposal - maxRound = round + maxRound = preparedCertificateRound } if previousProposal == nil { @@ -1140,6 +1141,11 @@ func (i *IBFT) ExtendRoundTimeout(amount time.Duration) { i.additionalTimeout = amount } +// SetBaseRoundTimeout sets the base (round 0) timeout +func (i *IBFT) SetBaseRoundTimeout(baseRoundTimeout time.Duration) { + i.baseRoundTimeout = baseRoundTimeout +} + // validPC verifies that the prepared certificate is valid func (i *IBFT) validPC( certificate *proto.PreparedCertificate, @@ -1288,9 +1294,9 @@ func (i *IBFT) subscribe(details messages.SubscriptionDetails) *messages.Subscri // - round 4: 8 sec func getRoundTimeout(baseRoundTimeout, additionalTimeout time.Duration, round uint64) time.Duration { var ( - duration = int(baseRoundTimeout) + baseDuration = int(baseRoundTimeout) roundFactor = int(math.Pow(roundFactorBase, float64(round))) - roundTimeout = time.Duration(duration * roundFactor) + roundTimeout = time.Duration(baseDuration * roundFactor) ) return roundTimeout + additionalTimeout diff --git a/core/ibft_test.go b/core/ibft_test.go index f54d58e..7aea003 100644 --- a/core/ibft_test.go +++ b/core/ibft_test.go @@ -705,7 +705,7 @@ func TestRunNewRound_Validator_NonZero(t *testing.T) { proposer := []byte("proposer") round := uint64(1) - correctRoundMessage := newCorrectRoundMessage(round) + roundMessage := newCorrectRoundMessage(round) generateProposalWithNoPrevious := func() *proto.Message { roundChangeMessages := generateMessagesWithUniqueSender(quorum, proto.MessageType_ROUND_CHANGE) @@ -720,8 +720,8 @@ func TestRunNewRound_Validator_NonZero(t *testing.T) { Type: proto.MessageType_PREPREPARE, Payload: &proto.Message_PreprepareData{ PreprepareData: &proto.PrePrepareMessage{ - Proposal: correctRoundMessage.proposal, - ProposalHash: correctRoundMessage.hash, + Proposal: roundMessage.proposal, + ProposalHash: roundMessage.hash, Certificate: &proto.RoundChangeCertificate{ RoundChangeMessages: roundChangeMessages, }, @@ -740,13 +740,13 @@ func TestRunNewRound_Validator_NonZero(t *testing.T) { Type: proto.MessageType_PREPREPARE, Payload: &proto.Message_PreprepareData{ PreprepareData: &proto.PrePrepareMessage{ - Proposal: correctRoundMessage.proposal, - ProposalHash: correctRoundMessage.hash, + Proposal: roundMessage.proposal, + ProposalHash: roundMessage.hash, Certificate: &proto.RoundChangeCertificate{ RoundChangeMessages: generateFilledRCMessages( quorum, - correctRoundMessage.proposal, - correctRoundMessage.hash, + roundMessage.proposal, + roundMessage.hash, ), }, }, @@ -800,7 +800,7 @@ func TestRunNewRound_Validator_NonZero(t *testing.T) { Type: proto.MessageType_PREPARE, Payload: &proto.Message_PrepareData{ PrepareData: &proto.PrepareMessage{ - ProposalHash: correctRoundMessage.hash, + ProposalHash: roundMessage.hash, }, }, } @@ -857,10 +857,10 @@ func TestRunNewRound_Validator_NonZero(t *testing.T) { assert.Equal(t, prepare, i.state.name) // Make sure the accepted proposal is the one that was sent out - assert.Equal(t, correctRoundMessage.proposal, i.state.getProposal()) + assert.Equal(t, roundMessage.proposal, i.state.getProposal()) // Make sure the correct proposal hash was multicasted - assert.True(t, prepareHashMatches(correctRoundMessage.hash, multicastedPrepare)) + assert.True(t, prepareHashMatches(roundMessage.hash, multicastedPrepare)) }) } } @@ -1201,6 +1201,7 @@ func TestIBFT_IsAcceptableMessage(t *testing.T) { }, } ) + i := NewIBFT(log, backend, transport) i.state.view = testCase.stateView @@ -1235,6 +1236,7 @@ func TestIBFT_StartRoundTimer(t *testing.T) { wg.Add(1) i.wg.Add(1) + go func() { i.startRoundTimer(ctx, 0) @@ -1264,6 +1266,7 @@ func TestIBFT_StartRoundTimer(t *testing.T) { ctx, cancelFn := context.WithCancel(context.Background()) wg.Add(1) + go func() { defer func() { wg.Done() @@ -1351,7 +1354,7 @@ func TestIBFT_FutureProposal(t *testing.T) { view *proto.View, roundChangeMessages []*proto.Message, ) *proto.Message { - correctRoundMessage := newCorrectRoundMessage(view.Round) + roundMessage := newCorrectRoundMessage(view.Round) return &proto.Message{ View: view, @@ -1359,8 +1362,8 @@ func TestIBFT_FutureProposal(t *testing.T) { Type: proto.MessageType_PREPREPARE, Payload: &proto.Message_PreprepareData{ PreprepareData: &proto.PrePrepareMessage{ - Proposal: correctRoundMessage.proposal, - ProposalHash: correctRoundMessage.hash, + Proposal: roundMessage.proposal, + ProposalHash: roundMessage.hash, Certificate: &proto.RoundChangeCertificate{ RoundChangeMessages: roundChangeMessages, }, @@ -1464,6 +1467,7 @@ func TestIBFT_FutureProposal(t *testing.T) { i.messages = mMessages wg.Add(1) + go func() { defer func() { cancelFn() @@ -1988,6 +1992,7 @@ func TestIBFT_ValidPC(t *testing.T) { i := NewIBFT(log, backend, transport) require.NoError(t, i.validatorManager.Init(0)) + proposal := generateMessagesWithSender(1, proto.MessageType_PREPREPARE, sender)[0] certificate := &proto.PreparedCertificate{ @@ -3045,6 +3050,18 @@ func TestIBFT_ExtendRoundTimer(t *testing.T) { assert.Equal(t, additionalTimeout, i.additionalTimeout) } +func TestIBFTOverrideBaseRoundTimeout(t *testing.T) { + t.Parallel() + + baseRoundTimeout := 50 * time.Second + + i := NewIBFT(mockLogger{}, mockBackend{}, mockTransport{}) + i.SetBaseRoundTimeout(baseRoundTimeout) + + // Make sure the base round timeout is properly set + assert.Equal(t, baseRoundTimeout, i.baseRoundTimeout) +} + func Test_getRoundTimeout(t *testing.T) { t.Parallel() @@ -3060,7 +3077,7 @@ func Test_getRoundTimeout(t *testing.T) { want time.Duration }{ { - name: "first round duration", + name: "zero round duration", args: args{ baseRoundTimeout: time.Second, additionalTimeout: time.Second, @@ -3069,7 +3086,7 @@ func Test_getRoundTimeout(t *testing.T) { want: time.Second * 2, }, { - name: "zero round duration", + name: "first round duration", args: args{ baseRoundTimeout: time.Second, additionalTimeout: time.Second, @@ -3077,6 +3094,14 @@ func Test_getRoundTimeout(t *testing.T) { }, want: time.Second * 3, }, + { + name: "third round duration", + args: args{ + baseRoundTimeout: time.Second, + round: 3, + }, + want: time.Second * 8, + }, } for _, tt := range tests {