From d39451d082864d97e15751c14f9ed30fdac7f626 Mon Sep 17 00:00:00 2001 From: madlabman <10616301+madlabman@users.noreply.github.com> Date: Wed, 22 Nov 2023 17:24:12 +0100 Subject: [PATCH 1/7] fix: queue cleanup --- src/CSAccounting.sol | 8 +- src/CSModule.sol | 52 +++++--- src/lib/Batch.sol | 16 ++- test/Batch.t.sol | 28 ++-- test/CSModule.t.sol | 256 +++++++++++++++++++++++++++++-------- test/helpers/Utilities.sol | 10 +- 6 files changed, 280 insertions(+), 90 deletions(-) diff --git a/src/CSAccounting.sol b/src/CSAccounting.sol index d123d35c..8f27a3e8 100644 --- a/src/CSAccounting.sol +++ b/src/CSAccounting.sol @@ -88,13 +88,13 @@ contract CSAccounting is CSAccountingBase, AccessControlEnumerable { } bytes32 public constant INSTANT_PENALIZE_BOND_ROLE = - keccak256("INSTANT_PENALIZE_BOND_ROLE"); + keccak256("INSTANT_PENALIZE_BOND_ROLE"); // 0x9909cf24c2d3bafa8c229558d86a1b726ba57c3ef6350848dcf434a4181b56c7 bytes32 public constant EL_REWARDS_STEALING_PENALTY_INIT_ROLE = - keccak256("EL_REWARDS_STEALING_PENALTY_INIT_ROLE"); + keccak256("EL_REWARDS_STEALING_PENALTY_INIT_ROLE"); // 0xcc2e7ce7be452f766dd24d55d87a3d42901c31ffa5b600cd1dff475abec91c1f bytes32 public constant EL_REWARDS_STEALING_PENALTY_SETTLE_ROLE = - keccak256("EL_REWARDS_STEALING_PENALTY_SETTLE_ROLE"); + keccak256("EL_REWARDS_STEALING_PENALTY_SETTLE_ROLE"); // 0xdf6226649a1ca132f86d419e46892001284368a8f7445b5eb0d3fadf91329fe6 bytes32 public constant SET_BOND_MULTIPLIER_ROLE = - keccak256("SET_BOND_MULTIPLIER_ROLE"); + keccak256("SET_BOND_MULTIPLIER_ROLE"); // 0x62131145aee19b18b85aa8ead52ba87f0efb6e61e249155edc68a2c24e8f79b5 // todo: should be reconsidered uint256 public constant MIN_BLOCKED_BOND_RETENTION_PERIOD = 4 weeks; diff --git a/src/CSModule.sol b/src/CSModule.sol index b34e9fd9..8a19e3e1 100644 --- a/src/CSModule.sol +++ b/src/CSModule.sol @@ -33,6 +33,7 @@ struct NodeOperator { uint256 stuckValidatorsCount; uint256 refundedValidatorsCount; bool isTargetLimitActive; + uint256 queueNonce; } struct NodeOperatorInfo { @@ -107,9 +108,9 @@ contract CSModuleBase { contract CSModule is IStakingModule, CSModuleBase { using QueueLib for QueueLib.Queue; - // @dev max number of node operators is limited by uint128 due to Batch serialization in 32 bytes + // @dev max number of node operators is limited by uint64 due to Batch serialization in 32 bytes // it seems to be enough - uint128 public constant MAX_NODE_OPERATORS_COUNT = type(uint128).max; + uint128 public constant MAX_NODE_OPERATORS_COUNT = type(uint64).max; bytes32 public constant SIGNING_KEYS_POSITION = keccak256("lido.CommunityStakingModule.signingKeysPosition"); @@ -645,14 +646,13 @@ contract CSModule is IStakingModule, CSModuleBase { if (vetKeysPointer > no.totalAddedKeys) revert InvalidVetKeysPointer(); uint64 count = SafeCast.toUint64(vetKeysPointer - no.totalVettedKeys); - uint64 start = SafeCast.toUint64( - no.totalVettedKeys == 0 ? 0 : no.totalVettedKeys - ); + uint64 start = SafeCast.toUint64(no.totalVettedKeys); bytes32 pointer = Batch.serialize({ - nodeOperatorId: SafeCast.toUint128(nodeOperatorId), + nodeOperatorId: SafeCast.toUint64(nodeOperatorId), start: start, - count: count + count: count, + nonce: SafeCast.toUint64(no.queueNonce) }); no.totalVettedKeys = vetKeysPointer; @@ -682,6 +682,7 @@ contract CSModule is IStakingModule, CSModuleBase { function _unvetKeys(uint256 nodeOperatorId) internal { NodeOperator storage no = _nodeOperators[nodeOperatorId]; no.totalVettedKeys = no.totalDepositedKeys; + no.queueNonce++; emit VettedSigningKeysCountChanged(nodeOperatorId, no.totalVettedKeys); _incrementNonce(); } @@ -754,6 +755,7 @@ contract CSModule is IStakingModule, CSModuleBase { _totalDepositedValidators += keysCount; NodeOperator storage no = _nodeOperators[nodeOperatorId]; no.totalDepositedKeys += keysCount; + // redundant check, enforced by _assertIsValidBatch require( no.totalDepositedKeys <= no.totalVettedKeys, "too many keys" @@ -789,11 +791,12 @@ contract CSModule is IStakingModule, CSModuleBase { { uint256 start; uint256 count; + uint256 nonce; - (nodeOperatorId, start, count) = Batch.deserialize(batch); + (nodeOperatorId, start, count, nonce) = Batch.deserialize(batch); NodeOperator storage no = _nodeOperators[nodeOperatorId]; - _assertIsValidBatch(no, start, count); + _assertIsValidBatch(no, start, count, nonce); startIndex = Math.max(start, no.totalDepositedKeys); depositableKeysCount = start + count - startIndex; @@ -802,9 +805,11 @@ contract CSModule is IStakingModule, CSModuleBase { function _assertIsValidBatch( NodeOperator storage no, uint256 start, - uint256 count + uint256 count, + uint256 nonce ) internal view { require(count != 0, "Empty batch given"); + require(nonce == no.queueNonce, "Invalid batch nonce"); require( _unvettedKeysInBatch(no, start, count) == false, "Batch contains unvetted keys" @@ -836,11 +841,18 @@ contract CSModule is IStakingModule, CSModuleBase { break; } - (uint256 nodeOperatorId, uint256 start, uint256 count) = Batch - .deserialize(item); + ( + uint256 nodeOperatorId, + uint256 start, + uint256 count, + uint256 nonce + ) = Batch.deserialize(item); NodeOperator storage no = _nodeOperators[nodeOperatorId]; - if (_unvettedKeysInBatch(no, start, count)) { + if ( + _unvettedKeysInBatch(no, start, count) || nonce != no.queueNonce + ) { queue.remove(pointer, item); + continue; } pointer = item; @@ -871,7 +883,7 @@ contract CSModule is IStakingModule, CSModuleBase { } /// @dev returns the next pointer to start check from - function isQueueHasUnvettedKeys( + function isQueueDirty( uint256 maxItems, bytes32 pointer ) external view returns (bool, bytes32) { @@ -887,10 +899,16 @@ contract CSModule is IStakingModule, CSModuleBase { break; } - (uint256 nodeOperatorId, uint256 start, uint256 count) = Batch - .deserialize(item); + ( + uint256 nodeOperatorId, + uint256 start, + uint256 count, + uint256 nonce + ) = Batch.deserialize(item); NodeOperator storage no = _nodeOperators[nodeOperatorId]; - if (_unvettedKeysInBatch(no, start, count)) { + if ( + _unvettedKeysInBatch(no, start, count) || nonce != no.queueNonce + ) { return (true, pointer); } diff --git a/src/lib/Batch.sol b/src/lib/Batch.sol index 664f88f7..cc7c3712 100644 --- a/src/lib/Batch.sol +++ b/src/lib/Batch.sol @@ -6,21 +6,23 @@ pragma solidity 0.8.21; library Batch { /// @notice Serialize node operator id, batch start and count of keys into a single bytes32 value function serialize( - uint128 nodeOperatorId, + uint64 nodeOperatorId, uint64 start, - uint64 count + uint64 count, + uint64 nonce ) internal pure returns (bytes32 s) { - return bytes32(abi.encodePacked(nodeOperatorId, start, count)); + return bytes32(abi.encodePacked(nodeOperatorId, start, count, nonce)); } /// @notice Deserialize node operator id, batch start and count of keys from a single bytes32 value function deserialize( bytes32 b - ) internal pure returns (uint128 nodeOperatorId, uint64 start, uint64 count) { + ) internal pure returns (uint64 nodeOperatorId, uint64 start, uint64 count, uint64 nonce) { assembly { - nodeOperatorId := shr(128, b) - start := shr(64, b) - count := b + nodeOperatorId := shr(192, b) + start := shr(128, b) + count := shr(64, b) + nonce := b } } diff --git a/test/Batch.t.sol b/test/Batch.t.sol index a5a197ad..cc93af46 100644 --- a/test/Batch.t.sol +++ b/test/Batch.t.sol @@ -11,44 +11,52 @@ contract BatchTest is Test { bytes32 b = Batch.serialize({ nodeOperatorId: 999, start: 3, - count: 42 + count: 42, + nonce: 7 }); assertEq( b, - // noIndex | start | count | - 0x000000000000000000000000000003e70000000000000003000000000000002a + // noIndex | start | count | nonce + 0x00000000000003e70000000000000003000000000000002a0000000000000007 ); } function test_deserialize() public { - (uint128 nodeOperatorId, uint64 start, uint64 count) = Batch - .deserialize( + ( + uint256 nodeOperatorId, + uint256 start, + uint256 count, + uint256 nonce + ) = Batch.deserialize( 0x0000000000000000000000000000000000000000000000000000000000000000 ); assertEq(nodeOperatorId, 0, "nodeOperatorId != 0"); assertEq(start, 0, "start != 0"); assertEq(count, 0, "count != 0"); + assertEq(nonce, 0, "nonce != 0"); - (nodeOperatorId, start, count) = Batch.deserialize( - 0x000000000000000000000000000003e70000000000000003000000000000002a + (nodeOperatorId, start, count, nonce) = Batch.deserialize( + 0x00000000000003e70000000000000003000000000000002a0000000000000007 ); assertEq(nodeOperatorId, 999, "nodeOperatorId != 999"); assertEq(start, 3, "start != 3"); assertEq(count, 42, "count != 42"); + assertEq(nonce, 7, "nonce != 7"); - (nodeOperatorId, start, count) = Batch.deserialize( + (nodeOperatorId, start, count, nonce) = Batch.deserialize( 0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff ); assertEq( nodeOperatorId, - type(uint128).max, - "nodeOperatorId != uint128.max" + type(uint64).max, + "nodeOperatorId != uint64.max" ); assertEq(start, type(uint64).max, "start != uint64.max"); assertEq(count, type(uint64).max, "count != uint64.max"); + assertEq(nonce, type(uint64).max, "nonce != uint64.max"); } } diff --git a/test/CSModule.t.sol b/test/CSModule.t.sol index 17a94578..8c088613 100644 --- a/test/CSModule.t.sol +++ b/test/CSModule.t.sol @@ -13,7 +13,18 @@ import "./helpers/mocks/LidoMock.sol"; import "./helpers/mocks/WstETHMock.sol"; import "./helpers/Utilities.sol"; +import { Strings } from "@openzeppelin/contracts/utils/Strings.sol"; + contract CSMCommon is Test, Fixtures, Utilities, CSModuleBase { + using Strings for uint256; + + struct BatchInfo { + uint256 nodeOperatorId; + uint256 start; + uint256 count; + uint256 nonce; + } + LidoLocatorMock public locator; WstETHMock public wstETH; LidoMock public stETH; @@ -22,16 +33,16 @@ contract CSMCommon is Test, Fixtures, Utilities, CSModuleBase { CSAccounting public accounting; CommunityStakingFeeDistributorMock public communityStakingFeeDistributor; + address internal admin; address internal stranger; - address internal alice; address internal nodeOperator; function setUp() public { - alice = address(1); - nodeOperator = address(2); - stranger = address(3); - address[] memory penalizeRoleMembers = new address[](1); - penalizeRoleMembers[0] = alice; + vm.label(address(this), "TEST"); + + nodeOperator = nextAddress("NODE_OPERATOR"); + stranger = nextAddress("STRANGER"); + admin = nextAddress("ADMIN"); (locator, wstETH, stETH, burner) = initLido(); @@ -46,7 +57,7 @@ contract CSMCommon is Test, Fixtures, Utilities, CSModuleBase { csm = new CSModule("community-staking-module", address(locator)); accounting = new CSAccounting( 2 ether, - alice, + admin, address(locator), address(wstETH), address(csm), @@ -54,6 +65,13 @@ contract CSMCommon is Test, Fixtures, Utilities, CSModuleBase { 1 days ); csm.setAccounting(address(accounting)); + + vm.startPrank(admin); + accounting.grantRole( + accounting.INSTANT_PENALIZE_BOND_ROLE(), + address(csm) + ); // NOTE: required because of `unvetKeys` + vm.stopPrank(); } function createNodeOperator() internal returns (uint256) { @@ -80,6 +98,66 @@ contract CSMCommon is Test, Fixtures, Utilities, CSModuleBase { ); return csm.getNodeOperatorsCount() - 1; } + + function _assertQueueState(BatchInfo[] memory exp) internal { + // (bytes32 pointer,) = csm.queue(); // it works, but how? + bytes32 pointer = bytes32(0); + + for (uint256 i = 0; i < exp.length; i++) { + BatchInfo memory b = exp[i]; + + assertFalse( + _isLastElementInQueue(pointer), + string.concat("unexpected end of queue at index ", i.toString()) + ); + + pointer = _nextPointer(pointer); + ( + uint256 nodeOperatorId, + uint256 start, + uint256 count, + uint256 nonce + ) = Batch.deserialize(pointer); + + assertEq( + nodeOperatorId, + b.nodeOperatorId, + string.concat( + "unexpected `nodeOperatorId` at index ", + i.toString() + ) + ); + assertEq( + start, + b.start, + string.concat("unexpected `start` at index ", i.toString()) + ); + assertEq( + count, + b.count, + string.concat("unexpected `count` at index ", i.toString()) + ); + assertEq( + nonce, + b.nonce, + string.concat("unexpected `nonce` at index ", i.toString()) + ); + } + + assertTrue(_isLastElementInQueue(pointer), "unexpected tail of queue"); + } + + function _isLastElementInQueue( + bytes32 pointer + ) internal view returns (bool) { + bytes32 next = _nextPointer(pointer); + return next == pointer; + } + + function _nextPointer(bytes32 pointer) internal view returns (bytes32) { + (, bytes32 next, ) = csm.depositQueue(1, pointer); + return next; + } } contract CsmInitialization is CSMCommon { @@ -369,9 +447,9 @@ contract CsmProposeNodeOperatorManagerAddressChange is CSMCommon { assertEq(no.rewardAddress, nodeOperator); vm.expectEmit(true, true, false, true, address(csm)); - emit NodeOperatorManagerAddressChangeProposed(noId, alice); + emit NodeOperatorManagerAddressChangeProposed(noId, stranger); vm.prank(nodeOperator); - csm.proposeNodeOperatorManagerAddressChange(noId, alice); + csm.proposeNodeOperatorManagerAddressChange(noId, stranger); assertEq(no.managerAddress, nodeOperator); assertEq(no.rewardAddress, nodeOperator); } @@ -380,7 +458,7 @@ contract CsmProposeNodeOperatorManagerAddressChange is CSMCommon { public { vm.expectRevert(NodeOperatorDoesNotExist.selector); - csm.proposeNodeOperatorManagerAddressChange(0, alice); + csm.proposeNodeOperatorManagerAddressChange(0, stranger); } function test_proposeNodeOperatorManagerAddressChange_RevertWhenNotManager() @@ -388,7 +466,7 @@ contract CsmProposeNodeOperatorManagerAddressChange is CSMCommon { { uint256 noId = createNodeOperator(); vm.expectRevert(SenderIsNotManagerAddress.selector); - csm.proposeNodeOperatorManagerAddressChange(noId, alice); + csm.proposeNodeOperatorManagerAddressChange(noId, stranger); } function test_proposeNodeOperatorManagerAddressChange_RevertWhenAlreadyProposed() @@ -396,11 +474,11 @@ contract CsmProposeNodeOperatorManagerAddressChange is CSMCommon { { uint256 noId = createNodeOperator(); vm.prank(nodeOperator); - csm.proposeNodeOperatorManagerAddressChange(noId, alice); + csm.proposeNodeOperatorManagerAddressChange(noId, stranger); vm.expectRevert(AlreadyProposed.selector); vm.prank(nodeOperator); - csm.proposeNodeOperatorManagerAddressChange(noId, alice); + csm.proposeNodeOperatorManagerAddressChange(noId, stranger); } function test_proposeNodeOperatorManagerAddressChange_RevertWhenSameAddressProposed() @@ -421,15 +499,15 @@ contract CsmConfirmNodeOperatorManagerAddressChange is CSMCommon { assertEq(no.rewardAddress, nodeOperator); vm.prank(nodeOperator); - csm.proposeNodeOperatorManagerAddressChange(noId, alice); + csm.proposeNodeOperatorManagerAddressChange(noId, stranger); vm.expectEmit(true, true, true, true, address(csm)); - emit NodeOperatorManagerAddressChanged(noId, nodeOperator, alice); - vm.prank(alice); + emit NodeOperatorManagerAddressChanged(noId, nodeOperator, stranger); + vm.prank(stranger); csm.confirmNodeOperatorManagerAddressChange(noId); no = csm.getNodeOperator(noId); - assertEq(no.managerAddress, alice); + assertEq(no.managerAddress, stranger); assertEq(no.rewardAddress, nodeOperator); } @@ -445,7 +523,7 @@ contract CsmConfirmNodeOperatorManagerAddressChange is CSMCommon { { uint256 noId = createNodeOperator(); vm.expectRevert(SenderIsNotProposedAddress.selector); - vm.prank(alice); + vm.prank(stranger); csm.confirmNodeOperatorManagerAddressChange(noId); } @@ -457,7 +535,7 @@ contract CsmConfirmNodeOperatorManagerAddressChange is CSMCommon { csm.proposeNodeOperatorManagerAddressChange(noId, stranger); vm.expectRevert(SenderIsNotProposedAddress.selector); - vm.prank(alice); + vm.prank(nextAddress()); csm.confirmNodeOperatorManagerAddressChange(noId); } } @@ -470,9 +548,9 @@ contract CsmProposeNodeOperatorRewardAddressChange is CSMCommon { assertEq(no.rewardAddress, nodeOperator); vm.expectEmit(true, true, false, true, address(csm)); - emit NodeOperatorRewardAddressChangeProposed(noId, alice); + emit NodeOperatorRewardAddressChangeProposed(noId, stranger); vm.prank(nodeOperator); - csm.proposeNodeOperatorRewardAddressChange(noId, alice); + csm.proposeNodeOperatorRewardAddressChange(noId, stranger); assertEq(no.managerAddress, nodeOperator); assertEq(no.rewardAddress, nodeOperator); } @@ -481,7 +559,7 @@ contract CsmProposeNodeOperatorRewardAddressChange is CSMCommon { public { vm.expectRevert(NodeOperatorDoesNotExist.selector); - csm.proposeNodeOperatorRewardAddressChange(0, alice); + csm.proposeNodeOperatorRewardAddressChange(0, stranger); } function test_proposeNodeOperatorRewardAddressChange_RevertWhenNotRewardAddress() @@ -489,7 +567,7 @@ contract CsmProposeNodeOperatorRewardAddressChange is CSMCommon { { uint256 noId = createNodeOperator(); vm.expectRevert(SenderIsNotRewardAddress.selector); - csm.proposeNodeOperatorRewardAddressChange(noId, alice); + csm.proposeNodeOperatorRewardAddressChange(noId, stranger); } function test_proposeNodeOperatorRewardAddressChange_RevertWhenAlreadyProposed() @@ -497,11 +575,11 @@ contract CsmProposeNodeOperatorRewardAddressChange is CSMCommon { { uint256 noId = createNodeOperator(); vm.prank(nodeOperator); - csm.proposeNodeOperatorRewardAddressChange(noId, alice); + csm.proposeNodeOperatorRewardAddressChange(noId, stranger); vm.expectRevert(AlreadyProposed.selector); vm.prank(nodeOperator); - csm.proposeNodeOperatorRewardAddressChange(noId, alice); + csm.proposeNodeOperatorRewardAddressChange(noId, stranger); } function test_proposeNodeOperatorRewardAddressChange_RevertWhenSameAddressProposed() @@ -522,16 +600,16 @@ contract CsmConfirmNodeOperatorRewardAddressChange is CSMCommon { assertEq(no.rewardAddress, nodeOperator); vm.prank(nodeOperator); - csm.proposeNodeOperatorRewardAddressChange(noId, alice); + csm.proposeNodeOperatorRewardAddressChange(noId, stranger); vm.expectEmit(true, true, true, true, address(csm)); - emit NodeOperatorRewardAddressChanged(noId, nodeOperator, alice); - vm.prank(alice); + emit NodeOperatorRewardAddressChanged(noId, nodeOperator, stranger); + vm.prank(stranger); csm.confirmNodeOperatorRewardAddressChange(noId); no = csm.getNodeOperator(noId); assertEq(no.managerAddress, nodeOperator); - assertEq(no.rewardAddress, alice); + assertEq(no.rewardAddress, stranger); } function test_confirmNodeOperatorRewardAddressChange_RevertWhenNoNodeOperator() @@ -546,7 +624,7 @@ contract CsmConfirmNodeOperatorRewardAddressChange is CSMCommon { { uint256 noId = createNodeOperator(); vm.expectRevert(SenderIsNotProposedAddress.selector); - vm.prank(alice); + vm.prank(stranger); csm.confirmNodeOperatorRewardAddressChange(noId); } @@ -558,7 +636,7 @@ contract CsmConfirmNodeOperatorRewardAddressChange is CSMCommon { csm.proposeNodeOperatorRewardAddressChange(noId, stranger); vm.expectRevert(SenderIsNotProposedAddress.selector); - vm.prank(alice); + vm.prank(nextAddress()); csm.confirmNodeOperatorRewardAddressChange(noId); } } @@ -568,18 +646,18 @@ contract CsmResetNodeOperatorManagerAddress is CSMCommon { uint256 noId = createNodeOperator(); vm.prank(nodeOperator); - csm.proposeNodeOperatorRewardAddressChange(noId, alice); - vm.prank(alice); + csm.proposeNodeOperatorRewardAddressChange(noId, stranger); + vm.prank(stranger); csm.confirmNodeOperatorRewardAddressChange(noId); vm.expectEmit(true, true, true, true, address(csm)); - emit NodeOperatorManagerAddressChanged(noId, nodeOperator, alice); - vm.prank(alice); + emit NodeOperatorManagerAddressChanged(noId, nodeOperator, stranger); + vm.prank(stranger); csm.resetNodeOperatorManagerAddress(noId); NodeOperatorInfo memory no = csm.getNodeOperator(noId); - assertEq(no.managerAddress, alice); - assertEq(no.rewardAddress, alice); + assertEq(no.managerAddress, stranger); + assertEq(no.rewardAddress, stranger); } function test_resetNodeOperatorManagerAddress_RevertWhenNoNodeOperator() @@ -618,13 +696,15 @@ contract CsmVetKeys is CSMCommon { NodeOperatorInfo memory no = csm.getNodeOperator(noId); assertEq(no.totalVettedValidators, 1); - (bytes32[] memory items, , ) = csm.depositQueue(1, bytes32(0)); - (uint128 batchNoId, uint64 start, uint64 count) = Batch.deserialize( - items[0] - ); - assertEq(batchNoId, uint128(noId)); - assertEq(start, 0); - assertEq(count, 1); + + BatchInfo[] memory exp = new BatchInfo[](1); + exp[0] = BatchInfo({ + nodeOperatorId: noId, + start: 0, + count: 1, + nonce: 0 + }); + _assertQueueState(exp); } function test_vetKeys_totalVettedKeysIsNotZero() public { @@ -639,13 +719,21 @@ contract CsmVetKeys is CSMCommon { NodeOperatorInfo memory no = csm.getNodeOperator(noId); assertEq(no.totalVettedValidators, 2); - (bytes32[] memory items, , ) = csm.depositQueue(2, bytes32(0)); - (uint128 batchNoId, uint64 start, uint64 count) = Batch.deserialize( - items[1] - ); - assertEq(batchNoId, uint128(noId)); - assertEq(start, 1); - assertEq(count, 1); + + BatchInfo[] memory exp = new BatchInfo[](2); + exp[0] = BatchInfo({ + nodeOperatorId: noId, + start: 0, + count: 1, + nonce: 0 + }); + exp[1] = BatchInfo({ + nodeOperatorId: noId, + start: 1, + count: 1, + nonce: 0 + }); + _assertQueueState(exp); } function test_vetKeys_RevertWhenNoNodeOperator() public { @@ -667,3 +755,69 @@ contract CsmVetKeys is CSMCommon { csm.vetKeys(noId, 2); } } + +contract CsmQueueOps is CSMCommon { + function test_queueIsCleanByDefault() public { + createNodeOperator({ keysCount: 2 }); + csm.vetKeys(0, 1); + + (bool isDirty /* next */, ) = csm.isQueueDirty(1, bytes32(0)); + assertFalse(isDirty, "queue should be clean"); + } + + function test_queueIsDirty_WhenUnvettedKeys() public { + createNodeOperator({ keysCount: 2 }); + csm.vetKeys(0, 1); + csm.unvetKeys(0); + + (bool isDirty /* next */, ) = csm.isQueueDirty(1, bytes32(0)); + assertTrue(isDirty, "queue should be dirty"); + } + + function test_queueIsClean_AfterCleanup() public { + createNodeOperator({ keysCount: 2 }); + csm.vetKeys(0, 1); + csm.unvetKeys(0); + csm.cleanDepositQueue(1, bytes32(0)); + + (bool isDirty /* next */, ) = csm.isQueueDirty(1, bytes32(0)); + assertFalse(isDirty, "queue should be clean"); + } + + function test_queueIsDirty_WhenDanglingBatch() public { + createNodeOperator({ keysCount: 2 }); + + csm.vetKeys(0, 1); + csm.vetKeys(0, 2); + csm.unvetKeys(0); + csm.vetKeys(0, 2); + + // let's check the state of the queue + BatchInfo[] memory exp = new BatchInfo[](3); + exp[0] = BatchInfo({ nodeOperatorId: 0, start: 0, count: 1, nonce: 0 }); + exp[1] = BatchInfo({ nodeOperatorId: 0, start: 1, count: 1, nonce: 0 }); + exp[2] = BatchInfo({ nodeOperatorId: 0, start: 0, count: 2, nonce: 1 }); + _assertQueueState(exp); + + (bool isDirty /* next */, ) = csm.isQueueDirty(3, bytes32(0)); + assertTrue(isDirty, "queue should be dirty"); + } + + function test_queueIsClean_WhenDanglingBatchCleanedUp() public { + createNodeOperator({ keysCount: 2 }); + + csm.vetKeys(0, 1); + csm.vetKeys(0, 2); + csm.unvetKeys(0); + csm.vetKeys(0, 2); + + csm.cleanDepositQueue(3, bytes32(0)); + // let's check the state of the queue + BatchInfo[] memory exp = new BatchInfo[](1); + exp[0] = BatchInfo({ nodeOperatorId: 0, start: 0, count: 2, nonce: 1 }); + _assertQueueState(exp); + + (bool isDirty /* next */, ) = csm.isQueueDirty(3, bytes32(0)); + assertFalse(isDirty, "queue should be clean"); + } +} diff --git a/test/helpers/Utilities.sol b/test/helpers/Utilities.sol index cbd01853..73695604 100644 --- a/test/helpers/Utilities.sol +++ b/test/helpers/Utilities.sol @@ -2,8 +2,10 @@ // SPDX-License-Identifier: GPL-3.0 pragma solidity 0.8.21; +import { CommonBase } from "forge-std/Base.sol"; + /// @author madlabman -contract Utilities { +contract Utilities is CommonBase { bytes32 internal seed = keccak256("seed sEed seEd"); function nextAddress() internal returns (address) { @@ -14,6 +16,12 @@ contract Utilities { return a; } + function nextAddress(string memory label) internal returns (address) { + address a = nextAddress(); + vm.label(a, label); + return a; + } + function keysSignatures( uint256 keysCount ) public pure returns (bytes memory, bytes memory) { From 3600100424cbac84379627dc4711284520b736fa Mon Sep 17 00:00:00 2001 From: madlabman <10616301+madlabman@users.noreply.github.com> Date: Thu, 23 Nov 2023 13:01:13 +0100 Subject: [PATCH 2/7] refactor: remove pointer return value from Queue.list --- src/CSModule.sol | 10 +--------- src/lib/QueueLib.sol | 9 +++++++-- test/CSModule.t.sol | 10 ++++++++-- test/QueueLib.t.sol | 18 +++++------------- 4 files changed, 21 insertions(+), 26 deletions(-) diff --git a/src/CSModule.sol b/src/CSModule.sol index 8a19e3e1..ef5b2ea4 100644 --- a/src/CSModule.sol +++ b/src/CSModule.sol @@ -864,15 +864,7 @@ contract CSModule is IStakingModule, CSModuleBase { function depositQueue( uint256 maxItems, bytes32 pointer - ) - external - view - returns ( - bytes32[] memory items, - bytes32 /* pointer */, - uint256 /* count */ - ) - { + ) external view returns (bytes32[] memory items, uint256 /* count */) { require(maxItems > 0, "Queue walkthrough limit is not set"); if (Batch.isNil(pointer)) { diff --git a/src/lib/QueueLib.sol b/src/lib/QueueLib.sol index 0283d87f..ad7824d3 100644 --- a/src/lib/QueueLib.sol +++ b/src/lib/QueueLib.sol @@ -7,6 +7,8 @@ pragma solidity 0.8.21; library QueueLib { bytes32 public constant NULL_POINTER = bytes32(0); + // @dev Queue is a linked list of items + // @dev front and back are pointers struct Queue { mapping(bytes32 => bytes32) queue; bytes32 front; @@ -38,11 +40,13 @@ library QueueLib { return self.queue[pointer]; } + // @dev returns items array of size `limit` and actual count of items + // @dev reverts if the queue is empty function list(Queue storage self, bytes32 pointer, uint256 limit) internal notEmpty(self) view returns ( bytes32[] memory items, - bytes32 /* pointer */, uint256 /* count */ ) { + require(limit > 0, "Queue: limit is not set"); items = new bytes32[](limit); uint256 i; @@ -56,7 +60,8 @@ library QueueLib { pointer = item; } - return (items, pointer, i); + // TODO: resize items array to actual count + return (items, i); } function isEmpty(Queue storage self) internal view returns (bool) { diff --git a/test/CSModule.t.sol b/test/CSModule.t.sol index 8c088613..081d49eb 100644 --- a/test/CSModule.t.sol +++ b/test/CSModule.t.sol @@ -100,6 +100,12 @@ contract CSMCommon is Test, Fixtures, Utilities, CSModuleBase { } function _assertQueueState(BatchInfo[] memory exp) internal { + if (exp.length == 0) { + vm.expectRevert("Queue: empty"); + _nextPointer(bytes32(0)); + return; + } + // (bytes32 pointer,) = csm.queue(); // it works, but how? bytes32 pointer = bytes32(0); @@ -155,8 +161,8 @@ contract CSMCommon is Test, Fixtures, Utilities, CSModuleBase { } function _nextPointer(bytes32 pointer) internal view returns (bytes32) { - (, bytes32 next, ) = csm.depositQueue(1, pointer); - return next; + (bytes32[] memory items, uint256 count) = csm.depositQueue(1, pointer); + return count == 0 ? pointer : items[0]; } } diff --git a/test/QueueLib.t.sol b/test/QueueLib.t.sol index 2f6acc66..eb8d8f12 100644 --- a/test/QueueLib.t.sol +++ b/test/QueueLib.t.sol @@ -56,32 +56,24 @@ contract QueueLibTest is Test { q.enqueue(p2); { - (bytes32[] memory items, bytes32 pointer, uint256 count) = q.list( - q.front, - 2 - ); + (bytes32[] memory items, uint256 count) = q.list(q.front, 2); assertEq(count, 2); - assertEq(pointer, p1); assertEq(items[0], p0); assertEq(items[1], p1); } { - (bytes32[] memory items, bytes32 pointer, uint256 count) = q.list( - p1, - 999 - ); + (bytes32[] memory items, uint256 count) = q.list(p1, 999); assertEq(count, 1); - assertEq(pointer, p2); assertEq(items[0], p2); } q.dequeue(); { - (, bytes32 pointer, uint256 count) = q.list(q.front, 0); - assertEq(count, 0); - assertEq(pointer, q.front); + (bytes32[] memory items, uint256 count) = q.list(q.front, 1); + assertEq(count, 1); + assertEq(items[0], p1); } } From 539bf9d3ff0acdc6491001fe71f3d43fdeafa940 Mon Sep 17 00:00:00 2001 From: madlabman <10616301+madlabman@users.noreply.github.com> Date: Thu, 23 Nov 2023 16:54:10 +0100 Subject: [PATCH 3/7] test: more test cases for queue operations --- test/CSModule.t.sol | 119 +++++++++++++++++++++++++++++++++++++------- 1 file changed, 102 insertions(+), 17 deletions(-) diff --git a/test/CSModule.t.sol b/test/CSModule.t.sol index 081d49eb..77107361 100644 --- a/test/CSModule.t.sol +++ b/test/CSModule.t.sol @@ -25,6 +25,8 @@ contract CSMCommon is Test, Fixtures, Utilities, CSModuleBase { uint256 nonce; } + bytes32 public constant NULL_POINTER = bytes32(0); + LidoLocatorMock public locator; WstETHMock public wstETH; LidoMock public stETH; @@ -101,13 +103,10 @@ contract CSMCommon is Test, Fixtures, Utilities, CSModuleBase { function _assertQueueState(BatchInfo[] memory exp) internal { if (exp.length == 0) { - vm.expectRevert("Queue: empty"); - _nextPointer(bytes32(0)); - return; + revert("NOTE: use _assertQueueIsEmpty"); } - // (bytes32 pointer,) = csm.queue(); // it works, but how? - bytes32 pointer = bytes32(0); + (bytes32 pointer, ) = csm.queue(); // queue.front for (uint256 i = 0; i < exp.length; i++) { BatchInfo memory b = exp[i]; @@ -153,6 +152,11 @@ contract CSMCommon is Test, Fixtures, Utilities, CSModuleBase { assertTrue(_isLastElementInQueue(pointer), "unexpected tail of queue"); } + function _assertQueueIsEmpty() internal { + (bytes32 front, bytes32 back) = csm.queue(); + assertEq(front, back, "queue is not empty"); + } + function _isLastElementInQueue( bytes32 pointer ) internal view returns (bool) { @@ -763,11 +767,13 @@ contract CsmVetKeys is CSMCommon { } contract CsmQueueOps is CSMCommon { - function test_queueIsCleanByDefault() public { - createNodeOperator({ keysCount: 2 }); - csm.vetKeys(0, 1); + uint256 internal constant LOOKUP_DEPTH = 150; // derived from maxDepositsPerBlock - (bool isDirty /* next */, ) = csm.isQueueDirty(1, bytes32(0)); + function test_emptyQueueIsClean() public { + (bool isDirty /* next */, ) = csm.isQueueDirty( + LOOKUP_DEPTH, + NULL_POINTER + ); assertFalse(isDirty, "queue should be clean"); } @@ -776,7 +782,10 @@ contract CsmQueueOps is CSMCommon { csm.vetKeys(0, 1); csm.unvetKeys(0); - (bool isDirty /* next */, ) = csm.isQueueDirty(1, bytes32(0)); + (bool isDirty /* next */, ) = csm.isQueueDirty( + LOOKUP_DEPTH, + NULL_POINTER + ); assertTrue(isDirty, "queue should be dirty"); } @@ -784,13 +793,16 @@ contract CsmQueueOps is CSMCommon { createNodeOperator({ keysCount: 2 }); csm.vetKeys(0, 1); csm.unvetKeys(0); - csm.cleanDepositQueue(1, bytes32(0)); + csm.cleanDepositQueue(LOOKUP_DEPTH, NULL_POINTER); - (bool isDirty /* next */, ) = csm.isQueueDirty(1, bytes32(0)); + (bool isDirty /* next */, ) = csm.isQueueDirty( + LOOKUP_DEPTH, + NULL_POINTER + ); assertFalse(isDirty, "queue should be clean"); } - function test_queueIsDirty_WhenDanglingBatch() public { + function test_queueIsDirty_WhenDanglingBatches() public { createNodeOperator({ keysCount: 2 }); csm.vetKeys(0, 1); @@ -805,11 +817,14 @@ contract CsmQueueOps is CSMCommon { exp[2] = BatchInfo({ nodeOperatorId: 0, start: 0, count: 2, nonce: 1 }); _assertQueueState(exp); - (bool isDirty /* next */, ) = csm.isQueueDirty(3, bytes32(0)); + (bool isDirty /* next */, ) = csm.isQueueDirty( + LOOKUP_DEPTH, + NULL_POINTER + ); assertTrue(isDirty, "queue should be dirty"); } - function test_queueIsClean_WhenDanglingBatchCleanedUp() public { + function test_queueIsClean_WhenDanglingBatchesCleanedUp() public { createNodeOperator({ keysCount: 2 }); csm.vetKeys(0, 1); @@ -817,13 +832,83 @@ contract CsmQueueOps is CSMCommon { csm.unvetKeys(0); csm.vetKeys(0, 2); - csm.cleanDepositQueue(3, bytes32(0)); + csm.cleanDepositQueue(LOOKUP_DEPTH, NULL_POINTER); // let's check the state of the queue BatchInfo[] memory exp = new BatchInfo[](1); exp[0] = BatchInfo({ nodeOperatorId: 0, start: 0, count: 2, nonce: 1 }); _assertQueueState(exp); - (bool isDirty /* next */, ) = csm.isQueueDirty(3, bytes32(0)); + (bool isDirty /* next */, ) = csm.isQueueDirty( + LOOKUP_DEPTH, + NULL_POINTER + ); assertFalse(isDirty, "queue should be clean"); } + + function test_cleanup_emptyQueue() public { + csm.cleanDepositQueue(LOOKUP_DEPTH, NULL_POINTER); + _assertQueueIsEmpty(); + + (bool isDirty /* next */, ) = csm.isQueueDirty( + LOOKUP_DEPTH, + NULL_POINTER + ); + assertFalse(isDirty, "queue should be clean"); + } + + function test_cleanup_WhenOneInvalidBatchInRow() public { + createNodeOperator({ keysCount: 2 }); + + csm.vetKeys(0, 1); + csm.unvetKeys(0); + csm.vetKeys(0, 1); + + csm.cleanDepositQueue(LOOKUP_DEPTH, NULL_POINTER); + // let's check the state of the queue + BatchInfo[] memory exp = new BatchInfo[](1); + exp[0] = BatchInfo({ nodeOperatorId: 0, start: 0, count: 1, nonce: 1 }); + _assertQueueState(exp); + + (bool isDirty /* next */, ) = csm.isQueueDirty( + LOOKUP_DEPTH, + NULL_POINTER + ); + assertFalse(isDirty, "queue should be clean"); + } + + function test_cleanup_WhenMultipleInvalidBatchesInRow() public { + createNodeOperator({ keysCount: 3 }); + createNodeOperator({ keysCount: 2 }); + + csm.vetKeys(0, 1); // <-- invalid + csm.vetKeys(1, 1); + csm.vetKeys(0, 2); // <-- invalid + csm.vetKeys(0, 3); // <-- invalid + csm.unvetKeys(0); + csm.vetKeys(0, 3); + + csm.cleanDepositQueue(LOOKUP_DEPTH, NULL_POINTER); + // let's check the state of the queue + BatchInfo[] memory exp = new BatchInfo[](2); + exp[0] = BatchInfo({ nodeOperatorId: 1, start: 0, count: 1, nonce: 0 }); + exp[1] = BatchInfo({ nodeOperatorId: 0, start: 0, count: 3, nonce: 1 }); + _assertQueueState(exp); + + (bool isDirty /* next */, ) = csm.isQueueDirty( + LOOKUP_DEPTH, + NULL_POINTER + ); + assertFalse(isDirty, "queue should be clean"); + } + + function test_cleanup_WhenAllBatchesInvalid() public { + createNodeOperator({ keysCount: 2 }); + + csm.vetKeys(0, 1); + csm.vetKeys(0, 2); + csm.unvetKeys(0); + + csm.cleanDepositQueue(LOOKUP_DEPTH, NULL_POINTER); + _assertQueueIsEmpty(); + } } From 441c784f6560088fc162d94188911461d40df670 Mon Sep 17 00:00:00 2001 From: madlabman <10616301+madlabman@users.noreply.github.com> Date: Thu, 23 Nov 2023 17:24:08 +0100 Subject: [PATCH 4/7] refactor: custom errors for queue --- src/CSModule.sol | 38 ++++++++++++++++++++------------------ 1 file changed, 20 insertions(+), 18 deletions(-) diff --git a/src/CSModule.sol b/src/CSModule.sol index ef5b2ea4..7193a697 100644 --- a/src/CSModule.sol +++ b/src/CSModule.sol @@ -103,6 +103,13 @@ contract CSModuleBase { error SameAddress(); error AlreadyProposed(); error InvalidVetKeysPointer(); + + error QueueLookupNoLimit(); + error QueueEmptyBatch(); + error QueueBatchInvalidNonce(bytes32 batch); + error QueueBatchInvalidStart(bytes32 batch); + error QueueBatchInvalidCount(bytes32 batch); + error QueueBatchUnvettedKeys(bytes32 batch); } contract CSModule is IStakingModule, CSModuleBase { @@ -796,7 +803,8 @@ contract CSModule is IStakingModule, CSModuleBase { (nodeOperatorId, start, count, nonce) = Batch.deserialize(batch); NodeOperator storage no = _nodeOperators[nodeOperatorId]; - _assertIsValidBatch(no, start, count, nonce); + // solhint-disable-next-line func-named-parameters + _assertIsValidBatch(no, batch, start, count, nonce); startIndex = Math.max(start, no.totalDepositedKeys); depositableKeysCount = start + count - startIndex; @@ -804,24 +812,18 @@ contract CSModule is IStakingModule, CSModuleBase { function _assertIsValidBatch( NodeOperator storage no, + bytes32 batch, uint256 start, uint256 count, uint256 nonce ) internal view { - require(count != 0, "Empty batch given"); - require(nonce == no.queueNonce, "Invalid batch nonce"); - require( - _unvettedKeysInBatch(no, start, count) == false, - "Batch contains unvetted keys" - ); - require( - start + count <= no.totalAddedKeys, - "Invalid batch range: not enough keys" - ); - require( - start <= no.totalDepositedKeys, - "Invalid batch range: skipped keys" - ); + if (count == 0) revert QueueEmptyBatch(); + if (nonce != no.queueNonce) revert QueueBatchInvalidNonce(batch); + if (start > no.totalDepositedKeys) revert QueueBatchInvalidStart(batch); + if (start + count > no.totalAddedKeys) + revert QueueBatchInvalidCount(batch); + if (_unvettedKeysInBatch(no, start, count)) + revert QueueBatchUnvettedKeys(batch); } /// @dev returns the next pointer to start cleanup from @@ -829,7 +831,7 @@ contract CSModule is IStakingModule, CSModuleBase { uint256 maxItems, bytes32 pointer ) external returns (bytes32) { - require(maxItems > 0, "Queue walkthrough limit is not set"); + if (maxItems == 0) revert QueueLookupNoLimit(); if (Batch.isNil(pointer)) { pointer = queue.front; @@ -865,7 +867,7 @@ contract CSModule is IStakingModule, CSModuleBase { uint256 maxItems, bytes32 pointer ) external view returns (bytes32[] memory items, uint256 /* count */) { - require(maxItems > 0, "Queue walkthrough limit is not set"); + if (maxItems == 0) revert QueueLookupNoLimit(); if (Batch.isNil(pointer)) { pointer = queue.front; @@ -879,7 +881,7 @@ contract CSModule is IStakingModule, CSModuleBase { uint256 maxItems, bytes32 pointer ) external view returns (bool, bytes32) { - require(maxItems > 0, "Queue walkthrough limit is not set"); + if (maxItems == 0) revert QueueLookupNoLimit(); if (Batch.isNil(pointer)) { pointer = queue.front; From a54b8a5972aa58abb8abad9685ef088eca993d8c Mon Sep 17 00:00:00 2001 From: madlabman <10616301+madlabman@users.noreply.github.com> Date: Thu, 23 Nov 2023 17:26:43 +0100 Subject: [PATCH 5/7] refactor: _incrementNonce -> _incrementModuleNonce --- src/CSModule.sol | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/CSModule.sol b/src/CSModule.sol index 7193a697..df2148fd 100644 --- a/src/CSModule.sol +++ b/src/CSModule.sol @@ -668,7 +668,7 @@ contract CSModule is IStakingModule, CSModuleBase { emit BatchEnqueued(nodeOperatorId, start, count); emit VettedSigningKeysCountChanged(nodeOperatorId, vetKeysPointer); - _incrementNonce(); + _incrementModuleNonce(); } function unvetKeys( @@ -691,7 +691,7 @@ contract CSModule is IStakingModule, CSModuleBase { no.totalVettedKeys = no.totalDepositedKeys; no.queueNonce++; emit VettedSigningKeysCountChanged(nodeOperatorId, no.totalVettedKeys); - _incrementNonce(); + _incrementModuleNonce(); } function onWithdrawalCredentialsChanged() external { @@ -724,7 +724,7 @@ contract CSModule is IStakingModule, CSModuleBase { _nodeOperators[nodeOperatorId].totalAddedKeys ); - _incrementNonce(); + _incrementModuleNonce(); } function obtainDepositData( @@ -782,7 +782,7 @@ contract CSModule is IStakingModule, CSModuleBase { } require(loadedKeysCount == depositsCount, "NOT_ENOUGH_KEYS"); - _incrementNonce(); + _incrementModuleNonce(); } function _depositableKeysInBatch( @@ -920,7 +920,7 @@ contract CSModule is IStakingModule, CSModuleBase { return start + count > no.totalVettedKeys; } - function _incrementNonce() internal { + function _incrementModuleNonce() internal { _nonce++; } From b32e1840579da7fbbb4ff1ecc4a0383d9eba8662 Mon Sep 17 00:00:00 2001 From: madlabman <10616301+madlabman@users.noreply.github.com> Date: Thu, 23 Nov 2023 17:30:35 +0100 Subject: [PATCH 6/7] chore: update isQueueDirty comment --- src/CSModule.sol | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/CSModule.sol b/src/CSModule.sol index df2148fd..864f1446 100644 --- a/src/CSModule.sol +++ b/src/CSModule.sol @@ -876,6 +876,8 @@ contract CSModule is IStakingModule, CSModuleBase { return queue.list(pointer, maxItems); } + /// @dev it is dirty if it contains a batch with unvetted keys + /// or with invalid nonce /// @dev returns the next pointer to start check from function isQueueDirty( uint256 maxItems, From 56c1ca6beb87527b85624c140e7086cb5e203523 Mon Sep 17 00:00:00 2001 From: madlabman <10616301+madlabman@users.noreply.github.com> Date: Thu, 23 Nov 2023 20:59:31 +0100 Subject: [PATCH 7/7] chore: MAX_NODE_OPERATORS_COUNT type --- src/CSModule.sol | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/CSModule.sol b/src/CSModule.sol index 864f1446..30bac7a7 100644 --- a/src/CSModule.sol +++ b/src/CSModule.sol @@ -117,7 +117,7 @@ contract CSModule is IStakingModule, CSModuleBase { // @dev max number of node operators is limited by uint64 due to Batch serialization in 32 bytes // it seems to be enough - uint128 public constant MAX_NODE_OPERATORS_COUNT = type(uint64).max; + uint64 public constant MAX_NODE_OPERATORS_COUNT = type(uint64).max; bytes32 public constant SIGNING_KEYS_POSITION = keccak256("lido.CommunityStakingModule.signingKeysPosition");