Skip to content

Commit df86bf1

Browse files
authored
If server allows anonymous access, allow anon access to WS too. (#1219)
* If server allows anonymous access, allow anon access to WS too. * fix: Test logic was not quite right * Extend timeouts (tests have been flaky) * tst: Test Subscription without API key
1 parent b098863 commit df86bf1

File tree

5 files changed

+160
-41
lines changed

5 files changed

+160
-41
lines changed

tiled/_tests/conftest.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -350,9 +350,7 @@ def minio_uri():
350350
raise pytest.skip("No TILED_TEST_BUCKET configured")
351351

352352

353-
@pytest.fixture(scope="function")
354-
def tiled_websocket_context(tmpdir, redis_uri):
355-
"""Fixture that provides a Tiled context with websocket support."""
353+
def build_test_app(tmpdir, redis_uri, public=False):
356354
tree = from_uri(
357355
"sqlite:///:memory:",
358356
writable_storage=[
@@ -373,10 +371,25 @@ def tiled_websocket_context(tmpdir, redis_uri):
373371
)
374372
app = build_app(
375373
tree,
376-
authentication=Authentication(single_user_api_key="secret"),
374+
authentication=Authentication(
375+
single_user_api_key="secret",
376+
allow_anonymous_access=public,
377+
),
377378
)
379+
return app
380+
381+
382+
@pytest.fixture(scope="function")
383+
def tiled_websocket_context(tmpdir, redis_uri):
384+
"""Fixture that provides a Tiled context with websocket support."""
385+
with Context.from_app(build_test_app(tmpdir, redis_uri, public=False)) as context:
386+
yield context
378387

379-
with Context.from_app(app) as context:
388+
389+
@pytest.fixture(scope="function")
390+
def tiled_websocket_context_public(tmpdir, redis_uri):
391+
"""Fixture that provides a Tiled context with websocket support."""
392+
with Context.from_app(build_test_app(tmpdir, redis_uri, public=True)) as context:
380393
yield context
381394

382395

tiled/_tests/test_subscription.py

Lines changed: 62 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def callback(update):
6868
streaming_node.write(new_arr)
6969

7070
# Wait for all messages to be received
71-
assert received_event.wait(timeout=5.0), "Timeout waiting for messages"
71+
assert received_event.wait(timeout=10.0), "Timeout waiting for messages"
7272

7373
# Verify all updates received in order
7474
assert len(received) == 3
@@ -139,7 +139,7 @@ def callback(update):
139139
streaming_node.write(new_arr)
140140

141141
# Wait for messages to be received
142-
assert received_event.wait(timeout=5.0), "Timeout waiting for messages"
142+
assert received_event.wait(timeout=10.0), "Timeout waiting for messages"
143143

144144
# Should only receive the 2 new updates (not the first one)
145145
assert len(received) == 2
@@ -196,7 +196,7 @@ def callback(update):
196196
streaming_node.write(new_arr)
197197

198198
# Wait for all messages to be received
199-
assert received_event.wait(timeout=5.0), "Timeout waiting for messages"
199+
assert received_event.wait(timeout=10.0), "Timeout waiting for messages"
200200

201201
# Should receive: initial array + first update + 2 new updates = 4 total
202202
assert len(received) == 4
@@ -253,7 +253,7 @@ def child_metadata_updated_cb(update):
253253
time.sleep(0.1)
254254
unique_key = f"{uuid.uuid4().hex[:8]}"
255255
uploaded_nodes.append(client.create_container(unique_key))
256-
assert created_3.wait(timeout=5.0), "Timeout waiting for messages"
256+
assert created_3.wait(timeout=10.0), "Timeout waiting for messages"
257257
downloaded_nodes = list(client.values())
258258
for up, streamed, down in zip(uploaded_nodes, streamed_nodes, downloaded_nodes):
259259
pass
@@ -262,7 +262,7 @@ def child_metadata_updated_cb(update):
262262

263263
assert len(child_metadata_updated_updates) == 0
264264
client.values().last().update_metadata({"color": "blue"})
265-
assert received_event.wait(timeout=5.0), "Timeout waiting for messages"
265+
assert received_event.wait(timeout=10.0), "Timeout waiting for messages"
266266
assert len(child_metadata_updated_updates) == 1
267267

268268

@@ -283,7 +283,7 @@ def callback(sub):
283283
sub.stream_closed.add_callback(callback)
284284
assert not event.is_set()
285285
x.close_stream()
286-
assert event.wait(timeout=5.0), "Timeout waiting for messages"
286+
assert event.wait(timeout=10.0), "Timeout waiting for messages"
287287

288288

289289
def test_subscribe_to_disconnected(
@@ -305,7 +305,7 @@ def callback(sub):
305305
sub.disconnected.add_callback(callback)
306306
assert not event.is_set()
307307
sub.disconnect()
308-
assert event.wait(timeout=5.0), "Timeout waiting for messages"
308+
assert event.wait(timeout=10.0), "Timeout waiting for messages"
309309

310310
# If the writer closes the stream, the client is disconnected.
311311
with x.subscribe().start_in_thread() as sub:
@@ -317,7 +317,7 @@ def callback(sub):
317317
sub.disconnected.add_callback(callback)
318318
assert not event.is_set()
319319
x.close_stream()
320-
assert event.wait(timeout=5.0), "Timeout waiting for messages"
320+
assert event.wait(timeout=10.0), "Timeout waiting for messages"
321321

322322

323323
def test_subscribe_to_array_registered_with_patch(tiled_websocket_context, tmp_path):
@@ -412,7 +412,7 @@ def on_child_created(update):
412412
content=safe_json_dump({"data_source": updated_data_source}),
413413
params=params,
414414
).raise_for_status()
415-
assert event.wait(timeout=5.0), "Timeout waiting for messages"
415+
assert event.wait(timeout=10.0), "Timeout waiting for messages"
416416
x.close_stream()
417417
client.close_stream()
418418
x.refresh()
@@ -503,7 +503,7 @@ def on_child_created(update):
503503
}
504504
),
505505
).raise_for_status()
506-
assert event.wait(timeout=5.0), "Timeout waiting for messages"
506+
assert event.wait(timeout=10.0), "Timeout waiting for messages"
507507
x.close_stream()
508508
client.close_stream()
509509
x.refresh()
@@ -533,12 +533,12 @@ def collect(update):
533533
sub = client[key].subscribe()
534534
sub.new_data.add_callback(collect)
535535
with sub.start_in_thread(1):
536-
assert event.wait(timeout=5.0), "Timeout waiting for messages"
536+
assert event.wait(timeout=10.0), "Timeout waiting for messages"
537537
actual = updates[0].data()
538538
assert_frame_equal(actual, df1)
539539
event.clear()
540540
x.write(df2)
541-
assert event.wait(timeout=5.0), "Timeout waiting for messages"
541+
assert event.wait(timeout=10.0), "Timeout waiting for messages"
542542
assert not updates[1].append
543543
actual_updated = updates[1].data()
544544
assert_frame_equal(actual_updated, df2)
@@ -565,14 +565,14 @@ def collect(update):
565565
sub.new_data.add_callback(collect)
566566
with sub.start_in_thread(1):
567567
x.append_partition(0, table1)
568-
assert event.wait(timeout=5.0), "Timeout waiting for messages"
568+
assert event.wait(timeout=10.0), "Timeout waiting for messages"
569569
assert updates[0].append
570570
streamed1 = updates[0].data()
571571
streamed1_pyarrow = pyarrow.Table.from_pandas(streamed1, preserve_index=False)
572572
assert streamed1_pyarrow == table1
573573
event.clear()
574574
x.append_partition(0, table2)
575-
assert event.wait(timeout=5.0), "Timeout waiting for messages"
575+
assert event.wait(timeout=10.0), "Timeout waiting for messages"
576576
assert updates[1].append
577577
streamed2 = updates[1].data()
578578
streamed2_pyarrow = pyarrow.Table.from_pandas(streamed2, preserve_index=False)
@@ -641,3 +641,51 @@ def __call__(self, timeout=None):
641641

642642
# Restore original recv before disconnecting to avoid cleanup issues
643643
subscription._websocket.recv = original_recv
644+
645+
646+
def test_subscribe_no_api_key_rejected(tiled_websocket_context):
647+
"Private server does not allow anonymous user to subscribe."
648+
context = tiled_websocket_context
649+
client = from_context(context)
650+
651+
arr = np.arange(10)
652+
streaming_node = client.write_array(arr, key="test_stream_immediate")
653+
654+
received_event = threading.Event()
655+
656+
def callback(update):
657+
"Set event once any update has been received."
658+
received_event.set()
659+
660+
# Any further requests will be unauthenticated.
661+
context.api_key = None
662+
663+
subscription = streaming_node.subscribe()
664+
subscription.new_data.add_callback(callback)
665+
666+
with pytest.raises(WebSocketDenialResponse):
667+
subscription.start(0)
668+
669+
670+
def test_subscribe_no_api_key_public(tiled_websocket_context_public):
671+
"Public server allows anonymous user to subscribe."
672+
context = tiled_websocket_context_public
673+
client = from_context(context)
674+
675+
arr = np.arange(10)
676+
streaming_node = client.write_array(arr, key="test_stream_immediate")
677+
678+
received_event = threading.Event()
679+
680+
def callback(update):
681+
"Set event once any update has been received."
682+
received_event.set()
683+
684+
# Any further requests will be unauthenticated.
685+
context.api_key = None
686+
687+
subscription = streaming_node.subscribe()
688+
subscription.new_data.add_callback(callback)
689+
690+
with subscription.start_in_thread(0):
691+
assert received_event.wait(timeout=10.0), "Timeout waiting for messages"

tiled/_tests/test_websockets.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import msgpack
55
import numpy as np
66
import pytest
7+
from starlette.testclient import WebSocketDenialResponse
78

89
from ..client import from_context
910
from ..config import parse_configs
@@ -381,7 +382,6 @@ def test_close_stream_not_found(tiled_websocket_context):
381382

382383
def test_websocket_connection_wrong_api_key(tiled_websocket_context):
383384
"""Test websocket connection with wrong API key fails with 401."""
384-
from starlette.testclient import WebSocketDenialResponse
385385

386386
context = tiled_websocket_context
387387
client = from_context(context)
@@ -402,6 +402,50 @@ def test_websocket_connection_wrong_api_key(tiled_websocket_context):
402402
assert exc_info.value.status_code == 401
403403

404404

405+
def test_websocket_connection_no_api_key(tiled_websocket_context):
406+
"""Test websocket connection with no API key fails with 401."""
407+
408+
context = tiled_websocket_context
409+
client = from_context(context)
410+
test_client = context.http_client
411+
412+
# Create streaming array node using correct key
413+
arr = np.arange(10)
414+
client.write_array(arr, key="test_auth_websocket")
415+
416+
# Strip API key so requests below are unauthenticated.
417+
context.api_key = None
418+
419+
# Try to connect to websocket with no API key
420+
with pytest.raises(WebSocketDenialResponse) as exc_info:
421+
with test_client.websocket_connect(
422+
"/api/v1/stream/single/test_auth_websocket?envelope_format=msgpack",
423+
):
424+
pass
425+
426+
assert exc_info.value.status_code == 401
427+
428+
429+
def test_websocket_connection_public_no_api_key(tiled_websocket_context_public):
430+
"""Test websocket connection to a public server with no API key works."""
431+
context = tiled_websocket_context_public
432+
client = from_context(context)
433+
test_client = context.http_client
434+
435+
# Create streaming array node using correct key
436+
arr = np.arange(10)
437+
client.write_array(arr, key="test_auth_websocket")
438+
439+
# Strip API key so requests below are unauthenticated.
440+
context.api_key = None
441+
442+
# Try to connect to (public) websocket with no API key
443+
with test_client.websocket_connect(
444+
"/api/v1/stream/single/test_auth_websocket?envelope_format=msgpack",
445+
):
446+
pass
447+
448+
405449
def test_close_stream_wrong_api_key(tiled_websocket_context):
406450
"""Test close endpoint returns 403 with wrong API key."""
407451
context = tiled_websocket_context

tiled/client/stream.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -67,15 +67,18 @@ def __init__(self, http_client, uri: httpx.URL):
6767
self._websocket = None
6868
self._connection_lock = threading.Lock()
6969

70-
def connect(self, api_key: str, start: Optional[int] = None):
70+
def connect(self, api_key: Optional[str], start: Optional[int] = None):
7171
"""Connect to the websocket."""
72+
params = self._uri.params
73+
headers = {}
74+
if api_key:
75+
headers["Authorization"] = f"Apikey {api_key}"
76+
if start is not None:
77+
params = params.set("start", start)
7278
with self._connection_lock:
73-
params = self._uri.params
74-
if start is not None:
75-
params = params.set("start", start)
7679
self._websocket = self._http_client.websocket_connect(
7780
str(self._uri.copy_with(params=params)),
78-
headers={"Authorization": f"Apikey {api_key}"},
81+
headers=headers,
7982
)
8083
self._websocket.__enter__()
8184

@@ -105,14 +108,17 @@ def __init__(self, http_client, uri: httpx.URL):
105108
self._uri = uri
106109
self._websocket = None
107110

108-
def connect(self, api_key: str, start: Optional[int] = None):
111+
def connect(self, api_key: Optional[str], start: Optional[int] = None):
109112
"""Connect to the websocket."""
110113
params = self._uri.params
114+
headers = {}
115+
if api_key:
116+
headers["Authorization"] = f"Apikey {api_key}"
111117
if start is not None:
112118
params = params.set("start", start)
113119
self._websocket = connect(
114120
str(self._uri.copy_with(params=params)),
115-
additional_headers={"Authorization": f"Apikey {api_key}"},
121+
additional_headers=headers,
116122
)
117123

118124
def recv(self, timeout=None):
@@ -341,7 +347,7 @@ def _connect(self, start: Optional[int] = None) -> None:
341347
)
342348
api_key = key_info["secret"]
343349
else:
344-
# Use single-user API key.
350+
# Use single-user API key or None (if unauthenticated).
345351
api_key = self.context.api_key
346352

347353
# Connect using the websocket wrapper

tiled/server/authentication.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -284,12 +284,9 @@ async def get_current_access_tags(
284284

285285
def get_api_key_websocket(
286286
authorization: Annotated[Optional[str], Header()] = None,
287-
):
287+
) -> Optional[str]:
288288
if authorization is None:
289-
raise HTTPException(
290-
status_code=HTTP_401_UNAUTHORIZED,
291-
detail="An API key must be passed in the Authorization header",
292-
)
289+
return None
293290
scheme, api_key = get_authorization_scheme_param(authorization)
294291
if scheme.lower() != "apikey":
295292
raise HTTPException(
@@ -489,19 +486,30 @@ async def get_current_principal_from_api_key(
489486

490487
async def get_current_principal_websocket(
491488
websocket: WebSocket,
492-
api_key: str = Depends(get_api_key_websocket),
489+
api_key: Optional[str] = Depends(get_api_key_websocket),
493490
settings: Settings = Depends(get_settings),
494491
db_factory: Callable[[], Optional[AsyncSession]] = Depends(
495492
get_database_session_factory
496493
),
497494
):
498-
async with db_factory() as db:
499-
principal = await get_current_principal_from_api_key(
500-
api_key, websocket.app.state.authenticated, db, settings
501-
)
502-
if principal is None and websocket.app.state.authenticated:
503-
raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="Invalid API key")
504-
return principal
495+
if api_key is not None:
496+
async with db_factory() as db:
497+
principal = await get_current_principal_from_api_key(
498+
api_key, websocket.app.state.authenticated, db, settings
499+
)
500+
if (principal is None) and websocket.app.state.authenticated:
501+
raise HTTPException(
502+
status_code=HTTP_401_UNAUTHORIZED, detail="Invalid API key"
503+
)
504+
return principal
505+
else:
506+
if settings.allow_anonymous_access:
507+
return None
508+
else:
509+
raise HTTPException(
510+
status_code=HTTP_401_UNAUTHORIZED,
511+
detail="No API key was provided with this request.",
512+
)
505513

506514

507515
async def get_current_principal(

0 commit comments

Comments
 (0)