Skip to content

Commit

Permalink
Fix update not working for Safes > 1.1.1
Browse files Browse the repository at this point in the history
- Update master copy function was removed after 1.1.1
- Migration contracts to update from 1.3.0 to 1.4.1 are still not available
- Test update
  • Loading branch information
Uxio0 committed Nov 6, 2023
1 parent a4ce370 commit a98b3e2
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 17 deletions.
4 changes: 4 additions & 0 deletions safe_cli/operators/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ class SafeAlreadyUpdatedException(SafeOperatorException):
pass


class SafeVersionNotSupportedException(SafeOperatorException):
pass


class UpdateAddressesNotValid(SafeOperatorException):
pass

Expand Down
32 changes: 25 additions & 7 deletions safe_cli/operators/safe_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
NotEnoughEtherToSend,
NotEnoughSignatures,
SafeAlreadyUpdatedException,
SafeVersionNotSupportedException,
SameFallbackHandlerException,
SameGuardException,
SameMasterCopyException,
Expand Down Expand Up @@ -213,26 +214,30 @@ def safe_cli_info(self) -> SafeCliInfo:
self._safe_cli_info = self.refresh_safe_cli_info()
return self._safe_cli_info

def refresh_safe_cli_info(self) -> SafeCliInfo:
self._safe_cli_info = self.get_safe_cli_info()
return self._safe_cli_info

def is_version_updated(self) -> bool:
"""
:return: True if Safe Master Copy is updated, False otherwise
"""

if self._safe_cli_info.master_copy == self.last_safe_contract_address:
last_safe_contract_address = self.last_safe_contract_address
if self.safe_cli_info.master_copy == last_safe_contract_address:
return True
else: # Check versions, maybe safe-cli addresses were not updated
try:
safe_contract_version = self.safe.retrieve_version()
safe_contract_version = Safe(
last_safe_contract_address, self.ethereum_client
).retrieve_version()
except BadFunctionCallOutput: # Safe master copy is not deployed or errored, maybe custom network
return True # We cannot say you are not updated ¯\_(ツ)_/¯

return semantic_version.parse(
self.safe_cli_info.version
) >= semantic_version.parse(safe_contract_version)

def refresh_safe_cli_info(self) -> SafeCliInfo:
self._safe_cli_info = self.get_safe_cli_info()
return self._safe_cli_info

def load_cli_owners_from_words(self, words: List[str]):
if len(words) == 1: # Reading seed from Environment Variable
words = os.environ.get(words[0], default="").strip().split(" ")
Expand Down Expand Up @@ -531,6 +536,12 @@ def change_master_copy(self, new_master_copy: str) -> bool:
if new_master_copy == self.safe_cli_info.master_copy:
raise SameMasterCopyException(new_master_copy)
else:
safe_version = self.safe.retrieve_version()
if semantic_version.parse(safe_version) >= semantic_version.parse("1.3.0"):
raise SafeVersionNotSupportedException(
f"{safe_version} cannot be updated (yet)"
)

try:
Safe(new_master_copy, self.ethereum_client).retrieve_version()
except BadFunctionCallOutput:
Expand All @@ -550,8 +561,15 @@ def update_version(self) -> Optional[bool]:
:return:
"""

safe_version = self.safe.retrieve_version()
if semantic_version.parse(safe_version) >= semantic_version.parse("1.3.0"):
raise SafeVersionNotSupportedException(
f"{safe_version} cannot be updated (yet)"
)

if self.is_version_updated():
raise SafeAlreadyUpdatedException()
raise SafeAlreadyUpdatedException(f"{safe_version} already updated")

addresses = (
self.last_safe_contract_address,
Expand Down
6 changes: 6 additions & 0 deletions safe_cli/prompt_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
NotEnoughSignatures,
NotEnoughTokenToSend,
SafeAlreadyUpdatedException,
SafeOperatorException,
SafeVersionNotSupportedException,
SameFallbackHandlerException,
SameMasterCopyException,
SenderRequiredException,
Expand Down Expand Up @@ -110,6 +112,8 @@ def wrapper(*args, **kwargs):
print_formatted_text(HTML(f"<ansired>{e.args[0]}</ansired>"))
except SafeAlreadyUpdatedException:
print_formatted_text(HTML("<ansired>Safe is already updated</ansired>"))
except SafeVersionNotSupportedException as e:
print_formatted_text(HTML(f"<ansired>{e.args[0]}</ansired>"))
except (NotEnoughEtherToSend, NotEnoughTokenToSend) as e:
print_formatted_text(
HTML(
Expand All @@ -127,6 +131,8 @@ def wrapper(*args, **kwargs):
print_formatted_text(
HTML(f"<ansired>HwDevice exception: {e.args[0]}</ansired>")
)
except SafeOperatorException as e:
print_formatted_text(HTML(f"<ansired>{e.args[0]}</ansired>"))

return wrapper

Expand Down
62 changes: 52 additions & 10 deletions tests/test_safe_operator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import unittest
from functools import lru_cache
from unittest import mock
from unittest.mock import MagicMock
from unittest.mock import MagicMock, PropertyMock

from eth_account import Account
from eth_typing import ChecksumAddress
Expand All @@ -25,6 +25,7 @@
NonExistingOwnerException,
NotEnoughEtherToSend,
NotEnoughSignatures,
SafeVersionNotSupportedException,
SameFallbackHandlerException,
SameGuardException,
SameMasterCopyException,
Expand All @@ -38,6 +39,18 @@


class TestSafeOperator(SafeCliTestCaseMixin, unittest.TestCase):
@lru_cache(maxsize=None)
def _deploy_l2_migration_contract(self) -> ChecksumAddress:
# Deploy L2 migration contract
safe_to_l2_migration_contract = self.w3.eth.contract(
abi=safe_to_l2_migration["abi"], bytecode=safe_to_l2_migration["bytecode"]
)
tx_hash = safe_to_l2_migration_contract.constructor().transact(
{"from": self.ethereum_test_account.address}
)
tx_receipt = self.w3.eth.wait_for_transaction_receipt(tx_hash)
return tx_receipt["contractAddress"]

def test_setup_operator(self):
for number_owners in range(1, 4):
safe_operator = self.setup_operator(number_owners=number_owners)
Expand Down Expand Up @@ -211,6 +224,10 @@ def test_change_guard(self):
self.assertEqual(safe.retrieve_guard(), new_guard)

def test_change_master_copy(self):
safe_operator = self.setup_operator(version="1.3.0")
with self.assertRaises(SafeVersionNotSupportedException):
safe_operator.change_master_copy(self.safe_contract_V1_4_1.address)

safe_operator = self.setup_operator(version="1.1.1")
safe = Safe(safe_operator.address, self.ethereum_client)
current_master_copy = safe.retrieve_master_copy_address()
Expand Down Expand Up @@ -246,17 +263,42 @@ def test_send_ether(self):
self.assertTrue(safe_operator.send_ether(random_address, value))
self.assertEqual(self.ethereum_client.get_balance(random_address), value)

@lru_cache(maxsize=None)
def _deploy_l2_migration_contract(self) -> ChecksumAddress:
# Deploy L2 migration contract
safe_to_l2_migration_contract = self.w3.eth.contract(
abi=safe_to_l2_migration["abi"], bytecode=safe_to_l2_migration["bytecode"]
@mock.patch.object(
SafeOperator, "last_default_fallback_handler_address", new_callable=PropertyMock
)
@mock.patch.object(
SafeOperator, "last_safe_contract_address", new_callable=PropertyMock
)
def test_update_version(
self,
last_safe_contract_address_mock: PropertyMock,
last_default_fallback_handler_address: PropertyMock,
):
last_safe_contract_address_mock.return_value = self.safe_contract_V1_4_1.address
last_default_fallback_handler_address.return_value = (
self.compatibility_fallback_handler.address
)
tx_hash = safe_to_l2_migration_contract.constructor().transact(
{"from": self.ethereum_test_account.address}

safe_operator_v130 = self.setup_operator(version="1.3.0")
with self.assertRaises(SafeVersionNotSupportedException):
safe_operator_v130.update_version()

safe_operator_v111 = self.setup_operator(version="1.1.1")
with mock.patch.object(
MultiSend,
"MULTISEND_CALL_ONLY_ADDRESSES",
[self.multi_send_contract.address],
):
safe_operator_v111.update_version()

self.assertEqual(
safe_operator_v111.safe.retrieve_master_copy_address(),
last_safe_contract_address_mock.return_value,
)
self.assertEqual(
safe_operator_v111.safe.retrieve_fallback_handler(),
last_default_fallback_handler_address.return_value.return_value,
)
tx_receipt = self.w3.eth.wait_for_transaction_receipt(tx_hash)
return tx_receipt["contractAddress"]

def test_update_to_l2_v111(self):
migration_contract_address = self._deploy_l2_migration_contract()
Expand Down

0 comments on commit a98b3e2

Please sign in to comment.