11# copyright (c) 2020, Matthias Dellweg
22# GNU General Public License v3.0+ (see LICENSE or https://www.gnu.org/licenses/gpl-3.0.txt)
33
4+ import asyncio
45import base64
56import datetime
67import json
78import os
9+ import ssl
810import typing as t
911from collections import defaultdict
1012from contextlib import suppress
1113from io import BufferedReader
1214from urllib .parse import urljoin
1315
16+ import aiohttp
1417import requests
1518import urllib3
1619
@@ -174,6 +177,9 @@ def __init__(
174177 self ._safe_calls_only : bool = safe_calls_only
175178 self ._headers = headers or {}
176179 self ._verify = verify
180+ # Shall we make that a parameter?
181+ self ._ssl_context : t .Optional [t .Union [ssl .SSLContext , bool ]] = None
182+
177183 self ._auth_provider = auth_provider
178184 self ._cert = cert
179185 self ._key = key
@@ -225,6 +231,22 @@ def base_url(self) -> str:
225231 def cid (self ) -> t .Optional [str ]:
226232 return self ._headers .get ("Correlation-Id" )
227233
234+ @property
235+ def ssl_context (self ) -> t .Union [ssl .SSLContext , bool ]:
236+ if self ._ssl_context is None :
237+ if self ._verify is False :
238+ self ._ssl_context = False
239+ else :
240+ if isinstance (self ._verify , str ):
241+ self ._ssl_context = ssl .create_default_context (cafile = self ._verify )
242+ else :
243+ self ._ssl_context = ssl .create_default_context ()
244+ if self ._cert is not None :
245+ self ._ssl_context .load_cert_chain (self ._cert , self ._key )
246+ # Type inference is failing here.
247+ self ._ssl_context = t .cast (t .Union [ssl .SSLContext | bool ], self ._ssl_context )
248+ return self ._ssl_context
249+
228250 def load_api (self , refresh_cache : bool = False ) -> None :
229251 # TODO: Find a way to invalidate caches on upstream change
230252 xdg_cache_home : str = os .environ .get ("XDG_CACHE_HOME" ) or "~/.cache"
@@ -242,7 +264,7 @@ def load_api(self, refresh_cache: bool = False) -> None:
242264 self ._parse_api (data )
243265 except Exception :
244266 # Try again with a freshly downloaded version
245- data = self ._download_api ()
267+ data = asyncio . run ( self ._download_api () )
246268 self ._parse_api (data )
247269 # Write to cache as it seems to be valid
248270 os .makedirs (os .path .dirname (apidoc_cache ), exist_ok = True )
@@ -262,15 +284,18 @@ def _parse_api(self, data: bytes) -> None:
262284 if method in {"get" , "put" , "post" , "delete" , "options" , "head" , "patch" , "trace" }
263285 }
264286
265- def _download_api (self ) -> bytes :
287+ async def _download_api (self ) -> bytes :
266288 try :
267- response : requests .Response = self ._session .get (urljoin (self ._base_url , self ._doc_path ))
268- except requests .RequestException as e :
289+ connector = aiohttp .TCPConnector (ssl = self .ssl_context )
290+ async with aiohttp .ClientSession (connector = connector , headers = self ._headers ) as session :
291+ async with session .get (urljoin (self ._base_url , self ._doc_path )) as response :
292+ response .raise_for_status ()
293+ data = await response .read ()
294+ except aiohttp .ClientError as e :
269295 raise OpenAPIError (str (e ))
270- response .raise_for_status ()
271296 if "Correlation-ID" in response .headers :
272297 self ._set_correlation_id (response .headers ["Correlation-ID" ])
273- return response . content
298+ return data
274299
275300 def _set_correlation_id (self , correlation_id : str ) -> None :
276301 if "Correlation-ID" in self ._headers :
0 commit comments