diff --git a/docs/userguides/clis.md b/docs/userguides/clis.md index 09c63ffc99..395a9f6fe6 100644 --- a/docs/userguides/clis.md +++ b/docs/userguides/clis.md @@ -63,28 +63,60 @@ def cli(cli_ctx): ## Network Tools -The [@network_option()](../methoddocs/cli.html#ape.cli.options.network_option) allows you to select an ecosystem / network / provider combination. -When using with the [NetworkBoundCommand](../methoddocs/cli.html#ape.cli.commands.NetworkBoundCommand) class, you can cause your CLI to connect before any of your code executes. -This is useful if your script or command requires a provider connection in order for it to run. +The [@network_option()](../methoddocs/cli.html#ape.cli.options.network_option) allows you to select an ecosystem, network, and provider. +To specify the network option, use values like: + +```shell +--network ethereum +--network ethereum:sepolia +--network ethereum:mainnet:alchemy +--network ::foundry +``` + +To use default values automatically, omit sections of the choice, but leave the semi-colons for parsing. +For example, `::test` means use the default ecosystem and network and the `test` provider. + +Use `ecosystem`, `network`, and `provider` argument names in your command implementation to access their corresponding class instances: ```python import click -from ape import networks -from ape.cli import network_option, NetworkBoundCommand +from ape.cli import network_option +@click.command() +@network_option() +def cmd(provider): + # This command only needs the provider. + click.echo(provider.name) @click.command() @network_option() -def cmd(network): - # Choices like "ethereum" or "polygon:local:test". - click.echo(network) +def cmd_2(ecosystem, network, provider): + # This command uses all parts of the parsed network choice. + click.echo(ecosystem.name) + click.echo(network.name) + click.echo(provider.name) +``` +The [ConnectedProviderCommand](../methoddocs/cli.html#ape.cli.commands.ConnectedProviderCommand) automatically uses the `--network` option and connects to the network before any of your code executes and then disconnects afterward. +This is useful if your script or command requires a provider connection in order for it to run. +Additionally, specify `ecosystem`, `network`, or `provider` in your command function if you need any of those instances in your `ConnectedProviderCommand`, just like when using `network_option`. -@click.command(cls=NetworkBoundCommand) -@network_option() -def cmd(network): - # Fails if we are not connected. - click.echo(networks.provider.network.name) +```python +import click +from ape.cli import ConnectedProviderCommand + +@click.command(cls=ConnectedProviderCommand) +def cmd(network, provider): + click.echo(network.name) + click.echo(provider.is_connected) # True + +@click.command(cls=ConnectedProviderCommand) +def cmd(provider): + click.echo(provider.is_connected) # True + +@click.command(cls=ConnectedProviderCommand) +def cmd(): + click.echo("Using params is from ConnectedProviderCommand is optional") ``` ## Account Tools diff --git a/docs/userguides/networks.md b/docs/userguides/networks.md index 6e5b949834..4a554cd5b0 100644 --- a/docs/userguides/networks.md +++ b/docs/userguides/networks.md @@ -12,7 +12,7 @@ ape test --network ethereum:local:foundry ape console --network arbitrum:testnet:alchemy ``` -You can also use the `--network` option on scripts that use the `main()` method approach or scripts that implement that `NetworkBoundCommand` command type. +You can also use the `--network` option on scripts that use the `main()` method approach or scripts that implement that `ConnectedProviderCommand` command type. See [the scripting guide](./scripts.html) to learn more about scripts and how to add the network option. **NOTE**: You can omit values to use defaults. @@ -85,9 +85,9 @@ geth: uri: https://foo.node.bar ``` -## Ad-hoc Network Connection +## Custom Network Connection -If you would like to connect to a URI using the `geth` provider, you can specify a URI for the provider name in the `--network` option: +If you would like to connect to a URI using the default Ethereum node provider, you can specify a URI for the provider name in the `--network` option: ```bash ape run script --network ethereum:mainnet:https://foo.bar diff --git a/docs/userguides/scripts.md b/docs/userguides/scripts.md index 28fc79b27e..ffb363c47e 100644 --- a/docs/userguides/scripts.md +++ b/docs/userguides/scripts.md @@ -33,17 +33,25 @@ ape run hello helloworld ``` Note that by default, `cli` scripts do not have [`ape.cli.network_option`](../methoddocs/cli.html?highlight=options#ape.cli.options.network_option) installed, giving you more flexibility in how you define your scripts. -However, you can add the `network_option` to your scripts by importing both the `NetworkBoundCommand` and the `network_option` from the `ape.cli` namespace: +However, you can add the `network_option` or `ConnectedProviderCommand` to your scripts by importing them from the `ape.cli` namespace: ```python import click -from ape.cli import network_option, NetworkBoundCommand +from ape.cli import network_option, ConnectedProviderCommand -@click.command(cls=NetworkBoundCommand) -@network_option() -def cli(network): - click.echo(f"You are connected to network '{network}'.") +@click.command(cls=ConnectedProviderCommand) +def cli(ecosystem, network): + click(f"You selected a provider on ecosystem '{ecosystem.name}' and {network.name}.") + +@click.command(cls=ConnectedProviderCommand) +def cli(network, provider): + click.echo(f"You are connected to network '{network.name}'.") + click.echo(provider.chain_id) + +@click.command(cls=ConnectedProviderCommand) +def cli_2(): + click.echo(f"Using any network-based argument is completely optional.") ``` Assume we saved this script as `shownet.py` and have the [ape-alchemy](https://github.com/ApeWorX/ape-alchemy) plugin installed. diff --git a/src/ape/api/networks.py b/src/ape/api/networks.py index 13616a371b..a24dc0256a 100644 --- a/src/ape/api/networks.py +++ b/src/ape/api/networks.py @@ -24,7 +24,6 @@ from eth_utils import keccak, to_int from ethpm_types import BaseModel, ContractType from ethpm_types.abi import ABIType, ConstructorABI, EventABI, MethodABI -from pydantic import computed_field from ape.exceptions import ( NetworkError, @@ -94,6 +93,29 @@ class EcosystemAPI(BaseInterfaceModel): def __repr__(self) -> str: return f"<{self.name}>" + @cached_property + def custom_network(self) -> "NetworkAPI": + ethereum_class = None + for plugin_name, ecosystem_class in self.plugin_manager.ecosystems: + if plugin_name == "ethereum": + ethereum_class = ecosystem_class + break + + if ethereum_class is None: + raise NetworkError("Core Ethereum plugin missing.") + + data_folder = mkdtemp() + request_header = self.config_manager.REQUEST_HEADER + init_kwargs = {"data_folder": data_folder, "request_header": request_header} + ethereum = ethereum_class(**init_kwargs) # type: ignore + return NetworkAPI( + name="custom", + ecosystem=ethereum, + data_folder=Path(data_folder), + request_header=request_header, + _default_provider="geth", + ) + @classmethod @abstractmethod def decode_address(cls, raw_address: RawAddress) -> AddressType: @@ -258,7 +280,7 @@ def add_network(self, network_name: str, network: "NetworkAPI"): self.networks[network_name] = network @property - def default_network(self) -> str: + def default_network_name(self) -> str: """ The name of the default network in this ecosystem. @@ -284,6 +306,10 @@ def default_network(self) -> str: # Very unlikely scenario. raise NetworkError("No networks found.") + @property + def default_network(self) -> "NetworkAPI": + return self.get_network(self.default_network_name) + def set_default_network(self, network_name: str): """ Change the default network. @@ -424,19 +450,22 @@ def get_network(self, network_name: str) -> "NetworkAPI": Get the network for the given name. Args: - network_name (str): The name of the network to get. + network_name (str): The name of the network to get. Raises: - :class:`~ape.exceptions.NetworkNotFoundError`: When the network is not present. + :class:`~ape.exceptions.NetworkNotFoundError`: When the network is not present. Returns: - :class:`~ape.api.networks.NetworkAPI` + :class:`~ape.api.networks.NetworkAPI` """ name = network_name.replace("_", "-") if name in self.networks: return self.networks[name] + elif name == "custom": + return self.custom_network + raise NetworkNotFoundError(network_name, ecosystem=self.name, options=self.networks) def get_network_data( @@ -459,7 +488,7 @@ def get_network_data( data: Dict[str, Any] = {"name": str(network_name)} # Only add isDefault key when True - if network_name == self.default_network: + if network_name == self.default_network_name: data["isDefault"] = True data["providers"] = [] @@ -475,7 +504,7 @@ def get_network_data( provider_data: Dict = {"name": str(provider_name)} # Only add isDefault key when True - if provider_name == network.default_provider: + if provider_name == network.default_provider_name: provider_data["isDefault"] = True data["providers"].append(provider_data) @@ -605,6 +634,7 @@ def __enter__(self, *args, **kwargs): # set inner var to the recycled provider for use in push_provider() self._provider = self._recycled_provider ProviderContextManager._recycled_provider = None + return self.push_provider() def __exit__(self, exception, *args, **kwargs): @@ -700,29 +730,6 @@ class NetworkAPI(BaseInterfaceModel): # See ``.default_provider`` which is the proper field. _default_provider: str = "" - @classmethod - def create_adhoc_network(cls) -> "NetworkAPI": - ethereum_class = None - for plugin_name, ecosystem_class in cls.plugin_manager.ecosystems: - if plugin_name == "ethereum": - ethereum_class = ecosystem_class - break - - if ethereum_class is None: - raise NetworkError("Core Ethereum plugin missing.") - - data_folder = mkdtemp() - request_header = cls.config_manager.REQUEST_HEADER - init_kwargs = {"data_folder": data_folder, "request_header": request_header} - ethereum = ethereum_class(**init_kwargs) # type: ignore - return cls( - name="adhoc", - ecosystem=ethereum, - data_folder=Path(data_folder), - request_header=request_header, - _default_provider="geth", - ) - def __repr__(self) -> str: try: chain_id = self.chain_id @@ -879,7 +886,10 @@ def providers(self): # -> Dict[str, Partial[ProviderAPI]] ecosystem_name, network_name, provider_class = plugin_tuple provider_name = clean_plugin_name(provider_class.__module__.split(".")[0]) - if self.ecosystem.name == ecosystem_name and self.name == network_name: + # NOTE: Custom networks work with any provider. + if self.name == "custom" or ( + self.ecosystem.name == ecosystem_name and self.name == network_name + ): # NOTE: Lazily load provider config providers[provider_name] = partial( provider_class, @@ -910,7 +920,7 @@ def get_provider( :class:`~ape.api.providers.ProviderAPI` """ - provider_name = provider_name or self.default_provider + provider_name = provider_name or self.default_provider_name if not provider_name: from ape.managers.config import CONFIG_FILE_NAME @@ -947,9 +957,10 @@ def get_provider( def use_provider( self, - provider_name: str, + provider: Union[str, "ProviderAPI"], provider_settings: Optional[Dict] = None, disconnect_after: bool = False, + disconnect_on_exit: bool = True, ) -> ProviderContextManager: """ Use and connect to a provider in a temporary context. When entering the context, it calls @@ -965,24 +976,37 @@ def use_provider( ... Args: - provider_name (str): The name of the provider to use. + provider (str): The provider instance or the name of the provider to use. + provider_settings (dict, optional): Settings to apply to the provider. + Defaults to ``None``. disconnect_after (bool): Set to ``True`` to force a disconnect after ending the context. This defaults to ``False`` so you can re-connect to the same network, such as in a multi-chain testing scenario. - provider_settings (dict, optional): Settings to apply to the provider. - Defaults to ``None``. + disconnect_on_exit (bool): Whether to disconnect on the exit of the python + session. Defaults to ``True``. Returns: :class:`~ape.api.networks.ProviderContextManager` """ settings = provider_settings or {} - provider = self.get_provider(provider_name=provider_name, provider_settings=settings) - return ProviderContextManager(provider=provider, disconnect_after=disconnect_after) - @computed_field() # type: ignore[misc] + # NOTE: The main reason we allow a provider instance here is to avoid unnecessarily + # re-initializing the class. + provider_obj = ( + self.get_provider(provider_name=provider, provider_settings=settings) + if isinstance(provider, str) + else provider + ) + + return ProviderContextManager( + provider=provider_obj, + disconnect_after=disconnect_after, + disconnect_on_exit=disconnect_on_exit, + ) + @property - def default_provider(self) -> Optional[str]: + def default_provider_name(self) -> Optional[str]: """ The name of the default provider or ``None``. @@ -1005,6 +1029,13 @@ def default_provider(self) -> Optional[str]: # There are no providers at all for this network. return None + @property + def default_provider(self) -> Optional["ProviderAPI"]: + if (name := self.default_provider_name) and name in self.providers: + return self.get_provider(name) + + return None + @property def choice(self) -> str: return f"{self.ecosystem.name}:{self.name}" @@ -1056,7 +1087,9 @@ def use_default_provider( if self.default_provider: settings = provider_settings or {} return self.use_provider( - self.default_provider, provider_settings=settings, disconnect_after=disconnect_after + self.default_provider.name, + provider_settings=settings, + disconnect_after=disconnect_after, ) raise NetworkError(f"No providers for network '{self.name}'.") @@ -1089,7 +1122,7 @@ def verify_chain_id(self, chain_id: int): not local or adhoc and has a different hardcoded chain ID than the given one. """ - if self.name not in ("adhoc", LOCAL_NETWORK_NAME) and self.chain_id != chain_id: + if self.name not in ("custom", LOCAL_NETWORK_NAME) and self.chain_id != chain_id: raise NetworkMismatchError(chain_id, self) @@ -1112,7 +1145,7 @@ def upstream_provider(self) -> "UpstreamProvider": """ config_choice = self._network_config.get("upstream_provider") - if provider_name := config_choice or self.upstream_network.default_provider: + if provider_name := config_choice or self.upstream_network.default_provider_name: return self.upstream_network.get_provider(provider_name) raise NetworkError(f"Upstream network '{self.upstream_network}' has no providers.") @@ -1135,7 +1168,7 @@ def use_upstream_provider(self) -> ProviderContextManager: Returns: :class:`~ape.api.networks.ProviderContextManager` """ - return self.upstream_network.use_provider(self.upstream_provider.name) + return self.upstream_network.use_provider(self.upstream_provider) def create_network_type(chain_id: int, network_id: int) -> Type[NetworkAPI]: diff --git a/src/ape/api/providers.py b/src/ape/api/providers.py index 57874edefd..4ebd3bd568 100644 --- a/src/ape/api/providers.py +++ b/src/ape/api/providers.py @@ -6,6 +6,7 @@ import shutil import sys import time +import warnings from logging import FileHandler, Formatter, Logger, getLogger from pathlib import Path from signal import SIGINT, SIGTERM, signal @@ -235,6 +236,13 @@ def network_choice(self) -> str: """ return f"{self.network.choice}:{self.name}" + def get_storage_at(self, *args, **kwargs) -> HexBytes: + warnings.warn( + "'provider.get_storage_at()' is deprecated. Use 'provider.get_storage()'.", + DeprecationWarning, + ) + return self.get_storage(*args, **kwargs) + @raises_not_implemented def get_storage( # type: ignore[empty-body] self, address: AddressType, slot: int, block_id: Optional[BlockID] = None diff --git a/src/ape/cli/__init__.py b/src/ape/cli/__init__.py index c73ab5a10b..952d9aec67 100644 --- a/src/ape/cli/__init__.py +++ b/src/ape/cli/__init__.py @@ -12,9 +12,10 @@ output_format_choice, select_account, ) -from ape.cli.commands import NetworkBoundCommand +from ape.cli.commands import ConnectedProviderCommand, NetworkBoundCommand from ape.cli.options import ( ApeCliContextObject, + NetworkOption, account_option, ape_cli_context, contract_option, @@ -33,20 +34,22 @@ "AllFilePaths", "ape_cli_context", "ApeCliContextObject", + "ConnectedProviderCommand", "contract_file_paths_argument", "contract_option", "existing_alias_argument", - "select_account", "incompatible_with", "network_option", "NetworkBoundCommand", "NetworkChoice", + "NetworkOption", "non_existing_alias_argument", "output_format_choice", "output_format_option", "OutputFormat", "Path", "PromptChoice", + "select_account", "skip_confirmation_option", "verbosity_option", ] diff --git a/src/ape/cli/choices.py b/src/ape/cli/choices.py index 8e330933ea..f1f84b4d6d 100644 --- a/src/ape/cli/choices.py +++ b/src/ape/cli/choices.py @@ -1,4 +1,5 @@ import re +import warnings from enum import Enum from functools import lru_cache from typing import Any, Callable, Iterator, List, Optional, Sequence, Type, Union @@ -8,10 +9,10 @@ from ape import accounts, networks from ape.api.accounts import AccountAPI +from ape.api.providers import ProviderAPI from ape.exceptions import AccountsError from ape.types import _LazySequence -ADHOC_NETWORK_PATTERN = re.compile(r"\w*:\w*:https?://\w*.*") _ACCOUNT_TYPE_FILTER = Union[ None, Sequence[AccountAPI], Type[AccountAPI], Callable[[AccountAPI], bool] ] @@ -138,6 +139,18 @@ def select(self) -> str: raise IndexError(f"Choice index '{choice_idx}' out of range.") +def get_user_selected_account( + prompt_message: Optional[str] = None, account_type: _ACCOUNT_TYPE_FILTER = None +) -> AccountAPI: + """ + **DEPRECATED**: Use :meth:`~ape.cli.choices.select_account` instead. + """ + warnings.warn( + "'get_user_selected_account' is deprecated. Use 'select_account'.", DeprecationWarning + ) + return select_account(prompt_message=prompt_message, key=account_type) + + def select_account( prompt_message: Optional[str] = None, key: _ACCOUNT_TYPE_FILTER = None ) -> AccountAPI: @@ -315,13 +328,17 @@ class NetworkChoice(click.Choice): This is used in :meth:`~ape.cli.options.network_option`. """ + CUSTOM_NETWORK_PATTERN = re.compile(r"\w*:\w*:https?://\w*.*") + def __init__( self, case_sensitive=True, ecosystem: _NETWORK_FILTER = None, network: _NETWORK_FILTER = None, provider: _NETWORK_FILTER = None, + base_type: Type = ProviderAPI, ): + self.base_type = base_type super().__init__( get_networks(ecosystem=ecosystem, network=network, provider=provider), case_sensitive ) @@ -330,26 +347,44 @@ def get_metavar(self, param): return "[ecosystem-name][:[network-name][:[provider-name]]]" def convert(self, value: Any, param: Optional[Parameter], ctx: Optional[Context]) -> Any: - if ( - ADHOC_NETWORK_PATTERN.match(value) + if not value or value in ("None", "none"): + return None + + if not self.is_custom_value(value): + try: + # Validate result. + choice = super().convert(value, param, ctx) + except BadParameter as err: + # If an error was not raised for some reason, raise a simpler error. + # NOTE: Still avoid showing the massive network options list. + raise click.BadParameter( + "Invalid network choice. Use `ape networks list` to see options." + ) from err + + else: + # By-pass choice constraints when using custom network. + choice = value + + if issubclass(self.base_type, ProviderAPI): + # Return the provider. + return networks.get_provider_from_choice(network_choice=value) + + elif isinstance(self.base_type, str): + # The user wants the regular choice back. + return choice + + else: + raise TypeError(f"Unhandled type '{self.base_type}' for NetworkChoice.") + + @classmethod + def is_custom_value(cls, value) -> bool: + return ( + value is not None + and isinstance(value, str) + and cls.CUSTOM_NETWORK_PATTERN.match(value) is not None or str(value).startswith("http://") or str(value).startswith("https://") - ): - # By-pass choice constraints when using adhoc network - return value - - try: - return super().convert(value, param, ctx) - except BadParameter as err: - # Find out actual bad parts of the value to show better error. - # The following line should raise a nicer error. - networks.get_provider_from_choice(network_choice=value) - - # If an error was not raised for some reason, raise a simpler error. - # NOTE: Still avoid showing the massive network options list. - raise click.BadParameter( - "Invalid network choice. Use `ape networks list` to see options." - ) from err + ) class OutputFormat(Enum): diff --git a/src/ape/cli/commands.py b/src/ape/cli/commands.py index 13a3abbf6c..4c60108ad9 100644 --- a/src/ape/cli/commands.py +++ b/src/ape/cli/commands.py @@ -1,29 +1,125 @@ -from typing import Any +import inspect +import warnings +from typing import Any, List import click from click import Context from ape import networks +from ape.api import ProviderAPI +from ape.exceptions import NetworkError def check_parents_for_interactive(ctx: Context) -> bool: interactive: bool = ctx.params.get("interactive", False) if interactive: return True + # If not found, check the parent context. if interactive is None and ctx.parent: return check_parents_for_interactive(ctx.parent) + return False -class NetworkBoundCommand(click.Command): +class ConnectedProviderCommand(click.Command): """ A command that uses the :meth:`~ape.cli.options.network_option`. It will automatically set the network for the duration of the command execution. """ + def __init__(self, *args, **kwargs): + self._use_cls_types = kwargs.pop("use_cls_types", True) + super().__init__(*args, **kwargs) + + def parse_args(self, ctx: Context, args: List[str]) -> List[str]: + if not any( + isinstance(param, click.core.Option) and param.name == "network" + for param in self.params + ): + from ape.cli.options import NetworkOption + + base_type = ProviderAPI if self._use_cls_types else str + option = NetworkOption(base_type=base_type) + self.params.append(option) + + return super().parse_args(ctx, args) + def invoke(self, ctx: Context) -> Any: - value = ctx.params.get("network") or networks.default_ecosystem.name interactive = check_parents_for_interactive(ctx) - with networks.parse_network_choice(value, disconnect_on_exit=not interactive): - super().invoke(ctx) + param = ctx.params.get("network") + if param is not None and isinstance(param, ProviderAPI): + provider = param + network_context = provider.network.use_provider( + provider, disconnect_on_exit=not interactive + ) + elif param is not None and isinstance(param, str): + network_context = networks.parse_network_choice(param) + elif param is None: + ecosystem = networks.default_ecosystem + network = ecosystem.default_network + if provider_name := network.default_provider_name: + network_context = network.use_provider(provider_name) + else: + raise NetworkError(f"Network {network.name} has no providers.") + else: + raise TypeError(f"Unknown type for network choice: '{param}'.") + + with network_context: + if self.callback is not None: + signature = inspect.signature(self.callback) + callback_args = [x.name for x in signature.parameters.values()] + + opt_name = "network" + param = ctx.params.pop(opt_name, None) + if param is None: + ecosystem = networks.default_ecosystem + network = ecosystem.default_network + # Use default + if default_provider := network.default_provider: + provider = default_provider + else: + # Unlikely to get here. + raise ValueError( + f"Missing default provider for network '{network.choice}'. " + f"Using 'ethereum:local:test'." + ) + + elif isinstance(param, ProviderAPI): + provider = param + + elif isinstance(param, str): + # Is a choice str + provider = networks.parse_network_choice(param)._provider + else: + raise TypeError(f"Can't handle type of parameter '{param}'.") + + if self._use_cls_types: + if "ecosystem" in callback_args: + ctx.params["ecosystem"] = provider.network.ecosystem + if "network" in callback_args: + ctx.params["network"] = provider.network + if "provider" in callback_args: + ctx.params["provider"] = provider + + # If none of the above, the user doesn't use any network value. + + else: + # Legacy behavior, but may have a purpose. + ctx.params[opt_name] = provider.network_choice + + return ctx.invoke(self.callback, **ctx.params) + + +# TODO: 0.8 delete +class NetworkBoundCommand(ConnectedProviderCommand): + def __init__(self, *args, **kwargs): + warnings.warn( + "'NetworkBoundCommand' is deprecated. Use 'ConnectedProviderCommand'.", + DeprecationWarning, + ) + + # Disable the advanced network class types so it behaves legacy. + kwargs["use_cls_types"] = False + + super().__init__(*args, **kwargs) diff --git a/src/ape/cli/options.py b/src/ape/cli/options.py index efdc3b236a..0bf1c8d6f3 100644 --- a/src/ape/cli/options.py +++ b/src/ape/cli/options.py @@ -1,9 +1,14 @@ -from typing import Callable, Dict, List, NoReturn, Optional, Type, Union +import inspect +import sys +from functools import partial +from typing import Callable, Dict, List, NoReturn, Optional, Sequence, Type, Union import click +from click import Option from ethpm_types import ContractType from ape import networks, project +from ape.api import ProviderAPI from ape.cli.choices import ( _ACCOUNT_TYPE_FILTER, AccountAliasPromptChoice, @@ -101,6 +106,59 @@ def decorator(f): return decorator +class NetworkOption(Option): + """ + The class used in `:meth:~ape.cli.options.network_option`. + """ + + # NOTE: Has to be kwargs only to avoid multiple-values for arg error. + def __init__(self, *args, **kwargs) -> None: + ecosystem = kwargs.pop("ecosystem", None) + network = kwargs.pop("network", None) + provider = kwargs.pop("provider", None) + default = kwargs.pop("default", "auto") + base_type = kwargs.pop("base_type", ProviderAPI) + + # NOTE: If using network_option, this part is skipped + # because parsing happens earlier to handle advanced usage. + if "type" not in kwargs: + kwargs["type"] = NetworkChoice( + case_sensitive=False, + ecosystem=ecosystem, + network=network, + provider=provider, + base_type=base_type, + ) + auto = default == "auto" + required = kwargs.get("required", False) + + if auto and not required: + if ecosystem: + default = ecosystem[0] if isinstance(ecosystem, (list, tuple)) else ecosystem + + else: + # NOTE: Use a function as the default so it is calculated lazily + def fn(): + return networks.default_ecosystem.name + + default = fn + + elif auto: + default = None + + help_msg = ( + "Override the default network and provider. (see `ape networks list` for options)" + ) + kwargs = { + "param_decls": ("--network",), + "help": help_msg, + "default": default, + "required": required, + **kwargs, + } + super().__init__(**kwargs) + + def network_option( default: Optional[Union[str, Callable]] = "auto", ecosystem: Optional[Union[List[str], str]] = None, @@ -127,34 +185,88 @@ def network_option( kwargs: Additional overrides to ``click.option``. """ - auto = default == "auto" - - if auto and not required: - if ecosystem: - default = ecosystem[0] if isinstance(ecosystem, (list, tuple)) else ecosystem + def decorator(f): + # When using network_option, handle parsing now so we can pass to + # callback outside of command context. + kwargs["type"] = kwargs.pop("type", None) or NetworkChoice( + case_sensitive=False, ecosystem=ecosystem, network=network, provider=provider + ) + + # Find passed --network value + network_choice = None + arguments = _get_sys_argv() + for idx, val in enumerate(arguments): + if val == "--network": + next_idx = idx + 1 + if next_idx < len(arguments): + network_choice = arguments[next_idx] + break + + if network_choice in ("None", "none"): + # Specified None in cmd line- ignore. + # Or is a custom value. + choice_classes = None + + elif network_choice is None: + # Unspecified. + ecosystem_obj = networks.default_ecosystem + network_obj = ecosystem_obj.default_network + choice_classes = { + "ecosystem": networks.default_ecosystem, + "network": ecosystem_obj.default_network, + "provider": network_obj.default_provider, + } else: - # NOTE: Use a function as the default so it is calculated lazily - def fn(): - return networks.default_ecosystem.name + network_ctx = networks.parse_network_choice(network_choice) + provider_obj = network_ctx._provider + network_obj = provider_obj.network + ecosystem_obj = network_obj.ecosystem + choice_classes = { + "ecosystem": ecosystem_obj, + "network": network_obj, + "provider": provider_obj, + } + + # Create the callback using values from the parsed network choice. + # Only pass in values requested. Exposure is optional! + if choice_classes is None: + # Was told to use None explicitly. + partial_f = f - default = fn - - elif auto: - default = None + else: + partial_kwargs = {} + signature = inspect.signature(f) + requested_data = [x.name for x in signature.parameters.values()] + for arg_type in ("ecosystem", "network", "provider"): + if arg_type in requested_data: + partial_kwargs[arg_type] = choice_classes[arg_type] + + partial_f = partial(f, **partial_kwargs) + partial_f.__name__ = f.__name__ # type: ignore[attr-defined] + + # Skip NetworkChoice parsing since we did it already here. + kwargs["type"] = None + + # Set this to false to avoid click passing in a str value for network. + # This happens with `kwargs["type"] = None` and we are already handling + # `network` via the partial. + kwargs["expose_value"] = False + + # Create the actual option. + option = click.option( + default=default, + ecosystem=ecosystem, + network=network, + provider=provider, + required=required, + cls=NetworkOption, + **kwargs, + )(partial_f) + + return option - return click.option( - "--network", - type=NetworkChoice( - case_sensitive=False, ecosystem=ecosystem, network=network, provider=provider - ), - default=default, - help="Override the default network and provider. (see `ape networks list` for options)", - show_default=True, - show_choices=False, - required=required, - **kwargs, - ) + return decorator def skip_confirmation_option(help=""): @@ -284,3 +396,8 @@ def handle_parse_result(self, ctx, opts, args): return super().handle_parse_result(ctx, opts, args) return IncompatibleOption + + +def _get_sys_argv() -> Sequence[str]: + # Is separate fn so can be mocked in tests. + return sys.argv diff --git a/src/ape/managers/networks.py b/src/ape/managers/networks.py index a389dcc804..952371749a 100644 --- a/src/ape/managers/networks.py +++ b/src/ape/managers/networks.py @@ -1,6 +1,6 @@ import json from functools import cached_property -from typing import Collection, Dict, Iterator, List, Optional, Set, Union +from typing import Collection, Dict, Iterator, List, Optional, Set, Type, Union import yaml @@ -8,6 +8,7 @@ from ape.api.networks import NetworkAPI from ape.exceptions import ApeAttributeError, EcosystemNotFoundError, NetworkError from ape.managers.base import BaseManager +from ape_ethereum.provider import EthereumNodeProvider class NetworkManager(BaseManager): @@ -152,33 +153,55 @@ def to_kwargs(name: str) -> Dict: ecosystems = self.plugin_manager.ecosystems return {n: cls(**to_kwargs(n)) for n, cls in ecosystems} # type: ignore - def create_adhoc_geth_provider(self, uri: str) -> ProviderAPI: + def create_custom_provider( + self, + connection_str: str, + provider_cls: Type[ProviderAPI] = EthereumNodeProvider, + provider_name: Optional[str] = None, + ) -> ProviderAPI: """ - Create an ad-hoc connection to a URI using the GethProvider core plugin. + Create a custom connection to a URI using the EthereumNodeProvider provider. **NOTE**: This provider will assume EVM-like behavior and this is generally not recommended. Use plugins when possible! Args: - uri (str): The URI of the node. + connection_str (str): The connection string of the node, such as its URI + when using HTTP. + provider_cls (Type[:class:`~ape.api.providers.ProviderAPI`]): Defaults to + :class:`~ape_ethereum.providers.EthereumNodeProvider`. + provider_name (Optional[str]): The name of the provider. Defaults to best guess. Returns: :class:`~ape.api.providers.ProviderAPI`: The Geth provider implementation that comes with Ape. """ - geth_class = None - for plugin_name, (_, _, provider_class) in self.plugin_manager.providers: - if plugin_name == "geth": - geth_class = provider_class - break + network = self.ethereum.custom_network + + if provider_name is None: + if issubclass(provider_cls, EthereumNodeProvider): + name = "geth" + + elif cls_name := getattr(provider_cls, "name", None): + name = cls_name + + elif cls_name := getattr(provider_cls, "__name__"): + name = cls_name.lower() - if geth_class is None: - raise NetworkError("Core Geth plugin missing.") + else: + # Would be unusual for this to happen though. + name = "provider" - network = NetworkAPI.create_adhoc_network() - return geth_class( + else: + name = provider_name + + if not connection_str.startswith("http"): + raise ValueError("Currently, only HTTP-based custom nodes are supported.") + + return (provider_cls or EthereumNodeProvider)( + name=name, network=network, - provider_settings={"uri": uri}, + provider_settings={"uri": connection_str}, data_folder=network.data_folder, request_header=network.request_header, ) @@ -301,14 +324,14 @@ def get_network_choices( for provider_name in providers: if ( ecosystem_name == self.default_ecosystem.name - and network_name == ecosystem.default_network + and network_name == ecosystem.default_network_name ): yield f"::{provider_name}" if ecosystem_name == self.default_ecosystem.name: yield f":{network_name}:{provider_name}" - if network_name == ecosystem.default_network: + if network_name == ecosystem.default_network_name: yield f"{ecosystem_name}::{provider_name}" # Always include the full path as an option. @@ -370,20 +393,22 @@ def get_provider_from_choice( if network_choice is None: default_network = self.default_ecosystem.default_network - return self.default_ecosystem[default_network].get_provider( - provider_settings=provider_settings - ) + return default_network.get_provider(provider_settings=provider_settings) elif network_choice.startswith("http://") or network_choice.startswith("https://"): - return self.create_adhoc_geth_provider(network_choice) + return self.create_custom_provider(network_choice) selections = network_choice.split(":") # NOTE: Handle case when URI is passed e.g. "http://..." if len(selections) > 3: - selections[2] = ":".join(selections[2:]) + provider_value = ":".join(selections[2:]) + selections[2] = provider_value selections = selections[:3] + if provider_value.startswith("https://") or provider_value.startswith("https://"): + selections[1] = selections[1] or "custom" + if selections == network_choice or len(selections) == 1: # Either split didn't work (in which case it matches the start) # or there was nothing after the ``:`` (e.g. "ethereum:") @@ -391,20 +416,20 @@ def get_provider_from_choice( # By default, the "local" network should be specified for # any ecosystem (this should not correspond to a production chain) default_network = ecosystem.default_network - return ecosystem[default_network].get_provider(provider_settings=provider_settings) + return default_network.get_provider(provider_settings=provider_settings) elif len(selections) == 2: # Only ecosystem and network were specified, not provider ecosystem_name, network_name = selections ecosystem = self.get_ecosystem(ecosystem_name or self.default_ecosystem.name) - network = ecosystem.get_network(network_name or ecosystem.default_network) + network = ecosystem.get_network(network_name or ecosystem.default_network_name) return network.get_provider(provider_settings=provider_settings) elif len(selections) == 3: # Everything is specified, use specified provider for ecosystem and network ecosystem_name, network_name, provider_name = selections ecosystem = self.get_ecosystem(ecosystem_name or self.default_ecosystem.name) - network = ecosystem.get_network(network_name or ecosystem.default_network) + network = ecosystem.get_network(network_name or ecosystem.default_network_name) return network.get_provider( provider_name=provider_name, provider_settings=provider_settings ) @@ -438,6 +463,8 @@ def parse_network_choice( disconnect_after (bool): Set to True to terminate the connection completely at the end of context. NOTE: May only work if the network was also started from this session. + disconnect_on_exit (bool): Whether to disconnect on the exit of the python + session. Defaults to ``True``. Returns: :class:`~api.api.networks.ProviderContextManager` diff --git a/src/ape_cache/_cli.py b/src/ape_cache/_cli.py index b88bf9ffdd..8e859ed286 100644 --- a/src/ape_cache/_cli.py +++ b/src/ape_cache/_cli.py @@ -1,8 +1,7 @@ import click import pandas as pd -from ape import networks -from ape.cli import NetworkBoundCommand, network_option +from ape.cli import ConnectedProviderCommand, network_option from ape.logging import logger from ape.utils import ManagerAccessMixin @@ -20,7 +19,7 @@ def cli(): @cli.command(short_help="Initialize a new cache database") @network_option(required=True) -def init(network): +def init(ecosystem, network): """ Initializes an SQLite database and creates a file to store data from the provider. @@ -29,21 +28,16 @@ def init(network): give an ecosystem name and a network name to initialize the database. """ - provider = networks.get_provider_from_choice(network) - ecosystem_name = provider.network.ecosystem.name - network_name = provider.network.name - - get_engine().init_database(ecosystem_name, network_name) - logger.success(f"Caching database initialized for {ecosystem_name}:{network_name}.") + get_engine().init_database(ecosystem.name, network.name) + logger.success(f"Caching database initialized for {ecosystem.name}:{network.name}.") @cli.command( - cls=NetworkBoundCommand, + cls=ConnectedProviderCommand, short_help="Call and print SQL statement to the cache database", ) -@network_option() @click.argument("query_str") -def query(query_str, network): +def query(query_str): """ Allows for a query of the database from an SQL statement. @@ -62,7 +56,7 @@ def query(query_str, network): @cli.command(short_help="Purges entire database") @network_option(required=True) -def purge(network): +def purge(ecosystem, network): """ Purges data from the selected database instance. @@ -74,9 +68,7 @@ def purge(network): purge the database of choice. """ - provider = networks.get_provider_from_choice(network) - ecosystem_name = provider.network.ecosystem.name - network_name = provider.network.name - + ecosystem_name = network.ecosystem.name + network_name = network.name get_engine().purge_database(ecosystem_name, network_name) logger.success(f"Caching database purged for {ecosystem_name}:{network_name}.") diff --git a/src/ape_console/_cli.py b/src/ape_console/_cli.py index 10f05f68f5..39ecd64012 100644 --- a/src/ape_console/_cli.py +++ b/src/ape_console/_cli.py @@ -14,7 +14,7 @@ from ape import config from ape import project as default_project -from ape.cli import NetworkBoundCommand, ape_cli_context, network_option +from ape.cli import ConnectedProviderCommand, ape_cli_context from ape.utils.misc import _python_version from ape.version import version as ape_version from ape_console.config import ConsoleConfig @@ -23,13 +23,12 @@ @click.command( - cls=NetworkBoundCommand, + cls=ConnectedProviderCommand, short_help="Load the console", context_settings=dict(ignore_unknown_options=True), ) -@network_option() @ape_cli_context() -def cli(cli_ctx, network): +def cli(cli_ctx): """Opens a console for the local project.""" verbose = cli_ctx.logger.level == logging.DEBUG return console(verbose=verbose) diff --git a/src/ape_ethereum/ecosystem.py b/src/ape_ethereum/ecosystem.py index 2b85779b7d..dc42201895 100644 --- a/src/ape_ethereum/ecosystem.py +++ b/src/ape_ethereum/ecosystem.py @@ -214,7 +214,7 @@ def config(self) -> EthereumConfig: @property def default_transaction_type(self) -> TransactionType: - network = self.default_network.replace("-", "_") + network = self.default_network_name.replace("-", "_") return self.config[network].default_transaction_type @classmethod diff --git a/src/ape_ethereum/provider.py b/src/ape_ethereum/provider.py index 07eb489b69..fbe33b9342 100644 --- a/src/ape_ethereum/provider.py +++ b/src/ape_ethereum/provider.py @@ -1,21 +1,42 @@ +import os +import sys import time from abc import ABC from concurrent.futures import ThreadPoolExecutor from copy import copy from functools import cached_property from itertools import tee +from pathlib import Path from typing import Any, Dict, Iterator, List, Optional, Union, cast +import ijson # type: ignore +import requests from eth_pydantic_types import HexBytes from eth_typing import BlockNumber, HexStr from eth_utils import add_0x_prefix, to_hex from ethpm_types import EventABI from evm_trace import CallTreeNode as EvmCallTreeNode +from evm_trace import ParityTraceList from evm_trace import TraceFrame as EvmTraceFrame +from evm_trace import ( + create_trace_frames, + get_calltree_from_geth_call_trace, + get_calltree_from_parity_trace, +) from pydantic.dataclasses import dataclass -from web3 import Web3 +from web3 import HTTPProvider, IPCProvider, Web3 from web3.exceptions import ContractLogicError as Web3ContractLogicError -from web3.exceptions import MethodUnavailable, TimeExhausted, TransactionNotFound +from web3.exceptions import ( + ExtraDataLengthError, + MethodUnavailable, + TimeExhausted, + TransactionNotFound, +) +from web3.gas_strategies.rpc import rpc_gas_price_strategy +from web3.middleware import geth_poa_middleware +from web3.middleware.validation import MAX_EXTRADATA_LENGTH +from web3.providers import AutoProvider +from web3.providers.auto import load_provider_from_environment from web3.types import RPCEndpoint, TxParams from ape.api import BlockAPI, ProviderAPI, ReceiptAPI, TransactionAPI @@ -25,6 +46,7 @@ APINotImplementedError, BlockNotFoundError, ContractLogicError, + ContractNotFoundError, OutOfGasError, ProviderError, ProviderNotConnectedError, @@ -47,6 +69,10 @@ from ape.utils import gas_estimation_error_message, run_until_complete, to_int from ape.utils.misc import DEFAULT_MAX_RETRIES_TX +DEFAULT_PORT = 8545 +DEFAULT_HOSTNAME = "localhost" +DEFAULT_SETTINGS = {"uri": f"http://{DEFAULT_HOSTNAME}:{DEFAULT_PORT}"} + def _sanitize_web3_url(msg: str) -> str: if "URI: " not in msg: @@ -231,7 +257,7 @@ def chain_id(self) -> int: if ( self.network.name not in ( - "adhoc", + "custom", LOCAL_NETWORK_NAME, ) and not self.network.is_fork @@ -994,7 +1020,12 @@ def _handle_execution_reverted( elif txn: err_trace = self.provider.get_transaction_trace(txn.txn_hash.hex()) - trace_ls: List[TraceFrame] = list(err_trace) if err_trace else [] + try: + trace_ls: List[TraceFrame] = list(err_trace) if err_trace else [] + except Exception as err: + logger.error(f"Failed getting traceback: {err}") + trace_ls = [] + data = trace_ls[-1].raw if len(trace_ls) > 0 else {} memory = data.get("memory", []) return_value = "".join([x[2:] for x in memory[4:]]) @@ -1017,3 +1048,271 @@ def _handle_execution_reverted( ) ) return self.compiler_manager.enrich_error(result) + + +class EthereumNodeProvider(Web3Provider, ABC): + # optimal values for geth + block_page_size: int = 5000 + concurrency: int = 16 + + name: str = "geth" + + """Is ``None`` until known.""" + can_use_parity_traces: Optional[bool] = None + + @property + def uri(self) -> str: + if "uri" in self.provider_settings: + # Use adhoc, scripted value + return self.provider_settings["uri"] + + config = self.config.model_dump(mode="json").get(self.network.ecosystem.name, None) + if config is None: + return DEFAULT_SETTINGS["uri"] + + # Use value from config file + network_config = config.get(self.network.name) or DEFAULT_SETTINGS + return network_config.get("uri", DEFAULT_SETTINGS["uri"]) + + @property + def connection_str(self) -> str: + return self.uri + + @property + def connection_id(self) -> Optional[str]: + return f"{self.network_choice}:{self.uri}" + + @property + def _clean_uri(self) -> str: + return sanitize_url(self.uri) + + @property + def ipc_path(self) -> Path: + return self.settings.ipc_path or self.data_dir / "geth.ipc" + + @property + def data_dir(self) -> Path: + if self.settings.data_dir: + return self.settings.data_dir.expanduser() + + return _get_default_data_dir() + + @cached_property + def _ots_api_level(self) -> Optional[int]: + # NOTE: Returns None when OTS namespace is not enabled. + try: + result = self._make_request("ots_getApiLevel") + except (NotImplementedError, ApeException, ValueError): + return None + + if isinstance(result, int): + return result + + elif isinstance(result, str) and result.isnumeric(): + return int(result) + + return None + + def _set_web3(self): + # Clear cached version when connecting to another URI. + self._client_version = None # type: ignore + self._web3 = _create_web3(self.uri, ipc_path=self.ipc_path) + + def _complete_connect(self): + client_version = self.client_version.lower() + if "geth" in client_version: + self._log_connection("Geth") + elif "reth" in client_version: + self._log_connection("Reth") + elif "erigon" in client_version: + self._log_connection("Erigon") + self.concurrency = 8 + self.block_page_size = 40_000 + elif "nethermind" in client_version: + self._log_connection("Nethermind") + self.concurrency = 32 + self.block_page_size = 50_000 + else: + client_name = client_version.split("/")[0] + logger.warning(f"Connecting Geth plugin to non-Geth client '{client_name}'.") + logger.warning(f"Connecting Geth plugin to non-Geth client '{client_name}'.") + + self.web3.eth.set_gas_price_strategy(rpc_gas_price_strategy) + + # Check for chain errors, including syncing + try: + chain_id = self.web3.eth.chain_id + except ValueError as err: + raise ProviderError( + err.args[0].get("message") + if all((hasattr(err, "args"), err.args, isinstance(err.args[0], dict))) + else "Error getting chain id." + ) + + try: + block = self.web3.eth.get_block("latest") + except ExtraDataLengthError: + is_likely_poa = True + else: + is_likely_poa = ( + "proofOfAuthorityData" in block + or len(block.get("extraData", "")) > MAX_EXTRADATA_LENGTH + ) + + if is_likely_poa and geth_poa_middleware not in self.web3.middleware_onion: + self.web3.middleware_onion.inject(geth_poa_middleware, layer=0) + + self.network.verify_chain_id(chain_id) + + def disconnect(self): + self.can_use_parity_traces = None + self._web3 = None # type: ignore + self._client_version = None # type: ignore + + def get_transaction_trace(self, txn_hash: str) -> Iterator[TraceFrame]: + frames = self._stream_request( + "debug_traceTransaction", [txn_hash, {"enableMemory": True}], "result.structLogs.item" + ) + for frame in create_trace_frames(frames): + yield self._create_trace_frame(frame) + + def _get_transaction_trace_using_call_tracer(self, txn_hash: str) -> Dict: + return self._make_request( + "debug_traceTransaction", [txn_hash, {"enableMemory": True, "tracer": "callTracer"}] + ) + + def get_call_tree(self, txn_hash: str) -> CallTreeNode: + if self.can_use_parity_traces is True: + return self._get_parity_call_tree(txn_hash) + + elif self.can_use_parity_traces is False: + return self._get_geth_call_tree(txn_hash) + + elif "erigon" in self.client_version.lower(): + tree = self._get_parity_call_tree(txn_hash) + self.can_use_parity_traces = True + return tree + + try: + # Try the Parity traces first, in case node client supports it. + tree = self._get_parity_call_tree(txn_hash) + except (ValueError, APINotImplementedError, ProviderError): + self.can_use_parity_traces = False + return self._get_geth_call_tree(txn_hash) + + # Parity style works. + self.can_use_parity_traces = True + return tree + + def _get_parity_call_tree(self, txn_hash: str) -> CallTreeNode: + result = self._make_request("trace_transaction", [txn_hash]) + if not result: + raise ProviderError(f"Failed to get trace for '{txn_hash}'.") + + traces = ParityTraceList.model_validate(result) + evm_call = get_calltree_from_parity_trace(traces) + return self._create_call_tree_node(evm_call, txn_hash=txn_hash) + + def _get_geth_call_tree(self, txn_hash: str) -> CallTreeNode: + calls = self._get_transaction_trace_using_call_tracer(txn_hash) + evm_call = get_calltree_from_geth_call_trace(calls) + return self._create_call_tree_node(evm_call, txn_hash=txn_hash) + + def _log_connection(self, client_name: str): + msg = f"Connecting to existing {client_name.strip()} node at" + suffix = ( + self.ipc_path.as_posix().replace(Path.home().as_posix(), "$HOME") + if self.ipc_path.exists() + else self._clean_uri + ) + logger.info(f"{msg} {suffix}.") + + def ots_get_contract_creator(self, address: AddressType) -> Optional[Dict]: + if self._ots_api_level is None: + return None + + result = self._make_request("ots_getContractCreator", [address]) + if result is None: + # NOTE: Skip the explorer part of the error message via `has_explorer=True`. + raise ContractNotFoundError(address, has_explorer=True, provider_name=self.name) + + return result + + def _get_contract_creation_receipt(self, address: AddressType) -> Optional[ReceiptAPI]: + if result := self.ots_get_contract_creator(address): + tx_hash = result["hash"] + return self.get_receipt(tx_hash) + + return None + + def _make_request(self, endpoint: str, parameters: Optional[List] = None) -> Any: + parameters = parameters or [] + try: + return super()._make_request(endpoint, parameters) + except ProviderError as err: + if "does not exist/is not available" in str(err): + raise APINotImplementedError( + f"RPC method '{endpoint}' is not implemented by this node instance." + ) from err + + raise # Original error + + def _stream_request(self, method: str, params: List, iter_path="result.item"): + payload = {"jsonrpc": "2.0", "id": 1, "method": method, "params": params} + results = ijson.sendable_list() + coroutine = ijson.items_coro(results, iter_path) + resp = requests.post(self.uri, json=payload, stream=True) + resp.raise_for_status() + + for chunk in resp.iter_content(chunk_size=2**17): + coroutine.send(chunk) + yield from results + del results[:] + + def connect(self): + self._set_web3() + if not self.is_connected: + raise ProviderError(f"No node found on '{self._clean_uri}'.") + + self._complete_connect() + + +def _create_web3(uri: str, ipc_path: Optional[Path] = None): + # Separated into helper method for testing purposes. + def http_provider(): + return HTTPProvider(uri, request_kwargs={"timeout": 30 * 60}) + + def ipc_provider(): + # NOTE: This mypy complaint seems incorrect. + if not (path := ipc_path): + raise ValueError("IPC Path required.") + + return IPCProvider(ipc_path=path) + + # NOTE: This tuple is ordered by try-attempt. + # Try ENV, then IPC, and then HTTP last. + providers = ( + load_provider_from_environment, + ipc_provider, + http_provider, + ) + provider = AutoProvider(potential_providers=providers) + return Web3(provider) + + +def _get_default_data_dir() -> Path: + # Modified from web3.py package to always return IPC even when none exist. + if sys.platform == "darwin": + return Path.home() / "Library" / "Ethereum" + + elif sys.platform.startswith("linux") or sys.platform.startswith("freebsd"): + return Path.home() / "ethereum" + + elif sys.platform == "win32": + return Path(os.path.join("\\\\", ".", "pipe")) + + else: + raise ValueError( + f"Unsupported platform '{sys.platform}'. Only darwin/linux/win32/" + "freebsd are supported. You must specify the data_dir." + ) diff --git a/src/ape_ethereum/transactions.py b/src/ape_ethereum/transactions.py index 0242cfc211..75badbeb57 100644 --- a/src/ape_ethereum/transactions.py +++ b/src/ape_ethereum/transactions.py @@ -18,6 +18,7 @@ from ape.api import ReceiptAPI, TransactionAPI from ape.contracts import ContractEvent from ape.exceptions import OutOfGasError, SignatureError, TransactionError +from ape.logging import logger from ape.types import CallTreeNode, ContractLog, ContractLogContainer, SourceTraceback from ape.utils import ZERO_ADDRESS @@ -179,7 +180,13 @@ def method_called(self) -> Optional[MethodABI]: @cached_property def source_traceback(self) -> SourceTraceback: if contract_type := self.contract_type: - return SourceTraceback.create(contract_type, self.trace, HexBytes(self.data)) + try: + return SourceTraceback.create(contract_type, self.trace, HexBytes(self.data)) + except Exception as err: + # Failing to get a traceback should not halt an Ape application. + # Sometimes, a node crashes and we are left with nothing. + logger.error(f"Problem retrieving traceback: {err}") + pass return SourceTraceback.model_validate([]) diff --git a/src/ape_geth/provider.py b/src/ape_geth/provider.py index d047dba355..8564f2e7f2 100644 --- a/src/ape_geth/provider.py +++ b/src/ape_geth/provider.py @@ -1,58 +1,29 @@ import atexit -import os import shutil -import sys -from abc import ABC -from functools import cached_property from itertools import tee from pathlib import Path from subprocess import DEVNULL, PIPE, Popen from typing import Any, Dict, Iterator, List, Optional, Tuple, Union -import ijson # type: ignore -import requests from eth_pydantic_types import HexBytes from eth_typing import HexStr from eth_utils import add_0x_prefix, to_hex, to_wei -from evm_trace import CallType, ParityTraceList +from evm_trace import CallType from evm_trace import TraceFrame as EvmTraceFrame -from evm_trace import ( - create_trace_frames, - get_calltree_from_geth_call_trace, - get_calltree_from_geth_trace, - get_calltree_from_parity_trace, -) +from evm_trace import create_trace_frames, get_calltree_from_geth_trace from geth.accounts import ensure_account_exists # type: ignore from geth.chain import initialize_chain # type: ignore from geth.process import BaseGethProcess # type: ignore from geth.wrapper import construct_test_chain_kwargs # type: ignore from pydantic_settings import SettingsConfigDict from requests.exceptions import ConnectionError -from web3 import HTTPProvider, Web3 -from web3.exceptions import ExtraDataLengthError -from web3.gas_strategies.rpc import rpc_gas_price_strategy from web3.middleware import geth_poa_middleware -from web3.middleware.validation import MAX_EXTRADATA_LENGTH -from web3.providers import AutoProvider, IPCProvider -from web3.providers.auto import load_provider_from_environment from yarl import URL -from ape.api import ( - PluginConfig, - ReceiptAPI, - SubprocessProvider, - TestProviderAPI, - TransactionAPI, - UpstreamProvider, -) -from ape.exceptions import ( - ApeException, - APINotImplementedError, - ContractNotFoundError, - ProviderError, -) -from ape.logging import LogLevel, logger, sanitize_url -from ape.types import AddressType, BlockID, CallTreeNode, SnapshotID, SourceTraceback, TraceFrame +from ape.api import PluginConfig, SubprocessProvider, TestProviderAPI, TransactionAPI +from ape.exceptions import ProviderError +from ape.logging import LogLevel, logger +from ape.types import BlockID, CallTreeNode, SnapshotID, SourceTraceback from ape.utils import ( DEFAULT_NUMBER_OF_TEST_ACCOUNTS, DEFAULT_TEST_CHAIN_ID, @@ -62,11 +33,12 @@ raises_not_implemented, spawn, ) -from ape_ethereum.provider import Web3Provider - -DEFAULT_PORT = 8545 -DEFAULT_HOSTNAME = "localhost" -DEFAULT_SETTINGS = {"uri": f"http://{DEFAULT_HOSTNAME}:{DEFAULT_PORT}"} +from ape_ethereum.provider import ( + DEFAULT_HOSTNAME, + DEFAULT_PORT, + DEFAULT_SETTINGS, + EthereumNodeProvider, +) class GethDevProcess(BaseGethProcess): @@ -247,224 +219,8 @@ def __init__(self): ) -class BaseGethProvider(Web3Provider, ABC): - # optimal values for geth - block_page_size: int = 5000 - concurrency: int = 16 - - name: str = "geth" - - """Is ``None`` until known.""" - can_use_parity_traces: Optional[bool] = None - - @property - def uri(self) -> str: - if "uri" in self.provider_settings: - # Use adhoc, scripted value - return self.provider_settings["uri"] - - config = self.config.model_dump(mode="json").get(self.network.ecosystem.name, None) - if config is None: - return DEFAULT_SETTINGS["uri"] - - # Use value from config file - network_config = config.get(self.network.name) or DEFAULT_SETTINGS - return network_config.get("uri", DEFAULT_SETTINGS["uri"]) - - @property - def connection_id(self) -> Optional[str]: - return f"{self.network_choice}:{self.uri}" - - @property - def _clean_uri(self) -> str: - return sanitize_url(self.uri) - - @property - def ipc_path(self) -> Path: - return self.settings.ipc_path or self.data_dir / "geth.ipc" - - @property - def data_dir(self) -> Path: - if self.settings.data_dir: - return self.settings.data_dir.expanduser() - - return _get_default_data_dir() - - @cached_property - def _ots_api_level(self) -> Optional[int]: - # NOTE: Returns None when OTS namespace is not enabled. - try: - result = self._make_request("ots_getApiLevel") - except (NotImplementedError, ApeException, ValueError): - return None - - if isinstance(result, int): - return result - - elif isinstance(result, str) and result.isnumeric(): - return int(result) - - return None - - def _set_web3(self): - # Clear cached version when connecting to another URI. - self._client_version = None # type: ignore - self._web3 = _create_web3(self.uri, ipc_path=self.ipc_path) - - def _complete_connect(self): - client_version = self.client_version.lower() - if "geth" in client_version: - self._log_connection("Geth") - elif "reth" in client_version: - self._log_connection("Reth") - elif "erigon" in client_version: - self._log_connection("Erigon") - self.concurrency = 8 - self.block_page_size = 40_000 - elif "nethermind" in client_version: - self._log_connection("Nethermind") - self.concurrency = 32 - self.block_page_size = 50_000 - else: - client_name = client_version.split("/")[0] - logger.warning(f"Connecting Geth plugin to non-Geth client '{client_name}'.") - logger.warning(f"Connecting Geth plugin to non-Geth client '{client_name}'.") - - self.web3.eth.set_gas_price_strategy(rpc_gas_price_strategy) - - # Check for chain errors, including syncing - try: - chain_id = self.web3.eth.chain_id - except ValueError as err: - raise ProviderError( - err.args[0].get("message") - if all((hasattr(err, "args"), err.args, isinstance(err.args[0], dict))) - else "Error getting chain id." - ) - - try: - block = self.web3.eth.get_block("latest") - except ExtraDataLengthError: - is_likely_poa = True - else: - is_likely_poa = ( - "proofOfAuthorityData" in block - or len(block.get("extraData", "")) > MAX_EXTRADATA_LENGTH - ) - - if is_likely_poa and geth_poa_middleware not in self.web3.middleware_onion: - self.web3.middleware_onion.inject(geth_poa_middleware, layer=0) - - self.network.verify_chain_id(chain_id) - - def disconnect(self): - self.can_use_parity_traces = None - self._web3 = None # type: ignore - self._client_version = None # type: ignore - - def get_transaction_trace(self, txn_hash: str) -> Iterator[TraceFrame]: - frames = self._stream_request( - "debug_traceTransaction", [txn_hash, {"enableMemory": True}], "result.structLogs.item" - ) - for frame in create_trace_frames(frames): - yield self._create_trace_frame(frame) - - def _get_transaction_trace_using_call_tracer(self, txn_hash: str) -> Dict: - return self._make_request( - "debug_traceTransaction", [txn_hash, {"enableMemory": True, "tracer": "callTracer"}] - ) - - def get_call_tree(self, txn_hash: str) -> CallTreeNode: - if self.can_use_parity_traces is True: - return self._get_parity_call_tree(txn_hash) - - elif self.can_use_parity_traces is False: - return self._get_geth_call_tree(txn_hash) - - elif "erigon" in self.client_version.lower(): - tree = self._get_parity_call_tree(txn_hash) - self.can_use_parity_traces = True - return tree - - try: - # Try the Parity traces first, in case node client supports it. - tree = self._get_parity_call_tree(txn_hash) - except (ValueError, APINotImplementedError, ProviderError): - self.can_use_parity_traces = False - return self._get_geth_call_tree(txn_hash) - - # Parity style works. - self.can_use_parity_traces = True - return tree - - def _get_parity_call_tree(self, txn_hash: str) -> CallTreeNode: - result = self._make_request("trace_transaction", [txn_hash]) - if not result: - raise ProviderError(f"Failed to get trace for '{txn_hash}'.") - - traces = ParityTraceList.model_validate(result) - evm_call = get_calltree_from_parity_trace(traces) - return self._create_call_tree_node(evm_call, txn_hash=txn_hash) - - def _get_geth_call_tree(self, txn_hash: str) -> CallTreeNode: - calls = self._get_transaction_trace_using_call_tracer(txn_hash) - evm_call = get_calltree_from_geth_call_trace(calls) - return self._create_call_tree_node(evm_call, txn_hash=txn_hash) - - def _log_connection(self, client_name: str): - msg = f"Connecting to existing {client_name.strip()} node at" - suffix = ( - self.ipc_path.as_posix().replace(Path.home().as_posix(), "$HOME") - if self.ipc_path.exists() - else self._clean_uri - ) - logger.info(f"{msg} {suffix}.") - - def ots_get_contract_creator(self, address: AddressType) -> Optional[Dict]: - if self._ots_api_level is None: - return None - - result = self._make_request("ots_getContractCreator", [address]) - if result is None: - # NOTE: Skip the explorer part of the error message via `has_explorer=True`. - raise ContractNotFoundError(address, has_explorer=True, provider_name=self.name) - - return result - - def _get_contract_creation_receipt(self, address: AddressType) -> Optional[ReceiptAPI]: - if result := self.ots_get_contract_creator(address): - tx_hash = result["hash"] - return self.get_receipt(tx_hash) - - return None - - def _make_request(self, endpoint: str, parameters: Optional[List] = None) -> Any: - parameters = parameters or [] - try: - return super()._make_request(endpoint, parameters) - except ProviderError as err: - if "does not exist/is not available" in str(err): - raise APINotImplementedError( - f"RPC method '{endpoint}' is not implemented by this node instance." - ) from err - - raise # Original error - - def _stream_request(self, method: str, params: List, iter_path="result.item"): - payload = {"jsonrpc": "2.0", "id": 1, "method": method, "params": params} - results = ijson.sendable_list() - coroutine = ijson.items_coro(results, iter_path) - - resp = requests.post(self.uri, json=payload, stream=True) - resp.raise_for_status() - - for chunk in resp.iter_content(chunk_size=2**17): - coroutine.send(chunk) - yield from results - del results[:] - - -class GethDev(BaseGethProvider, TestProviderAPI, SubprocessProvider): +# NOTE: Using EthereumNodeProvider because of it's geth-derived default behavior. +class GethDev(EthereumNodeProvider, TestProviderAPI, SubprocessProvider): _process: Optional[GethDevProcess] = None name: str = "geth" can_use_parity_traces: Optional[bool] = False @@ -703,55 +459,6 @@ def build_command(self) -> List[str]: return self._process.command if self._process else [] -class Geth(BaseGethProvider, UpstreamProvider): - @property - def connection_str(self) -> str: - return self.uri - - def connect(self): - self._set_web3() - if not self.is_connected: - raise ProviderError(f"No node found on '{self._clean_uri}'.") - - self._complete_connect() - - -def _create_web3(uri: str, ipc_path: Optional[Path] = None): - # Separated into helper method for testing purposes. - def http_provider(): - return HTTPProvider(uri, request_kwargs={"timeout": 30 * 60}) - - def ipc_provider(): - # NOTE: This mypy complaint seems incorrect. - if not (path := ipc_path): - raise ValueError("IPC Path required.") - - return IPCProvider(ipc_path=path) - - # NOTE: This tuple is ordered by try-attempt. - # Try ENV, then IPC, and then HTTP last. - providers = ( - load_provider_from_environment, - ipc_provider, - http_provider, - ) - provider = AutoProvider(potential_providers=providers) - return Web3(provider) - - -def _get_default_data_dir() -> Path: - # Modified from web3.py package to always return IPC even when none exist. - if sys.platform == "darwin": - return Path.home() / "Library" / "Ethereum" - - elif sys.platform.startswith("linux") or sys.platform.startswith("freebsd"): - return Path.home() / "ethereum" - - elif sys.platform == "win32": - return Path(os.path.join("\\\\", ".", "pipe")) - - else: - raise ValueError( - f"Unsupported platform '{sys.platform}'. Only darwin/linux/win32/" - "freebsd are supported. You must specify the data_dir." - ) +# NOTE: The default behavior of EthereumNodeBehavior assumes geth. +class Geth(EthereumNodeProvider): + pass diff --git a/src/ape_geth/query.py b/src/ape_geth/query.py index 98d6de374f..912a1c9204 100644 --- a/src/ape_geth/query.py +++ b/src/ape_geth/query.py @@ -4,7 +4,7 @@ from ape.api import ReceiptAPI from ape.api.query import ContractCreationQuery, QueryAPI, QueryType from ape.exceptions import QueryEngineError -from ape_geth.provider import BaseGethProvider +from ape_ethereum.provider import EthereumNodeProvider class OTSQueryEngine(QueryAPI): @@ -21,7 +21,7 @@ def perform_query(self, query: QueryType) -> Iterator: # type: ignore[override] @estimate_query.register def estimate_contract_creation_query(self, query: ContractCreationQuery) -> Optional[int]: if provider := self.network_manager.active_provider: - if not isinstance(provider, BaseGethProvider): + if not isinstance(provider, EthereumNodeProvider): return None elif uri := provider.http_uri: return 225 if uri.startswith("http://") else 600 @@ -30,6 +30,6 @@ def estimate_contract_creation_query(self, query: ContractCreationQuery) -> Opti @perform_query.register def get_contract_creation_receipt(self, query: ContractCreationQuery) -> Iterator[ReceiptAPI]: - if self.network_manager.active_provider and isinstance(self.provider, BaseGethProvider): + if self.network_manager.active_provider and isinstance(self.provider, EthereumNodeProvider): if receipt := self.provider._get_contract_creation_receipt(query.contract): yield receipt diff --git a/src/ape_networks/_cli.py b/src/ape_networks/_cli.py index 9606b5d030..32297ca8df 100644 --- a/src/ape_networks/_cli.py +++ b/src/ape_networks/_cli.py @@ -105,16 +105,13 @@ def make_sub_tree(data: Dict, create_tree: Callable) -> Tree: @cli.command() @ape_cli_context() @network_option(default="ethereum:local:geth") -def run(cli_ctx, network): +def run(cli_ctx, provider): """ Start a node process """ - # Ignore extra loggers, such as web3 loggers. cli_ctx.logger._extra_loggers = {} - network_ctx = cli_ctx.network_manager.parse_network_choice(network) - provider = network_ctx._provider if not isinstance(provider, SubprocessProvider): cli_ctx.abort( f"`ape networks run` requires a provider that manages a process, not '{provider.name}'." diff --git a/src/ape_run/_cli.py b/src/ape_run/_cli.py index e9f180ce94..2955d360e4 100644 --- a/src/ape_run/_cli.py +++ b/src/ape_run/_cli.py @@ -11,7 +11,7 @@ from click import Command, Context, Option from ape import networks, project -from ape.cli import NetworkBoundCommand, network_option, verbosity_option +from ape.cli import ConnectedProviderCommand, verbosity_option from ape.cli.options import _VERBOSITY_VALUES, _create_verbosity_kwargs from ape.exceptions import ApeException, handle_ape_exception from ape.logging import logger @@ -134,11 +134,10 @@ def _get_command(self, filepath: Path) -> Union[click.Command, click.Group, None logger.debug(f"Found 'main' method in script: {relative_filepath}") @click.command( - cls=NetworkBoundCommand, + cls=ConnectedProviderCommand, short_help=f"Run '{relative_filepath}:main'", name=relative_filepath.stem, ) - @network_option() @verbosity_option() def call(network): _ = network # Downstream might use this @@ -154,13 +153,11 @@ def call(network): logger.warning(f"No 'main' method or 'cli' command in script: {relative_filepath}") @click.command( - cls=NetworkBoundCommand, + cls=ConnectedProviderCommand, short_help=f"Run '{relative_filepath}:main'", name=relative_filepath.stem, ) - @network_option() - def call(network): - _ = network # Downstream might use this + def call(): with use_scripts_sys_path(filepath.parent.parent): empty_ns = run_script_module(filepath) diff --git a/tests/conftest.py b/tests/conftest.py index 5abed692ae..5247182774 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,11 +1,13 @@ import json import shutil +import subprocess +import sys import tempfile import time from contextlib import contextmanager from pathlib import Path from tempfile import mkdtemp -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Dict, Optional, Sequence import pytest import yaml @@ -181,9 +183,18 @@ def temp_accounts_path(config): shutil.rmtree(path) -@pytest.fixture(scope="session") -def runner(): - yield CliRunner() +@pytest.fixture +def runner(mock_sys_argv): + class ApeRunner(CliRunner): + def invoke(self, *args, **kwargs): + # Also make the command appear in sys.argv + mock_sys_argv.return_value = ("ape", *(args[1] if len(args) > 1 else [])) + try: + return super().invoke(*args, **kwargs) + finally: + mock_sys_argv.return_value = sys.argv + + yield ApeRunner() @pytest.fixture @@ -468,3 +479,49 @@ def mock_home_directory(tmp_path): Path.home = lambda: tmp_path # type: ignore[method-assign] yield tmp_path Path.home = lambda: original_home # type: ignore[method-assign] + + +class SubprocessRunner: + """ + Same CLI commands are better tested using a python subprocess, + such as `ape test` commands because duplicate pytest main methods + do not run well together, or `ape plugins` commands, which may + modify installed plugins. + """ + + def __init__(self, root_cmd: Optional[Sequence[str]] = None): + self.root_cmd = root_cmd or [] + + def invoke(self, subcommand: Optional[Sequence[str]] = None): + subcommand = subcommand or [] + cmd_ls = [*self.root_cmd, *subcommand] + completed_process = subprocess.run(cmd_ls, capture_output=True, text=True) + return SubprocessResult(completed_process) + + +class ApeSubprocessRunner(SubprocessRunner): + """ + Subprocess runner for Ape-specific commands. + """ + + def __init__(self, root_cmd: Optional[Sequence[str]] = None): + ape_path = Path(sys.executable).parent / "ape" + super().__init__([str(ape_path), *(root_cmd or [])]) + + +class SubprocessResult: + def __init__(self, completed_process: subprocess.CompletedProcess): + self._completed_process = completed_process + + @property + def exit_code(self) -> int: + return self._completed_process.returncode + + @property + def output(self) -> str: + return self._completed_process.stdout + + +@pytest.fixture +def mock_sys_argv(mocker): + return mocker.patch("ape.cli.options._get_sys_argv") diff --git a/tests/functional/geth/test_provider.py b/tests/functional/geth/test_provider.py index 6d30fbbb9d..1ed643b566 100644 --- a/tests/functional/geth/test_provider.py +++ b/tests/functional/geth/test_provider.py @@ -88,7 +88,7 @@ def test_connect_wrong_chain_id(mocker, ethereum, geth_provider): geth_provider.network = ethereum.get_network("goerli") # Ensure when reconnecting, it does not use HTTP - factory = mocker.patch("ape_geth.provider._create_web3") + factory = mocker.patch("ape_ethereum.provider._create_web3") factory.return_value = geth_provider._web3 expected_error_message = ( f"Provider connected to chain ID '{geth_provider._web3.eth.chain_id}', " diff --git a/tests/functional/test_cli.py b/tests/functional/test_cli.py index 90fc5fe161..e6b1bd52ed 100644 --- a/tests/functional/test_cli.py +++ b/tests/functional/test_cli.py @@ -1,10 +1,12 @@ import shutil +from typing import Optional, Sequence import click import pytest from ape.cli import ( AccountAliasPromptChoice, + ConnectedProviderCommand, NetworkBoundCommand, PromptChoice, account_option, @@ -15,10 +17,10 @@ select_account, verbosity_option, ) -from ape.exceptions import AccountsError +from ape.exceptions import AccountsError, EcosystemNotFoundError from ape.logging import logger -OUTPUT_FORMAT = "__TEST__{}__" +OUTPUT_FORMAT = "__TEST__{0}:{1}:{2}_" @pytest.fixture @@ -55,13 +57,20 @@ def one_keyfile_account(keyfile_swap_paths, keyfile_account, temp_config): @pytest.fixture -def network_cmd(): - @click.command() - @network_option() - def cmd(network): - click.echo(OUTPUT_FORMAT.format(network)) +def network_cmd(mock_sys_argv): + def fn(cli_args: Optional[Sequence[str]] = None): + if cli_args is not None: + mock_sys_argv.return_value = cli_args + + @click.command() + @network_option() + def cmd(ecosystem, network, provider): + output = OUTPUT_FORMAT.format(ecosystem.name, network.name, provider.name) + click.echo(output) + + return cmd - return cmd + return fn def _setup_temp_acct_number_change(accounts, num_accounts: int): @@ -159,41 +168,54 @@ def test_select_account_with_account_list(runner, keyfile_account, second_keyfil def test_network_option_default(runner, network_cmd): - result = runner.invoke(network_cmd) + cmd = network_cmd() + result = runner.invoke(cmd) assert result.exit_code == 0, result.output - assert OUTPUT_FORMAT.format("ethereum") in result.output + assert OUTPUT_FORMAT.format("ethereum", "local", "test") in result.output def test_network_option_specified(runner, network_cmd): - result = runner.invoke(network_cmd, ["--network", "ethereum:local:test"]) + network_part = ("--network", "ethereum:local:test") + cmd = network_cmd(network_part) + result = runner.invoke(cmd, network_part) assert result.exit_code == 0, result.output - assert OUTPUT_FORMAT.format("ethereum:local:test") in result.output + assert OUTPUT_FORMAT.format("ethereum", "local", "test") in result.output -def test_network_option_unknown(runner, network_cmd): - result = runner.invoke(network_cmd, ["--network", "UNKNOWN"]) - assert result.exit_code != 0, result.output - assert "No ecosystem named 'UNKNOWN'" in str(result.exception) +def test_network_option_unknown(network_cmd): + network_part = ("--network", "UNKNOWN") + with pytest.raises(EcosystemNotFoundError): + network_cmd(network_part) @pytest.mark.parametrize( "network_input", ( - "something:else:https://127.0.0.1:4545", - "something:else:https://127.0.0.1", - "something:else:http://127.0.0.1:4545", - "something:else:http://127.0.0.1", - "something:else:http://foo.bar", - "something:else:https://foo.bar:8000", - ":else:https://foo.bar:8000", + "ethereum:custom:https://127.0.0.1:4545", + "ethereum:custom:https://127.0.0.1", + "ethereum:custom:http://127.0.0.1:4545", + "ethereum:custom:http://127.0.0.1", + "ethereum:custom:http://foo.bar", + "ethereum:custom:https://foo.bar:8000", + ":custom:https://foo.bar:8000", "::https://foo.bar:8000", "https://foo.bar:8000", ), ) -def test_network_option_adhoc(runner, network_cmd, network_input): - result = runner.invoke(network_cmd, ["--network", network_input]) +def test_network_option_custom_uri(runner, network_cmd, network_input): + network_part = ("--network", network_input) + cmd = network_cmd(network_part) + result = runner.invoke(cmd, network_part) assert result.exit_code == 0, result.output - assert OUTPUT_FORMAT.format(network_input) in result.output + assert "custom" in result.output + + +def test_network_option_existing_network_with_custom_uri(runner, network_cmd): + network_part = ("--network", "ethereum:sepolia:https://foo.bar:8000") + cmd = network_cmd(network_part) + result = runner.invoke(cmd, network_part) + assert result.exit_code == 0, result.output + assert "sepolia" in result.output def test_network_option_make_required(runner): @@ -207,25 +229,19 @@ def cmd(network): assert "Error: Missing option '--network'." in result.output -def test_network_option_can_be_none(runner): +def test_network_option_can_be_none(runner, mock_sys_argv): + network_part = ("--network", "None") + mock_sys_argv.return_value = ("cmd", *network_part) + @click.command() @network_option(default=None) def cmd(network): click.echo(f"Value is '{network}'") - result = runner.invoke(cmd, []) + result = runner.invoke(cmd, network_part) assert "Value is 'None'" in result.output -def test_network_option_not_needed_on_network_bound_command(runner): - @click.command(cls=NetworkBoundCommand) - def cmd(): - click.echo("Success!") - - result = runner.invoke(cmd, []) - assert "Success" in result.output - - def test_account_option(runner, keyfile_account): @click.command() @account_option() @@ -396,3 +412,71 @@ def cmd(alias): result = runner.invoke(cmd, ["non-exists"]) assert magic_value in result.output + + +def test_connected_provider_command_no_args_or_network_specified(runner): + @click.command(cls=ConnectedProviderCommand) + def cmd(): + from ape import chain + + click.echo(chain.provider.is_connected) + + result = runner.invoke(cmd) + assert result.exit_code == 0 + assert "True" in result.output, result.output + + +def test_connected_provider_command_invalid_value(runner): + @click.command(cls=ConnectedProviderCommand) + def cmd(): + pass + + result = runner.invoke(cmd, ["--network", "OOGA_BOOGA"], catch_exceptions=False) + assert result.exit_code != 0 + assert "Invalid value for '--network'" in result.output + + +def test_connected_provider_use_provider(runner): + @click.command(cls=ConnectedProviderCommand) + def cmd(provider): + click.echo(provider.is_connected) + + result = runner.invoke(cmd) + assert result.exit_code == 0 + assert "True" in result.output, result.output + + +def test_connected_provider_use_ecosystem_network_and_provider(runner): + @click.command(cls=ConnectedProviderCommand) + def cmd(ecosystem, network, provider): + click.echo(f"{ecosystem.name}:{network.name}:{provider.name}") + + result = runner.invoke(cmd) + assert result.exit_code == 0 + assert "ethereum:local:test" in result.output, result.output + + +def test_connected_provider_use_ecosystem_network_and_provider_with_network_specified(runner): + @click.command(cls=ConnectedProviderCommand) + def cmd(ecosystem, network, provider): + click.echo(f"{ecosystem.name}:{network.name}:{provider.name}") + + result = runner.invoke(cmd, ["--network", "ethereum:local:test"]) + assert result.exit_code == 0 + assert "ethereum:local:test" in result.output, result.output + + +def test_deprecated_network_bound_command(runner): + with pytest.warns( + DeprecationWarning, + match=r"'NetworkBoundCommand' is deprecated\. Use 'ConnectedProviderCommand'\.", + ): + + @click.command(cls=NetworkBoundCommand) + @network_option() + def cmd(network): + click.echo(network) + + result = runner.invoke(cmd, ["--network", "ethereum:local:test"]) + assert result.exit_code == 0, result.output + assert "ethereum:local:test" in result.output, result.output diff --git a/tests/functional/test_ecosystem.py b/tests/functional/test_ecosystem.py index 4e34eb23f2..985314551a 100644 --- a/tests/functional/test_ecosystem.py +++ b/tests/functional/test_ecosystem.py @@ -318,7 +318,7 @@ def test_configure_default_txn_type(temp_config, ethereum): with temp_config(config_dict): ethereum._default_network = "mainnet-fork" - assert ethereum.default_network == "mainnet-fork" + assert ethereum.default_network_name == "mainnet-fork" assert ethereum.default_transaction_type == TransactionType.STATIC ethereum._default_network = LOCAL_NETWORK_NAME diff --git a/tests/functional/test_network_manager.py b/tests/functional/test_network_manager.py index 579c4db34a..ab39d20f7a 100644 --- a/tests/functional/test_network_manager.py +++ b/tests/functional/test_network_manager.py @@ -146,22 +146,22 @@ def test_repr_disconnected(networks_disconnected): assert repr(networks_disconnected.ethereum.goerli) == "" -def test_get_provider_from_choice_adhoc_provider(networks_connected_to_tester): +def test_get_provider_from_choice_custom_provider(networks_connected_to_tester): uri = "https://geth:1234567890abcdef@geth.foo.bar/" provider = networks_connected_to_tester.get_provider_from_choice(f"ethereum:local:{uri}") assert uri in provider.connection_id assert provider.name == "geth" assert provider.uri == uri - assert provider.network.name == "local" + assert provider.network.name == "local" # Network was specified to be local! assert provider.network.ecosystem.name == "ethereum" -def test_get_provider_from_choice_adhoc_ecosystem(networks_connected_to_tester): +def test_get_provider_from_choice_custom_ecosystem(networks_connected_to_tester): uri = "https://geth:1234567890abcdef@geth.foo.bar/" provider = networks_connected_to_tester.get_provider_from_choice(uri) assert provider.name == "geth" assert provider.uri == uri - assert provider.network.name == "adhoc" + assert provider.network.name == "custom" assert provider.network.ecosystem.name == "ethereum" diff --git a/tests/functional/test_provider.py b/tests/functional/test_provider.py index 4ac944ea4e..f7c52154ae 100644 --- a/tests/functional/test_provider.py +++ b/tests/functional/test_provider.py @@ -270,3 +270,9 @@ def test_no_comma_in_rpc_url(): sanitised_url = _sanitize_web3_url(test_url) assert "," not in sanitised_url + + +def test_use_provider_using_provider_instance(eth_tester_provider): + network = eth_tester_provider.network + with network.use_provider(eth_tester_provider) as provider: + assert id(provider) == id(eth_tester_provider) diff --git a/tests/integration/cli/conftest.py b/tests/integration/cli/conftest.py index ff672e30a7..33a5eda3bd 100644 --- a/tests/integration/cli/conftest.py +++ b/tests/integration/cli/conftest.py @@ -1,5 +1,3 @@ -import subprocess -import sys from contextlib import contextmanager from distutils.dir_util import copy_tree from importlib import import_module @@ -9,6 +7,7 @@ import pytest from ape.managers.config import CONFIG_FILE_NAME +from tests.conftest import ApeSubprocessRunner from .test_plugins import ListResult from .utils import NodeId, __project_names__, __projects_directory__, project_skipper @@ -147,43 +146,6 @@ def clean_cache(project): cache_file.unlink() -class ApeSubprocessRunner: - """ - Same CLI commands are better tested using a python subprocess, - such as `ape test` commands because duplicate pytest main methods - do not run well together, or `ape plugins` commands, which may - modify installed plugins. - """ - - def __init__(self, root_cmd: Optional[List[str]] = None): - ape_path = Path(sys.executable).parent / "ape" - self.root_cmd = [str(ape_path), *(root_cmd or [])] - - def invoke(self, subcommand: Optional[List[str]] = None): - subcommand = subcommand or [] - cmd_ls = [*self.root_cmd, *subcommand] - completed_process = subprocess.run(cmd_ls, capture_output=True, text=True) - return SubprocessResult(completed_process) - - -class SubprocessResult: - def __init__(self, completed_process: subprocess.CompletedProcess): - self._completed_process = completed_process - - @property - def exit_code(self) -> int: - return self._completed_process.returncode - - @property - def output(self) -> str: - return self._completed_process.stdout - - -@pytest.fixture(scope="session") -def subprocess_runner(subprocess_runner_cls): - return subprocess_runner_cls() - - @pytest.fixture def switch_config(config): """ @@ -222,7 +184,7 @@ def switch(project, new_content: str): return switch -@pytest.fixture(scope="module") +@pytest.fixture(scope="session") def ape_plugins_runner(): """ Use subprocess runner so can manipulate site packages and see results. diff --git a/tests/integration/cli/test_cache.py b/tests/integration/cli/test_cache.py index 5cb651b4bd..4c1aa9c482 100644 --- a/tests/integration/cli/test_cache.py +++ b/tests/integration/cli/test_cache.py @@ -1,5 +1,6 @@ def test_cache_init_purge(ape_cli, runner): - result = runner.invoke(ape_cli, ["cache", "init", "--network", "ethereum:goerli"]) + cmd = ("cache", "init", "--network", "ethereum:goerli") + result = runner.invoke(ape_cli, cmd) assert result.output == "SUCCESS: Caching database initialized for ethereum:goerli.\n" result = runner.invoke(ape_cli, ["cache", "purge", "--network", "ethereum:goerli"]) assert result.output == "SUCCESS: Caching database purged for ethereum:goerli.\n" diff --git a/tests/integration/cli/test_networks.py b/tests/integration/cli/test_networks.py index 966eede1e1..93f594c976 100644 --- a/tests/integration/cli/test_networks.py +++ b/tests/integration/cli/test_networks.py @@ -123,9 +123,6 @@ def test_list_yaml(ape_cli, runner): @skip_projects_except("geth") def test_geth(ape_cli, runner, networks, project): result = runner.invoke(ape_cli, ["networks", "list"]) - assert ( - networks.provider.network.default_provider == "geth" - ), "Setup failed - default provider didn't apply from config" # Grab ethereum actual = "ethereum (default)\n" + "".join(result.output.split("ethereum (default)\n")[-1]) @@ -161,7 +158,8 @@ def test_filter_providers(ape_cli, runner, networks): @run_once def test_node_not_subprocess_provider(ape_cli, runner): - result = runner.invoke(ape_cli, ["networks", "run", "--network", "ethereum:local:test"]) + cmd = ("networks", "run", "--network", "ethereum:local:test") + result = runner.invoke(ape_cli, ["networks", "run", "--network", cmd]) assert result.exit_code != 0 assert ( result.output diff --git a/tests/integration/cli/test_run.py b/tests/integration/cli/test_run.py index 78ed2340d0..bc3a4a72d8 100644 --- a/tests/integration/cli/test_run.py +++ b/tests/integration/cli/test_run.py @@ -111,9 +111,11 @@ def test_run_interactive(ape_cli, runner, project): @skip_projects_except("script") -def test_run_adhoc_provider(ape_cli, runner, project): +def test_run_custom_provider(ape_cli, runner, project): result = runner.invoke( - ape_cli, ["run", "deploy", "--network", "ethereum:mainnet:http://127.0.0.1:9545"] + ape_cli, + ["run", "deploy", "--network", "ethereum:mainnet:http://127.0.0.1:9545"], + catch_exceptions=False, ) # Show that it attempts to connect @@ -122,7 +124,7 @@ def test_run_adhoc_provider(ape_cli, runner, project): @skip_projects_except("script") -def test_run_adhoc_network(ape_cli, runner, project): +def test_run_custom_network(ape_cli, runner, project): result = runner.invoke(ape_cli, ["run", "deploy", "--network", "http://127.0.0.1:9545"]) # Show that it attempts to connect