Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: queue cleanup #42

Merged
merged 7 commits into from
Nov 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/CSAccounting.sol
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// SPDX-FileCopyrightText: 2023 Lido <info@lido.fi>
// SPDX-License-Identifier: GPL-3.0

pragma solidity 0.8.21;

Check warning on line 4 in src/CSAccounting.sol

View workflow job for this annotation

GitHub Actions / Linters

Found more than One contract per file. 2 contracts found!

import { AccessControlEnumerable } from "@openzeppelin/contracts/access/AccessControlEnumerable.sol";

Expand Down Expand Up @@ -83,18 +83,18 @@
bytes32 s;
}
struct BlockedBond {
uint256 ETHAmount;

Check warning on line 86 in src/CSAccounting.sol

View workflow job for this annotation

GitHub Actions / Linters

Variable name must be in mixedCase
uint256 retentionUntil;
}

bytes32 public constant INSTANT_PENALIZE_BOND_ROLE =
keccak256("INSTANT_PENALIZE_BOND_ROLE");
keccak256("INSTANT_PENALIZE_BOND_ROLE"); // 0x9909cf24c2d3bafa8c229558d86a1b726ba57c3ef6350848dcf434a4181b56c7
bytes32 public constant EL_REWARDS_STEALING_PENALTY_INIT_ROLE =
keccak256("EL_REWARDS_STEALING_PENALTY_INIT_ROLE");
keccak256("EL_REWARDS_STEALING_PENALTY_INIT_ROLE"); // 0xcc2e7ce7be452f766dd24d55d87a3d42901c31ffa5b600cd1dff475abec91c1f
bytes32 public constant EL_REWARDS_STEALING_PENALTY_SETTLE_ROLE =
keccak256("EL_REWARDS_STEALING_PENALTY_SETTLE_ROLE");
keccak256("EL_REWARDS_STEALING_PENALTY_SETTLE_ROLE"); // 0xdf6226649a1ca132f86d419e46892001284368a8f7445b5eb0d3fadf91329fe6
bytes32 public constant SET_BOND_MULTIPLIER_ROLE =
keccak256("SET_BOND_MULTIPLIER_ROLE");
keccak256("SET_BOND_MULTIPLIER_ROLE"); // 0x62131145aee19b18b85aa8ead52ba87f0efb6e61e249155edc68a2c24e8f79b5

// todo: should be reconsidered
uint256 public constant MIN_BLOCKED_BOND_RETENTION_PERIOD = 4 weeks;
Expand All @@ -110,7 +110,7 @@
ICSModule private immutable CSM;
IWstETH private immutable WSTETH;

address public FEE_DISTRIBUTOR;

Check warning on line 113 in src/CSAccounting.sol

View workflow job for this annotation

GitHub Actions / Linters

Variable name must be in mixedCase
uint256 public totalBondShares;

uint256 public blockedBondRetentionPeriod;
Expand Down Expand Up @@ -139,13 +139,13 @@
uint256 _blockedBondManagementPeriod
) {
// check zero addresses
require(admin != address(0), "admin is zero address");

Check warning on line 142 in src/CSAccounting.sol

View workflow job for this annotation

GitHub Actions / Linters

Use Custom Errors instead of require statements
require(lidoLocator != address(0), "lido locator is zero address");

Check warning on line 143 in src/CSAccounting.sol

View workflow job for this annotation

GitHub Actions / Linters

Use Custom Errors instead of require statements
require(

Check warning on line 144 in src/CSAccounting.sol

View workflow job for this annotation

GitHub Actions / Linters

Use Custom Errors instead of require statements
communityStakingModule != address(0),
"community staking module is zero address"
);
require(wstETH != address(0), "wstETH is zero address");

Check warning on line 148 in src/CSAccounting.sol

View workflow job for this annotation

GitHub Actions / Linters

Use Custom Errors instead of require statements
_validateBlockedBondPeriods(
_blockedBondRetentionPeriod,
_blockedBondManagementPeriod
Expand Down Expand Up @@ -435,7 +435,7 @@

/// @notice Returns the number of keys by the given bond ETH amount
function getKeysCountByBondETH(
uint256 ETHAmount

Check warning on line 438 in src/CSAccounting.sol

View workflow job for this annotation

GitHub Actions / Linters

Variable name must be in mixedCase
) public view returns (uint256) {
return ETHAmount / getRequiredBondETHForKeys(1);
}
Expand Down Expand Up @@ -678,7 +678,7 @@
bytes32[] memory rewardsProof,
uint256 nodeOperatorId,
uint256 cumulativeFeeShares,
uint256 ETHAmount

Check warning on line 681 in src/CSAccounting.sol

View workflow job for this annotation

GitHub Actions / Linters

Variable name must be in mixedCase
)
external
onlyExistingNodeOperator(nodeOperatorId)
Expand Down Expand Up @@ -761,7 +761,7 @@
function compensateBlockedBondETH(
uint256 nodeOperatorId
) external payable onlyExistingNodeOperator(nodeOperatorId) {
require(msg.value > 0, "value should be greater than zero");

Check warning on line 764 in src/CSAccounting.sol

View workflow job for this annotation

GitHub Actions / Linters

Use Custom Errors instead of require statements
payable(LIDO_LOCATOR.elRewardsVault()).transfer(msg.value);
emit BlockedBondCompensated(nodeOperatorId, msg.value);
_reduceBlockedBondETH(nodeOperatorId, msg.value);
Expand Down
108 changes: 61 additions & 47 deletions src/CSModule.sol
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ struct NodeOperator {
uint256 stuckValidatorsCount;
uint256 refundedValidatorsCount;
bool isTargetLimitActive;
uint256 queueNonce;
}

struct NodeOperatorInfo {
Expand Down Expand Up @@ -102,14 +103,21 @@ contract CSModuleBase {
error SameAddress();
error AlreadyProposed();
error InvalidVetKeysPointer();

error QueueLookupNoLimit();
error QueueEmptyBatch();
error QueueBatchInvalidNonce(bytes32 batch);
error QueueBatchInvalidStart(bytes32 batch);
error QueueBatchInvalidCount(bytes32 batch);
error QueueBatchUnvettedKeys(bytes32 batch);
}

contract CSModule is IStakingModule, CSModuleBase {
using QueueLib for QueueLib.Queue;

// @dev max number of node operators is limited by uint128 due to Batch serialization in 32 bytes
// @dev max number of node operators is limited by uint64 due to Batch serialization in 32 bytes
// it seems to be enough
uint128 public constant MAX_NODE_OPERATORS_COUNT = type(uint128).max;
uint64 public constant MAX_NODE_OPERATORS_COUNT = type(uint64).max;
bytes32 public constant SIGNING_KEYS_POSITION =
keccak256("lido.CommunityStakingModule.signingKeysPosition");

Expand Down Expand Up @@ -645,14 +653,13 @@ contract CSModule is IStakingModule, CSModuleBase {
if (vetKeysPointer > no.totalAddedKeys) revert InvalidVetKeysPointer();

uint64 count = SafeCast.toUint64(vetKeysPointer - no.totalVettedKeys);
uint64 start = SafeCast.toUint64(
no.totalVettedKeys == 0 ? 0 : no.totalVettedKeys
);
uint64 start = SafeCast.toUint64(no.totalVettedKeys);

bytes32 pointer = Batch.serialize({
nodeOperatorId: SafeCast.toUint128(nodeOperatorId),
nodeOperatorId: SafeCast.toUint64(nodeOperatorId),
start: start,
count: count
count: count,
nonce: SafeCast.toUint64(no.queueNonce)
});

no.totalVettedKeys = vetKeysPointer;
Expand All @@ -661,7 +668,7 @@ contract CSModule is IStakingModule, CSModuleBase {
emit BatchEnqueued(nodeOperatorId, start, count);
emit VettedSigningKeysCountChanged(nodeOperatorId, vetKeysPointer);

_incrementNonce();
_incrementModuleNonce();
}

function unvetKeys(
Expand All @@ -682,8 +689,9 @@ contract CSModule is IStakingModule, CSModuleBase {
function _unvetKeys(uint256 nodeOperatorId) internal {
NodeOperator storage no = _nodeOperators[nodeOperatorId];
no.totalVettedKeys = no.totalDepositedKeys;
no.queueNonce++;
emit VettedSigningKeysCountChanged(nodeOperatorId, no.totalVettedKeys);
_incrementNonce();
_incrementModuleNonce();
}

function onWithdrawalCredentialsChanged() external {
Expand Down Expand Up @@ -716,7 +724,7 @@ contract CSModule is IStakingModule, CSModuleBase {
_nodeOperators[nodeOperatorId].totalAddedKeys
);

_incrementNonce();
_incrementModuleNonce();
}

function obtainDepositData(
Expand Down Expand Up @@ -754,6 +762,7 @@ contract CSModule is IStakingModule, CSModuleBase {
_totalDepositedValidators += keysCount;
NodeOperator storage no = _nodeOperators[nodeOperatorId];
no.totalDepositedKeys += keysCount;
// redundant check, enforced by _assertIsValidBatch
skhomuti marked this conversation as resolved.
Show resolved Hide resolved
require(
no.totalDepositedKeys <= no.totalVettedKeys,
"too many keys"
Expand All @@ -773,7 +782,7 @@ contract CSModule is IStakingModule, CSModuleBase {
}

require(loadedKeysCount == depositsCount, "NOT_ENOUGH_KEYS");
_incrementNonce();
_incrementModuleNonce();
}

function _depositableKeysInBatch(
Expand All @@ -789,42 +798,40 @@ contract CSModule is IStakingModule, CSModuleBase {
{
uint256 start;
uint256 count;
uint256 nonce;

(nodeOperatorId, start, count) = Batch.deserialize(batch);
(nodeOperatorId, start, count, nonce) = Batch.deserialize(batch);

NodeOperator storage no = _nodeOperators[nodeOperatorId];
_assertIsValidBatch(no, start, count);
// solhint-disable-next-line func-named-parameters
_assertIsValidBatch(no, batch, start, count, nonce);

startIndex = Math.max(start, no.totalDepositedKeys);
depositableKeysCount = start + count - startIndex;
}

function _assertIsValidBatch(
NodeOperator storage no,
bytes32 batch,
uint256 start,
uint256 count
uint256 count,
uint256 nonce
) internal view {
require(count != 0, "Empty batch given");
require(
_unvettedKeysInBatch(no, start, count) == false,
"Batch contains unvetted keys"
);
require(
start + count <= no.totalAddedKeys,
"Invalid batch range: not enough keys"
);
require(
start <= no.totalDepositedKeys,
"Invalid batch range: skipped keys"
);
if (count == 0) revert QueueEmptyBatch();
if (nonce != no.queueNonce) revert QueueBatchInvalidNonce(batch);
if (start > no.totalDepositedKeys) revert QueueBatchInvalidStart(batch);
if (start + count > no.totalAddedKeys)
revert QueueBatchInvalidCount(batch);
if (_unvettedKeysInBatch(no, start, count))
revert QueueBatchUnvettedKeys(batch);
}

/// @dev returns the next pointer to start cleanup from
function cleanDepositQueue(
uint256 maxItems,
bytes32 pointer
) external returns (bytes32) {
require(maxItems > 0, "Queue walkthrough limit is not set");
if (maxItems == 0) revert QueueLookupNoLimit();

if (Batch.isNil(pointer)) {
pointer = queue.front;
Expand All @@ -836,11 +843,18 @@ contract CSModule is IStakingModule, CSModuleBase {
break;
}

(uint256 nodeOperatorId, uint256 start, uint256 count) = Batch
.deserialize(item);
(
uint256 nodeOperatorId,
uint256 start,
uint256 count,
uint256 nonce
vgorkavenko marked this conversation as resolved.
Show resolved Hide resolved
) = Batch.deserialize(item);
NodeOperator storage no = _nodeOperators[nodeOperatorId];
if (_unvettedKeysInBatch(no, start, count)) {
if (
_unvettedKeysInBatch(no, start, count) || nonce != no.queueNonce
) {
queue.remove(pointer, item);
continue;
madlabman marked this conversation as resolved.
Show resolved Hide resolved
}

pointer = item;
Expand All @@ -852,16 +866,8 @@ contract CSModule is IStakingModule, CSModuleBase {
function depositQueue(
uint256 maxItems,
bytes32 pointer
)
external
view
returns (
bytes32[] memory items,
bytes32 /* pointer */,
uint256 /* count */
)
{
require(maxItems > 0, "Queue walkthrough limit is not set");
) external view returns (bytes32[] memory items, uint256 /* count */) {
if (maxItems == 0) revert QueueLookupNoLimit();

if (Batch.isNil(pointer)) {
pointer = queue.front;
Expand All @@ -870,12 +876,14 @@ contract CSModule is IStakingModule, CSModuleBase {
return queue.list(pointer, maxItems);
}

/// @dev it is dirty if it contains a batch with unvetted keys
/// or with invalid nonce
/// @dev returns the next pointer to start check from
function isQueueHasUnvettedKeys(
function isQueueDirty(
vgorkavenko marked this conversation as resolved.
Show resolved Hide resolved
uint256 maxItems,
bytes32 pointer
) external view returns (bool, bytes32) {
require(maxItems > 0, "Queue walkthrough limit is not set");
if (maxItems == 0) revert QueueLookupNoLimit();

if (Batch.isNil(pointer)) {
pointer = queue.front;
Expand All @@ -887,10 +895,16 @@ contract CSModule is IStakingModule, CSModuleBase {
break;
}

(uint256 nodeOperatorId, uint256 start, uint256 count) = Batch
.deserialize(item);
(
uint256 nodeOperatorId,
vgorkavenko marked this conversation as resolved.
Show resolved Hide resolved
uint256 start,
uint256 count,
uint256 nonce
) = Batch.deserialize(item);
NodeOperator storage no = _nodeOperators[nodeOperatorId];
if (_unvettedKeysInBatch(no, start, count)) {
if (
_unvettedKeysInBatch(no, start, count) || nonce != no.queueNonce
) {
return (true, pointer);
}

Expand All @@ -908,7 +922,7 @@ contract CSModule is IStakingModule, CSModuleBase {
return start + count > no.totalVettedKeys;
}

function _incrementNonce() internal {
function _incrementModuleNonce() internal {
_nonce++;
}

Expand Down
16 changes: 9 additions & 7 deletions src/lib/Batch.sol
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,23 @@ pragma solidity 0.8.21;
library Batch {
/// @notice Serialize node operator id, batch start and count of keys into a single bytes32 value
function serialize(
uint128 nodeOperatorId,
uint64 nodeOperatorId,
uint64 start,
uint64 count
uint64 count,
uint64 nonce
) internal pure returns (bytes32 s) {
return bytes32(abi.encodePacked(nodeOperatorId, start, count));
return bytes32(abi.encodePacked(nodeOperatorId, start, count, nonce));
}

/// @notice Deserialize node operator id, batch start and count of keys from a single bytes32 value
function deserialize(
bytes32 b
) internal pure returns (uint128 nodeOperatorId, uint64 start, uint64 count) {
) internal pure returns (uint64 nodeOperatorId, uint64 start, uint64 count, uint64 nonce) {
assembly {
nodeOperatorId := shr(128, b)
start := shr(64, b)
count := b
nodeOperatorId := shr(192, b)
start := shr(128, b)
count := shr(64, b)
nonce := b
}
}

Expand Down
9 changes: 7 additions & 2 deletions src/lib/QueueLib.sol
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ pragma solidity 0.8.21;
library QueueLib {
bytes32 public constant NULL_POINTER = bytes32(0);

// @dev Queue is a linked list of items
// @dev front and back are pointers
struct Queue {
mapping(bytes32 => bytes32) queue;
bytes32 front;
Expand Down Expand Up @@ -38,11 +40,13 @@ library QueueLib {
return self.queue[pointer];
}

// @dev returns items array of size `limit` and actual count of items
// @dev reverts if the queue is empty
function list(Queue storage self, bytes32 pointer, uint256 limit) internal notEmpty(self) view returns (
bytes32[] memory items,
bytes32 /* pointer */,
uint256 /* count */
) {
require(limit > 0, "Queue: limit is not set");
items = new bytes32[](limit);

uint256 i;
Expand All @@ -56,7 +60,8 @@ library QueueLib {
pointer = item;
}

return (items, pointer, i);
// TODO: resize items array to actual count
return (items, i);
}

function isEmpty(Queue storage self) internal view returns (bool) {
Expand Down
28 changes: 18 additions & 10 deletions test/Batch.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -11,44 +11,52 @@ contract BatchTest is Test {
bytes32 b = Batch.serialize({
nodeOperatorId: 999,
start: 3,
count: 42
count: 42,
nonce: 7
});

assertEq(
b,
// noIndex | start | count |
0x000000000000000000000000000003e70000000000000003000000000000002a
// noIndex | start | count | nonce
0x00000000000003e70000000000000003000000000000002a0000000000000007
);
}

function test_deserialize() public {
(uint128 nodeOperatorId, uint64 start, uint64 count) = Batch
.deserialize(
(
uint256 nodeOperatorId,
uint256 start,
uint256 count,
uint256 nonce
) = Batch.deserialize(
0x0000000000000000000000000000000000000000000000000000000000000000
);

assertEq(nodeOperatorId, 0, "nodeOperatorId != 0");
assertEq(start, 0, "start != 0");
assertEq(count, 0, "count != 0");
assertEq(nonce, 0, "nonce != 0");

(nodeOperatorId, start, count) = Batch.deserialize(
0x000000000000000000000000000003e70000000000000003000000000000002a
(nodeOperatorId, start, count, nonce) = Batch.deserialize(
0x00000000000003e70000000000000003000000000000002a0000000000000007
);

assertEq(nodeOperatorId, 999, "nodeOperatorId != 999");
assertEq(start, 3, "start != 3");
assertEq(count, 42, "count != 42");
assertEq(nonce, 7, "nonce != 7");

(nodeOperatorId, start, count) = Batch.deserialize(
(nodeOperatorId, start, count, nonce) = Batch.deserialize(
0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff
);

assertEq(
nodeOperatorId,
type(uint128).max,
"nodeOperatorId != uint128.max"
type(uint64).max,
"nodeOperatorId != uint64.max"
);
assertEq(start, type(uint64).max, "start != uint64.max");
assertEq(count, type(uint64).max, "count != uint64.max");
assertEq(nonce, type(uint64).max, "nonce != uint64.max");
}
}
Loading