Skip to content

Commit

Permalink
refactor: more xdist friendly
Browse files Browse the repository at this point in the history
  • Loading branch information
antazoey committed Dec 14, 2024
1 parent b72b965 commit 2080314
Showing 1 changed file with 16 additions and 18 deletions.
34 changes: 16 additions & 18 deletions src/ape/managers/_contractscache.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from contextlib import contextmanager
from functools import cached_property
from pathlib import Path
from typing import TYPE_CHECKING, Generic, Optional, TypeVar, Union, cast
from typing import TYPE_CHECKING, Generic, Optional, TypeVar, Union

from ethpm_types import ABI, ContractType
from pydantic import BaseModel
Expand Down Expand Up @@ -618,14 +618,20 @@ def instance_at(
Returns:
:class:`~ape.contracts.base.ContractInstance`
"""

if self.conversion_manager.is_type(address, AddressType):
contract_address = cast(AddressType, address)
else:
if contract_type and not isinstance(contract_type, ContractType):
prefix = f"Expected type '{ContractType.__name__}' for argument 'contract_type'"
try:
contract_address = self.conversion_manager.convert(address, AddressType)
except ConversionError as err:
raise ValueError(f"Unknown address value '{address}'.") from err
suffix = f"; Given '{type(contract_type).__name__}'."
except Exception:
suffix = "."

raise TypeError(f"{prefix}{suffix}")

try:
contract_address = self.conversion_manager.convert(address, AddressType)
except ConversionError:
# Attempt as str.
raise ValueError(f"Unknown address value '{address}'.")

try:
# Always attempt to get an existing contract type to update caches
Expand Down Expand Up @@ -688,18 +694,10 @@ def instance_at(
self.provider.network_choice,
)

elif not isinstance(contract_type, ContractType):
prefix = f"Expected type '{ContractType.__name__}' for argument 'contract_type'"
try:
suffix = f"; Given '{type(contract_type).__name__}'."
except Exception:
suffix = "."

raise TypeError(f"{prefix}{suffix}")

if not txn_hash:
# Check for txn_hash in deployments.
deployments = self.deployments[contract_type.name or ""]
contract_name = getattr(contract_type, "name", f"{contract_type}") or ""
deployments = self.deployments[contract_name]
for deployment in deployments[::-1]:
if deployment.address == contract_address and deployment.transaction_hash:
txn_hash = deployment.transaction_hash
Expand Down

0 comments on commit 2080314

Please sign in to comment.