Skip to content

Commit

Permalink
Many accounts per EOA (#561)
Browse files Browse the repository at this point in the history
* Enable multiple accounts per admin by not hardcoding _data

* Update sender address in tests

* initial unit test

* Update unit test: create multiple accounts with same admin

* Add test for Dynamic and Managed smart wallets

* Use abi.encode instead of encodePacked

* Move deposit fns to AccountExtension
  • Loading branch information
nkrishang authored Oct 30, 2023
1 parent a453dd8 commit 80fa323
Show file tree
Hide file tree
Showing 13 changed files with 400 additions and 65 deletions.
4 changes: 2 additions & 2 deletions contracts/prebuilts/account/dynamic/DynamicAccount.sol
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ contract DynamicAccount is AccountCore, BaseRouter {
}

/// @notice Initializes the smart contract wallet.
function initialize(address _defaultAdmin, bytes calldata) public override initializer {
function initialize(address _defaultAdmin, bytes calldata _data) public override initializer {
__BaseRouter_init();
AccountCoreStorage.data().firstAdmin = _defaultAdmin;
AccountCoreStorage.data().creationSalt = _generateSalt(_defaultAdmin, _data);
_setAdmin(_defaultAdmin, true);
}

Expand Down
12 changes: 2 additions & 10 deletions contracts/prebuilts/account/interface/IAccountFactory.sol
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,8 @@ interface IAccountFactory is IAccountFactoryCore {
//////////////////////////////////////////////////////////////*/

/// @notice Callback function for an Account to register its signers.
function onSignerAdded(
address signer,
address creatorAdmin,
bytes memory data
) external;
function onSignerAdded(address signer, bytes32 salt) external;

/// @notice Callback function for an Account to un-register its signers.
function onSignerRemoved(
address signer,
address creatorAdmin,
bytes memory data
) external;
function onSignerRemoved(address signer, bytes32 salt) external;
}
13 changes: 12 additions & 1 deletion contracts/prebuilts/account/non-upgradeable/Account.sol
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,17 @@ contract Account is AccountCore, ContractMetadata, ERC1271, ERC721Holder, ERC115
}
}

/// @notice Deposit funds for this account in Entrypoint.
function addDeposit() public payable {
entryPoint().depositTo{ value: msg.value }(address(this));
}

/// @notice Withdraw funds for this account from Entrypoint.
function withdrawDepositTo(address payable withdrawAddress, uint256 amount) public {
_onlyAdmin();
entryPoint().withdrawTo(withdrawAddress, amount);
}

/*///////////////////////////////////////////////////////////////
Internal functions
//////////////////////////////////////////////////////////////*/
Expand All @@ -123,7 +134,7 @@ contract Account is AccountCore, ContractMetadata, ERC1271, ERC721Holder, ERC115
function _registerOnFactory() internal virtual {
BaseAccountFactory factoryContract = BaseAccountFactory(factory);
if (!factoryContract.isRegistered(address(this))) {
factoryContract.onRegister(AccountCoreStorage.data().firstAdmin, "");
factoryContract.onRegister(AccountCoreStorage.data().creationSalt);
}
}

Expand Down
26 changes: 10 additions & 16 deletions contracts/prebuilts/account/utils/AccountCore.sol
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ contract AccountCore is IAccountCore, Initializable, Multicall, BaseAccount, Acc
}

/// @notice Initializes the smart contract wallet.
function initialize(address _defaultAdmin, bytes calldata) public virtual initializer {
function initialize(address _defaultAdmin, bytes calldata _data) public virtual initializer {
// This is passed as data in the `_registerOnFactory()` call in `AccountExtension` / `Account`.
AccountCoreStorage.data().firstAdmin = _defaultAdmin;
AccountCoreStorage.data().creationSalt = _generateSalt(_defaultAdmin, _data);
_setAdmin(_defaultAdmin, true);
}

Expand Down Expand Up @@ -168,17 +168,6 @@ contract AccountCore is IAccountCore, Initializable, Multicall, BaseAccount, Acc
External functions
//////////////////////////////////////////////////////////////*/

/// @notice Deposit funds for this account in Entrypoint.
function addDeposit() public payable {
entryPoint().depositTo{ value: msg.value }(address(this));
}

/// @notice Withdraw funds for this account from Entrypoint.
function withdrawDepositTo(address payable withdrawAddress, uint256 amount) public {
_onlyAdmin();
entryPoint().withdrawTo(withdrawAddress, amount);
}

/// @notice Overrides the Entrypoint contract being used.
function setEntrypointOverride(IEntryPoint _entrypointOverride) public virtual {
_onlyAdmin();
Expand All @@ -189,6 +178,11 @@ contract AccountCore is IAccountCore, Initializable, Multicall, BaseAccount, Acc
Internal functions
//////////////////////////////////////////////////////////////*/

/// @dev Returns the salt used when deploying an Account.
function _generateSalt(address _admin, bytes memory _data) internal view virtual returns (bytes32) {
return keccak256(abi.encode(_admin, _data));
}

function getFunctionSignature(bytes calldata data) internal pure returns (bytes4 functionSelector) {
require(data.length >= 4, "!Data");
return bytes4(data[:4]);
Expand Down Expand Up @@ -243,17 +237,17 @@ contract AccountCore is IAccountCore, Initializable, Multicall, BaseAccount, Acc
super._setAdmin(_account, _isAdmin);
if (factory.code.length > 0) {
if (_isAdmin) {
BaseAccountFactory(factory).onSignerAdded(_account, AccountCoreStorage.data().firstAdmin, "");
BaseAccountFactory(factory).onSignerAdded(_account, AccountCoreStorage.data().creationSalt);
} else {
BaseAccountFactory(factory).onSignerRemoved(_account, AccountCoreStorage.data().firstAdmin, "");
BaseAccountFactory(factory).onSignerRemoved(_account, AccountCoreStorage.data().creationSalt);
}
}
}

/// @notice Runs after every `changeRole` run.
function _afterSignerPermissionsUpdate(SignerPermissionRequest calldata _req) internal virtual override {
if (factory.code.length > 0) {
BaseAccountFactory(factory).onSignerAdded(_req.signer, AccountCoreStorage.data().firstAdmin, "");
BaseAccountFactory(factory).onSignerAdded(_req.signer, AccountCoreStorage.data().creationSalt);
}
}
}
2 changes: 1 addition & 1 deletion contracts/prebuilts/account/utils/AccountCoreStorage.sol
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ library AccountCoreStorage {

struct Data {
address entrypointOverride;
address firstAdmin;
bytes32 creationSalt;
}

function data() internal pure returns (Data storage acountCoreData) {
Expand Down
13 changes: 12 additions & 1 deletion contracts/prebuilts/account/utils/AccountExtension.sol
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,17 @@ contract AccountExtension is ContractMetadata, ERC1271, AccountPermissions, ERC7
}
}

/// @notice Deposit funds for this account in Entrypoint.
function addDeposit() public payable {
AccountCore(payable(address(this))).entryPoint().depositTo{ value: msg.value }(address(this));
}

/// @notice Withdraw funds for this account from Entrypoint.
function withdrawDepositTo(address payable withdrawAddress, uint256 amount) public {
_onlyAdmin();
AccountCore(payable(address(this))).entryPoint().withdrawTo(withdrawAddress, amount);
}

/*///////////////////////////////////////////////////////////////
Internal functions
//////////////////////////////////////////////////////////////*/
Expand All @@ -125,7 +136,7 @@ contract AccountExtension is ContractMetadata, ERC1271, AccountPermissions, ERC7
address factory = AccountCore(payable(address(this))).factory();
BaseAccountFactory factoryContract = BaseAccountFactory(factory);
if (!factoryContract.isRegistered(address(this))) {
factoryContract.onRegister(AccountCoreStorage.data().firstAdmin, "");
factoryContract.onRegister(AccountCoreStorage.data().creationSalt);
}
}

Expand Down
33 changes: 10 additions & 23 deletions contracts/prebuilts/account/utils/BaseAccountFactory.sol
Original file line number Diff line number Diff line change
Expand Up @@ -72,20 +72,16 @@ abstract contract BaseAccountFactory is IAccountFactory, Multicall {
}

/// @notice Callback function for an Account to register itself on the factory.
function onRegister(address _defaultAdmin, bytes memory _data) external {
function onRegister(bytes32 _salt) external {
address account = msg.sender;
require(_isAccountOfFactory(account, _defaultAdmin, _data), "AccountFactory: not an account.");
require(_isAccountOfFactory(account, _salt), "AccountFactory: not an account.");

require(allAccounts.add(account), "AccountFactory: account already registered");
}

function onSignerAdded(
address _signer,
address _defaultAdmin,
bytes memory _data
) external {
function onSignerAdded(address _signer, bytes32 _salt) external {
address account = msg.sender;
require(_isAccountOfFactory(account, _defaultAdmin, _data), "AccountFactory: not an account.");
require(_isAccountOfFactory(account, _salt), "AccountFactory: not an account.");

bool isNewSigner = accountsOfSigner[_signer].add(account);

Expand All @@ -95,13 +91,9 @@ abstract contract BaseAccountFactory is IAccountFactory, Multicall {
}

/// @notice Callback function for an Account to un-register its signers.
function onSignerRemoved(
address _signer,
address _defaultAdmin,
bytes memory _data
) external {
function onSignerRemoved(address _signer, bytes32 _salt) external {
address account = msg.sender;
require(_isAccountOfFactory(account, _defaultAdmin, _data), "AccountFactory: not an account.");
require(_isAccountOfFactory(account, _salt), "AccountFactory: not an account.");

bool isAccount = accountsOfSigner[_signer].remove(account);

Expand Down Expand Up @@ -140,13 +132,8 @@ abstract contract BaseAccountFactory is IAccountFactory, Multicall {
//////////////////////////////////////////////////////////////*/

/// @dev Returns whether the caller is an account deployed by this factory.
function _isAccountOfFactory(
address _account,
address _admin,
bytes memory _data
) internal view virtual returns (bool) {
bytes32 salt = _generateSalt(_admin, _data);
address predicted = Clones.predictDeterministicAddress(accountImplementation, salt);
function _isAccountOfFactory(address _account, bytes32 _salt) internal view virtual returns (bool) {
address predicted = Clones.predictDeterministicAddress(accountImplementation, _salt);
return _account == predicted;
}

Expand All @@ -156,8 +143,8 @@ abstract contract BaseAccountFactory is IAccountFactory, Multicall {
}

/// @dev Returns the salt used when deploying an Account.
function _generateSalt(address _admin, bytes memory) internal view virtual returns (bytes32) {
return keccak256(abi.encode(_admin));
function _generateSalt(address _admin, bytes memory _data) internal view virtual returns (bytes32) {
return keccak256(abi.encode(_admin, _data));
}

/// @dev Called in `createAccount`. Initializes the account contract created in `createAccount`.
Expand Down
2 changes: 1 addition & 1 deletion src/test/benchmark/AccountBenchmark.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ contract AccountBenchmarkTest is BaseTest {
address private nonSigner;

// UserOp terminology: `sender` is the smart wallet.
address private sender = 0xBB956D56140CA3f3060986586A2631922a4B347E;
address private sender = 0x0df2C3523703d165Aa7fA1a552f3F0B56275DfC6;
address payable private beneficiary = payable(address(0x45654));

bytes32 private uidCache = bytes32("random uid");
Expand Down
112 changes: 110 additions & 2 deletions src/test/smart-wallet/Account.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ contract SimpleAccountTest is BaseTest {
address private nonSigner;

// UserOp terminology: `sender` is the smart wallet.
address private sender = 0xBB956D56140CA3f3060986586A2631922a4B347E;
address private sender = 0x0df2C3523703d165Aa7fA1a552f3F0B56275DfC6;
address payable private beneficiary = payable(address(0x45654));

bytes32 private uidCache = bytes32("random uid");
Expand Down Expand Up @@ -141,6 +141,63 @@ contract SimpleAccountTest is BaseTest {
ops[0] = op;
}

function _setupUserOpWithSender(
bytes memory _initCode,
bytes memory _callDataForEntrypoint,
address _sender
) internal returns (UserOperation[] memory ops) {
uint256 nonce = entrypoint.getNonce(_sender, 0);

// Get user op fields
UserOperation memory op = UserOperation({
sender: _sender,
nonce: nonce,
initCode: _initCode,
callData: _callDataForEntrypoint,
callGasLimit: 500_000,
verificationGasLimit: 500_000,
preVerificationGas: 500_000,
maxFeePerGas: 0,
maxPriorityFeePerGas: 0,
paymasterAndData: bytes(""),
signature: bytes("")
});

// Sign UserOp
bytes32 opHash = EntryPoint(entrypoint).getUserOpHash(op);
bytes32 msgHash = ECDSA.toEthSignedMessageHash(opHash);

(uint8 v, bytes32 r, bytes32 s) = vm.sign(accountAdminPKey, msgHash);
bytes memory userOpSignature = abi.encodePacked(r, s, v);

address recoveredSigner = ECDSA.recover(msgHash, v, r, s);
address expectedSigner = vm.addr(accountAdminPKey);
assertEq(recoveredSigner, expectedSigner);

op.signature = userOpSignature;

// Store UserOp
ops = new UserOperation[](1);
ops[0] = op;
}

function _setupUserOpExecuteWithSender(
bytes memory _initCode,
address _target,
uint256 _value,
bytes memory _callData,
address _sender
) internal returns (UserOperation[] memory) {
bytes memory callDataForEntrypoint = abi.encodeWithSignature(
"execute(address,uint256,bytes)",
_target,
_value,
_callData
);

return _setupUserOpWithSender(_initCode, callDataForEntrypoint, _sender);
}

function _setupUserOpExecute(
uint256 _signerPKey,
bytes memory _initCode,
Expand Down Expand Up @@ -175,6 +232,11 @@ contract SimpleAccountTest is BaseTest {
return _setupUserOp(_signerPKey, _initCode, callDataForEntrypoint);
}

/// @dev Returns the salt used when deploying an Account.
function _generateSalt(address _admin, bytes memory _data) internal view virtual returns (bytes32) {
return keccak256(abi.encode(_admin, _data));
}

function setUp() public override {
super.setUp();

Expand Down Expand Up @@ -234,7 +296,53 @@ contract SimpleAccountTest is BaseTest {
function test_revert_onRegister_nonFactoryChildContract() public {
vm.prank(address(0x12345));
vm.expectRevert("AccountFactory: not an account.");
accountFactory.onRegister(accountAdmin, "");
accountFactory.onRegister(_generateSalt(accountAdmin, ""));
}

/// @dev Create more than one accounts with the same admin.
function test_state_createAccount_viaEntrypoint_multipleAccountSameAdmin() public {
uint256 amount = 100;

for (uint256 i = 0; i < amount; i += 1) {
bytes memory initCallData = abi.encodeWithSignature(
"createAccount(address,bytes)",
accountAdmin,
bytes(abi.encode(i))
);
bytes memory initCode = abi.encodePacked(abi.encodePacked(address(accountFactory)), initCallData);

address expectedSenderAddress = Clones.predictDeterministicAddress(
accountFactory.accountImplementation(),
_generateSalt(accountAdmin, bytes(abi.encode(i))),
address(accountFactory)
);

UserOperation[] memory userOpCreateAccount = _setupUserOpExecuteWithSender(
initCode,
address(0),
0,
bytes(abi.encode(i)),
expectedSenderAddress
);

vm.expectEmit(true, true, false, true);
emit AccountCreated(expectedSenderAddress, accountAdmin);
EntryPoint(entrypoint).handleOps(userOpCreateAccount, beneficiary);
}

address[] memory allAccounts = accountFactory.getAllAccounts();
assertEq(allAccounts.length, amount);

for (uint256 i = 0; i < amount; i += 1) {
assertEq(
allAccounts[i],
Clones.predictDeterministicAddress(
accountFactory.accountImplementation(),
_generateSalt(accountAdmin, bytes(abi.encode(i))),
address(accountFactory)
)
);
}
}

/*///////////////////////////////////////////////////////////////
Expand Down
2 changes: 1 addition & 1 deletion src/test/smart-wallet/AccountVulnPOC.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ contract SimpleAccountVulnPOCTest is BaseTest {
address private nonSigner;

// UserOp terminology: `sender` is the smart wallet.
address private sender = 0xBB956D56140CA3f3060986586A2631922a4B347E;
address private sender = 0x0df2C3523703d165Aa7fA1a552f3F0B56275DfC6;
address payable private beneficiary = payable(address(0x45654));

bytes32 private uidCache = bytes32("random uid");
Expand Down
Loading

0 comments on commit 80fa323

Please sign in to comment.