Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: negative block number support in ContractLog.range queries #2388

Merged
merged 3 commits into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 49 additions & 19 deletions src/ape/contracts/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,19 +634,19 @@ def query(
# perf: pandas import is really slow. Avoid importing at module level.
import pandas as pd

HEAD = self.chain_manager.blocks.height
if start_block < 0:
start_block = self.chain_manager.blocks.height + start_block
start_block = HEAD + start_block

if stop_block is None:
stop_block = self.chain_manager.blocks.height
stop_block = HEAD

elif stop_block < 0:
stop_block = self.chain_manager.blocks.height + stop_block
stop_block = HEAD + stop_block

elif stop_block > self.chain_manager.blocks.height:
elif stop_block > HEAD:
raise ChainError(
f"'stop={stop_block}' cannot be greater than "
f"the chain length ({self.chain_manager.blocks.height})."
f"'stop={stop_block}' cannot be greater than the chain length ({HEAD})."
)
query: dict = {
"columns": list(ContractLog.__pydantic_fields__) if columns[0] == "*" else columns,
Expand Down Expand Up @@ -692,12 +692,12 @@ def range(
Returns:
Iterator[:class:`~ape.contracts.base.ContractLog`]
"""

if not (contract_address := getattr(self.contract, "address", None)):
return

start_block = None
stop_block = None
HEAD = self.chain_manager.blocks.height # Current block height

if stop is None:
contract = None
Expand All @@ -706,27 +706,57 @@ def range(
except Exception:
pass

if contract:
if creation := contract.creation_metadata:
start_block = creation.block

stop_block = start_or_stop
# Determine the start block from contract creation metadata
if contract and (creation := contract.creation_metadata):
start_block = creation.block

# Handle single parameter usage (like Python's range(stop))
if start_or_stop == 0:
# stop==0 is the same as stop==HEAD
# because of the -1 (turns to negative).
stop_block = HEAD + 1
elif start_or_stop >= 0:
# Given like range(1)
stop_block = min(start_or_stop - 1, HEAD)
else:
# Give like range(-1)
stop_block = HEAD + start_or_stop

elif start_or_stop is not None and stop is not None:
start_block = start_or_stop
stop_block = stop - 1

stop_block = min(stop_block, self.chain_manager.blocks.height)
# Handle cases where both start and stop are provided
if start_or_stop >= 0:
start_block = min(start_or_stop, HEAD)
else:
# Negative start relative to HEAD
adjusted_value = HEAD + start_or_stop + 1
start_block = max(adjusted_value, 0)

if stop == 0:
# stop==0 is the same as stop==HEAD
# because of the -1 (turns to negative).
stop_block = HEAD
elif stop > 0:
# Positive stop, capped to the chain HEAD
stop_block = min(stop - 1, HEAD)
else:
# Negative stop.
adjusted_value = HEAD + stop
stop_block = max(adjusted_value, 0)

# Gather all addresses to query (contract and any extra ones provided)
addresses = list(set([contract_address] + (extra_addresses or [])))

# Construct the event query
contract_event_query = ContractEventQuery(
columns=list(ContractLog.__pydantic_fields__),
columns=list(ContractLog.__pydantic_fields__), # Ensure all necessary columns
contract=addresses,
event=self.abi,
search_topics=search_topics,
start_block=start_block or 0,
stop_block=stop_block,
start_block=start_block or 0, # Default to block 0 if not set
stop_block=stop_block, # None means query to the current HEAD
)

# Execute the query and yield results
yield from self.query_manager.query(contract_event_query) # type: ignore

def from_receipt(self, receipt: "ReceiptAPI") -> list[ContractLog]:
Expand Down
81 changes: 63 additions & 18 deletions tests/functional/test_contract_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@


@pytest.fixture
def assert_log_values(owner, chain):
def assert_log_values(owner):
def _assert_log_values(log: ContractLog, number: int, previous_number: Optional[int] = None):
assert isinstance(log.b, bytes)
expected_previous_number = number - 1 if previous_number is None else previous_number
Expand All @@ -32,7 +32,7 @@ def _assert_log_values(log: ContractLog, number: int, previous_number: Optional[
return _assert_log_values


def test_contract_logs_from_receipts(owner, contract_instance, assert_log_values):
def test_from_receipts(owner, contract_instance, assert_log_values):
event_type = contract_instance.NumberChange

# Invoke a transaction 3 times that generates 3 logs.
Expand All @@ -55,7 +55,7 @@ def assert_receipt_logs(receipt: "ReceiptAPI", num: int):
assert_receipt_logs(receipt_2, 3)


def test_contract_logs_from_event_type(contract_instance, owner, assert_log_values):
def test_from_event_type(contract_instance, owner, assert_log_values):
event_type = contract_instance.NumberChange
start_num = 6
size = 20
Expand All @@ -76,7 +76,7 @@ def test_contract_logs_from_event_type(contract_instance, owner, assert_log_valu
assert_log_values(log, num)


def test_contract_logs_index_access(contract_instance, owner, assert_log_values):
def test_index_access(contract_instance, owner, assert_log_values):
event_type = contract_instance.NumberChange

contract_instance.setNumber(1, sender=owner)
Expand All @@ -93,7 +93,7 @@ def test_contract_logs_index_access(contract_instance, owner, assert_log_values)
assert event_type[-1] == contract_instance.NumberChange(newNum=3, prevNum=2)


def test_contract_logs_splicing(contract_instance, owner, assert_log_values):
def test_splicing(contract_instance, owner, assert_log_values):
event_type = contract_instance.NumberChange

contract_instance.setNumber(1, sender=owner)
Expand All @@ -113,7 +113,7 @@ def test_contract_logs_splicing(contract_instance, owner, assert_log_values):
assert_log_values(log, 2)


def test_contract_logs_range(chain, contract_instance, owner, assert_log_values):
def test_range(chain, contract_instance, owner, assert_log_values):
contract_instance.setNumber(1, sender=owner)
start = chain.blocks.height
logs = [
Expand All @@ -126,7 +126,7 @@ def test_contract_logs_range(chain, contract_instance, owner, assert_log_values)
assert_log_values(logs[0], 1)


def test_contract_logs_range_by_address(
def test_range_by_address(
mocker, chain, eth_tester_provider, accounts, contract_instance, owner, assert_log_values
):
get_logs_spy = mocker.spy(eth_tester_provider.tester.ethereum_tester, "get_logs")
Expand Down Expand Up @@ -157,29 +157,27 @@ def test_contract_logs_range_by_address(
assert logs == [contract_instance.AddressChange(newAddress=accounts[1])]


def test_contracts_log_multiple_addresses(
def test_range_multiple_addresses(
chain, contract_instance, contract_container, owner, assert_log_values
):
another_instance = contract_container.deploy(0, sender=owner)
start_block = chain.blocks.height
contract_instance.setNumber(1, sender=owner)
another_instance.setNumber(1, sender=owner)

logs = [
log
for log in contract_instance.NumberChange.range(
logs = list(
contract_instance.NumberChange.range(
start_block,
start_block + 100,
search_topics={"newNum": 1},
extra_addresses=[another_instance.address],
)
]
assert len(logs) == 2, "Unexpected number of logs"
)
assert len(logs) == 2, f"Unexpected number of logs: {len(logs)}"
assert logs[0] == contract_instance.NumberChange(newNum=1, prevNum=0)
assert logs[1] == another_instance.NumberChange(newNum=1, prevNum=0)


def test_contract_logs_range_start_and_stop(contract_instance, owner, chain):
def test_range_start_and_stop(contract_instance, owner, chain):
# Create 1 event
contract_instance.setNumber(1, sender=owner)

Expand All @@ -194,16 +192,63 @@ def test_contract_logs_range_start_and_stop(contract_instance, owner, chain):
assert len(logs) == 3, "Unexpected number of logs"


def test_contract_logs_range_only_stop(contract_instance, owner, chain):
# Create 1 event
def test_range_only_stop(contract_instance, owner, chain):
# Create 3 events
start = chain.blocks.height
contract_instance.setNumber(1, sender=owner)
contract_instance.setNumber(2, sender=owner)
contract_instance.setNumber(3, sender=owner)

stop = start + 100 # Stop can be bigger than height, it doesn't not matter
logs = [log for log in contract_instance.NumberChange.range(stop)]
assert len(logs) >= 3, "Unexpected number of logs"
assert len(logs) >= 3, f"Unexpected number of logs: {len(logs)}"


def test_range_negative_start(contract_instance, owner):
# Create 2 events
contract_instance.setNumber(1, sender=owner)
contract_instance.setNumber(2, sender=owner)
logs = [log for log in contract_instance.NumberChange.range(-2, 0)]
assert len(logs) == 2


def test_range_negative_start_and_stop(contract_instance, owner):
# Create 3 events
contract_instance.setNumber(1, sender=owner)
contract_instance.setNumber(2, sender=owner)
contract_instance.setNumber(3, sender=owner)

query_result = [log for log in contract_instance.NumberChange.range(-1, 0)]
assert len(query_result) == 1, "Should only be 1"
assert query_result[0].newNum == 3 # Was the last parameter.
query_result = [log for log in contract_instance.NumberChange.range(-2, -1)]
assert len(query_result) == 1, "Should only be 1"
assert query_result[0].newNum == 2 # Was the penultimate parameter.
query_result = [log for log in contract_instance.NumberChange.range(-3, -2)]
assert len(query_result) == 1, "Should only be 1"
assert query_result[0].newNum == 1 # Was the penultimate parameter.
logs = [log for log in contract_instance.NumberChange.range(-3, -1)]
assert len(logs) == 2
assert [x.newNum for x in logs] == [1, 2]
logs = [log for log in contract_instance.NumberChange.range(-3, 0)]
assert len(logs) == 3
assert [x.newNum for x in logs] == [1, 2, 3]


def test_range_negative_stop_only(contract_instance, owner):
# Create 2 events
contract_instance.setNumber(1, sender=owner)
contract_instance.setNumber(2, sender=owner)

# Get _all_ logs.
logs = [log for log in contract_instance.NumberChange.range(0)]
assert len(logs) == 2
assert [x.newNum for x in logs] == [1, 2]

# Basically means go from 0 to the second to last
logs = [log for log in contract_instance.NumberChange.range(-1)]
assert len(logs) == 1
assert logs[0].newNum == 1


def test_poll_logs_stop_block_not_in_future(
Expand Down
Loading