From 74aaeb1c7523c44a9a15dbc326d264d941562f71 Mon Sep 17 00:00:00 2001 From: Nikolay Eskov Date: Mon, 22 Apr 2024 00:22:12 +0300 Subject: [PATCH] Fix 'newBlocks.current' getter for binary blocks. (#1388) --- pkg/state/state.go | 51 ++++++++++++++++++++++++++++++---------------- 1 file changed, 34 insertions(+), 17 deletions(-) diff --git a/pkg/state/state.go b/pkg/state/state.go index 4584865c2..b612e955c 100644 --- a/pkg/state/state.go +++ b/pkg/state/state.go @@ -350,30 +350,47 @@ func (n *newBlocks) unmarshalBlock(block *proto.Block, blockBytes []byte) error return nil } +func getBlockDataWithOptionalSnapshot[T interface{ []byte | *proto.Block }]( + curPos int, + blocks []T, + snapshots []*proto.BlockSnapshot, +) (T, *proto.BlockSnapshot, error) { + if curPos > len(blocks) || curPos < 1 { + var zero T + return zero, nil, errors.New("bad current position") + } + var ( + pos = curPos - 1 + block = blocks[pos] + optionalSnapshot *proto.BlockSnapshot + ) + if sl := len(snapshots); sl != 0 { // snapshots aren't empty + if bl := len(blocks); sl != bl { // if snapshots are present, they must have the same length as blocks + var zero T + return zero, nil, errors.Errorf("snapshots and blocks slices have different lengths %d and %d", sl, bl) + } + optionalSnapshot = snapshots[pos] // blocks and snapshots have the same length + } + return block, optionalSnapshot, nil +} + func (n *newBlocks) current() (*proto.Block, *proto.BlockSnapshot, error) { if !n.binary { - if n.curPos > len(n.blocks) || n.curPos < 1 { - return nil, nil, errors.New("bad current position") - } - var ( - pos = n.curPos - 1 - block = n.blocks[pos] - optionalSnapshot *proto.BlockSnapshot - ) - if len(n.snapshots) == len(n.blocks) { // return block with snapshot if it is set - optionalSnapshot = n.snapshots[pos] + block, optionalSnapshot, err := getBlockDataWithOptionalSnapshot(n.curPos, n.blocks, n.snapshots) + if err != nil { + return nil, nil, errors.Wrap(err, "failed to get current deserialized block") } return block, optionalSnapshot, nil } - if n.curPos > len(n.binBlocks) || n.curPos < 1 { - return nil, nil, errors.New("bad current position") + blockBytes, optionalSnapshot, err := getBlockDataWithOptionalSnapshot(n.curPos, n.binBlocks, n.snapshots) + if err != nil { + return nil, nil, errors.Wrap(err, "failed to get current binary block") } - blockBytes := n.binBlocks[n.curPos-1] - b := &proto.Block{} - if err := n.unmarshalBlock(b, blockBytes); err != nil { - return nil, nil, err + block := &proto.Block{} + if unmErr := n.unmarshalBlock(block, blockBytes); unmErr != nil { + return nil, nil, unmErr } - return b, nil, nil + return block, optionalSnapshot, nil } func (n *newBlocks) reset() {