Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,14 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install pytest pycodestyle
pip install pytest pycodestyle pyright websockets
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
- name: Run tests with pytest
run: |
pytest
- name: Lint with PyCodeStyle
run: |
find . -name \*.py -exec pycodestyle {} +
find . -name \*.py -exec pycodestyle {} +
- name: Code validation using Pyright
run: |
pyright
4 changes: 3 additions & 1 deletion examples/bookstore/webserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import asyncio
from sys import argv
from functools import partial
from aiohttp import web
from aiohttp import web # type: ignore
from thingsdb.client import Client
from thingsdb.room import Room, event

Expand Down Expand Up @@ -62,12 +62,14 @@ def on_cleanup():
async def add_book(request):
book = await request.json()
# Use the procedure to add the book
assert bookstore
await bookstore.add_book(book)
return web.HTTPNoContent()


# We have the books in memory, no need for a query
async def get_books(request):
assert bookstore
return web.json_response({
"book_titles": [book['title'] for book in bookstore.books]
})
Expand Down
29 changes: 20 additions & 9 deletions thingsdb/client/buildin.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import asyncio
import datetime
from abc import ABC, abstractmethod
from typing import Union as U
from typing import Optional
from typing import Any
Expand All @@ -9,6 +11,15 @@ class Buildin:
#
# Build-in functions from the @thingsdb scope
#
@abstractmethod
def query(
self,
code: str,
scope: Optional[str] = None,
timeout: Optional[int] = None,
skip_strip_code: bool = False,
**kwargs: Any) -> asyncio.Future[Any]:
...

async def collection_info(self, collection: U[int, str]) -> dict:
"""Returns information about a specific collection.
Expand Down Expand Up @@ -245,16 +256,16 @@ async def new_token(
expiration_time: Optional[datetime.datetime] = None,
description: str = ''):

if expiration_time is not None:
expiration_time = int(datetime.datetime.timestamp(expiration_time))
ts = None if expiration_time is None \
else int(expiration_time.timestamp())

return await self.query(
"""//ti
et = is_nil(expiration_time) ? nil : datetime(expiration_time);
new_token(user, et, description);
""",
user=user,
expiration_time=expiration_time,
expiration_time=ts,
description=description,
scope='@t')

Expand Down Expand Up @@ -334,7 +345,8 @@ async def set_module_scope(
module_scope=scope,
scope='@t')

async def set_password(self, user: str, new_password: str = None) -> None:
async def set_password(self, user: str,
new_password: Optional[str] = None) -> None:
return await self.query(
'set_password(user, new_password)',
user=user,
Expand Down Expand Up @@ -412,16 +424,15 @@ async def new_backup(
max_files: Optional[int] = 7,
scope='@n'):

if start_ts is not None:
start_ts = int(datetime.datetime.timestamp(start_ts))
ts = None if start_ts is None else int(start_ts.timestamp())

return await self.query(
"""//ti
start_ts = is_nil(start_ts) ? nil : datetime(start_ts);
new_backup(file_template, start_ts, repeat, max_files);
""",
file_template=file_template,
start_ts=start_ts,
start_ts=ts,
repeat=repeat,
max_files=max_files,
scope=scope)
Expand Down Expand Up @@ -454,14 +465,14 @@ async def restart_module(self, name: str) -> None:
return await self.query('restart_module(name)', name=name, scope='@t')

async def set_log_level(self, log_level: str, scope='@n') -> None:
log_level = (
level = (
'DEBUG',
'INFO',
'WARNING',
'ERROR',
'CRITICAL').index(log_level)
return await self.query(
'set_log_level(log_level)', log_level=log_level, scope=scope)
'set_log_level(log_level)', log_level=level, scope=scope)

async def shutdown(self, scope='@n') -> None:
"""Shutdown the node in the selected scope.
Expand Down
68 changes: 39 additions & 29 deletions thingsdb/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import time
from collections import defaultdict
from ssl import SSLContext, PROTOCOL_TLS
from typing import Optional, Union, Any
from typing import Optional, Union, Any, List
from concurrent.futures import CancelledError
from .buildin import Buildin
from .protocol import Proto, Protocol, ProtocolWS
Expand Down Expand Up @@ -135,7 +135,7 @@ def connection_info(self) -> str:
"""
if not self.is_connected():
return 'disconnected'
socket = self._protocol.info()
socket = self._protocol.info() # type: ignore
if socket is None:
return 'unknown_addr'
addr, port = socket.getpeername()[:2]
Expand All @@ -145,7 +145,7 @@ def connect_pool(
self,
pool: list,
*auth: Union[str, tuple]
) -> asyncio.Future:
) -> asyncio.Future[None]:
"""Connect using a connection pool.

When using a connection pool, the client will randomly choose a node
Expand Down Expand Up @@ -183,21 +183,23 @@ def connect_pool(
assert self._reconnecting is False
assert len(pool), 'pool must contain at least one node'
if len(auth) == 1:
auth = auth[0]
auth = auth[0] # type: ignore

self._pool = tuple((
(address, 9200) if isinstance(address, str) else address
for address in pool))
self._auth = self._auth_check(auth)
self._pool_idx = random.randint(0, len(pool) - 1)
return self.reconnect()
fut = self.reconnect()
if fut is None:
raise ConnectionError('client already connecting')
return fut

def connect(
async def connect(
self,
host: str,
port: int = 9200,
timeout: Optional[int] = 5
) -> asyncio.Future:
timeout: Optional[int] = 5):
"""Connect to ThingsDB.

This method will *only* create a connection, so the connection is not
Expand Down Expand Up @@ -231,9 +233,9 @@ def connect(
assert self.is_connected() is False
self._pool = ((host, port),)
self._pool_idx = 0
return self._connect(timeout=timeout)
await self._connect(timeout=timeout)

def reconnect(self) -> Optional[asyncio.Future]:
def reconnect(self) -> Optional[asyncio.Future[Any]]:
"""Re-connect to ThingsDB.

This method can be used, even when a connection still exists. In case
Expand Down Expand Up @@ -286,7 +288,7 @@ async def authenticate(
wait forever on a response. Defaults to 5.
"""
if len(auth) == 1:
auth = auth[0]
auth = auth[0] # type: ignore
self._auth = self._auth_check(auth)
await self._authenticate(timeout)

Expand All @@ -297,7 +299,7 @@ def query(
timeout: Optional[int] = None,
skip_strip_code: bool = False,
**kwargs: Any
) -> asyncio.Future:
) -> asyncio.Future[Any]:
"""Query ThingsDB.

Use this method to run `code` in a scope.
Expand Down Expand Up @@ -348,17 +350,18 @@ def query(

data = [scope, code]
if kwargs:
data.append(kwargs)
data.append(kwargs) # type: ignore

return self._write_pkg(Proto.REQ_QUERY, data, timeout=timeout)
return self._write_pkg(
Proto.REQ_QUERY, data, timeout=timeout) # type: ignore

async def _ensure_write(
self,
tp: Proto,
data: Any = None,
is_bin: bool = False,
timeout: Optional[int] = None
) -> asyncio.Future:
) -> asyncio.Future[Any]:
if not self._pool:
raise ConnectionError('no connection')

Expand All @@ -372,6 +375,7 @@ async def _ensure_write(
continue

try:
assert self._protocol # we're connected
res = await self._protocol.write(tp, data, is_bin, timeout)
except (asyncio.exceptions.CancelledError,
CancelledError, NodeError, AuthError) as e:
Expand All @@ -394,9 +398,10 @@ async def _write(
data: Any = None,
is_bin: bool = False,
timeout: Optional[int] = None
) -> asyncio.Future:
) -> asyncio.Future[Any]:
if not self.is_connected():
raise ConnectionError('no connection')
assert self._protocol # we are connected
return await self._protocol.write(tp, data, is_bin, timeout)

def run(
Expand All @@ -406,7 +411,7 @@ def run(
scope: Optional[str] = None,
timeout: Optional[int] = None,
**kwargs: Any,
) -> asyncio.Future:
) -> asyncio.Future[Any]:
"""Run a procedure.

Use this method to run a stored procedure in a scope.
Expand Down Expand Up @@ -449,23 +454,23 @@ def run(
data = [scope, procedure]

if args:
data.append(args)
data.append(args) # type: ignore
if kwargs:
raise ValueError(
'it is not possible to use both keyword arguments '
'and positional arguments at the same time')
elif kwargs:
data.append(kwargs)
data.append(kwargs) # type: ignore

return self._write_pkg(Proto.REQ_RUN, data, timeout=timeout)
return self._write_pkg(
Proto.REQ_RUN, data, timeout=timeout) # type: ignore

def _emit(
async def _emit(
self,
room_id: Union[int, str],
event: str,
*args: Optional[Any],
scope: Optional[str] = None,
) -> asyncio.Future:
scope: Optional[str] = None):
"""Emit an event.

Use Room(room_id, scope=scope).emit(..) instead of this function to
Expand All @@ -492,10 +497,11 @@ def _emit(
"""
if scope is None:
scope = self._scope
return self._write_pkg(Proto.REQ_EMIT, [scope, room_id, event, *args])
await self._write_pkg(Proto.REQ_EMIT, [scope, room_id, event, *args])

def _join(self, *ids: Union[int, str],
scope: Optional[str] = None) -> asyncio.Future:
scope: Optional[str] = None
) -> asyncio.Future[List[Optional[int]]]:
"""Join one or more rooms.

Args:
Expand All @@ -521,10 +527,11 @@ def _join(self, *ids: Union[int, str],
if scope is None:
scope = self._scope

return self._write_pkg(Proto.REQ_JOIN, [scope, *ids])
return self._write_pkg(Proto.REQ_JOIN, [scope, *ids]) # type: ignore

def _leave(self, *ids: Union[int, str],
scope: Optional[str] = None) -> asyncio.Future:
scope: Optional[str] = None
) -> asyncio.Future[List[Optional[int]]]:
"""Leave one or more rooms.

Stop receiving events for the rooms given by one or more ids. It is
Expand Down Expand Up @@ -553,7 +560,7 @@ def _leave(self, *ids: Union[int, str],
if scope is None:
scope = self._scope

return self._write_pkg(Proto.REQ_LEAVE, [scope, *ids])
return self._write_pkg(Proto.REQ_LEAVE, [scope, *ids]) # type: ignore

@staticmethod
def _auth_check(auth):
Expand All @@ -574,7 +581,9 @@ def _auth_check(auth):
def _is_websocket_host(host):
return host.startswith('ws://') or host.startswith('wss://')

async def _connect(self, timeout=5):
async def _connect(self, timeout: Optional[int] = 5):
if not self._pool:
return
host, port = self._pool[self._pool_idx]
try:
if self._is_websocket_host(host):
Expand Down Expand Up @@ -646,6 +655,7 @@ def _on_connection_lost(self, protocol, exc):
self.reconnect()

async def _reconnect_loop(self):
assert self._pool # only when we have a pool
try:
wait_time = 1
timeout = 2
Expand Down
4 changes: 2 additions & 2 deletions thingsdb/client/package.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,11 @@ def _handle_fail_file(self, message: bytes):
def extract_data_from(self, barray: bytearray) -> None:
try:
self.data = msgpack.unpackb(
barray[self.__class__.st_package.size:self.total],
bytes(barray[self.__class__.st_package.size:self.total]),
raw=False) \
if self.length else None
except Exception as e:
self._handle_fail_file(barray)
self._handle_fail_file(bytes(barray))
raise e
finally:
del barray[:self.total]
Expand Down
Loading