diff --git a/bluesky_httpserver/_authentication.py b/bluesky_httpserver/_authentication.py index 0375794..c745dff 100644 --- a/bluesky_httpserver/_authentication.py +++ b/bluesky_httpserver/_authentication.py @@ -6,12 +6,13 @@ from datetime import datetime, timedelta from typing import Optional -from fastapi import APIRouter, Depends, HTTPException, Request, Response, Security, WebSocket +from fastapi import APIRouter, Depends, Form, HTTPException, Query, Request, Response, Security, WebSocket from fastapi.openapi.models import APIKey, APIKeyIn -from fastapi.responses import JSONResponse +from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm, SecurityScopes from fastapi.security.api_key import APIKeyBase, APIKeyCookie, APIKeyQuery from fastapi.security.utils import get_authorization_scheme_param +from sqlalchemy.exc import IntegrityError # To hide third-party warning # .../jose/backends/cryptography_backend.py:18: CryptographyDeprecationWarning: @@ -33,7 +34,14 @@ from .authorization._defaults import _DEFAULT_ANONYMOUS_PROVIDER_NAME from .core import json_or_msgpack from .database import orm -from .database.core import create_user, latest_principal_activity, lookup_valid_api_key, lookup_valid_session +from .database.core import ( + create_user, + latest_principal_activity, + lookup_valid_api_key, + lookup_valid_pending_session_by_device_code, + lookup_valid_pending_session_by_user_code, + lookup_valid_session, +) from .settings import get_sessionmaker, get_settings from .utils import ( API_KEY_COOKIE_NAME, @@ -48,6 +56,10 @@ ALGORITHM = "HS256" UNIT_SECOND = timedelta(seconds=1) +# Device code flow constants +DEVICE_CODE_MAX_AGE = timedelta(minutes=10) +DEVICE_CODE_POLLING_INTERVAL = 5 # seconds + def utcnow(): "UTC now with second resolution" @@ -505,6 +517,349 @@ async def handle_credentials( return handle_credentials +def create_pending_session(db): + """ + Create a pending session for device code flow. + + Returns a dict with 'user_code' (user-facing code) and 'device_code' (for polling). + """ + device_code = secrets.token_bytes(32) + hashed_device_code = hashlib.sha256(device_code).digest() + for _ in range(3): + user_code = secrets.token_hex(4).upper() # 8 digit code + pending_session = orm.PendingSession( + user_code=user_code, + hashed_device_code=hashed_device_code, + expiration_time=utcnow() + DEVICE_CODE_MAX_AGE, + ) + db.add(pending_session) + try: + db.commit() + except IntegrityError: + # Since the user_code is short, we cannot completely dismiss the + # possibility of a collision. Retry. + db.rollback() + continue + break + formatted_user_code = f"{user_code[:4]}-{user_code[4:]}" + return { + "user_code": formatted_user_code, + "device_code": device_code.hex(), + } + + +def build_authorize_route(authenticator, provider): + """Build a GET route that redirects the browser to the OIDC provider for authentication.""" + + async def authorize_redirect( + request: Request, + state: Optional[str] = Query(None), + ): + """Redirect browser to OAuth provider for authentication.""" + redirect_uri = f"{get_base_url(request)}/auth/provider/{provider}/code" + + params = { + "client_id": authenticator.client_id, + "response_type": "code", + "scope": "openid profile email", + "redirect_uri": redirect_uri, + } + if state: + params["state"] = state + + auth_url = authenticator.authorization_endpoint.copy_with(params=params) + return RedirectResponse(url=str(auth_url)) + + return authorize_redirect + + +def build_device_code_authorize_route(authenticator, provider): + """Build a POST route that initiates the device code flow for CLI/headless clients.""" + + async def device_code_authorize( + request: Request, + settings: BaseSettings = Depends(get_settings), + ): + """ + Initiate device code flow. + + Returns authorization_uri for the user to visit in browser, + and device_code + user_code for the CLI client to poll. + """ + request.state.endpoint = "auth" + with get_sessionmaker(settings.database_settings)() as db: + pending_session = create_pending_session(db) + + verification_uri = f"{get_base_url(request)}/auth/provider/{provider}/token" + authorization_uri = authenticator.authorization_endpoint.copy_with( + params={ + "client_id": authenticator.client_id, + "response_type": "code", + "scope": "openid profile email", + "redirect_uri": f"{get_base_url(request)}/auth/provider/{provider}/device_code", + } + ) + return { + "authorization_uri": str(authorization_uri), # URL that user should visit in browser + "verification_uri": str(verification_uri), # URL that terminal client will poll + "interval": DEVICE_CODE_POLLING_INTERVAL, # suggested polling interval + "device_code": pending_session["device_code"], + "expires_in": int(DEVICE_CODE_MAX_AGE.total_seconds()), # seconds + "user_code": pending_session["user_code"], + } + + return device_code_authorize + + +def build_device_code_form_route(authenticator, provider): + """Build a GET route that shows the user code entry form.""" + + async def device_code_form( + request: Request, + code: str, + ): + """Show form for user to enter user code after browser auth.""" + action = f"{get_base_url(request)}/auth/provider/{provider}/device_code?code={code}" + html_content = f""" + + + + Authorize Session + + + +

Authorize Bluesky HTTP Server Session

+
+ + + +
+ +
+ + +""" + return HTMLResponse(content=html_content) + + return device_code_form + + +def build_device_code_submit_route(authenticator, provider): + """Build a POST route that handles user code submission after browser auth.""" + + async def device_code_submit( + request: Request, + code: str = Form(), + user_code: str = Form(), + settings: BaseSettings = Depends(get_settings), + api_access_manager=Depends(get_api_access_manager), + ): + """Handle user code submission and link to authenticated session.""" + request.state.endpoint = "auth" + action = f"{get_base_url(request)}/auth/provider/{provider}/device_code?code={code}" + normalized_user_code = user_code.upper().replace("-", "").strip() + + with get_sessionmaker(settings.database_settings)() as db: + pending_session = lookup_valid_pending_session_by_user_code(db, normalized_user_code) + if pending_session is None: + error_html = f""" + + +Error + + + +

Authorization Failed

+
Invalid user code. It may have been mistyped, or the pending request may have expired.
+
Try again + + +""" + return HTMLResponse(content=error_html, status_code=401) + + # Authenticate with the OIDC provider using the authorization code + user_session_state = await authenticator.authenticate(request) + if not user_session_state: + error_html = """ + + +Authentication Failed + + + +

Authentication Failed

+
User code was correct but authentication with the identity provider failed. Please contact the administrator.
+ + +""" + return HTMLResponse(content=error_html, status_code=401) + + username = user_session_state.user_name + if not api_access_manager.is_user_known(username): + error_html = f""" + + +Authorization Failed + + + +

Authorization Failed

+
User '{username}' is not authorized to access this server.
+ + +""" + return HTMLResponse(content=error_html, status_code=403) + + # Create the session + session = await asyncio.get_running_loop().run_in_executor( + None, _create_session_orm, settings, provider, username, db + ) + + # Link the pending session to the real session + pending_session.session_id = session.id + db.add(pending_session) + db.commit() + + success_html = f""" + + +Success + + + +

Success!

+
You have been authenticated. Return to your terminal application - within {DEVICE_CODE_POLLING_INTERVAL} seconds it should be successfully logged in.
+ + +""" + return HTMLResponse(content=success_html) + + return device_code_submit + + +def _create_session_orm(settings, identity_provider, id, db): + """ + Create a session and return the ORM object (for device code flow). + + Unlike create_session(), this returns the ORM object so we can link it + to the pending session. + """ + # Have we seen this Identity before? + identity = ( + db.query(orm.Identity) + .filter(orm.Identity.id == id) + .filter(orm.Identity.provider == identity_provider) + .first() + ) + now = utcnow() + if identity is None: + # We have not. Make a new Principal and link this new Identity to it. + principal = create_user(db, identity_provider, id) + (new_identity,) = principal.identities + new_identity.latest_login = now + else: + identity.latest_login = now + principal = identity.principal + + session = orm.Session( + principal_id=principal.id, + expiration_time=utcnow() + settings.session_max_age, + ) + db.add(session) + db.commit() + db.refresh(session) + return session + + +def build_device_code_token_route(authenticator, provider): + """Build a POST route for the CLI client to poll for tokens.""" + + async def device_code_token( + request: Request, + body: schemas.DeviceCode, + settings: BaseSettings = Depends(get_settings), + api_access_manager=Depends(get_api_access_manager), + ): + """ + Poll for tokens after device code flow authentication. + + Returns tokens if the user has authenticated, or 400 with + 'authorization_pending' error if still waiting. + """ + request.state.endpoint = "auth" + device_code_hex = body.device_code + try: + device_code = bytes.fromhex(device_code_hex) + except Exception: + # Not valid hex, therefore not a valid device_code + raise HTTPException(status_code=401, detail="Invalid device code") + + with get_sessionmaker(settings.database_settings)() as db: + pending_session = lookup_valid_pending_session_by_device_code(db, device_code) + if pending_session is None: + raise HTTPException( + status_code=404, + detail="No such device_code. The pending request may have expired.", + ) + if pending_session.session_id is None: + raise HTTPException(status_code=400, detail={"error": "authorization_pending"}) + + session = pending_session.session + principal = session.principal + + # Get scopes for the user + # Find an identity to get the username + identity = db.query(orm.Identity).filter(orm.Identity.principal_id == principal.id).first() + if identity and api_access_manager.is_user_known(identity.id): + scopes = api_access_manager.get_user_scopes(identity.id) + else: + scopes = set() + + # The pending session can only be used once + db.delete(pending_session) + db.commit() + + # Generate tokens + data = { + "sub": principal.uuid.hex, + "sub_typ": principal.type.value, + "scp": list(scopes), + "ids": [{"id": ident.id, "idp": ident.provider} for ident in principal.identities], + } + access_token = create_access_token( + data=data, + expires_delta=settings.access_token_max_age, + secret_key=settings.secret_keys[0], + ) + refresh_token = create_refresh_token( + session_id=session.uuid.hex, + expires_delta=settings.refresh_token_max_age, + secret_key=settings.secret_keys[0], + ) + + return { + "access_token": access_token, + "expires_in": int(settings.access_token_max_age / UNIT_SECOND), + "refresh_token": refresh_token, + "refresh_token_expires_in": int(settings.refresh_token_max_age / UNIT_SECOND), + "token_type": "bearer", + } + + return device_code_token + + def generate_apikey(db, principal, apikey_params, request, allowed_scopes, source_api_key_scopes): # Use API key scopes if API key is generated based on existing API key, otherwise used allowed scopes if (source_api_key_scopes is not None) and ("inherit" not in source_api_key_scopes): diff --git a/bluesky_httpserver/app.py b/bluesky_httpserver/app.py index 9a8420a..0d96667 100644 --- a/bluesky_httpserver/app.py +++ b/bluesky_httpserver/app.py @@ -160,6 +160,11 @@ def build_app(authentication=None, api_access=None, resource_access=None, server from .authentication import ( base_authentication_router, build_auth_code_route, + build_authorize_route, + build_device_code_authorize_route, + build_device_code_form_route, + build_device_code_submit_route, + build_device_code_token_route, build_handle_credentials_route, oauth2_scheme, ) @@ -184,12 +189,34 @@ def build_app(authentication=None, api_access=None, resource_access=None, server build_handle_credentials_route(authenticator, provider) ) elif isinstance(authenticator, ExternalAuthenticator): + # Standard OAuth callback route (authorization code flow) authentication_router.get(f"/provider/{provider}/code")( build_auth_code_route(authenticator, provider) ) authentication_router.post(f"/provider/{provider}/code")( build_auth_code_route(authenticator, provider) ) + # Device code flow routes for CLI/headless clients + # GET /authorize - redirects browser to OIDC provider + authentication_router.get(f"/provider/{provider}/authorize")( + build_authorize_route(authenticator, provider) + ) + # POST /authorize - initiates device code flow (returns device_code, user_code, etc.) + authentication_router.post(f"/provider/{provider}/authorize")( + build_device_code_authorize_route(authenticator, provider) + ) + # GET /device_code - shows user code entry form + authentication_router.get(f"/provider/{provider}/device_code")( + build_device_code_form_route(authenticator, provider) + ) + # POST /device_code - handles user code submission after browser auth + authentication_router.post(f"/provider/{provider}/device_code")( + build_device_code_submit_route(authenticator, provider) + ) + # POST /token - CLI client polls this for tokens + authentication_router.post(f"/provider/{provider}/token")( + build_device_code_token_route(authenticator, provider) + ) else: raise ValueError(f"unknown authenticator type {type(authenticator)}") for custom_router in getattr(authenticator, "include_routers", []): diff --git a/bluesky_httpserver/authentication/__init__.py b/bluesky_httpserver/authentication/__init__.py index fc35cdd..85d835e 100644 --- a/bluesky_httpserver/authentication/__init__.py +++ b/bluesky_httpserver/authentication/__init__.py @@ -1,6 +1,11 @@ from .._authentication import ( base_authentication_router, build_auth_code_route, + build_authorize_route, + build_device_code_authorize_route, + build_device_code_form_route, + build_device_code_submit_route, + build_device_code_token_route, build_handle_credentials_route, get_current_principal, get_current_principal_websocket, @@ -20,6 +25,11 @@ "get_current_principal_websocket", "base_authentication_router", "build_auth_code_route", + "build_authorize_route", + "build_device_code_authorize_route", + "build_device_code_form_route", + "build_device_code_submit_route", + "build_device_code_token_route", "build_handle_credentials_route", "oauth2_scheme", ] diff --git a/bluesky_httpserver/authenticators.py b/bluesky_httpserver/authenticators.py index 78b6cf1..a58fedf 100644 --- a/bluesky_httpserver/authenticators.py +++ b/bluesky_httpserver/authenticators.py @@ -222,18 +222,38 @@ async def authenticate(self, request: Request) -> Optional[UserSessionState]: if response.is_error: logger.error("Authentication error: %r", response_body) return None - response_body = response.json() id_token = response_body["id_token"] - access_token = response_body["access_token"] + # NOTE: We decode the id_token, not access_token, because: + # 1. The id_token is the OIDC identity assertion meant for the client + # 2. Some providers (like Microsoft Entra) return opaque access_tokens + # that cannot be decoded with the JWKS keys when the resource is + # a first-party Microsoft API (e.g., Graph API with User.Read scope) try: - verified_body = self.decode_token(access_token) + verified_body = self.decode_token(id_token) except JWTError: logger.exception( "Authentication error. Unverified token: %r", jwt.get_unverified_claims(id_token), ) return None - return UserSessionState(verified_body["sub"], {}) + # Use preferred_username as the user identifier, extracting just the username + # part if it's in email format (user@domain.com -> user) + preferred_username = verified_body.get("preferred_username") + if preferred_username and "@" in preferred_username: + user_id = preferred_username.split("@")[0] + elif preferred_username: + user_id = preferred_username + else: + user_id = verified_body["sub"] + logger.info( + "OIDC authentication successful. user_id=%r (sub=%r, preferred_username=%r, email=%r, name=%r)", + user_id, + verified_body.get("sub"), + verified_body.get("preferred_username"), + verified_body.get("email"), + verified_body.get("name"), + ) + return UserSessionState(user_id, {}) class ProxiedOIDCAuthenticator(OIDCAuthenticator): diff --git a/bluesky_httpserver/config_schemas/examples/oidc_config.yml b/bluesky_httpserver/config_schemas/examples/oidc_config.yml new file mode 100644 index 0000000..c2f8d24 --- /dev/null +++ b/bluesky_httpserver/config_schemas/examples/oidc_config.yml @@ -0,0 +1,78 @@ +# Example OIDC Configuration for Bluesky HTTP Server +# +# This example shows how to configure OIDC (OpenID Connect) authentication. +# OIDC is used by providers like Google, Microsoft Entra (Azure AD), Okta, Keycloak, etc. +# +# Required environment variables: +# - OIDC_CLIENT_ID: The client ID from your OIDC provider +# - OIDC_CLIENT_SECRET: The client secret from your OIDC provider +# - OIDC_WELL_KNOWN_URI: The .well-known/openid-configuration URL +# +# Example for Google: +# OIDC_WELL_KNOWN_URI=https://accounts.google.com/.well-known/openid-configuration +# +# Example for Microsoft Entra (Azure AD): +# OIDC_WELL_KNOWN_URI=https://login.microsoftonline.com/{tenant-id}/v2.0/.well-known/openid-configuration +# +# Example for Keycloak: +# OIDC_WELL_KNOWN_URI=https://your-keycloak-server/realms/{realm}/.well-known/openid-configuration + +authentication: + providers: + - provider: oidc + authenticator: bluesky_httpserver.authenticators:OIDCAuthenticator + args: + # The audience should match the client_id or be a value expected by your OIDC provider + audience: ${OIDC_CLIENT_ID} + client_id: ${OIDC_CLIENT_ID} + client_secret: ${OIDC_CLIENT_SECRET} + well_known_uri: ${OIDC_WELL_KNOWN_URI} + confirmation_message: "You have successfully logged in via OIDC as {id}." + # Optional: redirect URLs after authentication + # redirect_on_success: https://your-app.example.com/success + # redirect_on_failure: https://your-app.example.com/login-failed + + # Secret keys used to sign secure tokens (generate with: openssl rand -hex 32) + secret_keys: + - ${SECRET_KEY} + + # Allow unauthenticated access to public endpoints + allow_anonymous_access: false + + # Token lifetimes (in seconds) + access_token_max_age: 900 # 15 minutes + refresh_token_max_age: 604800 # 7 days + +# Database for storing sessions and API keys +database: + uri: ${DATABASE_URI} + pool_size: 5 + pool_pre_ping: true + +# API access control - configure which users have access +api_access: + policy: bluesky_httpserver.authorization:DictionaryAPIAccessControl + args: + users: + # Add users identified by their OIDC subject ID (sub claim) + # The ID typically looks like an email or UUID depending on your OIDC provider + user@example.com: + roles: + - admin + - user + +# Resource access control +resource_access: + policy: bluesky_httpserver.authorization:DefaultResourceAccessControl + args: + default_group: root + +# Queue Server connection +qserver_zmq_configuration: + control_address: tcp://localhost:60615 + info_address: tcp://localhost:60625 + +# HTTP Server configuration +uvicorn: + host: 0.0.0.0 + port: 8000 diff --git a/bluesky_httpserver/config_schemas/service_configuration.yml b/bluesky_httpserver/config_schemas/service_configuration.yml index 57343f7..12f01a3 100644 --- a/bluesky_httpserver/config_schemas/service_configuration.yml +++ b/bluesky_httpserver/config_schemas/service_configuration.yml @@ -47,14 +47,14 @@ properties: properties: custom_routers: type: array - item: + items: type: string description: | The list of Python modules with custom routers. Overrides the list of modules set using QSERVER_HTTP_CUSTOM_ROUTERS environment variable. custom_modules: type: array - item: + items: type: string description: | THE FUNCTIONALITY WILL BE DEPRECATED IN FAVOR OF CUSTOM ROUTERS. Overrides the list of modules @@ -65,7 +65,7 @@ properties: properties: providers: type: array - item: + items: type: object additionalProperties: false required: @@ -83,7 +83,7 @@ properties: description: | Type of Authenticator to use. - These are typically from the tiled.authenticators module, + These are typically from the bluesky_httpserver.authenticators module, though user-defined ones may be used as well. This is given as an import path. In an import path, packages/modules @@ -92,21 +92,21 @@ properties: Example: ```yaml - authenticator: bluesky_httpserver.examples.DummyAuthenticator + authenticator: bluesky_httpserver.authenticators:DummyAuthenticator ``` - args: - type: [object, "null"] - description: | - Named arguments to pass to Authenticator. If there are none, - `args` may be omitted or empty. + args: + type: object + description: | + Named arguments to pass to Authenticator. If there are none, + `args` may be omitted or empty. - Example: + Example: - ```yaml - authenticator: bluesky_httpserver.examples.PAMAuthenticator - args: - service: "custom_service" - ``` + ```yaml + authenticator: bluesky_httpserver.authenticators:PAMAuthenticator + args: + service: "custom_service" + ``` # qserver_admins: # type: array # items: diff --git a/bluesky_httpserver/database/core.py b/bluesky_httpserver/database/core.py index 163fac3..f096edc 100644 --- a/bluesky_httpserver/database/core.py +++ b/bluesky_httpserver/database/core.py @@ -1,6 +1,7 @@ import hashlib import uuid as uuid_module from datetime import datetime +from typing import Optional from alembic import command from alembic.config import Config @@ -10,13 +11,13 @@ from .alembic_utils import temp_alembic_ini from .base import Base -from .orm import APIKey, Identity, Principal, Session # , Role +from .orm import APIKey, Identity, PendingSession, Principal, Session # , Role # This is the alembic revision ID of the database revision # required by this version of Tiled. -REQUIRED_REVISION = "722ff4e4fcc7" +REQUIRED_REVISION = "a1b2c3d4e5f6" # This is list of all valid revisions (from current to oldest). -ALL_REVISIONS = ["722ff4e4fcc7", "481830dd6c11"] +ALL_REVISIONS = ["a1b2c3d4e5f6", "722ff4e4fcc7", "481830dd6c11"] # def create_default_roles(engine): @@ -294,3 +295,36 @@ def latest_principal_activity(db, principal): if all([t is None for t in all_activity]): return None return max(t for t in all_activity if t is not None) + + +def lookup_valid_pending_session_by_device_code(db, device_code: bytes) -> Optional[PendingSession]: + """ + Look up a pending session by its device code. + + Returns None if the pending session is not found or has expired. + """ + hashed_device_code = hashlib.sha256(device_code).digest() + pending_session = db.query(PendingSession).filter(PendingSession.hashed_device_code == hashed_device_code).first() + if pending_session is None: + return None + if pending_session.expiration_time is not None and pending_session.expiration_time < datetime.utcnow(): + db.delete(pending_session) + db.commit() + return None + return pending_session + + +def lookup_valid_pending_session_by_user_code(db, user_code: str) -> Optional[PendingSession]: + """ + Look up a pending session by its user code. + + Returns None if the pending session is not found or has expired. + """ + pending_session = db.query(PendingSession).filter(PendingSession.user_code == user_code).first() + if pending_session is None: + return None + if pending_session.expiration_time is not None and pending_session.expiration_time < datetime.utcnow(): + db.delete(pending_session) + db.commit() + return None + return pending_session diff --git a/bluesky_httpserver/database/orm.py b/bluesky_httpserver/database/orm.py index 17d7c82..7611824 100644 --- a/bluesky_httpserver/database/orm.py +++ b/bluesky_httpserver/database/orm.py @@ -181,3 +181,24 @@ class Session(Timestamped, Base): revoked = Column(Boolean, default=False, nullable=False) principal = relationship("Principal", back_populates="sessions") + pending_sessions = relationship("PendingSession", back_populates="session") + + +class PendingSession(Timestamped, Base): + """ + This is used only in Device Code Flow for OIDC authentication. + + When a CLI client initiates the device code flow, a pending session is created + with a device_code (for the client to poll) and a user_code (for the user to + enter in the browser). Once the user authenticates, the pending session is + linked to a real session, which the polling client then receives. + """ + + __tablename__ = "pending_sessions" + + hashed_device_code = Column(LargeBinary(32), primary_key=True, index=True, nullable=False) + user_code = Column(Unicode(8), index=True, nullable=False) + expiration_time = Column(DateTime(timezone=False), nullable=False) + session_id = Column(Integer, ForeignKey("sessions.id"), nullable=True) + + session = relationship("Session", back_populates="pending_sessions") diff --git a/bluesky_httpserver/schemas.py b/bluesky_httpserver/schemas.py index c52d8f2..f1d9fcb 100644 --- a/bluesky_httpserver/schemas.py +++ b/bluesky_httpserver/schemas.py @@ -163,6 +163,23 @@ class RefreshToken(pydantic.BaseModel): refresh_token: str +class DeviceCode(pydantic.BaseModel): + """Schema for device code token polling request.""" + + device_code: str + + +class DeviceCodeResponse(pydantic.BaseModel): + """Schema for device code flow initiation response.""" + + authorization_uri: str + verification_uri: str + device_code: str + user_code: str + expires_in: int + interval: int + + class AuthenticationMode(str, enum.Enum): password = "password" external = "external" diff --git a/bluesky_httpserver/tests/conftest.py b/bluesky_httpserver/tests/conftest.py index ec69415..3c43529 100644 --- a/bluesky_httpserver/tests/conftest.py +++ b/bluesky_httpserver/tests/conftest.py @@ -195,3 +195,28 @@ def wait_for_ip_kernel_idle(timeout, polling_period=0.2, api_key=API_KEY_FOR_TES return True return False + + +# ============================================================================ +# OIDC Test Fixtures +# ============================================================================ + +@pytest.fixture +def oidc_base_url() -> str: + """Base URL for mock OIDC provider.""" + return "https://example.com/realms/example/" + + +@pytest.fixture +def well_known_response(oidc_base_url: str) -> dict: + """Mock OIDC well-known configuration response.""" + return { + "id_token_signing_alg_values_supported": ["RS256"], + "issuer": oidc_base_url.rstrip("/"), + "jwks_uri": f"{oidc_base_url}protocol/openid-connect/certs", + "authorization_endpoint": f"{oidc_base_url}protocol/openid-connect/auth", + "token_endpoint": f"{oidc_base_url}protocol/openid-connect/token", + "device_authorization_endpoint": f"{oidc_base_url}protocol/openid-connect/auth/device", + "end_session_endpoint": f"{oidc_base_url}protocol/openid-connect/logout", + } + diff --git a/bluesky_httpserver/tests/test_oidc_authenticators.py b/bluesky_httpserver/tests/test_oidc_authenticators.py new file mode 100644 index 0000000..30303e4 --- /dev/null +++ b/bluesky_httpserver/tests/test_oidc_authenticators.py @@ -0,0 +1,224 @@ +"""Tests for OIDC Authenticator functionality.""" + +import time +from typing import Any, Tuple + +import httpx +import pytest +from cryptography.hazmat.primitives.asymmetric import rsa +from jose import ExpiredSignatureError, jwt +from jose.backends import RSAKey +from respx import MockRouter + +from bluesky_httpserver.authenticators import OIDCAuthenticator, ProxiedOIDCAuthenticator + + +@pytest.fixture +def oidc_well_known_url(oidc_base_url: str) -> str: + return f"{oidc_base_url}.well-known/openid-configuration" + + +@pytest.fixture +def keys() -> Tuple[rsa.RSAPrivateKey, rsa.RSAPublicKey]: + """Generate RSA key pair for testing.""" + private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048) + public_key = private_key.public_key() + return (private_key, public_key) + + +@pytest.fixture +def json_web_keyset(keys: Tuple[rsa.RSAPrivateKey, rsa.RSAPublicKey]) -> list[dict[str, Any]]: + """Create a JSON Web Key Set from the test keys.""" + _, public_key = keys + return [RSAKey(key=public_key, algorithm="RS256").to_dict()] + + +@pytest.fixture +def mock_oidc_server( + respx_mock: MockRouter, + oidc_well_known_url: str, + well_known_response: dict[str, Any], + json_web_keyset: list[dict[str, Any]], +) -> MockRouter: + """Set up mock OIDC server endpoints.""" + respx_mock.get(oidc_well_known_url).mock( + return_value=httpx.Response(httpx.codes.OK, json=well_known_response) + ) + respx_mock.get(well_known_response["jwks_uri"]).mock( + return_value=httpx.Response(httpx.codes.OK, json={"keys": json_web_keyset}) + ) + return respx_mock + + +def create_token(issued: bool, expired: bool) -> dict[str, Any]: + """Create a test JWT token.""" + now = time.time() + return { + "aud": "test_client", + "exp": (now - 1500) if expired else (now + 1500), + "iat": (now - 1500) if issued else (now + 1500), + "iss": "https://example.com/realms/example", + "sub": "test_user", + } + + +def encrypt_token(token: dict[str, Any], private_key: rsa.RSAPrivateKey) -> str: + """Encrypt a token with the test private key.""" + return jwt.encode( + token, + key=private_key, + algorithm="RS256", + headers={"kid": "test_key"}, + ) + + +@pytest.mark.filterwarnings("ignore::DeprecationWarning") +class TestOIDCAuthenticator: + """Tests for OIDCAuthenticator class.""" + + def test_oidc_authenticator_caching( + self, + mock_oidc_server: MockRouter, + oidc_well_known_url: str, + well_known_response: dict[str, Any], + json_web_keyset: list[dict[str, Any]], + ): + """Test that OIDC configuration is cached after first fetch.""" + authenticator = OIDCAuthenticator( + audience="test_client", + client_id="test_client", + client_secret="secret", + well_known_uri=oidc_well_known_url, + ) + + # Access multiple properties to ensure caching works + assert authenticator.client_id == "test_client" + assert authenticator.authorization_endpoint == well_known_response["authorization_endpoint"] + assert ( + authenticator.id_token_signing_alg_values_supported + == well_known_response["id_token_signing_alg_values_supported"] + ) + assert authenticator.issuer == well_known_response["issuer"] + assert authenticator.jwks_uri == well_known_response["jwks_uri"] + assert authenticator.token_endpoint == well_known_response["token_endpoint"] + assert ( + authenticator.device_authorization_endpoint + == well_known_response["device_authorization_endpoint"] + ) + assert authenticator.end_session_endpoint == well_known_response["end_session_endpoint"] + + # Should only call well-known endpoint once due to caching + assert len(mock_oidc_server.calls) == 1 + call_request = mock_oidc_server.calls[0].request + assert call_request.method == "GET" + assert call_request.url == oidc_well_known_url + + # Keys should also be cached + assert authenticator.keys() == json_web_keyset + assert len(mock_oidc_server.calls) == 2 # Now also fetched JWKS + + # Multiple calls should still be cached + for _ in range(5): + assert authenticator.keys() == json_web_keyset + assert len(mock_oidc_server.calls) == 2 # No new calls + + @pytest.mark.parametrize("issued", [True, False]) + @pytest.mark.parametrize("expired", [True, False]) + def test_oidc_token_decoding( + self, + mock_oidc_server: MockRouter, + oidc_well_known_url: str, + issued: bool, + expired: bool, + keys: Tuple[rsa.RSAPrivateKey, rsa.RSAPublicKey], + ): + """Test token decoding with various validity scenarios.""" + private_key, _ = keys + authenticator = OIDCAuthenticator( + audience="test_client", + client_id="test_client", + client_secret="secret", + well_known_uri=oidc_well_known_url, + ) + + token = create_token(issued, expired) + encrypted = encrypt_token(token, private_key) + + if not expired: + # Non-expired tokens should decode successfully + decoded = authenticator.decode_token(encrypted) + assert decoded["sub"] == "test_user" + assert decoded["aud"] == "test_client" + else: + # Expired tokens should raise an error + with pytest.raises(ExpiredSignatureError): + authenticator.decode_token(encrypted) + + def test_oidc_authenticator_properties( + self, + mock_oidc_server: MockRouter, + oidc_well_known_url: str, + well_known_response: dict[str, Any], + ): + """Test that all authenticator properties are correctly set.""" + authenticator = OIDCAuthenticator( + audience="my_audience", + client_id="my_client_id", + client_secret="my_secret", + well_known_uri=oidc_well_known_url, + confirmation_message="Logged in as {id}", + redirect_on_success="https://app.example.com/success", + redirect_on_failure="https://app.example.com/failure", + ) + + assert authenticator.client_id == "my_client_id" + assert authenticator.confirmation_message == "Logged in as {id}" + assert authenticator.redirect_on_success == "https://app.example.com/success" + assert authenticator.redirect_on_failure == "https://app.example.com/failure" + + +@pytest.mark.filterwarnings("ignore::DeprecationWarning") +class TestProxiedOIDCAuthenticator: + """Tests for ProxiedOIDCAuthenticator class.""" + + @pytest.mark.asyncio + async def test_proxied_oidc_oauth2_schema( + self, + mock_oidc_server: MockRouter, + oidc_well_known_url: str, + ): + """Test that ProxiedOIDCAuthenticator extracts bearer token correctly.""" + authenticator = ProxiedOIDCAuthenticator( + audience="test_client", + client_id="test_client", + well_known_uri=oidc_well_known_url, + device_flow_client_id="test_cli_client", + ) + + # Create a mock request with Authorization header + test_request = httpx.Request( + "GET", + "http://example.com/api/test", + headers={"Authorization": "Bearer TEST_TOKEN"}, + ) + + # The oauth2_schema should extract the bearer token + token = await authenticator.oauth2_schema(test_request) + assert token == "TEST_TOKEN" + + def test_proxied_oidc_with_scopes( + self, + mock_oidc_server: MockRouter, + oidc_well_known_url: str, + ): + """Test ProxiedOIDCAuthenticator with custom scopes.""" + authenticator = ProxiedOIDCAuthenticator( + audience="test_client", + client_id="test_client", + well_known_uri=oidc_well_known_url, + device_flow_client_id="test_cli_client", + scopes=["openid", "profile", "email"], + ) + + assert authenticator.scopes == ["openid", "profile", "email"] + assert authenticator.device_flow_client_id == "test_cli_client" diff --git a/requirements-dev.txt b/requirements-dev.txt index dd7212a..e47dd72 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -3,13 +3,16 @@ black codecov coverage +cryptography fastapi[all] flake8 isort pre-commit pytest +pytest-asyncio pytest-xprocess py +respx sphinx ipython numpydoc diff --git a/requirements.txt b/requirements.txt index 818362f..1377ef0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,6 +3,7 @@ bluesky-queueserver bluesky-queueserver-api cachetools fastapi +httpx ldap3 orjson pamela