diff --git a/roborock/devices/device_manager.py b/roborock/devices/device_manager.py index 0be98ea1..c8c9bb4d 100644 --- a/roborock/devices/device_manager.py +++ b/roborock/devices/device_manager.py @@ -173,6 +173,7 @@ def create_web_api_wrapper( *, cache: Cache | None = None, session: aiohttp.ClientSession | None = None, + unauthorized_hook: SessionUnauthorizedHook | None = None, ) -> UserWebApiClient: """Create a home data API wrapper from an existing API client.""" @@ -180,7 +181,7 @@ def create_web_api_wrapper( # by caching this next to `UserData` if needed to avoid unnecessary API calls. client = RoborockApiClient(username=user_params.username, base_url=user_params.base_url, session=session) - return UserWebApiClient(client, user_params.user_data) + return UserWebApiClient(client, user_params.user_data, unauthorized_hook=unauthorized_hook) async def create_device_manager( @@ -212,7 +213,9 @@ async def create_device_manager( if cache is None: cache = NoCache() - web_api = create_web_api_wrapper(user_params, session=session, cache=cache) + web_api = create_web_api_wrapper( + user_params, session=session, cache=cache, unauthorized_hook=mqtt_session_unauthorized_hook + ) user_data = user_params.user_data diagnostics = Diagnostics() @@ -264,6 +267,12 @@ def device_creator(home_data: HomeData, device: HomeDataDevice, product: HomeDat dev.add_ready_callback(ready_callback) return dev - manager = DeviceManager(web_api, device_creator, mqtt_session=mqtt_session, cache=cache, diagnostics=diagnostics) + manager = DeviceManager( + web_api, + device_creator, + mqtt_session=mqtt_session, + cache=cache, + diagnostics=diagnostics, + ) await manager.discover_devices(prefer_cache) return manager diff --git a/roborock/web_api.py b/roborock/web_api.py index a76d14c5..a6eb75ab 100644 --- a/roborock/web_api.py +++ b/roborock/web_api.py @@ -6,6 +6,7 @@ import secrets import string import time +from collections.abc import Callable from dataclasses import dataclass import aiohttp @@ -737,23 +738,46 @@ class UserWebApiClient: to avoid needing to pass UserData around and mock out the web API. """ - def __init__(self, web_api: RoborockApiClient, user_data: UserData) -> None: + def __init__( + self, web_api: RoborockApiClient, user_data: UserData, unauthorized_hook: Callable[[], None] | None = None + ) -> None: """Initialize the wrapper with the API client and user data.""" self._web_api = web_api self._user_data = user_data + self._unauthorized_hook = unauthorized_hook async def get_home_data(self) -> HomeData: """Fetch home data using the API client.""" - return await self._web_api.get_home_data_v3(self._user_data) + try: + return await self._web_api.get_home_data_v3(self._user_data) + except RoborockInvalidCredentials: + if self._unauthorized_hook: + self._unauthorized_hook() + raise async def get_routines(self, device_id: str) -> list[HomeDataScene]: """Fetch routines (scenes) for a specific device.""" - return await self._web_api.get_scenes(self._user_data, device_id) + try: + return await self._web_api.get_scenes(self._user_data, device_id) + except RoborockInvalidCredentials: + if self._unauthorized_hook: + self._unauthorized_hook() + raise async def get_rooms(self) -> list[HomeDataRoom]: """Fetch rooms using the API client.""" - return await self._web_api.get_rooms(self._user_data) + try: + return await self._web_api.get_rooms(self._user_data) + except RoborockInvalidCredentials: + if self._unauthorized_hook: + self._unauthorized_hook() + raise async def execute_routine(self, scene_id: int) -> None: """Execute a specific routine (scene) by its ID.""" - await self._web_api.execute_scene(self._user_data, scene_id) + try: + await self._web_api.execute_scene(self._user_data, scene_id) + except RoborockInvalidCredentials: + if self._unauthorized_hook: + self._unauthorized_hook() + raise diff --git a/tests/devices/test_device_manager.py b/tests/devices/test_device_manager.py index 397ae2c5..86ce2559 100644 --- a/tests/devices/test_device_manager.py +++ b/tests/devices/test_device_manager.py @@ -13,7 +13,7 @@ from roborock.devices.cache import InMemoryCache from roborock.devices.device import RoborockDevice from roborock.devices.device_manager import UserParams, create_device_manager, create_web_api_wrapper -from roborock.exceptions import RoborockException +from roborock.exceptions import RoborockException, RoborockInvalidCredentials from tests import mock_data USER_DATA = UserData.from_dict(mock_data.USER_DATA) @@ -150,6 +150,19 @@ async def test_create_home_data_api_exception() -> None: await api.get_home_data() +async def test_device_manager_unauthorized_hook() -> None: + """Test that unauthorized hook is called when RoborockInvalidCredentials is raised.""" + mock_hook = Mock() + with patch( + "roborock.devices.device_manager.RoborockApiClient.get_home_data_v3", + side_effect=RoborockInvalidCredentials("Unauthorized"), + ): + with pytest.raises(RoborockInvalidCredentials, match="Unauthorized"): + await create_device_manager(USER_PARAMS, mqtt_session_unauthorized_hook=mock_hook, prefer_cache=False) + + mock_hook.assert_called_once() + + @pytest.mark.parametrize(("prefer_cache", "expected_call_count"), [(True, 1), (False, 2)]) async def test_cache_logic(prefer_cache: bool, expected_call_count: int) -> None: """Test that the cache logic works correctly.""" diff --git a/tests/test_web_api.py b/tests/test_web_api.py index 429783fc..935b9a7e 100644 --- a/tests/test_web_api.py +++ b/tests/test_web_api.py @@ -1,13 +1,14 @@ import re from typing import Any +from unittest.mock import AsyncMock, Mock import aiohttp import pytest from aioresponses.compat import normalize_url from roborock import HomeData, HomeDataScene, UserData -from roborock.exceptions import RoborockAccountDoesNotExist -from roborock.web_api import IotLoginInfo, RoborockApiClient +from roborock.exceptions import RoborockAccountDoesNotExist, RoborockInvalidCredentials +from roborock.web_api import IotLoginInfo, RoborockApiClient, UserWebApiClient from tests.mock_data import HOME_DATA_RAW, USER_DATA pytest_plugins = [ @@ -374,3 +375,26 @@ async def test_get_schedules(mock_rest) -> None: assert schedule.cron == "03 13 15 12 ?" assert schedule.repeated is False assert schedule.enabled is True + + +async def test_user_web_api_client_unauthorized_hook() -> None: + """Test that UserWebApiClient triggers unauthorized hook on RoborockInvalidCredentials.""" + mock_hook = Mock() + mock_api = AsyncMock(spec=RoborockApiClient) + + # Setup mock to raise RoborockInvalidCredentials + mock_api.get_home_data_v3.side_effect = RoborockInvalidCredentials("Unauthorized") + + client = UserWebApiClient(mock_api, UserData.from_dict(USER_DATA), unauthorized_hook=mock_hook) + + with pytest.raises(RoborockInvalidCredentials): + await client.get_home_data() + + mock_hook.assert_called_once() + + # Test another method + mock_hook.reset_mock() + mock_api.get_rooms.side_effect = RoborockInvalidCredentials("Unauthorized") + with pytest.raises(RoborockInvalidCredentials): + await client.get_rooms() + mock_hook.assert_called_once()