diff --git a/src/CSAccounting.sol b/src/CSAccounting.sol index d2ecfb11..fd9922df 100644 --- a/src/CSAccounting.sol +++ b/src/CSAccounting.sol @@ -344,10 +344,11 @@ contract CSAccounting is // todo: can be optimized. get active keys once (uint256 current, uint256 required) = _bondETHSummary(nodeOperatorId); uint256 currentKeysCount = _getNodeOperatorActiveKeys(nodeOperatorId); - uint256 requiredForNextKeys = _getCurveValueByKeysCount( - nodeOperatorId, - currentKeysCount + additionalKeysCount - ) - _getCurveValueByKeysCount(nodeOperatorId, currentKeysCount); + uint256 multiplier = getBondMultiplier(nodeOperatorId); + uint256 requiredForNextKeys = _getBondAmountByKeysCount( + currentKeysCount + additionalKeysCount, + multiplier + ) - _getBondAmountByKeysCount(currentKeysCount, multiplier); uint256 missing = required > current ? required - current : 0; if (missing > 0) { @@ -392,7 +393,7 @@ contract CSAccounting is function getRequiredBondETHForKeys( uint256 keysCount ) public view returns (uint256) { - return _getCurveValueByKeysCount(keysCount); + return _getBondAmountByKeysCount(keysCount); } /// @notice Returns the required bond stETH for the given number of keys. @@ -424,14 +425,14 @@ contract CSAccounting is uint256 currentBond = _ethByShares(_bondShares[nodeOperatorId]); uint256 blockedBond = getBlockedBondETH(nodeOperatorId); if (currentBond > blockedBond) { + uint256 multiplier = getBondMultiplier(nodeOperatorId); currentBond -= blockedBond; - uint256 bondedKeys = _getKeysCountByCurveValue( - nodeOperatorId, - currentBond + uint256 bondedKeys = _getKeysCountByBondAmount( + currentBond, + multiplier ); if ( - currentBond > - _getCurveValueByKeysCount(nodeOperatorId, bondedKeys) + currentBond > _getBondAmountByKeysCount(bondedKeys, multiplier) ) { bondedKeys += 1; } @@ -444,7 +445,7 @@ contract CSAccounting is function getKeysCountByBondETH( uint256 ETHAmount ) public view returns (uint256) { - return _getKeysCountByCurveValue(ETHAmount); + return _getKeysCountByBondAmount(ETHAmount); } /// @notice Returns the number of keys by the given bond stETH amount @@ -932,9 +933,9 @@ contract CSAccounting is ) internal view returns (uint256 current, uint256 required) { current = _ethByShares(getBondShares(nodeOperatorId)); required = - _getCurveValueByKeysCount( - nodeOperatorId, - _getNodeOperatorActiveKeys(nodeOperatorId) + _getBondAmountByKeysCount( + _getNodeOperatorActiveKeys(nodeOperatorId), + getBondMultiplier(nodeOperatorId) ) + getBlockedBondETH(nodeOperatorId); } @@ -945,9 +946,9 @@ contract CSAccounting is current = getBondShares(nodeOperatorId); required = _sharesByEth( - _getCurveValueByKeysCount( - nodeOperatorId, - _getNodeOperatorActiveKeys(nodeOperatorId) + _getBondAmountByKeysCount( + _getNodeOperatorActiveKeys(nodeOperatorId), + getBondMultiplier(nodeOperatorId) ) ) + _sharesByEth(getBlockedBondETH(nodeOperatorId)); diff --git a/src/CSBondCurve.sol b/src/CSBondCurve.sol index b16c87c9..16f3688c 100644 --- a/src/CSBondCurve.sol +++ b/src/CSBondCurve.sol @@ -13,12 +13,13 @@ abstract contract CSBondCurve { // bondCurve[i] | bond amount for `i + 1` keys count uint256[] public bondCurve; - uint256 internal constant MIN_CURVE_LENGTH = 1; // todo: might be redefined in the future uint256 internal constant MAX_CURVE_LENGTH = 20; + uint256 internal constant MIN_CURVE_LENGTH = 1; + uint256 internal constant BASIS_POINTS = 10000; + uint256 internal constant MAX_BOND_MULTIPLIER = BASIS_POINTS; // x1 uint256 internal constant MIN_BOND_MULTIPLIER = MAX_BOND_MULTIPLIER / 2; // x0.5 - uint256 internal constant MAX_BOND_MULTIPLIER = 10000; // x1 /// This mapping contains bond multiplier points (in basis points) for Node Operator's bond. /// By default, all Node Operators have x1 multiplier (10000 basis points). @@ -32,6 +33,7 @@ abstract contract CSBondCurve { function _setBondCurve(uint256[] memory _bondCurve) internal { _checkCurveLength(_bondCurve); + // todo: check curve values (not worse than previous and makes sense) bondCurve = _bondCurve; _bondCurveTrend = _bondCurve[_bondCurve.length - 1] - @@ -43,6 +45,7 @@ abstract contract CSBondCurve { uint256 basisPoints ) internal { _checkMultiplier(basisPoints); + // todo: check curve values (not worse than previous) _bondMultiplierBP[nodeOperatorId] = basisPoints; } @@ -66,69 +69,69 @@ abstract contract CSBondCurve { ) revert InvalidMultiplier(); } - /// @notice Returns the amount of keys for the given bond amount. - function _getKeysCountByCurveValue( + /// @notice Returns keys count for the given bond amount. + function _getKeysCountByBondAmount( uint256 amount ) internal view returns (uint256) { - return _getKeysCountByCurveValue(type(uint256).max, amount); + return _getKeysCountByBondAmount(amount, MAX_BOND_MULTIPLIER); } - /// @notice Returns the amount of keys for the given bond amount for particular node operator. - function _getKeysCountByCurveValue( - uint256 nodeOperatorId, - uint256 amount + /// @notice Returns keys count for the given bond amount for particular node operator. + function _getKeysCountByBondAmount( + uint256 amount, + uint256 multiplier ) internal view returns (uint256) { - uint256 mult = getBondMultiplier(nodeOperatorId); - if (amount < (bondCurve[0] * mult) / BASIS_POINTS) return 0; - uint256 last = (bondCurve[bondCurve.length - 1] * mult) / BASIS_POINTS; - if (amount >= last) { + if (amount < (bondCurve[0] * multiplier) / BASIS_POINTS) return 0; + uint256 maxCurveAmount = (bondCurve[bondCurve.length - 1] * + multiplier) / BASIS_POINTS; + if (amount >= maxCurveAmount) { return bondCurve.length + - ((amount - last) / ((_bondCurveTrend * mult) / BASIS_POINTS)); + ((amount - maxCurveAmount) / + ((_bondCurveTrend * multiplier) / BASIS_POINTS)); } - return _searchKeysByBond(amount, mult); + return _searchKeysCount(amount, multiplier); } - function _searchKeysByBond( - uint256 value, + function _searchKeysCount( + uint256 amount, uint256 multiplier ) internal view returns (uint256) { uint256 low; uint256 high = bondCurve.length - 1; while (low <= high) { uint256 mid = (low + high) / 2; - uint256 midValue = (bondCurve[mid] * multiplier) / BASIS_POINTS; - if (value == midValue) { + uint256 midAmount = (bondCurve[mid] * multiplier) / BASIS_POINTS; + if (amount == midAmount) { return mid + 1; } - if (value < midValue) { + if (amount < midAmount) { // zero mid is avoided above high = mid - 1; - } else if (value > midValue) { + } else if (amount > midAmount) { low = mid + 1; } } return low; } - function _getCurveValueByKeysCount( + function _getBondAmountByKeysCount( uint256 keys ) internal view returns (uint256) { - return _getCurveValueByKeysCount(type(uint256).max, keys); + return _getBondAmountByKeysCount(keys, MAX_BOND_MULTIPLIER); } - function _getCurveValueByKeysCount( - uint256 nodeOperatorId, - uint256 keys + function _getBondAmountByKeysCount( + uint256 keys, + uint256 multiplier ) internal view returns (uint256) { if (keys == 0) return 0; - uint256 mult = getBondMultiplier(nodeOperatorId); if (keys <= bondCurve.length) { - return (bondCurve[keys - 1] * mult) / BASIS_POINTS; + return (bondCurve[keys - 1] * multiplier) / BASIS_POINTS; } return - ((bondCurve[bondCurve.length - 1] * mult) / BASIS_POINTS) + + ((bondCurve[bondCurve.length - 1] * multiplier) / BASIS_POINTS) + (keys - bondCurve.length) * - ((_bondCurveTrend * mult) / BASIS_POINTS); + ((_bondCurveTrend * multiplier) / BASIS_POINTS); } } diff --git a/test/CSBondCurve.t.sol b/test/CSBondCurve.t.sol index 3b10e218..d115edf8 100644 --- a/test/CSBondCurve.t.sol +++ b/test/CSBondCurve.t.sol @@ -21,30 +21,30 @@ contract CSBondCurveTestable is CSBondCurve { _setBondMultiplier(nodeOperatorId, basisPoints); } - function getKeysCountByCurveValue( + function getKeysCountByBondAmount( uint256 amount ) external view returns (uint256) { - return _getKeysCountByCurveValue(amount); + return _getKeysCountByBondAmount(amount); } - function getCurveValueByKeysCount( + function getBondAmountByKeysCount( uint256 keysCount ) external view returns (uint256) { - return _getCurveValueByKeysCount(keysCount); + return _getBondAmountByKeysCount(keysCount); } - function getKeysCountByCurveValue( + function getKeysCountByBondAmount( uint256 nodeOperatorId, uint256 amount ) external view returns (uint256) { - return _getKeysCountByCurveValue(nodeOperatorId, amount); + return _getKeysCountByBondAmount(nodeOperatorId, amount); } - function getCurveValueByKeysCount( + function getBondAmountByKeysCount( uint256 nodeOperatorId, uint256 keysCount ) external view returns (uint256) { - return _getCurveValueByKeysCount(nodeOperatorId, keysCount); + return _getBondAmountByKeysCount(nodeOperatorId, keysCount); } } @@ -71,7 +71,6 @@ contract CSBondCurveTest is Test { uint256[] memory _bondCurve = new uint256[](0); vm.expectRevert(CSBondCurve.InvalidBondCurveLength.selector); - bondCurve.setBondCurve(_bondCurve); } @@ -99,65 +98,65 @@ contract CSBondCurveTest is Test { bondCurve.setBondMultiplier(0, 10001); } - function test_getKeysCountByCurveValue() public { - assertEq(bondCurve.getKeysCountByCurveValue(0), 0); - assertEq(bondCurve.getKeysCountByCurveValue(2 ether), 1); - assertEq(bondCurve.getKeysCountByCurveValue(3 ether), 1); - assertEq(bondCurve.getKeysCountByCurveValue(3.90 ether), 2); - assertEq(bondCurve.getKeysCountByCurveValue(5.70 ether), 3); - assertEq(bondCurve.getKeysCountByCurveValue(7.40 ether), 4); - assertEq(bondCurve.getKeysCountByCurveValue(9.00 ether), 5); - assertEq(bondCurve.getKeysCountByCurveValue(10.50 ether), 6); - assertEq(bondCurve.getKeysCountByCurveValue(11.90 ether), 7); - assertEq(bondCurve.getKeysCountByCurveValue(13.10 ether), 8); - assertEq(bondCurve.getKeysCountByCurveValue(14.30 ether), 9); - assertEq(bondCurve.getKeysCountByCurveValue(15.40 ether), 10); - assertEq(bondCurve.getKeysCountByCurveValue(16.40 ether), 11); - assertEq(bondCurve.getKeysCountByCurveValue(17.40 ether), 12); - - bondCurve.setBondMultiplier(0, 5000); - - assertEq(bondCurve.getKeysCountByCurveValue(0, 0), 0); - assertEq(bondCurve.getKeysCountByCurveValue(0, 2 ether), 2); - assertEq(bondCurve.getKeysCountByCurveValue(0, 3 ether), 3); - assertEq(bondCurve.getKeysCountByCurveValue(0, 3.90 ether), 4); - assertEq(bondCurve.getKeysCountByCurveValue(0, 5.70 ether), 6); - assertEq(bondCurve.getKeysCountByCurveValue(0, 7.40 ether), 9); - assertEq(bondCurve.getKeysCountByCurveValue(0, 9.00 ether), 12); - assertEq(bondCurve.getKeysCountByCurveValue(0, 10.50 ether), 15); - assertEq(bondCurve.getKeysCountByCurveValue(0, 11.90 ether), 18); - assertEq(bondCurve.getKeysCountByCurveValue(0, 13.10 ether), 20); + function test_getKeysCountByBondAmount() public { + assertEq(bondCurve.getKeysCountByBondAmount(0), 0); + assertEq(bondCurve.getKeysCountByBondAmount(2 ether), 1); + assertEq(bondCurve.getKeysCountByBondAmount(3 ether), 1); + assertEq(bondCurve.getKeysCountByBondAmount(3.90 ether), 2); + assertEq(bondCurve.getKeysCountByBondAmount(5.70 ether), 3); + assertEq(bondCurve.getKeysCountByBondAmount(7.40 ether), 4); + assertEq(bondCurve.getKeysCountByBondAmount(9.00 ether), 5); + assertEq(bondCurve.getKeysCountByBondAmount(10.50 ether), 6); + assertEq(bondCurve.getKeysCountByBondAmount(11.90 ether), 7); + assertEq(bondCurve.getKeysCountByBondAmount(13.10 ether), 8); + assertEq(bondCurve.getKeysCountByBondAmount(14.30 ether), 9); + assertEq(bondCurve.getKeysCountByBondAmount(15.40 ether), 10); + assertEq(bondCurve.getKeysCountByBondAmount(16.40 ether), 11); + assertEq(bondCurve.getKeysCountByBondAmount(17.40 ether), 12); + } + + function test_getKeysCountByCurveValue_WithMultiplier() public { + assertEq(bondCurve.getKeysCountByBondAmount(0, 5000), 0); + assertEq(bondCurve.getKeysCountByBondAmount(2 ether, 5000), 2); + assertEq(bondCurve.getKeysCountByBondAmount(3 ether, 5000), 3); + assertEq(bondCurve.getKeysCountByBondAmount(3.90 ether, 5000), 4); + assertEq(bondCurve.getKeysCountByBondAmount(5.70 ether, 5000), 6); + assertEq(bondCurve.getKeysCountByBondAmount(7.40 ether, 5000), 9); + assertEq(bondCurve.getKeysCountByBondAmount(9.00 ether, 5000), 12); + assertEq(bondCurve.getKeysCountByBondAmount(10.50 ether, 5000), 15); + assertEq(bondCurve.getKeysCountByBondAmount(11.90 ether, 5000), 18); + assertEq(bondCurve.getKeysCountByBondAmount(13.10 ether, 5000), 20); + } + + function test_getBondAmountByKeysCount() public { + assertEq(bondCurve.getBondAmountByKeysCount(0), 0); + assertEq(bondCurve.getBondAmountByKeysCount(1), 2 ether); + assertEq(bondCurve.getBondAmountByKeysCount(2), 3.90 ether); + assertEq(bondCurve.getBondAmountByKeysCount(3), 5.70 ether); + assertEq(bondCurve.getBondAmountByKeysCount(4), 7.40 ether); + assertEq(bondCurve.getBondAmountByKeysCount(5), 9.00 ether); + assertEq(bondCurve.getBondAmountByKeysCount(6), 10.50 ether); + assertEq(bondCurve.getBondAmountByKeysCount(7), 11.90 ether); + assertEq(bondCurve.getBondAmountByKeysCount(8), 13.10 ether); + assertEq(bondCurve.getBondAmountByKeysCount(9), 14.30 ether); + assertEq(bondCurve.getBondAmountByKeysCount(10), 15.40 ether); + assertEq(bondCurve.getBondAmountByKeysCount(11), 16.40 ether); + assertEq(bondCurve.getBondAmountByKeysCount(12), 17.40 ether); } - function test_getCurveValueByKeysCount() public { - assertEq(bondCurve.getCurveValueByKeysCount(0), 0); - assertEq(bondCurve.getCurveValueByKeysCount(1), 2 ether); - assertEq(bondCurve.getCurveValueByKeysCount(2), 3.90 ether); - assertEq(bondCurve.getCurveValueByKeysCount(3), 5.70 ether); - assertEq(bondCurve.getCurveValueByKeysCount(4), 7.40 ether); - assertEq(bondCurve.getCurveValueByKeysCount(5), 9.00 ether); - assertEq(bondCurve.getCurveValueByKeysCount(6), 10.50 ether); - assertEq(bondCurve.getCurveValueByKeysCount(7), 11.90 ether); - assertEq(bondCurve.getCurveValueByKeysCount(8), 13.10 ether); - assertEq(bondCurve.getCurveValueByKeysCount(9), 14.30 ether); - assertEq(bondCurve.getCurveValueByKeysCount(10), 15.40 ether); - assertEq(bondCurve.getCurveValueByKeysCount(11), 16.40 ether); - assertEq(bondCurve.getCurveValueByKeysCount(12), 17.40 ether); - - bondCurve.setBondMultiplier(0, 5000); - - assertEq(bondCurve.getCurveValueByKeysCount(0, 0), 0); - assertEq(bondCurve.getCurveValueByKeysCount(0, 1), 1 ether); - assertEq(bondCurve.getCurveValueByKeysCount(0, 2), 1.95 ether); - assertEq(bondCurve.getCurveValueByKeysCount(0, 3), 2.85 ether); - assertEq(bondCurve.getCurveValueByKeysCount(0, 4), 3.70 ether); - assertEq(bondCurve.getCurveValueByKeysCount(0, 5), 4.50 ether); - assertEq(bondCurve.getCurveValueByKeysCount(0, 6), 5.25 ether); - assertEq(bondCurve.getCurveValueByKeysCount(0, 7), 5.95 ether); - assertEq(bondCurve.getCurveValueByKeysCount(0, 8), 6.55 ether); - assertEq(bondCurve.getCurveValueByKeysCount(0, 9), 7.15 ether); - assertEq(bondCurve.getCurveValueByKeysCount(0, 10), 7.7 ether); - assertEq(bondCurve.getCurveValueByKeysCount(0, 11), 8.20 ether); - assertEq(bondCurve.getCurveValueByKeysCount(0, 12), 8.70 ether); + function test_getBondAmountByKeysCount_WithMultiplier() public { + assertEq(bondCurve.getBondAmountByKeysCount(0, 5000), 0); + assertEq(bondCurve.getBondAmountByKeysCount(1, 5000), 1 ether); + assertEq(bondCurve.getBondAmountByKeysCount(2, 5000), 1.95 ether); + assertEq(bondCurve.getBondAmountByKeysCount(3, 5000), 2.85 ether); + assertEq(bondCurve.getBondAmountByKeysCount(4, 5000), 3.70 ether); + assertEq(bondCurve.getBondAmountByKeysCount(5, 5000), 4.50 ether); + assertEq(bondCurve.getBondAmountByKeysCount(6, 5000), 5.25 ether); + assertEq(bondCurve.getBondAmountByKeysCount(7, 5000), 5.95 ether); + assertEq(bondCurve.getBondAmountByKeysCount(8, 5000), 6.55 ether); + assertEq(bondCurve.getBondAmountByKeysCount(9, 5000), 7.15 ether); + assertEq(bondCurve.getBondAmountByKeysCount(10, 5000), 7.7 ether); + assertEq(bondCurve.getBondAmountByKeysCount(11, 5000), 8.20 ether); + assertEq(bondCurve.getBondAmountByKeysCount(12, 5000), 8.70 ether); } }