From 41db908218893e9b4984316d3fe67865faf00dfc Mon Sep 17 00:00:00 2001 From: skhomuti Date: Fri, 24 Nov 2023 15:18:46 +0500 Subject: [PATCH] add invariant checks and tests fix for review --- src/CSModule.sol | 55 ++++++++++-- src/lib/ValidatorCountsReport.sol | 2 +- test/CSModule.t.sol | 138 ++++++++++++++++++++++++++++-- 3 files changed, 178 insertions(+), 17 deletions(-) diff --git a/src/CSModule.sol b/src/CSModule.sol index 8c4ed37e..0797a1c0 100644 --- a/src/CSModule.sol +++ b/src/CSModule.sol @@ -111,6 +111,9 @@ contract CSModuleBase { error SameAddress(); error AlreadyProposed(); error InvalidVetKeysPointer(); + error InvalidTargetLimit(); + error IncreasingTargetLimitWhenStuckKeys(); + error StuckKeysHigherThanTotalDeposited(); error QueueLookupNoLimit(); error QueueEmptyBatch(); @@ -614,6 +617,11 @@ contract CSModule is IStakingModule, CSModuleBase { // TODO: staking router role only } + // @notice update stuck validators count by StakingRouter + // Presence of stuck validators leads to setting target limit to total deposited keys + // to prevent further deposits and clean batches from the deposit queue. + // @param nodeOperatorIds - bytes packed array of node operator ids + // @param stuckValidatorsCounts - bytes packed array of stuck validators counts function updateStuckValidatorsCount( bytes calldata nodeOperatorIds, bytes calldata stuckValidatorsCounts @@ -633,24 +641,27 @@ contract CSModule is IStakingModule, CSModuleBase { stuckValidatorsCounts, i ); + if (nodeOperatorId >= _nodeOperatorsCount) + revert NodeOperatorDoesNotExist(); NodeOperator storage no = _nodeOperators[nodeOperatorId]; + if (stuckValidatorsCount > no.totalDepositedKeys) + revert StuckKeysHigherThanTotalDeposited(); + if (stuckValidatorsCount == no.stuckValidatorsCount) continue; + no.stuckValidatorsCount = stuckValidatorsCount; if (stuckValidatorsCount == 0) { - no.isTargetLimitActive = false; - no.targetLimit = 0; + _setTargetLimit(nodeOperatorId, false, 0); } else { - no.isTargetLimitActive = true; - no.targetLimit = no.totalDepositedKeys; + _setTargetLimit(nodeOperatorId, true, no.totalDepositedKeys); } emit StuckSigningKeysCountChanged( nodeOperatorId, stuckValidatorsCount ); - emit TargetValidatorsCountChanged(nodeOperatorId, no.targetLimit); } - _incrementNonce(); + _incrementModuleNonce(); } function updateExitedValidatorsCount( @@ -677,11 +688,37 @@ contract CSModule is IStakingModule, CSModuleBase { uint256 targetLimit ) external onlyExistingNodeOperator(nodeOperatorId) onlyStakingRouter { // TODO sanity checks? + _setTargetLimit(nodeOperatorId, isTargetLimitActive, targetLimit); + _incrementModuleNonce(); + } + + // @notice update target limits with event emission + // target limit decreasing (or appearing) must unvet node operator's keys from the queue + // @dev it's not expected (yet) that target limit can be enabled or disabled without or with some value. + // only (!isTargetLimitActive && targetLimit == 0) means that target limit is disabled + function _setTargetLimit( + uint256 nodeOperatorId, + bool isTargetLimitActive, + uint256 targetLimit + ) internal { + if (isTargetLimitActive && targetLimit == 0) + revert InvalidTargetLimit(); + if (!isTargetLimitActive && targetLimit != 0) + revert InvalidTargetLimit(); + NodeOperator storage no = _nodeOperators[nodeOperatorId]; - no.isTargetLimitActive = isTargetLimitActive; + if (no.stuckValidatorsCount > 0 && targetLimit > no.totalDepositedKeys) + revert IncreasingTargetLimitWhenStuckKeys(); + + if (no.isTargetLimitActive != isTargetLimitActive) { + no.isTargetLimitActive = isTargetLimitActive; + } + if (no.targetLimit == targetLimit) return; + if (targetLimit < no.targetLimit || no.targetLimit == 0) + _unvetKeys(nodeOperatorId); + no.targetLimit = targetLimit; emit TargetValidatorsCountChanged(nodeOperatorId, targetLimit); - _incrementNonce(); } function onExitedAndStuckValidatorsCountsUpdated() external { @@ -705,6 +742,8 @@ contract CSModule is IStakingModule, CSModuleBase { if (vetKeysPointer <= no.totalVettedKeys) revert InvalidVetKeysPointer(); if (vetKeysPointer > no.totalAddedKeys) revert InvalidVetKeysPointer(); + if (no.isTargetLimitActive && vetKeysPointer > no.targetLimit) + revert InvalidVetKeysPointer(); uint64 count = SafeCast.toUint64(vetKeysPointer - no.totalVettedKeys); uint64 start = SafeCast.toUint64(no.totalVettedKeys); diff --git a/src/lib/ValidatorCountsReport.sol b/src/lib/ValidatorCountsReport.sol index 2c59e7a9..75230720 100644 --- a/src/lib/ValidatorCountsReport.sol +++ b/src/lib/ValidatorCountsReport.sol @@ -23,7 +23,7 @@ library ValidatorCountsReport { function next( bytes calldata ids, bytes calldata counts, uint256 offset - ) internal returns (uint256 nodeOperatorId, uint256 keysCount) { + ) internal pure returns (uint256 nodeOperatorId, uint256 keysCount) { nodeOperatorId = uint256(bytes32(ids[8 * offset:8 * offset + 8]) >> 192); keysCount = uint256(bytes32(counts[16 * offset:16 * offset + 16]) >> 128); } diff --git a/test/CSModule.t.sol b/test/CSModule.t.sol index ad9cf954..4befd8b2 100644 --- a/test/CSModule.t.sol +++ b/test/CSModule.t.sol @@ -182,7 +182,7 @@ contract CSMCommon is Test, Fixtures, Utilities, CSModuleBase { function getNodeOperatorSummary( uint256 noId - ) public returns (NodeOperatorSummary memory) { + ) public view returns (NodeOperatorSummary memory) { ( bool isTargetLimitActive, uint256 targetValidatorsCount, @@ -801,6 +801,15 @@ contract CsmVetKeys is CSMCommon { vm.expectRevert(InvalidVetKeysPointer.selector); csm.vetKeys(noId, 2); } + + function test_vetKeys_RevertWhenPointerGreaterThanTargetLimit() public { + uint256 noId = createNodeOperator(2); + csm.vetKeys(noId, 1); + csm.updateTargetValidatorsLimits(noId, true, 1); + + vm.expectRevert(InvalidVetKeysPointer.selector); + csm.vetKeys(noId, 2); + } } contract CsmQueueOps is CSMCommon { @@ -1036,20 +1045,22 @@ contract CsmGetNodeOperatorSummary is CSMCommon { NodeOperatorSummary memory summary = getNodeOperatorSummary(noId); assertTrue(summary.isTargetLimitActive); assertEq(summary.targetValidatorsCount, 2); - assertEq(summary.depositableValidatorsCount, 2); + // should be unvetted + assertEq(summary.depositableValidatorsCount, 0); } function test_getNodeOperatorSummary_targetLimitHigherThanVettedKeys() public { uint256 noId = createNodeOperator(3); - csm.vetKeys(noId, 3); + csm.updateTargetValidatorsLimits(noId, true, 1); + csm.vetKeys(noId, 1); - csm.updateTargetValidatorsLimits(noId, true, 5); + csm.updateTargetValidatorsLimits(noId, true, 3); NodeOperatorSummary memory summary = getNodeOperatorSummary(noId); assertTrue(summary.isTargetLimitActive); - assertEq(summary.targetValidatorsCount, 5); - assertEq(summary.depositableValidatorsCount, 3); + assertEq(summary.targetValidatorsCount, 3); + assertEq(summary.depositableValidatorsCount, 1); } } @@ -1062,12 +1073,71 @@ contract CsmUpdateTargetValidatorsLimits is CSMCommon { csm.updateTargetValidatorsLimits(noId, true, 1); } + function test_updateTargetValidatorsLimits_unvetKeys() public { + uint256 noId = createNodeOperator(); + csm.vetKeys(noId, 1); + + csm.updateTargetValidatorsLimits(noId, true, 1); + NodeOperatorInfo memory no = csm.getNodeOperator(noId); + assertEq(no.totalVettedValidators, 0); + } + + function test_updateTargetValidatorsLimits_NoUnvetKeysWhenLimitHigher() + public + { + uint256 noId = createNodeOperator(2); + csm.updateTargetValidatorsLimits(noId, true, 1); + + csm.vetKeys(noId, 1); + + csm.updateTargetValidatorsLimits(noId, true, 2); + NodeOperatorInfo memory no = csm.getNodeOperator(noId); + assertEq(no.totalVettedValidators, 1); + } + function test_updateTargetValidatorsLimits_RevertWhenNoNodeOperator() public { vm.expectRevert(NodeOperatorDoesNotExist.selector); csm.updateTargetValidatorsLimits(0, true, 1); } + + function test_updateTargetValidatorsLimits_RevertWhenNotStakingRouter() + public + { + // TODO implement + vm.skip(true); + } + + function test_updateTargetValidatorsLimits_RevertWhenTargetLimitIsZero() + public + { + uint256 noId = createNodeOperator(); + vm.expectRevert(InvalidTargetLimit.selector); + csm.updateTargetValidatorsLimits(noId, true, 0); + } + + function test_updateTargetValidatorsLimits_RevertWhenTargetLimitIsNonZero() + public + { + uint256 noId = createNodeOperator(); + vm.expectRevert(InvalidTargetLimit.selector); + csm.updateTargetValidatorsLimits(noId, false, 10); + } + + function test_updateTargetValidatorsLimits_RevertWhenTargetLimitIsGreaterThanDepositedWhenStuck() + public + { + uint256 noId = createNodeOperator(); + csm.vetKeys(noId, 1); + csm.obtainDepositData(1, ""); + csm.updateStuckValidatorsCount( + bytes.concat(bytes8(0x0000000000000000)), + bytes.concat(bytes16(0x00000000000000000000000000000001)) + ); + vm.expectRevert(IncreasingTargetLimitWhenStuckKeys.selector); + csm.updateTargetValidatorsLimits(noId, true, 4); + } } contract CsmUpdateStuckValidatorsCount is CSMCommon { @@ -1076,10 +1146,10 @@ contract CsmUpdateStuckValidatorsCount is CSMCommon { csm.vetKeys(noId, 3); csm.obtainDepositData(1, ""); - vm.expectEmit(true, true, false, true, address(csm)); - emit StuckSigningKeysCountChanged(noId, 1); vm.expectEmit(true, true, false, true, address(csm)); emit TargetValidatorsCountChanged(noId, 1); + vm.expectEmit(true, true, false, true, address(csm)); + emit StuckSigningKeysCountChanged(noId, 1); csm.updateStuckValidatorsCount( bytes.concat(bytes8(0x0000000000000000)), bytes.concat(bytes16(0x00000000000000000000000000000001)) @@ -1098,4 +1168,56 @@ contract CsmUpdateStuckValidatorsCount is CSMCommon { ); assertTrue(summary.isTargetLimitActive, "isTargetLimitActive is false"); } + + function test_updateStuckValidatorsCount_RevertWhenNoNodeOperator() public { + vm.expectRevert(NodeOperatorDoesNotExist.selector); + csm.updateStuckValidatorsCount( + bytes.concat(bytes8(0x0000000000000000)), + bytes.concat(bytes16(0x00000000000000000000000000000001)) + ); + } + + function test_updateStuckValidatorsCount_RevertWhenNotStakingRouter() + public + { + // TODO implement + vm.skip(true); + } + + function test_updateStuckValidatorsCount_RevertWhenCountMoreThanDeposited() + public + { + uint256 noId = createNodeOperator(3); + csm.vetKeys(noId, 3); + csm.obtainDepositData(1, ""); + + vm.expectRevert(StuckKeysHigherThanTotalDeposited.selector); + csm.updateStuckValidatorsCount( + bytes.concat(bytes8(0x0000000000000000)), + bytes.concat(bytes16(0x00000000000000000000000000000002)) + ); + } + + // @dev this is ugly solution to test that events are not emitted when stuckKeysCount is not changed + // we can't do it properly while vm.expectNotEmit is not implemented in forge (or smth like that) + function testFail_updateStuckValidatorsCount_NoEventWhenStuckKeysCountSame() + public + { + uint256 noId = createNodeOperator(); + csm.vetKeys(noId, 1); + csm.obtainDepositData(1, ""); + csm.updateStuckValidatorsCount( + bytes.concat(bytes8(0x0000000000000000)), + bytes.concat(bytes16(0x00000000000000000000000000000001)) + ); + + vm.expectEmit(true, true, false, true, address(csm)); + emit TargetValidatorsCountChanged(noId, 1); + vm.expectEmit(true, true, false, true, address(csm)); + emit StuckSigningKeysCountChanged(noId, 1); + csm.updateStuckValidatorsCount( + bytes.concat(bytes8(0x0000000000000000)), + bytes.concat(bytes16(0x00000000000000000000000000000001)) + ); + } }