diff --git a/fastapi_utils/cbv.py b/fastapi_utils/cbv.py index 231325c..677b931 100644 --- a/fastapi_utils/cbv.py +++ b/fastapi_utils/cbv.py @@ -13,7 +13,7 @@ import pydantic from fastapi import APIRouter, Depends -from fastapi.routing import APIRoute +from fastapi.routing import APIRoute, APIWebSocketRoute from starlette.routing import Route, WebSocketRoute PYDANTIC_VERSION = pydantic.VERSION @@ -108,12 +108,16 @@ def _register_endpoints(router: APIRouter, cls: Type[Any], *urls: str) -> None: _allocate_routes_by_method_name(router, url, function_members) router_roles = [] for route in router.routes: - if not isinstance(route, APIRoute): - raise ValueError("The provided routes should be of type APIRoute") + if not (isinstance(route, APIRoute) or isinstance(route, APIWebSocketRoute)): + raise ValueError("The provided routes should be of type APIRoute or APIWebSocketRoute") - route_methods: Any = route.methods - cast(Tuple[Any], route_methods) - router_roles.append((route.path, tuple(route_methods))) + if isinstance(route, APIRoute): + route_methods: Any = route.methods + cast(Tuple[Any], route_methods) + router_roles.append((route.path, tuple(route_methods))) + + if isinstance(route, APIWebSocketRoute): + router_roles.append((route.path, tuple(["WS"]))) if len(set(router_roles)) != len(router_roles): raise Exception("An identical route role has been implemented more then once") diff --git a/tests/test_cbv.py b/tests/test_cbv.py index b5b9f64..c2ca264 100644 --- a/tests/test_cbv.py +++ b/tests/test_cbv.py @@ -3,8 +3,9 @@ from typing import Any, ClassVar, Optional import pytest -from fastapi import APIRouter, Depends, Request +from fastapi import APIRouter, Depends, Request, WebSocket from starlette.testclient import TestClient +from starlette.websockets import WebSocketDisconnect from fastapi_utils.cbv import cbv @@ -147,3 +148,21 @@ def example(self, request: Request) -> str: client = TestClient(router) response = client.get("/foo") assert response.json() == "http://testserver/bar" + + def test_websocket_router(self, router: APIRouter) -> None: + @cbv(router) + class Foo: + @router.websocket("/ws") + async def example(self, websocket: WebSocket) -> None: + await websocket.accept() + await websocket.send_text("hello") + "VALID_VALUE" == await websocket.receive_text() + await websocket.close() + + client = TestClient(router) + with client.websocket_connect("/ws") as websocket: + assert websocket.receive_text() == "hello" + websocket.send_text("VALID_VALUE") + + with pytest.raises(WebSocketDisconnect): + assert websocket.receive_text()