From 0521be3c8337518055d8f00a4f9bcfcea9f8a44c Mon Sep 17 00:00:00 2001 From: skhomuti Date: Wed, 8 Nov 2023 12:10:25 +0500 Subject: [PATCH] feat: vetKeys fixes and tests --- src/CSModule.sol | 47 +++++---- test/CSMInit.t.sol | 59 ------------ .../{CSMAddValidator.t.sol => CSModule.t.sol} | 95 ++++++++++++++++++- 3 files changed, 120 insertions(+), 81 deletions(-) delete mode 100644 test/CSMInit.t.sol rename test/{CSMAddValidator.t.sol => CSModule.t.sol} (86%) diff --git a/src/CSModule.sol b/src/CSModule.sol index c3815e93..b34e9fd9 100644 --- a/src/CSModule.sol +++ b/src/CSModule.sol @@ -95,17 +95,21 @@ contract CSModuleBase { event UnvettingFeeSet(uint256 unvettingFee); error NodeOperatorDoesNotExist(); + error MaxNodeOperatorsCountReached(); error SenderIsNotManagerAddress(); error SenderIsNotRewardAddress(); error SenderIsNotProposedAddress(); error SameAddress(); error AlreadyProposed(); + error InvalidVetKeysPointer(); } contract CSModule is IStakingModule, CSModuleBase { using QueueLib for QueueLib.Queue; - uint256 public constant MAX_NODE_OPERATOR_NAME_LENGTH = 255; + // @dev max number of node operators is limited by uint128 due to Batch serialization in 32 bytes + // it seems to be enough + uint128 public constant MAX_NODE_OPERATORS_COUNT = type(uint128).max; bytes32 public constant SIGNING_KEYS_POSITION = keccak256("lido.CommunityStakingModule.signingKeysPosition"); @@ -182,6 +186,8 @@ contract CSModule is IStakingModule, CSModuleBase { ); uint256 id = _nodeOperatorsCount; + if (id == MAX_NODE_OPERATORS_COUNT) + revert MaxNodeOperatorsCountReached(); NodeOperator storage no = _nodeOperators[id]; no.managerAddress = msg.sender; @@ -205,6 +211,8 @@ contract CSModule is IStakingModule, CSModuleBase { // TODO: sanity checks uint256 id = _nodeOperatorsCount; + if (id == MAX_NODE_OPERATORS_COUNT) + revert MaxNodeOperatorsCountReached(); NodeOperator storage no = _nodeOperators[id]; no.managerAddress = msg.sender; @@ -233,6 +241,8 @@ contract CSModule is IStakingModule, CSModuleBase { // TODO sanity checks uint256 id = _nodeOperatorsCount; + if (id == MAX_NODE_OPERATORS_COUNT) + revert MaxNodeOperatorsCountReached(); NodeOperator storage no = _nodeOperators[id]; no.rewardAddress = msg.sender; no.managerAddress = msg.sender; @@ -260,6 +270,8 @@ contract CSModule is IStakingModule, CSModuleBase { // TODO sanity checks uint256 id = _nodeOperatorsCount; + if (id == MAX_NODE_OPERATORS_COUNT) + revert MaxNodeOperatorsCountReached(); NodeOperator storage no = _nodeOperators[id]; no.managerAddress = msg.sender; @@ -288,6 +300,8 @@ contract CSModule is IStakingModule, CSModuleBase { // TODO sanity checks uint256 id = _nodeOperatorsCount; + if (id == MAX_NODE_OPERATORS_COUNT) + revert MaxNodeOperatorsCountReached(); NodeOperator storage no = _nodeOperators[id]; no.rewardAddress = msg.sender; no.managerAddress = msg.sender; @@ -533,7 +547,7 @@ contract CSModule is IStakingModule, CSModuleBase { stuckPenaltyEndTimestamp = no.stuckPenaltyEndTimestamp; totalExitedValidators = no.totalExitedKeys; totalDepositedValidators = no.totalDepositedKeys; - depositableValidatorsCount = no.totalAddedKeys - no.totalExitedKeys; + depositableValidatorsCount = no.totalVettedKeys - no.totalExitedKeys; } function getNonce() external view returns (uint256) { @@ -622,22 +636,17 @@ contract CSModule is IStakingModule, CSModuleBase { function vetKeys( uint256 nodeOperatorId, - uint64 vettedKeysCount - ) external onlyKeyValidator { + uint64 vetKeysPointer + ) external onlyExistingNodeOperator(nodeOperatorId) onlyKeyValidator { NodeOperator storage no = _nodeOperators[nodeOperatorId]; - require( - vettedKeysCount > no.totalVettedKeys, - "Wrong vettedKeysCount: less than already vetted" - ); - require( - vettedKeysCount <= no.totalAddedKeys, - "Wrong vettedKeysCount: more than added" - ); + if (vetKeysPointer <= no.totalVettedKeys) + revert InvalidVetKeysPointer(); + if (vetKeysPointer > no.totalAddedKeys) revert InvalidVetKeysPointer(); - uint64 count = SafeCast.toUint64(vettedKeysCount - no.totalVettedKeys); + uint64 count = SafeCast.toUint64(vetKeysPointer - no.totalVettedKeys); uint64 start = SafeCast.toUint64( - no.totalVettedKeys == 0 ? 0 : no.totalVettedKeys - 1 + no.totalVettedKeys == 0 ? 0 : no.totalVettedKeys ); bytes32 pointer = Batch.serialize({ @@ -646,18 +655,22 @@ contract CSModule is IStakingModule, CSModuleBase { count: count }); - no.totalVettedKeys = vettedKeysCount; + no.totalVettedKeys = vetKeysPointer; queue.enqueue(pointer); emit BatchEnqueued(nodeOperatorId, start, count); - emit VettedSigningKeysCountChanged(nodeOperatorId, vettedKeysCount); + emit VettedSigningKeysCountChanged(nodeOperatorId, vetKeysPointer); _incrementNonce(); } function unvetKeys( uint256 nodeOperatorId - ) external onlyKeyValidatorOrNodeOperatorManager { + ) + external + onlyExistingNodeOperator(nodeOperatorId) + onlyKeyValidatorOrNodeOperatorManager + { _unvetKeys(nodeOperatorId); accounting.penalize(nodeOperatorId, unvettingFee); } diff --git a/test/CSMInit.t.sol b/test/CSMInit.t.sol deleted file mode 100644 index db7ec6d6..00000000 --- a/test/CSMInit.t.sol +++ /dev/null @@ -1,59 +0,0 @@ -// SPDX-License-Identifier: UNLICENSED -pragma solidity 0.8.21; - -import "forge-std/Test.sol"; -import "../src/CSModule.sol"; -import "../src/CSAccounting.sol"; -import "./helpers/Fixtures.sol"; -import "./helpers/mocks/StETHMock.sol"; -import "./helpers/mocks/CommunityStakingFeeDistributorMock.sol"; -import "./helpers/mocks/LidoLocatorMock.sol"; -import "./helpers/mocks/LidoMock.sol"; -import "./helpers/mocks/WstETHMock.sol"; - -contract CSMInitTest is Test, Fixtures { - LidoLocatorMock public locator; - WstETHMock public wstETH; - LidoMock public stETH; - Stub public burner; - - CSModule public csm; - CSAccounting public accounting; - CommunityStakingFeeDistributorMock public communityStakingFeeDistributor; - - address internal stranger; - address internal alice; - - function setUp() public { - alice = address(1); - address[] memory penalizeRoleMembers = new address[](1); - penalizeRoleMembers[0] = alice; - - (locator, wstETH, stETH, burner) = initLido(); - - csm = new CSModule("community-staking-module", address(locator)); - communityStakingFeeDistributor = new CommunityStakingFeeDistributorMock( - address(locator), - address(accounting) - ); - accounting = new CSAccounting( - 2 ether, - alice, - address(locator), - address(wstETH), - address(csm), - 8 weeks, - 1 days - ); - } - - function test_InitContract() public { - assertEq(csm.getType(), "community-staking-module"); - assertEq(csm.getNodeOperatorsCount(), 0); - } - - function test_SetAccounting() public { - csm.setAccounting(address(accounting)); - assertEq(address(csm.accounting()), address(accounting)); - } -} diff --git a/test/CSMAddValidator.t.sol b/test/CSModule.t.sol similarity index 86% rename from test/CSMAddValidator.t.sol rename to test/CSModule.t.sol index 32519463..17a94578 100644 --- a/test/CSMAddValidator.t.sol +++ b/test/CSModule.t.sol @@ -4,6 +4,7 @@ pragma solidity ^0.8.21; import "forge-std/Test.sol"; import "../src/CSModule.sol"; import "../src/CSAccounting.sol"; +import "../src/lib/Batch.sol"; import "./helpers/Fixtures.sol"; import "./helpers/mocks/StETHMock.sol"; import "./helpers/mocks/CommunityStakingFeeDistributorMock.sol"; @@ -56,23 +57,45 @@ contract CSMCommon is Test, Fixtures, Utilities, CSModuleBase { } function createNodeOperator() internal returns (uint256) { - return createNodeOperator(nodeOperator); + return createNodeOperator(nodeOperator, 1); + } + + function createNodeOperator(uint256 keysCount) internal returns (uint256) { + return createNodeOperator(nodeOperator, keysCount); } function createNodeOperator( - address managerAddress + address managerAddress, + uint256 keysCount ) internal returns (uint256) { - uint256 keysCount = 1; (bytes memory keys, bytes memory signatures) = keysSignatures( keysCount ); - vm.deal(managerAddress, 2 ether); + vm.deal(managerAddress, keysCount * 2 ether); vm.prank(managerAddress); - csm.addNodeOperatorETH{ value: 2 ether }(keysCount, keys, signatures); + csm.addNodeOperatorETH{ value: keysCount * 2 ether }( + keysCount, + keys, + signatures + ); return csm.getNodeOperatorsCount() - 1; } } +contract CsmInitialization is CSMCommon { + function test_initContract() public { + csm = new CSModule("community-staking-module", address(locator)); + assertEq(csm.getType(), "community-staking-module"); + assertEq(csm.getNodeOperatorsCount(), 0); + } + + function test_setAccounting() public { + csm = new CSModule("community-staking-module", address(locator)); + csm.setAccounting(address(accounting)); + assertEq(address(csm.accounting()), address(accounting)); + } +} + contract CSMAddNodeOperator is CSMCommon, PermitTokenBase { function test_AddNodeOperatorWstETH() public { uint16 keysCount = 1; @@ -582,3 +605,65 @@ contract CsmResetNodeOperatorManagerAddress is CSMCommon { csm.resetNodeOperatorManagerAddress(noId); } } + +contract CsmVetKeys is CSMCommon { + function test_vetKeys() public { + uint256 noId = createNodeOperator(); + + vm.expectEmit(true, true, true, true, address(csm)); + emit BatchEnqueued(noId, 0, 1); + vm.expectEmit(true, true, false, true, address(csm)); + emit VettedSigningKeysCountChanged(noId, 1); + csm.vetKeys(noId, 1); + + NodeOperatorInfo memory no = csm.getNodeOperator(noId); + assertEq(no.totalVettedValidators, 1); + (bytes32[] memory items, , ) = csm.depositQueue(1, bytes32(0)); + (uint128 batchNoId, uint64 start, uint64 count) = Batch.deserialize( + items[0] + ); + assertEq(batchNoId, uint128(noId)); + assertEq(start, 0); + assertEq(count, 1); + } + + function test_vetKeys_totalVettedKeysIsNotZero() public { + uint256 noId = createNodeOperator(2); + csm.vetKeys(noId, 1); + + vm.expectEmit(true, true, true, true, address(csm)); + emit BatchEnqueued(noId, 1, 1); + vm.expectEmit(true, true, false, true, address(csm)); + emit VettedSigningKeysCountChanged(noId, 2); + csm.vetKeys(noId, 2); + + NodeOperatorInfo memory no = csm.getNodeOperator(noId); + assertEq(no.totalVettedValidators, 2); + (bytes32[] memory items, , ) = csm.depositQueue(2, bytes32(0)); + (uint128 batchNoId, uint64 start, uint64 count) = Batch.deserialize( + items[1] + ); + assertEq(batchNoId, uint128(noId)); + assertEq(start, 1); + assertEq(count, 1); + } + + function test_vetKeys_RevertWhenNoNodeOperator() public { + vm.expectRevert(NodeOperatorDoesNotExist.selector); + csm.vetKeys(0, 1); + } + + function test_vetKeys_RevertWhenPointerLessThanTotalVetted() public { + uint256 noId = createNodeOperator(); + csm.vetKeys(noId, 1); + + vm.expectRevert(InvalidVetKeysPointer.selector); + csm.vetKeys(noId, 1); + } + + function test_vetKeys_RevertWhenPointerGreaterThanTotalAdded() public { + uint256 noId = createNodeOperator(); + vm.expectRevert(InvalidVetKeysPointer.selector); + csm.vetKeys(noId, 2); + } +}