diff --git a/pyhilo/api.py b/pyhilo/api.py index 761e23f..2d111e2 100755 --- a/pyhilo/api.py +++ b/pyhilo/api.py @@ -27,7 +27,7 @@ API_NOTIFICATIONS_ENDPOINT, API_REGISTRATION_ENDPOINT, API_REGISTRATION_HEADERS, - AUTOMATION_DEVICEHUB_ENDPOINT, + AUTOMATION_CHALLENGE_ENDPOINT, DEFAULT_STATE_FILE, DEFAULT_USER_AGENT, FB_APP_ID, @@ -51,7 +51,7 @@ get_state, set_state, ) -from pyhilo.websocket import WebsocketClient +from pyhilo.websocket import WebsocketClient, WebsocketManager class API: @@ -81,9 +81,17 @@ def __init__( self.device_attributes = get_device_attributes() self.session: ClientSession = session self._oauth_session = oauth_session + self.websocket_devices: WebsocketClient + # Backward compatibility during transition to websocket for challenges. Currently the HA Hilo integration + # uses the .websocket attribute. Re-added this attribute and point to the same object as websocket_devices. + # Should be removed once the transition to the challenge websocket is completed everywhere. self.websocket: WebsocketClient + self.websocket_challenges: WebsocketClient self.log_traces = log_traces self._get_device_callbacks: list[Callable[..., Any]] = [] + self.ws_url: str = "" + self.ws_token: str = "" + self.endpoint: str = "" @classmethod async def async_create( @@ -132,6 +140,9 @@ async def async_get_access_token(self) -> str: if not self._oauth_session.valid_token: await self._oauth_session.async_ensure_token_valid() + access_token = str(self._oauth_session.token["access_token"]) + LOG.debug(f"ic-dev21 access token is {access_token}") + return str(self._oauth_session.token["access_token"]) def dev_atts( @@ -216,6 +227,8 @@ async def _async_request( :rtype: dict[str, Any] """ kwargs.setdefault("headers", self.headers) + access_token = await self.async_get_access_token() + if endpoint.startswith(API_REGISTRATION_ENDPOINT): kwargs["headers"] = {**kwargs["headers"], **API_REGISTRATION_HEADERS} if endpoint.startswith(FB_INSTALL_ENDPOINT): @@ -223,10 +236,15 @@ async def _async_request( if endpoint.startswith(ANDROID_CLIENT_ENDPOINT): kwargs["headers"] = {**kwargs["headers"], **ANDROID_CLIENT_HEADERS} if host == API_HOSTNAME: - access_token = await self.async_get_access_token() kwargs["headers"]["authorization"] = f"Bearer {access_token}" kwargs["headers"]["Host"] = host + # ic-dev21 trying Leicas suggestion + if endpoint.startswith(AUTOMATION_CHALLENGE_ENDPOINT): + # remove Ocp-Apim-Subscription-Key header to avoid 401 error + kwargs["headers"].pop("Ocp-Apim-Subscription-Key", None) + kwargs["headers"]["authorization"] = f"Bearer {access_token}" + data: dict[str, Any] = {} url = parse.urljoin(f"https://{host}", endpoint) if self.log_traces: @@ -303,8 +321,9 @@ async def _async_handle_on_backoff(self, _: dict[str, Any]) -> None: LOG.info( "401 detected on websocket, refreshing websocket token. Old url: {self.ws_url} Old Token: {self.ws_token}" ) + LOG.info(f"401 detected on {err.request_info.url}") async with self._backoff_refresh_lock_ws: - (self.ws_url, self.ws_token) = await self.post_devicehub_negociate() + await self.refresh_ws_token() await self.get_websocket_params() return @@ -354,30 +373,26 @@ async def _async_post_init(self) -> None: LOG.debug("Websocket postinit") await self._get_fid() await self._get_device_token() - await self.refresh_ws_token() - self.websocket = WebsocketClient(self) - async def refresh_ws_token(self) -> None: - (self.ws_url, self.ws_token) = await self.post_devicehub_negociate() - await self.get_websocket_params() - - async def post_devicehub_negociate(self) -> tuple[str, str]: - LOG.debug("Getting websocket url") - url = f"{AUTOMATION_DEVICEHUB_ENDPOINT}/negotiate" - LOG.debug(f"devicehub URL is {url}") - resp = await self.async_request("post", url) - ws_url = resp.get("url") - ws_token = resp.get("accessToken") - LOG.debug("Calling set_state devicehub_negotiate") - await set_state( - self._state_yaml, - "websocket", - { - "url": ws_url, - "token": ws_token, - }, + # Initialize WebsocketManager ic-dev21 + self.websocket_manager = WebsocketManager( + self.session, self.async_request, self._state_yaml, set_state ) - return (ws_url, ws_token) + await self.websocket_manager.initialize_websockets() + + # Create both websocket clients + # ic-dev21 need to work on this as it can't lint as is, may need to + # instantiate differently + self.websocket_devices = WebsocketClient(self.websocket_manager.devicehub) + + # For backward compatibility during the transition to challengehub websocket + self.websocket = self.websocket_devices + self.websocket_challenges = WebsocketClient(self.websocket_manager.challengehub) + + async def refresh_ws_token(self) -> None: + """Refresh the websocket token.""" + await self.websocket_manager.refresh_token(self.websocket_manager.devicehub) + await self.websocket_manager.refresh_token(self.websocket_manager.challengehub) async def get_websocket_params(self) -> None: uri = parse.urlparse(self.ws_url) diff --git a/pyhilo/const.py b/pyhilo/const.py old mode 100644 new mode 100755 index 6441f6d..a53edb1 --- a/pyhilo/const.py +++ b/pyhilo/const.py @@ -42,6 +42,8 @@ # Automation server constant AUTOMATION_DEVICEHUB_ENDPOINT: Final = "/DeviceHub" +AUTOMATION_CHALLENGE_ENDPOINT: Final = "/ChallengeHub" + # Request constants DEFAULT_USER_AGENT: Final = f"PyHilo/{PYHILO_VERSION} HomeAssistant/{homeassistant.core.__version__} aiohttp/{aiohttp.__version__} Python/{platform.python_version()}" diff --git a/pyhilo/event.py b/pyhilo/event.py old mode 100644 new mode 100755 index 5044262..d6bf32a --- a/pyhilo/event.py +++ b/pyhilo/event.py @@ -1,10 +1,13 @@ """Event object """ from datetime import datetime, timedelta, timezone +import logging import re from typing import Any, cast from pyhilo.util import camel_to_snake, from_utc_timestamp +LOG = logging.getLogger(__package__) + class Event: setting_deadline: datetime @@ -126,9 +129,12 @@ def current_phase_times(self) -> dict[str, datetime]: @property def state(self) -> str: now = datetime.now(self.preheat_start.tzinfo) - if self.pre_cold_start <= now < self.pre_cold_end: + if self.pre_cold_start and self.pre_cold_start <= now < self.pre_cold_end: return "pre_cold" - elif self.appreciation_start <= now < self.appreciation_end: + elif ( + self.appreciation_start + and self.appreciation_start <= now < self.appreciation_end + ): return "appreciation" elif self.preheat_start > now: return "scheduled" @@ -138,9 +144,12 @@ def state(self) -> str: return "reduction" elif self.recovery_start <= now < self.recovery_end: return "recovery" + elif now >= self.recovery_end + timedelta(minutes=5): + return "off" elif now >= self.recovery_end: return "completed" elif self.progress: return self.progress + else: return "unknown" diff --git a/pyhilo/util/__init__.py b/pyhilo/util/__init__.py old mode 100644 new mode 100755 index d4bd2e0..be18023 --- a/pyhilo/util/__init__.py +++ b/pyhilo/util/__init__.py @@ -1,6 +1,7 @@ """Define utility modules.""" import asyncio from datetime import datetime, timedelta +import logging import re from typing import Any, Callable @@ -9,6 +10,9 @@ from pyhilo.const import LOG # noqa: F401 +LOG = logging.getLogger(__package__) + + CAMEL_REX_1 = re.compile("(.)([A-Z][a-z]+)") CAMEL_REX_2 = re.compile("([a-z0-9])([A-Z])") @@ -35,7 +39,11 @@ def snake_to_camel(string: str) -> str: def from_utc_timestamp(date_string: str) -> datetime: from_zone = tz.tzutc() to_zone = tz.tzlocal() - return parse(date_string).replace(tzinfo=from_zone).astimezone(to_zone) + dt = parse(date_string) + if dt.tzinfo is None: # Only replace tzinfo if not already set + dt = dt.replace(tzinfo=from_zone) + output = dt.astimezone(to_zone) + return output def time_diff(ts1: datetime, ts2: datetime) -> timedelta: diff --git a/pyhilo/websocket.py b/pyhilo/websocket.py index ee2b6fc..fe0a506 100755 --- a/pyhilo/websocket.py +++ b/pyhilo/websocket.py @@ -7,9 +7,10 @@ from enum import IntEnum import json from os import environ -from typing import TYPE_CHECKING, Any, Callable, Dict +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple +from urllib import parse -from aiohttp import ClientWebSocketResponse, WSMsgType +from aiohttp import ClientSession, ClientWebSocketResponse, WSMsgType from aiohttp.client_exceptions import ( ClientError, ServerDisconnectedError, @@ -17,7 +18,12 @@ ) from yarl import URL -from pyhilo.const import DEFAULT_USER_AGENT, LOG +from pyhilo.const import ( + AUTOMATION_CHALLENGE_ENDPOINT, + AUTOMATION_DEVICEHUB_ENDPOINT, + DEFAULT_USER_AGENT, + LOG, +) from pyhilo.exceptions import ( CannotConnectError, ConnectionClosedError, @@ -208,16 +214,19 @@ async def _async_send_json(self, payload: dict[str, Any]) -> None: if self._api.log_traces: LOG.debug( - f"[TRACE] Sending data to websocket server: {json.dumps(payload)}" + f"[TRACE] Sending data to websocket {self._api.endpoint} : {json.dumps(payload)}" ) # Hilo added a control character (chr(30)) at the end of each payload they send. # They also expect this char to be there at the end of every payload we send them. + LOG.debug(f"ic-dev21 send_json {payload}") await self._client.send_str(json.dumps(payload) + chr(30)) def _parse_message(self, msg: dict[str, Any]) -> None: """Parse an incoming message.""" if self._api.log_traces: - LOG.debug(f"[TRACE] Received message from websocket: {msg}") + LOG.debug( + f"[TRACE] Received message on websocket(_parse_message) {self._api.endpoint}: {msg}" + ) if msg.get("type") == SignalRMsgType.PING: schedule_callback(self._async_pong) return @@ -247,7 +256,7 @@ def add_disconnect_callback( return self._add_callback(self._disconnect_callbacks, callback) def add_event_callback(self, callback: Callable[..., Any]) -> Callable[..., None]: - """Add a callback callback to be called upon receiving an event. + """Add a callback to be called upon receiving an event. Note that callbacks should expect to receive a WebsocketEvent object as a parameter. :param callback: The method to call after receiving an event. @@ -261,7 +270,7 @@ async def async_connect(self) -> None: LOG.debug("Websocket: async_connect() called but already connected") return - LOG.info("Websocket: Connecting to server") + LOG.info("Websocket: Connecting to server %s", self._api.endpoint) if self._api.log_traces: LOG.debug(f"[TRACE] Websocket URL: {self._api.full_ws_url}") headers = { @@ -296,7 +305,7 @@ async def async_connect(self) -> None: LOG.error(f"Unable to connect to WS server {err}") raise CannotConnectError(err) from err - LOG.info("Connected to websocket server") + LOG.info(f"Connected to websocket server {self._api.endpoint}") self._watchdog.trigger() for callback in self._connect_callbacks: schedule_callback(callback) @@ -368,6 +377,9 @@ async def async_invoke( except asyncio.TimeoutError: return self._ready_event.clear() + LOG.debug( + f"ic-dev21 invoke argument: {arg}, invocationId: {inv_id}, target: {target}, type: {type}" + ) await self._async_send_json( { "arguments": arg, @@ -376,3 +388,144 @@ async def async_invoke( "type": inv_type, } ) + + +@dataclass +class WebsocketConfig: + """Configuration for a websocket connection""" + + endpoint: str + url: Optional[str] = None + token: Optional[str] = None + connection_id: Optional[str] = None + full_ws_url: Optional[str] = None + log_traces: bool = True + session: ClientSession | None = None + + +class WebsocketManager: + """Manages multiple websocket connections for the Hilo API""" + + def __init__( + self, + session: ClientSession, + async_request: Callable[..., Any], + state_yaml: str, + set_state_callback: Callable[..., Any], + ) -> None: + """Initialize the websocket manager. + + Args: + session: The aiohttp client session + async_request: The async request method from the API class + state_yaml: Path to the state file + set_state_callback: Callback to save state + """ + self.session = session + self.async_request = async_request + self._state_yaml = state_yaml + self._set_state = set_state_callback + self._shared_token: Optional[str] = None + # Initialize websocket configurations, more can be added here + self.devicehub = WebsocketConfig( + endpoint=AUTOMATION_DEVICEHUB_ENDPOINT, session=session + ) + self.challengehub = WebsocketConfig( + endpoint=AUTOMATION_CHALLENGE_ENDPOINT, session=session + ) + + async def initialize_websockets(self) -> None: + """Initialize both websocket connections""" + # ic-dev21 get token from device hub + await self.refresh_token(self.devicehub, get_new_token=True) + # ic-dev21 get token from challenge hub + await self.refresh_token(self.challengehub, get_new_token=True) + + async def refresh_token( + self, config: WebsocketConfig, get_new_token: bool = True + ) -> None: + """Refresh token for a specific websocket configuration. + Args: + config: The websocket configuration to refresh + """ + if get_new_token: + config.url, self._shared_token = await self._negotiate(config) + config.token = self._shared_token + else: + config.url, _ = await self._negotiate(config) + config.token = self._shared_token + + await self._get_websocket_params(config) + + async def _negotiate(self, config: WebsocketConfig) -> Tuple[str, str]: + """Negotiate websocket connection and get URL and token. + Args: + config: The websocket configuration to negotiate + Returns: + Tuple containing the websocket URL and access token + """ + LOG.debug(f"Getting websocket url for {config.endpoint}") + url = f"{config.endpoint}/negotiate" + LOG.debug(f"Negotiate URL is {url}") + + resp = await self.async_request("post", url) + ws_url = resp.get("url") + ws_token = resp.get("accessToken") + + # Save state + state_key = ( + "websocketDevices" + if config.endpoint == AUTOMATION_DEVICEHUB_ENDPOINT + else "websocketChallenges" + ) + await self._set_state( + self._state_yaml, + state_key, + { + "url": ws_url, + "token": ws_token, + }, + ) + + return ws_url, ws_token + + async def _get_websocket_params(self, config: WebsocketConfig) -> None: + """Get websocket parameters including connection ID. + + Args: + config: The websocket configuration to get parameters for + """ + uri = parse.urlparse(config.url) + LOG.debug(f"Getting websocket params for {config.endpoint}") + LOG.debug(f"Getting uri {uri}") + + resp = await self.async_request( + "post", + f"{uri.path}negotiate?{uri.query}", # type: ignore + host=uri.netloc, + headers={ + "authorization": f"Bearer {config.token}", + }, + ) + + config.connection_id = resp.get("connectionId", "") + config.full_ws_url = ( + f"{config.url}&id={config.connection_id}&access_token={config.token}" + ) + LOG.debug(f"Getting full ws URL {config.full_ws_url}") + + transport_dict = resp.get("availableTransports", []) + websocket_dict = { + "connection_id": config.connection_id, + "available_transports": transport_dict, + "full_url": config.full_ws_url, + } + + # Save state + state_key = ( + "websocketDevices" + if config.endpoint == AUTOMATION_DEVICEHUB_ENDPOINT + else "websocketChallenges" + ) + LOG.debug(f"Calling set_state {state_key}_params") + await self._set_state(self._state_yaml, state_key, websocket_dict)