From 4348a8a83971bf17fc76a7c09d81c110a8285548 Mon Sep 17 00:00:00 2001 From: Don Perignom <10616301+madlabman@users.noreply.github.com> Date: Mon, 4 Dec 2023 08:58:39 +0100 Subject: [PATCH] feat: keys removal (#44) --- src/CSModule.sol | 87 ++++++++- src/lib/SigningKeys.sol | 41 +++-- test/CSModule.t.sol | 357 ++++++++++++++++++++++++++++++++++++- test/helpers/Utilities.sol | 24 ++- 4 files changed, 488 insertions(+), 21 deletions(-) diff --git a/src/CSModule.sol b/src/CSModule.sol index 30bac7a7..631b841c 100644 --- a/src/CSModule.sol +++ b/src/CSModule.sol @@ -95,6 +95,8 @@ contract CSModuleBase { event LocatorContractSet(address locatorAddress); event UnvettingFeeSet(uint256 unvettingFee); + event UnvettingFeeApplied(uint256 indexed nodeOperatorId); + error NodeOperatorDoesNotExist(); error MaxNodeOperatorsCountReached(); error SenderIsNotManagerAddress(); @@ -110,6 +112,8 @@ contract CSModuleBase { error QueueBatchInvalidStart(bytes32 batch); error QueueBatchInvalidCount(bytes32 batch); error QueueBatchUnvettedKeys(bytes32 batch); + + error SigningKeysInvalidOffset(); } contract CSModule is IStakingModule, CSModuleBase { @@ -151,10 +155,10 @@ contract CSModule is IStakingModule, CSModuleBase { accounting = ICSAccounting(_accounting); } - function setUnvettingFee(uint256 unvettingFee_) external { + function setUnvettingFee(uint256 _unvettingFee) external { // TODO: add role check - unvettingFee = unvettingFee_; - emit UnvettingFeeSet(unvettingFee_); + unvettingFee = _unvettingFee; + emit UnvettingFeeSet(_unvettingFee); } function _lido() internal view returns (ILido) { @@ -558,6 +562,25 @@ contract CSModule is IStakingModule, CSModuleBase { depositableValidatorsCount = no.totalVettedKeys - no.totalExitedKeys; } + function getNodeOperatorSigningKeys( + uint256 nodeOperatorId, + uint256 startIndex, + uint256 keysCount + ) + external + view + onlyExistingNodeOperator(nodeOperatorId) + returns (bytes memory) + { + return + SigningKeys.loadKeys( + SIGNING_KEYS_POSITION, + nodeOperatorId, + startIndex, + keysCount + ); + } + function getNonce() external view returns (uint256) { return _nonce; } @@ -679,19 +702,45 @@ contract CSModule is IStakingModule, CSModuleBase { onlyKeyValidatorOrNodeOperatorManager { _unvetKeys(nodeOperatorId); - accounting.penalize(nodeOperatorId, unvettingFee); + _applyUnvettingFee(nodeOperatorId); + _incrementModuleNonce(); } function unsafeUnvetKeys(uint256 nodeOperatorId) external onlyKeyValidator { _unvetKeys(nodeOperatorId); + _incrementModuleNonce(); } + function removeKeys( + uint256 nodeOperatorId, + uint256 startIndex, + uint256 keysCount + ) + external + onlyExistingNodeOperator(nodeOperatorId) + onlyNodeOperatorManager(nodeOperatorId) + { + NodeOperator storage no = _nodeOperators[nodeOperatorId]; + if (no.totalVettedKeys > startIndex) { + _unvetKeys(nodeOperatorId); + _applyUnvettingFee(nodeOperatorId); + } + + _removeSigningKeys(nodeOperatorId, startIndex, keysCount); + _incrementModuleNonce(); + } + + /// @dev NB! doesn't increment module nonce function _unvetKeys(uint256 nodeOperatorId) internal { NodeOperator storage no = _nodeOperators[nodeOperatorId]; no.totalVettedKeys = no.totalDepositedKeys; no.queueNonce++; emit VettedSigningKeysCountChanged(nodeOperatorId, no.totalVettedKeys); - _incrementModuleNonce(); + } + + function _applyUnvettingFee(uint256 nodeOperatorId) internal { + accounting.penalize(nodeOperatorId, unvettingFee); + emit UnvettingFeeApplied(nodeOperatorId); } function onWithdrawalCredentialsChanged() external { @@ -727,6 +776,34 @@ contract CSModule is IStakingModule, CSModuleBase { _incrementModuleNonce(); } + function _removeSigningKeys( + uint256 nodeOperatorId, + uint256 startIndex, + uint256 keysCount + ) internal { + NodeOperator storage no = _nodeOperators[nodeOperatorId]; + + if (startIndex < no.totalDepositedKeys) { + revert SigningKeysInvalidOffset(); + } + + if (startIndex + keysCount > no.totalAddedKeys) { + revert SigningKeysInvalidOffset(); + } + + // solhint-disable-next-line func-named-parameters + uint256 newTotalSigningKeys = SigningKeys.removeKeysSigs( + SIGNING_KEYS_POSITION, + nodeOperatorId, + startIndex, + keysCount, + no.totalAddedKeys + ); + + no.totalAddedKeys = newTotalSigningKeys; + emit TotalSigningKeysCountChanged(nodeOperatorId, newTotalSigningKeys); + } + function obtainDepositData( uint256 depositsCount, bytes calldata /* _depositCalldata */ diff --git a/src/lib/SigningKeys.sol b/src/lib/SigningKeys.sol index 0a0c2fea..f6e34d19 100644 --- a/src/lib/SigningKeys.sol +++ b/src/lib/SigningKeys.sol @@ -4,7 +4,7 @@ // See contracts/COMPILERS.md pragma solidity 0.8.21; -import "@openzeppelin/contracts/utils/math/SafeMath.sol"; +import { SafeMath } from "@openzeppelin/contracts/utils/math/SafeMath.sol"; /// @title Library for manage operator keys in storage /// @author KRogLA @@ -23,7 +23,7 @@ library SigningKeys { return uint256(keccak256(abi.encodePacked(position, nodeOperatorId, keyIndex))); } - /// @dev store opeartor keys to storage + /// @dev store operator keys to storage /// @param position storage slot /// @param nodeOperatorId operator id /// @param startIndex start index @@ -78,7 +78,7 @@ library SigningKeys { return startIndex; } - /// @dev remove opeartor keys from storage + /// @dev remove operator keys from storage /// @param position storage slot /// @param nodeOperatorId operator id /// @param startIndex start index @@ -105,14 +105,14 @@ library SigningKeys { for (uint256 i = startIndex + keysCount; i > startIndex;) { curOffset = position.getKeyOffset(nodeOperatorId, i - 1); assembly { - // read key + // read key mstore(add(tmpKey, 0x30), shr(128, sload(add(curOffset, 1)))) // bytes 16..47 mstore(add(tmpKey, 0x20), sload(curOffset)) // bytes 0..31 } if (i < totalKeysCount) { lastOffset = position.getKeyOffset(nodeOperatorId, totalKeysCount - 1); // move last key to deleted key index - for (j = 0; j < 5;) { + for (j = 0; j < 5;) { // load 160 bytes (5 slots) containing key and signature assembly { sstore(add(curOffset, j), sload(add(lastOffset, j))) j := add(j, 1) @@ -136,7 +136,7 @@ library SigningKeys { return totalKeysCount; } - /// @dev laod opeartor keys from storage + /// @dev load operator keys and signatures from storage /// @param position storage slot /// @param nodeOperatorId operator id /// @param startIndex start index @@ -157,12 +157,12 @@ library SigningKeys { for (uint256 i; i < keysCount;) { curOffset = position.getKeyOffset(nodeOperatorId, startIndex + i); assembly { - // read key - let _ofs := add(add(pubkeys, 0x20), mul(add(bufOffset, i), 48)) //PUBKEY_LENGTH = 48 + // read key + let _ofs := add(add(pubkeys, 0x20), mul(add(bufOffset, i), 48)) // PUBKEY_LENGTH = 48 mstore(add(_ofs, 0x10), shr(128, sload(add(curOffset, 1)))) // bytes 16..47 mstore(_ofs, sload(curOffset)) // bytes 0..31 - // store signature - _ofs := add(add(signatures, 0x20), mul(add(bufOffset, i), 96)) //SIGNATURE_LENGTH = 96 + // store signature + _ofs := add(add(signatures, 0x20), mul(add(bufOffset, i), 96)) // SIGNATURE_LENGTH = 96 mstore(_ofs, sload(add(curOffset, 2))) mstore(add(_ofs, 0x20), sload(add(curOffset, 3))) mstore(add(_ofs, 0x40), sload(add(curOffset, 4))) @@ -171,6 +171,27 @@ library SigningKeys { } } + function loadKeys( + bytes32 position, + uint256 nodeOperatorId, + uint256 startIndex, + uint256 keysCount + ) internal view returns (bytes memory pubkeys) { + uint256 curOffset; + + pubkeys = new bytes(keysCount.mul(PUBKEY_LENGTH)); + for (uint256 i; i < keysCount;) { + curOffset = position.getKeyOffset(nodeOperatorId, startIndex + i); + assembly { + // read key + let offset := add(add(pubkeys, 0x20), mul(i, 48)) // PUBKEY_LENGTH = 48 + mstore(add(offset, 0x10), shr(128, sload(add(curOffset, 1)))) // bytes 16..47 + mstore(offset, sload(curOffset)) // bytes 0..31 + i := add(i, 1) + } + } + } + function initKeysSigsBuf(uint256 count) internal pure returns (bytes memory, bytes memory) { return (new bytes(count.mul(PUBKEY_LENGTH)), new bytes(count.mul(SIGNATURE_LENGTH))); } diff --git a/test/CSModule.t.sol b/test/CSModule.t.sol index 77107361..3fb7eeef 100644 --- a/test/CSModule.t.sol +++ b/test/CSModule.t.sol @@ -40,14 +40,13 @@ contract CSMCommon is Test, Fixtures, Utilities, CSModuleBase { address internal nodeOperator; function setUp() public { - vm.label(address(this), "TEST"); - nodeOperator = nextAddress("NODE_OPERATOR"); stranger = nextAddress("STRANGER"); admin = nextAddress("ADMIN"); (locator, wstETH, stETH, burner) = initLido(); + // FIXME: move to the corresponding tests vm.deal(nodeOperator, 2 ether + 1 wei); vm.prank(nodeOperator); stETH.submit{ value: 2 ether + 1 wei }(address(0)); @@ -67,6 +66,7 @@ contract CSMCommon is Test, Fixtures, Utilities, CSModuleBase { 1 days ); csm.setAccounting(address(accounting)); + csm.setUnvettingFee(0.05 ether); vm.startPrank(admin); accounting.grantRole( @@ -91,6 +91,15 @@ contract CSMCommon is Test, Fixtures, Utilities, CSModuleBase { (bytes memory keys, bytes memory signatures) = keysSignatures( keysCount ); + return createNodeOperator(managerAddress, keysCount, keys, signatures); + } + + function createNodeOperator( + address managerAddress, + uint256 keysCount, + bytes memory keys, + bytes memory signatures + ) internal returns (uint256) { vm.deal(managerAddress, keysCount * 2 ether); vm.prank(managerAddress); csm.addNodeOperatorETH{ value: keysCount * 2 ether }( @@ -912,3 +921,347 @@ contract CsmQueueOps is CSMCommon { _assertQueueIsEmpty(); } } + +contract CsmViewKeys is CSMCommon { + function test_viewAllKeys() public { + bytes memory keys = randomBytes(48 * 3); + + uint256 noId = createNodeOperator({ + managerAddress: address(this), + keysCount: 3, + keys: keys, + signatures: randomBytes(96 * 3) + }); + + bytes memory obtainedKeys = csm.getNodeOperatorSigningKeys({ + nodeOperatorId: noId, + startIndex: 0, + keysCount: 3 + }); + + assertEq(obtainedKeys, keys, "unexpected keys"); + } + + function test_viewKeysFromOffset() public { + bytes memory wantedKey = randomBytes(48); + bytes memory keys = bytes.concat( + randomBytes(48), + wantedKey, + randomBytes(48) + ); + + uint256 noId = createNodeOperator({ + managerAddress: address(this), + keysCount: 3, + keys: keys, + signatures: randomBytes(96 * 3) + }); + + bytes memory obtainedKeys = csm.getNodeOperatorSigningKeys({ + nodeOperatorId: noId, + startIndex: 1, + keysCount: 1 + }); + + assertEq(obtainedKeys, wantedKey, "unexpected key at position 1"); + } +} + +contract CsmRemoveKeys is CSMCommon { + event SigningKeyRemoved(uint256 indexed nodeOperatorId, bytes pubkey); + + bytes key0 = randomBytes(48); + bytes key1 = randomBytes(48); + bytes key2 = randomBytes(48); + bytes key3 = randomBytes(48); + bytes key4 = randomBytes(48); + + function test_singleKeyRemoval() public { + bytes memory keys = bytes.concat(key0, key1, key2, key3, key4); + + uint256 noId = createNodeOperator({ + managerAddress: address(this), + keysCount: 5, + keys: keys, + signatures: randomBytes(96 * 5) + }); + + // at the beginning + { + vm.expectEmit(true, true, true, true, address(csm)); + emit SigningKeyRemoved(noId, key0); + + vm.expectEmit(true, true, true, true, address(csm)); + emit TotalSigningKeysCountChanged(noId, 4); + } + csm.removeKeys({ nodeOperatorId: noId, startIndex: 0, keysCount: 1 }); + /* + key4 + key1 + key2 + key3 + */ + + // in between + { + vm.expectEmit(true, true, true, true, address(csm)); + emit SigningKeyRemoved(noId, key1); + + vm.expectEmit(true, true, true, true, address(csm)); + emit TotalSigningKeysCountChanged(noId, 3); + } + csm.removeKeys({ nodeOperatorId: noId, startIndex: 1, keysCount: 1 }); + /* + key4 + key3 + key2 + */ + + // at the end + { + vm.expectEmit(true, true, true, true, address(csm)); + emit SigningKeyRemoved(noId, key2); + + vm.expectEmit(true, true, true, true, address(csm)); + emit TotalSigningKeysCountChanged(noId, 2); + } + csm.removeKeys({ nodeOperatorId: noId, startIndex: 2, keysCount: 1 }); + /* + key4 + key3 + */ + + bytes memory obtainedKeys = csm.getNodeOperatorSigningKeys({ + nodeOperatorId: noId, + startIndex: 0, + keysCount: 2 + }); + assertEq(obtainedKeys, bytes.concat(key4, key3), "unexpected keys"); + + NodeOperatorInfo memory no = csm.getNodeOperator(noId); + assertEq(no.totalAddedValidators, 2); + } + + function test_multipleKeysRemovalFromStart() public { + bytes memory keys = bytes.concat(key0, key1, key2, key3, key4); + + uint256 noId = createNodeOperator({ + managerAddress: address(this), + keysCount: 5, + keys: keys, + signatures: randomBytes(96 * 5) + }); + + { + // NOTE: keys are being removed in reverse order to keep an original order of keys at the end of the list + vm.expectEmit(true, true, true, true, address(csm)); + emit SigningKeyRemoved(noId, key1); + vm.expectEmit(true, true, true, true, address(csm)); + emit SigningKeyRemoved(noId, key0); + + vm.expectEmit(true, true, true, true, address(csm)); + emit TotalSigningKeysCountChanged(noId, 3); + } + + csm.removeKeys({ nodeOperatorId: noId, startIndex: 0, keysCount: 2 }); + + bytes memory obtainedKeys = csm.getNodeOperatorSigningKeys({ + nodeOperatorId: noId, + startIndex: 0, + keysCount: 3 + }); + assertEq( + obtainedKeys, + bytes.concat(key3, key4, key2), + "unexpected keys" + ); + + NodeOperatorInfo memory no = csm.getNodeOperator(noId); + assertEq(no.totalAddedValidators, 3); + } + + function test_multipleKeysRemovalInBetween() public { + bytes memory keys = bytes.concat(key0, key1, key2, key3, key4); + + uint256 noId = createNodeOperator({ + managerAddress: address(this), + keysCount: 5, + keys: keys, + signatures: randomBytes(96 * 5) + }); + + { + // NOTE: keys are being removed in reverse order to keep an original order of keys at the end of the list + vm.expectEmit(true, true, true, true, address(csm)); + emit SigningKeyRemoved(noId, key2); + vm.expectEmit(true, true, true, true, address(csm)); + emit SigningKeyRemoved(noId, key1); + + vm.expectEmit(true, true, true, true, address(csm)); + emit TotalSigningKeysCountChanged(noId, 3); + } + + csm.removeKeys({ nodeOperatorId: noId, startIndex: 1, keysCount: 2 }); + + bytes memory obtainedKeys = csm.getNodeOperatorSigningKeys({ + nodeOperatorId: noId, + startIndex: 0, + keysCount: 3 + }); + assertEq( + obtainedKeys, + bytes.concat(key0, key3, key4), + "unexpected keys" + ); + + NodeOperatorInfo memory no = csm.getNodeOperator(noId); + assertEq(no.totalAddedValidators, 3); + } + + function test_multipleKeysRemovalFromEnd() public { + bytes memory keys = bytes.concat(key0, key1, key2, key3, key4); + + uint256 noId = createNodeOperator({ + managerAddress: address(this), + keysCount: 5, + keys: keys, + signatures: randomBytes(96 * 5) + }); + + { + // NOTE: keys are being removed in reverse order to keep an original order of keys at the end of the list + vm.expectEmit(true, true, true, true, address(csm)); + emit SigningKeyRemoved(noId, key4); + vm.expectEmit(true, true, true, true, address(csm)); + emit SigningKeyRemoved(noId, key3); + + vm.expectEmit(true, true, true, true, address(csm)); + emit TotalSigningKeysCountChanged(noId, 3); + } + + csm.removeKeys({ nodeOperatorId: noId, startIndex: 3, keysCount: 2 }); + + bytes memory obtainedKeys = csm.getNodeOperatorSigningKeys({ + nodeOperatorId: noId, + startIndex: 0, + keysCount: 3 + }); + assertEq( + obtainedKeys, + bytes.concat(key0, key1, key2), + "unexpected keys" + ); + + NodeOperatorInfo memory no = csm.getNodeOperator(noId); + assertEq(no.totalAddedValidators, 3); + } + + function test_removeAllKeys() public { + uint256 noId = createNodeOperator({ + managerAddress: address(this), + keysCount: 5, + keys: randomBytes(48 * 5), + signatures: randomBytes(96 * 5) + }); + + { + vm.expectEmit(true, true, true, true, address(csm)); + emit TotalSigningKeysCountChanged(noId, 0); + } + + csm.removeKeys({ nodeOperatorId: noId, startIndex: 0, keysCount: 5 }); + + NodeOperatorInfo memory no = csm.getNodeOperator(noId); + assertEq(no.totalAddedValidators, 0); + } + + function test_removingVettedKeysUnvetsOperator() public { + uint256 noId = createNodeOperator({ + managerAddress: address(this), + keysCount: 5, + keys: randomBytes(48 * 5), + signatures: randomBytes(96 * 5) + }); + + csm.vetKeys(noId, 3); + csm.obtainDepositData(1, ""); + + /* + no.totalVettedValidators = 3 + no.totalDepositedKeys = 1 + no.totalAddedKeys = 5 + */ + + { + vm.expectEmit(true, true, true, true, address(csm)); + emit VettedSigningKeysCountChanged(noId, 1); + vm.expectEmit(true, true, true, true, address(csm)); + emit UnvettingFeeApplied(noId); + } + + csm.removeKeys({ nodeOperatorId: noId, startIndex: 1, keysCount: 2 }); + NodeOperatorInfo memory no = csm.getNodeOperator(noId); + assertEq(no.totalVettedValidators, 1); + } + + function test_removingNotVettedKeysDoesntUnvetOperator() public { + uint256 noId = createNodeOperator({ + managerAddress: address(this), + keysCount: 5, + keys: randomBytes(48 * 5), + signatures: randomBytes(96 * 5) + }); + + csm.vetKeys(noId, 3); + csm.obtainDepositData(1, ""); + + /* + no.totalVettedValidators = 3 + no.totalDepositedKeys = 1 + no.totalAddedKeys = 5 + */ + + csm.removeKeys({ nodeOperatorId: noId, startIndex: 3, keysCount: 2 }); + NodeOperatorInfo memory no = csm.getNodeOperator(0); + assertEq(no.totalVettedValidators, 3); + } + + function test_removeKeys_RevertWhenNoNodeOperator() public { + vm.expectRevert(NodeOperatorDoesNotExist.selector); + csm.removeKeys({ nodeOperatorId: 0, startIndex: 0, keysCount: 1 }); + } + + function test_removeKeys_RevertWhenMoreThanAdded() public { + uint256 noId = createNodeOperator({ + managerAddress: address(this), + keysCount: 1 + }); + + vm.expectRevert(SigningKeysInvalidOffset.selector); + csm.removeKeys({ nodeOperatorId: noId, startIndex: 0, keysCount: 2 }); + } + + function test_removeKeys_RevertWhenLessThanDeposited() public { + uint256 noId = createNodeOperator({ + managerAddress: address(this), + keysCount: 2 + }); + + csm.vetKeys(noId, 1); + csm.obtainDepositData(1, ""); + + vm.expectRevert(SigningKeysInvalidOffset.selector); + csm.removeKeys({ nodeOperatorId: noId, startIndex: 0, keysCount: 1 }); + } + + function test_removeKeys_RevertWhenNotManager() public { + uint256 noId = createNodeOperator({ + managerAddress: address(this), + keysCount: 1 + }); + + vm.prank(stranger); + vm.expectRevert(SenderIsNotManagerAddress.selector); + csm.removeKeys({ nodeOperatorId: noId, startIndex: 0, keysCount: 1 }); + } +} diff --git a/test/helpers/Utilities.sol b/test/helpers/Utilities.sol index 73695604..2dab7a6f 100644 --- a/test/helpers/Utilities.sol +++ b/test/helpers/Utilities.sol @@ -9,10 +9,9 @@ contract Utilities is CommonBase { bytes32 internal seed = keccak256("seed sEed seEd"); function nextAddress() internal returns (address) { - address a = address( - uint160(uint256(keccak256(abi.encodePacked(seed)))) - ); - seed = keccak256(abi.encodePacked(seed)); + bytes32 buf = keccak256(abi.encodePacked(seed)); + address a = address(uint160(uint256(buf))); + seed = buf; return a; } @@ -50,6 +49,23 @@ contract Utilities is CommonBase { return (keys, signatures); } + function randomBytes(uint256 length) public returns (bytes memory b) { + b = new bytes(length); + + for (;;) { + bytes32 buf = keccak256(abi.encodePacked(seed)); + seed = buf; + + for (uint256 i = 0; i < 32; i++) { + if (length == 0) { + return b; + } + length--; + b[length] = buf[i]; + } + } + } + function checkChainId(uint256 chainId) public view { if (chainId != block.chainid) { revert("wrong chain id");