diff --git a/src/CSAccounting.sol b/src/CSAccounting.sol index d123d35c..8f27a3e8 100644 --- a/src/CSAccounting.sol +++ b/src/CSAccounting.sol @@ -88,13 +88,13 @@ contract CSAccounting is CSAccountingBase, AccessControlEnumerable { } bytes32 public constant INSTANT_PENALIZE_BOND_ROLE = - keccak256("INSTANT_PENALIZE_BOND_ROLE"); + keccak256("INSTANT_PENALIZE_BOND_ROLE"); // 0x9909cf24c2d3bafa8c229558d86a1b726ba57c3ef6350848dcf434a4181b56c7 bytes32 public constant EL_REWARDS_STEALING_PENALTY_INIT_ROLE = - keccak256("EL_REWARDS_STEALING_PENALTY_INIT_ROLE"); + keccak256("EL_REWARDS_STEALING_PENALTY_INIT_ROLE"); // 0xcc2e7ce7be452f766dd24d55d87a3d42901c31ffa5b600cd1dff475abec91c1f bytes32 public constant EL_REWARDS_STEALING_PENALTY_SETTLE_ROLE = - keccak256("EL_REWARDS_STEALING_PENALTY_SETTLE_ROLE"); + keccak256("EL_REWARDS_STEALING_PENALTY_SETTLE_ROLE"); // 0xdf6226649a1ca132f86d419e46892001284368a8f7445b5eb0d3fadf91329fe6 bytes32 public constant SET_BOND_MULTIPLIER_ROLE = - keccak256("SET_BOND_MULTIPLIER_ROLE"); + keccak256("SET_BOND_MULTIPLIER_ROLE"); // 0x62131145aee19b18b85aa8ead52ba87f0efb6e61e249155edc68a2c24e8f79b5 // todo: should be reconsidered uint256 public constant MIN_BLOCKED_BOND_RETENTION_PERIOD = 4 weeks; diff --git a/src/CSModule.sol b/src/CSModule.sol index b34e9fd9..30bac7a7 100644 --- a/src/CSModule.sol +++ b/src/CSModule.sol @@ -33,6 +33,7 @@ struct NodeOperator { uint256 stuckValidatorsCount; uint256 refundedValidatorsCount; bool isTargetLimitActive; + uint256 queueNonce; } struct NodeOperatorInfo { @@ -102,14 +103,21 @@ contract CSModuleBase { error SameAddress(); error AlreadyProposed(); error InvalidVetKeysPointer(); + + error QueueLookupNoLimit(); + error QueueEmptyBatch(); + error QueueBatchInvalidNonce(bytes32 batch); + error QueueBatchInvalidStart(bytes32 batch); + error QueueBatchInvalidCount(bytes32 batch); + error QueueBatchUnvettedKeys(bytes32 batch); } contract CSModule is IStakingModule, CSModuleBase { using QueueLib for QueueLib.Queue; - // @dev max number of node operators is limited by uint128 due to Batch serialization in 32 bytes + // @dev max number of node operators is limited by uint64 due to Batch serialization in 32 bytes // it seems to be enough - uint128 public constant MAX_NODE_OPERATORS_COUNT = type(uint128).max; + uint64 public constant MAX_NODE_OPERATORS_COUNT = type(uint64).max; bytes32 public constant SIGNING_KEYS_POSITION = keccak256("lido.CommunityStakingModule.signingKeysPosition"); @@ -645,14 +653,13 @@ contract CSModule is IStakingModule, CSModuleBase { if (vetKeysPointer > no.totalAddedKeys) revert InvalidVetKeysPointer(); uint64 count = SafeCast.toUint64(vetKeysPointer - no.totalVettedKeys); - uint64 start = SafeCast.toUint64( - no.totalVettedKeys == 0 ? 0 : no.totalVettedKeys - ); + uint64 start = SafeCast.toUint64(no.totalVettedKeys); bytes32 pointer = Batch.serialize({ - nodeOperatorId: SafeCast.toUint128(nodeOperatorId), + nodeOperatorId: SafeCast.toUint64(nodeOperatorId), start: start, - count: count + count: count, + nonce: SafeCast.toUint64(no.queueNonce) }); no.totalVettedKeys = vetKeysPointer; @@ -661,7 +668,7 @@ contract CSModule is IStakingModule, CSModuleBase { emit BatchEnqueued(nodeOperatorId, start, count); emit VettedSigningKeysCountChanged(nodeOperatorId, vetKeysPointer); - _incrementNonce(); + _incrementModuleNonce(); } function unvetKeys( @@ -682,8 +689,9 @@ contract CSModule is IStakingModule, CSModuleBase { function _unvetKeys(uint256 nodeOperatorId) internal { NodeOperator storage no = _nodeOperators[nodeOperatorId]; no.totalVettedKeys = no.totalDepositedKeys; + no.queueNonce++; emit VettedSigningKeysCountChanged(nodeOperatorId, no.totalVettedKeys); - _incrementNonce(); + _incrementModuleNonce(); } function onWithdrawalCredentialsChanged() external { @@ -716,7 +724,7 @@ contract CSModule is IStakingModule, CSModuleBase { _nodeOperators[nodeOperatorId].totalAddedKeys ); - _incrementNonce(); + _incrementModuleNonce(); } function obtainDepositData( @@ -754,6 +762,7 @@ contract CSModule is IStakingModule, CSModuleBase { _totalDepositedValidators += keysCount; NodeOperator storage no = _nodeOperators[nodeOperatorId]; no.totalDepositedKeys += keysCount; + // redundant check, enforced by _assertIsValidBatch require( no.totalDepositedKeys <= no.totalVettedKeys, "too many keys" @@ -773,7 +782,7 @@ contract CSModule is IStakingModule, CSModuleBase { } require(loadedKeysCount == depositsCount, "NOT_ENOUGH_KEYS"); - _incrementNonce(); + _incrementModuleNonce(); } function _depositableKeysInBatch( @@ -789,11 +798,13 @@ contract CSModule is IStakingModule, CSModuleBase { { uint256 start; uint256 count; + uint256 nonce; - (nodeOperatorId, start, count) = Batch.deserialize(batch); + (nodeOperatorId, start, count, nonce) = Batch.deserialize(batch); NodeOperator storage no = _nodeOperators[nodeOperatorId]; - _assertIsValidBatch(no, start, count); + // solhint-disable-next-line func-named-parameters + _assertIsValidBatch(no, batch, start, count, nonce); startIndex = Math.max(start, no.totalDepositedKeys); depositableKeysCount = start + count - startIndex; @@ -801,22 +812,18 @@ contract CSModule is IStakingModule, CSModuleBase { function _assertIsValidBatch( NodeOperator storage no, + bytes32 batch, uint256 start, - uint256 count + uint256 count, + uint256 nonce ) internal view { - require(count != 0, "Empty batch given"); - require( - _unvettedKeysInBatch(no, start, count) == false, - "Batch contains unvetted keys" - ); - require( - start + count <= no.totalAddedKeys, - "Invalid batch range: not enough keys" - ); - require( - start <= no.totalDepositedKeys, - "Invalid batch range: skipped keys" - ); + if (count == 0) revert QueueEmptyBatch(); + if (nonce != no.queueNonce) revert QueueBatchInvalidNonce(batch); + if (start > no.totalDepositedKeys) revert QueueBatchInvalidStart(batch); + if (start + count > no.totalAddedKeys) + revert QueueBatchInvalidCount(batch); + if (_unvettedKeysInBatch(no, start, count)) + revert QueueBatchUnvettedKeys(batch); } /// @dev returns the next pointer to start cleanup from @@ -824,7 +831,7 @@ contract CSModule is IStakingModule, CSModuleBase { uint256 maxItems, bytes32 pointer ) external returns (bytes32) { - require(maxItems > 0, "Queue walkthrough limit is not set"); + if (maxItems == 0) revert QueueLookupNoLimit(); if (Batch.isNil(pointer)) { pointer = queue.front; @@ -836,11 +843,18 @@ contract CSModule is IStakingModule, CSModuleBase { break; } - (uint256 nodeOperatorId, uint256 start, uint256 count) = Batch - .deserialize(item); + ( + uint256 nodeOperatorId, + uint256 start, + uint256 count, + uint256 nonce + ) = Batch.deserialize(item); NodeOperator storage no = _nodeOperators[nodeOperatorId]; - if (_unvettedKeysInBatch(no, start, count)) { + if ( + _unvettedKeysInBatch(no, start, count) || nonce != no.queueNonce + ) { queue.remove(pointer, item); + continue; } pointer = item; @@ -852,16 +866,8 @@ contract CSModule is IStakingModule, CSModuleBase { function depositQueue( uint256 maxItems, bytes32 pointer - ) - external - view - returns ( - bytes32[] memory items, - bytes32 /* pointer */, - uint256 /* count */ - ) - { - require(maxItems > 0, "Queue walkthrough limit is not set"); + ) external view returns (bytes32[] memory items, uint256 /* count */) { + if (maxItems == 0) revert QueueLookupNoLimit(); if (Batch.isNil(pointer)) { pointer = queue.front; @@ -870,12 +876,14 @@ contract CSModule is IStakingModule, CSModuleBase { return queue.list(pointer, maxItems); } + /// @dev it is dirty if it contains a batch with unvetted keys + /// or with invalid nonce /// @dev returns the next pointer to start check from - function isQueueHasUnvettedKeys( + function isQueueDirty( uint256 maxItems, bytes32 pointer ) external view returns (bool, bytes32) { - require(maxItems > 0, "Queue walkthrough limit is not set"); + if (maxItems == 0) revert QueueLookupNoLimit(); if (Batch.isNil(pointer)) { pointer = queue.front; @@ -887,10 +895,16 @@ contract CSModule is IStakingModule, CSModuleBase { break; } - (uint256 nodeOperatorId, uint256 start, uint256 count) = Batch - .deserialize(item); + ( + uint256 nodeOperatorId, + uint256 start, + uint256 count, + uint256 nonce + ) = Batch.deserialize(item); NodeOperator storage no = _nodeOperators[nodeOperatorId]; - if (_unvettedKeysInBatch(no, start, count)) { + if ( + _unvettedKeysInBatch(no, start, count) || nonce != no.queueNonce + ) { return (true, pointer); } @@ -908,7 +922,7 @@ contract CSModule is IStakingModule, CSModuleBase { return start + count > no.totalVettedKeys; } - function _incrementNonce() internal { + function _incrementModuleNonce() internal { _nonce++; } diff --git a/src/lib/Batch.sol b/src/lib/Batch.sol index 664f88f7..cc7c3712 100644 --- a/src/lib/Batch.sol +++ b/src/lib/Batch.sol @@ -6,21 +6,23 @@ pragma solidity 0.8.21; library Batch { /// @notice Serialize node operator id, batch start and count of keys into a single bytes32 value function serialize( - uint128 nodeOperatorId, + uint64 nodeOperatorId, uint64 start, - uint64 count + uint64 count, + uint64 nonce ) internal pure returns (bytes32 s) { - return bytes32(abi.encodePacked(nodeOperatorId, start, count)); + return bytes32(abi.encodePacked(nodeOperatorId, start, count, nonce)); } /// @notice Deserialize node operator id, batch start and count of keys from a single bytes32 value function deserialize( bytes32 b - ) internal pure returns (uint128 nodeOperatorId, uint64 start, uint64 count) { + ) internal pure returns (uint64 nodeOperatorId, uint64 start, uint64 count, uint64 nonce) { assembly { - nodeOperatorId := shr(128, b) - start := shr(64, b) - count := b + nodeOperatorId := shr(192, b) + start := shr(128, b) + count := shr(64, b) + nonce := b } } diff --git a/src/lib/QueueLib.sol b/src/lib/QueueLib.sol index 0283d87f..ad7824d3 100644 --- a/src/lib/QueueLib.sol +++ b/src/lib/QueueLib.sol @@ -7,6 +7,8 @@ pragma solidity 0.8.21; library QueueLib { bytes32 public constant NULL_POINTER = bytes32(0); + // @dev Queue is a linked list of items + // @dev front and back are pointers struct Queue { mapping(bytes32 => bytes32) queue; bytes32 front; @@ -38,11 +40,13 @@ library QueueLib { return self.queue[pointer]; } + // @dev returns items array of size `limit` and actual count of items + // @dev reverts if the queue is empty function list(Queue storage self, bytes32 pointer, uint256 limit) internal notEmpty(self) view returns ( bytes32[] memory items, - bytes32 /* pointer */, uint256 /* count */ ) { + require(limit > 0, "Queue: limit is not set"); items = new bytes32[](limit); uint256 i; @@ -56,7 +60,8 @@ library QueueLib { pointer = item; } - return (items, pointer, i); + // TODO: resize items array to actual count + return (items, i); } function isEmpty(Queue storage self) internal view returns (bool) { diff --git a/test/Batch.t.sol b/test/Batch.t.sol index a5a197ad..cc93af46 100644 --- a/test/Batch.t.sol +++ b/test/Batch.t.sol @@ -11,44 +11,52 @@ contract BatchTest is Test { bytes32 b = Batch.serialize({ nodeOperatorId: 999, start: 3, - count: 42 + count: 42, + nonce: 7 }); assertEq( b, - // noIndex | start | count | - 0x000000000000000000000000000003e70000000000000003000000000000002a + // noIndex | start | count | nonce + 0x00000000000003e70000000000000003000000000000002a0000000000000007 ); } function test_deserialize() public { - (uint128 nodeOperatorId, uint64 start, uint64 count) = Batch - .deserialize( + ( + uint256 nodeOperatorId, + uint256 start, + uint256 count, + uint256 nonce + ) = Batch.deserialize( 0x0000000000000000000000000000000000000000000000000000000000000000 ); assertEq(nodeOperatorId, 0, "nodeOperatorId != 0"); assertEq(start, 0, "start != 0"); assertEq(count, 0, "count != 0"); + assertEq(nonce, 0, "nonce != 0"); - (nodeOperatorId, start, count) = Batch.deserialize( - 0x000000000000000000000000000003e70000000000000003000000000000002a + (nodeOperatorId, start, count, nonce) = Batch.deserialize( + 0x00000000000003e70000000000000003000000000000002a0000000000000007 ); assertEq(nodeOperatorId, 999, "nodeOperatorId != 999"); assertEq(start, 3, "start != 3"); assertEq(count, 42, "count != 42"); + assertEq(nonce, 7, "nonce != 7"); - (nodeOperatorId, start, count) = Batch.deserialize( + (nodeOperatorId, start, count, nonce) = Batch.deserialize( 0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff ); assertEq( nodeOperatorId, - type(uint128).max, - "nodeOperatorId != uint128.max" + type(uint64).max, + "nodeOperatorId != uint64.max" ); assertEq(start, type(uint64).max, "start != uint64.max"); assertEq(count, type(uint64).max, "count != uint64.max"); + assertEq(nonce, type(uint64).max, "nonce != uint64.max"); } } diff --git a/test/CSModule.t.sol b/test/CSModule.t.sol index 17a94578..77107361 100644 --- a/test/CSModule.t.sol +++ b/test/CSModule.t.sol @@ -13,7 +13,20 @@ import "./helpers/mocks/LidoMock.sol"; import "./helpers/mocks/WstETHMock.sol"; import "./helpers/Utilities.sol"; +import { Strings } from "@openzeppelin/contracts/utils/Strings.sol"; + contract CSMCommon is Test, Fixtures, Utilities, CSModuleBase { + using Strings for uint256; + + struct BatchInfo { + uint256 nodeOperatorId; + uint256 start; + uint256 count; + uint256 nonce; + } + + bytes32 public constant NULL_POINTER = bytes32(0); + LidoLocatorMock public locator; WstETHMock public wstETH; LidoMock public stETH; @@ -22,16 +35,16 @@ contract CSMCommon is Test, Fixtures, Utilities, CSModuleBase { CSAccounting public accounting; CommunityStakingFeeDistributorMock public communityStakingFeeDistributor; + address internal admin; address internal stranger; - address internal alice; address internal nodeOperator; function setUp() public { - alice = address(1); - nodeOperator = address(2); - stranger = address(3); - address[] memory penalizeRoleMembers = new address[](1); - penalizeRoleMembers[0] = alice; + vm.label(address(this), "TEST"); + + nodeOperator = nextAddress("NODE_OPERATOR"); + stranger = nextAddress("STRANGER"); + admin = nextAddress("ADMIN"); (locator, wstETH, stETH, burner) = initLido(); @@ -46,7 +59,7 @@ contract CSMCommon is Test, Fixtures, Utilities, CSModuleBase { csm = new CSModule("community-staking-module", address(locator)); accounting = new CSAccounting( 2 ether, - alice, + admin, address(locator), address(wstETH), address(csm), @@ -54,6 +67,13 @@ contract CSMCommon is Test, Fixtures, Utilities, CSModuleBase { 1 days ); csm.setAccounting(address(accounting)); + + vm.startPrank(admin); + accounting.grantRole( + accounting.INSTANT_PENALIZE_BOND_ROLE(), + address(csm) + ); // NOTE: required because of `unvetKeys` + vm.stopPrank(); } function createNodeOperator() internal returns (uint256) { @@ -80,6 +100,74 @@ contract CSMCommon is Test, Fixtures, Utilities, CSModuleBase { ); return csm.getNodeOperatorsCount() - 1; } + + function _assertQueueState(BatchInfo[] memory exp) internal { + if (exp.length == 0) { + revert("NOTE: use _assertQueueIsEmpty"); + } + + (bytes32 pointer, ) = csm.queue(); // queue.front + + for (uint256 i = 0; i < exp.length; i++) { + BatchInfo memory b = exp[i]; + + assertFalse( + _isLastElementInQueue(pointer), + string.concat("unexpected end of queue at index ", i.toString()) + ); + + pointer = _nextPointer(pointer); + ( + uint256 nodeOperatorId, + uint256 start, + uint256 count, + uint256 nonce + ) = Batch.deserialize(pointer); + + assertEq( + nodeOperatorId, + b.nodeOperatorId, + string.concat( + "unexpected `nodeOperatorId` at index ", + i.toString() + ) + ); + assertEq( + start, + b.start, + string.concat("unexpected `start` at index ", i.toString()) + ); + assertEq( + count, + b.count, + string.concat("unexpected `count` at index ", i.toString()) + ); + assertEq( + nonce, + b.nonce, + string.concat("unexpected `nonce` at index ", i.toString()) + ); + } + + assertTrue(_isLastElementInQueue(pointer), "unexpected tail of queue"); + } + + function _assertQueueIsEmpty() internal { + (bytes32 front, bytes32 back) = csm.queue(); + assertEq(front, back, "queue is not empty"); + } + + function _isLastElementInQueue( + bytes32 pointer + ) internal view returns (bool) { + bytes32 next = _nextPointer(pointer); + return next == pointer; + } + + function _nextPointer(bytes32 pointer) internal view returns (bytes32) { + (bytes32[] memory items, uint256 count) = csm.depositQueue(1, pointer); + return count == 0 ? pointer : items[0]; + } } contract CsmInitialization is CSMCommon { @@ -369,9 +457,9 @@ contract CsmProposeNodeOperatorManagerAddressChange is CSMCommon { assertEq(no.rewardAddress, nodeOperator); vm.expectEmit(true, true, false, true, address(csm)); - emit NodeOperatorManagerAddressChangeProposed(noId, alice); + emit NodeOperatorManagerAddressChangeProposed(noId, stranger); vm.prank(nodeOperator); - csm.proposeNodeOperatorManagerAddressChange(noId, alice); + csm.proposeNodeOperatorManagerAddressChange(noId, stranger); assertEq(no.managerAddress, nodeOperator); assertEq(no.rewardAddress, nodeOperator); } @@ -380,7 +468,7 @@ contract CsmProposeNodeOperatorManagerAddressChange is CSMCommon { public { vm.expectRevert(NodeOperatorDoesNotExist.selector); - csm.proposeNodeOperatorManagerAddressChange(0, alice); + csm.proposeNodeOperatorManagerAddressChange(0, stranger); } function test_proposeNodeOperatorManagerAddressChange_RevertWhenNotManager() @@ -388,7 +476,7 @@ contract CsmProposeNodeOperatorManagerAddressChange is CSMCommon { { uint256 noId = createNodeOperator(); vm.expectRevert(SenderIsNotManagerAddress.selector); - csm.proposeNodeOperatorManagerAddressChange(noId, alice); + csm.proposeNodeOperatorManagerAddressChange(noId, stranger); } function test_proposeNodeOperatorManagerAddressChange_RevertWhenAlreadyProposed() @@ -396,11 +484,11 @@ contract CsmProposeNodeOperatorManagerAddressChange is CSMCommon { { uint256 noId = createNodeOperator(); vm.prank(nodeOperator); - csm.proposeNodeOperatorManagerAddressChange(noId, alice); + csm.proposeNodeOperatorManagerAddressChange(noId, stranger); vm.expectRevert(AlreadyProposed.selector); vm.prank(nodeOperator); - csm.proposeNodeOperatorManagerAddressChange(noId, alice); + csm.proposeNodeOperatorManagerAddressChange(noId, stranger); } function test_proposeNodeOperatorManagerAddressChange_RevertWhenSameAddressProposed() @@ -421,15 +509,15 @@ contract CsmConfirmNodeOperatorManagerAddressChange is CSMCommon { assertEq(no.rewardAddress, nodeOperator); vm.prank(nodeOperator); - csm.proposeNodeOperatorManagerAddressChange(noId, alice); + csm.proposeNodeOperatorManagerAddressChange(noId, stranger); vm.expectEmit(true, true, true, true, address(csm)); - emit NodeOperatorManagerAddressChanged(noId, nodeOperator, alice); - vm.prank(alice); + emit NodeOperatorManagerAddressChanged(noId, nodeOperator, stranger); + vm.prank(stranger); csm.confirmNodeOperatorManagerAddressChange(noId); no = csm.getNodeOperator(noId); - assertEq(no.managerAddress, alice); + assertEq(no.managerAddress, stranger); assertEq(no.rewardAddress, nodeOperator); } @@ -445,7 +533,7 @@ contract CsmConfirmNodeOperatorManagerAddressChange is CSMCommon { { uint256 noId = createNodeOperator(); vm.expectRevert(SenderIsNotProposedAddress.selector); - vm.prank(alice); + vm.prank(stranger); csm.confirmNodeOperatorManagerAddressChange(noId); } @@ -457,7 +545,7 @@ contract CsmConfirmNodeOperatorManagerAddressChange is CSMCommon { csm.proposeNodeOperatorManagerAddressChange(noId, stranger); vm.expectRevert(SenderIsNotProposedAddress.selector); - vm.prank(alice); + vm.prank(nextAddress()); csm.confirmNodeOperatorManagerAddressChange(noId); } } @@ -470,9 +558,9 @@ contract CsmProposeNodeOperatorRewardAddressChange is CSMCommon { assertEq(no.rewardAddress, nodeOperator); vm.expectEmit(true, true, false, true, address(csm)); - emit NodeOperatorRewardAddressChangeProposed(noId, alice); + emit NodeOperatorRewardAddressChangeProposed(noId, stranger); vm.prank(nodeOperator); - csm.proposeNodeOperatorRewardAddressChange(noId, alice); + csm.proposeNodeOperatorRewardAddressChange(noId, stranger); assertEq(no.managerAddress, nodeOperator); assertEq(no.rewardAddress, nodeOperator); } @@ -481,7 +569,7 @@ contract CsmProposeNodeOperatorRewardAddressChange is CSMCommon { public { vm.expectRevert(NodeOperatorDoesNotExist.selector); - csm.proposeNodeOperatorRewardAddressChange(0, alice); + csm.proposeNodeOperatorRewardAddressChange(0, stranger); } function test_proposeNodeOperatorRewardAddressChange_RevertWhenNotRewardAddress() @@ -489,7 +577,7 @@ contract CsmProposeNodeOperatorRewardAddressChange is CSMCommon { { uint256 noId = createNodeOperator(); vm.expectRevert(SenderIsNotRewardAddress.selector); - csm.proposeNodeOperatorRewardAddressChange(noId, alice); + csm.proposeNodeOperatorRewardAddressChange(noId, stranger); } function test_proposeNodeOperatorRewardAddressChange_RevertWhenAlreadyProposed() @@ -497,11 +585,11 @@ contract CsmProposeNodeOperatorRewardAddressChange is CSMCommon { { uint256 noId = createNodeOperator(); vm.prank(nodeOperator); - csm.proposeNodeOperatorRewardAddressChange(noId, alice); + csm.proposeNodeOperatorRewardAddressChange(noId, stranger); vm.expectRevert(AlreadyProposed.selector); vm.prank(nodeOperator); - csm.proposeNodeOperatorRewardAddressChange(noId, alice); + csm.proposeNodeOperatorRewardAddressChange(noId, stranger); } function test_proposeNodeOperatorRewardAddressChange_RevertWhenSameAddressProposed() @@ -522,16 +610,16 @@ contract CsmConfirmNodeOperatorRewardAddressChange is CSMCommon { assertEq(no.rewardAddress, nodeOperator); vm.prank(nodeOperator); - csm.proposeNodeOperatorRewardAddressChange(noId, alice); + csm.proposeNodeOperatorRewardAddressChange(noId, stranger); vm.expectEmit(true, true, true, true, address(csm)); - emit NodeOperatorRewardAddressChanged(noId, nodeOperator, alice); - vm.prank(alice); + emit NodeOperatorRewardAddressChanged(noId, nodeOperator, stranger); + vm.prank(stranger); csm.confirmNodeOperatorRewardAddressChange(noId); no = csm.getNodeOperator(noId); assertEq(no.managerAddress, nodeOperator); - assertEq(no.rewardAddress, alice); + assertEq(no.rewardAddress, stranger); } function test_confirmNodeOperatorRewardAddressChange_RevertWhenNoNodeOperator() @@ -546,7 +634,7 @@ contract CsmConfirmNodeOperatorRewardAddressChange is CSMCommon { { uint256 noId = createNodeOperator(); vm.expectRevert(SenderIsNotProposedAddress.selector); - vm.prank(alice); + vm.prank(stranger); csm.confirmNodeOperatorRewardAddressChange(noId); } @@ -558,7 +646,7 @@ contract CsmConfirmNodeOperatorRewardAddressChange is CSMCommon { csm.proposeNodeOperatorRewardAddressChange(noId, stranger); vm.expectRevert(SenderIsNotProposedAddress.selector); - vm.prank(alice); + vm.prank(nextAddress()); csm.confirmNodeOperatorRewardAddressChange(noId); } } @@ -568,18 +656,18 @@ contract CsmResetNodeOperatorManagerAddress is CSMCommon { uint256 noId = createNodeOperator(); vm.prank(nodeOperator); - csm.proposeNodeOperatorRewardAddressChange(noId, alice); - vm.prank(alice); + csm.proposeNodeOperatorRewardAddressChange(noId, stranger); + vm.prank(stranger); csm.confirmNodeOperatorRewardAddressChange(noId); vm.expectEmit(true, true, true, true, address(csm)); - emit NodeOperatorManagerAddressChanged(noId, nodeOperator, alice); - vm.prank(alice); + emit NodeOperatorManagerAddressChanged(noId, nodeOperator, stranger); + vm.prank(stranger); csm.resetNodeOperatorManagerAddress(noId); NodeOperatorInfo memory no = csm.getNodeOperator(noId); - assertEq(no.managerAddress, alice); - assertEq(no.rewardAddress, alice); + assertEq(no.managerAddress, stranger); + assertEq(no.rewardAddress, stranger); } function test_resetNodeOperatorManagerAddress_RevertWhenNoNodeOperator() @@ -618,13 +706,15 @@ contract CsmVetKeys is CSMCommon { NodeOperatorInfo memory no = csm.getNodeOperator(noId); assertEq(no.totalVettedValidators, 1); - (bytes32[] memory items, , ) = csm.depositQueue(1, bytes32(0)); - (uint128 batchNoId, uint64 start, uint64 count) = Batch.deserialize( - items[0] - ); - assertEq(batchNoId, uint128(noId)); - assertEq(start, 0); - assertEq(count, 1); + + BatchInfo[] memory exp = new BatchInfo[](1); + exp[0] = BatchInfo({ + nodeOperatorId: noId, + start: 0, + count: 1, + nonce: 0 + }); + _assertQueueState(exp); } function test_vetKeys_totalVettedKeysIsNotZero() public { @@ -639,13 +729,21 @@ contract CsmVetKeys is CSMCommon { NodeOperatorInfo memory no = csm.getNodeOperator(noId); assertEq(no.totalVettedValidators, 2); - (bytes32[] memory items, , ) = csm.depositQueue(2, bytes32(0)); - (uint128 batchNoId, uint64 start, uint64 count) = Batch.deserialize( - items[1] - ); - assertEq(batchNoId, uint128(noId)); - assertEq(start, 1); - assertEq(count, 1); + + BatchInfo[] memory exp = new BatchInfo[](2); + exp[0] = BatchInfo({ + nodeOperatorId: noId, + start: 0, + count: 1, + nonce: 0 + }); + exp[1] = BatchInfo({ + nodeOperatorId: noId, + start: 1, + count: 1, + nonce: 0 + }); + _assertQueueState(exp); } function test_vetKeys_RevertWhenNoNodeOperator() public { @@ -667,3 +765,150 @@ contract CsmVetKeys is CSMCommon { csm.vetKeys(noId, 2); } } + +contract CsmQueueOps is CSMCommon { + uint256 internal constant LOOKUP_DEPTH = 150; // derived from maxDepositsPerBlock + + function test_emptyQueueIsClean() public { + (bool isDirty /* next */, ) = csm.isQueueDirty( + LOOKUP_DEPTH, + NULL_POINTER + ); + assertFalse(isDirty, "queue should be clean"); + } + + function test_queueIsDirty_WhenUnvettedKeys() public { + createNodeOperator({ keysCount: 2 }); + csm.vetKeys(0, 1); + csm.unvetKeys(0); + + (bool isDirty /* next */, ) = csm.isQueueDirty( + LOOKUP_DEPTH, + NULL_POINTER + ); + assertTrue(isDirty, "queue should be dirty"); + } + + function test_queueIsClean_AfterCleanup() public { + createNodeOperator({ keysCount: 2 }); + csm.vetKeys(0, 1); + csm.unvetKeys(0); + csm.cleanDepositQueue(LOOKUP_DEPTH, NULL_POINTER); + + (bool isDirty /* next */, ) = csm.isQueueDirty( + LOOKUP_DEPTH, + NULL_POINTER + ); + assertFalse(isDirty, "queue should be clean"); + } + + function test_queueIsDirty_WhenDanglingBatches() public { + createNodeOperator({ keysCount: 2 }); + + csm.vetKeys(0, 1); + csm.vetKeys(0, 2); + csm.unvetKeys(0); + csm.vetKeys(0, 2); + + // let's check the state of the queue + BatchInfo[] memory exp = new BatchInfo[](3); + exp[0] = BatchInfo({ nodeOperatorId: 0, start: 0, count: 1, nonce: 0 }); + exp[1] = BatchInfo({ nodeOperatorId: 0, start: 1, count: 1, nonce: 0 }); + exp[2] = BatchInfo({ nodeOperatorId: 0, start: 0, count: 2, nonce: 1 }); + _assertQueueState(exp); + + (bool isDirty /* next */, ) = csm.isQueueDirty( + LOOKUP_DEPTH, + NULL_POINTER + ); + assertTrue(isDirty, "queue should be dirty"); + } + + function test_queueIsClean_WhenDanglingBatchesCleanedUp() public { + createNodeOperator({ keysCount: 2 }); + + csm.vetKeys(0, 1); + csm.vetKeys(0, 2); + csm.unvetKeys(0); + csm.vetKeys(0, 2); + + csm.cleanDepositQueue(LOOKUP_DEPTH, NULL_POINTER); + // let's check the state of the queue + BatchInfo[] memory exp = new BatchInfo[](1); + exp[0] = BatchInfo({ nodeOperatorId: 0, start: 0, count: 2, nonce: 1 }); + _assertQueueState(exp); + + (bool isDirty /* next */, ) = csm.isQueueDirty( + LOOKUP_DEPTH, + NULL_POINTER + ); + assertFalse(isDirty, "queue should be clean"); + } + + function test_cleanup_emptyQueue() public { + csm.cleanDepositQueue(LOOKUP_DEPTH, NULL_POINTER); + _assertQueueIsEmpty(); + + (bool isDirty /* next */, ) = csm.isQueueDirty( + LOOKUP_DEPTH, + NULL_POINTER + ); + assertFalse(isDirty, "queue should be clean"); + } + + function test_cleanup_WhenOneInvalidBatchInRow() public { + createNodeOperator({ keysCount: 2 }); + + csm.vetKeys(0, 1); + csm.unvetKeys(0); + csm.vetKeys(0, 1); + + csm.cleanDepositQueue(LOOKUP_DEPTH, NULL_POINTER); + // let's check the state of the queue + BatchInfo[] memory exp = new BatchInfo[](1); + exp[0] = BatchInfo({ nodeOperatorId: 0, start: 0, count: 1, nonce: 1 }); + _assertQueueState(exp); + + (bool isDirty /* next */, ) = csm.isQueueDirty( + LOOKUP_DEPTH, + NULL_POINTER + ); + assertFalse(isDirty, "queue should be clean"); + } + + function test_cleanup_WhenMultipleInvalidBatchesInRow() public { + createNodeOperator({ keysCount: 3 }); + createNodeOperator({ keysCount: 2 }); + + csm.vetKeys(0, 1); // <-- invalid + csm.vetKeys(1, 1); + csm.vetKeys(0, 2); // <-- invalid + csm.vetKeys(0, 3); // <-- invalid + csm.unvetKeys(0); + csm.vetKeys(0, 3); + + csm.cleanDepositQueue(LOOKUP_DEPTH, NULL_POINTER); + // let's check the state of the queue + BatchInfo[] memory exp = new BatchInfo[](2); + exp[0] = BatchInfo({ nodeOperatorId: 1, start: 0, count: 1, nonce: 0 }); + exp[1] = BatchInfo({ nodeOperatorId: 0, start: 0, count: 3, nonce: 1 }); + _assertQueueState(exp); + + (bool isDirty /* next */, ) = csm.isQueueDirty( + LOOKUP_DEPTH, + NULL_POINTER + ); + assertFalse(isDirty, "queue should be clean"); + } + + function test_cleanup_WhenAllBatchesInvalid() public { + createNodeOperator({ keysCount: 2 }); + + csm.vetKeys(0, 1); + csm.vetKeys(0, 2); + csm.unvetKeys(0); + + csm.cleanDepositQueue(LOOKUP_DEPTH, NULL_POINTER); + _assertQueueIsEmpty(); + } +} diff --git a/test/QueueLib.t.sol b/test/QueueLib.t.sol index 2f6acc66..eb8d8f12 100644 --- a/test/QueueLib.t.sol +++ b/test/QueueLib.t.sol @@ -56,32 +56,24 @@ contract QueueLibTest is Test { q.enqueue(p2); { - (bytes32[] memory items, bytes32 pointer, uint256 count) = q.list( - q.front, - 2 - ); + (bytes32[] memory items, uint256 count) = q.list(q.front, 2); assertEq(count, 2); - assertEq(pointer, p1); assertEq(items[0], p0); assertEq(items[1], p1); } { - (bytes32[] memory items, bytes32 pointer, uint256 count) = q.list( - p1, - 999 - ); + (bytes32[] memory items, uint256 count) = q.list(p1, 999); assertEq(count, 1); - assertEq(pointer, p2); assertEq(items[0], p2); } q.dequeue(); { - (, bytes32 pointer, uint256 count) = q.list(q.front, 0); - assertEq(count, 0); - assertEq(pointer, q.front); + (bytes32[] memory items, uint256 count) = q.list(q.front, 1); + assertEq(count, 1); + assertEq(items[0], p1); } } diff --git a/test/helpers/Utilities.sol b/test/helpers/Utilities.sol index cbd01853..73695604 100644 --- a/test/helpers/Utilities.sol +++ b/test/helpers/Utilities.sol @@ -2,8 +2,10 @@ // SPDX-License-Identifier: GPL-3.0 pragma solidity 0.8.21; +import { CommonBase } from "forge-std/Base.sol"; + /// @author madlabman -contract Utilities { +contract Utilities is CommonBase { bytes32 internal seed = keccak256("seed sEed seEd"); function nextAddress() internal returns (address) { @@ -14,6 +16,12 @@ contract Utilities { return a; } + function nextAddress(string memory label) internal returns (address) { + address a = nextAddress(); + vm.label(a, label); + return a; + } + function keysSignatures( uint256 keysCount ) public pure returns (bytes memory, bytes memory) {