Skip to content

Commit

Permalink
feat: update stuck keys
Browse files Browse the repository at this point in the history
now if NO has stuck keys, targetLimit will be applied to prevent further deposits
  • Loading branch information
skhomuti committed Nov 22, 2023
1 parent ee8ccbb commit 60bd48f
Show file tree
Hide file tree
Showing 4 changed files with 254 additions and 6 deletions.
59 changes: 53 additions & 6 deletions src/CSModule.sol
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import { ILido } from "./interfaces/ILido.sol";

import { QueueLib } from "./lib/QueueLib.sol";
import { Batch } from "./lib/Batch.sol";
import { ValidatorCountsReport } from "./lib/ValidatorCountsReport.sol";

import "./lib/SigningKeys.sol";

Expand All @@ -23,7 +24,6 @@ struct NodeOperator {
address proposedRewardAddress;
bool active;
uint256 targetLimit;
uint256 targetLimitTimestamp;
uint256 stuckPenaltyEndTimestamp;
uint256 totalExitedKeys;
uint256 totalAddedKeys;
Expand Down Expand Up @@ -83,6 +83,14 @@ contract CSModuleBase {
uint256 indexed nodeOperatorId,
uint256 totalValidatorsCount
);
event StuckSigningKeysCountChanged(
uint256 indexed nodeOperatorId,
uint256 stuckValidatorsCount
);
event TargetValidatorsCountChanged(
uint256 indexed nodeOperatorId,
uint256 targetValidatorsCount
);

event BatchEnqueued(
uint256 indexed nodeOperatorId,
Expand Down Expand Up @@ -547,7 +555,9 @@ contract CSModule is IStakingModule, CSModuleBase {
stuckPenaltyEndTimestamp = no.stuckPenaltyEndTimestamp;
totalExitedValidators = no.totalExitedKeys;
totalDepositedValidators = no.totalDepositedKeys;
depositableValidatorsCount = no.totalVettedKeys - no.totalExitedKeys;
depositableValidatorsCount = no.isTargetLimitActive
? no.targetLimit - no.totalDepositedKeys
: no.totalVettedKeys - no.totalDepositedKeys;
}

function getNonce() external view returns (uint256) {
Expand Down Expand Up @@ -590,10 +600,42 @@ contract CSModule is IStakingModule, CSModuleBase {
}

function updateStuckValidatorsCount(
bytes calldata /*_nodeOperatorIds*/,
bytes calldata /*_stuckValidatorsCounts*/
) external {
// TODO: implement
bytes calldata nodeOperatorIds,
bytes calldata stuckValidatorsCounts
) external onlyStakingRouter {
ValidatorCountsReport.validate(nodeOperatorIds, stuckValidatorsCounts);

for (
uint256 i = 0;
i < ValidatorCountsReport.count(nodeOperatorIds);
i++
) {
(
uint256 nodeOperatorId,
uint256 stuckValidatorsCount
) = ValidatorCountsReport.next(
nodeOperatorIds,
stuckValidatorsCounts,
i
);
NodeOperator storage no = _nodeOperators[nodeOperatorId];
no.stuckValidatorsCount = stuckValidatorsCount;

if (stuckValidatorsCount == 0) {
no.isTargetLimitActive = false;
no.targetLimit = 0;
} else {
no.isTargetLimitActive = true;
no.targetLimit = no.totalDepositedKeys;
}

emit StuckSigningKeysCountChanged(
nodeOperatorId,
stuckValidatorsCount
);
emit TargetValidatorsCountChanged(nodeOperatorId, no.targetLimit);
}
_incrementNonce();
}

function updateExitedValidatorsCount(
Expand Down Expand Up @@ -951,4 +993,9 @@ contract CSModule is IStakingModule, CSModuleBase {
// TODO: check the role
_;
}

modifier onlyStakingRouter() {
// TODO check the role
_;
}
}
28 changes: 28 additions & 0 deletions src/lib/ValidatorCountsReport.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// SPDX-FileCopyrightText: 2023 Lido <info@lido.fi>
// SPDX-License-Identifier: GPL-3.0
pragma solidity 0.8.21;

/// @author skhomuti
library ValidatorCountsReport {
error InvalidReportData();


function count(bytes calldata ids) internal pure returns (uint256) {
return ids.length / 8;
}

function validate(bytes calldata ids, bytes calldata counts) internal pure {
uint256 count = count(ids);
if (
counts.length / 16 != count ||
ids.length % 8 != 0 ||
counts.length % 16 != 0) revert InvalidReportData();
}

function next(
bytes calldata ids, bytes calldata counts, uint256 offset
) internal returns (uint256 nodeOperatorId, uint256 keysCount) {
nodeOperatorId = uint256(bytes32(ids[8 * offset:8 * offset + 8]) >> 192);
keysCount = uint256(bytes32(counts[16 * offset:16 * offset + 16]) >> 128);
}
}
59 changes: 59 additions & 0 deletions test/CSModule.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -667,3 +667,62 @@ contract CsmVetKeys is CSMCommon {
csm.vetKeys(noId, 2);
}
}

contract CsmGetNodeOperatorSummary is CSMCommon {
// TODO add more tests here

function test_depositableValidatorsCount_whenStuckKeys() public {
uint256 noId = createNodeOperator(3);
csm.vetKeys(noId, 3);
csm.obtainDepositData(1, "");

(, , , , , , , uint256 depositableValidatorsCount) = csm
.getNodeOperatorSummary(noId);
assertEq(depositableValidatorsCount, 2);

csm.updateStuckValidatorsCount(
bytes.concat(bytes8(0x0000000000000000)),
bytes.concat(bytes16(0x00000000000000000000000000000001))
);

(, , , , , , , depositableValidatorsCount) = csm.getNodeOperatorSummary(
noId
);
assertEq(depositableValidatorsCount, 0);
}
}

contract CsmUpdateStuckValidatorsCount is CSMCommon {
function test_updateStuckValidatorsCount() public {
uint256 noId = createNodeOperator(3);
csm.vetKeys(noId, 3);
csm.obtainDepositData(1, "");

vm.expectEmit(true, true, false, true, address(csm));
emit StuckSigningKeysCountChanged(noId, 1);
vm.expectEmit(true, true, false, true, address(csm));
emit TargetValidatorsCountChanged(noId, 1);
csm.updateStuckValidatorsCount(
bytes.concat(bytes8(0x0000000000000000)),
bytes.concat(bytes16(0x00000000000000000000000000000001))
);

(
bool isTargetLimitActive,
uint256 targetValidatorsCount,
uint256 stuckValidatorsCount,
,
,
,
uint256 totalDepositedValidators,

) = csm.getNodeOperatorSummary(noId);
assertEq(stuckValidatorsCount, 1, "stuckValidatorsCount");
assertEq(
targetValidatorsCount,
totalDepositedValidators,
"targetValidatorsCount"
);
assertTrue(isTargetLimitActive, "isTargetLimitActive");
}
}
114 changes: 114 additions & 0 deletions test/ValidatorCountsReport.t.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
// SPDX-License-Identifier: UNLICENSED
pragma solidity ^0.8.0;

import "forge-std/Test.sol";

import { ValidatorCountsReport } from "../src/lib/ValidatorCountsReport.sol";

contract ReportCaller {
function count(bytes calldata ids) public pure returns (uint256) {
return ValidatorCountsReport.count(ids);
}

function validate(bytes calldata ids, bytes calldata counts) public pure {
ValidatorCountsReport.validate(ids, counts);
}

function next(
bytes calldata ids,
bytes calldata counts,
uint256 offset
) public returns (uint256 nodeOperatorId, uint256 count) {
return ValidatorCountsReport.next(ids, counts, offset);
}
}

contract ValidatorCountsReportTest is Test {
ReportCaller caller;

function setUp() public {
caller = new ReportCaller();
}

function test_validate() public {
(bytes memory ids, bytes memory counts) = (
bytes.concat(bytes8(0x0000000000000001)),
bytes.concat(bytes16(0x00000000000000000000000000000001))
);

caller.validate(ids, counts);
}

function test_validate_invalidIdsLength() public {
(bytes memory ids, bytes memory counts) = (
bytes.concat(bytes8(0x0000000000000001), bytes4(0x00000001)),
bytes.concat(bytes16(0x00000000000000000000000000000001))
);

vm.expectRevert(ValidatorCountsReport.InvalidReportData.selector);
caller.validate(ids, counts);
}

function test_validate_invalidCountsLength() public {
(bytes memory ids, bytes memory counts) = (
bytes.concat(bytes8(0x0000000000000001)),
bytes.concat(
bytes16(0x00000000000000000000000000000001),
bytes4(0x00000001)
)
);

vm.expectRevert(ValidatorCountsReport.InvalidReportData.selector);
caller.validate(ids, counts);
}

function test_validate_differentItemsCount() public {
(bytes memory ids, bytes memory counts) = (
bytes.concat(
bytes8(0x0000000000000001),
bytes8(0x0000000000000002)
),
bytes.concat(bytes16(0x00000000000000000000000000000001))
);

vm.expectRevert(ValidatorCountsReport.InvalidReportData.selector);
caller.validate(ids, counts);
}

function test_count() public {
(bytes memory ids, bytes memory counts) = (
bytes.concat(bytes8(0x0000000000000001)),
bytes.concat(bytes16(0x00000000000000000000000000000001))
);

assertEq(caller.count(ids), 1);
}

function test_next() public {
(bytes memory ids, bytes memory counts) = (
bytes.concat(bytes8(0x0000000000000001)),
bytes.concat(bytes16(0x00000000000000000000000000000001))
);

(uint256 nodeOperatorId, uint256 count) = caller.next(ids, counts, 0);
assertEq(nodeOperatorId, 1, "nodeOperatorId != 1");
assertEq(count, 1, "count != 1");
}

function test_nextWithOffset() public {
(bytes memory ids, bytes memory counts) = (
bytes.concat(
bytes8(0x0000000000000001),
bytes8(0x0000000000000002)
),
bytes.concat(
bytes16(0x00000000000000000000000000000001),
bytes16(0x00000000000000000000000000000002)
)
);

(uint256 nodeOperatorId, uint256 count) = caller.next(ids, counts, 1);
assertEq(nodeOperatorId, 2, "nodeOperatorId != 2");
assertEq(count, 2, "count != 2");
}
}

0 comments on commit 60bd48f

Please sign in to comment.