Skip to content

Commit

Permalink
feat: keys removal (#44)
Browse files Browse the repository at this point in the history
  • Loading branch information
madlabman authored Dec 4, 2023
1 parent c1e68ea commit 4348a8a
Show file tree
Hide file tree
Showing 4 changed files with 488 additions and 21 deletions.
87 changes: 82 additions & 5 deletions src/CSModule.sol
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ contract CSModuleBase {
event LocatorContractSet(address locatorAddress);
event UnvettingFeeSet(uint256 unvettingFee);

event UnvettingFeeApplied(uint256 indexed nodeOperatorId);

error NodeOperatorDoesNotExist();
error MaxNodeOperatorsCountReached();
error SenderIsNotManagerAddress();
Expand All @@ -110,6 +112,8 @@ contract CSModuleBase {
error QueueBatchInvalidStart(bytes32 batch);
error QueueBatchInvalidCount(bytes32 batch);
error QueueBatchUnvettedKeys(bytes32 batch);

error SigningKeysInvalidOffset();
}

contract CSModule is IStakingModule, CSModuleBase {
Expand Down Expand Up @@ -151,10 +155,10 @@ contract CSModule is IStakingModule, CSModuleBase {
accounting = ICSAccounting(_accounting);
}

function setUnvettingFee(uint256 unvettingFee_) external {
function setUnvettingFee(uint256 _unvettingFee) external {
// TODO: add role check
unvettingFee = unvettingFee_;
emit UnvettingFeeSet(unvettingFee_);
unvettingFee = _unvettingFee;
emit UnvettingFeeSet(_unvettingFee);
}

function _lido() internal view returns (ILido) {
Expand Down Expand Up @@ -558,6 +562,25 @@ contract CSModule is IStakingModule, CSModuleBase {
depositableValidatorsCount = no.totalVettedKeys - no.totalExitedKeys;
}

function getNodeOperatorSigningKeys(
uint256 nodeOperatorId,
uint256 startIndex,
uint256 keysCount
)
external
view
onlyExistingNodeOperator(nodeOperatorId)
returns (bytes memory)
{
return
SigningKeys.loadKeys(
SIGNING_KEYS_POSITION,
nodeOperatorId,
startIndex,
keysCount
);
}

function getNonce() external view returns (uint256) {
return _nonce;
}
Expand Down Expand Up @@ -679,19 +702,45 @@ contract CSModule is IStakingModule, CSModuleBase {
onlyKeyValidatorOrNodeOperatorManager
{
_unvetKeys(nodeOperatorId);
accounting.penalize(nodeOperatorId, unvettingFee);
_applyUnvettingFee(nodeOperatorId);
_incrementModuleNonce();
}

function unsafeUnvetKeys(uint256 nodeOperatorId) external onlyKeyValidator {
_unvetKeys(nodeOperatorId);
_incrementModuleNonce();
}

function removeKeys(
uint256 nodeOperatorId,
uint256 startIndex,
uint256 keysCount
)
external
onlyExistingNodeOperator(nodeOperatorId)
onlyNodeOperatorManager(nodeOperatorId)
{
NodeOperator storage no = _nodeOperators[nodeOperatorId];
if (no.totalVettedKeys > startIndex) {
_unvetKeys(nodeOperatorId);
_applyUnvettingFee(nodeOperatorId);
}

_removeSigningKeys(nodeOperatorId, startIndex, keysCount);
_incrementModuleNonce();
}

/// @dev NB! doesn't increment module nonce
function _unvetKeys(uint256 nodeOperatorId) internal {
NodeOperator storage no = _nodeOperators[nodeOperatorId];
no.totalVettedKeys = no.totalDepositedKeys;
no.queueNonce++;
emit VettedSigningKeysCountChanged(nodeOperatorId, no.totalVettedKeys);
_incrementModuleNonce();
}

function _applyUnvettingFee(uint256 nodeOperatorId) internal {
accounting.penalize(nodeOperatorId, unvettingFee);
emit UnvettingFeeApplied(nodeOperatorId);
}

function onWithdrawalCredentialsChanged() external {
Expand Down Expand Up @@ -727,6 +776,34 @@ contract CSModule is IStakingModule, CSModuleBase {
_incrementModuleNonce();
}

function _removeSigningKeys(
uint256 nodeOperatorId,
uint256 startIndex,
uint256 keysCount
) internal {
NodeOperator storage no = _nodeOperators[nodeOperatorId];

if (startIndex < no.totalDepositedKeys) {
revert SigningKeysInvalidOffset();
}

if (startIndex + keysCount > no.totalAddedKeys) {
revert SigningKeysInvalidOffset();
}

// solhint-disable-next-line func-named-parameters
uint256 newTotalSigningKeys = SigningKeys.removeKeysSigs(
SIGNING_KEYS_POSITION,
nodeOperatorId,
startIndex,
keysCount,
no.totalAddedKeys
);

no.totalAddedKeys = newTotalSigningKeys;
emit TotalSigningKeysCountChanged(nodeOperatorId, newTotalSigningKeys);
}

function obtainDepositData(
uint256 depositsCount,
bytes calldata /* _depositCalldata */
Expand Down
41 changes: 31 additions & 10 deletions src/lib/SigningKeys.sol
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
// See contracts/COMPILERS.md
pragma solidity 0.8.21;

import "@openzeppelin/contracts/utils/math/SafeMath.sol";
import { SafeMath } from "@openzeppelin/contracts/utils/math/SafeMath.sol";

/// @title Library for manage operator keys in storage
/// @author KRogLA
Expand All @@ -23,7 +23,7 @@ library SigningKeys {
return uint256(keccak256(abi.encodePacked(position, nodeOperatorId, keyIndex)));
}

/// @dev store opeartor keys to storage
/// @dev store operator keys to storage
/// @param position storage slot
/// @param nodeOperatorId operator id
/// @param startIndex start index
Expand Down Expand Up @@ -78,7 +78,7 @@ library SigningKeys {
return startIndex;
}

/// @dev remove opeartor keys from storage
/// @dev remove operator keys from storage
/// @param position storage slot
/// @param nodeOperatorId operator id
/// @param startIndex start index
Expand All @@ -105,14 +105,14 @@ library SigningKeys {
for (uint256 i = startIndex + keysCount; i > startIndex;) {
curOffset = position.getKeyOffset(nodeOperatorId, i - 1);
assembly {
// read key
// read key
mstore(add(tmpKey, 0x30), shr(128, sload(add(curOffset, 1)))) // bytes 16..47
mstore(add(tmpKey, 0x20), sload(curOffset)) // bytes 0..31
}
if (i < totalKeysCount) {
lastOffset = position.getKeyOffset(nodeOperatorId, totalKeysCount - 1);
// move last key to deleted key index
for (j = 0; j < 5;) {
for (j = 0; j < 5;) { // load 160 bytes (5 slots) containing key and signature
assembly {
sstore(add(curOffset, j), sload(add(lastOffset, j)))
j := add(j, 1)
Expand All @@ -136,7 +136,7 @@ library SigningKeys {
return totalKeysCount;
}

/// @dev laod opeartor keys from storage
/// @dev load operator keys and signatures from storage
/// @param position storage slot
/// @param nodeOperatorId operator id
/// @param startIndex start index
Expand All @@ -157,12 +157,12 @@ library SigningKeys {
for (uint256 i; i < keysCount;) {
curOffset = position.getKeyOffset(nodeOperatorId, startIndex + i);
assembly {
// read key
let _ofs := add(add(pubkeys, 0x20), mul(add(bufOffset, i), 48)) //PUBKEY_LENGTH = 48
// read key
let _ofs := add(add(pubkeys, 0x20), mul(add(bufOffset, i), 48)) // PUBKEY_LENGTH = 48
mstore(add(_ofs, 0x10), shr(128, sload(add(curOffset, 1)))) // bytes 16..47
mstore(_ofs, sload(curOffset)) // bytes 0..31
// store signature
_ofs := add(add(signatures, 0x20), mul(add(bufOffset, i), 96)) //SIGNATURE_LENGTH = 96
// store signature
_ofs := add(add(signatures, 0x20), mul(add(bufOffset, i), 96)) // SIGNATURE_LENGTH = 96
mstore(_ofs, sload(add(curOffset, 2)))
mstore(add(_ofs, 0x20), sload(add(curOffset, 3)))
mstore(add(_ofs, 0x40), sload(add(curOffset, 4)))
Expand All @@ -171,6 +171,27 @@ library SigningKeys {
}
}

function loadKeys(
bytes32 position,
uint256 nodeOperatorId,
uint256 startIndex,
uint256 keysCount
) internal view returns (bytes memory pubkeys) {
uint256 curOffset;

pubkeys = new bytes(keysCount.mul(PUBKEY_LENGTH));
for (uint256 i; i < keysCount;) {
curOffset = position.getKeyOffset(nodeOperatorId, startIndex + i);
assembly {
// read key
let offset := add(add(pubkeys, 0x20), mul(i, 48)) // PUBKEY_LENGTH = 48
mstore(add(offset, 0x10), shr(128, sload(add(curOffset, 1)))) // bytes 16..47
mstore(offset, sload(curOffset)) // bytes 0..31
i := add(i, 1)
}
}
}

function initKeysSigsBuf(uint256 count) internal pure returns (bytes memory, bytes memory) {
return (new bytes(count.mul(PUBKEY_LENGTH)), new bytes(count.mul(SIGNATURE_LENGTH)));
}
Expand Down
Loading

0 comments on commit 4348a8a

Please sign in to comment.