From c778b3ca38fe4e27abb186eda6e08f40d5a6e969 Mon Sep 17 00:00:00 2001 From: nkrishang <62195808+nkrishang@users.noreply.github.com> Date: Tue, 31 Oct 2023 01:41:47 +0530 Subject: [PATCH] Proxy pattern smart wallet factory contracts (#562) * Enable multiple accounts per admin by not hardcoding _data * Update sender address in tests * initial unit test * Update unit test: create multiple accounts with same admin * Add test for Dynamic and Managed smart wallets * Use abi.encode instead of encodePacked * Move deposit fns to AccountExtension * Create BaseAccountFactoryStorage * Make account factory contracts initializable * Fix build errors in tests * Store factory in initialize fn * Update tests --- .../account/dynamic/DynamicAccount.sol | 11 +++++-- .../account/dynamic/DynamicAccountFactory.sol | 19 +++++++---- .../account/managed/ManagedAccount.sol | 6 ++-- .../account/managed/ManagedAccountFactory.sol | 27 ++++++++++----- .../account/non-upgradeable/Account.sol | 4 +-- .../non-upgradeable/AccountFactory.sol | 13 +++++--- .../prebuilts/account/utils/AccountCore.sol | 31 ++++++++++------- .../account/utils/AccountCoreStorage.sol | 1 + .../account/utils/BaseAccountFactory.sol | 26 +++++++++------ .../utils/BaseAccountFactoryStorage.sol | 23 +++++++++++++ src/test/benchmark/AccountBenchmark.t.sol | 15 +++++++-- src/test/smart-wallet/Account.t.sol | 27 +++++++++++++-- src/test/smart-wallet/AccountVulnPOC.t.sol | 15 +++++++-- src/test/smart-wallet/DynamicAccount.t.sol | 29 +++++++++++++--- src/test/smart-wallet/ManagedAccount.t.sol | 33 +++++++++++++------ .../account-core/isValidSigner.t.sol | 15 +++++++-- .../setPermissionsForSigner.t.sol | 15 +++++++-- 17 files changed, 236 insertions(+), 74 deletions(-) create mode 100644 contracts/prebuilts/account/utils/BaseAccountFactoryStorage.sol diff --git a/contracts/prebuilts/account/dynamic/DynamicAccount.sol b/contracts/prebuilts/account/dynamic/DynamicAccount.sol index 49301309d..04b638086 100644 --- a/contracts/prebuilts/account/dynamic/DynamicAccount.sol +++ b/contracts/prebuilts/account/dynamic/DynamicAccount.sol @@ -24,16 +24,23 @@ contract DynamicAccount is AccountCore, BaseRouter { //////////////////////////////////////////////////////////////*/ constructor(IEntryPoint _entrypoint, Extension[] memory _defaultExtensions) - AccountCore(_entrypoint, msg.sender) + AccountCore(_entrypoint) BaseRouter(_defaultExtensions) { _disableInitializers(); } /// @notice Initializes the smart contract wallet. - function initialize(address _defaultAdmin, bytes calldata _data) public override initializer { + function initialize( + address _defaultAdmin, + address _factory, + bytes calldata _data + ) public override initializer { __BaseRouter_init(); + + // This is passed as data in the `_registerOnFactory()` call in `AccountExtension` / `Account`. AccountCoreStorage.data().creationSalt = _generateSalt(_defaultAdmin, _data); + AccountCoreStorage.data().factory = _factory; _setAdmin(_defaultAdmin, true); } diff --git a/contracts/prebuilts/account/dynamic/DynamicAccountFactory.sol b/contracts/prebuilts/account/dynamic/DynamicAccountFactory.sol index ea5c320da..3da2d766d 100644 --- a/contracts/prebuilts/account/dynamic/DynamicAccountFactory.sol +++ b/contracts/prebuilts/account/dynamic/DynamicAccountFactory.sol @@ -4,10 +4,11 @@ pragma solidity ^0.8.12; // Utils import "../utils/BaseAccountFactory.sol"; import "@thirdweb-dev/dynamic-contracts/src/interface/IExtension.sol"; +import "../../../extension/upgradeable/Initializable.sol"; // Extensions -import "../../../extension/upgradeable//PermissionsEnumerable.sol"; -import "../../../extension/upgradeable//ContractMetadata.sol"; +import "../../../extension/upgradeable/PermissionsEnumerable.sol"; +import "../../../extension/upgradeable/ContractMetadata.sol"; // Smart wallet implementation import { DynamicAccount, IEntryPoint } from "./DynamicAccount.sol"; @@ -21,20 +22,24 @@ import { DynamicAccount, IEntryPoint } from "./DynamicAccount.sol"; // \$$$$ |$$ | $$ |$$ |$$ | \$$$$$$$ |\$$$$$\$$$$ |\$$$$$$$\ $$$$$$$ | // \____/ \__| \__|\__|\__| \_______| \_____\____/ \_______|\_______/ -contract DynamicAccountFactory is BaseAccountFactory, ContractMetadata, PermissionsEnumerable { +contract DynamicAccountFactory is Initializable, BaseAccountFactory, ContractMetadata, PermissionsEnumerable { address public constant ENTRYPOINT_ADDRESS = 0x5FF137D4b0FDCD49DcA30c7CF57E578a026d2789; /*/////////////////////////////////////////////////////////////// Constructor //////////////////////////////////////////////////////////////*/ - constructor(address _defaultAdmin, IExtension.Extension[] memory _defaultExtensions) + constructor(IExtension.Extension[] memory _defaultExtensions) BaseAccountFactory( - payable(address(new DynamicAccount(IEntryPoint(ENTRYPOINT_ADDRESS), _defaultExtensions))), + address(new DynamicAccount(IEntryPoint(ENTRYPOINT_ADDRESS), _defaultExtensions)), ENTRYPOINT_ADDRESS ) - { + {} + + /// @notice Initializes the factory contract. + function initialize(address _defaultAdmin, string memory _contractURI) external initializer { _setupRole(DEFAULT_ADMIN_ROLE, _defaultAdmin); + _setupContractURI(_contractURI); } /*/////////////////////////////////////////////////////////////// @@ -47,7 +52,7 @@ contract DynamicAccountFactory is BaseAccountFactory, ContractMetadata, Permissi address _admin, bytes calldata _data ) internal override { - DynamicAccount(payable(_account)).initialize(_admin, _data); + DynamicAccount(payable(_account)).initialize(_admin, address(this), _data); } /// @dev Returns whether contract metadata can be set in the given execution context. diff --git a/contracts/prebuilts/account/managed/ManagedAccount.sol b/contracts/prebuilts/account/managed/ManagedAccount.sol index 4e70bb605..5ba01c850 100644 --- a/contracts/prebuilts/account/managed/ManagedAccount.sol +++ b/contracts/prebuilts/account/managed/ManagedAccount.sol @@ -19,15 +19,15 @@ import "@thirdweb-dev/dynamic-contracts/src/core/Router.sol"; import "@thirdweb-dev/dynamic-contracts/src/interface/IRouterState.sol"; contract ManagedAccount is AccountCore, Router, IRouterState { - constructor(IEntryPoint _entrypoint, address _factory) AccountCore(_entrypoint, _factory) {} + constructor(IEntryPoint _entrypoint) AccountCore(_entrypoint) {} /// @notice Returns the implementation contract address for a given function signature. function getImplementationForFunction(bytes4 _functionSelector) public view virtual override returns (address) { - return Router(payable(factory)).getImplementationForFunction(_functionSelector); + return Router(payable(AccountCoreStorage.data().factory)).getImplementationForFunction(_functionSelector); } /// @notice Returns all extensions of the Router. function getAllExtensions() external view returns (Extension[] memory) { - return IRouterState(payable(factory)).getAllExtensions(); + return IRouterState(payable(AccountCoreStorage.data().factory)).getAllExtensions(); } } diff --git a/contracts/prebuilts/account/managed/ManagedAccountFactory.sol b/contracts/prebuilts/account/managed/ManagedAccountFactory.sol index f46e1f2dd..dded66e86 100644 --- a/contracts/prebuilts/account/managed/ManagedAccountFactory.sol +++ b/contracts/prebuilts/account/managed/ManagedAccountFactory.sol @@ -4,6 +4,7 @@ pragma solidity ^0.8.12; // Utils import "@thirdweb-dev/dynamic-contracts/src/presets/BaseRouter.sol"; import "../utils/BaseAccountFactory.sol"; +import "../../../extension/upgradeable/Initializable.sol"; // Extensions import "../../../extension/upgradeable//PermissionsEnumerable.sol"; @@ -21,25 +22,33 @@ import { ManagedAccount, IEntryPoint } from "./ManagedAccount.sol"; // \$$$$ |$$ | $$ |$$ |$$ | \$$$$$$$ |\$$$$$\$$$$ |\$$$$$$$\ $$$$$$$ | // \____/ \__| \__|\__|\__| \_______| \_____\____/ \_______|\_______/ -contract ManagedAccountFactory is BaseAccountFactory, ContractMetadata, PermissionsEnumerable, BaseRouter { +contract ManagedAccountFactory is + Initializable, + BaseAccountFactory, + ContractMetadata, + PermissionsEnumerable, + BaseRouter +{ /*/////////////////////////////////////////////////////////////// Constructor //////////////////////////////////////////////////////////////*/ - constructor( - address _defaultAdmin, - IEntryPoint _entrypoint, - Extension[] memory _defaultExtensions - ) + constructor(IEntryPoint _entrypoint, Extension[] memory _defaultExtensions) BaseRouter(_defaultExtensions) - BaseAccountFactory(payable(address(new ManagedAccount(_entrypoint, address(this)))), address(_entrypoint)) - { + BaseAccountFactory(address(new ManagedAccount(_entrypoint)), address(_entrypoint)) + {} + + /// @notice Initializes the factory contract. + function initialize(address _defaultAdmin, string memory _contractURI) external initializer { __BaseRouter_init(); + _setupRole(DEFAULT_ADMIN_ROLE, _defaultAdmin); bytes32 _extensionRole = keccak256("EXTENSION_ROLE"); _setupRole(_extensionRole, _defaultAdmin); _setRoleAdmin(_extensionRole, _extensionRole); + + _setupContractURI(_contractURI); } /*/////////////////////////////////////////////////////////////// @@ -52,7 +61,7 @@ contract ManagedAccountFactory is BaseAccountFactory, ContractMetadata, Permissi address _admin, bytes calldata _data ) internal override { - ManagedAccount(payable(_account)).initialize(_admin, _data); + ManagedAccount(payable(_account)).initialize(_admin, address(this), _data); } /// @dev Returns whether all relevant permission and other checks are met before any upgrade. diff --git a/contracts/prebuilts/account/non-upgradeable/Account.sol b/contracts/prebuilts/account/non-upgradeable/Account.sol index 67519b0f8..82212d517 100644 --- a/contracts/prebuilts/account/non-upgradeable/Account.sol +++ b/contracts/prebuilts/account/non-upgradeable/Account.sol @@ -37,7 +37,7 @@ contract Account is AccountCore, ContractMetadata, ERC1271, ERC721Holder, ERC115 Constructor, Initializer, Modifiers //////////////////////////////////////////////////////////////*/ - constructor(IEntryPoint _entrypoint, address _factory) AccountCore(_entrypoint, _factory) {} + constructor(IEntryPoint _entrypoint) AccountCore(_entrypoint) {} /// @notice Checks whether the caller is the EntryPoint contract or the admin. modifier onlyAdminOrEntrypoint() virtual { @@ -132,7 +132,7 @@ contract Account is AccountCore, ContractMetadata, ERC1271, ERC721Holder, ERC115 /// @dev Registers the account on the factory if it hasn't been registered yet. function _registerOnFactory() internal virtual { - BaseAccountFactory factoryContract = BaseAccountFactory(factory); + BaseAccountFactory factoryContract = BaseAccountFactory(AccountCoreStorage.data().factory); if (!factoryContract.isRegistered(address(this))) { factoryContract.onRegister(AccountCoreStorage.data().creationSalt); } diff --git a/contracts/prebuilts/account/non-upgradeable/AccountFactory.sol b/contracts/prebuilts/account/non-upgradeable/AccountFactory.sol index e570f0b4b..61741573e 100644 --- a/contracts/prebuilts/account/non-upgradeable/AccountFactory.sol +++ b/contracts/prebuilts/account/non-upgradeable/AccountFactory.sol @@ -5,6 +5,7 @@ pragma solidity ^0.8.12; import "../utils/BaseAccountFactory.sol"; import "../utils/BaseAccount.sol"; import "../../../external-deps/openzeppelin/proxy/Clones.sol"; +import "../../../extension/upgradeable/Initializable.sol"; // Extensions import "../../../extension/upgradeable//PermissionsEnumerable.sol"; @@ -25,15 +26,17 @@ import { Account } from "./Account.sol"; // \$$$$ |$$ | $$ |$$ |$$ | \$$$$$$$ |\$$$$$\$$$$ |\$$$$$$$\ $$$$$$$ | // \____/ \__| \__|\__|\__| \_______| \_____\____/ \_______|\_______/ -contract AccountFactory is BaseAccountFactory, ContractMetadata, PermissionsEnumerable { +contract AccountFactory is Initializable, BaseAccountFactory, ContractMetadata, PermissionsEnumerable { /*/////////////////////////////////////////////////////////////// Constructor //////////////////////////////////////////////////////////////*/ - constructor(address _defaultAdmin, IEntryPoint _entrypoint) - BaseAccountFactory(address(new Account(_entrypoint, address(this))), address(_entrypoint)) - { + constructor(IEntryPoint _entrypoint) BaseAccountFactory(address(new Account(_entrypoint)), address(_entrypoint)) {} + + /// @notice Initializes the factory contract. + function initialize(address _defaultAdmin, string memory _contractURI) external initializer { _setupRole(DEFAULT_ADMIN_ROLE, _defaultAdmin); + _setupContractURI(_contractURI); } /*/////////////////////////////////////////////////////////////// @@ -46,7 +49,7 @@ contract AccountFactory is BaseAccountFactory, ContractMetadata, PermissionsEnum address _admin, bytes calldata _data ) internal override { - Account(payable(_account)).initialize(_admin, _data); + Account(payable(_account)).initialize(_admin, address(this), _data); } /// @dev Returns whether contract metadata can be set in the given execution context. diff --git a/contracts/prebuilts/account/utils/AccountCore.sol b/contracts/prebuilts/account/utils/AccountCore.sol index 79b31ec37..77b9c58ff 100644 --- a/contracts/prebuilts/account/utils/AccountCore.sol +++ b/contracts/prebuilts/account/utils/AccountCore.sol @@ -39,9 +39,6 @@ contract AccountCore is IAccountCore, Initializable, Multicall, BaseAccount, Acc State //////////////////////////////////////////////////////////////*/ - /// @notice EIP 4337 factory for this contract. - address public immutable factory; - /// @notice EIP 4337 Entrypoint contract. IEntryPoint private immutable entrypointContract; @@ -49,16 +46,20 @@ contract AccountCore is IAccountCore, Initializable, Multicall, BaseAccount, Acc Constructor, Initializer, Modifiers //////////////////////////////////////////////////////////////*/ - constructor(IEntryPoint _entrypoint, address _factory) EIP712("Account", "1") { + constructor(IEntryPoint _entrypoint) EIP712("Account", "1") { _disableInitializers(); - factory = _factory; entrypointContract = _entrypoint; } /// @notice Initializes the smart contract wallet. - function initialize(address _defaultAdmin, bytes calldata _data) public virtual initializer { + function initialize( + address _defaultAdmin, + address _factory, + bytes calldata _data + ) public virtual initializer { // This is passed as data in the `_registerOnFactory()` call in `AccountExtension` / `Account`. AccountCoreStorage.data().creationSalt = _generateSalt(_defaultAdmin, _data); + AccountCoreStorage.data().factory = _factory; _setAdmin(_defaultAdmin, true); } @@ -66,6 +67,11 @@ contract AccountCore is IAccountCore, Initializable, Multicall, BaseAccount, Acc View functions //////////////////////////////////////////////////////////////*/ + /// @notice Returns the address of the account factory. + function factory() public view virtual override returns (address) { + return AccountCoreStorage.data().factory; + } + /// @notice Returns the EIP 4337 entrypoint contract. function entryPoint() public view virtual override returns (IEntryPoint) { address entrypointOverride = AccountCoreStorage.data().entrypointOverride; @@ -235,19 +241,22 @@ contract AccountCore is IAccountCore, Initializable, Multicall, BaseAccount, Acc /// @notice Makes the given account an admin. function _setAdmin(address _account, bool _isAdmin) internal virtual override { super._setAdmin(_account, _isAdmin); - if (factory.code.length > 0) { + + address factoryAddr = factory(); + if (factoryAddr.code.length > 0) { if (_isAdmin) { - BaseAccountFactory(factory).onSignerAdded(_account, AccountCoreStorage.data().creationSalt); + BaseAccountFactory(factoryAddr).onSignerAdded(_account, AccountCoreStorage.data().creationSalt); } else { - BaseAccountFactory(factory).onSignerRemoved(_account, AccountCoreStorage.data().creationSalt); + BaseAccountFactory(factoryAddr).onSignerRemoved(_account, AccountCoreStorage.data().creationSalt); } } } /// @notice Runs after every `changeRole` run. function _afterSignerPermissionsUpdate(SignerPermissionRequest calldata _req) internal virtual override { - if (factory.code.length > 0) { - BaseAccountFactory(factory).onSignerAdded(_req.signer, AccountCoreStorage.data().creationSalt); + address factoryAddr = factory(); + if (factoryAddr.code.length > 0) { + BaseAccountFactory(factoryAddr).onSignerAdded(_req.signer, AccountCoreStorage.data().creationSalt); } } } diff --git a/contracts/prebuilts/account/utils/AccountCoreStorage.sol b/contracts/prebuilts/account/utils/AccountCoreStorage.sol index 4356ef94a..8a3f4899f 100644 --- a/contracts/prebuilts/account/utils/AccountCoreStorage.sol +++ b/contracts/prebuilts/account/utils/AccountCoreStorage.sol @@ -9,6 +9,7 @@ library AccountCoreStorage { struct Data { address entrypointOverride; + address factory; bytes32 creationSalt; } diff --git a/contracts/prebuilts/account/utils/BaseAccountFactory.sol b/contracts/prebuilts/account/utils/BaseAccountFactory.sol index 37da16076..aa19be908 100644 --- a/contracts/prebuilts/account/utils/BaseAccountFactory.sol +++ b/contracts/prebuilts/account/utils/BaseAccountFactory.sol @@ -2,6 +2,7 @@ pragma solidity ^0.8.12; // Utils +import "./BaseAccountFactoryStorage.sol"; import "../../../extension/Multicall.sol"; import "../../../external-deps/openzeppelin/proxy/Clones.sol"; import "../../../external-deps/openzeppelin/utils/structs/EnumerableSet.sol"; @@ -32,9 +33,6 @@ abstract contract BaseAccountFactory is IAccountFactory, Multicall { address public immutable accountImplementation; address public immutable entrypoint; - EnumerableSet.AddressSet private allAccounts; - mapping(address => EnumerableSet.AddressSet) internal accountsOfSigner; - /*/////////////////////////////////////////////////////////////// Constructor //////////////////////////////////////////////////////////////*/ @@ -61,7 +59,10 @@ abstract contract BaseAccountFactory is IAccountFactory, Multicall { account = Clones.cloneDeterministic(impl, salt); if (msg.sender != entrypoint) { - require(allAccounts.add(account), "AccountFactory: account already registered"); + require( + _baseAccountFactoryStorage().allAccounts.add(account), + "AccountFactory: account already registered" + ); } _initializeAccount(account, _admin, _data); @@ -76,14 +77,14 @@ abstract contract BaseAccountFactory is IAccountFactory, Multicall { address account = msg.sender; require(_isAccountOfFactory(account, _salt), "AccountFactory: not an account."); - require(allAccounts.add(account), "AccountFactory: account already registered"); + require(_baseAccountFactoryStorage().allAccounts.add(account), "AccountFactory: account already registered"); } function onSignerAdded(address _signer, bytes32 _salt) external { address account = msg.sender; require(_isAccountOfFactory(account, _salt), "AccountFactory: not an account."); - bool isNewSigner = accountsOfSigner[_signer].add(account); + bool isNewSigner = _baseAccountFactoryStorage().accountsOfSigner[_signer].add(account); if (isNewSigner) { emit SignerAdded(account, _signer); @@ -95,7 +96,7 @@ abstract contract BaseAccountFactory is IAccountFactory, Multicall { address account = msg.sender; require(_isAccountOfFactory(account, _salt), "AccountFactory: not an account."); - bool isAccount = accountsOfSigner[_signer].remove(account); + bool isAccount = _baseAccountFactoryStorage().accountsOfSigner[_signer].remove(account); if (isAccount) { emit SignerRemoved(account, _signer); @@ -108,12 +109,12 @@ abstract contract BaseAccountFactory is IAccountFactory, Multicall { /// @notice Returns whether an account is registered on this factory. function isRegistered(address _account) external view returns (bool) { - return allAccounts.contains(_account); + return _baseAccountFactoryStorage().allAccounts.contains(_account); } /// @notice Returns all accounts created on the factory. function getAllAccounts() external view returns (address[] memory) { - return allAccounts.values(); + return _baseAccountFactoryStorage().allAccounts.values(); } /// @notice Returns the address of an Account that would be deployed with the given admin signer. @@ -124,7 +125,7 @@ abstract contract BaseAccountFactory is IAccountFactory, Multicall { /// @notice Returns all accounts that the given address is a signer of. function getAccountsOfSigner(address signer) external view returns (address[] memory accounts) { - return accountsOfSigner[signer].values(); + return _baseAccountFactoryStorage().accountsOfSigner[signer].values(); } /*/////////////////////////////////////////////////////////////// @@ -147,6 +148,11 @@ abstract contract BaseAccountFactory is IAccountFactory, Multicall { return keccak256(abi.encode(_admin, _data)); } + /// @dev Returns the BaseAccountFactory contract's storage. + function _baseAccountFactoryStorage() internal pure returns (BaseAccountFactoryStorage.Data storage) { + return BaseAccountFactoryStorage.data(); + } + /// @dev Called in `createAccount`. Initializes the account contract created in `createAccount`. function _initializeAccount( address _account, diff --git a/contracts/prebuilts/account/utils/BaseAccountFactoryStorage.sol b/contracts/prebuilts/account/utils/BaseAccountFactoryStorage.sol new file mode 100644 index 000000000..10bc27c61 --- /dev/null +++ b/contracts/prebuilts/account/utils/BaseAccountFactoryStorage.sol @@ -0,0 +1,23 @@ +// SPDX-License-Identifier: Apache-2.0 +pragma solidity ^0.8.11; + +import "../../../external-deps/openzeppelin/utils/structs/EnumerableSet.sol"; + +library BaseAccountFactoryStorage { + /// @custom:storage-location erc7201:base.account.factory.storage + /// @dev keccak256(abi.encode(uint256(keccak256("base.account.factory.storage")) - 1)) & ~bytes32(uint256(0xff)) + bytes32 public constant BASE_ACCOUNT_FACTORY_STORAGE_POSITION = + 0x82f5b3e5f5ca1c04b70bced106a2c3b72d9cb53ebbafb3cad0740983db742900; + + struct Data { + EnumerableSet.AddressSet allAccounts; + mapping(address => EnumerableSet.AddressSet) accountsOfSigner; + } + + function data() internal pure returns (Data storage baseAccountFactoryData) { + bytes32 position = BASE_ACCOUNT_FACTORY_STORAGE_POSITION; + assembly { + baseAccountFactoryData.slot := position + } + } +} diff --git a/src/test/benchmark/AccountBenchmark.t.sol b/src/test/benchmark/AccountBenchmark.t.sol index 522540d5a..82032e599 100644 --- a/src/test/benchmark/AccountBenchmark.t.sol +++ b/src/test/benchmark/AccountBenchmark.t.sol @@ -3,6 +3,7 @@ pragma solidity ^0.8.0; // Test utils import "../utils/BaseTest.sol"; +import { TWProxy } from "contracts/infra/TWProxy.sol"; // Account Abstraction setup for smart wallets. import { EntryPoint, IEntryPoint } from "contracts/prebuilts/account/utils/Entrypoint.sol"; @@ -49,7 +50,7 @@ contract AccountBenchmarkTest is BaseTest { address private nonSigner; // UserOp terminology: `sender` is the smart wallet. - address private sender = 0x0df2C3523703d165Aa7fA1a552f3F0B56275DfC6; + address private sender = 0xDD1d01438DcF28eb45a611c7faBD716B0dECE259; address payable private beneficiary = payable(address(0x45654)); bytes32 private uidCache = bytes32("random uid"); @@ -181,7 +182,17 @@ contract AccountBenchmarkTest is BaseTest { // Setup contracts entrypoint = new EntryPoint(); // deploy account factory - accountFactory = new AccountFactory(deployer, IEntryPoint(payable(address(entrypoint)))); + address factoryImpl = address(new AccountFactory(IEntryPoint(payable(address(entrypoint))))); + accountFactory = AccountFactory( + address( + payable( + new TWProxy( + factoryImpl, + abi.encodeWithSignature("initialize(address,string)", deployer, "https://example.com") + ) + ) + ) + ); // deploy dummy contract numberContract = new Number(); } diff --git a/src/test/smart-wallet/Account.t.sol b/src/test/smart-wallet/Account.t.sol index 13be02e91..80538ccb5 100644 --- a/src/test/smart-wallet/Account.t.sol +++ b/src/test/smart-wallet/Account.t.sol @@ -3,6 +3,7 @@ pragma solidity ^0.8.0; // Test utils import "../utils/BaseTest.sol"; +import { TWProxy } from "contracts/infra/TWProxy.sol"; // Account Abstraction setup for smart wallets. import { EntryPoint, IEntryPoint } from "contracts/prebuilts/account/utils/Entrypoint.sol"; @@ -49,7 +50,7 @@ contract SimpleAccountTest is BaseTest { address private nonSigner; // UserOp terminology: `sender` is the smart wallet. - address private sender = 0x0df2C3523703d165Aa7fA1a552f3F0B56275DfC6; + address private sender = 0xDD1d01438DcF28eb45a611c7faBD716B0dECE259; address payable private beneficiary = payable(address(0x45654)); bytes32 private uidCache = bytes32("random uid"); @@ -250,11 +251,31 @@ contract SimpleAccountTest is BaseTest { // Setup contracts entrypoint = new EntryPoint(); // deploy account factory - accountFactory = new AccountFactory(deployer, IEntryPoint(payable(address(entrypoint)))); + address factoryImpl = address(new AccountFactory(IEntryPoint(payable(address(entrypoint))))); + accountFactory = AccountFactory( + address( + payable( + new TWProxy( + factoryImpl, + abi.encodeWithSignature("initialize(address,string)", deployer, "https://example.com") + ) + ) + ) + ); // deploy dummy contract numberContract = new Number(); } + /*/////////////////////////////////////////////////////////////// + Test: initial state + //////////////////////////////////////////////////////////////*/ + + function test_initialState() external { + assertEq(accountFactory.entrypoint(), address(entrypoint)); + assertEq(accountFactory.contractURI(), "https://example.com"); + assertEq(accountFactory.hasRole(0x00, deployer), true); + } + /*/////////////////////////////////////////////////////////////// Test: creating an account //////////////////////////////////////////////////////////////*/ @@ -271,7 +292,7 @@ contract SimpleAccountTest is BaseTest { } /// @dev Create an account via Entrypoint. - function test_state_createAccount_viaEntrypoint() public { + function test_state_createAccount_viaEntrypointSingle() public { bytes memory initCallData = abi.encodeWithSignature("createAccount(address,bytes)", accountAdmin, bytes("")); bytes memory initCode = abi.encodePacked(abi.encodePacked(address(accountFactory)), initCallData); diff --git a/src/test/smart-wallet/AccountVulnPOC.t.sol b/src/test/smart-wallet/AccountVulnPOC.t.sol index f590da824..67df08504 100644 --- a/src/test/smart-wallet/AccountVulnPOC.t.sol +++ b/src/test/smart-wallet/AccountVulnPOC.t.sol @@ -6,6 +6,7 @@ import "../utils/BaseTest.sol"; // Account Abstraction setup for smart wallets. import { EntryPoint, IEntryPoint } from "contracts/prebuilts/account/utils/Entrypoint.sol"; import { UserOperation } from "contracts/prebuilts/account/utils/UserOperation.sol"; +import { TWProxy } from "contracts/infra/TWProxy.sol"; // Target import { IAccountPermissions } from "contracts/extension/interface/IAccountPermissions.sol"; @@ -73,7 +74,7 @@ contract SimpleAccountVulnPOCTest is BaseTest { address private nonSigner; // UserOp terminology: `sender` is the smart wallet. - address private sender = 0x0df2C3523703d165Aa7fA1a552f3F0B56275DfC6; + address private sender = 0xDD1d01438DcF28eb45a611c7faBD716B0dECE259; address payable private beneficiary = payable(address(0x45654)); bytes32 private uidCache = bytes32("random uid"); @@ -212,7 +213,17 @@ contract SimpleAccountVulnPOCTest is BaseTest { // Setup contracts entrypoint = new EntryPoint(); // deploy account factory - accountFactory = new AccountFactory(deployer, IEntryPoint(payable(address(entrypoint)))); + address factoryImpl = address(new AccountFactory(IEntryPoint(payable(address(entrypoint))))); + accountFactory = AccountFactory( + address( + payable( + new TWProxy( + factoryImpl, + abi.encodeWithSignature("initialize(address,string)", deployer, "https://example.com") + ) + ) + ) + ); // deploy dummy contract numberContract = new Number(); } diff --git a/src/test/smart-wallet/DynamicAccount.t.sol b/src/test/smart-wallet/DynamicAccount.t.sol index bbf12aa19..7054ad3b6 100644 --- a/src/test/smart-wallet/DynamicAccount.t.sol +++ b/src/test/smart-wallet/DynamicAccount.t.sol @@ -7,6 +7,7 @@ import "@thirdweb-dev/dynamic-contracts/src/interface/IExtension.sol"; import { IAccountPermissions } from "contracts/extension/interface/IAccountPermissions.sol"; import { AccountPermissions } from "contracts/extension/upgradeable/AccountPermissions.sol"; import { AccountExtension } from "contracts/prebuilts/account/utils/AccountExtension.sol"; +import { TWProxy } from "contracts/infra/TWProxy.sol"; // Account Abstraction setup for smart wallets. import { EntryPoint, IEntryPoint } from "contracts/prebuilts/account/utils/Entrypoint.sol"; @@ -65,7 +66,7 @@ contract DynamicAccountTest is BaseTest { bytes internal data = bytes(""); // UserOp terminology: `sender` is the smart wallet. - address private sender = 0x78b942FBC9126b4Ed8384Bb9dd1420Ea865be91a; + address private sender = 0x96b1d554981298ED415f2D5788A6D093A39eECfF; address payable private beneficiary = payable(address(0x45654)); bytes32 private uidCache = bytes32("random uid"); @@ -319,7 +320,17 @@ contract DynamicAccountTest is BaseTest { extensions[0] = defaultExtension; // deploy account factory - accountFactory = new DynamicAccountFactory(deployer, extensions); + address factoryImpl = address(new DynamicAccountFactory(extensions)); + accountFactory = DynamicAccountFactory( + address( + payable( + new TWProxy( + factoryImpl, + abi.encodeWithSignature("initialize(address,string)", deployer, "https://example.com") + ) + ) + ) + ); // deploy dummy contract numberContract = new Number(); } @@ -374,7 +385,17 @@ contract DynamicAccountTest is BaseTest { extensions[0] = defaultExtension; // deploy account factory - DynamicAccountFactory factory = new DynamicAccountFactory(deployer, extensions); + address factoryImpl = address(new DynamicAccountFactory(extensions)); + DynamicAccountFactory factory = DynamicAccountFactory( + address( + payable( + new TWProxy( + factoryImpl, + abi.encodeWithSignature("initialize(address,string)", deployer, "https://example.com") + ) + ) + ) + ); } /// @dev Create an account by directly calling the factory. @@ -389,7 +410,7 @@ contract DynamicAccountTest is BaseTest { } /// @dev Create an account via Entrypoint. - function test_state_createAccount_viaEntrypoint() public { + function test_state_createAccount_viaEntrypointSingle() public { bytes memory initCallData = abi.encodeWithSignature("createAccount(address,bytes)", accountAdmin, data); bytes memory initCode = abi.encodePacked(abi.encodePacked(address(accountFactory)), initCallData); diff --git a/src/test/smart-wallet/ManagedAccount.t.sol b/src/test/smart-wallet/ManagedAccount.t.sol index dd53a920f..36ab1c4d4 100644 --- a/src/test/smart-wallet/ManagedAccount.t.sol +++ b/src/test/smart-wallet/ManagedAccount.t.sol @@ -7,6 +7,7 @@ import "@thirdweb-dev/dynamic-contracts/src/interface/IExtension.sol"; import { IAccountPermissions } from "contracts/extension/interface/IAccountPermissions.sol"; import { AccountPermissions } from "contracts/extension/upgradeable/AccountPermissions.sol"; import { AccountExtension } from "contracts/prebuilts/account/utils/AccountExtension.sol"; +import { TWProxy } from "contracts/infra/TWProxy.sol"; // Account Abstraction setup for smart wallets. import { EntryPoint, IEntryPoint } from "contracts/prebuilts/account/utils/Entrypoint.sol"; @@ -66,7 +67,7 @@ contract ManagedAccountTest is BaseTest { address private nonSigner; // UserOp terminology: `sender` is the smart wallet. - address private sender = 0xbEA1Fa134A1727187A8f2e7E714B660f3a95478D; + address private sender = 0x48670c84959b8854A9036D88a020382bA89a47Ce; address payable private beneficiary = payable(address(0x45654)); bytes32 private uidCache = bytes32("random uid"); @@ -320,10 +321,16 @@ contract ManagedAccountTest is BaseTest { // deploy account factory vm.prank(factoryDeployer); - accountFactory = new ManagedAccountFactory( - factoryDeployer, - IEntryPoint(payable(address(entrypoint))), - extensions + address factoryImpl = address(new ManagedAccountFactory(IEntryPoint(payable(address(entrypoint))), extensions)); + accountFactory = ManagedAccountFactory( + payable( + address( + new TWProxy( + factoryImpl, + abi.encodeWithSignature("initialize(address,string)", factoryDeployer, "https://example.com") + ) + ) + ) ); // deploy dummy contract numberContract = new Number(); @@ -376,10 +383,16 @@ contract ManagedAccountTest is BaseTest { // deploy account factory vm.prank(factoryDeployer); - ManagedAccountFactory factory = new ManagedAccountFactory( - factoryDeployer, - IEntryPoint(payable(address(entrypoint))), - extensions + address factoryImpl = address(new ManagedAccountFactory(IEntryPoint(payable(address(entrypoint))), extensions)); + ManagedAccountFactory factory = ManagedAccountFactory( + payable( + address( + new TWProxy( + factoryImpl, + abi.encodeWithSignature("initialize(address,string)", deployer, "https://example.com") + ) + ) + ) ); assertTrue(address(factory) != address(0), "factory address should not be zero"); } @@ -400,7 +413,7 @@ contract ManagedAccountTest is BaseTest { } /// @dev Create an account via Entrypoint. - function test_state_createAccount_viaEntrypoint() public { + function test_state_createAccount_viaEntrypointSingle() public { bytes memory initCallData = abi.encodeWithSignature("createAccount(address,bytes)", accountAdmin, data); bytes memory initCode = abi.encodePacked(abi.encodePacked(address(accountFactory)), initCallData); diff --git a/src/test/smart-wallet/account-core/isValidSigner.t.sol b/src/test/smart-wallet/account-core/isValidSigner.t.sol index c289ca237..8d9c45bbc 100644 --- a/src/test/smart-wallet/account-core/isValidSigner.t.sol +++ b/src/test/smart-wallet/account-core/isValidSigner.t.sol @@ -7,6 +7,7 @@ import "contracts/external-deps/openzeppelin/proxy/Clones.sol"; import "@thirdweb-dev/dynamic-contracts/src/interface/IExtension.sol"; import { IAccountPermissions } from "contracts/extension/interface/IAccountPermissions.sol"; import { AccountPermissions, EnumerableSet, ECDSA } from "contracts/extension/upgradeable/AccountPermissions.sol"; +import { TWProxy } from "contracts/infra/TWProxy.sol"; // Account Abstraction setup for smart wallets. import { EntryPoint, IEntryPoint } from "contracts/prebuilts/account/utils/Entrypoint.sol"; @@ -190,14 +191,24 @@ contract AccountCoreTest_isValidSigner is BaseTest { IExtension.Extension[] memory extensions; // deploy account factory - accountFactory = new DynamicAccountFactory(deployer, extensions); + address factoryImpl = address(new DynamicAccountFactory(extensions)); + accountFactory = DynamicAccountFactory( + address( + payable( + new TWProxy( + factoryImpl, + abi.encodeWithSignature("initialize(address,string)", deployer, "https://example.com") + ) + ) + ) + ); // deploy dummy contract numberContract = new Number(); address accountImpl = address(new MyDynamicAccount(IEntryPoint(payable(address(entrypoint))), extensions)); address _account = Clones.cloneDeterministic(accountImpl, "salt"); account = MyDynamicAccount(payable(_account)); - account.initialize(accountAdmin, ""); + account.initialize(accountAdmin, address(this), ""); } function test_isValidSigner_whenSignerIsAdmin() public { diff --git a/src/test/smart-wallet/account-permissions/setPermissionsForSigner.t.sol b/src/test/smart-wallet/account-permissions/setPermissionsForSigner.t.sol index 96ef0d462..3d3adca66 100644 --- a/src/test/smart-wallet/account-permissions/setPermissionsForSigner.t.sol +++ b/src/test/smart-wallet/account-permissions/setPermissionsForSigner.t.sol @@ -7,6 +7,7 @@ import "@thirdweb-dev/dynamic-contracts/src/interface/IExtension.sol"; import { IAccountPermissions } from "contracts/extension/interface/IAccountPermissions.sol"; import { AccountPermissions } from "contracts/extension/upgradeable/AccountPermissions.sol"; import { AccountExtension } from "contracts/prebuilts/account/utils/AccountExtension.sol"; +import { TWProxy } from "contracts/infra/TWProxy.sol"; // Account Abstraction setup for smart wallets. import { EntryPoint, IEntryPoint } from "contracts/prebuilts/account/utils/Entrypoint.sol"; @@ -73,7 +74,7 @@ contract AccountPermissionsTest_setPermissionsForSigner is BaseTest { bytes internal data = bytes(""); // UserOp terminology: `sender` is the smart wallet. - address private sender = 0x78b942FBC9126b4Ed8384Bb9dd1420Ea865be91a; + address private sender = 0x96b1d554981298ED415f2D5788A6D093A39eECfF; address payable private beneficiary = payable(address(0x45654)); bytes32 private uidCache = bytes32("random uid"); @@ -267,7 +268,17 @@ contract AccountPermissionsTest_setPermissionsForSigner is BaseTest { extensions[0] = defaultExtension; // deploy account factory - accountFactory = new DynamicAccountFactory(deployer, extensions); + address factoryImpl = address(new DynamicAccountFactory(extensions)); + accountFactory = DynamicAccountFactory( + address( + payable( + new TWProxy( + factoryImpl, + abi.encodeWithSignature("initialize(address,string)", deployer, "https://example.com") + ) + ) + ) + ); // deploy dummy contract numberContract = new Number(); }