Skip to content

Commit

Permalink
fix: review (contracts)
Browse files Browse the repository at this point in the history
  • Loading branch information
vgorkavenko committed Nov 29, 2023
1 parent d9e9e87 commit be19ca8
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 114 deletions.
35 changes: 18 additions & 17 deletions src/CSAccounting.sol
Original file line number Diff line number Diff line change
Expand Up @@ -344,10 +344,11 @@ contract CSAccounting is
// todo: can be optimized. get active keys once
(uint256 current, uint256 required) = _bondETHSummary(nodeOperatorId);
uint256 currentKeysCount = _getNodeOperatorActiveKeys(nodeOperatorId);
uint256 requiredForNextKeys = _getCurveValueByKeysCount(
nodeOperatorId,
currentKeysCount + additionalKeysCount
) - _getCurveValueByKeysCount(nodeOperatorId, currentKeysCount);
uint256 multiplier = getBondMultiplier(nodeOperatorId);
uint256 requiredForNextKeys = _getBondAmountByKeysCount(
currentKeysCount + additionalKeysCount,
multiplier
) - _getBondAmountByKeysCount(currentKeysCount, multiplier);

uint256 missing = required > current ? required - current : 0;
if (missing > 0) {
Expand Down Expand Up @@ -392,7 +393,7 @@ contract CSAccounting is
function getRequiredBondETHForKeys(
uint256 keysCount
) public view returns (uint256) {
return _getCurveValueByKeysCount(keysCount);
return _getBondAmountByKeysCount(keysCount);
}

/// @notice Returns the required bond stETH for the given number of keys.
Expand Down Expand Up @@ -424,14 +425,14 @@ contract CSAccounting is
uint256 currentBond = _ethByShares(_bondShares[nodeOperatorId]);
uint256 blockedBond = getBlockedBondETH(nodeOperatorId);
if (currentBond > blockedBond) {
uint256 multiplier = getBondMultiplier(nodeOperatorId);
currentBond -= blockedBond;
uint256 bondedKeys = _getKeysCountByCurveValue(
nodeOperatorId,
currentBond
uint256 bondedKeys = _getKeysCountByBondAmount(
currentBond,
multiplier
);
if (
currentBond >
_getCurveValueByKeysCount(nodeOperatorId, bondedKeys)
currentBond > _getBondAmountByKeysCount(bondedKeys, multiplier)
) {
bondedKeys += 1;
}
Expand All @@ -444,7 +445,7 @@ contract CSAccounting is
function getKeysCountByBondETH(
uint256 ETHAmount

Check warning on line 446 in src/CSAccounting.sol

View workflow job for this annotation

GitHub Actions / Linters

Variable name must be in mixedCase
) public view returns (uint256) {
return _getKeysCountByCurveValue(ETHAmount);
return _getKeysCountByBondAmount(ETHAmount);
}

/// @notice Returns the number of keys by the given bond stETH amount
Expand Down Expand Up @@ -932,9 +933,9 @@ contract CSAccounting is
) internal view returns (uint256 current, uint256 required) {
current = _ethByShares(getBondShares(nodeOperatorId));
required =
_getCurveValueByKeysCount(
nodeOperatorId,
_getNodeOperatorActiveKeys(nodeOperatorId)
_getBondAmountByKeysCount(
_getNodeOperatorActiveKeys(nodeOperatorId),
getBondMultiplier(nodeOperatorId)
) +
getBlockedBondETH(nodeOperatorId);
}
Expand All @@ -945,9 +946,9 @@ contract CSAccounting is
current = getBondShares(nodeOperatorId);
required =
_sharesByEth(
_getCurveValueByKeysCount(
nodeOperatorId,
_getNodeOperatorActiveKeys(nodeOperatorId)
_getBondAmountByKeysCount(
_getNodeOperatorActiveKeys(nodeOperatorId),
getBondMultiplier(nodeOperatorId)
)
) +
_sharesByEth(getBlockedBondETH(nodeOperatorId));
Expand Down
63 changes: 33 additions & 30 deletions src/CSBondCurve.sol
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@ abstract contract CSBondCurve {
// bondCurve[i] | bond amount for `i + 1` keys count
uint256[] public bondCurve;

uint256 internal constant MIN_CURVE_LENGTH = 1;
// todo: might be redefined in the future
uint256 internal constant MAX_CURVE_LENGTH = 20;
uint256 internal constant MIN_CURVE_LENGTH = 1;

uint256 internal constant BASIS_POINTS = 10000;
uint256 internal constant MAX_BOND_MULTIPLIER = BASIS_POINTS; // x1
uint256 internal constant MIN_BOND_MULTIPLIER = MAX_BOND_MULTIPLIER / 2; // x0.5
uint256 internal constant MAX_BOND_MULTIPLIER = 10000; // x1

/// This mapping contains bond multiplier points (in basis points) for Node Operator's bond.
/// By default, all Node Operators have x1 multiplier (10000 basis points).
Expand All @@ -32,6 +33,7 @@ abstract contract CSBondCurve {

function _setBondCurve(uint256[] memory _bondCurve) internal {
_checkCurveLength(_bondCurve);
// todo: check curve values (not worse than previous and makes sense)
bondCurve = _bondCurve;
_bondCurveTrend =
_bondCurve[_bondCurve.length - 1] -
Expand All @@ -43,6 +45,7 @@ abstract contract CSBondCurve {
uint256 basisPoints
) internal {
_checkMultiplier(basisPoints);
// todo: check curve values (not worse than previous)
_bondMultiplierBP[nodeOperatorId] = basisPoints;
}

Expand All @@ -66,69 +69,69 @@ abstract contract CSBondCurve {
) revert InvalidMultiplier();
}

/// @notice Returns the amount of keys for the given bond amount.
function _getKeysCountByCurveValue(
/// @notice Returns keys count for the given bond amount.
function _getKeysCountByBondAmount(
uint256 amount
) internal view returns (uint256) {
return _getKeysCountByCurveValue(type(uint256).max, amount);
return _getKeysCountByBondAmount(amount, MAX_BOND_MULTIPLIER);
}

/// @notice Returns the amount of keys for the given bond amount for particular node operator.
function _getKeysCountByCurveValue(
uint256 nodeOperatorId,
uint256 amount
/// @notice Returns keys count for the given bond amount for particular node operator.
function _getKeysCountByBondAmount(
uint256 amount,
uint256 multiplier
) internal view returns (uint256) {
uint256 mult = getBondMultiplier(nodeOperatorId);
if (amount < (bondCurve[0] * mult) / BASIS_POINTS) return 0;
uint256 last = (bondCurve[bondCurve.length - 1] * mult) / BASIS_POINTS;
if (amount >= last) {
if (amount < (bondCurve[0] * multiplier) / BASIS_POINTS) return 0;
uint256 maxCurveAmount = (bondCurve[bondCurve.length - 1] *
multiplier) / BASIS_POINTS;
if (amount >= maxCurveAmount) {
return
bondCurve.length +
((amount - last) / ((_bondCurveTrend * mult) / BASIS_POINTS));
((amount - maxCurveAmount) /
((_bondCurveTrend * multiplier) / BASIS_POINTS));
}
return _searchKeysByBond(amount, mult);
return _searchKeysCount(amount, multiplier);
}

function _searchKeysByBond(
uint256 value,
function _searchKeysCount(
uint256 amount,
uint256 multiplier
) internal view returns (uint256) {
uint256 low;
uint256 high = bondCurve.length - 1;
while (low <= high) {
uint256 mid = (low + high) / 2;
uint256 midValue = (bondCurve[mid] * multiplier) / BASIS_POINTS;
if (value == midValue) {
uint256 midAmount = (bondCurve[mid] * multiplier) / BASIS_POINTS;
if (amount == midAmount) {
return mid + 1;
}
if (value < midValue) {
if (amount < midAmount) {
// zero mid is avoided above
high = mid - 1;
} else if (value > midValue) {
} else if (amount > midAmount) {
low = mid + 1;
}
}
return low;
}

function _getCurveValueByKeysCount(
function _getBondAmountByKeysCount(
uint256 keys
) internal view returns (uint256) {
return _getCurveValueByKeysCount(type(uint256).max, keys);
return _getBondAmountByKeysCount(keys, MAX_BOND_MULTIPLIER);
}

function _getCurveValueByKeysCount(
uint256 nodeOperatorId,
uint256 keys
function _getBondAmountByKeysCount(
uint256 keys,
uint256 multiplier
) internal view returns (uint256) {
if (keys == 0) return 0;
uint256 mult = getBondMultiplier(nodeOperatorId);
if (keys <= bondCurve.length) {
return (bondCurve[keys - 1] * mult) / BASIS_POINTS;
return (bondCurve[keys - 1] * multiplier) / BASIS_POINTS;
}
return
((bondCurve[bondCurve.length - 1] * mult) / BASIS_POINTS) +
((bondCurve[bondCurve.length - 1] * multiplier) / BASIS_POINTS) +
(keys - bondCurve.length) *
((_bondCurveTrend * mult) / BASIS_POINTS);
((_bondCurveTrend * multiplier) / BASIS_POINTS);
}
}
133 changes: 66 additions & 67 deletions test/CSBondCurve.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -21,30 +21,30 @@ contract CSBondCurveTestable is CSBondCurve {
_setBondMultiplier(nodeOperatorId, basisPoints);
}

function getKeysCountByCurveValue(
function getKeysCountByBondAmount(
uint256 amount
) external view returns (uint256) {
return _getKeysCountByCurveValue(amount);
return _getKeysCountByBondAmount(amount);
}

function getCurveValueByKeysCount(
function getBondAmountByKeysCount(
uint256 keysCount
) external view returns (uint256) {
return _getCurveValueByKeysCount(keysCount);
return _getBondAmountByKeysCount(keysCount);
}

function getKeysCountByCurveValue(
function getKeysCountByBondAmount(
uint256 nodeOperatorId,
uint256 amount
) external view returns (uint256) {
return _getKeysCountByCurveValue(nodeOperatorId, amount);
return _getKeysCountByBondAmount(nodeOperatorId, amount);
}

function getCurveValueByKeysCount(
function getBondAmountByKeysCount(
uint256 nodeOperatorId,
uint256 keysCount
) external view returns (uint256) {
return _getCurveValueByKeysCount(nodeOperatorId, keysCount);
return _getBondAmountByKeysCount(nodeOperatorId, keysCount);
}
}

Expand All @@ -71,7 +71,6 @@ contract CSBondCurveTest is Test {
uint256[] memory _bondCurve = new uint256[](0);

vm.expectRevert(CSBondCurve.InvalidBondCurveLength.selector);

bondCurve.setBondCurve(_bondCurve);
}

Expand Down Expand Up @@ -99,65 +98,65 @@ contract CSBondCurveTest is Test {
bondCurve.setBondMultiplier(0, 10001);
}

function test_getKeysCountByCurveValue() public {
assertEq(bondCurve.getKeysCountByCurveValue(0), 0);
assertEq(bondCurve.getKeysCountByCurveValue(2 ether), 1);
assertEq(bondCurve.getKeysCountByCurveValue(3 ether), 1);
assertEq(bondCurve.getKeysCountByCurveValue(3.90 ether), 2);
assertEq(bondCurve.getKeysCountByCurveValue(5.70 ether), 3);
assertEq(bondCurve.getKeysCountByCurveValue(7.40 ether), 4);
assertEq(bondCurve.getKeysCountByCurveValue(9.00 ether), 5);
assertEq(bondCurve.getKeysCountByCurveValue(10.50 ether), 6);
assertEq(bondCurve.getKeysCountByCurveValue(11.90 ether), 7);
assertEq(bondCurve.getKeysCountByCurveValue(13.10 ether), 8);
assertEq(bondCurve.getKeysCountByCurveValue(14.30 ether), 9);
assertEq(bondCurve.getKeysCountByCurveValue(15.40 ether), 10);
assertEq(bondCurve.getKeysCountByCurveValue(16.40 ether), 11);
assertEq(bondCurve.getKeysCountByCurveValue(17.40 ether), 12);

bondCurve.setBondMultiplier(0, 5000);

assertEq(bondCurve.getKeysCountByCurveValue(0, 0), 0);
assertEq(bondCurve.getKeysCountByCurveValue(0, 2 ether), 2);
assertEq(bondCurve.getKeysCountByCurveValue(0, 3 ether), 3);
assertEq(bondCurve.getKeysCountByCurveValue(0, 3.90 ether), 4);
assertEq(bondCurve.getKeysCountByCurveValue(0, 5.70 ether), 6);
assertEq(bondCurve.getKeysCountByCurveValue(0, 7.40 ether), 9);
assertEq(bondCurve.getKeysCountByCurveValue(0, 9.00 ether), 12);
assertEq(bondCurve.getKeysCountByCurveValue(0, 10.50 ether), 15);
assertEq(bondCurve.getKeysCountByCurveValue(0, 11.90 ether), 18);
assertEq(bondCurve.getKeysCountByCurveValue(0, 13.10 ether), 20);
function test_getKeysCountByBondAmount() public {
assertEq(bondCurve.getKeysCountByBondAmount(0), 0);
assertEq(bondCurve.getKeysCountByBondAmount(2 ether), 1);
assertEq(bondCurve.getKeysCountByBondAmount(3 ether), 1);
assertEq(bondCurve.getKeysCountByBondAmount(3.90 ether), 2);
assertEq(bondCurve.getKeysCountByBondAmount(5.70 ether), 3);
assertEq(bondCurve.getKeysCountByBondAmount(7.40 ether), 4);
assertEq(bondCurve.getKeysCountByBondAmount(9.00 ether), 5);
assertEq(bondCurve.getKeysCountByBondAmount(10.50 ether), 6);
assertEq(bondCurve.getKeysCountByBondAmount(11.90 ether), 7);
assertEq(bondCurve.getKeysCountByBondAmount(13.10 ether), 8);
assertEq(bondCurve.getKeysCountByBondAmount(14.30 ether), 9);
assertEq(bondCurve.getKeysCountByBondAmount(15.40 ether), 10);
assertEq(bondCurve.getKeysCountByBondAmount(16.40 ether), 11);
assertEq(bondCurve.getKeysCountByBondAmount(17.40 ether), 12);
}

function test_getKeysCountByCurveValue_WithMultiplier() public {
assertEq(bondCurve.getKeysCountByBondAmount(0, 5000), 0);
assertEq(bondCurve.getKeysCountByBondAmount(2 ether, 5000), 2);
assertEq(bondCurve.getKeysCountByBondAmount(3 ether, 5000), 3);
assertEq(bondCurve.getKeysCountByBondAmount(3.90 ether, 5000), 4);
assertEq(bondCurve.getKeysCountByBondAmount(5.70 ether, 5000), 6);
assertEq(bondCurve.getKeysCountByBondAmount(7.40 ether, 5000), 9);
assertEq(bondCurve.getKeysCountByBondAmount(9.00 ether, 5000), 12);
assertEq(bondCurve.getKeysCountByBondAmount(10.50 ether, 5000), 15);
assertEq(bondCurve.getKeysCountByBondAmount(11.90 ether, 5000), 18);
assertEq(bondCurve.getKeysCountByBondAmount(13.10 ether, 5000), 20);
}

function test_getBondAmountByKeysCount() public {
assertEq(bondCurve.getBondAmountByKeysCount(0), 0);
assertEq(bondCurve.getBondAmountByKeysCount(1), 2 ether);
assertEq(bondCurve.getBondAmountByKeysCount(2), 3.90 ether);
assertEq(bondCurve.getBondAmountByKeysCount(3), 5.70 ether);
assertEq(bondCurve.getBondAmountByKeysCount(4), 7.40 ether);
assertEq(bondCurve.getBondAmountByKeysCount(5), 9.00 ether);
assertEq(bondCurve.getBondAmountByKeysCount(6), 10.50 ether);
assertEq(bondCurve.getBondAmountByKeysCount(7), 11.90 ether);
assertEq(bondCurve.getBondAmountByKeysCount(8), 13.10 ether);
assertEq(bondCurve.getBondAmountByKeysCount(9), 14.30 ether);
assertEq(bondCurve.getBondAmountByKeysCount(10), 15.40 ether);
assertEq(bondCurve.getBondAmountByKeysCount(11), 16.40 ether);
assertEq(bondCurve.getBondAmountByKeysCount(12), 17.40 ether);
}

function test_getCurveValueByKeysCount() public {
assertEq(bondCurve.getCurveValueByKeysCount(0), 0);
assertEq(bondCurve.getCurveValueByKeysCount(1), 2 ether);
assertEq(bondCurve.getCurveValueByKeysCount(2), 3.90 ether);
assertEq(bondCurve.getCurveValueByKeysCount(3), 5.70 ether);
assertEq(bondCurve.getCurveValueByKeysCount(4), 7.40 ether);
assertEq(bondCurve.getCurveValueByKeysCount(5), 9.00 ether);
assertEq(bondCurve.getCurveValueByKeysCount(6), 10.50 ether);
assertEq(bondCurve.getCurveValueByKeysCount(7), 11.90 ether);
assertEq(bondCurve.getCurveValueByKeysCount(8), 13.10 ether);
assertEq(bondCurve.getCurveValueByKeysCount(9), 14.30 ether);
assertEq(bondCurve.getCurveValueByKeysCount(10), 15.40 ether);
assertEq(bondCurve.getCurveValueByKeysCount(11), 16.40 ether);
assertEq(bondCurve.getCurveValueByKeysCount(12), 17.40 ether);

bondCurve.setBondMultiplier(0, 5000);

assertEq(bondCurve.getCurveValueByKeysCount(0, 0), 0);
assertEq(bondCurve.getCurveValueByKeysCount(0, 1), 1 ether);
assertEq(bondCurve.getCurveValueByKeysCount(0, 2), 1.95 ether);
assertEq(bondCurve.getCurveValueByKeysCount(0, 3), 2.85 ether);
assertEq(bondCurve.getCurveValueByKeysCount(0, 4), 3.70 ether);
assertEq(bondCurve.getCurveValueByKeysCount(0, 5), 4.50 ether);
assertEq(bondCurve.getCurveValueByKeysCount(0, 6), 5.25 ether);
assertEq(bondCurve.getCurveValueByKeysCount(0, 7), 5.95 ether);
assertEq(bondCurve.getCurveValueByKeysCount(0, 8), 6.55 ether);
assertEq(bondCurve.getCurveValueByKeysCount(0, 9), 7.15 ether);
assertEq(bondCurve.getCurveValueByKeysCount(0, 10), 7.7 ether);
assertEq(bondCurve.getCurveValueByKeysCount(0, 11), 8.20 ether);
assertEq(bondCurve.getCurveValueByKeysCount(0, 12), 8.70 ether);
function test_getBondAmountByKeysCount_WithMultiplier() public {
assertEq(bondCurve.getBondAmountByKeysCount(0, 5000), 0);
assertEq(bondCurve.getBondAmountByKeysCount(1, 5000), 1 ether);
assertEq(bondCurve.getBondAmountByKeysCount(2, 5000), 1.95 ether);
assertEq(bondCurve.getBondAmountByKeysCount(3, 5000), 2.85 ether);
assertEq(bondCurve.getBondAmountByKeysCount(4, 5000), 3.70 ether);
assertEq(bondCurve.getBondAmountByKeysCount(5, 5000), 4.50 ether);
assertEq(bondCurve.getBondAmountByKeysCount(6, 5000), 5.25 ether);
assertEq(bondCurve.getBondAmountByKeysCount(7, 5000), 5.95 ether);
assertEq(bondCurve.getBondAmountByKeysCount(8, 5000), 6.55 ether);
assertEq(bondCurve.getBondAmountByKeysCount(9, 5000), 7.15 ether);
assertEq(bondCurve.getBondAmountByKeysCount(10, 5000), 7.7 ether);
assertEq(bondCurve.getBondAmountByKeysCount(11, 5000), 8.20 ether);
assertEq(bondCurve.getBondAmountByKeysCount(12, 5000), 8.70 ether);
}
}

0 comments on commit be19ca8

Please sign in to comment.