From fee84bb03047ce80d15908b7a9686c658430fbc9 Mon Sep 17 00:00:00 2001 From: antazoey Date: Sat, 26 Oct 2024 12:10:17 -0500 Subject: [PATCH 1/7] fix: issue running tests when `mainnet` is default network and specifying a different network (#2343) --- src/ape/pytest/plugin.py | 9 --------- src/ape/pytest/runners.py | 21 +++++++++++++++++++-- tests/functional/test_test.py | 22 +++++++++++++--------- 3 files changed, 32 insertions(+), 20 deletions(-) diff --git a/src/ape/pytest/plugin.py b/src/ape/pytest/plugin.py index 85d5febd5d..7cfe0c106f 100644 --- a/src/ape/pytest/plugin.py +++ b/src/ape/pytest/plugin.py @@ -16,15 +16,6 @@ def _get_default_network(ecosystem: Optional[EcosystemAPI] = None) -> str: if ecosystem is None: ecosystem = ManagerAccessMixin.network_manager.default_ecosystem - if ecosystem.default_network.is_mainnet: - # Don't use mainnet for tests, even if it configured as - # the default. - raise ConfigError( - "Default network is mainnet; unable to run tests on mainnet. " - "Please specify the network using the `--network` flag or " - "configure a different default network." - ) - return ecosystem.name diff --git a/src/ape/pytest/runners.py b/src/ape/pytest/runners.py index 0445e8f446..98a708e99b 100644 --- a/src/ape/pytest/runners.py +++ b/src/ape/pytest/runners.py @@ -7,6 +7,7 @@ from rich import print as rich_print from ape.api.networks import ProviderContextManager +from ape.exceptions import ConfigError from ape.logging import LogLevel from ape.pytest.config import ConfigWrapper from ape.pytest.coverage import CoverageTracker @@ -206,8 +207,24 @@ def pytest_collection_finish(self, session): # Only start provider if collected tests. if not outcome.get_result() and session.items: - self._provider_context.push_provider() - self._provider_is_connected = True + self._connect() + + def _connect(self): + if self._provider_context._provider.network.is_mainnet: + # Ensure is not only running on tests on mainnet because + # was configured as the default. + is_from_command_line = ( + "--network" in self.config_wrapper.pytest_config.invocation_params.args + ) + if not is_from_command_line: + raise ConfigError( + "Default network is mainnet; unable to run tests on mainnet. " + "Please specify the network using the `--network` flag or " + "configure a different default network." + ) + + self._provider_context.push_provider() + self._provider_is_connected = True def pytest_terminal_summary(self, terminalreporter): """ diff --git a/tests/functional/test_test.py b/tests/functional/test_test.py index 9666170c56..56ee09c88a 100644 --- a/tests/functional/test_test.py +++ b/tests/functional/test_test.py @@ -1,7 +1,7 @@ import pytest from ape.exceptions import ConfigError -from ape.pytest.plugin import _get_default_network +from ape.pytest.runners import PytestApeRunner from ape_test import ApeTestConfig @@ -15,17 +15,21 @@ def test_balance_set_from_currency_str(self): assert actual == expected -def test_get_default_network(mocker): - # NOTE: Using this weird test to avoid actually - # using mainnet in any test, even accidentally. - mock_ecosystem = mocker.MagicMock() - mock_mainnet = mocker.MagicMock() - mock_mainnet.name = "mainnet" - mock_ecosystem.default_network = mock_mainnet +def test_connect_to_mainnet_by_default(mocker): + """ + Tests the condition where mainnet is configured as the default network + and no --network option is passed. It should avoid running the tests + to be safe. + """ + + cfg = mocker.MagicMock() + cfg.network = "ethereum:mainnet:node" + runner = PytestApeRunner(cfg, mocker.MagicMock(), mocker.MagicMock(), mocker.MagicMock()) + expected = ( "Default network is mainnet; unable to run tests on mainnet. " "Please specify the network using the `--network` flag or " "configure a different default network." ) with pytest.raises(ConfigError, match=expected): - _get_default_network(mock_mainnet) + runner._connect() From 54dcc57857f6a7c7411eca61840dddea3fe3aa27 Mon Sep 17 00:00:00 2001 From: antazoey Date: Sat, 26 Oct 2024 12:51:17 -0500 Subject: [PATCH 2/7] fix: callable trace bug when estimating gas on revert-tx with no fail message (#2344) --- src/ape_ethereum/provider.py | 17 +++++++++++------ tests/functional/geth/test_provider.py | 11 ++++++++++- 2 files changed, 21 insertions(+), 7 deletions(-) diff --git a/src/ape_ethereum/provider.py b/src/ape_ethereum/provider.py index 63d4e16ba0..53b63b2bf2 100644 --- a/src/ape_ethereum/provider.py +++ b/src/ape_ethereum/provider.py @@ -42,6 +42,8 @@ from ape.api.trace import TraceAPI from ape.api.transactions import ReceiptAPI, TransactionAPI from ape.exceptions import ( + _SOURCE_TRACEBACK_ARG, + _TRACE_ARG, ApeException, APINotImplementedError, BlockNotFoundError, @@ -1245,9 +1247,9 @@ def _handle_execution_reverted( self, exception: Union[Exception, str], txn: Optional[TransactionAPI] = None, - trace: Optional[TraceAPI] = None, + trace: _TRACE_ARG = None, contract_address: Optional[AddressType] = None, - source_traceback: Optional[SourceTraceback] = None, + source_traceback: _SOURCE_TRACEBACK_ARG = None, set_ape_traceback: Optional[bool] = None, ) -> ContractLogicError: if hasattr(exception, "args") and len(exception.args) == 2: @@ -1277,10 +1279,13 @@ def _handle_execution_reverted( if trace is None and txn is not None: trace = self.provider.get_transaction_trace(to_hex(txn.txn_hash)) - if trace is not None and (revert_message := trace.revert_message): - message = revert_message - no_reason = False - if revert_message := trace.revert_message: + if trace is not None: + if callable(trace): + trace_called = params["trace"] = trace() + else: + trace_called = trace + + if trace_called is not None and (revert_message := trace_called.revert_message): message = revert_message no_reason = False diff --git a/tests/functional/geth/test_provider.py b/tests/functional/geth/test_provider.py index 27f2a7043d..57cb676451 100644 --- a/tests/functional/geth/test_provider.py +++ b/tests/functional/geth/test_provider.py @@ -707,12 +707,21 @@ def test_estimate_gas_cost_of_static_fee_txn(geth_contract, geth_provider, geth_ @geth_process_test -def test_estimate_gas_cost_reverts(geth_contract, geth_provider, geth_second_account): +def test_estimate_gas_cost_reverts_with_message(geth_contract, geth_provider, geth_second_account): + # NOTE: The error message from not-owner is "!authorized". txn = geth_contract.setNumber.as_transaction(900, sender=geth_second_account, type=0) with pytest.raises(ContractLogicError): geth_provider.estimate_gas_cost(txn) +@geth_process_test +def test_estimate_gas_cost_reverts_no_message(geth_contract, geth_provider, geth_account): + # NOTE: The error message from using `5` has no revert message. + txn = geth_contract.setNumber.as_transaction(5, sender=geth_account, type=0) + with pytest.raises(ContractLogicError): + geth_provider.estimate_gas_cost(txn) + + @geth_process_test @pytest.mark.parametrize("tx_type", TransactionType) def test_prepare_transaction_with_max_gas(tx_type, geth_provider, ethereum, geth_account): From 09492dcbe647c56322e6e491ee1d8c6db11602af Mon Sep 17 00:00:00 2001 From: antazoey Date: Sat, 26 Oct 2024 14:17:32 -0500 Subject: [PATCH 3/7] fix: delay hex-int validation error when None values for non-optional (#2346) --- src/ape/types/basic.py | 6 ++++++ tests/functional/test_types.py | 15 ++++++++++++++- 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/src/ape/types/basic.py b/src/ape/types/basic.py index cdd8a42698..56899b1209 100644 --- a/src/ape/types/basic.py +++ b/src/ape/types/basic.py @@ -6,7 +6,13 @@ def _hex_int_validator(value, info): + if value is None: + # If not optional, will allow pydantic to (better) handle the error. + return value + + # NOTE: Allows this module to load lazier. access = import_module("ape.utils.basemodel").ManagerAccessMixin + convert = access.conversion_manager.convert return convert(value, int) diff --git a/tests/functional/test_types.py b/tests/functional/test_types.py index 37c2a56b20..41e916c2c8 100644 --- a/tests/functional/test_types.py +++ b/tests/functional/test_types.py @@ -4,7 +4,7 @@ from eth_utils import to_hex from ethpm_types.abi import EventABI from hexbytes import HexBytes -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, ValidationError from ape.types.address import AddressType from ape.types.basic import HexInt @@ -144,6 +144,19 @@ class MyModel(BaseModel): assert act.ual == expected assert act.ual_optional is None + def test_none(self): + """ + Was getting unhelpful conversion errors here. We should instead + let Pydantic fail as it normally does in this situation. + """ + + class MyModel(BaseModel): + an_int: HexInt + + expected = ".*Input should be a valid integer.*" + with pytest.raises(ValidationError, match=expected): + _ = MyModel(an_int=None) + class TestCurrencyValueComparable: def test_use_for_int_in_pydantic_model(self): From 6b584a26ff6217506e8b1fc03a29de9cabf2d51c Mon Sep 17 00:00:00 2001 From: antazoey Date: Sat, 26 Oct 2024 16:31:43 -0500 Subject: [PATCH 4/7] fix: mistaken blob receipt causing serialization problems (#2345) --- src/ape_ethereum/ecosystem.py | 16 ++++++++++++--- tests/functional/test_ecosystem.py | 33 ++++++++++++++++++++++++++++++ 2 files changed, 46 insertions(+), 3 deletions(-) diff --git a/src/ape_ethereum/ecosystem.py b/src/ape_ethereum/ecosystem.py index f08c7a980d..82636f65fa 100644 --- a/src/ape_ethereum/ecosystem.py +++ b/src/ape_ethereum/ecosystem.py @@ -598,9 +598,19 @@ def decode_receipt(self, data: dict) -> ReceiptAPI: "blob_gas_used", ) ): - receipt_cls = SharedBlobReceipt - receipt_kwargs["blobGasPrice"] = data.get("blob_gas_price", data.get("blobGasPrice")) - receipt_kwargs["blobGasUsed"] = data.get("blob_gas_used", data.get("blobGasUsed")) or 0 + blob_gas_price = data.get("blob_gas_price", data.get("blobGasPrice")) + if blob_gas_price is None: + # Not actually a blob-receipt? Some providers may give you + # empty values here when meaning the other types of receipts. + receipt_cls = Receipt + + else: + receipt_cls = SharedBlobReceipt + receipt_kwargs["blobGasPrice"] = blob_gas_price + receipt_kwargs["blobGasUsed"] = ( + data.get("blob_gas_used", data.get("blobGasUsed")) or 0 + ) + else: receipt_cls = Receipt diff --git a/tests/functional/test_ecosystem.py b/tests/functional/test_ecosystem.py index 08b3170c9a..981afecd80 100644 --- a/tests/functional/test_ecosystem.py +++ b/tests/functional/test_ecosystem.py @@ -631,6 +631,39 @@ def test_decode_receipt_shared_blob(ethereum, blob_gas_used, blob_gas_key): assert actual.blob_gas_used == 0 +def test_decode_receipt_misleading_blob_receipt(ethereum): + """ + Tests a strange situation (noticed on Tenderly nodes) where _some_ + of the keys indicate blob-related fields, set to ``0``, and others + are missing, because it's not actually a blob receipt. In this case, + don't use the blob-receipt class. + """ + data = { + "type": 2, + "status": 1, + "cumulativeGasUsed": 10565720, + "logsBloom": HexBytes( + "0x00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000" # noqa: E501 + ), + "logs": [], + "transactionHash": HexBytes( + "0x62fc9991bc7fb0c76bc83faaa8d1c17fc5efb050542e58ac358932f80aa7a087" + ), + "from": "0x1f9090aaE28b8a3dCeaDf281B0F12828e676c326", + "to": "0xeBec795c9c8bBD61FFc14A6662944748F299cAcf", + "contractAddress": None, + "gasUsed": 21055, + "effectiveGasPrice": 7267406643, + "blockHash": HexBytes("0xa47fc133f829183b751488c1146f1085451bcccd247db42066dc6c89eaf5ebac"), + "blockNumber": 21051245, + "transactionIndex": 130, + "blobGasUsed": 0, + } + actual = ethereum.decode_receipt(data) + assert not isinstance(actual, SharedBlobReceipt) + assert isinstance(actual, Receipt) + + def test_default_transaction_type_not_connected_used_default_network(project, ethereum, networks): value = TransactionType.STATIC.value config_dict = {"ethereum": {"mainnet_fork": {"default_transaction_type": value}}} From 4199cac8ed7cf08a9039d1a242b39bd7cff6a8f0 Mon Sep 17 00:00:00 2001 From: antazoey Date: Sat, 26 Oct 2024 20:48:39 -0500 Subject: [PATCH 5/7] perf: hardcoded trusted plugin during `--help` check (#2347) --- src/ape/_cli.py | 2 +- src/ape/plugins/_utils.py | 63 +++++++++++++++++++++++++++++++++++++-- 2 files changed, 62 insertions(+), 3 deletions(-) diff --git a/src/ape/_cli.py b/src/ape/_cli.py index 50d5e1b186..4781af0a7a 100644 --- a/src/ape/_cli.py +++ b/src/ape/_cli.py @@ -91,7 +91,7 @@ def format_commands(self, ctx, formatter) -> None: if plugin.in_core: sections["Core"].append((cli_name, help)) - elif plugin.is_installed and not plugin.is_third_party: + elif plugin.check_trusted(use_web=False): sections["Plugin"].append((cli_name, help)) else: sections["3rd-Party Plugin"].append((cli_name, help)) diff --git a/src/ape/plugins/_utils.py b/src/ape/plugins/_utils.py index 66510d3cfb..123580e288 100644 --- a/src/ape/plugins/_utils.py +++ b/src/ape/plugins/_utils.py @@ -37,6 +37,47 @@ "ape_run", "ape_test", ] +# Hardcoded for performance reasons. Functionality in plugins commands +# and functions won't use this; they use GitHub to check directly. +# This hardcoded list is useful for `ape --help`. If ApeWorX adds a new +# trusted plugin, it should be added to this list; else it will show +# as 3rd-party in `ape --help`. +TRUSTED_PLUGINS = [ + "addressbook", + "alchemy", + "arbitrum", + "avalanche", + "aws", + "base", + "blast", + "blockscout", + "bsc", + "cairo", + "chainstack", + "ens", + "etherscan", + "fantom", + "farcaster", + "flashbots", + "foundry", + "frame", + "ganache", + "hardhat", + "infura", + "ledger", + "notebook", + "optimism", + "polygon", + "polygon_zkevm", + "safe", + "solidity", + "template", + "tenderly", + "titanoboa", + "tokens", + "trezor", + "vyper", +] def clean_plugin_name(name: str) -> str: @@ -378,14 +419,23 @@ def is_installed(self) -> bool: @property def is_third_party(self) -> bool: - return self.is_installed and not self.is_available + """ + ``True`` when is an installed plugin that is not from ApeWorX. + """ + return self.is_installed and not self.is_trusted + + @cached_property + def is_trusted(self) -> bool: + """ + ``True`` when is a plugin from ApeWorX. + """ + return self.check_trusted() @property def is_available(self) -> bool: """ Whether the plugin is maintained by the ApeWorX organization. """ - return self.module_name in _get_available_plugins() def __str__(self) -> str: @@ -404,6 +454,15 @@ def check_installed(self, use_cache: bool = True) -> bool: return any(n == self.package_name for n in get_plugin_dists()) + def check_trusted(self, use_web: bool = True) -> bool: + if use_web: + return self.is_available + + else: + # Sometimes (such as for --help commands), it is better + # to not check GitHub to see if the plugin is trusted. + return self.name in TRUSTED_PLUGINS + def _prepare_install( self, upgrade: bool = False, skip_confirmation: bool = False ) -> Optional[dict[str, Any]]: From 46a1d2cfbee44bba1bc1002bf5757bca80d90fb4 Mon Sep 17 00:00:00 2001 From: antazoey Date: Tue, 29 Oct 2024 12:35:46 -0500 Subject: [PATCH 6/7] perf: more things that make `ape --help` way faster (#2351) --- src/ape/_cli.py | 33 ++++---- src/ape/cli/__init__.py | 2 + src/ape/cli/arguments.py | 21 +++-- src/ape/cli/choices.py | 57 +++++++++++--- src/ape/cli/commands.py | 3 +- src/ape/cli/options.py | 27 ++++--- src/ape/exceptions.py | 15 ++-- src/ape/plugins/account.py | 7 +- src/ape/plugins/compiler.py | 9 ++- src/ape/plugins/config.py | 7 +- src/ape/plugins/converter.py | 8 +- src/ape/plugins/network.py | 22 ++++-- src/ape/plugins/project.py | 10 ++- src/ape/utils/testing.py | 10 ++- src/ape_accounts/__init__.py | 19 ++--- src/ape_accounts/_cli.py | 31 ++++++-- src/ape_cache/__init__.py | 22 +++--- src/ape_cache/config.py | 5 ++ src/ape_compile/__init__.py | 101 ++++--------------------- src/ape_compile/_cli.py | 10 ++- src/ape_compile/config.py | 89 ++++++++++++++++++++++ src/ape_console/__init__.py | 7 +- src/ape_console/_cli.py | 10 ++- src/ape_networks/__init__.py | 48 ++++-------- src/ape_networks/_cli.py | 22 +++--- src/ape_networks/config.py | 38 ++++++++++ src/ape_plugins/__init__.py | 7 +- tests/functional/test_compilers.py | 2 +- tests/integration/cli/test_accounts.py | 2 +- 29 files changed, 398 insertions(+), 246 deletions(-) create mode 100644 src/ape_cache/config.py create mode 100644 src/ape_compile/config.py create mode 100644 src/ape_networks/config.py diff --git a/src/ape/_cli.py b/src/ape/_cli.py index 4781af0a7a..caf223ce0a 100644 --- a/src/ape/_cli.py +++ b/src/ape/_cli.py @@ -1,13 +1,14 @@ import difflib import re import sys -import warnings from collections.abc import Iterable +from functools import cached_property from gettext import gettext from importlib import import_module from importlib.metadata import entry_points from pathlib import Path from typing import Any, Optional +from warnings import catch_warnings, simplefilter import click import rich @@ -17,7 +18,6 @@ from ape.cli.options import ape_cli_context from ape.exceptions import Abort, ApeException, ConfigError, handle_ape_exception from ape.logging import logger -from ape.utils.basemodel import ManagerAccessMixin as access _DIFFLIB_CUT_OFF = 0.6 @@ -27,6 +27,8 @@ def display_config(ctx, param, value): if not value or ctx.resilient_parsing: return + from ape.utils.basemodel import ManagerAccessMixin as access + click.echo("# Current configuration") # NOTE: Using json-mode as yaml.dump requires JSON-like structure. @@ -37,6 +39,8 @@ def display_config(ctx, param, value): def _validate_config(): + from ape.utils.basemodel import ManagerAccessMixin as access + project = access.local_project try: _ = project.config @@ -47,7 +51,6 @@ def _validate_config(): class ApeCLI(click.MultiCommand): - _commands: Optional[dict] = None _CLI_GROUP_NAME = "ape_cli_subcommands" def parse_args(self, ctx: Context, args: list[str]) -> list[str]: @@ -60,6 +63,8 @@ def parse_args(self, ctx: Context, args: list[str]) -> list[str]: return super().parse_args(ctx, args) def format_commands(self, ctx, formatter) -> None: + from ape.utils.basemodel import ManagerAccessMixin as access + commands = [] for subcommand in self.list_commands(ctx): cmd = self.get_command(ctx, subcommand) @@ -142,25 +147,21 @@ def _suggest_cmd(usage_error): raise usage_error - @property + @cached_property def commands(self) -> dict: - if self._commands: - return self._commands - _entry_points = entry_points() eps: Iterable - if select_fn := getattr(_entry_points, "select", None): - # NOTE: Using getattr because mypy. - eps = select_fn(group=self._CLI_GROUP_NAME) - else: - # Python 3.9. Can remove once we drop support. - with warnings.catch_warnings(): - warnings.simplefilter("ignore") + + try: + eps = _entry_points.select(group=self._CLI_GROUP_NAME) + except AttributeError: + # Fallback for Python 3.9 + with catch_warnings(): + simplefilter("ignore") eps = _entry_points.get(self._CLI_GROUP_NAME, []) # type: ignore commands = {cmd.name.replace("_", "-").replace("ape-", ""): cmd.load for cmd in eps} - self._commands = {k: commands[k] for k in sorted(commands)} - return self._commands + return dict(sorted(commands.items())) def list_commands(self, ctx) -> list[str]: return [k for k in self.commands] diff --git a/src/ape/cli/__init__.py b/src/ape/cli/__init__.py index 4e0db3e853..262712c414 100644 --- a/src/ape/cli/__init__.py +++ b/src/ape/cli/__init__.py @@ -6,6 +6,7 @@ from ape.cli.choices import ( AccountAliasPromptChoice, Alias, + LazyChoice, NetworkChoice, OutputFormat, PromptChoice, @@ -42,6 +43,7 @@ "existing_alias_argument", "incompatible_with", "JSON", + "LazyChoice", "network_option", "NetworkChoice", "NetworkOption", diff --git a/src/ape/cli/arguments.py b/src/ape/cli/arguments.py index 75a09dc4d7..71701ad1a9 100644 --- a/src/ape/cli/arguments.py +++ b/src/ape/cli/arguments.py @@ -7,15 +7,14 @@ from ape.cli.choices import _ACCOUNT_TYPE_FILTER, Alias from ape.logging import logger -from ape.utils.basemodel import ManagerAccessMixin -from ape.utils.os import get_full_extension -from ape.utils.validators import _validate_account_alias if TYPE_CHECKING: from ape.managers.project import ProjectManager def _alias_callback(ctx, param, value): + from ape.utils.validators import _validate_account_alias + return _validate_account_alias(value) @@ -28,7 +27,6 @@ def existing_alias_argument(account_type: _ACCOUNT_TYPE_FILTER = None, **kwargs) If given, limits the type of account the user may choose from. **kwargs: click.argument overrides. """ - type_ = kwargs.pop("type", Alias(key=account_type)) return click.argument("alias", type=type_, **kwargs) @@ -45,12 +43,14 @@ def non_existing_alias_argument(**kwargs): return click.argument("alias", callback=callback, **kwargs) -class _ContractPaths(ManagerAccessMixin): +class _ContractPaths: """ Helper callback class for handling CLI-given contract paths. """ def __init__(self, value, project: Optional["ProjectManager"] = None): + from ape.utils.basemodel import ManagerAccessMixin + self.value = value self.missing_compilers: set[str] = set() # set of .ext self.project = project or ManagerAccessMixin.local_project @@ -105,14 +105,21 @@ def filtered_paths(self) -> set[Path]: @property def exclude_patterns(self) -> set[str]: - return self.config_manager.get_config("compile").exclude or set() + from ape.utils.basemodel import ManagerAccessMixin as access + + return access.config_manager.get_config("compile").exclude or set() def do_exclude(self, path: Union[Path, str]) -> bool: return self.project.sources.is_excluded(path) def compiler_is_unknown(self, path: Union[Path, str]) -> bool: + from ape.utils.basemodel import ManagerAccessMixin + from ape.utils.os import get_full_extension + ext = get_full_extension(path) - unknown_compiler = ext and ext not in self.compiler_manager.registered_compilers + unknown_compiler = ( + ext and ext not in ManagerAccessMixin.compiler_manager.registered_compilers + ) if unknown_compiler and ext not in self.missing_compilers: self.missing_compilers.add(ext) diff --git a/src/ape/cli/choices.py b/src/ape/cli/choices.py index 17a8bec427..e7dc752107 100644 --- a/src/ape/cli/choices.py +++ b/src/ape/cli/choices.py @@ -14,7 +14,6 @@ NetworkNotFoundError, ProviderNotFoundError, ) -from ape.utils.basemodel import ManagerAccessMixin as access if TYPE_CHECKING: from ape.api.accounts import AccountAPI @@ -26,6 +25,8 @@ def _get_accounts(key: _ACCOUNT_TYPE_FILTER) -> list["AccountAPI"]: + from ape.utils.basemodel import ManagerAccessMixin as access + accounts = access.account_manager add_test_accounts = False @@ -68,8 +69,11 @@ def __init__(self, key: _ACCOUNT_TYPE_FILTER = None): # NOTE: we purposely skip the constructor of `Choice` self.case_sensitive = False self._key_filter = key + + @cached_property + def choices(self) -> Sequence: # type: ignore[override] module = import_module("ape.types.basic") - self.choices = module._LazySequence(self._choices_iterator) + return module._LazySequence(self._choices_iterator) @property def _choices_iterator(self) -> Iterator[str]: @@ -206,6 +210,8 @@ def convert( else: alias = value + from ape.utils.basemodel import ManagerAccessMixin as access + accounts = access.account_manager if isinstance(alias, str) and alias.upper().startswith("TEST::"): idx_str = alias.upper().replace("TEST::", "") @@ -235,6 +241,8 @@ def print_choices(self): click.echo(f"{idx}. {choice}") did_print = True + from ape.utils.basemodel import ManagerAccessMixin as access + accounts = access.account_manager len_test_accounts = len(accounts.test_accounts) - 1 if len_test_accounts > 0: @@ -261,6 +269,7 @@ def select_account(self) -> "AccountAPI": Returns: :class:`~ape.api.accounts.AccountAPI` """ + from ape.utils.basemodel import ManagerAccessMixin as access accounts = access.account_manager if not self.choices or len(self.choices) == 0: @@ -348,12 +357,7 @@ def __init__( base_type: Optional[type] = None, callback: Optional[Callable] = None, ): - provider_module = import_module("ape.api.providers") - base_type = provider_module.ProviderAPI if base_type is None else base_type - if not issubclass(base_type, (provider_module.ProviderAPI, str)): - raise TypeError(f"Unhandled type '{base_type}' for NetworkChoice.") - - self.base_type = base_type + self._base_type = base_type self.callback = callback self.case_sensitive = case_sensitive self.ecosystem = ecosystem @@ -361,6 +365,21 @@ def __init__( self.provider = provider # NOTE: Purposely avoid super().init for performance reasons. + @property + def base_type(self) -> type["ProviderAPI"]: + # perf: property exists to delay import ProviderAPI at init time. + from ape.api.providers import ProviderAPI + + if self._base_type is not None: + return self._base_type + + self._base_type = ProviderAPI + return ProviderAPI + + @base_type.setter + def base_type(self, value): + self._base_type = value + @cached_property def choices(self) -> Sequence[Any]: # type: ignore[override] return get_networks(ecosystem=self.ecosystem, network=self.network, provider=self.provider) @@ -369,6 +388,8 @@ def get_metavar(self, param): return "[ecosystem-name][:[network-name][:[provider-name]]]" def convert(self, value: Any, param: Optional[Parameter], ctx: Optional[Context]) -> Any: + from ape.utils.basemodel import ManagerAccessMixin as access + choice: Optional[Union[str, "ProviderAPI"]] networks = access.network_manager if not value: @@ -406,8 +427,9 @@ def convert(self, value: Any, param: Optional[Parameter], ctx: Optional[Context] ) from err if choice not in (None, _NONE_NETWORK) and isinstance(choice, str): - provider_module = import_module("ape.api.providers") - if issubclass(self.base_type, provider_module.ProviderAPI): + from ape.api.providers import ProviderAPI + + if issubclass(self.base_type, ProviderAPI): # Return the provider. choice = networks.get_provider_from_choice(network_choice=value) @@ -454,3 +476,18 @@ def output_format_choice(options: Optional[list[OutputFormat]] = None) -> Choice # Uses `str` form of enum for CLI choices. return click.Choice([o.value for o in options], case_sensitive=False) + + +class LazyChoice(Choice): + """ + A simple lazy-choice where choices are evaluated lazily. + """ + + def __init__(self, get_choices: Callable[[], Sequence[str]], case_sensitive: bool = False): + self._get_choices = get_choices + self.case_sensitive = case_sensitive + # Note: Purposely avoid super init. + + @cached_property + def choices(self) -> Sequence[str]: # type: ignore[override] + return self._get_choices() diff --git a/src/ape/cli/commands.py b/src/ape/cli/commands.py index 00610d2792..fb4d305363 100644 --- a/src/ape/cli/commands.py +++ b/src/ape/cli/commands.py @@ -7,7 +7,6 @@ from ape.cli.choices import _NONE_NETWORK, NetworkChoice from ape.exceptions import NetworkError -from ape.utils.basemodel import ManagerAccessMixin as access if TYPE_CHECKING: from ape.api.networks import ProviderContextManager @@ -26,6 +25,8 @@ def get_param_from_ctx(ctx: Context, param: str) -> Optional[Any]: def parse_network(ctx: Context) -> Optional["ProviderContextManager"]: + from ape.utils.basemodel import ManagerAccessMixin as access + interactive = get_param_from_ctx(ctx, "interactive") # Handle if already parsed (as when using network-option) diff --git a/src/ape/cli/options.py b/src/ape/cli/options.py index 60d1da3b26..508051e13c 100644 --- a/src/ape/cli/options.py +++ b/src/ape/cli/options.py @@ -3,11 +3,10 @@ from functools import partial from importlib import import_module from pathlib import Path -from typing import Any, NoReturn, Optional, Union +from typing import TYPE_CHECKING, Any, NoReturn, Optional, Union import click from click import Option -from ethpm_types import ContractType from ape.cli.choices import ( _ACCOUNT_TYPE_FILTER, @@ -21,12 +20,14 @@ from ape.cli.paramtype import JSON, Noop from ape.exceptions import Abort, ProjectError from ape.logging import DEFAULT_LOG_LEVEL, ApeLogger, LogLevel, logger -from ape.utils.basemodel import ManagerAccessMixin + +if TYPE_CHECKING: + from ethpm_types.contract_type import ContractType _VERBOSITY_VALUES = ("--verbosity", "-v") -class ApeCliContextObject(ManagerAccessMixin, dict): +class ApeCliContextObject(dict): """ A ``click`` context object class. Use via :meth:`~ape.cli.options.ape_cli_context()`. It provides common CLI utilities for ape, such as logging or @@ -45,6 +46,8 @@ def __getattr__(self, item: str) -> Any: try: return self.__getattribute__(item) except AttributeError: + from ape.utils.basemodel import ManagerAccessMixin + return getattr(ManagerAccessMixin, item) @staticmethod @@ -174,14 +177,12 @@ def __init__(self, *args, **kwargs) -> None: provider = kwargs.pop("provider", None) default = kwargs.pop("default", "auto") - provider_module = import_module("ape.api.providers") - base_type = kwargs.pop("base_type", provider_module.ProviderAPI) - callback = kwargs.pop("callback", None) # NOTE: If using network_option, this part is skipped # because parsing happens earlier to handle advanced usage. if not kwargs.get("type"): + base_type = kwargs.pop("base_type", None) kwargs["type"] = NetworkChoice( case_sensitive=False, ecosystem=ecosystem, @@ -204,6 +205,8 @@ def __init__(self, *args, **kwargs) -> None: else: # NOTE: Use a function as the default so it is calculated lazily def fn(): + from ape.utils.basemodel import ManagerAccessMixin + return ManagerAccessMixin.network_manager.default_ecosystem.name default = fn @@ -344,6 +347,8 @@ def _update_context_with_network(ctx, provider, requested_network_objects): def _get_provider(value, default, keep_as_choice_str): + from ape.utils.basemodel import ManagerAccessMixin + use_default = value is None and default == "auto" provider_module = import_module("ape.api.providers") ProviderAPI = provider_module.ProviderAPI @@ -431,10 +436,12 @@ def account_option(account_type: _ACCOUNT_TYPE_FILTER = None) -> Callable: ) -def _load_contracts(ctx, param, value) -> Optional[Union[ContractType, list[ContractType]]]: +def _load_contracts(ctx, param, value) -> Optional[Union["ContractType", list["ContractType"]]]: if not value: return None + from ape.utils.basemodel import ManagerAccessMixin + if len(ManagerAccessMixin.local_project.contracts) == 0: raise ProjectError("Project has no contracts.") @@ -442,7 +449,7 @@ def _load_contracts(ctx, param, value) -> Optional[Union[ContractType, list[Cont # and therefore we should also return a list. is_multiple = isinstance(value, (tuple, list)) - def get_contract(contract_name: str) -> ContractType: + def get_contract(contract_name: str) -> "ContractType": if contract_name not in ManagerAccessMixin.local_project.contracts: raise ProjectError(f"No contract named '{value}'") @@ -523,6 +530,8 @@ def handle_parse_result(self, ctx, opts, args): def _project_callback(ctx, param, val): + from ape.utils.basemodel import ManagerAccessMixin + pm = None if not val: pm = ManagerAccessMixin.local_project diff --git a/src/ape/exceptions.py b/src/ape/exceptions.py index 19ab04e988..db50583303 100644 --- a/src/ape/exceptions.py +++ b/src/ape/exceptions.py @@ -14,13 +14,14 @@ import click from eth_typing import Hash32, HexStr from eth_utils import humanize_hash, to_hex -from ethpm_types import ContractType -from ethpm_types.abi import ConstructorABI, ErrorABI, MethodABI from rich import print as rich_print from ape.logging import LogLevel, logger if TYPE_CHECKING: + from ethpm_types.abi import ConstructorABI, ErrorABI, MethodABI + from ethpm_types.contract_type import ContractType + from ape.api.networks import NetworkAPI from ape.api.providers import SubprocessProvider from ape.api.trace import TraceAPI @@ -90,7 +91,7 @@ class MissingDeploymentBytecodeError(ContractDataError): Raised when trying to deploy an interface or empty data. """ - def __init__(self, contract_type: ContractType): + def __init__(self, contract_type: "ContractType"): message = "Cannot deploy: contract" if name := contract_type.name: message = f"{message} '{name}'" @@ -109,7 +110,7 @@ class ArgumentsLengthError(ContractDataError): def __init__( self, arguments_length: int, - inputs: Union[MethodABI, ConstructorABI, int, list, None] = None, + inputs: Union["MethodABI", "ConstructorABI", int, list, None] = None, **kwargs, ): prefix = ( @@ -120,7 +121,7 @@ def __init__( super().__init__(f"{prefix}.") return - inputs_ls: list[Union[MethodABI, ConstructorABI, int]] = ( + inputs_ls: list[Union["MethodABI", "ConstructorABI", int]] = ( inputs if isinstance(inputs, list) else [inputs] ) if not inputs_ls: @@ -223,7 +224,7 @@ def address(self) -> Optional["AddressType"]: return receiver @cached_property - def contract_type(self) -> Optional[ContractType]: + def contract_type(self) -> Optional["ContractType"]: if not (address := self.address): # Contract address not found. return None @@ -849,7 +850,7 @@ class CustomError(ContractLogicError): def __init__( self, - abi: ErrorABI, + abi: "ErrorABI", inputs: dict[str, Any], txn: Optional[FailedTxn] = None, trace: _TRACE_ARG = None, diff --git a/src/ape/plugins/account.py b/src/ape/plugins/account.py index 0598ca625d..f3e78ae3f2 100644 --- a/src/ape/plugins/account.py +++ b/src/ape/plugins/account.py @@ -1,7 +1,10 @@ -from ape.api.accounts import AccountAPI, AccountContainerAPI +from typing import TYPE_CHECKING from .pluggy_patch import PluginType, hookspec +if TYPE_CHECKING: + from ape.api.accounts import AccountAPI, AccountContainerAPI + class AccountPlugin(PluginType): """ @@ -13,7 +16,7 @@ class AccountPlugin(PluginType): @hookspec def account_types( # type: ignore[empty-body] self, - ) -> tuple[type[AccountContainerAPI], type[AccountAPI]]: + ) -> tuple[type["AccountContainerAPI"], type["AccountAPI"]]: """ A hook for returning a tuple of an account container and an account type. Each account-base plugin defines and returns their own types here. diff --git a/src/ape/plugins/compiler.py b/src/ape/plugins/compiler.py index 650e8bee3f..7fbf151557 100644 --- a/src/ape/plugins/compiler.py +++ b/src/ape/plugins/compiler.py @@ -1,7 +1,10 @@ -from ape.api.compiler import CompilerAPI +from typing import TYPE_CHECKING from .pluggy_patch import PluginType, hookspec +if TYPE_CHECKING: + from ape.api.compiler import CompilerAPI + class CompilerPlugin(PluginType): """ @@ -11,7 +14,9 @@ class CompilerPlugin(PluginType): """ @hookspec - def register_compiler(self) -> tuple[tuple[str], type[CompilerAPI]]: # type: ignore[empty-body] + def register_compiler( # type: ignore[empty-body] + self, + ) -> tuple[tuple[str], type["CompilerAPI"]]: """ A hook for returning the set of file extensions the plugin handles and the compiler class that can be used to compile them. diff --git a/src/ape/plugins/config.py b/src/ape/plugins/config.py index 933d407729..338b74c1f1 100644 --- a/src/ape/plugins/config.py +++ b/src/ape/plugins/config.py @@ -1,7 +1,10 @@ -from ape.api.config import PluginConfig +from typing import TYPE_CHECKING from .pluggy_patch import PluginType, hookspec +if TYPE_CHECKING: + from ape.api.config import PluginConfig + class Config(PluginType): """ @@ -12,7 +15,7 @@ class Config(PluginType): """ @hookspec - def config_class(self) -> type[PluginConfig]: # type: ignore[empty-body] + def config_class(self) -> type["PluginConfig"]: # type: ignore[empty-body] """ A hook that returns a :class:`~ape.api.config.PluginConfig` parser class that can be used to deconstruct the user config options for this plugins. diff --git a/src/ape/plugins/converter.py b/src/ape/plugins/converter.py index ac2a01232d..8f6c02e513 100644 --- a/src/ape/plugins/converter.py +++ b/src/ape/plugins/converter.py @@ -1,9 +1,11 @@ from collections.abc import Iterator - -from ape.api.convert import ConverterAPI +from typing import TYPE_CHECKING from .pluggy_patch import PluginType, hookspec +if TYPE_CHECKING: + from ape.api.convert import ConverterAPI + class ConversionPlugin(PluginType): """ @@ -12,7 +14,7 @@ class ConversionPlugin(PluginType): """ @hookspec - def converters(self) -> Iterator[tuple[str, type[ConverterAPI]]]: # type: ignore[empty-body] + def converters(self) -> Iterator[tuple[str, type["ConverterAPI"]]]: # type: ignore[empty-body] """ A hook that returns an iterator of tuples of a string ABI type and a ``ConverterAPI`` subclass. diff --git a/src/ape/plugins/network.py b/src/ape/plugins/network.py index 45aa93e8c2..dbd58a5cac 100644 --- a/src/ape/plugins/network.py +++ b/src/ape/plugins/network.py @@ -1,11 +1,13 @@ from collections.abc import Iterator - -from ape.api.explorers import ExplorerAPI -from ape.api.networks import EcosystemAPI, NetworkAPI -from ape.api.providers import ProviderAPI +from typing import TYPE_CHECKING from .pluggy_patch import PluginType, hookspec +if TYPE_CHECKING: + from ape.api.explorers import ExplorerAPI + from ape.api.networks import EcosystemAPI, NetworkAPI + from ape.api.providers import ProviderAPI + class EcosystemPlugin(PluginType): """ @@ -15,7 +17,7 @@ class EcosystemPlugin(PluginType): """ @hookspec # type: ignore[empty-body] - def ecosystems(self) -> Iterator[type[EcosystemAPI]]: + def ecosystems(self) -> Iterator[type["EcosystemAPI"]]: """ A hook that must return an iterator of :class:`ape.api.networks.EcosystemAPI` subclasses. @@ -39,7 +41,7 @@ class NetworkPlugin(PluginType): """ @hookspec # type: ignore[empty-body] - def networks(self) -> Iterator[tuple[str, str, type[NetworkAPI]]]: + def networks(self) -> Iterator[tuple[str, str, type["NetworkAPI"]]]: """ A hook that must return an iterator of tuples of: @@ -67,7 +69,9 @@ class ProviderPlugin(PluginType): """ @hookspec - def providers(self) -> Iterator[tuple[str, str, type[ProviderAPI]]]: # type: ignore[empty-body] + def providers( # type: ignore[empty-body] + self, + ) -> Iterator[tuple[str, str, type["ProviderAPI"]]]: """ A hook that must return an iterator of tuples of: @@ -93,7 +97,9 @@ class ExplorerPlugin(PluginType): """ @hookspec - def explorers(self) -> Iterator[tuple[str, str, type[ExplorerAPI]]]: # type: ignore[empty-body] + def explorers( # type: ignore[empty-body] + self, + ) -> Iterator[tuple[str, str, type["ExplorerAPI"]]]: """ A hook that must return an iterator of tuples of: diff --git a/src/ape/plugins/project.py b/src/ape/plugins/project.py index 32c14a4f54..5b4d44d820 100644 --- a/src/ape/plugins/project.py +++ b/src/ape/plugins/project.py @@ -1,9 +1,11 @@ from collections.abc import Iterator - -from ape.api.projects import DependencyAPI, ProjectAPI +from typing import TYPE_CHECKING from .pluggy_patch import PluginType, hookspec +if TYPE_CHECKING: + from ape.api.projects import DependencyAPI, ProjectAPI + class ProjectPlugin(PluginType): """ @@ -15,7 +17,7 @@ class ProjectPlugin(PluginType): """ @hookspec # type: ignore[empty-body] - def projects(self) -> Iterator[type[ProjectAPI]]: + def projects(self) -> Iterator[type["ProjectAPI"]]: """ A hook that returns a :class:`~ape.api.projects.ProjectAPI` subclass type. @@ -31,7 +33,7 @@ class DependencyPlugin(PluginType): """ @hookspec - def dependencies(self) -> dict[str, type[DependencyAPI]]: # type: ignore[empty-body] + def dependencies(self) -> dict[str, type["DependencyAPI"]]: # type: ignore[empty-body] """ A hook that returns a :class:`~ape.api.projects.DependencyAPI` mapped to its ``ape-config.yaml`` file dependencies special key. For example, diff --git a/src/ape/utils/testing.py b/src/ape/utils/testing.py index 6d85efed26..cc2ed7d3d4 100644 --- a/src/ape/utils/testing.py +++ b/src/ape/utils/testing.py @@ -1,8 +1,5 @@ from collections import namedtuple -from eth_account import Account -from eth_account.hdaccount import HDPath -from eth_account.hdaccount.mnemonic import Mnemonic from eth_utils import to_hex DEFAULT_NUMBER_OF_TEST_ACCOUNTS = 10 @@ -47,6 +44,9 @@ def generate_dev_accounts( Returns: list[:class:`~ape.utils.GeneratedDevAccount`]: List of development accounts. """ + # perf: lazy imports so module loads faster. + from eth_account.hdaccount.mnemonic import Mnemonic + seed = Mnemonic.to_seed(mnemonic) hd_path_format = ( hd_path if "{}" in hd_path or "{0}" in hd_path else f"{hd_path.rstrip('/')}/{{}}" @@ -58,6 +58,10 @@ def generate_dev_accounts( def _generate_dev_account(hd_path, index: int, seed: bytes) -> GeneratedDevAccount: + # perf: lazy imports so module loads faster. + from eth_account.account import Account + from eth_account.hdaccount import HDPath + return GeneratedDevAccount( address=Account.from_key( private_key := to_hex(HDPath(hd_path.format(index)).derive(seed)) diff --git a/src/ape_accounts/__init__.py b/src/ape_accounts/__init__.py index 862b78819b..b3af3607f3 100644 --- a/src/ape_accounts/__init__.py +++ b/src/ape_accounts/__init__.py @@ -1,19 +1,20 @@ -from ape import plugins +from importlib import import_module +from typing import Any -from .accounts import ( - AccountContainer, - KeyfileAccount, - generate_account, - import_account_from_mnemonic, - import_account_from_private_key, -) +from ape.plugins import AccountPlugin, register -@plugins.register(plugins.AccountPlugin) +@register(AccountPlugin) def account_types(): + from ape_accounts.accounts import AccountContainer, KeyfileAccount + return AccountContainer, KeyfileAccount +def __getattr__(name: str) -> Any: + return getattr(import_module("ape_accounts.accounts"), name) + + __all__ = [ "AccountContainer", "KeyfileAccount", diff --git a/src/ape_accounts/_cli.py b/src/ape_accounts/_cli.py index 27fd75cb67..f6708a0417 100644 --- a/src/ape_accounts/_cli.py +++ b/src/ape_accounts/_cli.py @@ -3,21 +3,23 @@ from typing import TYPE_CHECKING, Optional import click -from eth_account import Account as EthAccount -from eth_account.hdaccount import ETHEREUM_DEFAULT_PATH from eth_utils import to_checksum_address, to_hex from ape.cli.arguments import existing_alias_argument, non_existing_alias_argument from ape.cli.options import ape_cli_context from ape.logging import HIDDEN_MESSAGE -from ape.utils.basemodel import ManagerAccessMixin as access if TYPE_CHECKING: from ape.api.accounts import AccountAPI from ape_accounts.accounts import AccountContainer, KeyfileAccount +ETHEREUM_DEFAULT_PATH = "m/44'/60'/0'/0/0" + + def _get_container() -> "AccountContainer": + from ape.utils.basemodel import ManagerAccessMixin as access + # NOTE: Must used the instantiated version of `AccountsContainer` in `accounts` return access.account_manager.containers["accounts"] @@ -144,15 +146,14 @@ def ask_for_passphrase(): confirmation_prompt=True, ) - account_module = import_module("ape_accounts.accounts") if import_from_mnemonic: + from eth_account import Account as EthAccount + mnemonic = click.prompt("Enter mnemonic seed phrase", hide_input=True) EthAccount.enable_unaudited_hdwallet_features() try: passphrase = ask_for_passphrase() - account = account_module.import_account_from_mnemonic( - alias, passphrase, mnemonic, custom_hd_path - ) + account = _account_from_mnemonic(alias, passphrase, mnemonic, hd_path=custom_hd_path) except Exception as error: error_msg = f"{error}".replace(mnemonic, HIDDEN_MESSAGE) cli_ctx.abort(f"Seed phrase can't be imported: {error_msg}") @@ -161,7 +162,7 @@ def ask_for_passphrase(): key = click.prompt("Enter Private Key", hide_input=True) try: passphrase = ask_for_passphrase() - account = account_module.import_account_from_private_key(alias, passphrase, key) + account = _account_from_key(alias, passphrase, key) except Exception as error: cli_ctx.abort(f"Key can't be imported: {error}") @@ -176,10 +177,24 @@ def _load_account_type(account: "AccountAPI") -> bool: return isinstance(account, module.KeyfileAccount) +def _account_from_mnemonic( + alias: str, passphrase: str, mnemonic: str, hd_path: str = ETHEREUM_DEFAULT_PATH +) -> "KeyfileAccount": + account_module = import_module("ape_accounts.accounts") + return account_module.import_account_from_mnemonic(alias, passphrase, mnemonic, hd_path=hd_path) + + +def _account_from_key(alias: str, passphrase: str, key: str) -> "KeyfileAccount": + account_module = import_module("ape_accounts.accounts") + return account_module.import_account_from_private_key(alias, passphrase, key) + + @cli.command(short_help="Export an account private key") @ape_cli_context() @existing_alias_argument(account_type=_load_account_type) def export(cli_ctx, alias): + from eth_account import Account as EthAccount + path = _get_container().data_folder.joinpath(f"{alias}.json") account = json.loads(path.read_text()) password = click.prompt("Enter password to decrypt account", hide_input=True) diff --git a/src/ape_cache/__init__.py b/src/ape_cache/__init__.py index 936e2ad8bb..516e7e6a91 100644 --- a/src/ape_cache/__init__.py +++ b/src/ape_cache/__init__.py @@ -1,19 +1,16 @@ from importlib import import_module -from ape import plugins -from ape.api.config import PluginConfig +from ape.plugins import Config, QueryPlugin, register -class CacheConfig(PluginConfig): - size: int = 1024**3 # 1gb - - -@plugins.register(plugins.Config) +@register(Config) def config_class(): + from ape_cache.config import CacheConfig + return CacheConfig -@plugins.register(plugins.QueryPlugin) +@register(QueryPlugin) def query_engines(): query = import_module("ape_cache.query") return query.CacheQueryProvider @@ -21,13 +18,18 @@ def query_engines(): def __getattr__(name): if name == "CacheQueryProvider": - query = import_module("ape_cache.query") - return query.CacheQueryProvider + module = import_module("ape_cache.query") + return module.CacheQueryProvider + + elif name == "CacheConfig": + module = import_module("ape_cache.config") + return module.CacheConfig else: raise AttributeError(name) __all__ = [ + "CacheConfig", "CacheQueryProvider", ] diff --git a/src/ape_cache/config.py b/src/ape_cache/config.py new file mode 100644 index 0000000000..264516b738 --- /dev/null +++ b/src/ape_cache/config.py @@ -0,0 +1,5 @@ +from ape.api.config import PluginConfig + + +class CacheConfig(PluginConfig): + size: int = 1024**3 # 1gb diff --git a/src/ape_compile/__init__.py b/src/ape_compile/__init__.py index f9cac3e8d6..30dad9fbd1 100644 --- a/src/ape_compile/__init__.py +++ b/src/ape_compile/__init__.py @@ -1,95 +1,26 @@ -import re -from re import Pattern -from typing import Union +from typing import Any -from pydantic import field_serializer, field_validator +from ape.plugins import Config as RConfig +from ape.plugins import register -from ape import plugins -from ape.api.config import ConfigEnum, PluginConfig -from ape.utils.misc import SOURCE_EXCLUDE_PATTERNS +@register(RConfig) +def config_class(): + from ape_compile.config import Config -class OutputExtras(ConfigEnum): - """ - Extra stuff you can output. It will - appear in ``.build/{key.lower()/`` - """ - - ABI = "ABI" - """ - Include this value to output the ABIs of your contracts - to minified JSONs. This is useful for hosting purposes - for web-apps. - """ - - -class Config(PluginConfig): - """ - Configure general compiler settings. - """ - - exclude: set[Union[str, Pattern]] = set() - """ - Source exclusion globs or regex patterns across all file types. - To use regex, start your values with ``r"`` and they'll be turned - into regex pattern objects. - - **NOTE**: ``ape.utils.misc.SOURCE_EXCLUDE_PATTERNS`` are automatically - included in this set. - """ - - include_dependencies: bool = False - """ - Set to ``True`` to compile dependencies during ``ape compile``. - Generally, dependencies are not compiled during ``ape compile`` - This is because dependencies may not compile in Ape on their own, - but you can still reference them in your project's contracts' imports. - Some projects may be more dependency-based and wish to have the - contract types always compiled during ``ape compile``, and these projects - should configure ``include_dependencies`` to be ``True``. - """ - - output_extra: list[OutputExtras] = [] - """ - Extra selections to output. Outputs to ``.build/{key.lower()}``. - """ - - @field_validator("exclude", mode="before") - @classmethod - def validate_exclude(cls, value): - given_values = [] - - # Convert regex to Patterns. - for given in value or []: - if (given.startswith('r"') and given.endswith('"')) or ( - given.startswith("r'") and given.endswith("'") - ): - value_clean = given[2:-1] - pattern = re.compile(value_clean) - given_values.append(pattern) + return Config - else: - given_values.append(given) - # Include defaults. - return {*given_values, *SOURCE_EXCLUDE_PATTERNS} +def __getattr__(name: str) -> Any: + if name == "Config": + from ape_compile.config import Config - @field_serializer("exclude", when_used="json") - def serialize_exclude(self, exclude, info): - """ - Exclude is put back with the weird r-prefix so we can - go to-and-from. - """ - result: list[str] = [] - for excl in exclude: - if isinstance(excl, Pattern): - result.append(f'r"{excl.pattern}"') - else: - result.append(excl) + return Config - return result + else: + raise AttributeError(name) -@plugins.register(plugins.Config) -def config_class(): - return Config +__all__ = [ + "Config", +] diff --git a/src/ape_compile/_cli.py b/src/ape_compile/_cli.py index 30b1425793..0251d77d3a 100644 --- a/src/ape_compile/_cli.py +++ b/src/ape_compile/_cli.py @@ -1,12 +1,14 @@ import sys from pathlib import Path +from typing import TYPE_CHECKING import click -from ethpm_types import ContractType from ape.cli.arguments import contract_file_paths_argument from ape.cli.options import ape_cli_context, config_override_option, project_option -from ape.utils.os import clean_path + +if TYPE_CHECKING: + from ethpm_types import ContractType def _include_dependencies_callback(ctx, param, value): @@ -93,6 +95,8 @@ def cli( _display_byte_code_sizes(cli_ctx, contract_types) if not compiled: + from ape.utils.os import clean_path # perf: lazy import + folder = clean_path(project.contracts_folder) cli_ctx.logger.warning(f"Nothing to compile ({folder}).") @@ -101,7 +105,7 @@ def cli( sys.exit(1) -def _display_byte_code_sizes(cli_ctx, contract_types: dict[str, ContractType]): +def _display_byte_code_sizes(cli_ctx, contract_types: dict[str, "ContractType"]): # Display bytecode size for *all* contract types (not just ones we compiled) code_size = [] for contract in contract_types.values(): diff --git a/src/ape_compile/config.py b/src/ape_compile/config.py new file mode 100644 index 0000000000..012e590b04 --- /dev/null +++ b/src/ape_compile/config.py @@ -0,0 +1,89 @@ +import re +from re import Pattern +from typing import Union + +from pydantic import field_serializer, field_validator + +from ape.api.config import ConfigEnum, PluginConfig +from ape.utils.misc import SOURCE_EXCLUDE_PATTERNS + + +class OutputExtras(ConfigEnum): + """ + Extra stuff you can output. It will + appear in ``.build/{key.lower()/`` + """ + + ABI = "ABI" + """ + Include this value to output the ABIs of your contracts + to minified JSONs. This is useful for hosting purposes + for web-apps. + """ + + +class Config(PluginConfig): + """ + Configure general compiler settings. + """ + + exclude: set[Union[str, Pattern]] = set() + """ + Source exclusion globs or regex patterns across all file types. + To use regex, start your values with ``r"`` and they'll be turned + into regex pattern objects. + + **NOTE**: ``ape.utils.misc.SOURCE_EXCLUDE_PATTERNS`` are automatically + included in this set. + """ + + include_dependencies: bool = False + """ + Set to ``True`` to compile dependencies during ``ape compile``. + Generally, dependencies are not compiled during ``ape compile`` + This is because dependencies may not compile in Ape on their own, + but you can still reference them in your project's contracts' imports. + Some projects may be more dependency-based and wish to have the + contract types always compiled during ``ape compile``, and these projects + should configure ``include_dependencies`` to be ``True``. + """ + + output_extra: list[OutputExtras] = [] + """ + Extra selections to output. Outputs to ``.build/{key.lower()}``. + """ + + @field_validator("exclude", mode="before") + @classmethod + def validate_exclude(cls, value): + given_values = [] + + # Convert regex to Patterns. + for given in value or []: + if (given.startswith('r"') and given.endswith('"')) or ( + given.startswith("r'") and given.endswith("'") + ): + value_clean = given[2:-1] + pattern = re.compile(value_clean) + given_values.append(pattern) + + else: + given_values.append(given) + + # Include defaults. + return {*given_values, *SOURCE_EXCLUDE_PATTERNS} + + @field_serializer("exclude", when_used="json") + def serialize_exclude(self, exclude, info): + """ + Exclude is put back with the weird r-prefix so we can + go to-and-from. + """ + result: list[str] = [] + for excl in exclude: + if isinstance(excl, Pattern): + result.append(f'r"{excl.pattern}"') + else: + result.append(excl) + + return result diff --git a/src/ape_console/__init__.py b/src/ape_console/__init__.py index 5bd99c7c67..b6ed602857 100644 --- a/src/ape_console/__init__.py +++ b/src/ape_console/__init__.py @@ -1,7 +1,8 @@ -from ape import plugins -from ape_console.config import ConsoleConfig +from ape.plugins import Config, register -@plugins.register(plugins.Config) +@register(Config) def config_class(): + from ape_console.config import ConsoleConfig + return ConsoleConfig diff --git a/src/ape_console/_cli.py b/src/ape_console/_cli.py index f04cdea48d..07855986d1 100644 --- a/src/ape_console/_cli.py +++ b/src/ape_console/_cli.py @@ -12,10 +12,6 @@ from ape.cli.commands import ConnectedProviderCommand from ape.cli.options import ape_cli_context, project_option -from ape.utils.basemodel import ManagerAccessMixin as access -from ape.utils.misc import _python_version -from ape.version import version as ape_version -from ape_console.config import ConsoleConfig if TYPE_CHECKING: from IPython.terminal.ipapp import Config as IPythonConfig @@ -54,6 +50,8 @@ def import_extras_file(file_path) -> ModuleType: def load_console_extras(**namespace: Any) -> dict[str, Any]: """load and return namespace updates from ape_console_extras.py files if they exist""" + from ape.utils.basemodel import ManagerAccessMixin as access + pm = namespace.get("project", access.local_project) global_extras = pm.config_manager.DATA_FOLDER.joinpath(CONSOLE_EXTRAS_FILENAME) project_extras = pm.path.joinpath(CONSOLE_EXTRAS_FILENAME) @@ -102,6 +100,8 @@ def console( from IPython.terminal.ipapp import Config as IPythonConfig import ape + from ape.utils.misc import _python_version + from ape.version import version as ape_version project = project or ape.project banner = "" @@ -155,6 +155,8 @@ def console( def _launch_console(namespace: dict, ipy_config: "IPythonConfig", embed: bool, banner: str): import IPython + from ape_console.config import ConsoleConfig + ipython_kwargs = {"user_ns": namespace, "config": ipy_config} if embed: IPython.embed(**ipython_kwargs, colors="Neutral", banner1=banner) diff --git a/src/ape_networks/__init__.py b/src/ape_networks/__init__.py index f82dc72c30..51e382e1f8 100644 --- a/src/ape_networks/__init__.py +++ b/src/ape_networks/__init__.py @@ -1,44 +1,22 @@ -from typing import Optional +from importlib import import_module +from typing import Any -from ape import plugins -from ape.api.config import PluginConfig +from ape.plugins import Config, register -class CustomNetwork(PluginConfig): - """ - A custom network config. - """ - - name: str - """Name of the network e.g. mainnet.""" - - chain_id: int - """Chain ID (required).""" - - ecosystem: str - """The name of the ecosystem.""" - - base_ecosystem_plugin: Optional[str] = None - """The base ecosystem plugin to use, when applicable. Defaults to the default ecosystem.""" - - default_provider: str = "node" - """The default provider plugin to use. Default is the default node provider.""" +@register(Config) +def config_class(): + from ape_networks.config import NetworksConfig - request_header: dict = {} - """The HTTP request header.""" + return NetworksConfig - @property - def is_fork(self) -> bool: - """ - ``True`` when the name of the network ends in ``"-fork"``. - """ - return self.name.endswith("-fork") +def __getattr__(name: str) -> Any: + if name in ("NetworksConfig", "CustomNetwork"): + return getattr(import_module("ape_networks.config"), name) -class NetworksConfig(PluginConfig): - custom: list[CustomNetwork] = [] + else: + raise AttributeError(name) -@plugins.register(plugins.Config) -def config_class(): - return NetworksConfig +__all__ = ["NetworksConfig"] diff --git a/src/ape_networks/_cli.py b/src/ape_networks/_cli.py index fd9c019787..41ff326e42 100644 --- a/src/ape_networks/_cli.py +++ b/src/ape_networks/_cli.py @@ -1,5 +1,5 @@ import json -from collections.abc import Callable +from collections.abc import Callable, Sequence from importlib import import_module from typing import TYPE_CHECKING @@ -8,24 +8,22 @@ from rich import print as echo_rich_text from rich.tree import Tree -from ape.cli.choices import OutputFormat +from ape.cli.choices import LazyChoice, OutputFormat from ape.cli.options import ape_cli_context, network_option, output_format_option from ape.exceptions import NetworkError from ape.logging import LogLevel -from ape.types.basic import _LazySequence -from ape.utils.basemodel import ManagerAccessMixin as access if TYPE_CHECKING: from ape.api.providers import SubprocessProvider -def _filter_option(name: str, options): +def _filter_option(name: str, get_options: Callable[[], Sequence[str]]): return click.option( f"--{name}", f"{name}_filter", multiple=True, help=f"Filter the results by {name}", - type=click.Choice(options), + type=LazyChoice(get_options), ) @@ -36,20 +34,24 @@ def cli(): """ -def _lazy_get(name: str) -> _LazySequence: +def _lazy_get(name: str) -> Sequence: # NOTE: Using fn generator to maintain laziness. def gen(): + from ape.utils.basemodel import ManagerAccessMixin as access + yield from getattr(access.network_manager, f"{name}_names") + from ape.types.basic import _LazySequence + return _LazySequence(gen) @cli.command(name="list", short_help="List registered networks") @ape_cli_context() @output_format_option() -@_filter_option("ecosystem", _lazy_get("ecosystem")) -@_filter_option("network", _lazy_get("network")) -@_filter_option("provider", _lazy_get("provider")) +@_filter_option("ecosystem", lambda: _lazy_get("ecosystem")) +@_filter_option("network", lambda: _lazy_get("network")) +@_filter_option("provider", lambda: _lazy_get("provider")) def _list(cli_ctx, output_format, ecosystem_filter, network_filter, provider_filter): """ List all the registered ecosystems, networks, and providers. diff --git a/src/ape_networks/config.py b/src/ape_networks/config.py new file mode 100644 index 0000000000..381cd268b2 --- /dev/null +++ b/src/ape_networks/config.py @@ -0,0 +1,38 @@ +from typing import Optional + +from ape.api.config import PluginConfig + + +class CustomNetwork(PluginConfig): + """ + A custom network config. + """ + + name: str + """Name of the network e.g. mainnet.""" + + chain_id: int + """Chain ID (required).""" + + ecosystem: str + """The name of the ecosystem.""" + + base_ecosystem_plugin: Optional[str] = None + """The base ecosystem plugin to use, when applicable. Defaults to the default ecosystem.""" + + default_provider: str = "node" + """The default provider plugin to use. Default is the default node provider.""" + + request_header: dict = {} + """The HTTP request header.""" + + @property + def is_fork(self) -> bool: + """ + ``True`` when the name of the network ends in ``"-fork"``. + """ + return self.name.endswith("-fork") + + +class NetworksConfig(PluginConfig): + custom: list[CustomNetwork] = [] diff --git a/src/ape_plugins/__init__.py b/src/ape_plugins/__init__.py index 8826fdd2ce..734889b268 100644 --- a/src/ape_plugins/__init__.py +++ b/src/ape_plugins/__init__.py @@ -1,7 +1,8 @@ -from ape import plugins -from ape.api.config import ConfigDict +from ape.plugins import Config, register -@plugins.register(plugins.Config) +@register(Config) def config_class(): + from ape.api.config import ConfigDict + return ConfigDict diff --git a/tests/functional/test_compilers.py b/tests/functional/test_compilers.py index 48ed938828..58db5f44b8 100644 --- a/tests/functional/test_compilers.py +++ b/tests/functional/test_compilers.py @@ -8,7 +8,7 @@ from ape.contracts import ContractContainer from ape.exceptions import APINotImplementedError, CompilerError, ContractLogicError, CustomError from ape.types.address import AddressType -from ape_compile import Config +from ape_compile.config import Config def test_get_imports(project, compilers): diff --git a/tests/integration/cli/test_accounts.py b/tests/integration/cli/test_accounts.py index a0087033e3..1fea64efd5 100644 --- a/tests/integration/cli/test_accounts.py +++ b/tests/integration/cli/test_accounts.py @@ -139,7 +139,7 @@ def invoke_import(): @run_once def test_import_account_instantiation_failure(mocker, ape_cli, runner): - eth_account_from_key_patch = mocker.patch("ape_accounts._cli.EthAccount.from_key") + eth_account_from_key_patch = mocker.patch("ape_accounts._cli._account_from_key") eth_account_from_key_patch.side_effect = Exception("Can't instantiate this account!") result = runner.invoke( ape_cli, From 4fdf8e9228defd54503b010b1b7dccd48276b490 Mon Sep 17 00:00:00 2001 From: antazoey Date: Tue, 29 Oct 2024 16:06:17 -0500 Subject: [PATCH 7/7] perf: Enable flake8 type checks (#2352) --- .pre-commit-config.yaml | 2 +- setup.cfg | 5 +- setup.py | 1 + src/ape/_cli.py | 8 ++- src/ape/api/accounts.py | 5 +- src/ape/api/address.py | 4 +- src/ape/api/compiler.py | 36 +++++----- src/ape/api/explorers.py | 16 +++-- src/ape/api/networks.py | 35 +++++----- src/ape/api/providers.py | 87 +++++++++++++------------ src/ape/api/trace.py | 8 ++- src/ape/api/transactions.py | 25 +++---- src/ape/cli/commands.py | 13 ++-- src/ape/contracts/base.py | 44 +++++++------ src/ape/logging.py | 4 +- src/ape/managers/accounts.py | 6 +- src/ape/managers/chain.py | 29 +++++---- src/ape/managers/compilers.py | 17 ++--- src/ape/managers/config.py | 10 +-- src/ape/managers/converters.py | 8 ++- src/ape/managers/networks.py | 22 ++++--- src/ape/managers/query.py | 2 +- src/ape/pytest/config.py | 9 +-- src/ape/pytest/coverage.py | 39 ++++++----- src/ape/pytest/fixtures.py | 38 ++++++----- src/ape/pytest/gas.py | 26 ++++---- src/ape/pytest/plugin.py | 8 ++- src/ape/pytest/runners.py | 28 ++++---- src/ape/types/address.py | 9 ++- src/ape/types/coverage.py | 10 ++- src/ape/types/signatures.py | 22 ++++--- src/ape/types/trace.py | 9 +-- src/ape/types/units.py | 8 ++- src/ape/utils/misc.py | 4 +- src/ape/utils/os.py | 7 +- src/ape/utils/rpc.py | 2 +- src/ape_accounts/accounts.py | 21 ++++-- src/ape_cache/query.py | 6 +- src/ape_ethereum/_print.py | 15 +++-- src/ape_ethereum/ecosystem.py | 30 +++++---- src/ape_ethereum/multicall/handlers.py | 25 +++---- src/ape_ethereum/provider.py | 50 +++++++------- src/ape_ethereum/trace.py | 34 ++++++---- src/ape_ethereum/transactions.py | 11 ++-- src/ape_node/provider.py | 18 +++-- src/ape_pm/compiler.py | 8 ++- src/ape_pm/project.py | 5 +- src/ape_test/accounts.py | 18 +++-- src/ape_test/provider.py | 28 ++++---- tests/functional/conftest.py | 9 ++- tests/functional/test_config.py | 9 ++- tests/functional/test_contract_event.py | 8 ++- tests/functional/test_explorer.py | 15 +++-- tests/functional/test_receipt.py | 8 ++- 54 files changed, 519 insertions(+), 405 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8c85fc0f1b..a0a339b0ec 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -19,7 +19,7 @@ repos: rev: 7.1.1 hooks: - id: flake8 - additional_dependencies: [flake8-breakpoint, flake8-print, flake8-pydantic] + additional_dependencies: [flake8-breakpoint, flake8-print, flake8-pydantic, flake8-type-checking] - repo: https://github.com/pre-commit/mirrors-mypy rev: v1.13.0 diff --git a/setup.cfg b/setup.cfg index 3272b9ceeb..0b71eb61a4 100644 --- a/setup.cfg +++ b/setup.cfg @@ -7,10 +7,13 @@ exclude = build .eggs tests/integration/cli/projects -ignore = E704,W503,PYD002 +ignore = E704,W503,PYD002,TC003,TC006 per-file-ignores = # Need signal handler before imports src/ape/__init__.py: E402 # Test data causes long lines tests/functional/data/python/__init__.py: E501 tests/functional/utils/expected_traces.py: E501 + +type-checking-pydantic-enabled = True +type-checking-sqlalchemy-enabled = True diff --git a/setup.py b/setup.py index 6d516c59c8..932989e652 100644 --- a/setup.py +++ b/setup.py @@ -34,6 +34,7 @@ "flake8-breakpoint>=1.1.0,<2", # Detect breakpoints left in code "flake8-print>=4.0.1,<5", # Detect print statements left in code "flake8-pydantic", # For detecting issues with Pydantic models + "flake8-type-checking", # Detect imports to move in/out of type-checking blocks "isort>=5.13.2,<6", # Import sorting linter "mdformat>=0.7.18", # Auto-formatter for markdown "mdformat-gfm>=0.3.5", # Needed for formatting GitHub-flavored markdown diff --git a/src/ape/_cli.py b/src/ape/_cli.py index caf223ce0a..8afe38318b 100644 --- a/src/ape/_cli.py +++ b/src/ape/_cli.py @@ -7,18 +7,20 @@ from importlib import import_module from importlib.metadata import entry_points from pathlib import Path -from typing import Any, Optional +from typing import TYPE_CHECKING, Any, Optional from warnings import catch_warnings, simplefilter import click import rich import yaml -from click import Context from ape.cli.options import ape_cli_context from ape.exceptions import Abort, ApeException, ConfigError, handle_ape_exception from ape.logging import logger +if TYPE_CHECKING: + from click import Context + _DIFFLIB_CUT_OFF = 0.6 @@ -53,7 +55,7 @@ def _validate_config(): class ApeCLI(click.MultiCommand): _CLI_GROUP_NAME = "ape_cli_subcommands" - def parse_args(self, ctx: Context, args: list[str]) -> list[str]: + def parse_args(self, ctx: "Context", args: list[str]) -> list[str]: # Validate the config before any argument parsing, # as arguments may utilize config. if "--help" not in args and args != []: diff --git a/src/ape/api/accounts.py b/src/ape/api/accounts.py index 9320bfa5f1..c8c357ffb6 100644 --- a/src/ape/api/accounts.py +++ b/src/ape/api/accounts.py @@ -10,7 +10,6 @@ from eip712.messages import SignableMessage as EIP712SignableMessage from eth_account import Account from eth_account.messages import encode_defunct -from eth_pydantic_types import HexBytes from eth_utils import to_hex from ethpm_types import ContractType @@ -31,6 +30,8 @@ from ape.utils.misc import raises_not_implemented if TYPE_CHECKING: + from eth_pydantic_types import HexBytes + from ape.contracts import ContractContainer, ContractInstance @@ -65,7 +66,7 @@ def alias(self) -> Optional[str]: """ return None - def sign_raw_msghash(self, msghash: HexBytes) -> Optional[MessageSignature]: + def sign_raw_msghash(self, msghash: "HexBytes") -> Optional[MessageSignature]: """ Sign a raw message hash. diff --git a/src/ape/api/address.py b/src/ape/api/address.py index cd62661c41..7ab52a81c3 100644 --- a/src/ape/api/address.py +++ b/src/ape/api/address.py @@ -7,13 +7,13 @@ from ape.exceptions import ConversionError from ape.types.address import AddressType from ape.types.units import CurrencyValue -from ape.types.vm import ContractCode from ape.utils.basemodel import BaseInterface from ape.utils.misc import log_instead_of_fail if TYPE_CHECKING: from ape.api.transactions import ReceiptAPI, TransactionAPI from ape.managers.chain import AccountHistory + from ape.types.vm import ContractCode class BaseAddress(BaseInterface): @@ -146,7 +146,7 @@ def __setattr__(self, attr: str, value: Any) -> None: super().__setattr__(attr, value) @property - def code(self) -> ContractCode: + def code(self) -> "ContractCode": """ The raw bytes of the smart-contract code at the address. """ diff --git a/src/ape/api/compiler.py b/src/ape/api/compiler.py index e870edc745..eec9c47265 100644 --- a/src/ape/api/compiler.py +++ b/src/ape/api/compiler.py @@ -4,21 +4,21 @@ from pathlib import Path from typing import TYPE_CHECKING, Optional -from eth_pydantic_types import HexBytes -from ethpm_types import ContractType -from ethpm_types.source import Content, ContractSource -from packaging.version import Version - -from ape.api.config import PluginConfig -from ape.api.trace import TraceAPI from ape.exceptions import APINotImplementedError, ContractLogicError -from ape.types.coverage import ContractSourceCoverage -from ape.types.trace import SourceTraceback from ape.utils.basemodel import BaseInterfaceModel from ape.utils.misc import log_instead_of_fail, raises_not_implemented if TYPE_CHECKING: + from eth_pydantic_types import HexBytes + from ethpm_types import ContractType + from ethpm_types.source import Content, ContractSource + from packaging.version import Version + + from ape.api.config import PluginConfig + from ape.api.trace import TraceAPI from ape.managers.project import ProjectManager + from ape.types.coverage import ContractSourceCoverage + from ape.types.trace import SourceTraceback class CompilerAPI(BaseInterfaceModel): @@ -44,7 +44,7 @@ def name(self) -> str: The name of the compiler. """ - def get_config(self, project: Optional["ProjectManager"] = None) -> PluginConfig: + def get_config(self, project: Optional["ProjectManager"] = None) -> "PluginConfig": """ The combination of settings from ``ape-config.yaml`` and ``.compiler_settings``. @@ -79,7 +79,7 @@ def get_compiler_settings( # type: ignore[empty-body] contract_filepaths: Iterable[Path], project: Optional["ProjectManager"] = None, **overrides, - ) -> dict[Version, dict]: + ) -> dict["Version", dict]: """ Get a mapping of the settings that would be used to compile each of the sources by the compiler version number. @@ -101,7 +101,7 @@ def compile( contract_filepaths: Iterable[Path], project: Optional["ProjectManager"], settings: Optional[dict] = None, - ) -> Iterator[ContractType]: + ) -> Iterator["ContractType"]: """ Compile the given source files. All compiler plugins must implement this function. @@ -123,7 +123,7 @@ def compile_code( # type: ignore[empty-body] project: Optional["ProjectManager"], settings: Optional[dict] = None, **kwargs, - ) -> ContractType: + ) -> "ContractType": """ Compile a program. @@ -162,7 +162,7 @@ def get_version_map( # type: ignore[empty-body] self, contract_filepaths: Iterable[Path], project: Optional["ProjectManager"] = None, - ) -> dict[Version, set[Path]]: + ) -> dict["Version", set[Path]]: """ Get a map of versions to source paths. @@ -218,8 +218,8 @@ def enrich_error(self, err: ContractLogicError) -> ContractLogicError: @raises_not_implemented def trace_source( # type: ignore[empty-body] - self, contract_source: ContractSource, trace: TraceAPI, calldata: HexBytes - ) -> SourceTraceback: + self, contract_source: "ContractSource", trace: "TraceAPI", calldata: "HexBytes" + ) -> "SourceTraceback": """ Get a source-traceback for the given contract type. The source traceback object contains all the control paths taken in the transaction. @@ -239,7 +239,7 @@ def trace_source( # type: ignore[empty-body] @raises_not_implemented def flatten_contract( # type: ignore[empty-body] self, path: Path, project: Optional["ProjectManager"] = None, **kwargs - ) -> Content: + ) -> "Content": """ Get the content of a flattened contract via its source path. Plugin implementations handle import resolution, SPDX de-duplication, @@ -259,7 +259,7 @@ def flatten_contract( # type: ignore[empty-body] @raises_not_implemented def init_coverage_profile( - self, source_coverage: ContractSourceCoverage, contract_source: ContractSource + self, source_coverage: "ContractSourceCoverage", contract_source: "ContractSource" ): # type: ignore[empty-body] """ Initialize an empty report for the given source ID. Modifies the given source diff --git a/src/ape/api/explorers.py b/src/ape/api/explorers.py index 2344c06d05..e1901d4a0f 100644 --- a/src/ape/api/explorers.py +++ b/src/ape/api/explorers.py @@ -1,12 +1,14 @@ from abc import abstractmethod -from typing import Optional - -from ethpm_types import ContractType +from typing import TYPE_CHECKING, Optional from ape.api.networks import NetworkAPI -from ape.types.address import AddressType from ape.utils.basemodel import BaseInterfaceModel +if TYPE_CHECKING: + from ethpm_types import ContractType + + from ape.types.address import AddressType + class ExplorerAPI(BaseInterfaceModel): """ @@ -18,7 +20,7 @@ class ExplorerAPI(BaseInterfaceModel): network: NetworkAPI @abstractmethod - def get_address_url(self, address: AddressType) -> str: + def get_address_url(self, address: "AddressType") -> str: """ Get an address URL, such as for a transaction. @@ -42,7 +44,7 @@ def get_transaction_url(self, transaction_hash: str) -> str: """ @abstractmethod - def get_contract_type(self, address: AddressType) -> Optional[ContractType]: + def get_contract_type(self, address: "AddressType") -> Optional["ContractType"]: """ Get the contract type for a given address if it has been published to this explorer. @@ -54,7 +56,7 @@ def get_contract_type(self, address: AddressType) -> Optional[ContractType]: """ @abstractmethod - def publish_contract(self, address: AddressType): + def publish_contract(self, address: "AddressType"): """ Publish a contract to the explorer. diff --git a/src/ape/api/networks.py b/src/ape/api/networks.py index 92145dbbf6..5f965428f1 100644 --- a/src/ape/api/networks.py +++ b/src/ape/api/networks.py @@ -12,8 +12,6 @@ ) from eth_pydantic_types import HexBytes from eth_utils import keccak, to_int -from ethpm_types import ContractType -from ethpm_types.abi import ABIType, ConstructorABI, EventABI, MethodABI from pydantic import model_validator from ape.exceptions import ( @@ -26,8 +24,7 @@ SignatureError, ) from ape.logging import logger -from ape.types.address import AddressType, RawAddress -from ape.types.events import ContractLog +from ape.types.address import AddressType from ape.types.gas import AutoGasLimit, GasLimit from ape.utils.basemodel import ( BaseInterfaceModel, @@ -47,6 +44,12 @@ from .config import PluginConfig if TYPE_CHECKING: + from ethpm_types import ContractType + from ethpm_types.abi import ABIType, ConstructorABI, EventABI, MethodABI + + from ape.types.address import RawAddress + from ape.types.events import ContractLog + from .explorers import ExplorerAPI from .providers import BlockAPI, ProviderAPI, UpstreamProvider from .trace import TraceAPI @@ -135,7 +138,7 @@ def custom_network(self) -> "NetworkAPI": @classmethod @abstractmethod - def decode_address(cls, raw_address: RawAddress) -> AddressType: + def decode_address(cls, raw_address: "RawAddress") -> AddressType: """ Convert a raw address to the ecosystem's native address type. @@ -149,7 +152,7 @@ def decode_address(cls, raw_address: RawAddress) -> AddressType: @classmethod @abstractmethod - def encode_address(cls, address: AddressType) -> RawAddress: + def encode_address(cls, address: AddressType) -> "RawAddress": """ Convert the ecosystem's native address type to a raw integer or str address. @@ -162,7 +165,7 @@ def encode_address(cls, address: AddressType) -> RawAddress: @raises_not_implemented def encode_contract_blueprint( # type: ignore[empty-body] - self, contract_type: ContractType, *args, **kwargs + self, contract_type: "ContractType", *args, **kwargs ) -> "TransactionAPI": """ Encode a unique type of transaction that allows contracts to be created @@ -386,7 +389,7 @@ def set_default_network(self, network_name: str): @abstractmethod def encode_deployment( - self, deployment_bytecode: HexBytes, abi: ConstructorABI, *args, **kwargs + self, deployment_bytecode: HexBytes, abi: "ConstructorABI", *args, **kwargs ) -> "TransactionAPI": """ Create a deployment transaction in the given ecosystem. @@ -404,7 +407,7 @@ def encode_deployment( @abstractmethod def encode_transaction( - self, address: AddressType, abi: MethodABI, *args, **kwargs + self, address: AddressType, abi: "MethodABI", *args, **kwargs ) -> "TransactionAPI": """ Encode a transaction object from a contract function's ABI and call arguments. @@ -421,12 +424,12 @@ def encode_transaction( """ @abstractmethod - def decode_logs(self, logs: Sequence[dict], *events: EventABI) -> Iterator[ContractLog]: + def decode_logs(self, logs: Sequence[dict], *events: "EventABI") -> Iterator["ContractLog"]: """ Decode any contract logs that match the given event ABI from the raw log data. Args: - logs (Sequence[Dict]): A list of raw log data from the chain. + logs (Sequence[dict]): A list of raw log data from the chain. *events (EventABI): Event definitions to decode. Returns: @@ -464,7 +467,7 @@ def create_transaction(self, **kwargs) -> "TransactionAPI": """ @abstractmethod - def decode_calldata(self, abi: Union[ConstructorABI, MethodABI], calldata: bytes) -> dict: + def decode_calldata(self, abi: Union["ConstructorABI", "MethodABI"], calldata: bytes) -> dict: """ Decode method calldata. @@ -479,7 +482,7 @@ def decode_calldata(self, abi: Union[ConstructorABI, MethodABI], calldata: bytes """ @abstractmethod - def encode_calldata(self, abi: Union[ConstructorABI, MethodABI], *args: Any) -> HexBytes: + def encode_calldata(self, abi: Union["ConstructorABI", "MethodABI"], *args: Any) -> HexBytes: """ Encode method calldata. @@ -492,7 +495,7 @@ def encode_calldata(self, abi: Union[ConstructorABI, MethodABI], *args: Any) -> """ @abstractmethod - def decode_returndata(self, abi: MethodABI, raw_data: bytes) -> Any: + def decode_returndata(self, abi: "MethodABI", raw_data: bytes) -> Any: """ Get the result of a contract call. @@ -586,7 +589,7 @@ def get_proxy_info(self, address: AddressType) -> Optional[ProxyInfoAPI]: """ return None - def get_method_selector(self, abi: MethodABI) -> HexBytes: + def get_method_selector(self, abi: "MethodABI") -> HexBytes: """ Get a contract method selector, typically via hashing such as ``keccak``. Defaults to using ``keccak`` but can be overridden in different ecosystems. @@ -626,7 +629,7 @@ def enrich_trace(self, trace: "TraceAPI", **kwargs) -> "TraceAPI": @raises_not_implemented def get_python_types( # type: ignore[empty-body] - self, abi_type: ABIType + self, abi_type: "ABIType" ) -> Union[type, Sequence]: """ Get the Python types for a given ABI type. diff --git a/src/ape/api/providers.py b/src/ape/api/providers.py index dd5bc51c15..b0e1695f0b 100644 --- a/src/ape/api/providers.py +++ b/src/ape/api/providers.py @@ -16,14 +16,10 @@ from subprocess import DEVNULL, PIPE, Popen from typing import TYPE_CHECKING, Any, Optional, Union, cast -from eth_pydantic_types import HexBytes -from ethpm_types.abi import EventABI from pydantic import Field, computed_field, field_serializer, model_validator -from ape.api.config import PluginConfig from ape.api.networks import NetworkAPI from ape.api.query import BlockTransactionQuery -from ape.api.trace import TraceAPI from ape.api.transactions import ReceiptAPI, TransactionAPI from ape.exceptions import ( APINotImplementedError, @@ -35,10 +31,7 @@ VirtualMachineError, ) from ape.logging import LogLevel, logger -from ape.types.address import AddressType from ape.types.basic import HexInt -from ape.types.events import ContractLog, LogFilter -from ape.types.vm import BlockID, ContractCode, SnapshotID from ape.utils.basemodel import BaseInterfaceModel from ape.utils.misc import ( EMPTY_BYTES32, @@ -51,7 +44,15 @@ from ape.utils.rpc import RPCHeaders if TYPE_CHECKING: + from eth_pydantic_types import HexBytes + from ethpm_types.abi import EventABI + from ape.api.accounts import TestAccountAPI + from ape.api.config import PluginConfig + from ape.api.trace import TraceAPI + from ape.types.address import AddressType + from ape.types.events import ContractLog, LogFilter + from ape.types.vm import BlockID, ContractCode, SnapshotID class BlockAPI(BaseInterfaceModel): @@ -254,7 +255,7 @@ def ws_uri(self) -> Optional[str]: return None @property - def settings(self) -> PluginConfig: + def settings(self) -> "PluginConfig": """ The combination of settings from ``ape-config.yaml`` and ``.provider_settings``. """ @@ -303,7 +304,7 @@ def chain_id(self) -> int: """ @abstractmethod - def get_balance(self, address: AddressType, block_id: Optional[BlockID] = None) -> int: + def get_balance(self, address: "AddressType", block_id: Optional["BlockID"] = None) -> int: """ Get the balance of an account. @@ -317,7 +318,9 @@ def get_balance(self, address: AddressType, block_id: Optional[BlockID] = None) """ @abstractmethod - def get_code(self, address: AddressType, block_id: Optional[BlockID] = None) -> ContractCode: + def get_code( + self, address: "AddressType", block_id: Optional["BlockID"] = None + ) -> "ContractCode": """ Get the bytes a contract. @@ -370,7 +373,7 @@ def stream_request( # type: ignore[empty-body] """ # TODO: In 0.9, delete this method. - def get_storage_at(self, *args, **kwargs) -> HexBytes: + def get_storage_at(self, *args, **kwargs) -> "HexBytes": warnings.warn( "'provider.get_storage_at()' is deprecated. Use 'provider.get_storage()'.", DeprecationWarning, @@ -379,8 +382,8 @@ def get_storage_at(self, *args, **kwargs) -> HexBytes: @raises_not_implemented def get_storage( # type: ignore[empty-body] - self, address: AddressType, slot: int, block_id: Optional[BlockID] = None - ) -> HexBytes: + self, address: "AddressType", slot: int, block_id: Optional["BlockID"] = None + ) -> "HexBytes": """ Gets the raw value of a storage slot of a contract. @@ -395,7 +398,7 @@ def get_storage( # type: ignore[empty-body] """ @abstractmethod - def get_nonce(self, address: AddressType, block_id: Optional[BlockID] = None) -> int: + def get_nonce(self, address: "AddressType", block_id: Optional["BlockID"] = None) -> int: """ Get the number of times an account has transacted. @@ -409,7 +412,7 @@ def get_nonce(self, address: AddressType, block_id: Optional[BlockID] = None) -> """ @abstractmethod - def estimate_gas_cost(self, txn: TransactionAPI, block_id: Optional[BlockID] = None) -> int: + def estimate_gas_cost(self, txn: TransactionAPI, block_id: Optional["BlockID"] = None) -> int: """ Estimate the cost of gas for a transaction. @@ -444,7 +447,7 @@ def max_gas(self) -> int: """ @property - def config(self) -> PluginConfig: + def config(self) -> "PluginConfig": """ The provider's configuration. """ @@ -482,7 +485,7 @@ def base_fee(self) -> int: raise APINotImplementedError("base_fee is not implemented by this provider") @abstractmethod - def get_block(self, block_id: BlockID) -> BlockAPI: + def get_block(self, block_id: "BlockID") -> BlockAPI: """ Get a block. @@ -502,10 +505,10 @@ def get_block(self, block_id: BlockID) -> BlockAPI: def send_call( self, txn: TransactionAPI, - block_id: Optional[BlockID] = None, + block_id: Optional["BlockID"] = None, state: Optional[dict] = None, **kwargs, - ) -> HexBytes: # Return value of function + ) -> "HexBytes": # Return value of function """ Execute a new transaction call immediately without creating a transaction on the block chain. @@ -538,7 +541,7 @@ def get_receipt(self, txn_hash: str, **kwargs) -> ReceiptAPI: """ @abstractmethod - def get_transactions_by_block(self, block_id: BlockID) -> Iterator[TransactionAPI]: + def get_transactions_by_block(self, block_id: "BlockID") -> Iterator[TransactionAPI]: """ Get the information about a set of transactions from a block. @@ -552,7 +555,7 @@ def get_transactions_by_block(self, block_id: BlockID) -> Iterator[TransactionAP @raises_not_implemented def get_transactions_by_account_nonce( # type: ignore[empty-body] self, - account: AddressType, + account: "AddressType", start_nonce: int = 0, stop_nonce: int = -1, ) -> Iterator[ReceiptAPI]: @@ -581,7 +584,7 @@ def send_transaction(self, txn: TransactionAPI) -> ReceiptAPI: """ @abstractmethod - def get_contract_logs(self, log_filter: LogFilter) -> Iterator[ContractLog]: + def get_contract_logs(self, log_filter: "LogFilter") -> Iterator["ContractLog"]: """ Get logs from contracts. @@ -622,25 +625,25 @@ def send_private_transaction(self, txn: TransactionAPI, **kwargs) -> ReceiptAPI: raise _create_raises_not_implemented_error(self.send_private_transaction) @raises_not_implemented - def snapshot(self) -> SnapshotID: # type: ignore[empty-body] + def snapshot(self) -> "SnapshotID": # type: ignore[empty-body] """ Defined to make the ``ProviderAPI`` interchangeable with a :class:`~ape.api.providers.TestProviderAPI`, as in :class:`ape.managers.chain.ChainManager`. Raises: - :class:`~ape.exceptions.APINotImplementedError`: Unless overriden. + :class:`~ape.exceptions.APINotImplementedError`: Unless overridden. """ @raises_not_implemented - def restore(self, snapshot_id: SnapshotID): + def restore(self, snapshot_id: "SnapshotID"): """ Defined to make the ``ProviderAPI`` interchangeable with a :class:`~ape.api.providers.TestProviderAPI`, as in :class:`ape.managers.chain.ChainManager`. Raises: - :class:`~ape.exceptions.APINotImplementedError`: Unless overriden. + :class:`~ape.exceptions.APINotImplementedError`: Unless overridden. """ @raises_not_implemented @@ -651,7 +654,7 @@ def set_timestamp(self, new_timestamp: int): :class:`ape.managers.chain.ChainManager`. Raises: - :class:`~ape.exceptions.APINotImplementedError`: Unless overriden. + :class:`~ape.exceptions.APINotImplementedError`: Unless overridden. """ @raises_not_implemented @@ -662,11 +665,11 @@ def mine(self, num_blocks: int = 1): :class:`ape.managers.chain.ChainManager`. Raises: - :class:`~ape.exceptions.APINotImplementedError`: Unless overriden. + :class:`~ape.exceptions.APINotImplementedError`: Unless overridden. """ @raises_not_implemented - def set_balance(self, address: AddressType, amount: int): + def set_balance(self, address: "AddressType", amount: int): """ Change the balance of an account. @@ -693,7 +696,7 @@ def __repr__(self) -> str: @raises_not_implemented def set_code( # type: ignore[empty-body] - self, address: AddressType, code: ContractCode + self, address: "AddressType", code: "ContractCode" ) -> bool: """ Change the code of a smart contract, for development purposes. @@ -706,7 +709,7 @@ def set_code( # type: ignore[empty-body] @raises_not_implemented def set_storage( # type: ignore[empty-body] - self, address: AddressType, slot: int, value: HexBytes + self, address: "AddressType", slot: int, value: "HexBytes" ): """ Sets the raw value of a storage slot of a contract. @@ -718,7 +721,7 @@ def set_storage( # type: ignore[empty-body] """ @raises_not_implemented - def unlock_account(self, address: AddressType) -> bool: # type: ignore[empty-body] + def unlock_account(self, address: "AddressType") -> bool: # type: ignore[empty-body] """ Ask the provider to allow an address to submit transactions without validating signatures. This feature is intended to be subclassed by a @@ -736,7 +739,7 @@ def unlock_account(self, address: AddressType) -> bool: # type: ignore[empty-bo """ @raises_not_implemented - def relock_account(self, address: AddressType): + def relock_account(self, address: "AddressType"): """ Stop impersonating an account. @@ -746,13 +749,13 @@ def relock_account(self, address: AddressType): @raises_not_implemented def get_transaction_trace( # type: ignore[empty-body] - self, txn_hash: Union[HexBytes, str] - ) -> TraceAPI: + self, txn_hash: Union["HexBytes", str] + ) -> "TraceAPI": """ Provide a detailed description of opcodes. Args: - transaction_hash (Union[HexBytes, str]): The hash of a transaction + txn_hash (Union[HexBytes, str]): The hash of a transaction to trace. Returns: @@ -794,12 +797,12 @@ def poll_blocks( # type: ignore[empty-body] def poll_logs( # type: ignore[empty-body] self, stop_block: Optional[int] = None, - address: Optional[AddressType] = None, + address: Optional["AddressType"] = None, topics: Optional[list[Union[str, list[str]]]] = None, required_confirmations: Optional[int] = None, new_block_timeout: Optional[int] = None, - events: Optional[list[EventABI]] = None, - ) -> Iterator[ContractLog]: + events: Optional[list["EventABI"]] = None, + ) -> Iterator["ContractLog"]: """ Poll new blocks. Optionally set a start block to include historical blocks. @@ -874,11 +877,11 @@ class TestProviderAPI(ProviderAPI): """ @cached_property - def test_config(self) -> PluginConfig: + def test_config(self) -> "PluginConfig": return self.config_manager.get_config("test") @abstractmethod - def snapshot(self) -> SnapshotID: + def snapshot(self) -> "SnapshotID": """ Record the current state of the blockchain with intent to later call the method :meth:`~ape.managers.chain.ChainManager.revert` @@ -889,7 +892,7 @@ def snapshot(self) -> SnapshotID: """ @abstractmethod - def restore(self, snapshot_id: SnapshotID): + def restore(self, snapshot_id: "SnapshotID"): """ Regress the current call using the given snapshot ID. Allows developers to go back to a previous state. diff --git a/src/ape/api/trace.py b/src/ape/api/trace.py index f4398fb58b..009f65e3a8 100644 --- a/src/ape/api/trace.py +++ b/src/ape/api/trace.py @@ -1,11 +1,13 @@ import sys from abc import abstractmethod from collections.abc import Iterator, Sequence -from typing import IO, Any, Optional +from typing import IO, TYPE_CHECKING, Any, Optional -from ape.types.trace import ContractFunctionPath, GasReport from ape.utils.basemodel import BaseInterfaceModel +if TYPE_CHECKING: + from ape.types.trace import ContractFunctionPath, GasReport + class TraceAPI(BaseInterfaceModel): """ @@ -22,7 +24,7 @@ def show(self, verbose: bool = False, file: IO[str] = sys.stdout): @abstractmethod def get_gas_report( self, exclude: Optional[Sequence["ContractFunctionPath"]] = None - ) -> GasReport: + ) -> "GasReport": """ Get the gas report. """ diff --git a/src/ape/api/transactions.py b/src/ape/api/transactions.py index 8a68da6163..930ac26b10 100644 --- a/src/ape/api/transactions.py +++ b/src/ape/api/transactions.py @@ -2,18 +2,16 @@ import time from abc import abstractmethod from collections.abc import Iterator -from datetime import datetime +from datetime import datetime as datetime_type from functools import cached_property from typing import IO, TYPE_CHECKING, Any, NoReturn, Optional, Union from eth_pydantic_types import HexBytes, HexStr from eth_utils import is_hex, to_hex, to_int -from ethpm_types.abi import EventABI, MethodABI from pydantic import ConfigDict, field_validator from pydantic.fields import Field from tqdm import tqdm # type: ignore -from ape.api.explorers import ExplorerAPI from ape.exceptions import ( NetworkError, ProviderNotConnectedError, @@ -24,17 +22,20 @@ from ape.logging import logger from ape.types.address import AddressType from ape.types.basic import HexInt -from ape.types.events import ContractLogContainer from ape.types.gas import AutoGasLimit from ape.types.signatures import TransactionSignature -from ape.types.trace import SourceTraceback from ape.utils.basemodel import BaseInterfaceModel, ExtraAttributesMixin, ExtraModelAttributes from ape.utils.misc import log_instead_of_fail, raises_not_implemented if TYPE_CHECKING: + from ethpm_types.abi import EventABI, MethodABI + + from ape.api.explorers import ExplorerAPI from ape.api.providers import BlockAPI from ape.api.trace import TraceAPI from ape.contracts import ContractEvent + from ape.types.events import ContractLogContainer + from ape.types.trace import SourceTraceback class TransactionAPI(BaseInterfaceModel): @@ -352,7 +353,7 @@ def trace(self) -> "TraceAPI": return self.provider.get_transaction_trace(self.txn_hash) @property - def _explorer(self) -> Optional[ExplorerAPI]: + def _explorer(self) -> Optional["ExplorerAPI"]: return self.provider.network.explorer @property @@ -377,11 +378,11 @@ def timestamp(self) -> int: return self.block.timestamp @property - def datetime(self) -> datetime: + def datetime(self) -> "datetime_type": return self.block.datetime @cached_property - def events(self) -> ContractLogContainer: + def events(self) -> "ContractLogContainer": """ All the events that were emitted from this call. """ @@ -392,9 +393,9 @@ def events(self) -> ContractLogContainer: def decode_logs( self, abi: Optional[ - Union[list[Union[EventABI, "ContractEvent"]], Union[EventABI, "ContractEvent"]] + Union[list[Union["EventABI", "ContractEvent"]], Union["EventABI", "ContractEvent"]] ] = None, - ) -> ContractLogContainer: + ) -> "ContractLogContainer": """ Decode the logs on the receipt. @@ -482,7 +483,7 @@ def _await_confirmations(self): time.sleep(time_to_sleep) @property - def method_called(self) -> Optional[MethodABI]: + def method_called(self) -> Optional["MethodABI"]: """ The method ABI of the method called to produce this receipt. """ @@ -502,7 +503,7 @@ def return_value(self) -> Any: @property @raises_not_implemented - def source_traceback(self) -> SourceTraceback: # type: ignore[empty-body] + def source_traceback(self) -> "SourceTraceback": # type: ignore[empty-body] """ A Pythonic style traceback for both failing and non-failing receipts. Requires a provider that implements diff --git a/src/ape/cli/commands.py b/src/ape/cli/commands.py index fb4d305363..63a8a2f246 100644 --- a/src/ape/cli/commands.py +++ b/src/ape/cli/commands.py @@ -3,17 +3,18 @@ from typing import TYPE_CHECKING, Any, Optional import click -from click import Context from ape.cli.choices import _NONE_NETWORK, NetworkChoice from ape.exceptions import NetworkError if TYPE_CHECKING: + from click import Context + from ape.api.networks import ProviderContextManager from ape.api.providers import ProviderAPI -def get_param_from_ctx(ctx: Context, param: str) -> Optional[Any]: +def get_param_from_ctx(ctx: "Context", param: str) -> Optional[Any]: if value := ctx.params.get(param): return value @@ -24,7 +25,7 @@ def get_param_from_ctx(ctx: Context, param: str) -> Optional[Any]: return None -def parse_network(ctx: Context) -> Optional["ProviderContextManager"]: +def parse_network(ctx: "Context") -> Optional["ProviderContextManager"]: from ape.utils.basemodel import ManagerAccessMixin as access interactive = get_param_from_ctx(ctx, "interactive") @@ -70,7 +71,7 @@ def __init__(self, *args, **kwargs): self._network_callback = kwargs.pop("network_callback", None) super().__init__(*args, **kwargs) - def parse_args(self, ctx: Context, args: list[str]) -> list[str]: + def parse_args(self, ctx: "Context", args: list[str]) -> list[str]: arguments = args # Renamed for better pdb support. provider_module = import_module("ape.api.providers") base_type = provider_module.ProviderAPI if self._use_cls_types else str @@ -96,7 +97,7 @@ def parse_args(self, ctx: Context, args: list[str]) -> list[str]: return super().parse_args(ctx, arguments) - def invoke(self, ctx: Context) -> Any: + def invoke(self, ctx: "Context") -> Any: if self.callback is None: return @@ -106,7 +107,7 @@ def invoke(self, ctx: Context) -> Any: else: return self._invoke(ctx) - def _invoke(self, ctx: Context, provider: Optional["ProviderAPI"] = None): + def _invoke(self, ctx: "Context", provider: Optional["ProviderAPI"] = None): # Will be put back with correct value if needed. # Else, causes issues. ctx.params.pop("network", None) diff --git a/src/ape/contracts/base.py b/src/ape/contracts/base.py index 1dbe3c0d8f..9c73fb8158 100644 --- a/src/ape/contracts/base.py +++ b/src/ape/contracts/base.py @@ -4,13 +4,13 @@ from functools import cached_property, partial, singledispatchmethod from itertools import islice from pathlib import Path -from typing import Any, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union import click import pandas as pd from eth_pydantic_types import HexBytes from eth_utils import to_hex -from ethpm_types.abi import ConstructorABI, ErrorABI, EventABI, MethodABI +from ethpm_types.abi import EventABI, MethodABI from ethpm_types.contract_type import ABI_W_SELECTOR_T, ContractType from IPython.lib.pretty import for_type @@ -22,7 +22,6 @@ extract_fields, validate_and_expand_columns, ) -from ape.api.transactions import ReceiptAPI, TransactionAPI from ape.exceptions import ( ApeAttributeError, ArgumentsLengthError, @@ -49,12 +48,17 @@ ) from ape.utils.misc import log_instead_of_fail +if TYPE_CHECKING: + from ethpm_types.abi import ConstructorABI, ErrorABI + + from ape.api.transactions import ReceiptAPI, TransactionAPI + class ContractConstructor(ManagerAccessMixin): def __init__( self, deployment_bytecode: HexBytes, - abi: ConstructorABI, + abi: "ConstructorABI", ) -> None: self.deployment_bytecode = deployment_bytecode self.abi = abi @@ -76,14 +80,14 @@ def decode_input(self, calldata: bytes) -> tuple[str, dict[str, Any]]: decoded_inputs = self.provider.network.ecosystem.decode_calldata(self.abi, calldata) return self.abi.selector, decoded_inputs - def serialize_transaction(self, *args, **kwargs) -> TransactionAPI: + def serialize_transaction(self, *args, **kwargs) -> "TransactionAPI": arguments = self.conversion_manager.convert_method_args(self.abi, args) converted_kwargs = self.conversion_manager.convert_method_kwargs(kwargs) return self.provider.network.ecosystem.encode_deployment( self.deployment_bytecode, self.abi, *arguments, **converted_kwargs ) - def __call__(self, private: bool = False, *args, **kwargs) -> ReceiptAPI: + def __call__(self, private: bool = False, *args, **kwargs) -> "ReceiptAPI": txn = self.serialize_transaction(*args, **kwargs) if "sender" in kwargs and isinstance(kwargs["sender"], AccountAPI): @@ -109,7 +113,7 @@ def __init__(self, abi: MethodABI, address: AddressType) -> None: def __repr__(self) -> str: return self.abi.signature - def serialize_transaction(self, *args, **kwargs) -> TransactionAPI: + def serialize_transaction(self, *args, **kwargs) -> "TransactionAPI": converted_kwargs = self.conversion_manager.convert_method_kwargs(kwargs) return self.provider.network.ecosystem.encode_transaction( self.address, self.abi, *args, **converted_kwargs @@ -343,7 +347,7 @@ def __init__(self, abi: MethodABI, address: AddressType) -> None: def __repr__(self) -> str: return self.abi.signature - def serialize_transaction(self, *args, **kwargs) -> TransactionAPI: + def serialize_transaction(self, *args, **kwargs) -> "TransactionAPI": if "sender" in kwargs and isinstance(kwargs["sender"], (ContractInstance, Address)): # Automatically impersonate contracts (if API available) when sender kwargs["sender"] = self.account_manager.test_accounts[kwargs["sender"].address] @@ -354,7 +358,7 @@ def serialize_transaction(self, *args, **kwargs) -> TransactionAPI: self.address, self.abi, *arguments, **converted_kwargs ) - def __call__(self, *args, **kwargs) -> ReceiptAPI: + def __call__(self, *args, **kwargs) -> "ReceiptAPI": txn = self.serialize_transaction(*args, **kwargs) private = kwargs.get("private", False) @@ -370,7 +374,7 @@ def __call__(self, *args, **kwargs) -> ReceiptAPI: class ContractTransactionHandler(ContractMethodHandler): - def as_transaction(self, *args, **kwargs) -> TransactionAPI: + def as_transaction(self, *args, **kwargs) -> "TransactionAPI": """ Get a :class:`~ape.api.transactions.TransactionAPI` for this contract method invocation. This is useful @@ -421,7 +425,7 @@ def call(self) -> ContractCallHandler: return ContractCallHandler(self.contract, self.abis) - def __call__(self, *args, **kwargs) -> ReceiptAPI: + def __call__(self, *args, **kwargs) -> "ReceiptAPI": contract_transaction = self._as_transaction(*args) if "sender" not in kwargs and self.account_manager.default_sender is not None: kwargs["sender"] = self.account_manager.default_sender @@ -727,7 +731,7 @@ def range( ) yield from self.query_manager.query(contract_event_query) # type: ignore - def from_receipt(self, receipt: ReceiptAPI) -> list[ContractLog]: + def from_receipt(self, receipt: "ReceiptAPI") -> list[ContractLog]: """ Get all the events from the given receipt. @@ -864,7 +868,7 @@ def decode_input(self, calldata: bytes) -> tuple[str, dict[str, Any]]: input_dict = ecosystem.decode_calldata(method, rest_calldata) return method.selector, input_dict - def _create_custom_error_type(self, abi: ErrorABI, **kwargs) -> type[CustomError]: + def _create_custom_error_type(self, abi: "ErrorABI", **kwargs) -> type[CustomError]: def exec_body(namespace): namespace["abi"] = abi namespace["contract"] = self @@ -929,7 +933,7 @@ def __init__( (txn_hash if isinstance(txn_hash, str) else to_hex(txn_hash)) if txn_hash else None ) - def __call__(self, *args, **kwargs) -> ReceiptAPI: + def __call__(self, *args, **kwargs) -> "ReceiptAPI": has_value = kwargs.get("value") has_data = kwargs.get("data") or kwargs.get("input") has_non_payable_fallback = ( @@ -953,7 +957,7 @@ def __call__(self, *args, **kwargs) -> ReceiptAPI: return super().__call__(*args, **kwargs) @classmethod - def from_receipt(cls, receipt: ReceiptAPI, contract_type: ContractType) -> "ContractInstance": + def from_receipt(cls, receipt: "ReceiptAPI", contract_type: ContractType) -> "ContractInstance": """ Create a contract instance from the contract deployment receipt. """ @@ -1074,7 +1078,7 @@ def call_view_method(self, method_name: str, *args, **kwargs) -> Any: name = self.contract_type.name or ContractType.__name__ raise ApeAttributeError(f"'{name}' has no attribute '{method_name}'.") - def invoke_transaction(self, method_name: str, *args, **kwargs) -> ReceiptAPI: + def invoke_transaction(self, method_name: str, *args, **kwargs) -> "ReceiptAPI": """ Call a contract's function directly using the method_name. This function is for non-view function's which may change @@ -1183,7 +1187,7 @@ def _events_(self) -> dict[str, list[ContractEvent]]: @cached_property def _errors_(self) -> dict[str, list[type[CustomError]]]: - abis: dict[str, list[ErrorABI]] = {} + abis: dict[str, list["ErrorABI"]] = {} try: for abi in self.contract_type.errors: @@ -1434,7 +1438,7 @@ def constructor(self) -> ContractConstructor: deployment_bytecode=self.contract_type.get_deployment_bytecode() or HexBytes(""), ) - def __call__(self, *args, **kwargs) -> TransactionAPI: + def __call__(self, *args, **kwargs) -> "TransactionAPI": args_length = len(args) inputs_length = ( len(self.constructor.abi.inputs) @@ -1500,7 +1504,7 @@ def deploy(self, *args, publish: bool = False, **kwargs) -> ContractInstance: instance.base_path = self.base_path or self.local_project.contracts_folder return instance - def _cache_wrap(self, function: Callable) -> ReceiptAPI: + def _cache_wrap(self, function: Callable) -> "ReceiptAPI": """ A helper method to ensure a contract type is cached as early on as possible to help enrich errors from ``deploy()`` transactions @@ -1525,7 +1529,7 @@ def _cache_wrap(self, function: Callable) -> ReceiptAPI: raise # The error after caching. - def declare(self, *args, **kwargs) -> ReceiptAPI: + def declare(self, *args, **kwargs) -> "ReceiptAPI": transaction = self.provider.network.ecosystem.encode_contract_blueprint( self.contract_type, *args, **kwargs ) diff --git a/src/ape/logging.py b/src/ape/logging.py index a874c3353f..34fa511e1e 100644 --- a/src/ape/logging.py +++ b/src/ape/logging.py @@ -287,7 +287,9 @@ def _format_logger( def get_logger( - name: str, fmt: Optional[str] = None, handlers: Optional[Sequence[Callable[[str], str]]] = None + name: str, + fmt: Optional[str] = None, + handlers: Optional[Sequence[Callable[[str], str]]] = None, ) -> logging.Logger: """ Get a logger with the given ``name`` and configure it for usage with Ape. diff --git a/src/ape/managers/accounts.py b/src/ape/managers/accounts.py index 858cf0de8a..e45939204e 100644 --- a/src/ape/managers/accounts.py +++ b/src/ape/managers/accounts.py @@ -25,7 +25,7 @@ @contextlib.contextmanager def _use_sender( account: Union[AccountAPI, TestAccountAPI] -) -> Generator[AccountAPI, TestAccountAPI, None]: +) -> "Generator[AccountAPI, TestAccountAPI, None]": try: _DEFAULT_SENDERS.append(account) yield account @@ -160,7 +160,7 @@ def stop_impersonating(self, address: AddressType): def generate_test_account(self, container_name: str = "test") -> TestAccountAPI: return self.containers[container_name].generate_account() - def use_sender(self, account_id: Union[TestAccountAPI, AddressType, int]) -> ContextManager: + def use_sender(self, account_id: Union[TestAccountAPI, AddressType, int]) -> "ContextManager": account = account_id if isinstance(account_id, TestAccountAPI) else self[account_id] return _use_sender(account) @@ -412,7 +412,7 @@ def __contains__(self, address: AddressType) -> bool: def use_sender( self, account_id: Union[AccountAPI, AddressType, str, int], - ) -> ContextManager: + ) -> "ContextManager": if not isinstance(account_id, AccountAPI): if isinstance(account_id, int) or is_hex(account_id): account = self[account_id] diff --git a/src/ape/managers/chain.py b/src/ape/managers/chain.py index fbfde6f326..4c14bad123 100644 --- a/src/ape/managers/chain.py +++ b/src/ape/managers/chain.py @@ -6,13 +6,11 @@ from functools import partial, singledispatchmethod from pathlib import Path from statistics import mean, median -from typing import IO, Optional, Union, cast +from typing import IO, TYPE_CHECKING, Optional, Union, cast import pandas as pd -from eth_pydantic_types import HexBytes from ethpm_types import ABI, ContractType from rich.box import SIMPLE -from rich.console import Console as RichConsole from rich.table import Table from ape.api.address import BaseAddress @@ -42,11 +40,16 @@ from ape.logging import get_rich_console, logger from ape.managers.base import BaseManager from ape.types.address import AddressType -from ape.types.trace import GasReport, SourceTraceback -from ape.types.vm import SnapshotID from ape.utils.basemodel import BaseInterfaceModel from ape.utils.misc import is_evm_precompile, is_zero_hex, log_instead_of_fail, nonreentrant +if TYPE_CHECKING: + from eth_pydantic_types import HexBytes + from rich.console import Console as RichConsole + + from ape.types.trace import GasReport, SourceTraceback + from ape.types.vm import SnapshotID + class BlockContainer(BaseManager): """ @@ -1131,7 +1134,7 @@ def instance_at( self, address: Union[str, AddressType], contract_type: Optional[ContractType] = None, - txn_hash: Optional[Union[str, HexBytes]] = None, + txn_hash: Optional[Union[str, "HexBytes"]] = None, abi: Optional[Union[list[ABI], dict, str, Path]] = None, ) -> ContractInstance: """ @@ -1413,7 +1416,7 @@ class ReportManager(BaseManager): **NOTE**: This class is not part of the public API. """ - def show_gas(self, report: GasReport, file: Optional[IO[str]] = None): + def show_gas(self, report: "GasReport", file: Optional[IO[str]] = None): tables: list[Table] = [] for contract_id, method_calls in report.items(): @@ -1454,16 +1457,16 @@ def show_gas(self, report: GasReport, file: Optional[IO[str]] = None): self.echo(*tables, file=file) def echo( - self, *rich_items, file: Optional[IO[str]] = None, console: Optional[RichConsole] = None + self, *rich_items, file: Optional[IO[str]] = None, console: Optional["RichConsole"] = None ): console = console or get_rich_console(file) console.print(*rich_items) def show_source_traceback( self, - traceback: SourceTraceback, + traceback: "SourceTraceback", file: Optional[IO[str]] = None, - console: Optional[RichConsole] = None, + console: Optional["RichConsole"] = None, failing: bool = True, ): console = console or get_rich_console(file) @@ -1471,7 +1474,7 @@ def show_source_traceback( console.print(str(traceback), style=style) def show_events( - self, events: list, file: Optional[IO[str]] = None, console: Optional[RichConsole] = None + self, events: list, file: Optional[IO[str]] = None, console: Optional["RichConsole"] = None ): console = console or get_rich_console(file) console.print("Events emitted:") @@ -1587,7 +1590,7 @@ def __repr__(self) -> str: cls_name = getattr(type(self), "__name__", ChainManager.__name__) return f"<{cls_name} ({props})>" - def snapshot(self) -> SnapshotID: + def snapshot(self) -> "SnapshotID": """ Record the current state of the blockchain with intent to later call the method :meth:`~ape.managers.chain.ChainManager.revert` @@ -1607,7 +1610,7 @@ def snapshot(self) -> SnapshotID: return snapshot_id - def restore(self, snapshot_id: Optional[SnapshotID] = None): + def restore(self, snapshot_id: Optional["SnapshotID"] = None): """ Regress the current call using the given snapshot ID. Allows developers to go back to a previous state. diff --git a/src/ape/managers/compilers.py b/src/ape/managers/compilers.py index 09cc989fa4..b37a2cefec 100644 --- a/src/ape/managers/compilers.py +++ b/src/ape/managers/compilers.py @@ -5,10 +5,7 @@ from typing import TYPE_CHECKING, Any, Optional, Union from eth_pydantic_types import HexBytes -from ethpm_types import ContractType -from ethpm_types.source import Content -from ape.api.compiler import CompilerAPI from ape.contracts import ContractContainer from ape.exceptions import CompilerError, ContractLogicError, CustomError from ape.logging import logger @@ -23,6 +20,10 @@ from ape.utils.os import get_full_extension if TYPE_CHECKING: + from ethpm_types.contract_type import ContractType + from ethpm_types.source import Content + + from ape.api.compiler import CompilerAPI from ape.managers.project import ProjectManager @@ -39,7 +40,7 @@ class CompilerManager(BaseManager, ExtraAttributesMixin): from ape import compilers # "compilers" is the CompilerManager singleton """ - _registered_compilers_cache: dict[Path, dict[str, CompilerAPI]] = {} + _registered_compilers_cache: dict[Path, dict[str, "CompilerAPI"]] = {} @log_instead_of_fail(default="") def __repr__(self) -> str: @@ -59,7 +60,7 @@ def __getattr__(self, attr_name: str) -> Any: return get_attribute_with_extras(self, attr_name) @cached_property - def registered_compilers(self) -> dict[str, CompilerAPI]: + def registered_compilers(self) -> dict[str, "CompilerAPI"]: """ Each compile-able file extension mapped to its respective :class:`~ape.api.compiler.CompilerAPI` instance. @@ -80,7 +81,7 @@ def registered_compilers(self) -> dict[str, CompilerAPI]: return registered_compilers - def get_compiler(self, name: str, settings: Optional[dict] = None) -> Optional[CompilerAPI]: + def get_compiler(self, name: str, settings: Optional[dict] = None) -> Optional["CompilerAPI"]: for compiler in self.registered_compilers.values(): if compiler.name != name: continue @@ -98,7 +99,7 @@ def compile( contract_filepaths: Union[Path, str, Iterable[Union[Path, str]]], project: Optional["ProjectManager"] = None, settings: Optional[dict] = None, - ) -> Iterator[ContractType]: + ) -> Iterator["ContractType"]: """ Invoke :meth:`ape.ape.compiler.CompilerAPI.compile` for each of the given files. For example, use the `ape-solidity plugin `__ @@ -333,7 +334,7 @@ def get_custom_error(self, err: ContractLogicError) -> Optional[CustomError]: except NotImplementedError: return None - def flatten_contract(self, path: Path, **kwargs) -> Content: + def flatten_contract(self, path: Path, **kwargs) -> "Content": """ Get the flattened version of a contract via its source path. Delegates to the matching :class:`~ape.api.compilers.CompilerAPI`. diff --git a/src/ape/managers/config.py b/src/ape/managers/config.py index 7739703008..92d94bf53a 100644 --- a/src/ape/managers/config.py +++ b/src/ape/managers/config.py @@ -3,9 +3,7 @@ from contextlib import contextmanager from functools import cached_property from pathlib import Path -from typing import Any, Optional - -from ethpm_types import PackageManifest +from typing import TYPE_CHECKING, Any, Optional from ape.api.config import ApeConfig from ape.managers.base import BaseManager @@ -20,6 +18,10 @@ from ape.utils.os import create_tempdir, in_tempdir from ape.utils.rpc import RPCHeaders +if TYPE_CHECKING: + from ethpm_types import PackageManifest + + CONFIG_FILE_NAME = "ape-config.yaml" @@ -93,7 +95,7 @@ def merge_with_global(self, project_config: ApeConfig) -> ApeConfig: return ApeConfig.model_validate(merged_data) @classmethod - def extract_config(cls, manifest: PackageManifest, **overrides) -> ApeConfig: + def extract_config(cls, manifest: "PackageManifest", **overrides) -> ApeConfig: """ Calculate the ape-config data from a package manifest. diff --git a/src/ape/managers/converters.py b/src/ape/managers/converters.py index ca141b53ff..85c48f6e69 100644 --- a/src/ape/managers/converters.py +++ b/src/ape/managers/converters.py @@ -3,7 +3,7 @@ from datetime import datetime, timedelta, timezone from decimal import Decimal from functools import cached_property -from typing import Any, Union +from typing import TYPE_CHECKING, Any, Union from dateutil.parser import parse from eth_pydantic_types import Address, HexBytes @@ -16,7 +16,6 @@ to_checksum_address, to_int, ) -from ethpm_types import ConstructorABI, EventABI, MethodABI from ape.api.address import BaseAddress from ape.api.convert import ConverterAPI @@ -28,6 +27,9 @@ from .base import BaseManager +if TYPE_CHECKING: + from ethpm_types import ConstructorABI, EventABI, MethodABI + class HexConverter(ConverterAPI): """ @@ -400,7 +402,7 @@ def convert(self, value: Any, to_type: Union[type, tuple, list]) -> Any: def convert_method_args( self, - abi: Union[MethodABI, ConstructorABI, EventABI], + abi: Union["MethodABI", "ConstructorABI", "EventABI"], arguments: Sequence[Any], ): input_types = [i.canonical_type for i in abi.inputs] diff --git a/src/ape/managers/networks.py b/src/ape/managers/networks.py index b116140690..8297ce43a3 100644 --- a/src/ape/managers/networks.py +++ b/src/ape/managers/networks.py @@ -1,9 +1,8 @@ from collections.abc import Collection, Iterator from functools import cached_property -from typing import Optional, Union +from typing import TYPE_CHECKING, Optional, Union from ape.api.networks import EcosystemAPI, NetworkAPI, ProviderContextManager -from ape.api.providers import ProviderAPI from ape.exceptions import EcosystemNotFoundError, NetworkError, NetworkNotFoundError from ape.managers.base import BaseManager from ape.utils.basemodel import ( @@ -13,9 +12,12 @@ only_raise_attribute_error, ) from ape.utils.misc import _dict_overlay, log_instead_of_fail -from ape.utils.rpc import RPCHeaders from ape_ethereum.provider import EthereumNodeProvider +if TYPE_CHECKING: + from ape.api.providers import ProviderAPI + from ape.utils.rpc import RPCHeaders + class NetworkManager(BaseManager, ExtraAttributesMixin): """ @@ -32,7 +34,7 @@ class NetworkManager(BaseManager, ExtraAttributesMixin): ... """ - _active_provider: Optional[ProviderAPI] = None + _active_provider: Optional["ProviderAPI"] = None _default_ecosystem_name: Optional[str] = None # For adhoc adding custom networks, or incorporating some defined @@ -47,7 +49,7 @@ def __repr__(self) -> str: return f"<{content}>" @property - def active_provider(self) -> Optional[ProviderAPI]: + def active_provider(self) -> Optional["ProviderAPI"]: """ The currently connected provider if one exists. Otherwise, returns ``None``. """ @@ -55,7 +57,7 @@ def active_provider(self) -> Optional[ProviderAPI]: return self._active_provider @active_provider.setter - def active_provider(self, new_value: ProviderAPI): + def active_provider(self, new_value: "ProviderAPI"): self._active_provider = new_value @property @@ -88,7 +90,7 @@ def ecosystem(self) -> EcosystemAPI: def get_request_headers( self, ecosystem_name: str, network_name: str, provider_name: str - ) -> RPCHeaders: + ) -> "RPCHeaders": """ All request headers to be used when connecting to this network. """ @@ -249,9 +251,9 @@ def _plugin_ecosystems(self) -> dict[str, EcosystemAPI]: def create_custom_provider( self, connection_str: str, - provider_cls: type[ProviderAPI] = EthereumNodeProvider, + provider_cls: type["ProviderAPI"] = EthereumNodeProvider, provider_name: Optional[str] = None, - ) -> ProviderAPI: + ) -> "ProviderAPI": """ Create a custom connection to a URI using the EthereumNodeProvider provider. **NOTE**: This provider will assume EVM-like behavior and this is generally not recommended. @@ -444,7 +446,7 @@ def get_provider_from_choice( self, network_choice: Optional[str] = None, provider_settings: Optional[dict] = None, - ) -> ProviderAPI: + ) -> "ProviderAPI": """ Get a :class:`~ape.api.providers.ProviderAPI` from a network choice. A network choice is any value returned from diff --git a/src/ape/managers/query.py b/src/ape/managers/query.py index 86cd971b2c..c96cca06a6 100644 --- a/src/ape/managers/query.py +++ b/src/ape/managers/query.py @@ -14,7 +14,7 @@ QueryAPI, QueryType, ) -from ape.api.transactions import ReceiptAPI, TransactionAPI +from ape.api.transactions import ReceiptAPI, TransactionAPI # noqa: TC002 from ape.contracts.base import ContractLog, LogFilter from ape.exceptions import QueryEngineError from ape.logging import logger diff --git a/src/ape/pytest/config.py b/src/ape/pytest/config.py index 99f6db77c7..77825f338f 100644 --- a/src/ape/pytest/config.py +++ b/src/ape/pytest/config.py @@ -1,11 +1,12 @@ from functools import cached_property -from typing import Any, Optional, Union - -from _pytest.config import Config as PytestConfig +from typing import TYPE_CHECKING, Any, Optional, Union from ape.types.trace import ContractFunctionPath from ape.utils.basemodel import ManagerAccessMixin +if TYPE_CHECKING: + from _pytest.config import Config as PytestConfig + def _get_config_exclusions(config) -> list[ContractFunctionPath]: return [ @@ -21,7 +22,7 @@ class ConfigWrapper(ManagerAccessMixin): Pytest config object for ease-of-use and code-sharing. """ - def __init__(self, pytest_config: PytestConfig): + def __init__(self, pytest_config: "PytestConfig"): self.pytest_config = pytest_config @cached_property diff --git a/src/ape/pytest/coverage.py b/src/ape/pytest/coverage.py index 9adee6fa31..784a025f06 100644 --- a/src/ape/pytest/coverage.py +++ b/src/ape/pytest/coverage.py @@ -1,36 +1,39 @@ from collections.abc import Iterable from pathlib import Path -from typing import Callable, Optional, Union +from typing import TYPE_CHECKING, Callable, Optional, Union import click -from ethpm_types.abi import MethodABI -from ethpm_types.source import ContractSource from ape.logging import logger -from ape.managers.project import ProjectManager -from ape.pytest.config import ConfigWrapper from ape.types.coverage import CoverageProject, CoverageReport -from ape.types.trace import ContractFunctionPath, ControlFlow, SourceTraceback from ape.utils.basemodel import ManagerAccessMixin from ape.utils.misc import get_current_timestamp_ms from ape.utils.os import get_full_extension, get_relative_path from ape.utils.trace import parse_coverage_tables +if TYPE_CHECKING: + from ethpm_types.abi import MethodABI + from ethpm_types.source import ContractSource + + from ape.managers.project import ProjectManager + from ape.pytest.config import ConfigWrapper + from ape.types.trace import ContractFunctionPath, ControlFlow, SourceTraceback + class CoverageData(ManagerAccessMixin): def __init__( self, - project: ProjectManager, - sources: Union[Iterable[ContractSource], Callable[[], Iterable[ContractSource]]], + project: "ProjectManager", + sources: Union[Iterable["ContractSource"], Callable[[], Iterable["ContractSource"]]], ): self.project = project - self._sources: Union[Iterable[ContractSource], Callable[[], Iterable[ContractSource]]] = ( - sources - ) + self._sources: Union[ + Iterable["ContractSource"], Callable[[], Iterable["ContractSource"]] + ] = sources self._report: Optional[CoverageReport] = None @property - def sources(self) -> list[ContractSource]: + def sources(self) -> list["ContractSource"]: if isinstance(self._sources, list): return self._sources @@ -138,8 +141,8 @@ def cover( class CoverageTracker(ManagerAccessMixin): def __init__( self, - config_wrapper: ConfigWrapper, - project: Optional[ProjectManager] = None, + config_wrapper: "ConfigWrapper", + project: Optional["ProjectManager"] = None, output_path: Optional[Path] = None, ): self.config_wrapper = config_wrapper @@ -173,7 +176,7 @@ def enabled(self) -> bool: return self.config_wrapper.track_coverage @property - def exclusions(self) -> list[ContractFunctionPath]: + def exclusions(self) -> list["ContractFunctionPath"]: return self.config_wrapper.coverage_exclusions def reset(self): @@ -182,7 +185,7 @@ def reset(self): def cover( self, - traceback: SourceTraceback, + traceback: "SourceTraceback", contract: Optional[str] = None, function: Optional[str] = None, ): @@ -259,7 +262,7 @@ def cover( def _cover( self, - control_flow: ControlFlow, + control_flow: "ControlFlow", last_path: Optional[Path] = None, last_pcs: Optional[set[int]] = None, last_call: Optional[str] = None, @@ -281,7 +284,7 @@ def _cover( inc_fn = last_call is None or last_call != control_flow.closure.full_name return self.data.cover(control_flow.source_path, new_pcs, inc_fn_hits=inc_fn) - def hit_function(self, contract_source: ContractSource, method: MethodABI): + def hit_function(self, contract_source: "ContractSource", method: "MethodABI"): """ Another way to increment a function's hit count. Providers may not offer a way to trace calls but this method is available to still increment function diff --git a/src/ape/pytest/fixtures.py b/src/ape/pytest/fixtures.py index 80a77b0279..925a10d903 100644 --- a/src/ape/pytest/fixtures.py +++ b/src/ape/pytest/fixtures.py @@ -1,23 +1,25 @@ from collections.abc import Iterator from fnmatch import fnmatch from functools import cached_property -from typing import Optional +from typing import TYPE_CHECKING, Optional import pytest from eth_utils import to_hex -from ape.api.accounts import TestAccountAPI -from ape.api.transactions import ReceiptAPI from ape.exceptions import BlockNotFoundError, ChainError from ape.logging import logger -from ape.managers.chain import ChainManager -from ape.managers.networks import NetworkManager -from ape.managers.project import ProjectManager -from ape.pytest.config import ConfigWrapper -from ape.types.vm import SnapshotID from ape.utils.basemodel import ManagerAccessMixin from ape.utils.rpc import allow_disconnected +if TYPE_CHECKING: + from ape.api.accounts import TestAccountAPI + from ape.api.transactions import ReceiptAPI + from ape.managers.chain import ChainManager + from ape.managers.networks import NetworkManager + from ape.managers.project import ProjectManager + from ape.pytest.config import ConfigWrapper + from ape.types.vm import SnapshotID + class PytestApeFixtures(ManagerAccessMixin): # NOTE: Avoid including links, markdown, or rst in method-docs @@ -27,7 +29,7 @@ class PytestApeFixtures(ManagerAccessMixin): _supports_snapshot: bool = True receipt_capture: "ReceiptCapture" - def __init__(self, config_wrapper: ConfigWrapper, receipt_capture: "ReceiptCapture"): + def __init__(self, config_wrapper: "ConfigWrapper", receipt_capture: "ReceiptCapture"): self.config_wrapper = config_wrapper self.receipt_capture = receipt_capture @@ -40,7 +42,7 @@ def _track_transactions(self) -> bool: ) @pytest.fixture(scope="session") - def accounts(self) -> list[TestAccountAPI]: + def accounts(self) -> list["TestAccountAPI"]: """ A collection of pre-funded accounts. """ @@ -54,21 +56,21 @@ def compilers(self): return self.compiler_manager @pytest.fixture(scope="session") - def chain(self) -> ChainManager: + def chain(self) -> "ChainManager": """ Manipulate the blockchain, such as mine or change the pending timestamp. """ return self.chain_manager @pytest.fixture(scope="session") - def networks(self) -> NetworkManager: + def networks(self) -> "NetworkManager": """ Connect to other networks in your tests. """ return self.network_manager @pytest.fixture(scope="session") - def project(self) -> ProjectManager: + def project(self) -> "ProjectManager": """ Access contract types and dependencies. """ @@ -121,7 +123,7 @@ def _isolation(self) -> Iterator[None]: _function_isolation = pytest.fixture(_isolation, scope="function") @allow_disconnected - def _snapshot(self) -> Optional[SnapshotID]: + def _snapshot(self) -> Optional["SnapshotID"]: try: return self.chain_manager.snapshot() except NotImplementedError: @@ -135,7 +137,7 @@ def _snapshot(self) -> Optional[SnapshotID]: return None @allow_disconnected - def _restore(self, snapshot_id: SnapshotID): + def _restore(self, snapshot_id: "SnapshotID"): if snapshot_id not in self.chain_manager._snapshots[self.provider.chain_id]: return try: @@ -150,11 +152,11 @@ def _restore(self, snapshot_id: SnapshotID): class ReceiptCapture(ManagerAccessMixin): - config_wrapper: ConfigWrapper - receipt_map: dict[str, dict[str, ReceiptAPI]] = {} + config_wrapper: "ConfigWrapper" + receipt_map: dict[str, dict[str, "ReceiptAPI"]] = {} enter_blocks: list[int] = [] - def __init__(self, config_wrapper: ConfigWrapper): + def __init__(self, config_wrapper: "ConfigWrapper"): self.config_wrapper = config_wrapper def __enter__(self): diff --git a/src/ape/pytest/gas.py b/src/ape/pytest/gas.py index 3b8e0e63ce..1f37af2b68 100644 --- a/src/ape/pytest/gas.py +++ b/src/ape/pytest/gas.py @@ -1,16 +1,20 @@ -from typing import Optional +from typing import TYPE_CHECKING, Optional -from ethpm_types.abi import MethodABI -from ethpm_types.source import ContractSource from evm_trace.gas import merge_reports -from ape.api.trace import TraceAPI -from ape.pytest.config import ConfigWrapper -from ape.types.address import AddressType -from ape.types.trace import ContractFunctionPath, GasReport +from ape.types.trace import GasReport from ape.utils.basemodel import ManagerAccessMixin from ape.utils.trace import _exclude_gas, parse_gas_table +if TYPE_CHECKING: + from ethpm_types.abi import MethodABI + from ethpm_types.source import ContractSource + + from ape.api.trace import TraceAPI + from ape.pytest.config import ConfigWrapper + from ape.types.address import AddressType + from ape.types.trace import ContractFunctionPath + class GasTracker(ManagerAccessMixin): """ @@ -18,7 +22,7 @@ class GasTracker(ManagerAccessMixin): contracts in your test suite. """ - def __init__(self, config_wrapper: ConfigWrapper): + def __init__(self, config_wrapper: "ConfigWrapper"): self.config_wrapper = config_wrapper self.session_gas_report: Optional[GasReport] = None @@ -27,7 +31,7 @@ def enabled(self) -> bool: return self.config_wrapper.track_gas @property - def gas_exclusions(self) -> list[ContractFunctionPath]: + def gas_exclusions(self) -> list["ContractFunctionPath"]: return self.config_wrapper.gas_exclusions def show_session_gas(self) -> bool: @@ -38,7 +42,7 @@ def show_session_gas(self) -> bool: self.chain_manager._reports.echo(*tables) return True - def append_gas(self, trace: TraceAPI, contract_address: AddressType): + def append_gas(self, trace: "TraceAPI", contract_address: "AddressType"): contract_type = self.chain_manager.contracts.get(contract_address) if not contract_type: # Skip unknown contracts. @@ -47,7 +51,7 @@ def append_gas(self, trace: TraceAPI, contract_address: AddressType): report = trace.get_gas_report(exclude=self.gas_exclusions) self._merge(report) - def append_toplevel_gas(self, contract: ContractSource, method: MethodABI, gas_cost: int): + def append_toplevel_gas(self, contract: "ContractSource", method: "MethodABI", gas_cost: int): exclusions = self.gas_exclusions or [] if (contract_id := contract.contract_type.name) and not _exclude_gas( exclusions, contract_id, method.selector diff --git a/src/ape/pytest/plugin.py b/src/ape/pytest/plugin.py index 7cfe0c106f..72d09c1809 100644 --- a/src/ape/pytest/plugin.py +++ b/src/ape/pytest/plugin.py @@ -1,8 +1,7 @@ import sys from pathlib import Path -from typing import Optional +from typing import TYPE_CHECKING, Optional -from ape.api.networks import EcosystemAPI from ape.exceptions import ConfigError from ape.pytest.config import ConfigWrapper from ape.pytest.coverage import CoverageTracker @@ -11,8 +10,11 @@ from ape.pytest.runners import PytestApeRunner from ape.utils.basemodel import ManagerAccessMixin +if TYPE_CHECKING: + from ape.api.networks import EcosystemAPI -def _get_default_network(ecosystem: Optional[EcosystemAPI] = None) -> str: + +def _get_default_network(ecosystem: Optional["EcosystemAPI"] = None) -> str: if ecosystem is None: ecosystem = ManagerAccessMixin.network_manager.default_ecosystem diff --git a/src/ape/pytest/runners.py b/src/ape/pytest/runners.py index 98a708e99b..e41f724027 100644 --- a/src/ape/pytest/runners.py +++ b/src/ape/pytest/runners.py @@ -1,30 +1,32 @@ from pathlib import Path -from typing import Optional +from typing import TYPE_CHECKING, Optional import click import pytest from _pytest._code.code import Traceback as PytestTraceback from rich import print as rich_print -from ape.api.networks import ProviderContextManager from ape.exceptions import ConfigError from ape.logging import LogLevel -from ape.pytest.config import ConfigWrapper -from ape.pytest.coverage import CoverageTracker -from ape.pytest.fixtures import ReceiptCapture -from ape.pytest.gas import GasTracker -from ape.types.coverage import CoverageReport from ape.utils.basemodel import ManagerAccessMixin from ape_console._cli import console +if TYPE_CHECKING: + from ape.api.networks import ProviderContextManager + from ape.pytest.config import ConfigWrapper + from ape.pytest.coverage import CoverageTracker + from ape.pytest.fixtures import ReceiptCapture + from ape.pytest.gas import GasTracker + from ape.types.coverage import CoverageReport + class PytestApeRunner(ManagerAccessMixin): def __init__( self, - config_wrapper: ConfigWrapper, - receipt_capture: ReceiptCapture, - gas_tracker: GasTracker, - coverage_tracker: CoverageTracker, + config_wrapper: "ConfigWrapper", + receipt_capture: "ReceiptCapture", + gas_tracker: "GasTracker", + coverage_tracker: "CoverageTracker", ): self.config_wrapper = config_wrapper self.receipt_capture = receipt_capture @@ -36,11 +38,11 @@ def __init__( self.coverage_tracker = coverage_tracker @property - def _provider_context(self) -> ProviderContextManager: + def _provider_context(self) -> "ProviderContextManager": return self.network_manager.parse_network_choice(self.config_wrapper.network) @property - def _coverage_report(self) -> Optional[CoverageReport]: + def _coverage_report(self) -> Optional["CoverageReport"]: return self.coverage_tracker.data.report if self.coverage_tracker.data else None def pytest_exception_interact(self, report, call): diff --git a/src/ape/types/address.py b/src/ape/types/address.py index 1b68301a85..8572825edd 100644 --- a/src/ape/types/address.py +++ b/src/ape/types/address.py @@ -1,12 +1,15 @@ -from typing import Annotated, Any, Optional, Union +from typing import TYPE_CHECKING, Annotated, Any, Optional, Union from eth_pydantic_types import Address as _Address from eth_pydantic_types import HashBytes20, HashStr20 from eth_typing import ChecksumAddress -from pydantic_core.core_schema import ValidationInfo from ape.utils.basemodel import ManagerAccessMixin +if TYPE_CHECKING: + from pydantic_core.core_schema import ValidationInfo + + RawAddress = Union[str, int, HashStr20, HashBytes20] """ A raw data-type representation of an address. @@ -23,7 +26,7 @@ class _AddressValidator(_Address, ManagerAccessMixin): """ @classmethod - def __eth_pydantic_validate__(cls, value: Any, info: Optional[ValidationInfo] = None) -> str: + def __eth_pydantic_validate__(cls, value: Any, info: Optional["ValidationInfo"] = None) -> str: if type(value) in (list, tuple): return cls.conversion_manager.convert(value, list[AddressType]) diff --git a/src/ape/types/coverage.py b/src/ape/types/coverage.py index 760d20935a..1179fe7758 100644 --- a/src/ape/types/coverage.py +++ b/src/ape/types/coverage.py @@ -2,12 +2,12 @@ from datetime import datetime from html.parser import HTMLParser from pathlib import Path -from typing import Any, Optional +from typing import TYPE_CHECKING, Any, Optional from xml.dom.minidom import getDOMImplementation from xml.etree.ElementTree import Element, SubElement, tostring import requests -from ethpm_types.source import ContractSource, SourceLocation +from ethpm_types.source import SourceLocation from pydantic import NonNegativeInt, field_validator from ape.logging import logger @@ -15,6 +15,10 @@ from ape.utils.misc import get_current_timestamp_ms from ape.version import version as ape_version +if TYPE_CHECKING: + from ethpm_types.source import ContractSource + + _APE_DOCS_URL = "https://docs.apeworx.io/ape/stable/index.html" _DTD_URL = "https://raw.githubusercontent.com/cobertura/web/master/htdocs/xml/coverage-04.dtd" _CSS = """ @@ -545,7 +549,7 @@ def model_dump(self, *args, **kwargs) -> dict: return attribs - def include(self, contract_source: ContractSource) -> ContractSourceCoverage: + def include(self, contract_source: "ContractSource") -> ContractSourceCoverage: for src in self.sources: if src.source_id == contract_source.source_id: return src diff --git a/src/ape/types/signatures.py b/src/ape/types/signatures.py index 60db85fcb1..c3f857d919 100644 --- a/src/ape/types/signatures.py +++ b/src/ape/types/signatures.py @@ -1,5 +1,5 @@ from collections.abc import Iterator -from typing import Optional, Union +from typing import TYPE_CHECKING, Optional, Union from eth_account import Account from eth_account.messages import SignableMessage @@ -9,13 +9,15 @@ from ape.utils.misc import as_our_module, log_instead_of_fail -try: - # Only on Python 3.11 - from typing import Self # type: ignore -except ImportError: - from typing_extensions import Self # type: ignore +if TYPE_CHECKING: + from ape.types.address import AddressType + + try: + # Only on Python 3.11 + from typing import Self # type: ignore + except ImportError: + from typing_extensions import Self # type: ignore -from ape.types.address import AddressType # Fix 404 in doc link. as_our_module( @@ -89,7 +91,7 @@ def __iter__(self) -> Iterator[Union[int, bytes]]: yield self.s @classmethod - def from_rsv(cls, rsv: HexBytes) -> Self: + def from_rsv(cls, rsv: HexBytes) -> "Self": # NOTE: Values may be padded. if len(rsv) != 65: raise ValueError("Length of RSV bytes must be 65.") @@ -97,7 +99,7 @@ def from_rsv(cls, rsv: HexBytes) -> Self: return cls(r=HexBytes(rsv[:32]), s=HexBytes(rsv[32:64]), v=rsv[64]) @classmethod - def from_vrs(cls, vrs: HexBytes) -> Self: + def from_vrs(cls, vrs: HexBytes) -> "Self": # NOTE: Values may be padded. if len(vrs) != 65: raise ValueError("Length of VRS bytes must be 65.") @@ -122,7 +124,7 @@ class MessageSignature(_Signature): """ -def recover_signer(msg: SignableMessage, sig: MessageSignature) -> AddressType: +def recover_signer(msg: SignableMessage, sig: MessageSignature) -> "AddressType": """ Get the address of the signer. diff --git a/src/ape/types/trace.py b/src/ape/types/trace.py index dfa65ddeb7..1ac708edf9 100644 --- a/src/ape/types/trace.py +++ b/src/ape/types/trace.py @@ -5,7 +5,6 @@ from eth_pydantic_types import HexBytes from ethpm_types import ASTNode, BaseModel -from ethpm_types.ast import SourceLocation from ethpm_types.source import ( Closure, Content, @@ -20,6 +19,8 @@ from ape.utils.misc import log_instead_of_fail if TYPE_CHECKING: + from ethpm_types.ast import SourceLocation + from ape.api.trace import TraceAPI @@ -162,7 +163,7 @@ def pcs(self) -> set[int]: def extend( self, - location: SourceLocation, + location: "SourceLocation", pcs: Optional[set[int]] = None, ws_start: Optional[int] = None, ): @@ -441,7 +442,7 @@ def format(self) -> str: def add_jump( self, - location: SourceLocation, + location: "SourceLocation", function: Function, depth: int, pcs: Optional[set[int]] = None, @@ -469,7 +470,7 @@ def add_jump( ControlFlow.model_rebuild() self._add(asts, content, pcs, function, depth, source_path=source_path) - def extend_last(self, location: SourceLocation, pcs: Optional[set[int]] = None): + def extend_last(self, location: "SourceLocation", pcs: Optional[set[int]] = None): """ Extend the last node with more content. diff --git a/src/ape/types/units.py b/src/ape/types/units.py index c81d122a4c..22a2c1b1e1 100644 --- a/src/ape/types/units.py +++ b/src/ape/types/units.py @@ -1,4 +1,4 @@ -from typing import Any, Optional +from typing import TYPE_CHECKING, Any, Optional from pydantic_core.core_schema import ( CoreSchema, @@ -7,11 +7,13 @@ no_info_plain_validator_function, plain_serializer_function_ser_schema, ) -from typing_extensions import TypeAlias from ape.exceptions import ConversionError from ape.utils.basemodel import ManagerAccessMixin +if TYPE_CHECKING: + from typing_extensions import TypeAlias + class CurrencyValueComparable(int): """ @@ -72,7 +74,7 @@ def _serialize(value): CurrencyValueComparable.__name__ = int.__name__ -CurrencyValue: TypeAlias = CurrencyValueComparable +CurrencyValue: "TypeAlias" = CurrencyValueComparable """ An alias to :class:`~ape.types.CurrencyValueComparable` for situations when you know for sure the type is a currency-value diff --git a/src/ape/utils/misc.py b/src/ape/utils/misc.py index 6fae748b61..b369f7b1c5 100644 --- a/src/ape/utils/misc.py +++ b/src/ape/utils/misc.py @@ -63,7 +63,7 @@ ) -_python_version = ( +_python_version: str = ( f"{sys.version_info.major}.{sys.version_info.minor}" f".{sys.version_info.micro} {sys.version_info.releaselevel}" ) @@ -193,7 +193,7 @@ def get_package_version(obj: Any) -> str: return "" -__version__ = get_package_version(__name__) +__version__: str = get_package_version(__name__) def load_config(path: Path, expand_envars=True, must_exist=False) -> dict: diff --git a/src/ape/utils/os.py b/src/ape/utils/os.py index 822ea2387b..f66d9736c1 100644 --- a/src/ape/utils/os.py +++ b/src/ape/utils/os.py @@ -211,12 +211,7 @@ def create_tempdir(name: Optional[str] = None) -> Iterator[Path]: def run_in_tempdir( - fn: Callable[ - [ - Path, - ], - Any, - ], + fn: Callable[[Path], Any], name: Optional[str] = None, ): """ diff --git a/src/ape/utils/rpc.py b/src/ape/utils/rpc.py index 3cfa7b54e2..f552fb1c86 100644 --- a/src/ape/utils/rpc.py +++ b/src/ape/utils/rpc.py @@ -8,7 +8,7 @@ from ape.logging import logger from ape.utils.misc import __version__, _python_version -USER_AGENT = f"Ape/{__version__} (Python/{_python_version})" +USER_AGENT: str = f"Ape/{__version__} (Python/{_python_version})" def allow_disconnected(fn: Callable): diff --git a/src/ape_accounts/accounts.py b/src/ape_accounts/accounts.py index aea2f66f60..7dd0bae941 100644 --- a/src/ape_accounts/accounts.py +++ b/src/ape_accounts/accounts.py @@ -3,28 +3,31 @@ from collections.abc import Iterator from os import environ from pathlib import Path -from typing import Any, Optional +from typing import TYPE_CHECKING, Any, Optional import click from eip712.messages import EIP712Message from eth_account import Account as EthAccount from eth_account.hdaccount import ETHEREUM_DEFAULT_PATH from eth_account.messages import encode_defunct -from eth_account.signers.local import LocalAccount from eth_keys import keys # type: ignore from eth_pydantic_types import HexBytes from eth_utils import to_bytes, to_hex from ape.api.accounts import AccountAPI, AccountContainerAPI -from ape.api.transactions import TransactionAPI from ape.exceptions import AccountsError from ape.logging import logger -from ape.types.address import AddressType from ape.types.signatures import MessageSignature, SignableMessage, TransactionSignature from ape.utils.basemodel import ManagerAccessMixin from ape.utils.misc import log_instead_of_fail from ape.utils.validators import _validate_account_alias, _validate_account_passphrase +if TYPE_CHECKING: + from eth_account.signers.local import LocalAccount + + from ape.api.transactions import TransactionAPI + from ape.types.address import AddressType + class InvalidPasswordError(AccountsError): """ @@ -83,7 +86,7 @@ def keyfile(self) -> dict: return json.loads(self.keyfile_path.read_text()) @property - def address(self) -> AddressType: + def address(self) -> "AddressType": return self.network_manager.ethereum.decode_address(self.keyfile["address"]) @property @@ -220,7 +223,9 @@ def sign_message(self, msg: Any, **signer_options) -> Optional[MessageSignature] s=to_bytes(signed_msg.s), ) - def sign_transaction(self, txn: TransactionAPI, **signer_options) -> Optional[TransactionAPI]: + def sign_transaction( + self, txn: "TransactionAPI", **signer_options + ) -> Optional["TransactionAPI"]: user_approves = self.__autosign or click.confirm(f"{txn}\n\nSign: ") if not user_approves: return None @@ -292,7 +297,9 @@ def __decrypt_keyfile(self, passphrase: str) -> bytes: raise InvalidPasswordError() from err -def _write_and_return_account(alias: str, passphrase: str, account: LocalAccount) -> KeyfileAccount: +def _write_and_return_account( + alias: str, passphrase: str, account: "LocalAccount" +) -> KeyfileAccount: """Write an account to disk and return an Ape KeyfileAccount""" path = ManagerAccessMixin.account_manager.containers["accounts"].data_folder.joinpath( f"{alias}.json" diff --git a/src/ape_cache/query.py b/src/ape_cache/query.py index 30e32aeab7..deccb1c651 100644 --- a/src/ape_cache/query.py +++ b/src/ape_cache/query.py @@ -4,11 +4,11 @@ from typing import Any, Optional, cast from sqlalchemy import create_engine, func -from sqlalchemy.engine import CursorResult +from sqlalchemy.engine import CursorResult # noqa: TC002 from sqlalchemy.sql import column, insert, select -from sqlalchemy.sql.expression import Insert, Select +from sqlalchemy.sql.expression import Insert, Select # noqa: TC002 -from ape.api.providers import BlockAPI +from ape.api.providers import BlockAPI # noqa: TC002 from ape.api.query import ( BaseInterfaceModel, BlockQuery, diff --git a/src/ape_ethereum/_print.py b/src/ape_ethereum/_print.py index e2fc56bc7a..4102765d63 100644 --- a/src/ape_ethereum/_print.py +++ b/src/ape_ethereum/_print.py @@ -20,26 +20,29 @@ """ from collections.abc import Iterable -from typing import Any, cast +from typing import TYPE_CHECKING, Any, cast from eth_abi import decode from eth_typing import ChecksumAddress from eth_utils import add_0x_prefix, decode_hex, to_hex from ethpm_types import ContractType, MethodABI -from evm_trace import CallTreeNode from hexbytes import HexBytes -from typing_extensions import TypeGuard import ape from ape_ethereum._console_log_abi import CONSOLE_LOG_ABI +if TYPE_CHECKING: + from evm_trace import CallTreeNode + from typing_extensions import TypeGuard + + CONSOLE_ADDRESS = cast(ChecksumAddress, "0x000000000000000000636F6e736F6c652e6c6f67") VYPER_PRINT_METHOD_ID = HexBytes("0x23cdd8e8") # log(string,bytes) console_contract = ContractType(abi=CONSOLE_LOG_ABI, contractName="console") -def is_console_log(call: CallTreeNode) -> TypeGuard[CallTreeNode]: +def is_console_log(call: "CallTreeNode") -> "TypeGuard[CallTreeNode]": """Determine if a call is a standard console.log() call""" return ( call.address == HexBytes(CONSOLE_ADDRESS) @@ -47,7 +50,7 @@ def is_console_log(call: CallTreeNode) -> TypeGuard[CallTreeNode]: ) -def is_vyper_print(call: CallTreeNode) -> TypeGuard[CallTreeNode]: +def is_vyper_print(call: "CallTreeNode") -> "TypeGuard[CallTreeNode]": """Determine if a call is a standard Vyper print() call""" if call.address != HexBytes(CONSOLE_ADDRESS) or call.calldata[:4] != VYPER_PRINT_METHOD_ID: return False @@ -79,7 +82,7 @@ def vyper_print(calldata: str) -> tuple[Any]: return tuple(data) -def extract_debug_logs(call: CallTreeNode) -> Iterable[tuple[Any]]: +def extract_debug_logs(call: "CallTreeNode") -> Iterable[tuple[Any]]: """Filter calls to console.log() and print() from a transactions call tree""" if is_vyper_print(call) and call.calldata is not None: yield vyper_print(add_0x_prefix(to_hex(call.calldata[4:]))) diff --git a/src/ape_ethereum/ecosystem.py b/src/ape_ethereum/ecosystem.py index 82636f65fa..86ea1b3d16 100644 --- a/src/ape_ethereum/ecosystem.py +++ b/src/ape_ethereum/ecosystem.py @@ -2,7 +2,7 @@ from collections.abc import Iterator, Sequence from decimal import Decimal from functools import cached_property -from typing import Any, ClassVar, Optional, Union, cast +from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union, cast from eth_abi import decode, encode from eth_abi.exceptions import InsufficientDataBytes, NonEmptyPaddingBytes @@ -20,7 +20,6 @@ to_checksum_address, to_hex, ) -from ethpm_types import ContractType from ethpm_types.abi import ABIType, ConstructorABI, EventABI, MethodABI from pydantic import Field, computed_field, field_validator, model_validator from pydantic_settings import SettingsConfigDict @@ -28,8 +27,6 @@ from ape.api.config import PluginConfig from ape.api.networks import EcosystemAPI from ape.api.providers import BlockAPI -from ape.api.trace import TraceAPI -from ape.api.transactions import ReceiptAPI, TransactionAPI from ape.contracts.base import ContractCall from ape.exceptions import ( ApeException, @@ -80,6 +77,13 @@ TransactionType, ) +if TYPE_CHECKING: + from ethpm_types import ContractType + + from ape.api.trace import TraceAPI + from ape.api.transactions import ReceiptAPI, TransactionAPI + + NETWORKS = { # chain_id, network_id "mainnet": (1, 1), @@ -418,7 +422,7 @@ def decode_address(cls, raw_address: RawAddress) -> AddressType: def encode_address(cls, address: AddressType) -> RawAddress: return f"{address}" - def decode_transaction_type(self, transaction_type_id: Any) -> type[TransactionAPI]: + def decode_transaction_type(self, transaction_type_id: Any) -> type["TransactionAPI"]: if isinstance(transaction_type_id, TransactionType): tx_type = transaction_type_id elif isinstance(transaction_type_id, int): @@ -435,8 +439,8 @@ def decode_transaction_type(self, transaction_type_id: Any) -> type[TransactionA return DynamicFeeTransaction def encode_contract_blueprint( - self, contract_type: ContractType, *args, **kwargs - ) -> TransactionAPI: + self, contract_type: "ContractType", *args, **kwargs + ) -> "TransactionAPI": # EIP-5202 implementation. bytes_obj = contract_type.deployment_bytecode contract_bytes = (bytes_obj.to_bytes() or b"") if bytes_obj else b"" @@ -546,7 +550,7 @@ def str_to_slot(text): return None - def decode_receipt(self, data: dict) -> ReceiptAPI: + def decode_receipt(self, data: dict) -> "ReceiptAPI": status = data.get("status") if status is not None: status = self.conversion_manager.convert(status, int) @@ -864,7 +868,7 @@ def encode_transaction( return cast(BaseTransaction, txn) - def create_transaction(self, **kwargs) -> TransactionAPI: + def create_transaction(self, **kwargs) -> "TransactionAPI": """ Returns a transaction using the given constructor kwargs. @@ -902,7 +906,7 @@ def create_transaction(self, **kwargs) -> TransactionAPI: tx_data["data"] = b"" # Deduce the transaction type. - transaction_types: dict[TransactionType, type[TransactionAPI]] = { + transaction_types: dict[TransactionType, type["TransactionAPI"]] = { TransactionType.STATIC: StaticFeeTransaction, TransactionType.ACCESS_LIST: AccessListTransaction, TransactionType.DYNAMIC: DynamicFeeTransaction, @@ -973,7 +977,7 @@ def create_transaction(self, **kwargs) -> TransactionAPI: return txn_class.model_validate(tx_data) - def decode_logs(self, logs: Sequence[dict], *events: EventABI) -> Iterator["ContractLog"]: + def decode_logs(self, logs: Sequence[dict], *events: EventABI) -> Iterator[ContractLog]: if not logs: return @@ -1052,7 +1056,7 @@ def get_abi(_topic: HexStr) -> Optional[LogInputABICollection]: ), ) - def enrich_trace(self, trace: TraceAPI, **kwargs) -> TraceAPI: + def enrich_trace(self, trace: "TraceAPI", **kwargs) -> "TraceAPI": kwargs["trace"] = trace if not isinstance(trace, Trace): # Can only enrich `ape_ethereum.trace.Trace` (or subclass) implementations. @@ -1416,7 +1420,7 @@ def _enrich_revert_message(self, call: dict) -> dict: def _get_contract_type_for_enrichment( self, address: AddressType, **kwargs - ) -> Optional[ContractType]: + ) -> Optional["ContractType"]: if not (contract_type := kwargs.get("contract_type")): try: contract_type = self.chain_manager.contracts.get(address) diff --git a/src/ape_ethereum/multicall/handlers.py b/src/ape_ethereum/multicall/handlers.py index b8722fe7e5..9e48208834 100644 --- a/src/ape_ethereum/multicall/handlers.py +++ b/src/ape_ethereum/multicall/handlers.py @@ -1,12 +1,10 @@ from collections.abc import Iterator from functools import cached_property from types import ModuleType -from typing import Any, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union -from eth_pydantic_types import HexBytes from ethpm_types import ContractType -from ape.api.transactions import ReceiptAPI, TransactionAPI from ape.contracts.base import ( ContractCallHandler, ContractInstance, @@ -16,7 +14,6 @@ ) from ape.exceptions import ChainError, DecodingError from ape.logging import logger -from ape.types.address import AddressType from ape.utils.abi import MethodABI from ape.utils.basemodel import ManagerAccessMixin @@ -28,11 +25,17 @@ ) from .exceptions import InvalidOption, UnsupportedChainError, ValueRequired +if TYPE_CHECKING: + from eth_pydantic_types import HexBytes + + from ape.api.transactions import ReceiptAPI, TransactionAPI + from ape.types.address import AddressType + class BaseMulticall(ManagerAccessMixin): def __init__( self, - address: AddressType = MULTICALL3_ADDRESS, + address: "AddressType" = MULTICALL3_ADDRESS, supported_chains: Optional[list[int]] = None, ) -> None: """ @@ -159,13 +162,13 @@ class Call(BaseMulticall): def __init__( self, - address: AddressType = MULTICALL3_ADDRESS, + address: "AddressType" = MULTICALL3_ADDRESS, supported_chains: Optional[list[int]] = None, ) -> None: super().__init__(address=address, supported_chains=supported_chains) self.abis: list[MethodABI] = [] - self._result: Union[None, list[tuple[bool, HexBytes]]] = None + self._result: Union[None, list[tuple[bool, "HexBytes"]]] = None @property def handler(self) -> ContractCallHandler: # type: ignore[override] @@ -180,7 +183,7 @@ def add(self, call: ContractMethodHandler, *args, **kwargs): return self @property - def returnData(self) -> list[HexBytes]: + def returnData(self) -> list["HexBytes"]: # NOTE: this property is kept camelCase to align with the raw EVM struct result = self._result # Declare for typing reasons. return [res.returnData if res.success else None for res in result] # type: ignore @@ -225,7 +228,7 @@ def __call__(self, **call_kwargs) -> Iterator[Any]: self._result = self.handler(self.calls, **call_kwargs) return self._decode_results() - def as_transaction(self, **txn_kwargs) -> TransactionAPI: + def as_transaction(self, **txn_kwargs) -> "TransactionAPI": """ Encode the Multicall transaction as a ``TransactionAPI`` object, but do not execute it. @@ -272,7 +275,7 @@ def _validate_calls(self, **txn_kwargs) -> None: # NOTE: Won't fail if `value` is provided otherwise (won't do anything either) - def __call__(self, **txn_kwargs) -> ReceiptAPI: + def __call__(self, **txn_kwargs) -> "ReceiptAPI": """ Execute the Multicall transaction. The transaction will broadcast again every time the ``Transaction`` object is called. @@ -290,7 +293,7 @@ def __call__(self, **txn_kwargs) -> ReceiptAPI: self._validate_calls(**txn_kwargs) return self.handler(self.calls, **txn_kwargs) - def as_transaction(self, **txn_kwargs) -> TransactionAPI: + def as_transaction(self, **txn_kwargs) -> "TransactionAPI": """ Encode the Multicall transaction as a ``TransactionAPI`` object, but do not execute it. diff --git a/src/ape_ethereum/provider.py b/src/ape_ethereum/provider.py index 53b63b2bf2..c1ff49e704 100644 --- a/src/ape_ethereum/provider.py +++ b/src/ape_ethereum/provider.py @@ -9,14 +9,13 @@ from copy import copy from functools import cached_property, wraps from pathlib import Path -from typing import Any, Optional, Union, cast +from typing import TYPE_CHECKING, Any, Optional, Union, cast import ijson # type: ignore import requests from eth_pydantic_types import HexBytes from eth_typing import BlockNumber, HexStr from eth_utils import add_0x_prefix, is_hex, to_hex -from ethpm_types import EventABI from evmchains import get_random_rpc from pydantic.dataclasses import dataclass from requests import HTTPError @@ -39,7 +38,6 @@ from ape.api.address import Address from ape.api.providers import BlockAPI, ProviderAPI -from ape.api.trace import TraceAPI from ape.api.transactions import ReceiptAPI, TransactionAPI from ape.exceptions import ( _SOURCE_TRACEBACK_ARG, @@ -58,17 +56,23 @@ VirtualMachineError, ) from ape.logging import logger, sanitize_url -from ape.types.address import AddressType from ape.types.events import ContractLog, LogFilter from ape.types.gas import AutoGasLimit from ape.types.trace import SourceTraceback -from ape.types.vm import BlockID, ContractCode from ape.utils.basemodel import ManagerAccessMixin from ape.utils.misc import DEFAULT_MAX_RETRIES_TX, gas_estimation_error_message, to_int from ape_ethereum._print import CONSOLE_ADDRESS, console_contract from ape_ethereum.trace import CallTrace, TraceApproach, TransactionTrace from ape_ethereum.transactions import AccessList, AccessListTransaction, TransactionStatusEnum +if TYPE_CHECKING: + from ethpm_types import EventABI + + from ape.api.trace import TraceAPI + from ape.types.address import AddressType + from ape.types.vm import BlockID, ContractCode + + DEFAULT_PORT = 8545 DEFAULT_HOSTNAME = "localhost" DEFAULT_SETTINGS = {"uri": f"http://{DEFAULT_HOSTNAME}:{DEFAULT_PORT}"} @@ -322,7 +326,7 @@ def update_settings(self, new_settings: dict): self.provider_settings.update(new_settings) self.connect() - def estimate_gas_cost(self, txn: TransactionAPI, block_id: Optional[BlockID] = None) -> int: + def estimate_gas_cost(self, txn: TransactionAPI, block_id: Optional["BlockID"] = None) -> int: # NOTE: Using JSON mode since used as request data. txn_dict = txn.model_dump(by_alias=True, mode="json") @@ -410,7 +414,7 @@ def priority_fee(self) -> int: "eth_maxPriorityFeePerGas not supported in this RPC. Please specify manually." ) from err - def get_block(self, block_id: BlockID) -> BlockAPI: + def get_block(self, block_id: "BlockID") -> BlockAPI: if isinstance(block_id, str) and block_id.isnumeric(): block_id = int(block_id) @@ -429,17 +433,19 @@ def _get_latest_block(self) -> BlockAPI: def _get_latest_block_rpc(self) -> dict: return self.make_request("eth_getBlockByNumber", ["latest", False]) - def get_nonce(self, address: AddressType, block_id: Optional[BlockID] = None) -> int: + def get_nonce(self, address: "AddressType", block_id: Optional["BlockID"] = None) -> int: return self.web3.eth.get_transaction_count(address, block_identifier=block_id) - def get_balance(self, address: AddressType, block_id: Optional[BlockID] = None) -> int: + def get_balance(self, address: "AddressType", block_id: Optional["BlockID"] = None) -> int: return self.web3.eth.get_balance(address, block_identifier=block_id) - def get_code(self, address: AddressType, block_id: Optional[BlockID] = None) -> ContractCode: + def get_code( + self, address: "AddressType", block_id: Optional["BlockID"] = None + ) -> "ContractCode": return self.web3.eth.get_code(address, block_identifier=block_id) def get_storage( - self, address: AddressType, slot: int, block_id: Optional[BlockID] = None + self, address: "AddressType", slot: int, block_id: Optional["BlockID"] = None ) -> HexBytes: try: return HexBytes(self.web3.eth.get_storage_at(address, slot, block_identifier=block_id)) @@ -449,7 +455,7 @@ def get_storage( raise # Raise original error - def get_transaction_trace(self, transaction_hash: str, **kwargs) -> TraceAPI: + def get_transaction_trace(self, transaction_hash: str, **kwargs) -> "TraceAPI": if transaction_hash in self._transaction_trace_cache: return self._transaction_trace_cache[transaction_hash] @@ -463,7 +469,7 @@ def get_transaction_trace(self, transaction_hash: str, **kwargs) -> TraceAPI: def send_call( self, txn: TransactionAPI, - block_id: Optional[BlockID] = None, + block_id: Optional["BlockID"] = None, state: Optional[dict] = None, **kwargs: Any, ) -> HexBytes: @@ -694,7 +700,7 @@ def _create_receipt(self, **kwargs) -> ReceiptAPI: data = {"provider": self, **kwargs} return self.network.ecosystem.decode_receipt(data) - def get_transactions_by_block(self, block_id: BlockID) -> Iterator[TransactionAPI]: + def get_transactions_by_block(self, block_id: "BlockID") -> Iterator[TransactionAPI]: if isinstance(block_id, str): block_id = HexStr(block_id) @@ -707,7 +713,7 @@ def get_transactions_by_block(self, block_id: BlockID) -> Iterator[TransactionAP def get_transactions_by_account_nonce( self, - account: AddressType, + account: "AddressType", start_nonce: int = 0, stop_nonce: int = -1, ) -> Iterator[ReceiptAPI]: @@ -732,7 +738,7 @@ def get_transactions_by_account_nonce( def _find_txn_by_account_and_nonce( self, - account: AddressType, + account: "AddressType", start_nonce: int, stop_nonce: int, start_block: int, @@ -878,11 +884,11 @@ def assert_chain_activity(): def poll_logs( self, stop_block: Optional[int] = None, - address: Optional[AddressType] = None, + address: Optional["AddressType"] = None, topics: Optional[list[Union[str, list[str]]]] = None, required_confirmations: Optional[int] = None, new_block_timeout: Optional[int] = None, - events: Optional[list[EventABI]] = None, + events: Optional[list["EventABI"]] = None, ) -> Iterator[ContractLog]: events = events or [] if required_confirmations is None: @@ -1169,7 +1175,7 @@ def stream_request(self, method: str, params: Iterable, iter_path: str = "result del results[:] def create_access_list( - self, transaction: TransactionAPI, block_id: Optional[BlockID] = None + self, transaction: TransactionAPI, block_id: Optional["BlockID"] = None ) -> list[AccessList]: """ Get the access list for a transaction use ``eth_createAccessList``. @@ -1248,7 +1254,7 @@ def _handle_execution_reverted( exception: Union[Exception, str], txn: Optional[TransactionAPI] = None, trace: _TRACE_ARG = None, - contract_address: Optional[AddressType] = None, + contract_address: Optional["AddressType"] = None, source_traceback: _SOURCE_TRACEBACK_ARG = None, set_ape_traceback: Optional[bool] = None, ) -> ContractLogicError: @@ -1548,7 +1554,7 @@ def _log_connection(self, client_name: str): ) logger.info(f"{msg} {suffix}.") - def ots_get_contract_creator(self, address: AddressType) -> Optional[dict]: + def ots_get_contract_creator(self, address: "AddressType") -> Optional[dict]: if self._ots_api_level is None: return None @@ -1559,7 +1565,7 @@ def ots_get_contract_creator(self, address: AddressType) -> Optional[dict]: return result - def _get_contract_creation_receipt(self, address: AddressType) -> Optional[ReceiptAPI]: + def _get_contract_creation_receipt(self, address: "AddressType") -> Optional[ReceiptAPI]: if result := self.ots_get_contract_creator(address): tx_hash = result["hash"] return self.get_receipt(tx_hash) diff --git a/src/ape_ethereum/trace.py b/src/ape_ethereum/trace.py index 0812624f96..e67463cb0e 100644 --- a/src/ape_ethereum/trace.py +++ b/src/ape_ethereum/trace.py @@ -5,11 +5,10 @@ from collections.abc import Iterable, Iterator, Sequence from enum import Enum from functools import cached_property -from typing import IO, Any, Optional, Union +from typing import IO, TYPE_CHECKING, Any, Optional, Union from eth_pydantic_types import HexStr from eth_utils import is_0x_prefixed, to_hex -from ethpm_types import ContractType, MethodABI from evm_trace import ( CallTreeNode, CallType, @@ -25,17 +24,22 @@ from pydantic import field_validator from rich.tree import Tree -from ape.api.networks import EcosystemAPI from ape.api.trace import TraceAPI from ape.api.transactions import TransactionAPI from ape.exceptions import ContractLogicError, ProviderError, TransactionNotFoundError from ape.logging import get_rich_console, logger -from ape.types.address import AddressType -from ape.types.trace import ContractFunctionPath, GasReport from ape.utils.misc import ZERO_ADDRESS, is_evm_precompile, is_zero_hex, log_instead_of_fail from ape.utils.trace import TraceStyles, _exclude_gas from ape_ethereum._print import extract_debug_logs +if TYPE_CHECKING: + from ethpm_types import ContractType, MethodABI + + from ape.api.networks import EcosystemAPI + from ape.types.address import AddressType + from ape.types.trace import ContractFunctionPath, GasReport + + _INDENT = 2 _WRAP_THRESHOLD = 50 _REVERT_PREFIX = "0x08c379a00000000000000000000000000000000000000000000000000000000000000020" @@ -174,11 +178,11 @@ def frames(self) -> Iterator[TraceFrame]: yield from create_trace_frames(iter(self.raw_trace_frames)) @property - def addresses(self) -> Iterator[AddressType]: + def addresses(self) -> Iterator["AddressType"]: yield from self.get_addresses_used() @cached_property - def root_contract_type(self) -> Optional[ContractType]: + def root_contract_type(self) -> Optional["ContractType"]: if address := self.transaction.get("to"): try: return self.chain_manager.contracts.get(address) @@ -188,7 +192,7 @@ def root_contract_type(self) -> Optional[ContractType]: return None @cached_property - def root_method_abi(self) -> Optional[MethodABI]: + def root_method_abi(self) -> Optional["MethodABI"]: method_id = self.transaction.get("data", b"")[:10] if ct := self.root_contract_type: try: @@ -199,7 +203,7 @@ def root_method_abi(self) -> Optional[MethodABI]: return None @property - def _ecosystem(self) -> EcosystemAPI: + def _ecosystem(self) -> "EcosystemAPI": if provider := self.network_manager.active_provider: return provider.network.ecosystem @@ -357,13 +361,15 @@ def show(self, verbose: bool = False, file: IO[str] = sys.stdout): console.print(root) - def get_gas_report(self, exclude: Optional[Sequence[ContractFunctionPath]] = None) -> GasReport: + def get_gas_report( + self, exclude: Optional[Sequence["ContractFunctionPath"]] = None + ) -> "GasReport": call = self.enriched_calltree return self._get_gas_report_from_call(call, exclude=exclude) def _get_gas_report_from_call( - self, call: dict, exclude: Optional[Sequence[ContractFunctionPath]] = None - ) -> GasReport: + self, call: dict, exclude: Optional[Sequence["ContractFunctionPath"]] = None + ) -> "GasReport": tx = self.transaction # Enrich transfers. @@ -388,7 +394,7 @@ def _get_gas_report_from_call( return merge_reports(*sub_reports) elif not is_zero_hex(call["method_id"]) and not is_evm_precompile(call["method_id"]): - report: GasReport = { + report: "GasReport" = { call["contract_id"]: { call["method_id"]: ( [int(call["gas_cost"])] if call.get("gas_cost") is not None else [] @@ -434,7 +440,7 @@ def _debug_trace_transaction_struct_logs_to_call(self) -> CallTreeNode: def _get_tree(self, verbose: bool = False) -> Tree: return parse_rich_tree(self.enriched_calltree, verbose=verbose) - def _get_abi(self, call: dict) -> Optional[MethodABI]: + def _get_abi(self, call: dict) -> Optional["MethodABI"]: if not (addr := call.get("address")): return self.root_method_abi if not (calldata := call.get("calldata")): diff --git a/src/ape_ethereum/transactions.py b/src/ape_ethereum/transactions.py index 4385973469..83ebbd57a3 100644 --- a/src/ape_ethereum/transactions.py +++ b/src/ape_ethereum/transactions.py @@ -1,7 +1,7 @@ import sys from enum import Enum, IntEnum from functools import cached_property -from typing import IO, Any, Optional, Union +from typing import IO, TYPE_CHECKING, Any, Optional, Union from eth_abi import decode from eth_account import Account as EthAccount @@ -11,12 +11,10 @@ ) from eth_pydantic_types import HexBytes from eth_utils import decode_hex, encode_hex, keccak, to_hex, to_int -from ethpm_types import ContractType from ethpm_types.abi import EventABI, MethodABI from pydantic import BaseModel, Field, field_validator, model_validator from ape.api.transactions import ReceiptAPI, TransactionAPI -from ape.contracts import ContractEvent from ape.exceptions import OutOfGasError, SignatureError, TransactionError from ape.logging import logger from ape.types.address import AddressType @@ -26,6 +24,11 @@ from ape.utils.misc import ZERO_ADDRESS from ape_ethereum.trace import Trace, _events_to_trees +if TYPE_CHECKING: + from ethpm_types import ContractType + + from ape.contracts import ContractEvent + class TransactionStatusEnum(IntEnum): """ @@ -221,7 +224,7 @@ def debug_logs_typed(self) -> list[tuple[Any]]: return list(trace.debug_logs) @cached_property - def contract_type(self) -> Optional[ContractType]: + def contract_type(self) -> Optional["ContractType"]: if address := (self.receiver or self.contract_address): return self.chain_manager.contracts.get(address) diff --git a/src/ape_node/provider.py b/src/ape_node/provider.py index 6f459324ff..95bd54d2b7 100644 --- a/src/ape_node/provider.py +++ b/src/ape_node/provider.py @@ -2,13 +2,12 @@ import shutil from pathlib import Path from subprocess import DEVNULL, PIPE, Popen -from typing import Any, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union from eth_utils import add_0x_prefix, to_hex from evmchains import get_random_rpc from geth.chain import initialize_chain from geth.process import BaseGethProcess -from geth.types import GenesisDataTypedDict from geth.wrapper import construct_test_chain_kwargs from pydantic import field_validator from pydantic_settings import SettingsConfigDict @@ -16,11 +15,9 @@ from web3.middleware import geth_poa_middleware as ExtraDataToPOAMiddleware from yarl import URL -from ape.api.accounts import TestAccountAPI from ape.api.config import PluginConfig from ape.api.providers import SubprocessProvider, TestProviderAPI from ape.logging import LogLevel, logger -from ape.types.vm import SnapshotID from ape.utils.misc import ZERO_ADDRESS, log_instead_of_fail, raises_not_implemented from ape.utils.process import JoinableQueue, spawn from ape.utils.testing import ( @@ -39,10 +36,17 @@ ) from ape_ethereum.trace import TraceApproach +if TYPE_CHECKING: + from geth.types import GenesisDataTypedDict + + from ape.api.accounts import TestAccountAPI + from ape.types.vm import SnapshotID + + Alloc = dict[str, dict[str, Any]] -def create_genesis_data(alloc: Alloc, chain_id: int) -> GenesisDataTypedDict: +def create_genesis_data(alloc: Alloc, chain_id: int) -> "GenesisDataTypedDict": """ A wrapper around genesis data for py-geth that fills in more defaults. @@ -398,10 +402,10 @@ def disconnect(self): super().disconnect() - def snapshot(self) -> SnapshotID: + def snapshot(self) -> "SnapshotID": return self._get_latest_block().number or 0 - def restore(self, snapshot_id: SnapshotID): + def restore(self, snapshot_id: "SnapshotID"): if isinstance(snapshot_id, int): block_number_int = snapshot_id block_number_hex_str = str(to_hex(snapshot_id)) diff --git a/src/ape_pm/compiler.py b/src/ape_pm/compiler.py index 297af2ad25..ebd5ea97ee 100644 --- a/src/ape_pm/compiler.py +++ b/src/ape_pm/compiler.py @@ -2,7 +2,7 @@ from collections.abc import Iterable, Iterator from json import JSONDecodeError from pathlib import Path -from typing import Optional +from typing import TYPE_CHECKING, Optional from eth_pydantic_types import HexBytes from eth_utils import is_0x_prefixed @@ -11,9 +11,11 @@ from ape.api.compiler import CompilerAPI from ape.exceptions import CompilerError, ContractLogicError from ape.logging import logger -from ape.managers.project import ProjectManager from ape.utils.os import get_relative_path +if TYPE_CHECKING: + from ape.managers.project import ProjectManager + class InterfaceCompiler(CompilerAPI): """ @@ -64,7 +66,7 @@ def compile( def compile_code( self, code: str, - project: Optional[ProjectManager] = None, + project: Optional["ProjectManager"] = None, **kwargs, ) -> ContractType: code = code or "[]" diff --git a/src/ape_pm/project.py b/src/ape_pm/project.py index 7e9a53cc4c..0091bbe142 100644 --- a/src/ape_pm/project.py +++ b/src/ape_pm/project.py @@ -1,5 +1,7 @@ import sys from collections.abc import Iterable +from pathlib import Path +from typing import Any, Optional from ape.utils._github import _GithubClient, github_client @@ -10,9 +12,6 @@ else: import toml as tomllib # type: ignore[no-redef] -from pathlib import Path -from typing import Any, Optional - from yaml import safe_load from ape.api.config import ApeConfig diff --git a/src/ape_test/accounts.py b/src/ape_test/accounts.py index 0dbe558250..43c46a29fd 100644 --- a/src/ape_test/accounts.py +++ b/src/ape_test/accounts.py @@ -1,6 +1,6 @@ import warnings from collections.abc import Iterator -from typing import Any, Optional, cast +from typing import TYPE_CHECKING, Any, Optional, cast from eip712.messages import EIP712Message from eth_account import Account as EthAccount @@ -11,9 +11,7 @@ from eth_utils import to_bytes, to_hex from ape.api.accounts import TestAccountAPI, TestAccountContainerAPI -from ape.api.transactions import TransactionAPI from ape.exceptions import ProviderNotConnectedError, SignatureError -from ape.types.address import AddressType from ape.types.signatures import MessageSignature, TransactionSignature from ape.utils.testing import ( DEFAULT_NUMBER_OF_TEST_ACCOUNTS, @@ -22,6 +20,10 @@ generate_dev_accounts, ) +if TYPE_CHECKING: + from ape.api.transactions import TransactionAPI + from ape.types.address import AddressType + class TestAccountContainer(TestAccountContainerAPI): generated_accounts: list["TestAccount"] = [] @@ -82,7 +84,9 @@ def generate_account(self, index: Optional[int] = None) -> "TestAccountAPI": return account @classmethod - def init_test_account(cls, index: int, address: AddressType, private_key: str) -> "TestAccount": + def init_test_account( + cls, index: int, address: "AddressType", private_key: str + ) -> "TestAccount": return TestAccount( index=index, address_str=address, @@ -105,7 +109,7 @@ def alias(self) -> str: return f"TEST::{self.index}" @property - def address(self) -> AddressType: + def address(self) -> "AddressType": return self.network_manager.ethereum.decode_address(self.address_str) def sign_message(self, msg: Any, **signer_options) -> Optional[MessageSignature]: @@ -129,7 +133,9 @@ def sign_message(self, msg: Any, **signer_options) -> Optional[MessageSignature] ) return None - def sign_transaction(self, txn: TransactionAPI, **signer_options) -> Optional[TransactionAPI]: + def sign_transaction( + self, txn: "TransactionAPI", **signer_options + ) -> Optional["TransactionAPI"]: # Signs any transaction that's given to it. # NOTE: Using JSON mode, as only primitive types can be signed. tx_data = txn.model_dump(mode="json", by_alias=True, exclude={"sender"}) diff --git a/src/ape_test/provider.py b/src/ape_test/provider.py index afd12e227a..f6c63e8060 100644 --- a/src/ape_test/provider.py +++ b/src/ape_test/provider.py @@ -18,8 +18,6 @@ from web3.types import TxParams from ape.api.providers import BlockAPI, TestProviderAPI -from ape.api.trace import TraceAPI -from ape.api.transactions import ReceiptAPI, TransactionAPI from ape.exceptions import ( APINotImplementedError, ContractLogicError, @@ -31,8 +29,6 @@ ) from ape.logging import logger from ape.types.address import AddressType -from ape.types.events import ContractLog, LogFilter -from ape.types.vm import BlockID, SnapshotID from ape.utils.misc import gas_estimation_error_message from ape.utils.testing import DEFAULT_TEST_HD_PATH from ape_ethereum.provider import Web3Provider @@ -41,6 +37,10 @@ if TYPE_CHECKING: from ape.api.accounts import TestAccountAPI + from ape.api.trace import TraceAPI + from ape.api.transactions import ReceiptAPI, TransactionAPI + from ape.types.events import ContractLog, LogFilter + from ape.types.vm import BlockID, SnapshotID class LocalProvider(TestProviderAPI, Web3Provider): @@ -121,7 +121,7 @@ def update_settings(self, new_settings: dict): self.connect() def estimate_gas_cost( - self, txn: TransactionAPI, block_id: Optional[BlockID] = None, **kwargs + self, txn: "TransactionAPI", block_id: Optional["BlockID"] = None, **kwargs ) -> int: if isinstance(self.network.gas_limit, int): return self.network.gas_limit @@ -201,8 +201,8 @@ def base_fee(self) -> int: def send_call( self, - txn: TransactionAPI, - block_id: Optional[BlockID] = None, + txn: "TransactionAPI", + block_id: Optional["BlockID"] = None, state: Optional[dict] = None, **kwargs, ) -> HexBytes: @@ -244,7 +244,7 @@ def send_call( return HexBytes(result) - def send_transaction(self, txn: TransactionAPI) -> ReceiptAPI: + def send_transaction(self, txn: "TransactionAPI") -> "ReceiptAPI": vm_err = None txn_dict = None try: @@ -304,10 +304,10 @@ def send_transaction(self, txn: TransactionAPI) -> ReceiptAPI: return receipt - def snapshot(self) -> SnapshotID: + def snapshot(self) -> "SnapshotID": return self.evm_backend.take_snapshot() - def restore(self, snapshot_id: SnapshotID): + def restore(self, snapshot_id: "SnapshotID"): if snapshot_id: current_hash = self._get_latest_block_rpc().get("hash") if current_hash != snapshot_id: @@ -341,18 +341,18 @@ def set_timestamp(self, new_timestamp: int): def mine(self, num_blocks: int = 1): self.evm_backend.mine_blocks(num_blocks) - def get_balance(self, address: AddressType, block_id: Optional[BlockID] = None) -> int: + def get_balance(self, address: AddressType, block_id: Optional["BlockID"] = None) -> int: # perf: Using evm_backend directly instead of going through web3. return self.evm_backend.get_balance( HexBytes(address), block_number="latest" if block_id is None else block_id ) - def get_nonce(self, address: AddressType, block_id: Optional[BlockID] = None) -> int: + def get_nonce(self, address: AddressType, block_id: Optional["BlockID"] = None) -> int: return self.evm_backend.get_nonce( HexBytes(address), block_number="latest" if block_id is None else block_id ) - def get_contract_logs(self, log_filter: LogFilter) -> Iterator[ContractLog]: + def get_contract_logs(self, log_filter: "LogFilter") -> Iterator["ContractLog"]: from_block = max(0, log_filter.start_block) if log_filter.stop_block is None: @@ -397,7 +397,7 @@ def _get_last_base_fee(self) -> int: raise APINotImplementedError("No base fee found in block.") - def get_transaction_trace(self, transaction_hash: str, **kwargs) -> TraceAPI: + def get_transaction_trace(self, transaction_hash: str, **kwargs) -> "TraceAPI": if "call_trace_approach" not in kwargs: kwargs["call_trace_approach"] = TraceApproach.BASIC diff --git a/tests/functional/conftest.py b/tests/functional/conftest.py index d06dd7bee0..2e7b2e0cf2 100644 --- a/tests/functional/conftest.py +++ b/tests/functional/conftest.py @@ -3,7 +3,7 @@ from contextlib import contextmanager from pathlib import Path from shutil import copytree -from typing import Optional, cast +from typing import TYPE_CHECKING, Optional, cast import pytest from eth_pydantic_types import HexBytes @@ -18,10 +18,13 @@ from ape.logging import LogLevel from ape.logging import logger as _logger from ape.types.address import AddressType -from ape.types.events import ContractLog from ape.utils.misc import LOCAL_NETWORK_NAME from ape_ethereum.proxies import minimal_proxy as _minimal_proxy_container +if TYPE_CHECKING: + from ape.types.events import ContractLog + + ALIAS_2 = "__FUNCTIONAL_TESTS_ALIAS_2__" TEST_ADDRESS = cast(AddressType, "0xd8dA6BF26964aF9D7eEd9e03E53415D37aA96045") BASE_PROJECTS_DIRECTORY = (Path(__file__).parent / "data" / "projects").absolute() @@ -431,7 +434,7 @@ def PollDaemon(): @pytest.fixture def assert_log_values(contract_instance): def _assert_log_values( - log: ContractLog, + log: "ContractLog", number: int, previous_number: Optional[int] = None, address: Optional[AddressType] = None, diff --git a/tests/functional/test_config.py b/tests/functional/test_config.py index 2c90a5ba06..ec49a34c0f 100644 --- a/tests/functional/test_config.py +++ b/tests/functional/test_config.py @@ -1,7 +1,7 @@ import os import re from pathlib import Path -from typing import Optional, Union +from typing import TYPE_CHECKING, Optional, Union import pytest from pydantic import ValidationError @@ -10,12 +10,15 @@ from ape.api.config import ApeConfig, ConfigEnum, PluginConfig from ape.exceptions import ConfigError from ape.managers.config import CONFIG_FILE_NAME, merge_configs -from ape.types.gas import GasLimit from ape.utils.os import create_tempdir from ape_ethereum.ecosystem import EthereumConfig, NetworkConfig from ape_networks import CustomNetwork from tests.functional.conftest import PROJECT_WITH_LONG_CONTRACTS_FOLDER +if TYPE_CHECKING: + from ape.types.gas import GasLimit + + CONTRACTS_FOLDER = "pathsomewhwere" NUMBER_OF_TEST_ACCOUNTS = 31 YAML_CONTENT = rf""" @@ -277,7 +280,7 @@ def test_network_gas_limit_default(config): assert eth_config.local.gas_limit == "max" -def _sepolia_with_gas_limit(gas_limit: GasLimit) -> dict: +def _sepolia_with_gas_limit(gas_limit: "GasLimit") -> dict: return { "ethereum": { "sepolia": { diff --git a/tests/functional/test_contract_event.py b/tests/functional/test_contract_event.py index 531366dedf..8dfa555b7e 100644 --- a/tests/functional/test_contract_event.py +++ b/tests/functional/test_contract_event.py @@ -1,6 +1,6 @@ import time from queue import Queue -from typing import Optional +from typing import TYPE_CHECKING, Optional import pytest from eth_pydantic_types import HexBytes @@ -8,11 +8,13 @@ from eth_utils import to_hex from ethpm_types import ContractType -from ape.api.transactions import ReceiptAPI from ape.exceptions import ProviderError from ape.types.events import ContractLog from ape.types.units import CurrencyValueComparable +if TYPE_CHECKING: + from ape.api.transactions import ReceiptAPI + @pytest.fixture def assert_log_values(owner, chain): @@ -38,7 +40,7 @@ def test_contract_logs_from_receipts(owner, contract_instance, assert_log_values receipt_1 = contract_instance.setNumber(2, sender=owner) receipt_2 = contract_instance.setNumber(3, sender=owner) - def assert_receipt_logs(receipt: ReceiptAPI, num: int): + def assert_receipt_logs(receipt: "ReceiptAPI", num: int): logs = event_type.from_receipt(receipt) assert len(logs) == 1 assert_log_values(logs[0], num) diff --git a/tests/functional/test_explorer.py b/tests/functional/test_explorer.py index db9e35e725..44214112eb 100644 --- a/tests/functional/test_explorer.py +++ b/tests/functional/test_explorer.py @@ -1,23 +1,26 @@ -from typing import Optional +from typing import TYPE_CHECKING, Optional import pytest -from ethpm_types import ContractType from ape.api.explorers import ExplorerAPI -from ape.types.address import AddressType + +if TYPE_CHECKING: + from ethpm_types import ContractType + + from ape.types.address import AddressType class MyExplorer(ExplorerAPI): def get_transaction_url(self, transaction_hash: str) -> str: return "" - def get_address_url(self, address: AddressType) -> str: + def get_address_url(self, address: "AddressType") -> str: return "" - def get_contract_type(self, address: AddressType) -> Optional[ContractType]: + def get_contract_type(self, address: "AddressType") -> Optional["ContractType"]: return None - def publish_contract(self, address: AddressType): + def publish_contract(self, address: "AddressType"): return diff --git a/tests/functional/test_receipt.py b/tests/functional/test_receipt.py index 318b65e562..1d99da80ec 100644 --- a/tests/functional/test_receipt.py +++ b/tests/functional/test_receipt.py @@ -1,12 +1,16 @@ +from typing import TYPE_CHECKING + import pytest from rich.table import Table from rich.tree import Tree -from ape.api import ReceiptAPI from ape.exceptions import ContractLogicError, OutOfGasError from ape.utils import ManagerAccessMixin from ape_ethereum.transactions import DynamicFeeTransaction, Receipt, TransactionStatusEnum +if TYPE_CHECKING: + from ape.api import ReceiptAPI + @pytest.fixture def deploy_receipt(vyper_contract_instance): @@ -147,7 +151,7 @@ def test_decode_logs(owner, contract_instance, assert_log_values): receipt_1 = contract_instance.setNumber(2, sender=owner) receipt_2 = contract_instance.setNumber(3, sender=owner) - def assert_receipt_logs(receipt: ReceiptAPI, num: int): + def assert_receipt_logs(receipt: "ReceiptAPI", num: int): logs = receipt.decode_logs(event_type) assert len(logs) == 1 assert_log_values(logs[0], num)