From 427d81195446df36ba37d28381bf249ceb734663 Mon Sep 17 00:00:00 2001 From: David Gil Date: Tue, 17 Feb 2026 21:58:08 +0100 Subject: [PATCH 1/2] security: restrict CORS to known local origins instead of wildcard The wildcard `allow_origins=["*"]` allows any website the user visits to make requests to the local voicebox backend, potentially triggering TTS generation or reading voice profiles without consent. Restrict to the known Tauri webview and Vite dev server origins by default. Users running in remote server mode can set VOICEBOX_CORS_ORIGINS to allow additional origins. --- backend/main.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/backend/main.py b/backend/main.py index 59fb9e1..05ce741 100644 --- a/backend/main.py +++ b/backend/main.py @@ -36,10 +36,23 @@ version=__version__, ) -# CORS middleware +# CORS middleware - restrict to known local origins by default. +# Set VOICEBOX_CORS_ORIGINS env var to a comma-separated list of origins +# to allow additional origins (e.g. for remote server mode). +_default_origins = [ + "http://localhost:5173", # Vite dev server + "http://127.0.0.1:5173", + "http://localhost:17493", + "http://127.0.0.1:17493", + "tauri://localhost", # Tauri webview (macOS) + "https://tauri.localhost", # Tauri webview (Windows/Linux) +] +_env_origins = os.environ.get("VOICEBOX_CORS_ORIGINS", "") +_cors_origins = _default_origins + [o.strip() for o in _env_origins.split(",") if o.strip()] + app.add_middleware( CORSMiddleware, - allow_origins=["*"], # Configure appropriately for production + allow_origins=_cors_origins, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], From 80c87c8e2c01f3d0c87690cfc5111d90dda1d30d Mon Sep 17 00:00:00 2001 From: David Gil Date: Tue, 17 Feb 2026 22:04:25 +0100 Subject: [PATCH 2/2] test: add CORS origin restriction tests 20 tests covering: - All 6 default local origins are allowed - Arbitrary external origins are blocked - Preflight (OPTIONS) requests respect the allowlist - VOICEBOX_CORS_ORIGINS env var extends the allowlist - Edge cases: empty env, whitespace trimming, trailing commas Tests use a minimal FastAPI app mirroring the real CORS config, so they run without ML dependencies (torch, numpy, etc.). --- backend/tests/test_cors.py | 162 +++++++++++++++++++++++++++++++++++++ 1 file changed, 162 insertions(+) create mode 100644 backend/tests/test_cors.py diff --git a/backend/tests/test_cors.py b/backend/tests/test_cors.py new file mode 100644 index 0000000..ae999c9 --- /dev/null +++ b/backend/tests/test_cors.py @@ -0,0 +1,162 @@ +""" +Tests for CORS origin restrictions. + +Validates that the CORS middleware only allows known local origins +and respects the VOICEBOX_CORS_ORIGINS environment variable. + +Uses a minimal FastAPI app that mirrors the exact CORS configuration +from backend/main.py, so tests run without heavy ML dependencies. + +Usage: + pip install httpx pytest fastapi starlette + python -m pytest backend/tests/test_cors.py -v +""" + +import os +import pytest +from unittest.mock import patch +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware +from starlette.testclient import TestClient + + +def _build_app(env_origins: str = "") -> FastAPI: + """ + Build a minimal FastAPI app with the same CORS logic as backend/main.py. + + This mirrors the exact code in main.py so the test validates the real + configuration without needing torch/numpy/transformers installed. + """ + app = FastAPI() + + _default_origins = [ + "http://localhost:5173", + "http://127.0.0.1:5173", + "http://localhost:17493", + "http://127.0.0.1:17493", + "tauri://localhost", + "https://tauri.localhost", + ] + _cors_origins = _default_origins + [o.strip() for o in env_origins.split(",") if o.strip()] + + app.add_middleware( + CORSMiddleware, + allow_origins=_cors_origins, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + + @app.get("/health") + async def health(): + return {"status": "ok"} + + return app + + +@pytest.fixture() +def client(): + return TestClient(_build_app()) + + +@pytest.fixture() +def client_with_custom_origins(): + return TestClient(_build_app("https://custom.example.com,https://other.example.com")) + + +def _get_with_origin(client: TestClient, origin: str) -> dict: + """Send a GET with Origin header, return response headers.""" + response = client.get("/health", headers={"Origin": origin}) + return dict(response.headers) + + +def _preflight(client: TestClient, origin: str) -> dict: + """Send CORS preflight OPTIONS request, return response headers.""" + response = client.options( + "/health", + headers={ + "Origin": origin, + "Access-Control-Request-Method": "GET", + }, + ) + return dict(response.headers) + + +class TestCORSDefaultOrigins: + """CORS should allow known local origins and block everything else.""" + + @pytest.mark.parametrize("origin", [ + "http://localhost:5173", + "http://127.0.0.1:5173", + "http://localhost:17493", + "http://127.0.0.1:17493", + "tauri://localhost", + "https://tauri.localhost", + ]) + def test_allowed_origins(self, client, origin): + headers = _get_with_origin(client, origin) + assert headers.get("access-control-allow-origin") == origin + + @pytest.mark.parametrize("origin", [ + "http://evil.com", + "http://localhost:9999", + "https://attacker.example.com", + "null", + ]) + def test_blocked_origins(self, client, origin): + headers = _get_with_origin(client, origin) + assert "access-control-allow-origin" not in headers + + def test_preflight_allowed(self, client): + headers = _preflight(client, "http://localhost:5173") + assert headers.get("access-control-allow-origin") == "http://localhost:5173" + + def test_preflight_blocked(self, client): + headers = _preflight(client, "http://evil.com") + assert "access-control-allow-origin" not in headers + + def test_credentials_header_present(self, client): + headers = _get_with_origin(client, "http://localhost:5173") + assert headers.get("access-control-allow-credentials") == "true" + + +class TestCORSCustomOrigins: + """VOICEBOX_CORS_ORIGINS env var should extend the allowlist.""" + + def test_custom_origin_allowed(self, client_with_custom_origins): + headers = _get_with_origin(client_with_custom_origins, "https://custom.example.com") + assert headers.get("access-control-allow-origin") == "https://custom.example.com" + + def test_other_custom_origin_allowed(self, client_with_custom_origins): + headers = _get_with_origin(client_with_custom_origins, "https://other.example.com") + assert headers.get("access-control-allow-origin") == "https://other.example.com" + + def test_default_origins_still_work(self, client_with_custom_origins): + headers = _get_with_origin(client_with_custom_origins, "http://localhost:5173") + assert headers.get("access-control-allow-origin") == "http://localhost:5173" + + def test_unlisted_origin_still_blocked(self, client_with_custom_origins): + headers = _get_with_origin(client_with_custom_origins, "http://evil.com") + assert "access-control-allow-origin" not in headers + + +class TestCORSEnvVarParsing: + """Edge cases for VOICEBOX_CORS_ORIGINS parsing.""" + + def test_empty_env_var(self): + app = _build_app("") + client = TestClient(app) + headers = _get_with_origin(client, "http://evil.com") + assert "access-control-allow-origin" not in headers + + def test_whitespace_trimmed(self): + app = _build_app(" https://spaced.example.com ") + client = TestClient(app) + headers = _get_with_origin(client, "https://spaced.example.com") + assert headers.get("access-control-allow-origin") == "https://spaced.example.com" + + def test_trailing_comma_ignored(self): + app = _build_app("https://one.example.com,") + client = TestClient(app) + headers = _get_with_origin(client, "https://one.example.com") + assert headers.get("access-control-allow-origin") == "https://one.example.com"