Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ dev = [

[project.optional-dependencies]
ledger = ["ledgereth==0.10.0"]
trezor = ["trezor==0.13.10"]
trezor = ["trezor==0.20.1"]

[project.scripts]
safe-cli = "safe_cli.main:main"
Expand Down
3 changes: 3 additions & 0 deletions src/safe_cli/operators/hw_wallets/trezor_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from trezorlib.exceptions import (
Cancelled,
DeviceLockedError,
OutdatedFirmwareError,
PinException,
TrezorFailure,
Expand All @@ -27,6 +28,8 @@ def wrapper(*args, **kwargs):
raise HardwareWalletException("Wrong PIN") from None
except Cancelled:
raise HardwareWalletException("Trezor operation was cancelled") from None
except DeviceLockedError:
raise HardwareWalletException("Trezor device is locked") from None
except TransportException:
raise HardwareWalletException("Trezor device is not connected") from None
except InvalidDerivationPath as e:
Expand Down
38 changes: 24 additions & 14 deletions src/safe_cli/operators/hw_wallets/trezor_wallet.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,20 @@
from hexbytes import HexBytes
from safe_eth.safe.signatures import signature_split, signature_to_bytes
from trezorlib import tools
from trezorlib.client import TrezorClient, get_default_client
from trezorlib.cli import get_code_entry_code, get_passphrase
from trezorlib.cli.ui import ClickUI
from trezorlib.client import (
Session,
get_default_client,
get_default_session,
)
from trezorlib.ethereum import (
get_address,
sign_message,
sign_tx,
sign_tx_eip1559,
sign_typed_data_hash,
)
from trezorlib.ui import ClickUI
from web3.types import TxParams

from .hw_wallet import HwWallet
Expand All @@ -22,20 +27,25 @@

@cache
@raise_trezor_exception_as_hw_wallet_exception
def get_trezor_client() -> TrezorClient:
def get_trezor_session() -> Session:
"""
Return default trezor configuration that store passphrase on host.
This method is cached to share the same configuration between trezor calls while the class is not instantiated.
Return a default Trezor session, entering the passphrase on the host unless the device requires on-device entry.
This method is cached to share the same session between trezor calls while the class is not instantiated.
:return:
"""
ui = ClickUI(passphrase_on_host=True, always_prompt=True)
client = get_default_client(ui=ui)
return client
ui = ClickUI(always_prompt=True)
client = get_default_client(
"safe-cli",
button_callback=ui.button_request,
pin_callback=ui.get_pin,
code_entry_callback=get_code_entry_code,
)
return get_default_session(client, passphrase_callback=get_passphrase)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Preserve host passphrase prompts for capable devices

For passphrase-enabled devices that advertise PassphraseEntry (for example the Trezor T/Safe family), trezorlib.get_default_session() chooses on-device passphrase entry before it ever calls the supplied passphrase_callback, so this line ignores get_passphrase in exactly the common devices where host entry used to be forced by ClickUI(passphrase_on_host=True). Users who rely on entering the passphrase in the CLI will be pushed to device entry instead; prompt on the host and pass that value to client.get_session() unless the device's passphrase_always_on_device setting requires otherwise.

Useful? React with 👍 / 👎.



class TrezorWallet(HwWallet):
def __init__(self, derivation_path: str):
self.client: TrezorClient = get_trezor_client()
self.session: Session = get_trezor_session()
self.address_n = tools.parse_path(derivation_path)
super().__init__(derivation_path)

Expand All @@ -44,7 +54,7 @@ def get_address(self) -> ChecksumAddress:
"""
:return: public address for derivation_path
"""
return get_address(client=self.client, n=self.address_n)
return get_address(self.session, n=self.address_n)

@raise_trezor_exception_as_hw_wallet_exception
def sign_typed_hash(self, domain_hash: bytes, message_hash: bytes) -> bytes:
Expand All @@ -55,7 +65,7 @@ def sign_typed_hash(self, domain_hash: bytes, message_hash: bytes) -> bytes:
:return: signature bytes
"""
signed = sign_typed_data_hash(
self.client,
self.session,
n=self.address_n,
domain_hash=domain_hash,
message_hash=message_hash,
Expand All @@ -75,7 +85,7 @@ def get_signed_raw_transaction(
if tx_parameters.get("maxPriorityFeePerGas"):
# EIP1559
v, r, s = sign_tx_eip1559(
self.client,
self.session,
n=self.address_n,
nonce=tx_parameters["nonce"],
gas_limit=tx_parameters["gas"],
Expand Down Expand Up @@ -109,7 +119,7 @@ def get_signed_raw_transaction(
else:
# Legacy transaction
v, r, s = sign_tx(
self.client,
self.session,
n=self.address_n,
nonce=tx_parameters["nonce"],
gas_price=tx_parameters["gasPrice"],
Expand Down Expand Up @@ -144,7 +154,7 @@ def sign_message(self, message: bytes) -> bytes:
:param message:
:return: bytes signature
"""
signed = sign_message(self.client, self.address_n, message)
signed = sign_message(self.session, self.address_n, message)
# V field must be greater than 30 for signed messages. https://github.com/safe-global/safe-smart-account/blob/main/contracts/Safe.sol#L309
v, r, s = signature_split(signed.signature)
return signature_to_bytes(v + 4, r, s)
48 changes: 12 additions & 36 deletions tests/test_trezor_wallet.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,30 +9,28 @@
from safe_eth.safe import SafeTx
from safe_eth.safe.signatures import signature_split, signature_to_bytes
from safe_eth.safe.tests.safe_test_case import SafeTestCaseMixin
from trezorlib.client import TrezorClient
from trezorlib.exceptions import Cancelled, OutdatedFirmwareError, PinException
from trezorlib.messages import EthereumTypedDataSignature
from trezorlib.transport import TransportException
from trezorlib.ui import ClickUI

from safe_cli.operators.exceptions import HardwareWalletException
from safe_cli.operators.hw_wallets.trezor_wallet import TrezorWallet


class TestTrezorManager(SafeTestCaseMixin, unittest.TestCase):
@mock.patch(
"safe_cli.operators.hw_wallets.trezor_wallet.get_trezor_client",
"safe_cli.operators.hw_wallets.trezor_wallet.get_trezor_session",
return_value=None,
)
@mock.patch(
"safe_cli.operators.hw_wallets.trezor_wallet.get_address",
return_value=None,
)
def test_setup_trezor_wallet(
self, mock_trezor_client: MagicMock, mock_get_address: MagicMock
self, mock_trezor_session: MagicMock, mock_get_address: MagicMock
):
trezor_wallet = TrezorWallet("44'/60'/0'/0")
self.assertIsNone(trezor_wallet.client)
self.assertIsNone(trezor_wallet.session)

@mock.patch(
"safe_cli.operators.hw_wallets.trezor_wallet.sign_typed_data_hash",
Expand All @@ -43,21 +41,16 @@ def test_setup_trezor_wallet(
autospec=True,
)
@mock.patch(
"safe_cli.operators.hw_wallets.trezor_wallet.get_trezor_client",
"safe_cli.operators.hw_wallets.trezor_wallet.get_trezor_session",
autospec=True,
)
def test_hw_device_exception(
self,
mock_trezor_client: MagicMock,
mock_trezor_session: MagicMock,
mock_trezor_get_address: MagicMock,
mock_trezor_sign: MagicMock,
):
derivation_path = "44'/60'/0'/0"
transport_mock = MagicMock(auto_spec=True)
mock_trezor_client.return_value = TrezorClient(
transport_mock, ui=ClickUI(), _init_device=False
)
mock_trezor_client.return_value.is_outdated = MagicMock(return_value=False)
random_domain_bytes = os.urandom(32)
random_message_bytes = os.urandom(32)

Expand Down Expand Up @@ -104,19 +97,14 @@ def test_hw_device_exception(
autospec=True,
)
@mock.patch(
"safe_cli.operators.hw_wallets.trezor_wallet.get_trezor_client",
"safe_cli.operators.hw_wallets.trezor_wallet.get_trezor_session",
autospec=True,
)
def test_sign_typed_hash(
self, mock_trezor_client: MagicMock, mock_get_address: MagicMock
self, mock_trezor_session: MagicMock, mock_get_address: MagicMock
):
owner = Account.create()
to = Account.create()
transport_mock = MagicMock(auto_spec=True)
mock_trezor_client.return_value = TrezorClient(
transport_mock, ui=ClickUI(), _init_device=False
)
mock_trezor_client.return_value.is_outdated = MagicMock(return_value=False)
mock_get_address.return_value = owner.address
trezor_wallet = TrezorWallet("44'/60'/0'/0")

Expand Down Expand Up @@ -145,9 +133,7 @@ def test_sign_typed_hash(
trezor_return_signature = EthereumTypedDataSignature(
signature=expected_signature, address=trezor_wallet.address
)
mock_trezor_client.return_value.call = MagicMock(
return_value=trezor_return_signature
)
mock_trezor_session.return_value.call.return_value = trezor_return_signature
signature = trezor_wallet.sign_typed_hash(encode_hash[1], encode_hash[2])
self.assertEqual(expected_signature, signature)

Expand All @@ -164,23 +150,18 @@ def test_sign_typed_hash(
autospec=True,
)
@mock.patch(
"safe_cli.operators.hw_wallets.trezor_wallet.get_trezor_client",
"safe_cli.operators.hw_wallets.trezor_wallet.get_trezor_session",
autospec=True,
)
def test_get_signed_raw_transaction(
self,
mock_trezor_client: MagicMock,
mock_trezor_session: MagicMock,
mock_get_address: MagicMock,
mock_sign_tx_eip1559: MagicMock,
mock_sign_tx: MagicMock,
):
owner = Account.create()
to = Account.create()
transport_mock = MagicMock(auto_spec=True)
mock_trezor_client.return_value = TrezorClient(
transport_mock, ui=ClickUI(), _init_device=False
)
mock_trezor_client.return_value.is_outdated = MagicMock(return_value=False)
mock_get_address.return_value = owner.address
trezor_wallet = TrezorWallet("44'/60'/0'/0")

Expand Down Expand Up @@ -261,21 +242,16 @@ def test_get_signed_raw_transaction(
autospec=True,
)
@mock.patch(
"safe_cli.operators.hw_wallets.trezor_wallet.get_trezor_client",
"safe_cli.operators.hw_wallets.trezor_wallet.get_trezor_session",
autospec=True,
)
def test_get_sign_message(
self,
mock_trezor_client: MagicMock,
mock_trezor_session: MagicMock,
mock_get_address: MagicMock,
mock_sign_message: MagicMock,
):
owner = Account.create()
transport_mock = MagicMock(auto_spec=True)
mock_trezor_client.return_value = TrezorClient(
transport_mock, ui=ClickUI(), _init_device=False
)
mock_trezor_client.return_value.is_outdated = MagicMock(return_value=False)
mock_get_address.return_value = owner.address
trezor_wallet = TrezorWallet("44'/60'/0'/0")
expected_signature = HexBytes(
Expand Down
47 changes: 20 additions & 27 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading