diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 36f15ea..e4d3516 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -24,11 +24,14 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install pytest pycodestyle + pip install pytest pycodestyle pyright websockets if [ -f requirements.txt ]; then pip install -r requirements.txt; fi - name: Run tests with pytest run: | pytest - name: Lint with PyCodeStyle run: | - find . -name \*.py -exec pycodestyle {} + \ No newline at end of file + find . -name \*.py -exec pycodestyle {} + + - name: Code validation using Pyright + run: | + pyright \ No newline at end of file diff --git a/examples/bookstore/webserver.py b/examples/bookstore/webserver.py index 54f7ea9..74e69f3 100644 --- a/examples/bookstore/webserver.py +++ b/examples/bookstore/webserver.py @@ -28,7 +28,7 @@ import asyncio from sys import argv from functools import partial -from aiohttp import web +from aiohttp import web # type: ignore from thingsdb.client import Client from thingsdb.room import Room, event @@ -62,12 +62,14 @@ def on_cleanup(): async def add_book(request): book = await request.json() # Use the procedure to add the book + assert bookstore await bookstore.add_book(book) return web.HTTPNoContent() # We have the books in memory, no need for a query async def get_books(request): + assert bookstore return web.json_response({ "book_titles": [book['title'] for book in bookstore.books] }) diff --git a/thingsdb/client/buildin.py b/thingsdb/client/buildin.py index cb0d3b8..ef163a7 100644 --- a/thingsdb/client/buildin.py +++ b/thingsdb/client/buildin.py @@ -1,4 +1,6 @@ +import asyncio import datetime +from abc import ABC, abstractmethod from typing import Union as U from typing import Optional from typing import Any @@ -9,6 +11,15 @@ class Buildin: # # Build-in functions from the @thingsdb scope # + @abstractmethod + def query( + self, + code: str, + scope: Optional[str] = None, + timeout: Optional[int] = None, + skip_strip_code: bool = False, + **kwargs: Any) -> asyncio.Future[Any]: + ... async def collection_info(self, collection: U[int, str]) -> dict: """Returns information about a specific collection. @@ -245,8 +256,8 @@ async def new_token( expiration_time: Optional[datetime.datetime] = None, description: str = ''): - if expiration_time is not None: - expiration_time = int(datetime.datetime.timestamp(expiration_time)) + ts = None if expiration_time is None \ + else int(expiration_time.timestamp()) return await self.query( """//ti @@ -254,7 +265,7 @@ async def new_token( new_token(user, et, description); """, user=user, - expiration_time=expiration_time, + expiration_time=ts, description=description, scope='@t') @@ -334,7 +345,8 @@ async def set_module_scope( module_scope=scope, scope='@t') - async def set_password(self, user: str, new_password: str = None) -> None: + async def set_password(self, user: str, + new_password: Optional[str] = None) -> None: return await self.query( 'set_password(user, new_password)', user=user, @@ -412,8 +424,7 @@ async def new_backup( max_files: Optional[int] = 7, scope='@n'): - if start_ts is not None: - start_ts = int(datetime.datetime.timestamp(start_ts)) + ts = None if start_ts is None else int(start_ts.timestamp()) return await self.query( """//ti @@ -421,7 +432,7 @@ async def new_backup( new_backup(file_template, start_ts, repeat, max_files); """, file_template=file_template, - start_ts=start_ts, + start_ts=ts, repeat=repeat, max_files=max_files, scope=scope) @@ -454,14 +465,14 @@ async def restart_module(self, name: str) -> None: return await self.query('restart_module(name)', name=name, scope='@t') async def set_log_level(self, log_level: str, scope='@n') -> None: - log_level = ( + level = ( 'DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL').index(log_level) return await self.query( - 'set_log_level(log_level)', log_level=log_level, scope=scope) + 'set_log_level(log_level)', log_level=level, scope=scope) async def shutdown(self, scope='@n') -> None: """Shutdown the node in the selected scope. diff --git a/thingsdb/client/client.py b/thingsdb/client/client.py index 2267206..c766a9d 100644 --- a/thingsdb/client/client.py +++ b/thingsdb/client/client.py @@ -5,7 +5,7 @@ import time from collections import defaultdict from ssl import SSLContext, PROTOCOL_TLS -from typing import Optional, Union, Any +from typing import Optional, Union, Any, List from concurrent.futures import CancelledError from .buildin import Buildin from .protocol import Proto, Protocol, ProtocolWS @@ -135,7 +135,7 @@ def connection_info(self) -> str: """ if not self.is_connected(): return 'disconnected' - socket = self._protocol.info() + socket = self._protocol.info() # type: ignore if socket is None: return 'unknown_addr' addr, port = socket.getpeername()[:2] @@ -145,7 +145,7 @@ def connect_pool( self, pool: list, *auth: Union[str, tuple] - ) -> asyncio.Future: + ) -> asyncio.Future[None]: """Connect using a connection pool. When using a connection pool, the client will randomly choose a node @@ -183,21 +183,23 @@ def connect_pool( assert self._reconnecting is False assert len(pool), 'pool must contain at least one node' if len(auth) == 1: - auth = auth[0] + auth = auth[0] # type: ignore self._pool = tuple(( (address, 9200) if isinstance(address, str) else address for address in pool)) self._auth = self._auth_check(auth) self._pool_idx = random.randint(0, len(pool) - 1) - return self.reconnect() + fut = self.reconnect() + if fut is None: + raise ConnectionError('client already connecting') + return fut - def connect( + async def connect( self, host: str, port: int = 9200, - timeout: Optional[int] = 5 - ) -> asyncio.Future: + timeout: Optional[int] = 5): """Connect to ThingsDB. This method will *only* create a connection, so the connection is not @@ -231,9 +233,9 @@ def connect( assert self.is_connected() is False self._pool = ((host, port),) self._pool_idx = 0 - return self._connect(timeout=timeout) + await self._connect(timeout=timeout) - def reconnect(self) -> Optional[asyncio.Future]: + def reconnect(self) -> Optional[asyncio.Future[Any]]: """Re-connect to ThingsDB. This method can be used, even when a connection still exists. In case @@ -286,7 +288,7 @@ async def authenticate( wait forever on a response. Defaults to 5. """ if len(auth) == 1: - auth = auth[0] + auth = auth[0] # type: ignore self._auth = self._auth_check(auth) await self._authenticate(timeout) @@ -297,7 +299,7 @@ def query( timeout: Optional[int] = None, skip_strip_code: bool = False, **kwargs: Any - ) -> asyncio.Future: + ) -> asyncio.Future[Any]: """Query ThingsDB. Use this method to run `code` in a scope. @@ -348,9 +350,10 @@ def query( data = [scope, code] if kwargs: - data.append(kwargs) + data.append(kwargs) # type: ignore - return self._write_pkg(Proto.REQ_QUERY, data, timeout=timeout) + return self._write_pkg( + Proto.REQ_QUERY, data, timeout=timeout) # type: ignore async def _ensure_write( self, @@ -358,7 +361,7 @@ async def _ensure_write( data: Any = None, is_bin: bool = False, timeout: Optional[int] = None - ) -> asyncio.Future: + ) -> asyncio.Future[Any]: if not self._pool: raise ConnectionError('no connection') @@ -372,6 +375,7 @@ async def _ensure_write( continue try: + assert self._protocol # we're connected res = await self._protocol.write(tp, data, is_bin, timeout) except (asyncio.exceptions.CancelledError, CancelledError, NodeError, AuthError) as e: @@ -394,9 +398,10 @@ async def _write( data: Any = None, is_bin: bool = False, timeout: Optional[int] = None - ) -> asyncio.Future: + ) -> asyncio.Future[Any]: if not self.is_connected(): raise ConnectionError('no connection') + assert self._protocol # we are connected return await self._protocol.write(tp, data, is_bin, timeout) def run( @@ -406,7 +411,7 @@ def run( scope: Optional[str] = None, timeout: Optional[int] = None, **kwargs: Any, - ) -> asyncio.Future: + ) -> asyncio.Future[Any]: """Run a procedure. Use this method to run a stored procedure in a scope. @@ -449,23 +454,23 @@ def run( data = [scope, procedure] if args: - data.append(args) + data.append(args) # type: ignore if kwargs: raise ValueError( 'it is not possible to use both keyword arguments ' 'and positional arguments at the same time') elif kwargs: - data.append(kwargs) + data.append(kwargs) # type: ignore - return self._write_pkg(Proto.REQ_RUN, data, timeout=timeout) + return self._write_pkg( + Proto.REQ_RUN, data, timeout=timeout) # type: ignore - def _emit( + async def _emit( self, room_id: Union[int, str], event: str, *args: Optional[Any], - scope: Optional[str] = None, - ) -> asyncio.Future: + scope: Optional[str] = None): """Emit an event. Use Room(room_id, scope=scope).emit(..) instead of this function to @@ -492,10 +497,11 @@ def _emit( """ if scope is None: scope = self._scope - return self._write_pkg(Proto.REQ_EMIT, [scope, room_id, event, *args]) + await self._write_pkg(Proto.REQ_EMIT, [scope, room_id, event, *args]) def _join(self, *ids: Union[int, str], - scope: Optional[str] = None) -> asyncio.Future: + scope: Optional[str] = None + ) -> asyncio.Future[List[Optional[int]]]: """Join one or more rooms. Args: @@ -521,10 +527,11 @@ def _join(self, *ids: Union[int, str], if scope is None: scope = self._scope - return self._write_pkg(Proto.REQ_JOIN, [scope, *ids]) + return self._write_pkg(Proto.REQ_JOIN, [scope, *ids]) # type: ignore def _leave(self, *ids: Union[int, str], - scope: Optional[str] = None) -> asyncio.Future: + scope: Optional[str] = None + ) -> asyncio.Future[List[Optional[int]]]: """Leave one or more rooms. Stop receiving events for the rooms given by one or more ids. It is @@ -553,7 +560,7 @@ def _leave(self, *ids: Union[int, str], if scope is None: scope = self._scope - return self._write_pkg(Proto.REQ_LEAVE, [scope, *ids]) + return self._write_pkg(Proto.REQ_LEAVE, [scope, *ids]) # type: ignore @staticmethod def _auth_check(auth): @@ -574,7 +581,9 @@ def _auth_check(auth): def _is_websocket_host(host): return host.startswith('ws://') or host.startswith('wss://') - async def _connect(self, timeout=5): + async def _connect(self, timeout: Optional[int] = 5): + if not self._pool: + return host, port = self._pool[self._pool_idx] try: if self._is_websocket_host(host): @@ -646,6 +655,7 @@ def _on_connection_lost(self, protocol, exc): self.reconnect() async def _reconnect_loop(self): + assert self._pool # only when we have a pool try: wait_time = 1 timeout = 2 diff --git a/thingsdb/client/package.py b/thingsdb/client/package.py index 02e6862..9580b72 100644 --- a/thingsdb/client/package.py +++ b/thingsdb/client/package.py @@ -48,11 +48,11 @@ def _handle_fail_file(self, message: bytes): def extract_data_from(self, barray: bytearray) -> None: try: self.data = msgpack.unpackb( - barray[self.__class__.st_package.size:self.total], + bytes(barray[self.__class__.st_package.size:self.total]), raw=False) \ if self.length else None except Exception as e: - self._handle_fail_file(barray) + self._handle_fail_file(bytes(barray)) raise e finally: del barray[:self.total] diff --git a/thingsdb/client/protocol.py b/thingsdb/client/protocol.py index d5dd2ac..a9ac3e7 100644 --- a/thingsdb/client/protocol.py +++ b/thingsdb/client/protocol.py @@ -30,10 +30,18 @@ from ..exceptions import ZeroDivisionError try: import websockets - from websockets.client import connect, WebSocketClientProtocol - from websockets.exceptions import ConnectionClosed -except ImportError: - pass + from websockets.client import connect # type: ignore + from websockets.client import WebSocketClientProtocol # type: ignore + from websockets.exceptions import ConnectionClosed # type: ignore +except (ImportError, ModuleNotFoundError): + websockets = None + connect = None + + class WebSocketClientProtocol: + pass + + class ConnectionClosed(Exception): + pass WEBSOCKET_MAX_SIZE = 2**24 # default from websocket is 2**20 @@ -136,7 +144,7 @@ class Err(enum.IntEnum): ) -def proto_unkown(f, d): +def proto_unknown(f, d): f.set_exception(TypeError('unknown package type received ({})'.format(d))) @@ -150,23 +158,22 @@ def __init__( self._on_connection_lost = on_connection_lost self._on_event = on_event - async def _timer(self, pid: int, timeout: Optional[int]) -> None: + async def _timer(self, pid: int, timeout: int) -> None: await asyncio.sleep(timeout) try: future, task = self._requests.pop(pid) except KeyError: - logging.error('Timed out package Id not found: {}'.format( - self._data_package.pid)) + logging.error(f'Timed out package Id not found: {pid}') return None future.set_exception(TimeoutError( - 'request timed out on package Id {}'.format(pid))) + f'request timed out on package Id {pid}')) def _on_response(self, pkg: Package) -> None: try: future, task = self._requests.pop(pkg.pid) except KeyError: - logging.error('Received package id not found: {}'.format(pkg.pid)) + logging.error(f'Received package id not found: {pkg.pid}') return None # cancel the timeout task @@ -176,7 +183,7 @@ def _on_response(self, pkg: Package) -> None: if future.cancelled(): return - _PROTO_RESPONSE_MAP.get(pkg.tp, proto_unkown)(future, pkg.data) + _PROTO_RESPONSE_MAP.get(pkg.tp, proto_unknown)(future, pkg.data) def _handle_package(self, pkg: Package): tp = pkg.tp @@ -196,7 +203,7 @@ def write( data: Any = None, is_bin: bool = False, timeout: Optional[int] = None - ) -> asyncio.Future: + ) -> asyncio.Future[Any]: """Write data to ThingsDB. This will create a new PID and returns a Future which will be set when a response is received from ThingsDB, or time-out is reached. @@ -269,23 +276,26 @@ def __init__( self.package = None self.transport = None self.loop = asyncio.get_running_loop() if loop is None else loop - self.close_future = None + self.close_future: Optional[asyncio.Future[Any]] = None - def connection_made(self, transport: asyncio.Transport) -> None: + def connection_made(self, transport): ''' override asyncio.Protocol ''' self.close_future = self.loop.create_future() self.transport = transport - def connection_lost(self, exc: Exception) -> None: + def connection_lost(self, exc) -> None: ''' override asyncio.Protocol ''' self.cancel_requests() - self.close_future.set_result(None) - self.close_future = None + if self.close_future: + self.close_future.set_result(None) + self.close_future = None self.transport = None + if not isinstance(exc, Exception): + exc = Exception(f'connection lost ({exc})') self._on_connection_lost(self, exc) def data_received(self, data: bytes) -> None: @@ -317,7 +327,7 @@ def data_received(self, data: bytes) -> None: def _write(self, data: Any): if self.transport is None: raise ConnectionError('no connection') - self.transport.write(data) + self.transport.write(data) # type: ignore def close(self): if self.transport: @@ -327,14 +337,17 @@ def is_closing(self) -> bool: return self.close_future is not None async def wait_closed(self): - await self.close_future + if self.close_future: + await self.close_future async def close_and_wait(self): self.close() - await self.close_future + if self.close_future: + await self.close_future - def info(self): - return self.transport.get_extra_info('socket', None) + def info(self) -> Any: + if self.transport: + return self.transport.get_extra_info('socket', None) def is_connected(self) -> bool: return self.transport is not None @@ -355,10 +368,11 @@ def __init__( 'missing `websockets` module; ' 'please install the `websockets` module: ' '\n\n pip install websockets\n\n') - self._proto: WebSocketClientProtocol = None + self._proto: Optional[WebSocketClientProtocol] = None self._is_closing = False - async def connect(self, uri, ssl: SSLContext): + async def connect(self, uri, ssl: Optional[SSLContext]): + assert connect, 'websockets required, please install websockets' self._proto = await connect(uri, ssl=ssl, max_size=WEBSOCKET_MAX_SIZE) asyncio.create_task(self._recv_loop()) self._is_closing = False @@ -367,7 +381,7 @@ async def connect(self, uri, ssl: SSLContext): async def _recv_loop(self): try: while True: - data = await self._proto.recv() + data = await self._proto.recv() # type: ignore pkg = None try: pkg = Package(data) @@ -383,32 +397,34 @@ async def _recv_loop(self): except ConnectionClosed as exc: self.cancel_requests() - self._proto = None - self._on_connection_lost(self, exc) + self._proto = None # type: ignore + self._on_connection_lost(self, exc) # type: ignore def _write(self, data: Any): if self._proto is None: raise ConnectionError('no connection') - asyncio.create_task(self._proto.send(data)) + asyncio.create_task(self._proto.send(data)) # type: ignore def close(self): self._is_closing = True if self._proto: - asyncio.create_task(self._proto.close()) + asyncio.create_task(self._proto.close()) # type: ignore def is_closing(self) -> bool: - self._is_closing + return self._is_closing async def wait_closed(self): if self._proto: - await self._proto.wait_closed() + await self._proto.wait_closed() # type: ignore async def close_and_wait(self): if self._proto: - await self._proto.close() + await self._proto.close() # type: ignore - def info(self): - return self._proto.transport.get_extra_info('socket', None) + def info(self) -> Any: + return self._proto.transport.get_extra_info( # type: ignore + 'socket', + None) def is_connected(self) -> bool: return self._proto is not None diff --git a/thingsdb/room/roombase.py b/thingsdb/room/roombase.py index d4d858f..68b1711 100644 --- a/thingsdb/room/roombase.py +++ b/thingsdb/room/roombase.py @@ -15,12 +15,12 @@ def __init_subclass__(cls): for key, val in cls.__dict__.items(): if not key.startswith('__') and \ callable(val) and hasattr(val, '_event'): - cls._event_handlers[val._event] = val + cls._event_handlers[val._event] = val # type: ignore def __init__( self, room: Union[int, str], - scope: str = None): + scope: Optional[str] = None): """Initializes a room. Args: @@ -29,7 +29,7 @@ def __init__( Examples are: - 123 - '.my_room.id();' - scope (str): + scope (str?): Collection scope. If no scope is given, the scope will later be set to the default client scope once the room is joined. """ @@ -47,7 +47,10 @@ def scope(self): return self._scope @property - def client(self): + def client(self) -> Client: + if self._client is None: + raise RuntimeError( + 'must call join(..) or no_join(..) before we have a client') return self._client async def no_join(self, client: Client): @@ -106,8 +109,9 @@ async def join(self, client: Client, wait: Optional[float] = 60.0): if isinstance(self._id, str): if is_name(self._id): + code = "room(name).id();" id = await client.query( - "room(name).id();", + code, name=self._id, scope=self._scope) else: @@ -144,6 +148,7 @@ async def join(self, client: Client, wait: Optional[float] = 60.0): if wait: # wait for the first join to finish + assert isinstance(self._wait_join, asyncio.Future) await asyncio.wait_for(self._wait_join, wait) async def leave(self): @@ -155,11 +160,14 @@ async def leave(self): raise TypeError( 'room Id is not an integer; most likely `join()` has never ' 'been called') + if self._client is None: + raise RuntimeError( + 'must call join(..) or no_join(..) before using emit') res = await self._client._leave(self._id, scope=self._scope) if res[0] is None: raise LookupError(f'room Id {self._id} is not found (anymore)') - def emit(self, event: str, *args) -> asyncio.Future: + async def emit(self, event: str, *args): """Emit an event. Args: @@ -176,7 +184,7 @@ def emit(self, event: str, *args) -> asyncio.Future: if self._client is None: raise RuntimeError( 'must call join(..) or no_join(..) before using emit') - return self._client._emit(self._id, event, *args, scope=self._scope) + await self._client._emit(self._id, event, *args, scope=self._scope) def _on_event(self, pkg) -> Optional[asyncio.Task]: return self.__class__._ROOM_EVENT_MAP[pkg.tp](self, pkg.data) @@ -198,6 +206,7 @@ def on_emit(self, event: str, *args) -> None: pass async def _on_first_join(self): + assert isinstance(self._wait_join, asyncio.Future) fut = self._wait_join self._wait_join = None # Instead of using finally to set the result, we could also catch the @@ -223,11 +232,13 @@ def _on_join(self, _data): # User has decided not to wait for the join. Thus we can asume that # event handlers do not depend on the on_join to be finished assert self._wait_join is False - loop = self.client.get_event_loop() + assert self._client + loop = self._client.get_event_loop() asyncio.ensure_future(self.on_join(), loop=loop) def _on_stop(self, func): try: + assert self._client del self._client._rooms[self._id] except KeyError: pass diff --git a/thingsdb/util/id.py b/thingsdb/util/id.py index ddcbcc7..9bf6a45 100644 --- a/thingsdb/util/id.py +++ b/thingsdb/util/id.py @@ -13,6 +13,7 @@ def id(val): return id except Exception: return None + assert isinstance(val, dict) return val.get('#') diff --git a/thingsdb/version.py b/thingsdb/version.py index 99d2a6f..6ebd335 100644 --- a/thingsdb/version.py +++ b/thingsdb/version.py @@ -1 +1 @@ -__version__ = '1.1.5' +__version__ = '1.1.6'