Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make the splits weights uint256 #345

Merged
merged 2 commits into from
Mar 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions src/Drips.sol
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ contract Drips is Managed, Streams, Splits {
/// Limits the cost of splitting.
uint256 public constant MAX_SPLITS_RECEIVERS = _MAX_SPLITS_RECEIVERS;
/// @notice The total splits weight of an account
uint32 public constant TOTAL_SPLITS_WEIGHT = _TOTAL_SPLITS_WEIGHT;
uint256 public constant TOTAL_SPLITS_WEIGHT = _TOTAL_SPLITS_WEIGHT;
/// @notice The offset of the controlling driver ID in the account ID.
/// In other words the controlling driver ID is the highest 32 bits of the account ID.
/// Every account ID is a 256-bit integer constructed by concatenating:
Expand Down Expand Up @@ -471,7 +471,7 @@ contract Drips is Managed, Streams, Splits {
/// @return collectableAmt The amount made collectable for the account
/// on top of what was collectable before.
/// @return splitAmt The amount split to the account's splits receivers
function splitResult(uint256 accountId, SplitsReceiver[] memory currReceivers, uint128 amount)
function splitResult(uint256 accountId, SplitsReceiver[] calldata currReceivers, uint128 amount)
public
view
onlyProxy
Expand Down Expand Up @@ -503,7 +503,7 @@ contract Drips is Managed, Streams, Splits {
/// @return collectableAmt The amount made collectable for the account
/// on top of what was collectable before.
/// @return splitAmt The amount split to the account's splits receivers
function split(uint256 accountId, IERC20 erc20, SplitsReceiver[] memory currReceivers)
function split(uint256 accountId, IERC20 erc20, SplitsReceiver[] calldata currReceivers)
public
onlyProxy
returns (uint128 collectableAmt, uint128 splitAmt)
Expand Down Expand Up @@ -747,7 +747,7 @@ contract Drips is Managed, Streams, Splits {
/// This is usually unwanted, because if splitting is repeated,
/// funds split to themselves will be again split using the current configuration.
/// Splitting 100% to self effectively blocks splitting unless the configuration is updated.
function setSplits(uint256 accountId, SplitsReceiver[] memory receivers)
function setSplits(uint256 accountId, SplitsReceiver[] calldata receivers)
public
onlyProxy
onlyDriver(accountId)
Expand All @@ -766,7 +766,7 @@ contract Drips is Managed, Streams, Splits {
/// @param receivers The list of the splits receivers.
/// Must be sorted by the account IDs, without duplicate account IDs and without 0 weights.
/// @return receiversHash The hash of the list of splits receivers.
function hashSplits(SplitsReceiver[] memory receivers)
function hashSplits(SplitsReceiver[] calldata receivers)
public
pure
returns (bytes32 receiversHash)
Expand Down
6 changes: 4 additions & 2 deletions src/ImmutableSplitsDriver.sol
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ contract ImmutableSplitsDriver is Managed {
/// @notice The driver ID which this driver uses when calling Drips.
uint32 public immutable driverId;
/// @notice The required total splits weight of each splits configuration
uint32 public immutable totalSplitsWeight;
uint256 public immutable totalSplitsWeight;
/// @notice The ERC-1967 storage slot holding a single `uint256` counter of created identities.
bytes32 private immutable _counterSlot = _erc1967Slot("eip1967.immutableSplitsDriver.storage");

Expand Down Expand Up @@ -76,7 +76,9 @@ contract ImmutableSplitsDriver is Managed {
uint256 weightSum = 0;
unchecked {
for (uint256 i = 0; i < receivers.length; i++) {
weightSum += receivers[i].weight;
uint256 weight = receivers[i].weight;
if (weight > totalSplitsWeight) weight = totalSplitsWeight + 1;
weightSum += weight;
}
}
require(weightSum == totalSplitsWeight, "Invalid total receivers weight");
Expand Down
33 changes: 19 additions & 14 deletions src/Splits.sol
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ struct SplitsReceiver {
/// @notice The splits weight. Must never be zero.
/// The account will be getting `weight / _TOTAL_SPLITS_WEIGHT`
/// share of the funds collected by the splitting account.
uint32 weight;
uint256 weight;
}

/// @notice Splits can keep track of at most `type(uint128).max`
Expand All @@ -22,7 +22,7 @@ abstract contract Splits {
/// Limits the cost of splitting.
uint256 internal constant _MAX_SPLITS_RECEIVERS = 200;
/// @notice The total splits weight of an account.
uint32 internal constant _TOTAL_SPLITS_WEIGHT = 1_000_000;
uint256 internal constant _TOTAL_SPLITS_WEIGHT = 1_000_000;
/// @notice The amount the contract can keep track of each ERC-20 token.
// slither-disable-next-line unused-state
uint128 internal constant _MAX_SPLITS_BALANCE = _SPLITTABLE_MASK;
Expand Down Expand Up @@ -115,18 +115,20 @@ abstract contract Splits {
/// @return collectableAmt The amount made collectable for the account
/// on top of what was collectable before.
/// @return splitAmt The amount split to the account's splits receivers
function _splitResult(uint256 accountId, SplitsReceiver[] memory currReceivers, uint128 amount)
internal
view
returns (uint128 collectableAmt, uint128 splitAmt)
{
function _splitResult(
uint256 accountId,
SplitsReceiver[] calldata currReceivers,
uint128 amount
) internal view returns (uint128 collectableAmt, uint128 splitAmt) {
_assertCurrSplits(accountId, currReceivers);
if (amount == 0) {
return (0, 0);
}
unchecked {
uint256 splitsWeight = 0;
for (uint256 i = currReceivers.length; i != 0;) {
// This will not overflow because the receivers list
// is verified to add up to no more than _TOTAL_SPLITS_WEIGHT
splitsWeight += currReceivers[--i].weight;
}
splitAmt = uint128(amount * splitsWeight / _TOTAL_SPLITS_WEIGHT);
Expand All @@ -145,7 +147,7 @@ abstract contract Splits {
/// @return collectableAmt The amount made collectable for the account
/// on top of what was collectable before.
/// @return splitAmt The amount split to the account's splits receivers
function _split(uint256 accountId, IERC20 erc20, SplitsReceiver[] memory currReceivers)
function _split(uint256 accountId, IERC20 erc20, SplitsReceiver[] calldata currReceivers)
internal
returns (uint128 collectableAmt, uint128 splitAmt)
{
Expand All @@ -164,6 +166,8 @@ abstract contract Splits {
unchecked {
uint256 splitsWeight = 0;
for (uint256 i = 0; i < currReceivers.length; i++) {
// This will not overflow because the receivers list
// is verified to add up to no more than _TOTAL_SPLITS_WEIGHT
splitsWeight += currReceivers[i].weight;
uint128 currSplitAmt = splitAmt;
splitAmt = uint128(splittable * splitsWeight / _TOTAL_SPLITS_WEIGHT);
Expand Down Expand Up @@ -227,7 +231,7 @@ abstract contract Splits {
/// This is usually unwanted, because if splitting is repeated,
/// funds split to themselves will be again split using the current configuration.
/// Splitting 100% to self effectively blocks splitting unless the configuration is updated.
function _setSplits(uint256 accountId, SplitsReceiver[] memory receivers) internal {
function _setSplits(uint256 accountId, SplitsReceiver[] calldata receivers) internal {
SplitsState storage state = _splitsStorage().splitsStates[accountId];
bytes32 newSplitsHash = _hashSplits(receivers);
if (newSplitsHash == state.splitsHash) return;
Expand All @@ -240,15 +244,16 @@ abstract contract Splits {
/// @notice Validates a list of splits receivers and emits events for them
/// @param receivers The list of splits receivers
/// Must be sorted by the account IDs, without duplicate account IDs and without 0 weights.
function _assertSplitsValid(SplitsReceiver[] memory receivers) private pure {
function _assertSplitsValid(SplitsReceiver[] calldata receivers) private pure {
unchecked {
require(receivers.length <= _MAX_SPLITS_RECEIVERS, "Too many splits receivers");
uint256 totalWeight = 0;
uint256 prevAccountId = 0;
for (uint256 i = 0; i < receivers.length; i++) {
SplitsReceiver memory receiver = receivers[i];
uint32 weight = receiver.weight;
SplitsReceiver calldata receiver = receivers[i];
uint256 weight = receiver.weight;
require(weight != 0, "Splits receiver weight is zero");
if (weight > _TOTAL_SPLITS_WEIGHT) weight = _TOTAL_SPLITS_WEIGHT + 1;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, just to verify if I understood this correctly.

weight = _TOTAL_SPLITS_WEIGHT + 1;

This is a gas optimization. You want to operate with unchecked and still have a the potential overflow protection. Therefore you check if the weight is greater than the _TOTAL_SPLITS_WEIGHT.

You add +1 to ensure it will revert later with totalWeight <= _TOTAL_SPLITS_WEIGHT an additional require would be more expensive.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, exactly 👍

totalWeight += weight;
uint256 accountId = receiver.accountId;
if (accountId <= prevAccountId) require(i == 0, "Splits receivers not sorted");
Expand All @@ -262,7 +267,7 @@ abstract contract Splits {
/// @param accountId The account ID.
/// @param currReceivers The list of the account's current splits receivers.
/// If the splits have never been set, pass an empty array.
function _assertCurrSplits(uint256 accountId, SplitsReceiver[] memory currReceivers)
function _assertCurrSplits(uint256 accountId, SplitsReceiver[] calldata currReceivers)
internal
view
{
Expand All @@ -282,7 +287,7 @@ abstract contract Splits {
/// @param receivers The list of the splits receivers.
/// If the splits have never been set, pass an empty array.
/// @return receiversHash The hash of the list of splits receivers.
function _hashSplits(SplitsReceiver[] memory receivers)
function _hashSplits(SplitsReceiver[] calldata receivers)
internal
pure
returns (bytes32 receiversHash)
Expand Down
10 changes: 5 additions & 5 deletions test/Drips.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ contract DripsTest is Test {
list = new SplitsReceiver[](0);
}

function splitsReceivers(uint256 splitsReceiver, uint32 weight)
function splitsReceivers(uint256 splitsReceiver, uint256 weight)
internal
pure
returns (SplitsReceiver[] memory list)
Expand All @@ -203,9 +203,9 @@ contract DripsTest is Test {

function splitsReceivers(
uint256 splitsReceiver1,
uint32 weight1,
uint256 weight1,
uint256 splitsReceiver2,
uint32 weight2
uint256 weight2
) internal pure returns (SplitsReceiver[] memory list) {
list = new SplitsReceiver[](2);
list[0] = SplitsReceiver(splitsReceiver1, weight1);
Expand Down Expand Up @@ -418,7 +418,7 @@ contract DripsTest is Test {
}

function testUncollectedFundsAreSplitUsingCurrentConfig() public {
uint32 totalWeight = drips.TOTAL_SPLITS_WEIGHT();
uint256 totalWeight = drips.TOTAL_SPLITS_WEIGHT();
setSplits(accountId1, splitsReceivers(receiver1, totalWeight));
setStreams(accountId2, 0, 5, streamsReceivers(accountId1, 5));
skipToCycleEnd();
Expand Down Expand Up @@ -508,7 +508,7 @@ contract DripsTest is Test {
}

function testSplitSplitsFundsReceivedFromAllSources() public {
uint32 totalWeight = drips.TOTAL_SPLITS_WEIGHT();
uint256 totalWeight = drips.TOTAL_SPLITS_WEIGHT();
// Gives
give(accountId2, accountId1, 1);

Expand Down
32 changes: 25 additions & 7 deletions test/ImmutableSplitsDriver.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import {Test} from "forge-std/Test.sol";
contract ImmutableSplitsDriverTest is Test {
Drips internal drips;
ImmutableSplitsDriver internal driver;
uint32 internal totalSplitsWeight;
uint256 internal totalSplitsWeight;

function setUp() public {
Drips dripsLogic = new Drips(10);
Expand All @@ -25,10 +25,18 @@ contract ImmutableSplitsDriverTest is Test {
totalSplitsWeight = driver.totalSplitsWeight();
}

function splitsReceivers(uint256 weight1, uint256 weight2)
internal
pure
returns (SplitsReceiver[] memory list)
{
list = new SplitsReceiver[](2);
list[0] = SplitsReceiver(1, weight1);
list[1] = SplitsReceiver(2, weight2);
}

function testCreateSplits() public {
SplitsReceiver[] memory receivers = new SplitsReceiver[](2);
receivers[0] = SplitsReceiver({accountId: 1, weight: totalSplitsWeight - 1});
receivers[1] = SplitsReceiver({accountId: 2, weight: 1});
SplitsReceiver[] memory receivers = splitsReceivers(totalSplitsWeight - 1, 1);
uint256 nextAccountId = driver.nextAccountId();
AccountMetadata[] memory metadata = new AccountMetadata[](1);
metadata[0] = AccountMetadata("key", "value");
Expand All @@ -43,10 +51,20 @@ contract ImmutableSplitsDriverTest is Test {
}

function testCreateSplitsRevertsWhenWeightsSumTooLow() public {
SplitsReceiver[] memory receivers = new SplitsReceiver[](2);
receivers[0] = SplitsReceiver({accountId: 1, weight: totalSplitsWeight - 2});
receivers[1] = SplitsReceiver({accountId: 2, weight: 1});
SplitsReceiver[] memory receivers = splitsReceivers(totalSplitsWeight - 2, 1);
vm.expectRevert("Invalid total receivers weight");
driver.createSplits(receivers, new AccountMetadata[](0));
}

function testCreateSplitsRevertsWhenWeightsSumTooHigh() public {
SplitsReceiver[] memory receivers = splitsReceivers(totalSplitsWeight - 1, 2);
vm.expectRevert("Invalid total receivers weight");
driver.createSplits(receivers, new AccountMetadata[](0));
}

function testCreateSplitsRevertsWhenWeightsSumOverflows() public {
SplitsReceiver[] memory receivers =
splitsReceivers(totalSplitsWeight + 1, type(uint256).max);
vm.expectRevert("Invalid total receivers weight");
driver.createSplits(receivers, new AccountMetadata[](0));
}
Expand Down
Loading
Loading