Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 35 additions & 27 deletions pyhilo/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
API_NOTIFICATIONS_ENDPOINT,
API_REGISTRATION_ENDPOINT,
API_REGISTRATION_HEADERS,
AUTOMATION_DEVICEHUB_ENDPOINT,
AUTOMATION_CHALLENGE_ENDPOINT,
DEFAULT_STATE_FILE,
DEFAULT_USER_AGENT,
FB_APP_ID,
Expand All @@ -51,7 +51,7 @@
get_state,
set_state,
)
from pyhilo.websocket import WebsocketClient
from pyhilo.websocket import WebsocketClient, WebsocketManager


class API:
Expand Down Expand Up @@ -81,9 +81,13 @@ def __init__(
self.device_attributes = get_device_attributes()
self.session: ClientSession = session
self._oauth_session = oauth_session
self.websocket: WebsocketClient
self.websocket_devices: WebsocketClient
self.websocket_challenges: WebsocketClient
self.log_traces = log_traces
self._get_device_callbacks: list[Callable[..., Any]] = []
self.ws_url: str = ""
self.ws_token: str = ""
self.endpoint: str = ""

@classmethod
async def async_create(
Expand Down Expand Up @@ -132,6 +136,9 @@ async def async_get_access_token(self) -> str:
if not self._oauth_session.valid_token:
await self._oauth_session.async_ensure_token_valid()

access_token = str(self._oauth_session.token["access_token"])
LOG.debug(f"ic-dev21 access token is {access_token}")

return str(self._oauth_session.token["access_token"])

def dev_atts(
Expand Down Expand Up @@ -216,17 +223,24 @@ async def _async_request(
:rtype: dict[str, Any]
"""
kwargs.setdefault("headers", self.headers)
access_token = await self.async_get_access_token()

if endpoint.startswith(API_REGISTRATION_ENDPOINT):
kwargs["headers"] = {**kwargs["headers"], **API_REGISTRATION_HEADERS}
if endpoint.startswith(FB_INSTALL_ENDPOINT):
kwargs["headers"] = {**kwargs["headers"], **FB_INSTALL_HEADERS}
if endpoint.startswith(ANDROID_CLIENT_ENDPOINT):
kwargs["headers"] = {**kwargs["headers"], **ANDROID_CLIENT_HEADERS}
if host == API_HOSTNAME:
access_token = await self.async_get_access_token()
kwargs["headers"]["authorization"] = f"Bearer {access_token}"
kwargs["headers"]["Host"] = host

# ic-dev21 trying Leicas suggestion
if endpoint.startswith(AUTOMATION_CHALLENGE_ENDPOINT):
# remove Ocp-Apim-Subscription-Key header to avoid 401 error
kwargs["headers"].pop("Ocp-Apim-Subscription-Key", None)
kwargs["headers"]["authorization"] = f"Bearer {access_token}"
Comment on lines +238 to +242
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removing the 'Ocp-Apim-Subscription-Key' header when the endpoint is 'AUTOMATION_CHALLENGE_ENDPOINT' might introduce issues if this header is required elsewhere in the system. It's crucial to document this change within the function's docstring or as comments explaining the broader context and its necessity.

Suggested change
# ic-dev21 trying Leicas suggestion
if endpoint.startswith(AUTOMATION_CHALLENGE_ENDPOINT):
# remove Ocp-Apim-Subscription-Key header to avoid 401 error
kwargs["headers"].pop("Ocp-Apim-Subscription-Key", None)
kwargs["headers"]["authorization"] = f"Bearer {access_token}"
// Consider adding more comments or documentation here to explain why removing the header is necessary

Comment on lines +239 to +242
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Handling different authorization methods for various endpoints in the same function increases complexity and potential for errors. Consider extracting this logic into a separate function to improve readability and maintainability.

Suggested change
if endpoint.startswith(AUTOMATION_CHALLENGE_ENDPOINT):
# remove Ocp-Apim-Subscription-Key header to avoid 401 error
kwargs["headers"].pop("Ocp-Apim-Subscription-Key", None)
kwargs["headers"]["authorization"] = f"Bearer {access_token}"
# Define a new function to handle request headers
def add_authorization_headers(endpoint, headers):
if endpoint.startswith(API_REGISTRATION_ENDPOINT):
headers.update(API_REGISTRATION_HEADERS)
elif endpoint.startswith(FB_INSTALL_ENDPOINT):
headers.update(FB_INSTALL_HEADERS)
elif endpoint.startswith(ANDROID_CLIENT_ENDPOINT):
headers.update(ANDROID_CLIENT_HEADERS)
if endpoint.startswith(AUTOMATION_CHALLENGE_ENDPOINT):
headers.pop("Ocp-Apim-Subscription-Key", None)


data: dict[str, Any] = {}
url = parse.urljoin(f"https://{host}", endpoint)
if self.log_traces:
Expand Down Expand Up @@ -303,8 +317,9 @@ async def _async_handle_on_backoff(self, _: dict[str, Any]) -> None:
LOG.info(
"401 detected on websocket, refreshing websocket token. Old url: {self.ws_url} Old Token: {self.ws_token}"
)
LOG.info(f"401 detected on {err.request_info.url}")
async with self._backoff_refresh_lock_ws:
(self.ws_url, self.ws_token) = await self.post_devicehub_negociate()
await self.refresh_ws_token()
await self.get_websocket_params()
return

Expand Down Expand Up @@ -354,30 +369,23 @@ async def _async_post_init(self) -> None:
LOG.debug("Websocket postinit")
await self._get_fid()
await self._get_device_token()
await self.refresh_ws_token()
self.websocket = WebsocketClient(self)

async def refresh_ws_token(self) -> None:
(self.ws_url, self.ws_token) = await self.post_devicehub_negociate()
await self.get_websocket_params()

async def post_devicehub_negociate(self) -> tuple[str, str]:
LOG.debug("Getting websocket url")
url = f"{AUTOMATION_DEVICEHUB_ENDPOINT}/negotiate"
LOG.debug(f"devicehub URL is {url}")
resp = await self.async_request("post", url)
ws_url = resp.get("url")
ws_token = resp.get("accessToken")
LOG.debug("Calling set_state devicehub_negotiate")
await set_state(
self._state_yaml,
"websocket",
{
"url": ws_url,
"token": ws_token,
},
# Initialize WebsocketManager ic-dev21
self.websocket_manager = WebsocketManager(
self.session, self.async_request, self._state_yaml, set_state
)
return (ws_url, ws_token)
await self.websocket_manager.initialize_websockets()

# Create both websocket clients
# ic-dev21 need to work on this as it can't lint as is, may need to
# instantiate differently
self.websocket_devices = WebsocketClient(self.websocket_manager.devicehub)
self.websocket_challenges = WebsocketClient(self.websocket_manager.challengehub)

async def refresh_ws_token(self) -> None:
"""Refresh the websocket token."""
await self.websocket_manager.refresh_token(self.websocket_manager.devicehub)
await self.websocket_manager.refresh_token(self.websocket_manager.challengehub)

async def get_websocket_params(self) -> None:
uri = parse.urlparse(self.ws_url)
Expand Down
2 changes: 2 additions & 0 deletions pyhilo/const.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@

# Automation server constant
AUTOMATION_DEVICEHUB_ENDPOINT: Final = "/DeviceHub"
AUTOMATION_CHALLENGE_ENDPOINT: Final = "/ChallengeHub"


# Request constants
DEFAULT_USER_AGENT: Final = f"PyHilo/{PYHILO_VERSION} HomeAssistant/{homeassistant.core.__version__} aiohttp/{aiohttp.__version__} Python/{platform.python_version()}"
Expand Down
13 changes: 11 additions & 2 deletions pyhilo/event.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
"""Event object """
from datetime import datetime, timedelta, timezone
import logging
import re
from typing import Any, cast

from pyhilo.util import camel_to_snake, from_utc_timestamp

LOG = logging.getLogger(__package__)


class Event:
setting_deadline: datetime
Expand Down Expand Up @@ -126,9 +129,12 @@ def current_phase_times(self) -> dict[str, datetime]:
@property
def state(self) -> str:
now = datetime.now(self.preheat_start.tzinfo)
if self.pre_cold_start <= now < self.pre_cold_end:
if self.pre_cold_start and self.pre_cold_start <= now < self.pre_cold_end:
return "pre_cold"
elif self.appreciation_start <= now < self.appreciation_end:
elif (
self.appreciation_start
and self.appreciation_start <= now < self.appreciation_end
):
return "appreciation"
elif self.preheat_start > now:
return "scheduled"
Expand All @@ -138,9 +144,12 @@ def state(self) -> str:
return "reduction"
elif self.recovery_start <= now < self.recovery_end:
return "recovery"
elif now >= self.recovery_end + timedelta(minutes=5):
return "off"
elif now >= self.recovery_end:
return "completed"
elif self.progress:
return self.progress

else:
return "unknown"
10 changes: 9 additions & 1 deletion pyhilo/util/__init__.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Define utility modules."""
import asyncio
from datetime import datetime, timedelta
import logging
import re
from typing import Any, Callable

Expand All @@ -9,6 +10,9 @@

from pyhilo.const import LOG # noqa: F401

LOG = logging.getLogger(__package__)


CAMEL_REX_1 = re.compile("(.)([A-Z][a-z]+)")
CAMEL_REX_2 = re.compile("([a-z0-9])([A-Z])")

Expand All @@ -35,7 +39,11 @@ def snake_to_camel(string: str) -> str:
def from_utc_timestamp(date_string: str) -> datetime:
from_zone = tz.tzutc()
to_zone = tz.tzlocal()
return parse(date_string).replace(tzinfo=from_zone).astimezone(to_zone)
dt = parse(date_string)
if dt.tzinfo is None: # Only replace tzinfo if not already set
dt = dt.replace(tzinfo=from_zone)
output = dt.astimezone(to_zone)
return output


def time_diff(ts1: datetime, ts2: datetime) -> timedelta:
Expand Down
Loading
Loading