Skip to content

Commit

Permalink
Merge pull request #42 from lidofinance/fix-queue-cleanup
Browse files Browse the repository at this point in the history
fix: queue cleanup
  • Loading branch information
madlabman authored Nov 24, 2023
2 parents 0fe17c3 + 56c1ca6 commit c1e68ea
Show file tree
Hide file tree
Showing 8 changed files with 409 additions and 135 deletions.
8 changes: 4 additions & 4 deletions src/CSAccounting.sol
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,13 @@ contract CSAccounting is CSAccountingBase, AccessControlEnumerable {
}

bytes32 public constant INSTANT_PENALIZE_BOND_ROLE =
keccak256("INSTANT_PENALIZE_BOND_ROLE");
keccak256("INSTANT_PENALIZE_BOND_ROLE"); // 0x9909cf24c2d3bafa8c229558d86a1b726ba57c3ef6350848dcf434a4181b56c7
bytes32 public constant EL_REWARDS_STEALING_PENALTY_INIT_ROLE =
keccak256("EL_REWARDS_STEALING_PENALTY_INIT_ROLE");
keccak256("EL_REWARDS_STEALING_PENALTY_INIT_ROLE"); // 0xcc2e7ce7be452f766dd24d55d87a3d42901c31ffa5b600cd1dff475abec91c1f
bytes32 public constant EL_REWARDS_STEALING_PENALTY_SETTLE_ROLE =
keccak256("EL_REWARDS_STEALING_PENALTY_SETTLE_ROLE");
keccak256("EL_REWARDS_STEALING_PENALTY_SETTLE_ROLE"); // 0xdf6226649a1ca132f86d419e46892001284368a8f7445b5eb0d3fadf91329fe6
bytes32 public constant SET_BOND_MULTIPLIER_ROLE =
keccak256("SET_BOND_MULTIPLIER_ROLE");
keccak256("SET_BOND_MULTIPLIER_ROLE"); // 0x62131145aee19b18b85aa8ead52ba87f0efb6e61e249155edc68a2c24e8f79b5

// todo: should be reconsidered
uint256 public constant MIN_BLOCKED_BOND_RETENTION_PERIOD = 4 weeks;
Expand Down
108 changes: 61 additions & 47 deletions src/CSModule.sol
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ struct NodeOperator {
uint256 stuckValidatorsCount;
uint256 refundedValidatorsCount;
bool isTargetLimitActive;
uint256 queueNonce;
}

struct NodeOperatorInfo {
Expand Down Expand Up @@ -102,14 +103,21 @@ contract CSModuleBase {
error SameAddress();
error AlreadyProposed();
error InvalidVetKeysPointer();

error QueueLookupNoLimit();
error QueueEmptyBatch();
error QueueBatchInvalidNonce(bytes32 batch);
error QueueBatchInvalidStart(bytes32 batch);
error QueueBatchInvalidCount(bytes32 batch);
error QueueBatchUnvettedKeys(bytes32 batch);
}

contract CSModule is IStakingModule, CSModuleBase {
using QueueLib for QueueLib.Queue;

// @dev max number of node operators is limited by uint128 due to Batch serialization in 32 bytes
// @dev max number of node operators is limited by uint64 due to Batch serialization in 32 bytes
// it seems to be enough
uint128 public constant MAX_NODE_OPERATORS_COUNT = type(uint128).max;
uint64 public constant MAX_NODE_OPERATORS_COUNT = type(uint64).max;
bytes32 public constant SIGNING_KEYS_POSITION =
keccak256("lido.CommunityStakingModule.signingKeysPosition");

Expand Down Expand Up @@ -645,14 +653,13 @@ contract CSModule is IStakingModule, CSModuleBase {
if (vetKeysPointer > no.totalAddedKeys) revert InvalidVetKeysPointer();

uint64 count = SafeCast.toUint64(vetKeysPointer - no.totalVettedKeys);
uint64 start = SafeCast.toUint64(
no.totalVettedKeys == 0 ? 0 : no.totalVettedKeys
);
uint64 start = SafeCast.toUint64(no.totalVettedKeys);

bytes32 pointer = Batch.serialize({
nodeOperatorId: SafeCast.toUint128(nodeOperatorId),
nodeOperatorId: SafeCast.toUint64(nodeOperatorId),
start: start,
count: count
count: count,
nonce: SafeCast.toUint64(no.queueNonce)
});

no.totalVettedKeys = vetKeysPointer;
Expand All @@ -661,7 +668,7 @@ contract CSModule is IStakingModule, CSModuleBase {
emit BatchEnqueued(nodeOperatorId, start, count);
emit VettedSigningKeysCountChanged(nodeOperatorId, vetKeysPointer);

_incrementNonce();
_incrementModuleNonce();
}

function unvetKeys(
Expand All @@ -682,8 +689,9 @@ contract CSModule is IStakingModule, CSModuleBase {
function _unvetKeys(uint256 nodeOperatorId) internal {
NodeOperator storage no = _nodeOperators[nodeOperatorId];
no.totalVettedKeys = no.totalDepositedKeys;
no.queueNonce++;
emit VettedSigningKeysCountChanged(nodeOperatorId, no.totalVettedKeys);
_incrementNonce();
_incrementModuleNonce();
}

function onWithdrawalCredentialsChanged() external {
Expand Down Expand Up @@ -716,7 +724,7 @@ contract CSModule is IStakingModule, CSModuleBase {
_nodeOperators[nodeOperatorId].totalAddedKeys
);

_incrementNonce();
_incrementModuleNonce();
}

function obtainDepositData(
Expand Down Expand Up @@ -754,6 +762,7 @@ contract CSModule is IStakingModule, CSModuleBase {
_totalDepositedValidators += keysCount;
NodeOperator storage no = _nodeOperators[nodeOperatorId];
no.totalDepositedKeys += keysCount;
// redundant check, enforced by _assertIsValidBatch
require(
no.totalDepositedKeys <= no.totalVettedKeys,
"too many keys"
Expand All @@ -773,7 +782,7 @@ contract CSModule is IStakingModule, CSModuleBase {
}

require(loadedKeysCount == depositsCount, "NOT_ENOUGH_KEYS");
_incrementNonce();
_incrementModuleNonce();
}

function _depositableKeysInBatch(
Expand All @@ -789,42 +798,40 @@ contract CSModule is IStakingModule, CSModuleBase {
{
uint256 start;
uint256 count;
uint256 nonce;

(nodeOperatorId, start, count) = Batch.deserialize(batch);
(nodeOperatorId, start, count, nonce) = Batch.deserialize(batch);

NodeOperator storage no = _nodeOperators[nodeOperatorId];
_assertIsValidBatch(no, start, count);
// solhint-disable-next-line func-named-parameters
_assertIsValidBatch(no, batch, start, count, nonce);

startIndex = Math.max(start, no.totalDepositedKeys);
depositableKeysCount = start + count - startIndex;
}

function _assertIsValidBatch(
NodeOperator storage no,
bytes32 batch,
uint256 start,
uint256 count
uint256 count,
uint256 nonce
) internal view {
require(count != 0, "Empty batch given");
require(
_unvettedKeysInBatch(no, start, count) == false,
"Batch contains unvetted keys"
);
require(
start + count <= no.totalAddedKeys,
"Invalid batch range: not enough keys"
);
require(
start <= no.totalDepositedKeys,
"Invalid batch range: skipped keys"
);
if (count == 0) revert QueueEmptyBatch();
if (nonce != no.queueNonce) revert QueueBatchInvalidNonce(batch);
if (start > no.totalDepositedKeys) revert QueueBatchInvalidStart(batch);
if (start + count > no.totalAddedKeys)
revert QueueBatchInvalidCount(batch);
if (_unvettedKeysInBatch(no, start, count))
revert QueueBatchUnvettedKeys(batch);
}

/// @dev returns the next pointer to start cleanup from
function cleanDepositQueue(
uint256 maxItems,
bytes32 pointer
) external returns (bytes32) {
require(maxItems > 0, "Queue walkthrough limit is not set");
if (maxItems == 0) revert QueueLookupNoLimit();

if (Batch.isNil(pointer)) {
pointer = queue.front;
Expand All @@ -836,11 +843,18 @@ contract CSModule is IStakingModule, CSModuleBase {
break;
}

(uint256 nodeOperatorId, uint256 start, uint256 count) = Batch
.deserialize(item);
(
uint256 nodeOperatorId,
uint256 start,
uint256 count,
uint256 nonce
) = Batch.deserialize(item);
NodeOperator storage no = _nodeOperators[nodeOperatorId];
if (_unvettedKeysInBatch(no, start, count)) {
if (
_unvettedKeysInBatch(no, start, count) || nonce != no.queueNonce
) {
queue.remove(pointer, item);
continue;
}

pointer = item;
Expand All @@ -852,16 +866,8 @@ contract CSModule is IStakingModule, CSModuleBase {
function depositQueue(
uint256 maxItems,
bytes32 pointer
)
external
view
returns (
bytes32[] memory items,
bytes32 /* pointer */,
uint256 /* count */
)
{
require(maxItems > 0, "Queue walkthrough limit is not set");
) external view returns (bytes32[] memory items, uint256 /* count */) {
if (maxItems == 0) revert QueueLookupNoLimit();

if (Batch.isNil(pointer)) {
pointer = queue.front;
Expand All @@ -870,12 +876,14 @@ contract CSModule is IStakingModule, CSModuleBase {
return queue.list(pointer, maxItems);
}

/// @dev it is dirty if it contains a batch with unvetted keys
/// or with invalid nonce
/// @dev returns the next pointer to start check from
function isQueueHasUnvettedKeys(
function isQueueDirty(
uint256 maxItems,
bytes32 pointer
) external view returns (bool, bytes32) {
require(maxItems > 0, "Queue walkthrough limit is not set");
if (maxItems == 0) revert QueueLookupNoLimit();

if (Batch.isNil(pointer)) {
pointer = queue.front;
Expand All @@ -887,10 +895,16 @@ contract CSModule is IStakingModule, CSModuleBase {
break;
}

(uint256 nodeOperatorId, uint256 start, uint256 count) = Batch
.deserialize(item);
(
uint256 nodeOperatorId,
uint256 start,
uint256 count,
uint256 nonce
) = Batch.deserialize(item);
NodeOperator storage no = _nodeOperators[nodeOperatorId];
if (_unvettedKeysInBatch(no, start, count)) {
if (
_unvettedKeysInBatch(no, start, count) || nonce != no.queueNonce
) {
return (true, pointer);
}

Expand All @@ -908,7 +922,7 @@ contract CSModule is IStakingModule, CSModuleBase {
return start + count > no.totalVettedKeys;
}

function _incrementNonce() internal {
function _incrementModuleNonce() internal {
_nonce++;
}

Expand Down
16 changes: 9 additions & 7 deletions src/lib/Batch.sol
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,23 @@ pragma solidity 0.8.21;
library Batch {
/// @notice Serialize node operator id, batch start and count of keys into a single bytes32 value
function serialize(
uint128 nodeOperatorId,
uint64 nodeOperatorId,
uint64 start,
uint64 count
uint64 count,
uint64 nonce
) internal pure returns (bytes32 s) {
return bytes32(abi.encodePacked(nodeOperatorId, start, count));
return bytes32(abi.encodePacked(nodeOperatorId, start, count, nonce));
}

/// @notice Deserialize node operator id, batch start and count of keys from a single bytes32 value
function deserialize(
bytes32 b
) internal pure returns (uint128 nodeOperatorId, uint64 start, uint64 count) {
) internal pure returns (uint64 nodeOperatorId, uint64 start, uint64 count, uint64 nonce) {
assembly {
nodeOperatorId := shr(128, b)
start := shr(64, b)
count := b
nodeOperatorId := shr(192, b)
start := shr(128, b)
count := shr(64, b)
nonce := b
}
}

Expand Down
9 changes: 7 additions & 2 deletions src/lib/QueueLib.sol
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ pragma solidity 0.8.21;
library QueueLib {
bytes32 public constant NULL_POINTER = bytes32(0);

// @dev Queue is a linked list of items
// @dev front and back are pointers
struct Queue {
mapping(bytes32 => bytes32) queue;
bytes32 front;
Expand Down Expand Up @@ -38,11 +40,13 @@ library QueueLib {
return self.queue[pointer];
}

// @dev returns items array of size `limit` and actual count of items
// @dev reverts if the queue is empty
function list(Queue storage self, bytes32 pointer, uint256 limit) internal notEmpty(self) view returns (
bytes32[] memory items,
bytes32 /* pointer */,
uint256 /* count */
) {
require(limit > 0, "Queue: limit is not set");
items = new bytes32[](limit);

uint256 i;
Expand All @@ -56,7 +60,8 @@ library QueueLib {
pointer = item;
}

return (items, pointer, i);
// TODO: resize items array to actual count
return (items, i);
}

function isEmpty(Queue storage self) internal view returns (bool) {
Expand Down
28 changes: 18 additions & 10 deletions test/Batch.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -11,44 +11,52 @@ contract BatchTest is Test {
bytes32 b = Batch.serialize({
nodeOperatorId: 999,
start: 3,
count: 42
count: 42,
nonce: 7
});

assertEq(
b,
// noIndex | start | count |
0x000000000000000000000000000003e70000000000000003000000000000002a
// noIndex | start | count | nonce
0x00000000000003e70000000000000003000000000000002a0000000000000007
);
}

function test_deserialize() public {
(uint128 nodeOperatorId, uint64 start, uint64 count) = Batch
.deserialize(
(
uint256 nodeOperatorId,
uint256 start,
uint256 count,
uint256 nonce
) = Batch.deserialize(
0x0000000000000000000000000000000000000000000000000000000000000000
);

assertEq(nodeOperatorId, 0, "nodeOperatorId != 0");
assertEq(start, 0, "start != 0");
assertEq(count, 0, "count != 0");
assertEq(nonce, 0, "nonce != 0");

(nodeOperatorId, start, count) = Batch.deserialize(
0x000000000000000000000000000003e70000000000000003000000000000002a
(nodeOperatorId, start, count, nonce) = Batch.deserialize(
0x00000000000003e70000000000000003000000000000002a0000000000000007
);

assertEq(nodeOperatorId, 999, "nodeOperatorId != 999");
assertEq(start, 3, "start != 3");
assertEq(count, 42, "count != 42");
assertEq(nonce, 7, "nonce != 7");

(nodeOperatorId, start, count) = Batch.deserialize(
(nodeOperatorId, start, count, nonce) = Batch.deserialize(
0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff
);

assertEq(
nodeOperatorId,
type(uint128).max,
"nodeOperatorId != uint128.max"
type(uint64).max,
"nodeOperatorId != uint64.max"
);
assertEq(start, type(uint64).max, "start != uint64.max");
assertEq(count, type(uint64).max, "count != uint64.max");
assertEq(nonce, type(uint64).max, "nonce != uint64.max");
}
}
Loading

0 comments on commit c1e68ea

Please sign in to comment.