-
Notifications
You must be signed in to change notification settings - Fork 8
switch get_client to async #15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: richo/client-factory
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,6 +4,7 @@ | |
| and retrieve statistics from FranklinWH energy gateway devices. | ||
| """ | ||
|
|
||
| from collections.abc import Awaitable, Callable | ||
| from dataclasses import dataclass | ||
| from enum import Enum | ||
| import hashlib | ||
|
|
@@ -291,16 +292,28 @@ class DeviceTimeoutException(Exception): | |
| class GatewayOfflineException(Exception): | ||
| """raised when the gateway is offline.""" | ||
|
|
||
|
|
||
| class HttpClientFactory: | ||
| # If you store a function in an attribute, it becomes a bound method | ||
| factory = (lambda: httpx.AsyncClient(http2=True),) | ||
| """Factory for creating httpx.AsyncClient.""" | ||
|
|
||
| @staticmethod | ||
| async def _default_get_client() -> httpx.AsyncClient: | ||
| return httpx.AsyncClient(http2=True) | ||
|
|
||
| factory: Callable[..., Awaitable[httpx.AsyncClient]] = _default_get_client | ||
|
|
||
| @classmethod | ||
| def set_client_factory( | ||
| cls, factory: Callable[..., Awaitable[httpx.AsyncClient]] | ||
| ) -> None: | ||
| """Set the async factory method for creating HTTP/2 clients.""" | ||
| cls.factory = factory | ||
|
|
||
| @classmethod | ||
| def set_client_factory(cls, factory): | ||
| cls.factory = (factory,) | ||
| async def get_client(cls) -> httpx.AsyncClient: | ||
| """Create a new httpx.AsyncClient using the configured async factory method.""" | ||
| return await cls.factory() | ||
|
|
||
| def get_client(self): | ||
| return self.factory[0]() | ||
|
|
||
| class TokenFetcher(HttpClientFactory): | ||
| """Fetches and refreshes authentication tokens for FranklinWH API.""" | ||
|
|
@@ -311,7 +324,7 @@ def __init__(self, username: str, password: str) -> None: | |
| self.password = password | ||
| self.info: dict | None = None | ||
|
|
||
| async def get_token(self): | ||
| async def get_token(self) -> str: | ||
| """Fetch a new authentication token using the stored credentials. | ||
|
|
||
| Store the intermediate account information in self.info. | ||
|
|
@@ -320,15 +333,11 @@ async def get_token(self): | |
| return self.info["token"] | ||
|
|
||
| @staticmethod | ||
| async def login(username: str, password: str): | ||
| async def login(username: str, password: str) -> None: | ||
| """Log in to the FranklinWH API and retrieve an authentication token.""" | ||
| await TokenFetcher(username, password).get_token() | ||
|
|
||
| @staticmethod | ||
| async def _login(username: str, password: str) -> dict: | ||
| await TokenFetcher(username, password).get_token() | ||
|
|
||
| async def fetch_token(self): | ||
| async def fetch_token(self) -> dict: | ||
| """Log in to the FranklinWH API and retrieve account information.""" | ||
| url = ( | ||
| DEFAULT_URL_BASE + "hes-gateway/terminal/initialize/appUserOrInstallerLogin" | ||
|
|
@@ -339,7 +348,7 @@ async def fetch_token(self): | |
| "lang": "en_US", | ||
| "type": 1, | ||
| } | ||
| async with self.get_client() as client: | ||
| async with await self.get_client() as client: | ||
| res = await client.post(url, data=form, timeout=10) | ||
| res.raise_for_status() | ||
| js = res.json() | ||
|
|
@@ -374,54 +383,61 @@ def __init__( | |
| self.url_base = url_base | ||
| self.token = "" | ||
| self.snno = 0 | ||
| self.session = self.get_client() | ||
| self.session: httpx.AsyncClient | None = None | ||
|
|
||
| # to enable detailed logging add this to configuration.yaml: | ||
| # logger: | ||
| # logs: | ||
| # franklinwh: debug | ||
|
|
||
| logger = logging.getLogger("franklinwh") | ||
| logger.warning("Session class: %s" % type(self.session)) | ||
| self.logger = logger | ||
| if logger.isEnabledFor(logging.DEBUG): | ||
|
|
||
| async def debug_request(request: httpx.Request): | ||
| body = request.content | ||
| if body and request.headers.get("Content-Type", "").startswith( | ||
| "application/json" | ||
| ): | ||
| body = json.dumps(json.loads(body), ensure_ascii=False) | ||
| self.logger.debug( | ||
| "Request: %s %s %s %s", | ||
| request.method, | ||
| request.url, | ||
| request.headers, | ||
| body, | ||
| ) | ||
| return request | ||
|
|
||
| async def debug_response(response: httpx.Response): | ||
| await response.aread() | ||
| self.logger.debug( | ||
| "Response: %s %s %s %s", | ||
| response.status_code, | ||
| response.url, | ||
| response.headers, | ||
| response.json(), | ||
| ) | ||
| return response | ||
|
|
||
| self.logger = logging.getLogger("franklinwh") | ||
|
|
||
| async def get_client(self) -> httpx.AsyncClient: | ||
| """Return the session or create a new session with optional debug logging.""" | ||
| if self.session is None: | ||
| self.session = await super().get_client() | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think you explicitly don't want to use super here? As it stands you can set different factories in the subclasses if that's seomthing you care about for whatever reason.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Client IS a subclass of HttpClientFactory and this override of
thus the explicit call to i guess the same effect could be achieved by setting a different factory but how would that then conflict with hass, which also needs to
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm with you now, sorry I was misreading what's going on a little. I'm with you now. |
||
| if self.logger.isEnabledFor(logging.DEBUG): | ||
|
|
||
| async def debug_request(request: httpx.Request): | ||
| body = request.content | ||
| if body and request.headers.get("Content-Type", "").startswith( | ||
| "application/json" | ||
| ): | ||
| body = json.dumps(json.loads(body), ensure_ascii=False) | ||
| self.logger.debug( | ||
| "Request: %s %s %s %s", | ||
| request.method, | ||
| request.url, | ||
| request.headers, | ||
| body, | ||
| ) | ||
| return request | ||
|
|
||
| async def debug_response(response: httpx.Response): | ||
| await response.aread() | ||
| self.logger.debug( | ||
| "Response: %s %s %s %s", | ||
| response.status_code, | ||
| response.url, | ||
| response.headers, | ||
| response.json(), | ||
| ) | ||
| return response | ||
|
|
||
| self.session.event_hooks["request"].append(debug_request) | ||
| self.session.event_hooks["response"].append(debug_response) | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was just rooting around in httpx internals and realised we can do this instead of trying to build a new client, I'll make sure I cherry pick this regardless of which session injection approach merges.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i already did this in #13. |
||
| return self.session | ||
|
|
||
| # TODO(richo) Setup timeouts and deal with them gracefully. | ||
| async def _post(self, url, payload, params: dict | None = None): | ||
| session = await self.get_client() | ||
| if params is not None: | ||
| params = params.copy() | ||
| params.update({"gatewayId": self.gateway, "lang": "en_US"}) | ||
|
|
||
| async def __post(): | ||
| return ( | ||
| await self.session.post( | ||
| await session.post( | ||
| url, | ||
| params=params, | ||
| headers={ | ||
|
|
@@ -435,9 +451,11 @@ async def __post(): | |
| return await retry(__post, lambda j: j["code"] != 401, self.refresh_token) | ||
|
|
||
| async def _post_form(self, url, payload): | ||
| session = await self.get_client() | ||
|
|
||
| async def __post(): | ||
| return ( | ||
| await self.session.post( | ||
| await session.post( | ||
| url, | ||
| headers={ | ||
| "loginToken": self.token, | ||
|
|
@@ -451,6 +469,7 @@ async def __post(): | |
| return await retry(__post, lambda j: j["code"] != 401, self.refresh_token) | ||
|
|
||
| async def _get(self, url, params: dict | None = None): | ||
| session = await self.get_client() | ||
| if params is None: | ||
| params = {} | ||
| else: | ||
|
|
@@ -459,7 +478,7 @@ async def _get(self, url, params: dict | None = None): | |
|
|
||
| async def __get(): | ||
| return ( | ||
| await self.session.get( | ||
| await session.get( | ||
| url, params=params, headers={"loginToken": self.token} | ||
| ) | ||
| ).json() | ||
|
|
@@ -579,7 +598,6 @@ async def get_stats(self) -> Stats: | |
|
|
||
| This includes instantaneous measurements for current power, as well as totals for today (in local time) | ||
| """ | ||
| self.logger.warning("get_stats: Session class: %s" % type(self.session)) | ||
| data = await self._status() | ||
| grid_status: GridStatus = GridStatus.NORMAL | ||
| if "offgridreason" in data: | ||
|
|
@@ -620,8 +638,8 @@ def next_snno(self): | |
| return self.snno | ||
|
|
||
| def _build_payload(self, ty, data): | ||
| blob = json.dumps(data, separators=(",", ":")).encode("utf-8") | ||
| # crc = to_hex(zlib.crc32(blob.encode("ascii"))) | ||
| raw = json.dumps(data, separators=(",", ":")) | ||
| blob = raw.encode("utf-8") | ||
| crc = to_hex(zlib.crc32(blob)) | ||
| ts = int(time.time()) | ||
|
|
||
|
|
@@ -639,7 +657,7 @@ def _build_payload(self, ty, data): | |
| } | ||
| ) | ||
| # We do it this way because without a canonical way to generate JSON we can't risk reordering breaking the CRC. | ||
| return temp.replace('"DATA"', blob.decode("utf-8")) | ||
| return temp.replace('"DATA"', raw) | ||
|
|
||
| async def _mqtt_send(self, payload): | ||
| url = DEFAULT_URL_BASE + "hes-gateway/terminal/sendMqtt" | ||
|
|
@@ -680,21 +698,24 @@ async def get_controllable_loads(self): | |
| ) | ||
| params = {"id": self.gateway, "lang": "en_US"} | ||
| headers = {"loginToken": self.token} | ||
| res = await self.session.get(url, params=params, headers=headers) | ||
| session = await self.get_client() | ||
| res = await session.get(url, params=params, headers=headers) | ||
| return res.json() | ||
|
|
||
| async def get_accessory_list(self): | ||
| """Get the list of accessories connected to the gateway.""" | ||
| url = self.url_base + "hes-gateway/terminal/getIotAccessoryList" | ||
| params = {"gatewayId": self.gateway, "lang": "en_US"} | ||
| headers = {"loginToken": self.token} | ||
| res = await self.session.get(url, params=params, headers=headers) | ||
| session = await self.get_client() | ||
| res = await session.get(url, params=params, headers=headers) | ||
| return res.json() | ||
|
|
||
| async def get_equipment_list(self): | ||
| """Get the list of equipment connected to the gateway.""" | ||
| url = self.url_base + "hes-gateway/manage/getEquipmentList" | ||
| params = {"gatewayId": self.gateway, "lang": "en_US"} | ||
| headers = {"loginToken": self.token} | ||
| res = await self.session.get(url, params=params, headers=headers) | ||
| session = await self.get_client() | ||
| res = await session.get(url, params=params, headers=headers) | ||
| return res.json() | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure I understand what's made better by making this all be async? Is the idea here that it will loan cleanly in hass now and we don't have to use their helper?
I feel like at best creating a new client shouldn't do io and at worst it'll do it once.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
with the API shift to async all session activity occurs in async context, which also works best for hass. EXCEPT
Client.__init__, which CANNOT be async, created the session in sync context and, to your point, did I/O to load the SSL certificates from filesystem. this causes the warning from hass because Client is created inasync_setup_platformand hass notices the blocking I/O in async context.this change pushes session creation to async context uniformly, removing it from
__init__, soget_client()becomes async. but look at https://github.com/richo/homeassistant-franklinwh/pull/57/changes#diff-adf648fcc09089e05b962654454a3cbdfcfe9884f8feee8c10fb6bee086a859dR73-R76: due to the I/O we still require hass' helper.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
again, contrast with #13 where we create the session in known (hass, async) context.