diff --git a/docs/userguides/contracts.md b/docs/userguides/contracts.md index c06c0f4f4c..ea13a3cb76 100644 --- a/docs/userguides/contracts.md +++ b/docs/userguides/contracts.md @@ -178,8 +178,9 @@ or ```python from ape import project + def main(): - my_contract = project.MyContract.deployments[-1] + my_contract = project.MyContract.deployments[-1] ``` `my_contract` will be of type `ContractInstance`. diff --git a/src/ape/api/transactions.py b/src/ape/api/transactions.py index 930ac26b10..9dcf89f915 100644 --- a/src/ape/api/transactions.py +++ b/src/ape/api/transactions.py @@ -73,7 +73,7 @@ def __init__(self, *args, **kwargs): @classmethod def validate_gas_limit(cls, value): if value is None: - if not cls.network_manager.active_provider: + if not cls.network_manager.connected: raise NetworkError("Must be connected to use default gas config.") value = cls.network_manager.active_provider.network.gas_limit @@ -82,7 +82,7 @@ def validate_gas_limit(cls, value): return None # Delegate to `ProviderAPI.estimate_gas_cost` elif value == "max": - if not cls.network_manager.active_provider: + if not cls.network_manager.connected: raise NetworkError("Must be connected to use 'max'.") return cls.network_manager.active_provider.max_gas diff --git a/src/ape/contracts/base.py b/src/ape/contracts/base.py index 25bbdd6bf3..a20d1b1b74 100644 --- a/src/ape/contracts/base.py +++ b/src/ape/contracts/base.py @@ -1054,7 +1054,7 @@ def from_receipt( # Cache creation. creation = ContractCreation.from_receipt(receipt) - cls.chain_manager.contracts._local_contract_creation[address] = creation + cls.chain_manager.contracts.contract_creations[address] = creation return instance diff --git a/src/ape/managers/_contractscache.py b/src/ape/managers/_contractscache.py new file mode 100644 index 0000000000..fe5677bc3e --- /dev/null +++ b/src/ape/managers/_contractscache.py @@ -0,0 +1,842 @@ +import json +from collections.abc import Collection +from concurrent.futures import ThreadPoolExecutor +from contextlib import contextmanager +from functools import cached_property +from pathlib import Path +from typing import TYPE_CHECKING, Generic, Optional, TypeVar, Union + +from ethpm_types import ABI, ContractType +from pydantic import BaseModel + +from ape.api.networks import ProxyInfoAPI +from ape.api.query import ContractCreation, ContractCreationQuery +from ape.contracts.base import ContractContainer, ContractInstance +from ape.exceptions import ApeException, ContractNotFoundError, ConversionError, CustomError +from ape.logging import logger +from ape.managers._deploymentscache import Deployment, DeploymentDiskCache +from ape.managers.base import BaseManager +from ape.types.address import AddressType +from ape.utils.misc import nonreentrant +from ape.utils.os import CacheDirectory + +if TYPE_CHECKING: + from eth_pydantic_types import HexBytes + + from ape.api.transactions import ReceiptAPI + + +_BASE_MODEL = TypeVar("_BASE_MODEL", bound=BaseModel) + + +class ApeDataCache(CacheDirectory, Generic[_BASE_MODEL]): + """ + A wrapper around some cached models in the data directory, + such as the cached contract types. + """ + + def __init__( + self, + base_data_folder: Path, + ecosystem_key: str, + network_key: str, + key: str, + model_type: type[_BASE_MODEL], + ): + data_folder = base_data_folder / ecosystem_key + base_path = data_folder / network_key + self._model_type = model_type + self.memory: dict[str, _BASE_MODEL] = {} + + # Only write if we are not testing! + self._write_to_disk = not network_key.endswith("-fork") and network_key != "local" + # Read from disk if using forks or live networks. + self._read_from_disk = network_key.endswith("-fork") or network_key != "local" + + super().__init__(base_path / key) + + def __getitem__(self, key: str) -> Optional[_BASE_MODEL]: # type: ignore + return self.get_type(key) + + def __setitem__(self, key: str, value: _BASE_MODEL): # type: ignore + self.memory[key] = value + if self._write_to_disk: + # Cache to disk. + self.cache_data(key, value.model_dump(mode="json")) + + def __delitem__(self, key: str): + self.memory.pop(key, None) + if self._write_to_disk: + # Delete the cache file. + self.delete_data(key) + + def __contains__(self, key: str) -> bool: + try: + return bool(self[key]) + except KeyError: + return False + + def get_type(self, key: str) -> Optional[_BASE_MODEL]: + if model := self.memory.get(key): + return model + + elif self._read_from_disk: + if data := self.get_data(key): + # Found on disk. + model = self._model_type.model_validate(data) + # Cache locally for next time. + self.memory[key] = model + return model + + return None + + +class ContractCache(BaseManager): + """ + A collection of cached contracts. Contracts can be cached in two ways: + + 1. An in-memory cache of locally deployed contracts + 2. A cache of contracts per network (only permanent networks are stored this way) + + When retrieving a contract, if a :class:`~ape.api.explorers.ExplorerAPI` is used, + it will be cached to disk for faster look-up next time. + """ + + # ecosystem_name -> network_name -> cache_name -> cache + _caches: dict[str, dict[str, dict[str, ApeDataCache]]] = {} + + # chain_id -> address -> custom_err + # Cached to prevent calling `new_class` multiple times with conflicts. + _custom_error_types: dict[int, dict[AddressType, set[type[CustomError]]]] = {} + + @property + def contract_types(self) -> ApeDataCache[ContractType]: + return self._get_data_cache("contract_types", ContractType) + + @property + def proxy_infos(self) -> ApeDataCache[ProxyInfoAPI]: + return self._get_data_cache("proxy_info", ProxyInfoAPI) + + @property + def blueprints(self) -> ApeDataCache[ContractType]: + return self._get_data_cache("blueprints", ContractType) + + @property + def contract_creations(self) -> ApeDataCache[ContractCreation]: + return self._get_data_cache("contract_creation", ContractCreation) + + def _get_data_cache( + self, + key: str, + model_type: type, + ecosystem_key: Optional[str] = None, + network_key: Optional[str] = None, + ): + ecosystem_name = ecosystem_key or self.provider.network.ecosystem.name + network_name = network_key or self.provider.network.name.replace("-fork", "") + self._caches.setdefault(ecosystem_name, {}) + self._caches[ecosystem_name].setdefault(network_name, {}) + + if cache := self._caches[ecosystem_name][network_name].get(key): + return cache + + self._caches[ecosystem_name][network_name][key] = ApeDataCache( + self.config_manager.DATA_FOLDER, ecosystem_name, network_name, key, model_type + ) + return self._caches[ecosystem_name][network_name][key] + + @cached_property + def deployments(self) -> DeploymentDiskCache: + """A manager for contract deployments across networks.""" + return DeploymentDiskCache() + + def __setitem__( + self, address: AddressType, item: Union[ContractType, ProxyInfoAPI, ContractCreation] + ): + """ + Cache the given contract type. Contracts are cached in memory per session. + In live networks, contracts also get cached to disk at + ``.ape/{ecosystem_name}/{network_name}/contract_types/{address}.json`` + for faster look-up next time. + + Args: + address (AddressType): The on-chain address of the contract. + item (ContractType | ProxyInfoAPI | ContractCreation): The contract's type, proxy info, + or creation metadata. + """ + # Note: Can't cache blueprints this way. + address = self.provider.network.ecosystem.decode_address(int(address, 16)) + if isinstance(item, ContractType): + self.cache_contract_type(address, item) + elif isinstance(item, ProxyInfoAPI): + self.cache_proxy_info(address, item) + elif isinstance(item, ContractCreation): + self.cache_contract_creation(address, item) + elif contract_type := getattr(item, "contract_type", None): + self.cache_contract_type(address, contract_type) + else: + raise TypeError(item) + + def cache_contract_type( + self, + address: AddressType, + contract_type: ContractType, + ecosystem_key: Optional[str] = None, + network_key: Optional[str] = None, + ): + """ + Cache a contract type at the given address for the given network. + If not connected, you must provider both ``ecosystem_key:`` and + ``network_key::``. + + Args: + address (AddressType): The address key. + contract_type (ContractType): The contract type to cache. + ecosystem_key (str | None): The ecosystem key. Defaults to + the connected ecosystem's name. + network_key (str | None): The network key. Defaults to the + connected network's name. + """ + # Get the cache in a way that doesn't require an active connection. + cache = self._get_data_cache( + "contract_types", ContractType, ecosystem_key=ecosystem_key, network_key=network_key + ) + cache[address] = contract_type + + # NOTE: The txn_hash is not included when caching this way. + if name := contract_type.name: + self.deployments.cache_deployment( + address, name, ecosystem_key=ecosystem_key, network_key=network_key + ) + + def cache_contract_creation( + self, + address: AddressType, + contract_creation: ContractCreation, + ecosystem_key: Optional[str] = None, + network_key: Optional[str] = None, + ): + """ + Cache a contract creation object. + + Args: + address (AddressType): The address of the contract. + contract_creation (ContractCreation): The object to cache. + ecosystem_key (str | None): The ecosystem key. Defaults to + the connected ecosystem's name. + network_key (str | None): The network key. Defaults to the + connected network's name. + """ + # Get the cache in a way that doesn't require an active connection. + cache = self._get_data_cache( + "contract_creation", + ContractCreation, + ecosystem_key=ecosystem_key, + network_key=network_key, + ) + cache[address] = contract_creation + + def __delitem__(self, address: AddressType): + """ + Delete a cached contract. + If using a live network, it will also delete the file-cache for the contract. + + Args: + address (AddressType): The address to remove from the cache. + """ + del self.contract_types[address] + self._delete_proxy(address) + del self.contract_creations[address] + + @contextmanager + def use_temporary_caches(self): + """ + Create temporary context where there are no cached items. + Useful for testing. + """ + caches = self._caches + self._caches = {} + with self.deployments.use_temporary_cache(): + yield + + self._caches = caches + + def _delete_proxy(self, address: AddressType): + if info := self.proxy_infos[address]: + target = info.target + del self.proxy_infos[target] + del self.contract_types[target] + + def __contains__(self, address: AddressType) -> bool: + return self.get(address) is not None + + def cache_deployment(self, contract_instance: ContractInstance): + """ + Cache the given contract instance's type and deployment information. + + Args: + contract_instance (:class:`~ape.contracts.base.ContractInstance`): The contract + to cache. + """ + address = contract_instance.address + contract_type = contract_instance.contract_type # may be a proxy + + # Cache contract type in memory before proxy check, + # in case it is needed somewhere. It may get overridden. + self.contract_types.memory[address] = contract_type + + if proxy_info := self.provider.network.ecosystem.get_proxy_info(address): + # The user is caching a deployment of a proxy with the target already set. + self.cache_proxy_info(address, proxy_info) + if implementation_contract := self.get(proxy_info.target): + updated_proxy_contract = _get_combined_contract_type( + contract_type, proxy_info, implementation_contract + ) + self.contract_types[address] = updated_proxy_contract + + # Use this contract type in the user's contract instance. + contract_instance.contract_type = updated_proxy_contract + + else: + # No implementation yet. Just cache proxy. + self.contract_types[address] = contract_type + + else: + # Regular contract. Cache normally. + self.contract_types[address] = contract_type + + # Cache the deployment now. + txn_hash = contract_instance.txn_hash + if contract_name := contract_type.name: + self.deployments.cache_deployment(address, contract_name, transaction_hash=txn_hash) + + return contract_type + + def cache_proxy_info(self, address: AddressType, proxy_info: ProxyInfoAPI): + """ + Cache proxy info for a particular address, useful for plugins adding already + deployed proxies. When you deploy a proxy locally, it will also call this method. + + Args: + address (AddressType): The address of the proxy contract. + proxy_info (:class:`~ape.api.networks.ProxyInfo`): The proxy info class + to cache. + """ + self.proxy_infos[address] = proxy_info + + def cache_blueprint(self, blueprint_id: str, contract_type: ContractType): + """ + Cache a contract blueprint. + + Args: + blueprint_id (``str``): The ID of the blueprint. For example, in EIP-5202, + it would be the address of the deployed blueprint. For Starknet, it would + be the class identifier. + contract_type (``ContractType``): The contract type associated with the blueprint. + """ + self.blueprints[blueprint_id] = contract_type + + def get_proxy_info(self, address: AddressType) -> Optional[ProxyInfoAPI]: + """ + Get proxy information about a contract using its address, + either from a local cache, a disk cache, or the provider. + + Args: + address (AddressType): The address of the proxy contract. + + Returns: + Optional[:class:`~ape.api.networks.ProxyInfoAPI`] + """ + return self.proxy_infos[address] + + def get_creation_metadata(self, address: AddressType) -> Optional[ContractCreation]: + """ + Get contract creation metadata containing txn_hash, deployer, factory, block. + + Args: + address (AddressType): The address of the contract. + + Returns: + Optional[:class:`~ape.api.query.ContractCreation`] + """ + if creation := self.contract_creations[address]: + return creation + + # Query and cache. + query = ContractCreationQuery(columns=["*"], contract=address) + get_creation = self.query_manager.query(query) + + try: + if not (creation := next(get_creation, None)): # type: ignore[arg-type] + return None + + except ApeException: + return None + + self.contract_creations[address] = creation + return creation + + def get_blueprint(self, blueprint_id: str) -> Optional[ContractType]: + """ + Get a cached blueprint contract type. + + Args: + blueprint_id (``str``): The unique identifier used when caching + the blueprint. + + Returns: + ``ContractType`` + """ + return self.blueprints[blueprint_id] + + def _get_errors( + self, address: AddressType, chain_id: Optional[int] = None + ) -> set[type[CustomError]]: + if chain_id is None and self.network_manager.active_provider is not None: + chain_id = self.provider.chain_id + elif chain_id is None: + raise ValueError("Missing chain ID.") + + if chain_id not in self._custom_error_types: + return set() + + errors = self._custom_error_types[chain_id] + if address in errors: + return errors[address] + + return set() + + def _cache_error( + self, address: AddressType, error: type[CustomError], chain_id: Optional[int] = None + ): + if chain_id is None and self.network_manager.active_provider is not None: + chain_id = self.provider.chain_id + elif chain_id is None: + raise ValueError("Missing chain ID.") + + if chain_id not in self._custom_error_types: + self._custom_error_types[chain_id] = {address: set()} + elif address not in self._custom_error_types[chain_id]: + self._custom_error_types[chain_id][address] = set() + + self._custom_error_types[chain_id][address].add(error) + + def __getitem__(self, address: AddressType) -> ContractType: + contract_type = self.get(address) + if not contract_type: + # Create error message from custom exception cls. + err = ContractNotFoundError( + address, self.provider.network.explorer is not None, self.provider.network_choice + ) + # Must raise KeyError. + raise KeyError(str(err)) + + return contract_type + + def get_multiple( + self, addresses: Collection[AddressType], concurrency: Optional[int] = None + ) -> dict[AddressType, ContractType]: + """ + Get contract types for all given addresses. + + Args: + addresses (list[AddressType): A list of addresses to get contract types for. + concurrency (Optional[int]): The number of threads to use. Defaults to + ``min(4, len(addresses))``. + + Returns: + dict[AddressType, ContractType]: A mapping of addresses to their respective + contract types. + """ + if not addresses: + logger.warning("No addresses provided.") + return {} + + def get_contract_type(addr: AddressType): + addr = self.conversion_manager.convert(addr, AddressType) + ct = self.get(addr) + + if not ct: + logger.warning(f"Failed to locate contract at '{addr}'.") + return addr, None + else: + return addr, ct + + converted_addresses: list[AddressType] = [] + for address in converted_addresses: + if not self.conversion_manager.is_type(address, AddressType): + converted_address = self.conversion_manager.convert(address, AddressType) + converted_addresses.append(converted_address) + else: + converted_addresses.append(address) + + contract_types = {} + default_max_threads = 4 + max_threads = ( + concurrency + if concurrency is not None + else min(len(addresses), default_max_threads) or default_max_threads + ) + with ThreadPoolExecutor(max_workers=max_threads) as pool: + for address, contract_type in pool.map(get_contract_type, addresses): + if contract_type is None: + continue + + contract_types[address] = contract_type + + return contract_types + + @nonreentrant(key_fn=lambda *args, **kwargs: args[1]) + def get( + self, + address: AddressType, + default: Optional[ContractType] = None, + fetch_from_explorer: bool = True, + ) -> Optional[ContractType]: + """ + Get a contract type by address. + If the contract is cached, it will return the contract from the cache. + Otherwise, if on a live network, it fetches it from the + :class:`~ape.api.explorers.ExplorerAPI`. + + Args: + address (AddressType): The address of the contract. + default (Optional[ContractType]): A default contract when none is found. + Defaults to ``None``. + fetch_from_explorer (bool): Set to ``False`` to avoid fetching from an + explorer. Defaults to ``True``. Only fetches if it needs to (uses disk + & memory caching otherwise). + + Returns: + Optional[ContractType]: The contract type if it was able to get one, + otherwise the default parameter. + """ + try: + address_key: AddressType = self.conversion_manager.convert(address, AddressType) + except ConversionError: + if not address.startswith("0x"): + # Still raise conversion errors for ENS and such. + raise + + # In this case, it at least _looked_ like an address. + return None + + if contract_type := self.contract_types[address_key]: + if default and default != contract_type: + # Replacing contract type + self.contract_types[address_key] = default + return default + + return contract_type + + else: + # Contract is not cached yet. Check broader sources, such as an explorer. + # First, detect if this is a proxy. + if not (proxy_info := self.proxy_infos[address_key]): + if proxy_info := self.provider.network.ecosystem.get_proxy_info(address_key): + self.proxy_infos[address_key] = proxy_info + + if proxy_info: + # Contract is a proxy. + implementation_contract_type = self.get(proxy_info.target, default=default) + proxy_contract_type = ( + self._get_contract_type_from_explorer(address_key) + if fetch_from_explorer + else None + ) + if proxy_contract_type: + contract_type_to_cache = _get_combined_contract_type( + proxy_contract_type, proxy_info, implementation_contract_type + ) + else: + contract_type_to_cache = implementation_contract_type + + self.contract_types[address_key] = contract_type_to_cache + return contract_type_to_cache + + if not self.provider.get_code(address_key): + if default: + self.contract_types[address_key] = default + + return default + + # Also gets cached to disk for faster lookup next time. + if fetch_from_explorer: + contract_type = self._get_contract_type_from_explorer(address_key) + + # Cache locally for faster in-session look-up. + if contract_type: + self.contract_types[address_key] = contract_type + elif default: + self.contract_types[address_key] = default + return default + + return contract_type + + @classmethod + def get_container(cls, contract_type: ContractType) -> ContractContainer: + """ + Get a contract container for the given contract type. + + Args: + contract_type (ContractType): The contract type to wrap. + + Returns: + ContractContainer: A container object you can deploy. + """ + + return ContractContainer(contract_type) + + def instance_at( + self, + address: Union[str, AddressType], + contract_type: Optional[ContractType] = None, + txn_hash: Optional[Union[str, "HexBytes"]] = None, + abi: Optional[Union[list[ABI], dict, str, Path]] = None, + fetch_from_explorer: bool = True, + ) -> ContractInstance: + """ + Get a contract at the given address. If the contract type of the contract is known, + either from a local deploy or a :class:`~ape.api.explorers.ExplorerAPI`, it will use that + contract type. You can also provide the contract type from which it will cache and use + next time. + + Raises: + TypeError: When passing an invalid type for the `contract_type` arguments + (expects `ContractType`). + :class:`~ape.exceptions.ContractNotFoundError`: When the contract type is not found. + + Args: + address (Union[str, AddressType]): The address of the plugin. If you are using the ENS + plugin, you can also provide an ENS domain name. + contract_type (Optional[``ContractType``]): Optionally provide the contract type + in case it is not already known. + txn_hash (Optional[Union[str, HexBytes]]): The hash of the transaction responsible for + deploying the contract, if known. Useful for publishing. Defaults to ``None``. + abi (Optional[Union[list[ABI], dict, str, Path]]): Use an ABI str, dict, path, + or ethpm models to create a contract instance class. + fetch_from_explorer (bool): Set to ``False`` to avoid fetching from the explorer. + Defaults to ``True``. Won't fetch unless it needs to (uses disk & memory caching + first). + + Returns: + :class:`~ape.contracts.base.ContractInstance` + """ + if contract_type and not isinstance(contract_type, ContractType): + prefix = f"Expected type '{ContractType.__name__}' for argument 'contract_type'" + try: + suffix = f"; Given '{type(contract_type).__name__}'." + except Exception: + suffix = "." + + raise TypeError(f"{prefix}{suffix}") + + try: + contract_address = self.conversion_manager.convert(address, AddressType) + except ConversionError: + # Attempt as str. + raise ValueError(f"Unknown address value '{address}'.") + + try: + # Always attempt to get an existing contract type to update caches + contract_type = self.get( + contract_address, default=contract_type, fetch_from_explorer=fetch_from_explorer + ) + except Exception as err: + if contract_type or abi: + # If a default contract type was provided, don't error and use it. + logger.error(str(err)) + else: + raise # Current exception + + if abi: + # if the ABI is a str then convert it to a JSON dictionary. + if isinstance(abi, Path) or ( + isinstance(abi, str) and "{" not in abi and Path(abi).is_file() + ): + # Handle both absolute and relative paths + abi_path = Path(abi) + if not abi_path.is_absolute(): + abi_path = self.local_project.path / abi + + try: + abi = json.loads(abi_path.read_text()) + except Exception as err: + if contract_type: + # If a default contract type was provided, don't error and use it. + logger.error(str(err)) + else: + raise # Current exception + + elif isinstance(abi, str): + # JSON str + try: + abi = json.loads(abi) + except Exception as err: + if contract_type: + # If a default contract type was provided, don't error and use it. + logger.error(str(err)) + else: + raise # Current exception + + # If the ABI was a str, it should be a list now. + if isinstance(abi, list): + contract_type = ContractType(abi=abi) + + # Ensure we cache the contract-types from ABI! + self[contract_address] = contract_type + + else: + raise TypeError( + f"Invalid ABI type '{type(abi)}', expecting str, list[ABI] or a JSON file." + ) + + if not contract_type: + raise ContractNotFoundError( + contract_address, + self.provider.network.explorer is not None, + self.provider.network_choice, + ) + + if not txn_hash: + # Check for txn_hash in deployments. + contract_name = getattr(contract_type, "name", f"{contract_type}") or "" + deployments = self.deployments[contract_name] + for deployment in deployments[::-1]: + if deployment.address == contract_address and deployment.transaction_hash: + txn_hash = deployment.transaction_hash + break + + return ContractInstance(contract_address, contract_type, txn_hash=txn_hash) + + @classmethod + def instance_from_receipt( + cls, receipt: "ReceiptAPI", contract_type: ContractType + ) -> ContractInstance: + """ + A convenience method for creating instances from receipts. + + Args: + receipt (:class:`~ape.api.transactions.ReceiptAPI`): The receipt. + contract_type (ContractType): The deployed contract type. + + Returns: + :class:`~ape.contracts.base.ContractInstance` + """ + # NOTE: Mostly just needed this method to avoid a local import. + return ContractInstance.from_receipt(receipt, contract_type) + + def get_deployments(self, contract_container: ContractContainer) -> list[ContractInstance]: + """ + Retrieves previous deployments of a contract container or contract type. + Locally deployed contracts are saved for the duration of the script and read from + ``_local_deployments_mapping``, while those deployed on a live network are written to + disk in ``deployments_map.json``. + + Args: + contract_container (:class:`~ape.contracts.ContractContainer`): The + ``ContractContainer`` with deployments. + + Returns: + list[:class:`~ape.contracts.ContractInstance`]: Returns a list of contracts that + have been deployed. + """ + contract_type = contract_container.contract_type + if not (contract_name := contract_type.name or ""): + return [] + + config_deployments = self._get_config_deployments(contract_name) + if not (deployments := [*config_deployments, *self.deployments[contract_name]]): + return [] + + instances: list[ContractInstance] = [] + for deployment in deployments: + instance = ContractInstance( + deployment.address, contract_type, txn_hash=deployment.transaction_hash + ) + instances.append(instance) + + return instances + + def _get_config_deployments(self, contract_name: str) -> list[Deployment]: + if not self.network_manager.connected: + return [] + + ecosystem_name = self.provider.network.ecosystem.name + network_name = self.provider.network.name + all_config_deployments = ( + self.config_manager.deployments if self.config_manager.deployments else {} + ) + ecosystem_deployments = all_config_deployments.get(ecosystem_name, {}) + network_deployments = ecosystem_deployments.get(network_name, {}) + return [ + Deployment(address=c["address"], transaction_hash=c.get("transaction_hash")) + for c in network_deployments + if c["contract_type"] == contract_name + ] + + def clear_local_caches(self): + """ + Reset local caches to a blank state. + """ + if self.network_manager.connected: + for cache in ( + self.contract_types, + self.proxy_infos, + self.contract_creations, + self.blueprints, + ): + cache.memory = {} + + self.deployments.clear_local() + + def _get_contract_type_from_explorer(self, address: AddressType) -> Optional[ContractType]: + if not self.provider.network.explorer: + return None + + try: + contract_type = self.provider.network.explorer.get_contract_type(address) + except Exception as err: + explorer_name = self.provider.network.explorer.name + if "rate limit" in str(err).lower(): + # Don't show any additional error message during rate limit errors, + # if it can be helped, as it may scare users into thinking their + # contracts are not verified. + message = str(err) + else: + # Carefully word this message in a way that doesn't hint at + # any one specific reason, such as un-verified source code, + # which is potentially a scare for users. + message = ( + f"Attempted to retrieve contract type from explorer '{explorer_name}' " + f"from address '{address}' but encountered an exception: {err}\n" + ) + + logger.error(message) + return None + + if contract_type: + # Cache contract so faster look-up next time. + self.contract_types[address] = contract_type + + return contract_type + + +def _get_combined_contract_type( + proxy_contract_type: ContractType, + proxy_info: ProxyInfoAPI, + implementation_contract_type: ContractType, +) -> ContractType: + proxy_abis = [ + abi for abi in proxy_contract_type.abi if abi.type in ("error", "event", "function") + ] + + # Include "hidden" ABIs, such as Safe's `masterCopy()`. + if proxy_info.abi and proxy_info.abi.signature not in [ + abi.signature for abi in implementation_contract_type.abi + ]: + proxy_abis.append(proxy_info.abi) + + combined_contract_type = implementation_contract_type.model_copy(deep=True) + combined_contract_type.abi.extend(proxy_abis) + return combined_contract_type diff --git a/src/ape/managers/_deploymentscache.py b/src/ape/managers/_deploymentscache.py new file mode 100644 index 0000000000..9a496f83b0 --- /dev/null +++ b/src/ape/managers/_deploymentscache.py @@ -0,0 +1,182 @@ +from contextlib import contextmanager +from pathlib import Path +from typing import Optional + +from ape.managers.base import BaseManager +from ape.types.address import AddressType +from ape.utils.basemodel import BaseModel, DiskCacheableModel +from ape.utils.os import create_tempdir + + +class Deployment(BaseModel): + """ + A single deployment of a contract in a network. + """ + + address: AddressType + transaction_hash: Optional[str] = None + + def __getitem__(self, key: str): + # Mainly exists for backwards compat. + if key == "address": + return self.address + elif key == "transaction_hash": + return self.transaction_hash + + raise KeyError(key) + + def get(self, key: str): + # Mainly exists for backwards compat. + try: + return self[key] + except KeyError: + return None + + +class Deployments(DiskCacheableModel): + """The deployments structured JSON.""" + + ecosystems: dict[str, dict[str, dict[str, list[Deployment]]]] = {} + + +class DeploymentDiskCache(BaseManager): + """ + Manage cached contract deployments. + """ + + def __init__(self): + # NOTE: For some reason, deployments are all inside their ecosystem folders, + # but they still have the ecosystem key. Hence, the weird structure here. + self._deployments: dict[str, Deployments] = {} + self._base_path = None + + @property + def _is_live_network(self) -> bool: + return bool(self.network_manager.active_provider) and not self.provider.network.is_dev + + @property + def cachefile(self) -> Path: + base_path = self._base_path or self.provider.network.ecosystem.data_folder + return base_path / "deployments_map.json" + + @property + def _all_deployments(self) -> Deployments: + if not self._is_live_network: + # No file. + if "local" not in self._deployments: + self._deployments["local"] = Deployments() + + return self._deployments["local"] + + ecosystem_name = self.provider.network.ecosystem.name + if ecosystem_name not in self._deployments: + self._deployments[ecosystem_name] = Deployments.model_validate_file(self.cachefile) + + return self._deployments[ecosystem_name] + + def __getitem__(self, contract_name: str) -> list[Deployment]: + return self.get_deployments(contract_name) + + def __setitem__(self, contract_name, deployments: list[Deployment]): + self._set_deployments(contract_name, deployments) + + def __delitem__(self, contract_name: str): + self.remove_deployments(contract_name) + + def get_deployments( + self, + contract_name: str, + ecosystem_key: Optional[str] = None, + network_key: Optional[str] = None, + ) -> list[Deployment]: + """ + Get the deployments of the given contract on the currently connected network. + + Args: + contract_name (str): The name of the deployed contract. + ecosystem_key (str | None): The ecosystem key. Defaults to + the connected ecosystem's name. + network_key (str | None): The network key. Defaults to the + connected network's name. + + Returns: + list[Deployment] + """ + if not self.network_manager.connected and (not ecosystem_key or not network_key): + # Allows it to work when not connected (testing?) + return [] + + ecosystem_name = ecosystem_key or self.provider.network.ecosystem.name + network_name = network_key or self.provider.network.name.replace("-fork", "") + return ( + self._all_deployments.ecosystems.get(ecosystem_name, {}) + .get(network_name, {}) + .get(contract_name, []) + ) + + def cache_deployment( + self, + address: AddressType, + contract_name: str, + transaction_hash: Optional[str] = None, + ecosystem_key: Optional[str] = None, + network_key: Optional[str] = None, + ): + """ + Update the deployments cache with a new contract. + + Args: + address (AddressType): The address of the contract. + contract_name (str): The name of the contract type. + transaction_hash (Optional[str]): Optionally, the transaction has + associated with the deployment transaction. + ecosystem_key (str | None): The ecosystem key. Defaults to + the connected ecosystem's name. + network_key (str | None): The network key. Defaults to the + connected network's name. + """ + deployments = [ + *self.get_deployments(contract_name), + Deployment(address=address, transaction_hash=transaction_hash), + ] + self._set_deployments( + contract_name, + deployments, + ecosystem_key=ecosystem_key, + network_key=network_key, + ) + + @contextmanager + def use_temporary_cache(self): + base_path = self._base_path + deployments = self._deployments + with create_tempdir() as temp_path: + self._base_path = temp_path + self._deployments = {} + yield + + self._base_path = base_path + self._deployments = deployments + + def _set_deployments( + self, + contract_name: str, + deployments: list[Deployment], + ecosystem_key: Optional[str] = None, + network_key: Optional[str] = None, + ): + ecosystem_name = ecosystem_key or self.provider.network.ecosystem.name + network_name = network_key or self.provider.network.name.replace("-fork", "") + self._all_deployments.ecosystems.setdefault(ecosystem_name, {}) + self._all_deployments.ecosystems[ecosystem_name].setdefault(network_name, {}) + self._all_deployments.ecosystems[ecosystem_name][network_name][contract_name] = deployments + + # For live networks, cache the deployments to a file as well. + if self._is_live_network: + self._deployments[ecosystem_name].model_dump_file() + + def remove_deployments(self, contract_name: str): + self._set_deployments(contract_name, []) + + def clear_local(self): + self._deployments["local"] = Deployments() diff --git a/src/ape/managers/chain.py b/src/ape/managers/chain.py index 0ba2379cad..c459b29e26 100644 --- a/src/ape/managers/chain.py +++ b/src/ape/managers/chain.py @@ -1,51 +1,40 @@ -import json from collections import defaultdict -from collections.abc import Collection, Iterator -from concurrent.futures import ThreadPoolExecutor +from collections.abc import Iterator from contextlib import contextmanager -from functools import partial, singledispatchmethod -from pathlib import Path +from functools import cached_property, partial, singledispatchmethod from statistics import mean, median from typing import IO, TYPE_CHECKING, Optional, Union, cast import pandas as pd -from ethpm_types import ABI, ContractType from rich.box import SIMPLE from rich.table import Table from ape.api.address import BaseAddress -from ape.api.networks import NetworkAPI, ProxyInfoAPI from ape.api.providers import BlockAPI from ape.api.query import ( AccountTransactionQuery, BlockQuery, - ContractCreation, - ContractCreationQuery, extract_fields, validate_and_expand_columns, ) from ape.api.transactions import ReceiptAPI -from ape.contracts import ContractContainer, ContractInstance from ape.exceptions import ( APINotImplementedError, BlockNotFoundError, ChainError, - ContractNotFoundError, - ConversionError, - CustomError, ProviderNotConnectedError, QueryEngineError, TransactionNotFoundError, UnknownSnapshotError, ) from ape.logging import get_rich_console, logger +from ape.managers._contractscache import ContractCache from ape.managers.base import BaseManager from ape.types.address import AddressType from ape.utils.basemodel import BaseInterfaceModel -from ape.utils.misc import is_evm_precompile, is_zero_hex, log_instead_of_fail, nonreentrant +from ape.utils.misc import is_evm_precompile, is_zero_hex, log_instead_of_fail if TYPE_CHECKING: - from eth_pydantic_types import HexBytes from rich.console import Console as RichConsole from ape.types.trace import GasReport, SourceTraceback @@ -631,830 +620,6 @@ def _get_account_history(self, address: Union[BaseAddress, AddressType]) -> Acco return self._account_history_cache[address_key] -class ContractCache(BaseManager): - """ - A collection of cached contracts. Contracts can be cached in two ways: - - 1. An in-memory cache of locally deployed contracts - 2. A cache of contracts per network (only permanent networks are stored this way) - - When retrieving a contract, if a :class:`~ape.api.explorers.ExplorerAPI` is used, - it will be cached to disk for faster look-up next time. - """ - - _local_contract_types: dict[AddressType, ContractType] = {} - _local_proxies: dict[AddressType, ProxyInfoAPI] = {} - _local_blueprints: dict[str, ContractType] = {} - _local_deployments_mapping: dict[str, dict] = {} - _local_contract_creation: dict[str, ContractCreation] = {} - - # chain_id -> address -> custom_err - # Cached to prevent calling `new_class` multiple times with conflicts. - _custom_error_types: dict[int, dict[AddressType, set[type[CustomError]]]] = {} - - @property - def _network(self) -> NetworkAPI: - return self.provider.network - - @property - def _ecosystem_name(self) -> str: - return self._network.ecosystem.name - - @property - def _is_live_network(self) -> bool: - if not self.network_manager.active_provider: - return False - - return not self._network.is_dev - - @property - def _data_network_name(self) -> str: - return self._network.name.replace("-fork", "") - - @property - def _network_cache(self) -> Path: - return self._network.ecosystem.data_folder / self._data_network_name - - @property - def _contract_types_cache(self) -> Path: - return self._network_cache / "contract_types" - - @property - def _deployments_mapping_cache(self) -> Path: - return self._network.ecosystem.data_folder / "deployments_map.json" - - @property - def _proxy_info_cache(self) -> Path: - return self._network_cache / "proxy_info" - - @property - def _blueprint_cache(self) -> Path: - return self._network_cache / "blueprints" - - @property - def _contract_creation_cache(self) -> Path: - return self._network_cache / "contract_creation" - - @property - def _full_deployments(self) -> dict: - deployments = self._local_deployments_mapping - if self._is_live_network: - deployments = {**deployments, **self._load_deployments_cache()} - - return deployments - - @property - def _deployments(self) -> dict: - if not self.network_manager.active_provider: - return {} - - deployments = self._full_deployments - return deployments.get(self._ecosystem_name, {}).get(self._data_network_name, {}) - - @_deployments.setter - def _deployments(self, value): - deployments = self._full_deployments - ecosystem_deployments = self._local_deployments_mapping.get(self._ecosystem_name, {}) - ecosystem_deployments[self._data_network_name] = value - self._local_deployments_mapping[self._ecosystem_name] = ecosystem_deployments - - if self._is_live_network: - self._write_deployments_mapping( - {**deployments, self._ecosystem_name: ecosystem_deployments} - ) - - def __setitem__(self, address: AddressType, contract_type: ContractType): - """ - Cache the given contract type. Contracts are cached in memory per session. - In live networks, contracts also get cached to disk at - ``.ape/{ecosystem_name}/{network_name}/contract_types/{address}.json`` - for faster look-up next time. - - Args: - address (AddressType): The on-chain address of the contract. - contract_type (ContractType): The contract's type. - """ - - if self.network_manager.active_provider: - address = self.provider.network.ecosystem.decode_address(int(address, 16)) - else: - logger.warning("Not connected to a provider. Assuming Ethereum-style checksums.") - ethereum = self.network_manager.ethereum - address = ethereum.decode_address(int(address, 16)) - - self._cache_contract_type(address, contract_type) - - # NOTE: The txn_hash is not included when caching this way. - if contract_type.name: - self._cache_deployment(address, contract_type) - - def __delitem__(self, address: AddressType): - """ - Delete a cached contract. - If using a live network, it will also delete the file-cache for the contract. - - Args: - address (AddressType): The address to remove from the cache. - """ - - if address in self._local_contract_types: - del self._local_contract_types[address] - - # Delete proxy. - if address in self._local_proxies: - info = self._local_proxies[address] - target = info.target - del self._local_proxies[address] - - # Also delete target. - if target in self._local_contract_types: - del self._local_contract_types[target] - - if self._is_live_network: - if self._contract_types_cache.is_dir(): - address_file = self._contract_types_cache / f"{address}.json" - address_file.unlink(missing_ok=True) - - if self._proxy_info_cache.is_dir(): - disk_info = self._get_proxy_info_from_disk(address) - if disk_info: - target = disk_info.target - address_file = self._proxy_info_cache / f"{address}.json" - address_file.unlink() - - # Also delete the target. - self.__delitem__(target) - - def __contains__(self, address: AddressType) -> bool: - return self.get(address) is not None - - def cache_deployment(self, contract_instance: ContractInstance): - """ - Cache the given contract instance's type and deployment information. - - Args: - contract_instance (:class:`~ape.contracts.base.ContractInstance`): The contract - to cache. - """ - address = contract_instance.address - contract_type = contract_instance.contract_type # may be a proxy - - # Cache contract type in memory before proxy check, - # in case it is needed somewhere. It may get overridden. - self._local_contract_types[address] = contract_type - - if proxy_info := self.provider.network.ecosystem.get_proxy_info(address): - # The user is caching a deployment of a proxy with the target already set. - self.cache_proxy_info(address, proxy_info) - if implementation_contract := self.get(proxy_info.target): - updated_proxy_contract = _get_combined_contract_type( - contract_type, proxy_info, implementation_contract - ) - self._cache_contract_type(address, updated_proxy_contract) - - # Use this contract type in the user's contract instance. - contract_instance.contract_type = updated_proxy_contract - - else: - # No implementation yet. Just cache proxy. - self._cache_contract_type(address, contract_type) - - else: - # Regular contract. Cache normally. - self._cache_contract_type(address, contract_type) - - # Cache the deployment now. - txn_hash = contract_instance.txn_hash - if contract_type.name: - self._cache_deployment(address, contract_type, txn_hash) - - return contract_type - - def cache_proxy_info(self, address: AddressType, proxy_info: ProxyInfoAPI): - """ - Cache proxy info for a particular address, useful for plugins adding already - deployed proxies. When you deploy a proxy locally, it will also call this method. - - Args: - address (AddressType): The address of the proxy contract. - proxy_info (:class:`~ape.api.networks.ProxyInfo`): The proxy info class - to cache. - """ - if self.get_proxy_info(address) and self._is_live_network: - return - - self._local_proxies[address] = proxy_info - - if self._is_live_network: - self._cache_proxy_info_to_disk(address, proxy_info) - - def cache_blueprint(self, blueprint_id: str, contract_type: ContractType): - """ - Cache a contract blueprint. - - Args: - blueprint_id (``str``): The ID of the blueprint. For example, in EIP-5202, - it would be the address of the deployed blueprint. For Starknet, it would - be the class identifier. - contract_type (``ContractType``): The contract type associated with the blueprint. - """ - - if self.get_blueprint(blueprint_id) and self._is_live_network: - return - - self._local_blueprints[blueprint_id] = contract_type - - if self._is_live_network: - self._cache_blueprint_to_disk(blueprint_id, contract_type) - - def get_proxy_info(self, address: AddressType) -> Optional[ProxyInfoAPI]: - """ - Get proxy information about a contract using its address, - either from a local cache, a disk cache, or the provider. - - Args: - address (AddressType): The address of the proxy contract. - - Returns: - Optional[:class:`~ape.api.networks.ProxyInfoAPI`] - """ - return self._local_proxies.get(address) or self._get_proxy_info_from_disk(address) - - def get_creation_metadata(self, address: AddressType) -> Optional[ContractCreation]: - """ - Get contract creation metadata containing txn_hash, deployer, factory, block. - - Args: - address (AddressType): The address of the contract. - - Returns: - Optional[:class:`~ape.api.query.ContractCreation`] - """ - if creation := self._local_contract_creation.get(address): - return creation - - # read from disk - elif creation := self._get_contract_creation_from_disk(address): - self._local_contract_creation[address] = creation - return creation - - # query and cache - query = ContractCreationQuery(columns=["*"], contract=address) - get_creation = self.query_manager.query(query) - - try: - if not (creation := next(get_creation, None)): # type: ignore[arg-type] - return None - - except QueryEngineError: - return None - - if self._is_live_network: - self._cache_contract_creation_to_disk(address, creation) - - self._local_contract_creation[address] = creation - return creation - - def get_blueprint(self, blueprint_id: str) -> Optional[ContractType]: - """ - Get a cached blueprint contract type. - - Args: - blueprint_id (``str``): The unique identifier used when caching - the blueprint. - - Returns: - ``ContractType`` - """ - - return self._local_blueprints.get(blueprint_id) or self._get_blueprint_from_disk( - blueprint_id - ) - - def _get_errors( - self, address: AddressType, chain_id: Optional[int] = None - ) -> set[type[CustomError]]: - if chain_id is None and self.network_manager.active_provider is not None: - chain_id = self.provider.chain_id - elif chain_id is None: - raise ValueError("Missing chain ID.") - - if chain_id not in self._custom_error_types: - return set() - - errors = self._custom_error_types[chain_id] - if address in errors: - return errors[address] - - return set() - - def _cache_error( - self, address: AddressType, error: type[CustomError], chain_id: Optional[int] = None - ): - if chain_id is None and self.network_manager.active_provider is not None: - chain_id = self.provider.chain_id - elif chain_id is None: - raise ValueError("Missing chain ID.") - - if chain_id not in self._custom_error_types: - self._custom_error_types[chain_id] = {address: set()} - elif address not in self._custom_error_types[chain_id]: - self._custom_error_types[chain_id][address] = set() - - self._custom_error_types[chain_id][address].add(error) - - def _cache_contract_type(self, address: AddressType, contract_type: ContractType): - self._local_contract_types[address] = contract_type - if self._is_live_network: - # NOTE: We don't cache forked network contracts in this method to avoid caching - # deployments from a fork. However, if you retrieve a contract from an explorer - # when using a forked network, it will still get cached to disk. - self._cache_contract_to_disk(address, contract_type) - - def _cache_deployment( - self, address: AddressType, contract_type: ContractType, txn_hash: Optional[str] = None - ): - deployments = self._deployments - contract_deployments = deployments.get(contract_type.name or "", []) - new_deployment = {"address": address, "transaction_hash": txn_hash} - contract_deployments.append(new_deployment) - self._deployments = {**deployments, contract_type.name: contract_deployments} - - def __getitem__(self, address: AddressType) -> ContractType: - contract_type = self.get(address) - if not contract_type: - # Create error message from custom exception cls. - err = ContractNotFoundError( - address, self.provider.network.explorer is not None, self.provider.network_choice - ) - # Must raise KeyError. - raise KeyError(str(err)) - - return contract_type - - def get_multiple( - self, addresses: Collection[AddressType], concurrency: Optional[int] = None - ) -> dict[AddressType, ContractType]: - """ - Get contract types for all given addresses. - - Args: - addresses (list[AddressType): A list of addresses to get contract types for. - concurrency (Optional[int]): The number of threads to use. Defaults to - ``min(4, len(addresses))``. - - Returns: - dict[AddressType, ContractType]: A mapping of addresses to their respective - contract types. - """ - if not addresses: - logger.warning("No addresses provided.") - return {} - - def get_contract_type(addr: AddressType): - addr = self.conversion_manager.convert(addr, AddressType) - ct = self.get(addr) - - if not ct: - logger.warning(f"Failed to locate contract at '{addr}'.") - return addr, None - else: - return addr, ct - - converted_addresses: list[AddressType] = [] - for address in converted_addresses: - if not self.conversion_manager.is_type(address, AddressType): - converted_address = self.conversion_manager.convert(address, AddressType) - converted_addresses.append(converted_address) - else: - converted_addresses.append(address) - - contract_types = {} - default_max_threads = 4 - max_threads = ( - concurrency - if concurrency is not None - else min(len(addresses), default_max_threads) or default_max_threads - ) - with ThreadPoolExecutor(max_workers=max_threads) as pool: - for address, contract_type in pool.map(get_contract_type, addresses): - if contract_type is None: - continue - - contract_types[address] = contract_type - - return contract_types - - @nonreentrant(key_fn=lambda *args, **kwargs: args[1]) - def get( - self, - address: AddressType, - default: Optional[ContractType] = None, - fetch_from_explorer: bool = True, - ) -> Optional[ContractType]: - """ - Get a contract type by address. - If the contract is cached, it will return the contract from the cache. - Otherwise, if on a live network, it fetches it from the - :class:`~ape.api.explorers.ExplorerAPI`. - - Args: - address (AddressType): The address of the contract. - default (Optional[ContractType]): A default contract when none is found. - Defaults to ``None``. - fetch_from_explorer (bool): Set to ``False`` to avoid fetching from an - explorer. Defaults to ``True``. Only fetches if it needs to (uses disk - & memory caching otherwise). - - Returns: - Optional[ContractType]: The contract type if it was able to get one, - otherwise the default parameter. - """ - try: - address_key: AddressType = self.conversion_manager.convert(address, AddressType) - except ConversionError: - if not address.startswith("0x"): - # Still raise conversion errors for ENS and such. - raise - - # In this case, it at least _looked_ like an address. - return None - - if contract_type := self._local_contract_types.get(address_key): - if default and default != contract_type: - # Replacing contract type - self._local_contract_types[address_key] = default - return default - - return contract_type - - if self._network.is_local: - # Don't check disk-cache or explorer when using local - if default: - self._local_contract_types[address_key] = default - - return default - - if not (contract_type := self._get_contract_type_from_disk(address_key)): - # Contract is not cached yet. Check broader sources, such as an explorer. - # First, detect if this is a proxy. - proxy_info = self._local_proxies.get(address_key) or self._get_proxy_info_from_disk( - address_key - ) - if not proxy_info: - proxy_info = self.provider.network.ecosystem.get_proxy_info(address_key) - if proxy_info and self._is_live_network: - self._cache_proxy_info_to_disk(address_key, proxy_info) - - if proxy_info: - # Contract is a proxy. - self._local_proxies[address_key] = proxy_info - implementation_contract_type = self.get(proxy_info.target, default=default) - proxy_contract_type = ( - self._get_contract_type_from_explorer(address_key) - if fetch_from_explorer - else None - ) - if proxy_contract_type: - contract_type_to_cache = _get_combined_contract_type( - proxy_contract_type, proxy_info, implementation_contract_type - ) - else: - contract_type_to_cache = implementation_contract_type - - self._local_contract_types[address_key] = contract_type_to_cache - self._cache_contract_to_disk(address_key, contract_type_to_cache) - return contract_type_to_cache - - if not self.provider.get_code(address_key): - if default: - self._local_contract_types[address_key] = default - self._cache_contract_to_disk(address_key, default) - - return default - - # Also gets cached to disk for faster lookup next time. - if fetch_from_explorer: - contract_type = self._get_contract_type_from_explorer(address_key) - - # Cache locally for faster in-session look-up. - if contract_type: - self._local_contract_types[address_key] = contract_type - - if not contract_type: - if default: - self._local_contract_types[address_key] = default - self._cache_contract_to_disk(address_key, default) - - return default - - if default and default != contract_type: - # Replacing contract type - self._local_contract_types[address_key] = default - self._cache_contract_to_disk(address_key, default) - return default - - return contract_type - - @classmethod - def get_container(cls, contract_type: ContractType) -> ContractContainer: - """ - Get a contract container for the given contract type. - - Args: - contract_type (ContractType): The contract type to wrap. - - Returns: - ContractContainer: A container object you can deploy. - """ - - return ContractContainer(contract_type) - - def instance_at( - self, - address: Union[str, AddressType], - contract_type: Optional[ContractType] = None, - txn_hash: Optional[Union[str, "HexBytes"]] = None, - abi: Optional[Union[list[ABI], dict, str, Path]] = None, - fetch_from_explorer: bool = True, - ) -> ContractInstance: - """ - Get a contract at the given address. If the contract type of the contract is known, - either from a local deploy or a :class:`~ape.api.explorers.ExplorerAPI`, it will use that - contract type. You can also provide the contract type from which it will cache and use - next time. - - Raises: - TypeError: When passing an invalid type for the `contract_type` arguments - (expects `ContractType`). - :class:`~ape.exceptions.ContractNotFoundError`: When the contract type is not found. - - Args: - address (Union[str, AddressType]): The address of the plugin. If you are using the ENS - plugin, you can also provide an ENS domain name. - contract_type (Optional[``ContractType``]): Optionally provide the contract type - in case it is not already known. - txn_hash (Optional[Union[str, HexBytes]]): The hash of the transaction responsible for - deploying the contract, if known. Useful for publishing. Defaults to ``None``. - abi (Optional[Union[list[ABI], dict, str, Path]]): Use an ABI str, dict, path, - or ethpm models to create a contract instance class. - fetch_from_explorer (bool): Set to ``False`` to avoid fetching from the explorer. - Defaults to ``True``. Won't fetch unless it needs to (uses disk & memory caching - first). - - Returns: - :class:`~ape.contracts.base.ContractInstance` - """ - - if self.conversion_manager.is_type(address, AddressType): - contract_address = cast(AddressType, address) - else: - try: - contract_address = self.conversion_manager.convert(address, AddressType) - except ConversionError as err: - raise ValueError(f"Unknown address value '{address}'.") from err - - try: - # Always attempt to get an existing contract type to update caches - contract_type = self.get( - contract_address, default=contract_type, fetch_from_explorer=fetch_from_explorer - ) - except Exception as err: - if contract_type or abi: - # If a default contract type was provided, don't error and use it. - logger.error(str(err)) - else: - raise # Current exception - - if abi: - # if the ABI is a str then convert it to a JSON dictionary. - if isinstance(abi, Path) or ( - isinstance(abi, str) and "{" not in abi and Path(abi).is_file() - ): - # Handle both absolute and relative paths - abi_path = Path(abi) - if not abi_path.is_absolute(): - abi_path = self.local_project.path / abi - - try: - abi = json.loads(abi_path.read_text()) - except Exception as err: - if contract_type: - # If a default contract type was provided, don't error and use it. - logger.error(str(err)) - else: - raise # Current exception - - elif isinstance(abi, str): - # JSON str - try: - abi = json.loads(abi) - except Exception as err: - if contract_type: - # If a default contract type was provided, don't error and use it. - logger.error(str(err)) - else: - raise # Current exception - - # If the ABI was a str, it should be a list now. - if isinstance(abi, list): - contract_type = ContractType(abi=abi) - - # Ensure we cache the contract-types from ABI! - self[contract_address] = contract_type - - else: - raise TypeError( - f"Invalid ABI type '{type(abi)}', expecting str, list[ABI] or a JSON file." - ) - - if not contract_type: - raise ContractNotFoundError( - contract_address, - self.provider.network.explorer is not None, - self.provider.network_choice, - ) - - elif not isinstance(contract_type, ContractType): - raise TypeError( - f"Expected type '{ContractType.__name__}' for argument 'contract_type'." - ) - - if not txn_hash: - # Check for txn_hash in deployments. - deployments = self._deployments.get(contract_type.name) or [] - for deployment in deployments[::-1]: - if deployment["address"] == contract_address and "transaction_hash" in deployment: - txn_hash = deployment["transaction_hash"] - break - - return ContractInstance(contract_address, contract_type, txn_hash=txn_hash) - - def instance_from_receipt( - self, receipt: ReceiptAPI, contract_type: ContractType - ) -> ContractInstance: - """ - A convenience method for creating instances from receipts. - - Args: - receipt (:class:`~ape.api.transactions.ReceiptAPI`): The receipt. - contract_type (ContractType): The deployed contract type. - - Returns: - :class:`~ape.contracts.base.ContractInstance` - """ - # NOTE: Mostly just needed this method to avoid a local import. - return ContractInstance.from_receipt(receipt, contract_type) - - def get_deployments(self, contract_container: ContractContainer) -> list[ContractInstance]: - """ - Retrieves previous deployments of a contract container or contract type. - Locally deployed contracts are saved for the duration of the script and read from - ``_local_deployments_mapping``, while those deployed on a live network are written to - disk in ``deployments_map.json``. - - Args: - contract_container (:class:`~ape.contracts.ContractContainer`): The - ``ContractContainer`` with deployments. - - Returns: - list[:class:`~ape.contracts.ContractInstance`]: Returns a list of contracts that - have been deployed. - """ - - contract_type = contract_container.contract_type - contract_name = contract_type.name - if not contract_name: - return [] - - config_deployments = [] - if self.network_manager.active_provider: - ecosystem_name = self.provider.network.ecosystem.name - network_name = self.provider.network.name - all_config_deployments = ( - self.config_manager.deployments if self.config_manager.deployments else {} - ) - ecosystem_deployments = all_config_deployments.get(ecosystem_name, {}) - network_deployments = ecosystem_deployments.get(network_name, {}) - config_deployments = [ - c for c in network_deployments if c["contract_type"] == contract_name - ] - - deployments = [*config_deployments, *self._deployments.get(contract_name, [])] - if not deployments: - return [] - - instances: list[ContractInstance] = [] - for deployment in deployments: - address = deployment["address"] - txn_hash = deployment.get("transaction_hash") - instance = ContractInstance(address, contract_type, txn_hash=txn_hash) - instances.append(instance) - - return instances - - def clear_local_caches(self): - """ - Reset local caches to a blank state. - """ - self._local_contract_types = {} - self._local_proxies = {} - self._local_blueprints = {} - self._local_deployments_mapping = {} - self._local_contract_creation = {} - - def _get_contract_type_from_disk(self, address: AddressType) -> Optional[ContractType]: - address_file = self._contract_types_cache / f"{address}.json" - if not address_file.is_file(): - return None - - return ContractType.model_validate_json(address_file.read_text()) - - def _get_proxy_info_from_disk(self, address: AddressType) -> Optional[ProxyInfoAPI]: - address_file = self._proxy_info_cache / f"{address}.json" - if not address_file.is_file(): - return None - - return ProxyInfoAPI.model_validate_json(address_file.read_text(encoding="utf8")) - - def _get_blueprint_from_disk(self, blueprint_id: str) -> Optional[ContractType]: - contract_file = self._blueprint_cache / f"{blueprint_id}.json" - if not contract_file.is_file(): - return None - - return ContractType.model_validate_json(contract_file.read_text(encoding="utf8")) - - def _get_contract_type_from_explorer(self, address: AddressType) -> Optional[ContractType]: - if not self._network.explorer: - return None - - try: - contract_type = self._network.explorer.get_contract_type(address) - except Exception as err: - explorer_name = self._network.explorer.name - if "rate limit" in str(err).lower(): - # Don't show any additional error message during rate limit errors, - # if it can be helped, as it may scare users into thinking their - # contracts are not verified. - message = str(err) - else: - # Carefully word this message in a way that doesn't hint at - # any one specific reason, such as un-verified source code, - # which is potentially a scare for users. - message = ( - f"Attempted to retrieve contract type from explorer '{explorer_name}' " - f"from address '{address}' but encountered an exception: {err}\n" - ) - - logger.error(message) - return None - - if contract_type: - # Cache contract so faster look-up next time. - self._cache_contract_to_disk(address, contract_type) - - return contract_type - - def _get_contract_creation_from_disk(self, address: AddressType) -> Optional[ContractCreation]: - path = self._contract_creation_cache / f"{address}.json" - if not path.is_file(): - return None - - return ContractCreation.model_validate_json(path.read_text()) - - def _cache_contract_to_disk(self, address: AddressType, contract_type: ContractType): - self._contract_types_cache.mkdir(exist_ok=True, parents=True) - address_file = self._contract_types_cache / f"{address}.json" - address_file.write_text(contract_type.model_dump_json(), encoding="utf8") - - def _cache_proxy_info_to_disk(self, address: AddressType, proxy_info: ProxyInfoAPI): - self._proxy_info_cache.mkdir(exist_ok=True, parents=True) - address_file = self._proxy_info_cache / f"{address}.json" - address_file.write_text(proxy_info.model_dump_json(), encoding="utf8") - - def _cache_blueprint_to_disk(self, blueprint_id: str, contract_type: ContractType): - self._blueprint_cache.mkdir(exist_ok=True, parents=True) - blueprint_file = self._blueprint_cache / f"{blueprint_id}.json" - blueprint_file.write_text(contract_type.model_dump_json(), encoding="utf8") - - def _cache_contract_creation_to_disk(self, address: AddressType, creation: ContractCreation): - self._contract_creation_cache.mkdir(exist_ok=True, parents=True) - path = self._contract_creation_cache / f"{address}.json" - path.write_text(creation.model_dump_json(), encoding="utf8") - - def _load_deployments_cache(self) -> dict: - return ( - json.loads(self._deployments_mapping_cache.read_text(encoding="utf8")) - if self._deployments_mapping_cache.is_file() - else {} - ) - - def _write_deployments_mapping(self, deployments_map: dict): - self._deployments_mapping_cache.parent.mkdir(exist_ok=True, parents=True) - with self._deployments_mapping_cache.open("w") as fp: - json.dump(deployments_map, fp, sort_keys=True, indent=2, default=sorted) - - class ReportManager(BaseManager): """ A class representing the active Ape session. Useful for tracking data and @@ -1549,9 +714,15 @@ class ChainManager(BaseManager): _chain_id_map: dict[str, int] = {} _block_container_map: dict[int, BlockContainer] = {} _transaction_history_map: dict[int, TransactionHistory] = {} - contracts: ContractCache = ContractCache() _reports: ReportManager = ReportManager() + @cached_property + def contracts(self) -> ContractCache: + """ + A manager for cached contract-types, proxy info, and more. + """ + return ContractCache() + @property def blocks(self) -> BlockContainer: """ @@ -1794,23 +965,3 @@ def get_receipt(self, transaction_hash: str) -> ReceiptAPI: raise TransactionNotFoundError(transaction_hash=transaction_hash) return receipt - - -def _get_combined_contract_type( - proxy_contract_type: ContractType, - proxy_info: ProxyInfoAPI, - implementation_contract_type: ContractType, -) -> ContractType: - proxy_abis = [ - abi for abi in proxy_contract_type.abi if abi.type in ("error", "event", "function") - ] - - # Include "hidden" ABIs, such as Safe's `masterCopy()`. - if proxy_info.abi and proxy_info.abi.signature not in [ - abi.signature for abi in implementation_contract_type.abi - ]: - proxy_abis.append(proxy_info.abi) - - combined_contract_type = implementation_contract_type.model_copy(deep=True) - combined_contract_type.abi.extend(proxy_abis) - return combined_contract_type diff --git a/src/ape/managers/networks.py b/src/ape/managers/networks.py index 773e947879..29322ca24b 100644 --- a/src/ape/managers/networks.py +++ b/src/ape/managers/networks.py @@ -62,6 +62,13 @@ def active_provider(self) -> Optional["ProviderAPI"]: def active_provider(self, new_value: "ProviderAPI"): self._active_provider = new_value + @property + def connected(self) -> bool: + """ + ``True`` when there is an active provider. + """ + return self.active_provider is not None + @property def network(self) -> "NetworkAPI": """ diff --git a/src/ape/managers/project.py b/src/ape/managers/project.py index 24b151ad5e..e2f1ee591b 100644 --- a/src/ape/managers/project.py +++ b/src/ape/managers/project.py @@ -2003,7 +2003,7 @@ def instance_map(self) -> dict[str, dict[str, EthPMContractInstance]]: return result - def track(self, contract: ContractInstance): + def track(self, contract: ContractInstance, allow_dev: bool = False): """ Indicate that a contract deployment should be included in the package manifest upon publication. @@ -2014,9 +2014,9 @@ def track(self, contract: ContractInstance): Args: contract (:class:`~ape.contracts.base.ContractInstance`): The contract to track as a deployment of the project. + allow_dev (bool): Set to ``True`` if simulating in a local dev environment. """ - - if self.provider.network.is_dev: + if not allow_dev and self.provider.network.is_dev: raise ProjectError("Can only publish deployments on a live network.") elif not (contract_name := contract.contract_type.name): diff --git a/src/ape/pytest/runners.py b/src/ape/pytest/runners.py index 94684b24b1..f4a7653af9 100644 --- a/src/ape/pytest/runners.py +++ b/src/ape/pytest/runners.py @@ -320,7 +320,7 @@ def pytest_terminal_summary(self, terminalreporter): def _show_gas_report(self, terminalreporter): terminalreporter.section("Gas Profile") - if not self.network_manager.active_provider: + if not self.network_manager.connected: # Happens if never needed to connect (no tests) return @@ -336,7 +336,7 @@ def _show_coverage_report(self, terminalreporter): if self.config_wrapper.ape_test_config.coverage.reports.terminal: terminalreporter.section("Coverage Profile") - if not self.network_manager.active_provider: + if not self.network_manager.connected: # Happens if never needed to connect (no tests) return diff --git a/src/ape/utils/basemodel.py b/src/ape/utils/basemodel.py index 3d80bf1e48..be4ecee9e4 100644 --- a/src/ape/utils/basemodel.py +++ b/src/ape/utils/basemodel.py @@ -614,3 +614,58 @@ def __dir__(self) -> list[str]: """ # Filter out protected/private members return [member for member in super().__dir__() if not member.startswith("_")] + + +class DiskCacheableModel(BaseModel): + """ + A model with extra utilities for caching to disk. + """ + + def __init__(self, *args, **kwargs): + path = kwargs.pop("path", None) + super().__init__(*args, **kwargs) + self._path = path + + def model_dump_file(self, path: Optional[Path] = None, **kwargs): + """ + Save this model to disk. + + Args: + path (Optional[Path]): Optionally provide the path now + if one wasn't declared at init time. If given a directory, + saves the file in that dir with the name of class with a + .json suffix. + **kwargs: Extra kwargs to pass to ``.model_dump_json()``. + """ + path = self._get_path(path=path) + json_str = self.model_dump_json(**kwargs) + path.unlink(missing_ok=True) + path.write_text(json_str) + + @classmethod + def model_validate_file(cls, path: Path, **kwargs): + """ + Validate a file. + + Args: + path (Optional[Path]): Optionally provide the path now + if one wasn't declared at init time. + **kwargs: Extra kwargs to pass to ``.model_validate_json()``. + """ + if json_str := path.read_text(encoding="utf8") if path.is_file() else "": + model = cls.model_validate_json(json_str, **kwargs) + else: + model = cls.model_validate({}) + + model._path = path + return model + + def _get_path(self, path: Optional[Path] = None) -> Path: + if save_path := (path or self._path): + return save_path + + elif save_path.is_dir(): + name = self.__class__.__name__ or "Model" + return save_path / f"{name}.json" + + raise ValueError("Unknown path for caching.") diff --git a/src/ape/utils/os.py b/src/ape/utils/os.py index bc6e3f37dd..6a3a4213df 100644 --- a/src/ape/utils/os.py +++ b/src/ape/utils/os.py @@ -1,3 +1,4 @@ +import json import os import re import sys @@ -369,3 +370,69 @@ def extract_archive(archive_file: Path, destination: Optional[Path] = None): else: raise ValueError(f"Unsupported zip format: '{archive_file.suffix}'.") + + +class CacheDirectory: + """ + A directory for caching data where each data item is named + ``.json`` and is in the directory. You can access the + items by their key like a dictionary. This type is used + in Ape's contract-caching for ContractTypes, ProxyInfoAPI, + and other model types. + """ + + def __init__(self, path: Path): + if path.is_file(): + raise ValueError("Expecting directory.") + + self._path = path + + def __getitem__(self, key: str) -> dict: + """ + Get the data from ``base_path / .json``. + + Returns: + The JSON dictionary + """ + return self.get_data(key) + + def __setitem__(self, key: str, value: dict): + """ + Cache the given data to ``base_path / .json``. + + Args: + key (str): The key, used as the file name ``{key}.json``. + value (dict): The JSON dictionary to cache. + """ + self.cache_data(key, value) + + def __delitem__(self, key: str): + """ + Delete the cache-file. + + Args: + key (str): The file stem of the JSON. + """ + self.delete_data(key) + + def get_file(self, key: str) -> Path: + return self._path / f"{key}.json" + + def cache_data(self, key: str, data: dict): + json_str = json.dumps(data) + file = self.get_file(key) + file.unlink(missing_ok=True) + file.parent.mkdir(parents=True, exist_ok=True) + file.write_text(json_str) + + def get_data(self, key: str) -> dict: + file = self.get_file(key) + if not file.is_file(): + return {} + + json_str = file.read_text(encoding="utf8") + return json.loads(json_str) + + def delete_data(self, key: str): + file = self.get_file(key) + file.unlink(missing_ok=True) diff --git a/src/ape_cache/query.py b/src/ape_cache/query.py index deccb1c651..159646b3a7 100644 --- a/src/ape_cache/query.py +++ b/src/ape_cache/query.py @@ -129,7 +129,7 @@ def database_connection(self): if self.provider.network.is_local: return None - if not self.network_manager.active_provider: + if not self.network_manager.connected: raise QueryEngineError("Not connected to a provider") database_file = self._get_database_file( diff --git a/src/ape_ethereum/provider.py b/src/ape_ethereum/provider.py index bd14077935..dde8c84cf8 100644 --- a/src/ape_ethereum/provider.py +++ b/src/ape_ethereum/provider.py @@ -1132,7 +1132,12 @@ def _send_transaction(self, txn: TransactionAPI) -> str: def _post_connect(self): # Register the console contract for trace enrichment - self.chain_manager.contracts._cache_contract_type(CONSOLE_ADDRESS, console_contract) + self.chain_manager.contracts.cache_contract_type( + CONSOLE_ADDRESS, + console_contract, + ecosystem_key=self.network.ecosystem.name, + network_key=self.network.name, + ) def make_request(self, rpc: str, parameters: Optional[Iterable] = None) -> Any: return request_with_retry(lambda: self._make_request(rpc, parameters=parameters)) diff --git a/tests/functional/conftest.py b/tests/functional/conftest.py index 2e7b2e0cf2..5ae46f2f21 100644 --- a/tests/functional/conftest.py +++ b/tests/functional/conftest.py @@ -285,11 +285,9 @@ def project_with_source_files_contract(project_with_contract): @pytest.fixture -def clean_contracts_cache(chain): - original_cached_contracts = chain.contracts._local_contract_types - chain.contracts._local_contract_types = {} - yield - chain.contracts._local_contract_types = original_cached_contracts +def clean_contract_caches(chain): + with chain.contracts.use_temporary_caches(): + yield @pytest.fixture @@ -329,7 +327,7 @@ def contract_getter(address): / f"{address}.json" ) contract = ContractType.model_validate_json(path.read_text()) - chain.contracts._local_contract_types[address] = contract + chain.contracts[address] = contract return contract return contract_getter @@ -452,17 +450,6 @@ def _assert_log_values( return _assert_log_values -@pytest.fixture -def remove_disk_writes_deployments(chain): - if chain.contracts._deployments_mapping_cache.is_file(): - chain.contracts._deployments_mapping_cache.unlink() - - yield - - if chain.contracts._deployments_mapping_cache.is_file(): - chain.contracts._deployments_mapping_cache.unlink() - - @pytest.fixture(scope="session") def logger(): return _logger @@ -793,6 +780,6 @@ def fn(addr: AddressType): contract_type = ContractType(abi=abi) # Hack in contract-type. - chain.contracts._local_contract_types[addr] = contract_type + chain.contracts.contract_types[addr] = contract_type return fn diff --git a/tests/functional/geth/test_contracts_cache.py b/tests/functional/geth/test_contracts_cache.py index 1c3cb5a6de..92782c49fa 100644 --- a/tests/functional/geth/test_contracts_cache.py +++ b/tests/functional/geth/test_contracts_cache.py @@ -48,6 +48,8 @@ def get_contract_type(address, *args, **kwargs): # w/ both proxy ABIs and the target ABIs. contract_from_explorer = chain.contracts.instance_at(proxy_contract.address) + network.__dict__.pop("explorer", None) + # Ensure we can call proxy methods! assert contract_from_explorer.masterCopy # No attr error! diff --git a/tests/functional/geth/test_query.py b/tests/functional/geth/test_query.py index 04c9f27834..5d1d5c86f2 100644 --- a/tests/functional/geth/test_query.py +++ b/tests/functional/geth/test_query.py @@ -11,7 +11,7 @@ def test_get_contract_metadata( # hold onto block, setup mock. block = geth_provider.get_block(actual.block) - del chain.contracts._local_contract_creation[geth_contract.address] + del chain.contracts.contract_creations[geth_contract.address] mock_geth.web3.eth.get_block.return_value = block orig_web3 = chain.network_manager.active_provider._web3 diff --git a/tests/functional/test_accounts.py b/tests/functional/test_accounts.py index 7a7ecfe41c..e7f1422428 100644 --- a/tests/functional/test_accounts.py +++ b/tests/functional/test_accounts.py @@ -240,7 +240,7 @@ def test_transfer_value_of_0(sender, receiver): assert receiver.balance == initial_balance -def test_deploy(owner, contract_container, chain, clean_contracts_cache): +def test_deploy(owner, contract_container, clean_contract_caches): contract = owner.deploy(contract_container, 0) assert contract.address assert contract.txn_hash @@ -274,6 +274,7 @@ def test_deploy_and_publish(owner, contract_container, dummy_live_network, mock_ dummy_live_network.__dict__["explorer"] = mock_explorer contract = owner.deploy(contract_container, 0, publish=True, required_confirmations=0) mock_explorer.publish_contract.assert_called_once_with(contract.address) + dummy_live_network.__dict__["explorer"] = None @explorer_test @@ -281,6 +282,7 @@ def test_deploy_and_not_publish(owner, contract_container, dummy_live_network, m dummy_live_network.__dict__["explorer"] = mock_explorer owner.deploy(contract_container, 0, publish=True, required_confirmations=0) assert not mock_explorer.call_count + dummy_live_network.__dict__["explorer"] = None def test_deploy_proxy(owner, vyper_contract_instance, proxy_contract_container, chain): @@ -292,11 +294,11 @@ def test_deploy_proxy(owner, vyper_contract_instance, proxy_contract_container, assert proxy.myNumber # No attr err # Ensure was properly cached. - assert proxy.address in chain.contracts._local_contract_types - assert proxy.address in chain.contracts._local_proxies + assert proxy.address in chain.contracts.contract_types + assert proxy.address in chain.contracts.proxy_infos # Show the cached proxy info is correct. - proxy_info = chain.contracts._local_proxies[proxy.address] + proxy_info = chain.contracts.proxy_infos[proxy.address] assert proxy_info.target == target assert proxy_info.type == ProxyType.Delegate assert proxy_info.abi.name == "implementation" @@ -345,7 +347,7 @@ def test_deploy_no_deployment_bytecode(owner, bytecode): owner.deploy(contract) -def test_deploy_contract_type(owner, vyper_contract_type, chain, clean_contracts_cache): +def test_deploy_contract_type(owner, vyper_contract_type, clean_contract_caches): contract = owner.deploy(vyper_contract_type, 0) assert contract.address assert contract.txn_hash diff --git a/tests/functional/test_compilers.py b/tests/functional/test_compilers.py index 5e9ed7ac2c..dd2fcced4f 100644 --- a/tests/functional/test_compilers.py +++ b/tests/functional/test_compilers.py @@ -209,7 +209,7 @@ def test_enrich_error_custom_error(chain, compilers): err = ContractLogicError("0x6a12f104", contract_address=addr) # Hack in contract-type. - chain.contracts._local_contract_types[addr] = contract_type + chain.contracts.contract_types[addr] = contract_type # Enriching the error should produce a custom error from the ABI. actual = compilers.enrich_error(err) diff --git a/tests/functional/test_contract_container.py b/tests/functional/test_contract_container.py index 870b21e5c7..e311b23a7b 100644 --- a/tests/functional/test_contract_container.py +++ b/tests/functional/test_contract_container.py @@ -16,7 +16,7 @@ def test_deploy( not_owner, contract_container, networks_connected_to_tester, - clean_contracts_cache, + clean_contract_caches, ): contract = contract_container.deploy(4, sender=not_owner, something_else="IGNORED") assert contract.txn_hash @@ -33,7 +33,7 @@ def test_deploy_wrong_number_of_arguments( not_owner, contract_container, networks_connected_to_tester, - clean_contracts_cache, + clean_contract_caches, ): expected = ( r"The number of the given arguments \(0\) do not match what is defined in the " @@ -59,12 +59,14 @@ def test_deploy_and_publish(owner, contract_container, dummy_live_network, mock_ dummy_live_network.__dict__["explorer"] = mock_explorer contract = contract_container.deploy(0, sender=owner, publish=True, required_confirmations=0) mock_explorer.publish_contract.assert_called_once_with(contract.address) + dummy_live_network.__dict__["explorer"] = None def test_deploy_and_not_publish(owner, contract_container, dummy_live_network, mock_explorer): dummy_live_network.__dict__["explorer"] = mock_explorer contract_container.deploy(0, sender=owner, publish=False, required_confirmations=0) assert not mock_explorer.call_count + dummy_live_network.__dict__["explorer"] = None def test_deploy_privately(owner, contract_container): @@ -110,11 +112,11 @@ def test_deploy_proxy( assert proxy.myNumber # No attr err # Ensure caching works. - assert proxy.address in chain.contracts._local_contract_types - assert proxy.address in chain.contracts._local_proxies + assert proxy.address in chain.contracts.contract_types + assert proxy.address in chain.contracts.proxy_infos # Show the cached proxy info is correct. - proxy_info = chain.contracts._local_proxies[proxy.address] + proxy_info = chain.contracts.proxy_infos[proxy.address] assert proxy_info.target == target assert proxy_info.type == ProxyType.Delegate @@ -190,7 +192,7 @@ def test_at_fetch_from_explorer_false( project_with_contract.clean() # Simulate having an explorer plugin installed (e.g. ape-etherscan). - eth_tester_provider.network.explorer = mock_explorer + eth_tester_provider.network.__dict__["explorer"] = mock_explorer # Attempt to create an instance. It should NOT use the explorer at all! instance2 = container.at(instance.address, fetch_from_explorer=False) @@ -200,4 +202,4 @@ def test_at_fetch_from_explorer_false( assert mock_explorer.get_contract_type.call_count == 0 # Clean up test. - eth_tester_provider.network.explorer = None + eth_tester_provider.network.__dict__.pop("explorer") diff --git a/tests/functional/test_contracts_cache.py b/tests/functional/test_contracts_cache.py index 307668289b..e5827ba926 100644 --- a/tests/functional/test_contracts_cache.py +++ b/tests/functional/test_contracts_cache.py @@ -38,7 +38,7 @@ def test_instance_at_when_given_contract_type(chain, contract_instance): def test_instance_at_when_given_name_as_contract_type(chain, contract_instance): - expected_match = "Expected type 'ContractType' for argument 'contract_type'." + expected_match = "Expected type 'ContractType' for argument 'contract_type'; Given 'str'." with pytest.raises(TypeError, match=expected_match): address = str(contract_instance.address) bad_contract_type = contract_instance.contract_type.name @@ -123,25 +123,24 @@ def test_cache_deployment_live_network( chain, vyper_contract_instance, vyper_contract_container, - remove_disk_writes_deployments, + clean_contract_caches, dummy_live_network, ): # Arrange - Ensure the contract is not cached anywhere address = vyper_contract_instance.address contract_name = vyper_contract_instance.contract_type.name - deployments = chain.contracts._deployments - contract_types = chain.contracts._local_contract_types - chain.contracts._local_contract_types = { + contract_types = chain.contracts.contract_types.memory + chain.contracts.contract_types.memory = { a: ct for a, ct in contract_types.items() if a != address } - chain.contracts._deployments = {n: d for n, d in deployments.items() if n != contract_name} + del chain.contracts.deployments[contract_name] # Act chain.contracts.cache_deployment(vyper_contract_instance) # Assert actual_deployments = chain.contracts.get_deployments(vyper_contract_container) - actual_contract_type = chain.contracts._get_contract_type_from_disk(address) + actual_contract_type = chain.contracts.contract_types[address] expected = vyper_contract_instance.contract_type assert len(actual_deployments) == 1 assert actual_deployments[0].address == address @@ -155,13 +154,7 @@ def test_cache_default_contract_type_when_used(solidity_contract_instance, chain contract_type = solidity_contract_instance.contract_type # Delete contract from local cache if it's there - if address in chain.contracts._local_contract_types: - del chain.contracts._local_contract_types[address] - - # Delete cache file if it exists - cache_file = chain.contracts._contract_types_cache / f"{address}.json" - if cache_file.is_file(): - cache_file.unlink() + del chain.contracts[address] # Create a contract using the contract type when nothing is cached. contract = Contract(address, contract_type=contract_type) @@ -194,7 +187,7 @@ def test_contracts_getitem_contract_not_found(chain, eth_tester_provider): def test_deployments_mapping_cache_location(chain): # Arrange / Act - mapping_location = chain.contracts._deployments_mapping_cache + mapping_location = chain.contracts.deployments.cachefile split_mapping_location = str(mapping_location).split("/") # Assert @@ -211,10 +204,7 @@ def test_deployments_when_offline(chain, networks_disconnected, vyper_contract_c def test_get_deployments_local(chain, owner, contract_0, contract_1): # Arrange - chain.contracts._local_deployments_mapping = {} - chain.contracts._local_contract_types = {} - starting_contracts_list_0 = chain.contracts.get_deployments(contract_0) - starting_contracts_list_1 = chain.contracts.get_deployments(contract_1) + chain.contracts.clear_local_caches() deployed_contract_0 = owner.deploy(contract_0, 900000000) deployed_contract_1 = owner.deploy(contract_1, 900000001) @@ -226,16 +216,14 @@ def test_get_deployments_local(chain, owner, contract_0, contract_1): for contract_list in (contracts_list_0, contracts_list_1): assert type(contract_list[0]) is ContractInstance - index_0 = len(contracts_list_0) - len(starting_contracts_list_0) - 1 - index_1 = len(contracts_list_1) - len(starting_contracts_list_1) - 1 - actual_address_0 = contracts_list_0[index_0].address + actual_address_0 = contracts_list_0[-1].address assert actual_address_0 == deployed_contract_0.address - actual_address_1 = contracts_list_1[index_1].address + actual_address_1 = contracts_list_1[-1].address assert actual_address_1 == deployed_contract_1.address def test_get_deployments_live( - chain, owner, contract_0, contract_1, remove_disk_writes_deployments, dummy_live_network + chain, owner, contract_0, contract_1, clean_contract_caches, dummy_live_network ): deployed_contract_0 = owner.deploy(contract_0, 8000000, required_confirmations=0) deployed_contract_1 = owner.deploy(contract_1, 8000001, required_confirmations=0) @@ -252,7 +240,7 @@ def test_get_deployments_live( def test_get_multiple_deployments_live( - chain, owner, contract_0, contract_1, remove_disk_writes_deployments, dummy_live_network + chain, owner, contract_0, contract_1, clean_contract_caches, dummy_live_network ): starting_contracts_list_0 = chain.contracts.get_deployments(contract_0) starting_contracts_list_1 = chain.contracts.get_deployments(contract_1) @@ -354,9 +342,9 @@ def get_contract_type(addr): # Hack in a way to publish on this local network. with create_mock_sepolia() as network: + del chain.contracts[contract.address] mock_explorer.get_contract_type.side_effect = get_contract_type network.__dict__["explorer"] = mock_explorer - del chain.contracts[contract.address] try: actual = chain.contracts.get(contract.address) finally: @@ -373,6 +361,11 @@ def test_get_attempts_explorer_logs_errors_from_explorer( ): contract = owner.deploy(vyper_fallback_container) check_error_str = "__CHECK_FOR_THIS_ERROR__" + expected_log = ( + f"Attempted to retrieve contract type from explorer 'mock' " + f"from address '{contract.address}' but encountered an " + f"exception: {check_error_str}" + ) def get_contract_type(addr): if addr == contract.address: @@ -381,14 +374,9 @@ def get_contract_type(addr): raise ValueError("nope") with create_mock_sepolia() as network: + del chain.contracts[contract.address] mock_explorer.get_contract_type.side_effect = get_contract_type network.__dict__["explorer"] = mock_explorer - expected_log = ( - f"Attempted to retrieve contract type from explorer 'mock' " - f"from address '{contract.address}' but encountered an " - f"exception: {check_error_str}" - ) - del chain.contracts[contract.address] try: actual = chain.contracts.get(contract.address) finally: @@ -408,6 +396,8 @@ def test_get_attempts_explorer_logs_rate_limit_error_from_explorer( # Ensure is not cached locally. del chain.contracts[contract.address] + # For rate limit errors, we don't show anything else, + # as it may be confusing. check_error_str = "you have been rate limited" def get_contract_type(addr): @@ -419,16 +409,12 @@ def get_contract_type(addr): with create_mock_sepolia() as network: mock_explorer.get_contract_type.side_effect = get_contract_type network.__dict__["explorer"] = mock_explorer - - # For rate limit errors, we don't show anything else, - # as it may be confusing. - expected_log = "you have been rate limited" try: actual = chain.contracts.get(contract.address) finally: network.__dict__["explorer"] = None - assert expected_log in ape_caplog.head + assert check_error_str in ape_caplog.head assert actual is None mock_explorer.get_contract_type.reset_mock() @@ -458,8 +444,7 @@ def test_get_creation_metadata(chain, vyper_contract_instance, owner): def test_delete_contract(vyper_contract_instance, chain): # Ensure we start with it cached. - if vyper_contract_instance.address not in chain.contracts: - chain.contracts[vyper_contract_instance.address] = vyper_contract_instance + chain.contracts[vyper_contract_instance.address] = vyper_contract_instance del chain.contracts[vyper_contract_instance.address] assert vyper_contract_instance.address not in chain.contracts @@ -499,26 +484,22 @@ def test_clear_local_caches(chain, vyper_contract_instance, proxy_contract_conta # Ensure contract type exists. address = vyper_contract_instance.address # Ensure blueprint exists. - chain.contracts._local_blueprints[address] = vyper_contract_instance.contract_type + chain.contracts.blueprints[address] = vyper_contract_instance.contract_type # Ensure proxy exists. proxy = proxy_contract_container.deploy(address, sender=owner) # Ensure creation exists. _ = chain.contracts.get_creation_metadata(address) # Test setup verification. - assert ( - address in chain.contracts._local_contract_types - ), "Setup failed - no contract type(s) cached" - assert proxy.address in chain.contracts._local_proxies, "Setup failed - no proxy cached" - assert ( - address in chain.contracts._local_contract_creation - ), "Setup failed - no creation(s) cached" + assert address in chain.contracts.contract_types, "Setup failed - no contract type(s) cached" + assert proxy.address in chain.contracts.proxy_infos, "Setup failed - no proxy cached" + assert address in chain.contracts.contract_creations, "Setup failed - no creation(s) cached" # This is the method we are testing. chain.contracts.clear_local_caches() # Assertions - everything should be empty. - assert chain.contracts._local_proxies == {} - assert chain.contracts._local_blueprints == {} - assert chain.contracts._local_deployments_mapping == {} - assert chain.contracts._local_contract_creation == {} + assert chain.contracts.proxy_infos.memory == {} + assert chain.contracts.blueprints.memory == {} + assert chain.contracts.contract_types.memory == {} + assert chain.contracts.contract_creations.memory == {} diff --git a/tests/functional/test_ecosystem.py b/tests/functional/test_ecosystem.py index 304611d2a6..c7498af6a1 100644 --- a/tests/functional/test_ecosystem.py +++ b/tests/functional/test_ecosystem.py @@ -1080,7 +1080,7 @@ def test_decode_custom_error(chain, ethereum): addr = cast(AddressType, "0x3fC91A3afd70395Cd496C647d5a6CC9D4B2b7FAD") # Hack in contract-type. - chain.contracts._local_contract_types[addr] = contract_type + chain.contracts[addr] = contract_type actual = ethereum.decode_custom_error(data, addr) assert isinstance(actual, CustomError) @@ -1109,7 +1109,7 @@ def test_decode_custom_error_selector_not_found(chain, ethereum): addr = cast(AddressType, "0x3fC91A3afd70395Cd496C647d5a6CC9D4B2b7FAD") # Hack in contract-type. - chain.contracts._local_contract_types[addr] = contract_type + chain.contracts.contract_types[addr] = contract_type tx = ethereum.create_transaction() actual = ethereum.decode_custom_error(data, addr, txn=tx) diff --git a/tests/functional/test_history.py b/tests/functional/test_history.py index 1d0bacaba4..7009e372a6 100644 --- a/tests/functional/test_history.py +++ b/tests/functional/test_history.py @@ -56,5 +56,4 @@ def get_txns_patch(address): # Actual is 0 because the receipt was cached under the sender. assert len(actual) == 0 finally: - if "explorer" in network.__dict__: - del network.__dict__["explorer"] + network.__dict__.pop("explorer", None) diff --git a/tests/functional/test_project.py b/tests/functional/test_project.py index 561294857d..40482d0329 100644 --- a/tests/functional/test_project.py +++ b/tests/functional/test_project.py @@ -308,10 +308,10 @@ def test_meta(project): assert f"{project.meta.links['apeworx.io']}" == "https://apeworx.io/" -def test_extract_manifest(tmp_project, mock_sepolia, vyper_contract_instance): +def test_extract_manifest(tmp_project, vyper_contract_instance): contract_type = vyper_contract_instance.contract_type tmp_project.manifest.contract_types = {contract_type.name: contract_type} - tmp_project.deployments.track(vyper_contract_instance) + tmp_project.deployments.track(vyper_contract_instance, allow_dev=True) manifest = tmp_project.extract_manifest() assert type(manifest) is PackageManifest @@ -1030,20 +1030,20 @@ def test_values(self, tmp_project): class TestDeploymentManager: @pytest.fixture - def project(self, tmp_project, vyper_contract_instance, mock_sepolia): + def project(self, tmp_project, vyper_contract_instance): contract_type = vyper_contract_instance.contract_type tmp_project.manifest.contract_types = {contract_type.name: contract_type} return tmp_project - def test_track(self, project, vyper_contract_instance, mock_sepolia): - project.deployments.track(vyper_contract_instance) + def test_track(self, project, vyper_contract_instance): + project.deployments.track(vyper_contract_instance, allow_dev=True) deployment = next(iter(project.deployments), None) contract_type = vyper_contract_instance.contract_type assert deployment is not None assert deployment.contract_type == f"{contract_type.source_id}:{contract_type.name}" - def test_instance_map(self, project, vyper_contract_instance, mock_sepolia): - project.deployments.track(vyper_contract_instance) + def test_instance_map(self, project, vyper_contract_instance): + project.deployments.track(vyper_contract_instance, allow_dev=True) assert project.deployments.instance_map != {} bip122_chain_id = to_hex(project.provider.get_block(0).hash) diff --git a/tests/functional/utils/test_basemodel.py b/tests/functional/utils/test_basemodel.py index cd7392398e..24a7724b17 100644 --- a/tests/functional/utils/test_basemodel.py +++ b/tests/functional/utils/test_basemodel.py @@ -3,7 +3,8 @@ from ape.exceptions import ProviderNotConnectedError from ape.logging import logger from ape.managers.project import DependencyManager -from ape.utils.basemodel import ManagerAccessMixin, only_raise_attribute_error +from ape.utils.basemodel import DiskCacheableModel, ManagerAccessMixin, only_raise_attribute_error +from ape.utils.os import create_tempdir class CustomClass(ManagerAccessMixin): @@ -56,3 +57,38 @@ def fn(): def test_dependency_manager(): actual = ManagerAccessMixin.dependency_manager assert isinstance(actual, DependencyManager) + + +class TestDiskCacheableModel: + @pytest.fixture(scope="class") + def ExampleModel(self): + class _ExampleModel(DiskCacheableModel): + aa: int + bb: str + cc: dict[str, dict[str, int]] + + return _ExampleModel + + def test_model_validate_file(self, ExampleModel): + with create_tempdir() as path: + file = path / "example.json" + json_str = '{"aa":123,"bb":"Hello Pydantic!","cc":{"1":{"2":3}}}' + file.write_text(json_str) + instance = ExampleModel.model_validate_file(file) + file.unlink() + + assert instance.aa == 123 + assert instance.bb == "Hello Pydantic!" + assert instance.cc == {"1": {"2": 3}} + # Show the path was already set. + assert instance._path == file + + def test_model_dump_file(self, ExampleModel): + instance = ExampleModel(aa=123, bb="Hello Pydantic!", cc={"1": {"2": 3}}) + expected = '{"aa":123,"bb":"Hello Pydantic!","cc":{"1":{"2":3}}}' + with create_tempdir() as path: + file = path / "example.json" + instance.model_dump_file(file) + actual = file.read_text() + + assert actual == expected