Skip to content

Commit

Permalink
Merge pull request #38 from lidofinance/queue-improvements
Browse files Browse the repository at this point in the history
feat: vetKeys fixes and tests
  • Loading branch information
skhomuti authored Nov 20, 2023
2 parents 0313fbf + 0521be3 commit ee8ccbb
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 81 deletions.
47 changes: 30 additions & 17 deletions src/CSModule.sol
Original file line number Diff line number Diff line change
Expand Up @@ -95,17 +95,21 @@ contract CSModuleBase {
event UnvettingFeeSet(uint256 unvettingFee);

error NodeOperatorDoesNotExist();
error MaxNodeOperatorsCountReached();
error SenderIsNotManagerAddress();
error SenderIsNotRewardAddress();
error SenderIsNotProposedAddress();
error SameAddress();
error AlreadyProposed();
error InvalidVetKeysPointer();
}

contract CSModule is IStakingModule, CSModuleBase {
using QueueLib for QueueLib.Queue;

uint256 public constant MAX_NODE_OPERATOR_NAME_LENGTH = 255;
// @dev max number of node operators is limited by uint128 due to Batch serialization in 32 bytes
// it seems to be enough
uint128 public constant MAX_NODE_OPERATORS_COUNT = type(uint128).max;
bytes32 public constant SIGNING_KEYS_POSITION =
keccak256("lido.CommunityStakingModule.signingKeysPosition");

Expand Down Expand Up @@ -182,6 +186,8 @@ contract CSModule is IStakingModule, CSModuleBase {
);

uint256 id = _nodeOperatorsCount;
if (id == MAX_NODE_OPERATORS_COUNT)
revert MaxNodeOperatorsCountReached();
NodeOperator storage no = _nodeOperators[id];

no.managerAddress = msg.sender;
Expand All @@ -205,6 +211,8 @@ contract CSModule is IStakingModule, CSModuleBase {
// TODO: sanity checks

uint256 id = _nodeOperatorsCount;
if (id == MAX_NODE_OPERATORS_COUNT)
revert MaxNodeOperatorsCountReached();
NodeOperator storage no = _nodeOperators[id];

no.managerAddress = msg.sender;
Expand Down Expand Up @@ -233,6 +241,8 @@ contract CSModule is IStakingModule, CSModuleBase {
// TODO sanity checks

uint256 id = _nodeOperatorsCount;
if (id == MAX_NODE_OPERATORS_COUNT)
revert MaxNodeOperatorsCountReached();
NodeOperator storage no = _nodeOperators[id];
no.rewardAddress = msg.sender;
no.managerAddress = msg.sender;
Expand Down Expand Up @@ -260,6 +270,8 @@ contract CSModule is IStakingModule, CSModuleBase {
// TODO sanity checks

uint256 id = _nodeOperatorsCount;
if (id == MAX_NODE_OPERATORS_COUNT)
revert MaxNodeOperatorsCountReached();
NodeOperator storage no = _nodeOperators[id];

no.managerAddress = msg.sender;
Expand Down Expand Up @@ -288,6 +300,8 @@ contract CSModule is IStakingModule, CSModuleBase {
// TODO sanity checks

uint256 id = _nodeOperatorsCount;
if (id == MAX_NODE_OPERATORS_COUNT)
revert MaxNodeOperatorsCountReached();
NodeOperator storage no = _nodeOperators[id];
no.rewardAddress = msg.sender;
no.managerAddress = msg.sender;
Expand Down Expand Up @@ -533,7 +547,7 @@ contract CSModule is IStakingModule, CSModuleBase {
stuckPenaltyEndTimestamp = no.stuckPenaltyEndTimestamp;
totalExitedValidators = no.totalExitedKeys;
totalDepositedValidators = no.totalDepositedKeys;
depositableValidatorsCount = no.totalAddedKeys - no.totalExitedKeys;
depositableValidatorsCount = no.totalVettedKeys - no.totalExitedKeys;
}

function getNonce() external view returns (uint256) {
Expand Down Expand Up @@ -622,22 +636,17 @@ contract CSModule is IStakingModule, CSModuleBase {

function vetKeys(
uint256 nodeOperatorId,
uint64 vettedKeysCount
) external onlyKeyValidator {
uint64 vetKeysPointer
) external onlyExistingNodeOperator(nodeOperatorId) onlyKeyValidator {
NodeOperator storage no = _nodeOperators[nodeOperatorId];

require(
vettedKeysCount > no.totalVettedKeys,
"Wrong vettedKeysCount: less than already vetted"
);
require(
vettedKeysCount <= no.totalAddedKeys,
"Wrong vettedKeysCount: more than added"
);
if (vetKeysPointer <= no.totalVettedKeys)
revert InvalidVetKeysPointer();
if (vetKeysPointer > no.totalAddedKeys) revert InvalidVetKeysPointer();

uint64 count = SafeCast.toUint64(vettedKeysCount - no.totalVettedKeys);
uint64 count = SafeCast.toUint64(vetKeysPointer - no.totalVettedKeys);
uint64 start = SafeCast.toUint64(
no.totalVettedKeys == 0 ? 0 : no.totalVettedKeys - 1
no.totalVettedKeys == 0 ? 0 : no.totalVettedKeys
);

bytes32 pointer = Batch.serialize({
Expand All @@ -646,18 +655,22 @@ contract CSModule is IStakingModule, CSModuleBase {
count: count
});

no.totalVettedKeys = vettedKeysCount;
no.totalVettedKeys = vetKeysPointer;
queue.enqueue(pointer);

emit BatchEnqueued(nodeOperatorId, start, count);
emit VettedSigningKeysCountChanged(nodeOperatorId, vettedKeysCount);
emit VettedSigningKeysCountChanged(nodeOperatorId, vetKeysPointer);

_incrementNonce();
}

function unvetKeys(
uint256 nodeOperatorId
) external onlyKeyValidatorOrNodeOperatorManager {
)
external
onlyExistingNodeOperator(nodeOperatorId)
onlyKeyValidatorOrNodeOperatorManager
{
_unvetKeys(nodeOperatorId);
accounting.penalize(nodeOperatorId, unvettingFee);
}
Expand Down
59 changes: 0 additions & 59 deletions test/CSMInit.t.sol

This file was deleted.

95 changes: 90 additions & 5 deletions test/CSMAddValidator.t.sol → test/CSModule.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ pragma solidity ^0.8.21;
import "forge-std/Test.sol";
import "../src/CSModule.sol";
import "../src/CSAccounting.sol";
import "../src/lib/Batch.sol";
import "./helpers/Fixtures.sol";
import "./helpers/mocks/StETHMock.sol";
import "./helpers/mocks/CommunityStakingFeeDistributorMock.sol";
Expand Down Expand Up @@ -56,23 +57,45 @@ contract CSMCommon is Test, Fixtures, Utilities, CSModuleBase {
}

function createNodeOperator() internal returns (uint256) {
return createNodeOperator(nodeOperator);
return createNodeOperator(nodeOperator, 1);
}

function createNodeOperator(uint256 keysCount) internal returns (uint256) {
return createNodeOperator(nodeOperator, keysCount);
}

function createNodeOperator(
address managerAddress
address managerAddress,
uint256 keysCount
) internal returns (uint256) {
uint256 keysCount = 1;
(bytes memory keys, bytes memory signatures) = keysSignatures(
keysCount
);
vm.deal(managerAddress, 2 ether);
vm.deal(managerAddress, keysCount * 2 ether);
vm.prank(managerAddress);
csm.addNodeOperatorETH{ value: 2 ether }(keysCount, keys, signatures);
csm.addNodeOperatorETH{ value: keysCount * 2 ether }(
keysCount,
keys,
signatures
);
return csm.getNodeOperatorsCount() - 1;
}
}

contract CsmInitialization is CSMCommon {
function test_initContract() public {
csm = new CSModule("community-staking-module", address(locator));
assertEq(csm.getType(), "community-staking-module");
assertEq(csm.getNodeOperatorsCount(), 0);
}

function test_setAccounting() public {
csm = new CSModule("community-staking-module", address(locator));
csm.setAccounting(address(accounting));
assertEq(address(csm.accounting()), address(accounting));
}
}

contract CSMAddNodeOperator is CSMCommon, PermitTokenBase {
function test_AddNodeOperatorWstETH() public {
uint16 keysCount = 1;
Expand Down Expand Up @@ -582,3 +605,65 @@ contract CsmResetNodeOperatorManagerAddress is CSMCommon {
csm.resetNodeOperatorManagerAddress(noId);
}
}

contract CsmVetKeys is CSMCommon {
function test_vetKeys() public {
uint256 noId = createNodeOperator();

vm.expectEmit(true, true, true, true, address(csm));
emit BatchEnqueued(noId, 0, 1);
vm.expectEmit(true, true, false, true, address(csm));
emit VettedSigningKeysCountChanged(noId, 1);
csm.vetKeys(noId, 1);

NodeOperatorInfo memory no = csm.getNodeOperator(noId);
assertEq(no.totalVettedValidators, 1);
(bytes32[] memory items, , ) = csm.depositQueue(1, bytes32(0));
(uint128 batchNoId, uint64 start, uint64 count) = Batch.deserialize(
items[0]
);
assertEq(batchNoId, uint128(noId));
assertEq(start, 0);
assertEq(count, 1);
}

function test_vetKeys_totalVettedKeysIsNotZero() public {
uint256 noId = createNodeOperator(2);
csm.vetKeys(noId, 1);

vm.expectEmit(true, true, true, true, address(csm));
emit BatchEnqueued(noId, 1, 1);
vm.expectEmit(true, true, false, true, address(csm));
emit VettedSigningKeysCountChanged(noId, 2);
csm.vetKeys(noId, 2);

NodeOperatorInfo memory no = csm.getNodeOperator(noId);
assertEq(no.totalVettedValidators, 2);
(bytes32[] memory items, , ) = csm.depositQueue(2, bytes32(0));
(uint128 batchNoId, uint64 start, uint64 count) = Batch.deserialize(
items[1]
);
assertEq(batchNoId, uint128(noId));
assertEq(start, 1);
assertEq(count, 1);
}

function test_vetKeys_RevertWhenNoNodeOperator() public {
vm.expectRevert(NodeOperatorDoesNotExist.selector);
csm.vetKeys(0, 1);
}

function test_vetKeys_RevertWhenPointerLessThanTotalVetted() public {
uint256 noId = createNodeOperator();
csm.vetKeys(noId, 1);

vm.expectRevert(InvalidVetKeysPointer.selector);
csm.vetKeys(noId, 1);
}

function test_vetKeys_RevertWhenPointerGreaterThanTotalAdded() public {
uint256 noId = createNodeOperator();
vm.expectRevert(InvalidVetKeysPointer.selector);
csm.vetKeys(noId, 2);
}
}

0 comments on commit ee8ccbb

Please sign in to comment.