Skip to content

Commit

Permalink
fix: private props and recursive err
Browse files Browse the repository at this point in the history
  • Loading branch information
antazoey committed Dec 7, 2023
1 parent 2ee9be1 commit 27820bc
Show file tree
Hide file tree
Showing 10 changed files with 54 additions and 50 deletions.
2 changes: 1 addition & 1 deletion docs/userguides/developing_plugins.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class MyProvider(ProviderAPI):
_web3: Web3 = None # type: ignore

def connect(self):
self.cached_web3 = Web3(HTTPProvider(str("https://localhost:1337")))
self._web3 = Web3(HTTPProvider(str("https://localhost:1337")))

"""Implement rest of abstract methods"""
```
Expand Down
5 changes: 4 additions & 1 deletion src/ape/api/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from eth_utils import keccak, to_int
from ethpm_types import BaseModel, ContractType
from ethpm_types.abi import ABIType, ConstructorABI, EventABI, MethodABI
from pydantic import computed_field

from ape.exceptions import (
NetworkError,
Expand Down Expand Up @@ -700,6 +701,7 @@ class NetworkAPI(BaseInterfaceModel):
request_header: Dict
"""A shareable network HTTP header."""

# See ``.default_provider`` which is the proper field.
_default_provider: str = ""

@classmethod
Expand Down Expand Up @@ -982,7 +984,8 @@ def use_provider(
provider = self.get_provider(provider_name=provider_name, provider_settings=settings)
return ProviderContextManager(provider=provider, disconnect_after=disconnect_after)

@property
@computed_field()
@cached_property
def default_provider(self) -> Optional[str]:
"""
The name of the default provider or ``None``.
Expand Down
23 changes: 10 additions & 13 deletions src/ape/api/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,6 @@ class ProviderAPI(BaseInterfaceModel):
request_header: Dict
"""A header to set on HTTP/RPC requests."""

cached_chain_id: Optional[int] = Field(None, exclude=True)
"""Implementation providers may use this to cache and re-use chain ID."""

block_page_size: int = 100
"""
The amount of blocks to fetch in a response, as a default.
Expand Down Expand Up @@ -758,8 +755,8 @@ class Web3Provider(ProviderAPI, ABC):
`web3.py <https://web3py.readthedocs.io/en/stable/>`__ python package.
"""

cached_web3: Optional[Web3] = None
cached_client_version: Optional[str] = None
_web3: Optional[Web3] = None
_client_version: Optional[str] = None

def __init__(self, *args, **kwargs):
logger.create_logger("web3.RequestManager", handlers=(_sanitize_web3_url,))
Expand All @@ -772,10 +769,10 @@ def web3(self) -> Web3:
Access to the ``web3`` object as if you did ``Web3(HTTPProvider(uri))``.
"""

if not self.cached_web3:
if not self._web3:
raise ProviderNotConnectedError()

return self.cached_web3
return self._web3

@property
def http_uri(self) -> Optional[str]:
Expand Down Expand Up @@ -805,14 +802,14 @@ def ws_uri(self) -> Optional[str]:

@property
def client_version(self) -> str:
if not self.cached_web3:
if not self._web3:
return ""

# NOTE: Gets reset to `None` on `connect()` and `disconnect()`.
if self.cached_client_version is None:
self.cached_client_version = self.web3.client_version
if self._client_version is None:
self._client_version = self.web3.client_version

return self.cached_client_version
return self._client_version

@property
def base_fee(self) -> int:
Expand Down Expand Up @@ -853,10 +850,10 @@ def _get_last_base_fee(self) -> int:

@property
def is_connected(self) -> bool:
if self.cached_web3 is None:
if self._web3 is None:
return False

return run_until_complete(self.cached_web3.is_connected())
return run_until_complete(self._web3.is_connected())

@property
def max_gas(self) -> int:
Expand Down
14 changes: 12 additions & 2 deletions src/ape/utils/basemodel.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC
from typing import TYPE_CHECKING, Any, ClassVar, Dict, Iterator, List, Optional, Union, cast
from typing import TYPE_CHECKING, Any, ClassVar, Dict, Iterator, List, Optional, Set, Union, cast

from ethpm_types import BaseModel as EthpmTypesBaseModel
from pydantic import BaseModel as RootBaseModel
Expand Down Expand Up @@ -193,6 +193,8 @@ class BaseModel(EthpmTypesBaseModel):

model_config = ConfigDict(arbitrary_types_allowed=True)

__checking__: Set[str] = set()

def __ape_extra_attributes__(self) -> Iterator[ExtraModelAttributes]:
"""
Override this method to supply extra attributes
Expand All @@ -214,8 +216,13 @@ def __getattr__(self, name: str) -> Any:
if name in self.__private_attributes__:
return super().__getattr__(name)

if name in self.__checking__:
# Prevent recursive error.
raise AttributeError(name)

self.__checking__.add(name)
try:
return super().__getattribute__(name)
res = super().__getattribute__(name)
except AttributeError:
extras_checked = set()
for ape_extra in self.__ape_extra_attributes__():
Expand All @@ -238,6 +245,9 @@ def __getattr__(self, name: str) -> Any:

raise ApeAttributeError(message)

self.__checking__.remove(name)
return res

def __getitem__(self, name: Any) -> Any:
# For __getitem__, we first try the extra (unlike `__getattr__`).
extras_checked = set()
Expand Down
10 changes: 5 additions & 5 deletions src/ape_geth/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,8 +308,8 @@ def _ots_api_level(self) -> Optional[int]:

def _set_web3(self):
# Clear cached version when connecting to another URI.
self.cached_client_version = None # type: ignore
self.cached_web3 = _create_web3(self.uri, ipc_path=self.ipc_path)
self._client_version = None # type: ignore
self._web3 = _create_web3(self.uri, ipc_path=self.ipc_path)

def _complete_connect(self):
client_version = self.client_version.lower()
Expand Down Expand Up @@ -359,8 +359,8 @@ def _complete_connect(self):

def disconnect(self):
self.can_use_parity_traces = None
self.cached_web3 = None # type: ignore
self.cached_client_version = None # type: ignore
self._web3 = None # type: ignore
self._client_version = None # type: ignore

def get_transaction_trace(self, txn_hash: str) -> Iterator[TraceFrame]:
frames = self._stream_request(
Expand Down Expand Up @@ -465,7 +465,7 @@ def _stream_request(self, method: str, params: List, iter_path="result.item"):


class GethDev(BaseGethProvider, TestProviderAPI, SubprocessProvider):
cached_process: Optional[GethDevProcess] = None
_process: Optional[GethDevProcess] = None
name: str = "geth"
can_use_parity_traces: Optional[bool] = False

Expand Down
30 changes: 12 additions & 18 deletions src/ape_test/provider.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import re
from ast import literal_eval
from functools import cached_property
from re import Pattern
from typing import Dict, Optional, cast

Expand Down Expand Up @@ -33,7 +34,7 @@ class EthTesterProviderConfig(PluginConfig):


class LocalProvider(TestProviderAPI, Web3Provider):
cached_evm_backend: Optional[PyEVMBackend] = None
_evm_backend: Optional[PyEVMBackend] = None
CANNOT_AFFORD_GAS_PATTERN: Pattern = re.compile(
r"Sender b'[\\*|\w]*' cannot afford txn gas (\d+) with account balance (\d+)"
)
Expand All @@ -43,39 +44,36 @@ class LocalProvider(TestProviderAPI, Web3Provider):

@property
def evm_backend(self) -> PyEVMBackend:
if self.cached_evm_backend is None:
if self._evm_backend is None:
raise ProviderNotConnectedError()

return self.cached_evm_backend
return self._evm_backend

def connect(self):
chain_id = self.settings.chain_id
if self.cached_web3 is not None:
if self._web3 is not None:
connected_chain_id = self._make_request("eth_chainId")
if connected_chain_id == chain_id:
# Is already connected and settings have not changed.
return

self.cached_evm_backend = PyEVMBackend.from_mnemonic(
self._evm_backend = PyEVMBackend.from_mnemonic(
mnemonic=self.config.mnemonic,
num_accounts=self.config.number_of_accounts,
)
endpoints = {**API_ENDPOINTS}
endpoints["eth"] = merge(endpoints["eth"], {"chainId": static_return(chain_id)})
tester = EthereumTesterProvider(
ethereum_tester=self.cached_evm_backend, api_endpoints=endpoints
)
self.cached_web3 = Web3(tester)
tester = EthereumTesterProvider(ethereum_tester=self._evm_backend, api_endpoints=endpoints)
self._web3 = Web3(tester)

def disconnect(self):
# NOTE: This type ignore seems like a bug in pydantic.
self.cached_chain_id = None # type: ignore
self.cached_web3 = None # type: ignore
self.cached_evm_backend = None # type: ignore
self._web3 = None # type: ignore
self._evm_backend = None # type: ignore
self.provider_settings = {}

def update_settings(self, new_settings: Dict):
self.cached_chain_id = None # type: ignore[assignment]
self._cached_chain_id = None # type: ignore[assignment]
self.provider_settings = {**self.provider_settings, **new_settings}
self.disconnect()
self.connect()
Expand Down Expand Up @@ -124,17 +122,13 @@ def settings(self) -> EthTesterProviderConfig:
{**self.config.provider.model_dump(mode="json"), **self.provider_settings}
)

@property
@cached_property
def chain_id(self) -> int:
if self.cached_chain_id is not None:
return self.cached_chain_id

try:
result = self._make_request("eth_chainId")
except ProviderNotConnectedError:
result = self.settings.chain_id

self.cached_chain_id = result
return result

@property
Expand Down
4 changes: 2 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,9 +214,9 @@ def eth_tester_provider(ethereum):
@pytest.fixture
def mock_provider(mock_web3, eth_tester_provider):
web3 = eth_tester_provider.web3
eth_tester_provider.cached_web3 = mock_web3
eth_tester_provider._web3 = mock_web3
yield eth_tester_provider
eth_tester_provider.cached_web3 = web3
eth_tester_provider._web3 = web3


@pytest.fixture
Expand Down
6 changes: 3 additions & 3 deletions tests/functional/geth/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,10 @@ def mock_geth(geth_provider, mock_web3):
data_folder=Path("."),
request_header={},
)
original_web3 = provider.cached_web3
provider.cached_web3 = mock_web3
original_web3 = provider._web3
provider._web3 = mock_web3
yield provider
provider.cached_web3 = original_web3
provider._web3 = original_web3


@pytest.fixture
Expand Down
2 changes: 1 addition & 1 deletion tests/functional/geth/test_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def test_chain_id_live_network_not_connected(networks):
@geth_process_test
def test_chain_id_live_network_connected_uses_web3_chain_id(mocker, geth_provider):
mock_network = mocker.MagicMock()
mock_network.chain_id = 999999999 # Shouldn't use hardcoded network
mock_network._chain_id = 999999999 # Shouldn't use hardcoded network
mock_network.name = "mock"
orig_network = geth_provider.network

Expand Down
8 changes: 4 additions & 4 deletions tests/functional/test_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,12 @@ def test_chain_id(eth_tester_provider):
def test_chain_id_is_cached(eth_tester_provider):
_ = eth_tester_provider.chain_id

# Unset `cached_web3` to show that it is not used in a second call to `chain_id`.
web3 = eth_tester_provider.cached_web3
eth_tester_provider.cached_web3 = None
# Unset `_web3` to show that it is not used in a second call to `chain_id`.
web3 = eth_tester_provider._web3
eth_tester_provider._web3 = None
chain_id = eth_tester_provider.chain_id
assert chain_id == DEFAULT_TEST_CHAIN_ID
eth_tester_provider.cached_web3 = web3 # Undo
eth_tester_provider._web3 = web3 # Undo


def test_chain_id_when_disconnected(eth_tester_provider):
Expand Down

0 comments on commit 27820bc

Please sign in to comment.