Skip to content

Commit

Permalink
fix: test-account negative index did not refer to correct account (#2444
Browse files Browse the repository at this point in the history
)

Co-authored-by: antazoey <antyzoa@gmail.com>
  • Loading branch information
antazoey and antazoey authored Jan 13, 2025
1 parent 316379a commit 3855bf9
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 15 deletions.
21 changes: 13 additions & 8 deletions src/ape/managers/accounts.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,7 @@ class TestAccountManager(list, ManagerAccessMixin):

@log_instead_of_fail(default="<TestAccountManager>")
def __repr__(self) -> str:
accounts_str = ", ".join([a.address for a in self.accounts])
return f"[{accounts_str}]"
return f"<apetest-wallet {self.hd_path}>"

@cached_property
def containers(self) -> dict[str, TestAccountContainerAPI]:
Expand All @@ -54,6 +53,13 @@ def containers(self) -> dict[str, TestAccountContainerAPI]:
for plugin_name, (container_type, account_type) in account_types
}

@property
def hd_path(self) -> str:
"""
The HD path used for generating the test accounts.
"""
return self.config_manager.get_config("test").hd_path

@property
def accounts(self) -> Iterator[AccountAPI]:
for container in self.containers.values():
Expand All @@ -76,15 +82,14 @@ def __getitem__(self, account_id):

@__getitem__.register
def __getitem_int(self, account_id: int):
if account_id in self._accounts_by_index:
return self._accounts_by_index[account_id]

original_account_id = account_id
if account_id < 0:
account_id = len(self) + account_id

if account_id in self._accounts_by_index:
return self._accounts_by_index[account_id]

account = self.containers["test"].get_test_account(account_id)
self._accounts_by_index[original_account_id] = account
self._accounts_by_index[account_id] = account
return account

@__getitem__.register
Expand Down Expand Up @@ -265,7 +270,7 @@ def __iter__(self) -> Iterator[AccountAPI]:

@log_instead_of_fail(default="<AccountManager>")
def __repr__(self) -> str:
return "[" + ", ".join(repr(a) for a in self) + "]"
return "<AccountManager>"

@cached_property
def test_accounts(self) -> TestAccountManager:
Expand Down
5 changes: 5 additions & 0 deletions src/ape_test/accounts.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from ape.exceptions import ProviderNotConnectedError, SignatureError
from ape.types.signatures import MessageSignature, TransactionSignature
from ape.utils._web3_compat import sign_hash
from ape.utils.misc import log_instead_of_fail
from ape.utils.testing import (
DEFAULT_NUMBER_OF_TEST_ACCOUNTS,
DEFAULT_TEST_HD_PATH,
Expand Down Expand Up @@ -113,6 +114,10 @@ def alias(self) -> str:
def address(self) -> "AddressType":
return self.network_manager.ethereum.decode_address(self.address_str)

@log_instead_of_fail(default="<TestAccount>")
def __repr__(self) -> str:
return f"<{self.__class__.__name__}_{self.index} {self.address_str}>"

def sign_message(self, msg: Any, **signer_options) -> Optional[MessageSignature]:
# Convert str and int to SignableMessage if needed
if isinstance(msg, str):
Expand Down
41 changes: 34 additions & 7 deletions tests/functional/test_accounts.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,12 +412,19 @@ def test_send_transaction_sets_defaults(sender, receiver):
assert receipt.required_confirmations == 0


def test_account_index_access(accounts):
account = accounts[0]
assert account.index == 0
last_account = accounts[-1]
assert last_account.index == len(accounts) - 1


def test_accounts_splice_access(accounts):
a, b = accounts[:2]
assert a == accounts[0]
assert b == accounts[1]
c = accounts[-1]
assert c == accounts[len(accounts) - 1]
alice, bob = accounts[:2]
assert alice == accounts[0]
assert bob == accounts[1]
cat = accounts[-1]
assert cat == accounts[len(accounts) - 1]
expected = (len(accounts) // 2) if len(accounts) % 2 == 0 else (len(accounts) // 2 + 1)
assert len(accounts[::2]) == expected

Expand Down Expand Up @@ -612,9 +619,9 @@ def test_custom_num_of_test_accounts_config(accounts, project):
assert len(accounts) == custom_number_of_test_accounts


def test_test_accounts_repr(accounts):
def test_test_accounts_repr(accounts, config):
actual = repr(accounts)
assert all(a.address in actual for a in accounts)
assert config.get_config("test").hd_path in actual


def test_account_comparison_to_non_account(core_account):
Expand All @@ -629,11 +636,20 @@ def test_create_account(accounts):
assert isinstance(created_account, TestAccount)
assert created_account.index == length_at_start

length_at_start = len(accounts)
second_created_account = accounts.generate_test_account()
assert len(accounts) == length_at_start + 1

assert created_account.address != second_created_account.address
assert second_created_account.index == created_account.index + 1

# Last index should now refer to the last-created account.
last_idx_acct = accounts[-1]
assert last_idx_acct.index == second_created_account.index
assert last_idx_acct.address == second_created_account.address
assert last_idx_acct.address != accounts[0].address
assert last_idx_acct.address != created_account.address


def test_dir(core_account):
actual = dir(core_account)
Expand Down Expand Up @@ -951,3 +967,14 @@ def test_get_deployment_address(owner, vyper_contract_container):
assert instance_1.address == deployment_address_1
instance_2 = owner.deploy(vyper_contract_container, 490)
assert instance_2.address == deployment_address_2


def test_repr(account_manager):
"""
NOTE: __repr__ should be simple and fast!
Previously, we showed the repr of all the accounts.
That was a bad idea, as that can be very unnecessarily slow.
Hence, this test exists to ensure care is taken.
"""
actual = repr(account_manager)
assert actual == "<AccountManager>"

0 comments on commit 3855bf9

Please sign in to comment.