Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions fastapi_utils/cbv.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

import pydantic
from fastapi import APIRouter, Depends
from fastapi.routing import APIRoute
from fastapi.routing import APIRoute, APIWebSocketRoute
from starlette.routing import Route, WebSocketRoute

PYDANTIC_VERSION = pydantic.VERSION
Expand Down Expand Up @@ -108,12 +108,16 @@ def _register_endpoints(router: APIRouter, cls: Type[Any], *urls: str) -> None:
_allocate_routes_by_method_name(router, url, function_members)
router_roles = []
for route in router.routes:
if not isinstance(route, APIRoute):
raise ValueError("The provided routes should be of type APIRoute")
if not (isinstance(route, APIRoute) or isinstance(route, APIWebSocketRoute)):
raise ValueError("The provided routes should be of type APIRoute or APIWebSocketRoute")

route_methods: Any = route.methods
cast(Tuple[Any], route_methods)
router_roles.append((route.path, tuple(route_methods)))
if isinstance(route, APIRoute):
route_methods: Any = route.methods
cast(Tuple[Any], route_methods)
router_roles.append((route.path, tuple(route_methods)))

if isinstance(route, APIWebSocketRoute):
router_roles.append((route.path, tuple(["WS"])))

if len(set(router_roles)) != len(router_roles):
raise Exception("An identical route role has been implemented more then once")
Expand Down
21 changes: 20 additions & 1 deletion tests/test_cbv.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
from typing import Any, ClassVar, Optional

import pytest
from fastapi import APIRouter, Depends, Request
from fastapi import APIRouter, Depends, Request, WebSocket
from starlette.testclient import TestClient
from starlette.websockets import WebSocketDisconnect

from fastapi_utils.cbv import cbv

Expand Down Expand Up @@ -147,3 +148,21 @@ def example(self, request: Request) -> str:
client = TestClient(router)
response = client.get("/foo")
assert response.json() == "http://testserver/bar"

def test_websocket_router(self, router: APIRouter) -> None:
@cbv(router)
class Foo:
@router.websocket("/ws")
async def example(self, websocket: WebSocket) -> None:
await websocket.accept()
await websocket.send_text("hello")
"VALID_VALUE" == await websocket.receive_text()
await websocket.close()

client = TestClient(router)
with client.websocket_connect("/ws") as websocket:
assert websocket.receive_text() == "hello"
websocket.send_text("VALID_VALUE")

with pytest.raises(WebSocketDisconnect):
assert websocket.receive_text()