diff --git a/bluesky_httpserver/_authentication.py b/bluesky_httpserver/_authentication.py
index 0375794..c745dff 100644
--- a/bluesky_httpserver/_authentication.py
+++ b/bluesky_httpserver/_authentication.py
@@ -6,12 +6,13 @@
from datetime import datetime, timedelta
from typing import Optional
-from fastapi import APIRouter, Depends, HTTPException, Request, Response, Security, WebSocket
+from fastapi import APIRouter, Depends, Form, HTTPException, Query, Request, Response, Security, WebSocket
from fastapi.openapi.models import APIKey, APIKeyIn
-from fastapi.responses import JSONResponse
+from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm, SecurityScopes
from fastapi.security.api_key import APIKeyBase, APIKeyCookie, APIKeyQuery
from fastapi.security.utils import get_authorization_scheme_param
+from sqlalchemy.exc import IntegrityError
# To hide third-party warning
# .../jose/backends/cryptography_backend.py:18: CryptographyDeprecationWarning:
@@ -33,7 +34,14 @@
from .authorization._defaults import _DEFAULT_ANONYMOUS_PROVIDER_NAME
from .core import json_or_msgpack
from .database import orm
-from .database.core import create_user, latest_principal_activity, lookup_valid_api_key, lookup_valid_session
+from .database.core import (
+ create_user,
+ latest_principal_activity,
+ lookup_valid_api_key,
+ lookup_valid_pending_session_by_device_code,
+ lookup_valid_pending_session_by_user_code,
+ lookup_valid_session,
+)
from .settings import get_sessionmaker, get_settings
from .utils import (
API_KEY_COOKIE_NAME,
@@ -48,6 +56,10 @@
ALGORITHM = "HS256"
UNIT_SECOND = timedelta(seconds=1)
+# Device code flow constants
+DEVICE_CODE_MAX_AGE = timedelta(minutes=10)
+DEVICE_CODE_POLLING_INTERVAL = 5 # seconds
+
def utcnow():
"UTC now with second resolution"
@@ -505,6 +517,349 @@ async def handle_credentials(
return handle_credentials
+def create_pending_session(db):
+ """
+ Create a pending session for device code flow.
+
+ Returns a dict with 'user_code' (user-facing code) and 'device_code' (for polling).
+ """
+ device_code = secrets.token_bytes(32)
+ hashed_device_code = hashlib.sha256(device_code).digest()
+ for _ in range(3):
+ user_code = secrets.token_hex(4).upper() # 8 digit code
+ pending_session = orm.PendingSession(
+ user_code=user_code,
+ hashed_device_code=hashed_device_code,
+ expiration_time=utcnow() + DEVICE_CODE_MAX_AGE,
+ )
+ db.add(pending_session)
+ try:
+ db.commit()
+ except IntegrityError:
+ # Since the user_code is short, we cannot completely dismiss the
+ # possibility of a collision. Retry.
+ db.rollback()
+ continue
+ break
+ formatted_user_code = f"{user_code[:4]}-{user_code[4:]}"
+ return {
+ "user_code": formatted_user_code,
+ "device_code": device_code.hex(),
+ }
+
+
+def build_authorize_route(authenticator, provider):
+ """Build a GET route that redirects the browser to the OIDC provider for authentication."""
+
+ async def authorize_redirect(
+ request: Request,
+ state: Optional[str] = Query(None),
+ ):
+ """Redirect browser to OAuth provider for authentication."""
+ redirect_uri = f"{get_base_url(request)}/auth/provider/{provider}/code"
+
+ params = {
+ "client_id": authenticator.client_id,
+ "response_type": "code",
+ "scope": "openid profile email",
+ "redirect_uri": redirect_uri,
+ }
+ if state:
+ params["state"] = state
+
+ auth_url = authenticator.authorization_endpoint.copy_with(params=params)
+ return RedirectResponse(url=str(auth_url))
+
+ return authorize_redirect
+
+
+def build_device_code_authorize_route(authenticator, provider):
+ """Build a POST route that initiates the device code flow for CLI/headless clients."""
+
+ async def device_code_authorize(
+ request: Request,
+ settings: BaseSettings = Depends(get_settings),
+ ):
+ """
+ Initiate device code flow.
+
+ Returns authorization_uri for the user to visit in browser,
+ and device_code + user_code for the CLI client to poll.
+ """
+ request.state.endpoint = "auth"
+ with get_sessionmaker(settings.database_settings)() as db:
+ pending_session = create_pending_session(db)
+
+ verification_uri = f"{get_base_url(request)}/auth/provider/{provider}/token"
+ authorization_uri = authenticator.authorization_endpoint.copy_with(
+ params={
+ "client_id": authenticator.client_id,
+ "response_type": "code",
+ "scope": "openid profile email",
+ "redirect_uri": f"{get_base_url(request)}/auth/provider/{provider}/device_code",
+ }
+ )
+ return {
+ "authorization_uri": str(authorization_uri), # URL that user should visit in browser
+ "verification_uri": str(verification_uri), # URL that terminal client will poll
+ "interval": DEVICE_CODE_POLLING_INTERVAL, # suggested polling interval
+ "device_code": pending_session["device_code"],
+ "expires_in": int(DEVICE_CODE_MAX_AGE.total_seconds()), # seconds
+ "user_code": pending_session["user_code"],
+ }
+
+ return device_code_authorize
+
+
+def build_device_code_form_route(authenticator, provider):
+ """Build a GET route that shows the user code entry form."""
+
+ async def device_code_form(
+ request: Request,
+ code: str,
+ ):
+ """Show form for user to enter user code after browser auth."""
+ action = f"{get_base_url(request)}/auth/provider/{provider}/device_code?code={code}"
+ html_content = f"""
+
+
+
+ Authorize Session
+
+
+
+ Authorize Bluesky HTTP Server Session
+
+
+
+"""
+ return HTMLResponse(content=html_content)
+
+ return device_code_form
+
+
+def build_device_code_submit_route(authenticator, provider):
+ """Build a POST route that handles user code submission after browser auth."""
+
+ async def device_code_submit(
+ request: Request,
+ code: str = Form(),
+ user_code: str = Form(),
+ settings: BaseSettings = Depends(get_settings),
+ api_access_manager=Depends(get_api_access_manager),
+ ):
+ """Handle user code submission and link to authenticated session."""
+ request.state.endpoint = "auth"
+ action = f"{get_base_url(request)}/auth/provider/{provider}/device_code?code={code}"
+ normalized_user_code = user_code.upper().replace("-", "").strip()
+
+ with get_sessionmaker(settings.database_settings)() as db:
+ pending_session = lookup_valid_pending_session_by_user_code(db, normalized_user_code)
+ if pending_session is None:
+ error_html = f"""
+
+
+Error
+
+
+
+ Authorization Failed
+ Invalid user code. It may have been mistyped, or the pending request may have expired.
+
Try again
+
+
+"""
+ return HTMLResponse(content=error_html, status_code=401)
+
+ # Authenticate with the OIDC provider using the authorization code
+ user_session_state = await authenticator.authenticate(request)
+ if not user_session_state:
+ error_html = """
+
+
+Authentication Failed
+
+
+
+ Authentication Failed
+ User code was correct but authentication with the identity provider failed. Please contact the administrator.
+
+
+"""
+ return HTMLResponse(content=error_html, status_code=401)
+
+ username = user_session_state.user_name
+ if not api_access_manager.is_user_known(username):
+ error_html = f"""
+
+
+Authorization Failed
+
+
+
+ Authorization Failed
+ User '{username}' is not authorized to access this server.
+
+
+"""
+ return HTMLResponse(content=error_html, status_code=403)
+
+ # Create the session
+ session = await asyncio.get_running_loop().run_in_executor(
+ None, _create_session_orm, settings, provider, username, db
+ )
+
+ # Link the pending session to the real session
+ pending_session.session_id = session.id
+ db.add(pending_session)
+ db.commit()
+
+ success_html = f"""
+
+
+Success
+
+
+
+ Success!
+ You have been authenticated. Return to your terminal application - within {DEVICE_CODE_POLLING_INTERVAL} seconds it should be successfully logged in.
+
+
+"""
+ return HTMLResponse(content=success_html)
+
+ return device_code_submit
+
+
+def _create_session_orm(settings, identity_provider, id, db):
+ """
+ Create a session and return the ORM object (for device code flow).
+
+ Unlike create_session(), this returns the ORM object so we can link it
+ to the pending session.
+ """
+ # Have we seen this Identity before?
+ identity = (
+ db.query(orm.Identity)
+ .filter(orm.Identity.id == id)
+ .filter(orm.Identity.provider == identity_provider)
+ .first()
+ )
+ now = utcnow()
+ if identity is None:
+ # We have not. Make a new Principal and link this new Identity to it.
+ principal = create_user(db, identity_provider, id)
+ (new_identity,) = principal.identities
+ new_identity.latest_login = now
+ else:
+ identity.latest_login = now
+ principal = identity.principal
+
+ session = orm.Session(
+ principal_id=principal.id,
+ expiration_time=utcnow() + settings.session_max_age,
+ )
+ db.add(session)
+ db.commit()
+ db.refresh(session)
+ return session
+
+
+def build_device_code_token_route(authenticator, provider):
+ """Build a POST route for the CLI client to poll for tokens."""
+
+ async def device_code_token(
+ request: Request,
+ body: schemas.DeviceCode,
+ settings: BaseSettings = Depends(get_settings),
+ api_access_manager=Depends(get_api_access_manager),
+ ):
+ """
+ Poll for tokens after device code flow authentication.
+
+ Returns tokens if the user has authenticated, or 400 with
+ 'authorization_pending' error if still waiting.
+ """
+ request.state.endpoint = "auth"
+ device_code_hex = body.device_code
+ try:
+ device_code = bytes.fromhex(device_code_hex)
+ except Exception:
+ # Not valid hex, therefore not a valid device_code
+ raise HTTPException(status_code=401, detail="Invalid device code")
+
+ with get_sessionmaker(settings.database_settings)() as db:
+ pending_session = lookup_valid_pending_session_by_device_code(db, device_code)
+ if pending_session is None:
+ raise HTTPException(
+ status_code=404,
+ detail="No such device_code. The pending request may have expired.",
+ )
+ if pending_session.session_id is None:
+ raise HTTPException(status_code=400, detail={"error": "authorization_pending"})
+
+ session = pending_session.session
+ principal = session.principal
+
+ # Get scopes for the user
+ # Find an identity to get the username
+ identity = db.query(orm.Identity).filter(orm.Identity.principal_id == principal.id).first()
+ if identity and api_access_manager.is_user_known(identity.id):
+ scopes = api_access_manager.get_user_scopes(identity.id)
+ else:
+ scopes = set()
+
+ # The pending session can only be used once
+ db.delete(pending_session)
+ db.commit()
+
+ # Generate tokens
+ data = {
+ "sub": principal.uuid.hex,
+ "sub_typ": principal.type.value,
+ "scp": list(scopes),
+ "ids": [{"id": ident.id, "idp": ident.provider} for ident in principal.identities],
+ }
+ access_token = create_access_token(
+ data=data,
+ expires_delta=settings.access_token_max_age,
+ secret_key=settings.secret_keys[0],
+ )
+ refresh_token = create_refresh_token(
+ session_id=session.uuid.hex,
+ expires_delta=settings.refresh_token_max_age,
+ secret_key=settings.secret_keys[0],
+ )
+
+ return {
+ "access_token": access_token,
+ "expires_in": int(settings.access_token_max_age / UNIT_SECOND),
+ "refresh_token": refresh_token,
+ "refresh_token_expires_in": int(settings.refresh_token_max_age / UNIT_SECOND),
+ "token_type": "bearer",
+ }
+
+ return device_code_token
+
+
def generate_apikey(db, principal, apikey_params, request, allowed_scopes, source_api_key_scopes):
# Use API key scopes if API key is generated based on existing API key, otherwise used allowed scopes
if (source_api_key_scopes is not None) and ("inherit" not in source_api_key_scopes):
diff --git a/bluesky_httpserver/app.py b/bluesky_httpserver/app.py
index 9a8420a..0d96667 100644
--- a/bluesky_httpserver/app.py
+++ b/bluesky_httpserver/app.py
@@ -160,6 +160,11 @@ def build_app(authentication=None, api_access=None, resource_access=None, server
from .authentication import (
base_authentication_router,
build_auth_code_route,
+ build_authorize_route,
+ build_device_code_authorize_route,
+ build_device_code_form_route,
+ build_device_code_submit_route,
+ build_device_code_token_route,
build_handle_credentials_route,
oauth2_scheme,
)
@@ -184,12 +189,34 @@ def build_app(authentication=None, api_access=None, resource_access=None, server
build_handle_credentials_route(authenticator, provider)
)
elif isinstance(authenticator, ExternalAuthenticator):
+ # Standard OAuth callback route (authorization code flow)
authentication_router.get(f"/provider/{provider}/code")(
build_auth_code_route(authenticator, provider)
)
authentication_router.post(f"/provider/{provider}/code")(
build_auth_code_route(authenticator, provider)
)
+ # Device code flow routes for CLI/headless clients
+ # GET /authorize - redirects browser to OIDC provider
+ authentication_router.get(f"/provider/{provider}/authorize")(
+ build_authorize_route(authenticator, provider)
+ )
+ # POST /authorize - initiates device code flow (returns device_code, user_code, etc.)
+ authentication_router.post(f"/provider/{provider}/authorize")(
+ build_device_code_authorize_route(authenticator, provider)
+ )
+ # GET /device_code - shows user code entry form
+ authentication_router.get(f"/provider/{provider}/device_code")(
+ build_device_code_form_route(authenticator, provider)
+ )
+ # POST /device_code - handles user code submission after browser auth
+ authentication_router.post(f"/provider/{provider}/device_code")(
+ build_device_code_submit_route(authenticator, provider)
+ )
+ # POST /token - CLI client polls this for tokens
+ authentication_router.post(f"/provider/{provider}/token")(
+ build_device_code_token_route(authenticator, provider)
+ )
else:
raise ValueError(f"unknown authenticator type {type(authenticator)}")
for custom_router in getattr(authenticator, "include_routers", []):
diff --git a/bluesky_httpserver/authentication/__init__.py b/bluesky_httpserver/authentication/__init__.py
index fc35cdd..85d835e 100644
--- a/bluesky_httpserver/authentication/__init__.py
+++ b/bluesky_httpserver/authentication/__init__.py
@@ -1,6 +1,11 @@
from .._authentication import (
base_authentication_router,
build_auth_code_route,
+ build_authorize_route,
+ build_device_code_authorize_route,
+ build_device_code_form_route,
+ build_device_code_submit_route,
+ build_device_code_token_route,
build_handle_credentials_route,
get_current_principal,
get_current_principal_websocket,
@@ -20,6 +25,11 @@
"get_current_principal_websocket",
"base_authentication_router",
"build_auth_code_route",
+ "build_authorize_route",
+ "build_device_code_authorize_route",
+ "build_device_code_form_route",
+ "build_device_code_submit_route",
+ "build_device_code_token_route",
"build_handle_credentials_route",
"oauth2_scheme",
]
diff --git a/bluesky_httpserver/authenticators.py b/bluesky_httpserver/authenticators.py
index 78b6cf1..a58fedf 100644
--- a/bluesky_httpserver/authenticators.py
+++ b/bluesky_httpserver/authenticators.py
@@ -222,18 +222,38 @@ async def authenticate(self, request: Request) -> Optional[UserSessionState]:
if response.is_error:
logger.error("Authentication error: %r", response_body)
return None
- response_body = response.json()
id_token = response_body["id_token"]
- access_token = response_body["access_token"]
+ # NOTE: We decode the id_token, not access_token, because:
+ # 1. The id_token is the OIDC identity assertion meant for the client
+ # 2. Some providers (like Microsoft Entra) return opaque access_tokens
+ # that cannot be decoded with the JWKS keys when the resource is
+ # a first-party Microsoft API (e.g., Graph API with User.Read scope)
try:
- verified_body = self.decode_token(access_token)
+ verified_body = self.decode_token(id_token)
except JWTError:
logger.exception(
"Authentication error. Unverified token: %r",
jwt.get_unverified_claims(id_token),
)
return None
- return UserSessionState(verified_body["sub"], {})
+ # Use preferred_username as the user identifier, extracting just the username
+ # part if it's in email format (user@domain.com -> user)
+ preferred_username = verified_body.get("preferred_username")
+ if preferred_username and "@" in preferred_username:
+ user_id = preferred_username.split("@")[0]
+ elif preferred_username:
+ user_id = preferred_username
+ else:
+ user_id = verified_body["sub"]
+ logger.info(
+ "OIDC authentication successful. user_id=%r (sub=%r, preferred_username=%r, email=%r, name=%r)",
+ user_id,
+ verified_body.get("sub"),
+ verified_body.get("preferred_username"),
+ verified_body.get("email"),
+ verified_body.get("name"),
+ )
+ return UserSessionState(user_id, {})
class ProxiedOIDCAuthenticator(OIDCAuthenticator):
diff --git a/bluesky_httpserver/config_schemas/examples/oidc_config.yml b/bluesky_httpserver/config_schemas/examples/oidc_config.yml
new file mode 100644
index 0000000..c2f8d24
--- /dev/null
+++ b/bluesky_httpserver/config_schemas/examples/oidc_config.yml
@@ -0,0 +1,78 @@
+# Example OIDC Configuration for Bluesky HTTP Server
+#
+# This example shows how to configure OIDC (OpenID Connect) authentication.
+# OIDC is used by providers like Google, Microsoft Entra (Azure AD), Okta, Keycloak, etc.
+#
+# Required environment variables:
+# - OIDC_CLIENT_ID: The client ID from your OIDC provider
+# - OIDC_CLIENT_SECRET: The client secret from your OIDC provider
+# - OIDC_WELL_KNOWN_URI: The .well-known/openid-configuration URL
+#
+# Example for Google:
+# OIDC_WELL_KNOWN_URI=https://accounts.google.com/.well-known/openid-configuration
+#
+# Example for Microsoft Entra (Azure AD):
+# OIDC_WELL_KNOWN_URI=https://login.microsoftonline.com/{tenant-id}/v2.0/.well-known/openid-configuration
+#
+# Example for Keycloak:
+# OIDC_WELL_KNOWN_URI=https://your-keycloak-server/realms/{realm}/.well-known/openid-configuration
+
+authentication:
+ providers:
+ - provider: oidc
+ authenticator: bluesky_httpserver.authenticators:OIDCAuthenticator
+ args:
+ # The audience should match the client_id or be a value expected by your OIDC provider
+ audience: ${OIDC_CLIENT_ID}
+ client_id: ${OIDC_CLIENT_ID}
+ client_secret: ${OIDC_CLIENT_SECRET}
+ well_known_uri: ${OIDC_WELL_KNOWN_URI}
+ confirmation_message: "You have successfully logged in via OIDC as {id}."
+ # Optional: redirect URLs after authentication
+ # redirect_on_success: https://your-app.example.com/success
+ # redirect_on_failure: https://your-app.example.com/login-failed
+
+ # Secret keys used to sign secure tokens (generate with: openssl rand -hex 32)
+ secret_keys:
+ - ${SECRET_KEY}
+
+ # Allow unauthenticated access to public endpoints
+ allow_anonymous_access: false
+
+ # Token lifetimes (in seconds)
+ access_token_max_age: 900 # 15 minutes
+ refresh_token_max_age: 604800 # 7 days
+
+# Database for storing sessions and API keys
+database:
+ uri: ${DATABASE_URI}
+ pool_size: 5
+ pool_pre_ping: true
+
+# API access control - configure which users have access
+api_access:
+ policy: bluesky_httpserver.authorization:DictionaryAPIAccessControl
+ args:
+ users:
+ # Add users identified by their OIDC subject ID (sub claim)
+ # The ID typically looks like an email or UUID depending on your OIDC provider
+ user@example.com:
+ roles:
+ - admin
+ - user
+
+# Resource access control
+resource_access:
+ policy: bluesky_httpserver.authorization:DefaultResourceAccessControl
+ args:
+ default_group: root
+
+# Queue Server connection
+qserver_zmq_configuration:
+ control_address: tcp://localhost:60615
+ info_address: tcp://localhost:60625
+
+# HTTP Server configuration
+uvicorn:
+ host: 0.0.0.0
+ port: 8000
diff --git a/bluesky_httpserver/config_schemas/service_configuration.yml b/bluesky_httpserver/config_schemas/service_configuration.yml
index 57343f7..12f01a3 100644
--- a/bluesky_httpserver/config_schemas/service_configuration.yml
+++ b/bluesky_httpserver/config_schemas/service_configuration.yml
@@ -47,14 +47,14 @@ properties:
properties:
custom_routers:
type: array
- item:
+ items:
type: string
description: |
The list of Python modules with custom routers. Overrides the list of modules set using
QSERVER_HTTP_CUSTOM_ROUTERS environment variable.
custom_modules:
type: array
- item:
+ items:
type: string
description: |
THE FUNCTIONALITY WILL BE DEPRECATED IN FAVOR OF CUSTOM ROUTERS. Overrides the list of modules
@@ -65,7 +65,7 @@ properties:
properties:
providers:
type: array
- item:
+ items:
type: object
additionalProperties: false
required:
@@ -83,7 +83,7 @@ properties:
description: |
Type of Authenticator to use.
- These are typically from the tiled.authenticators module,
+ These are typically from the bluesky_httpserver.authenticators module,
though user-defined ones may be used as well.
This is given as an import path. In an import path, packages/modules
@@ -92,21 +92,21 @@ properties:
Example:
```yaml
- authenticator: bluesky_httpserver.examples.DummyAuthenticator
+ authenticator: bluesky_httpserver.authenticators:DummyAuthenticator
```
- args:
- type: [object, "null"]
- description: |
- Named arguments to pass to Authenticator. If there are none,
- `args` may be omitted or empty.
+ args:
+ type: object
+ description: |
+ Named arguments to pass to Authenticator. If there are none,
+ `args` may be omitted or empty.
- Example:
+ Example:
- ```yaml
- authenticator: bluesky_httpserver.examples.PAMAuthenticator
- args:
- service: "custom_service"
- ```
+ ```yaml
+ authenticator: bluesky_httpserver.authenticators:PAMAuthenticator
+ args:
+ service: "custom_service"
+ ```
# qserver_admins:
# type: array
# items:
diff --git a/bluesky_httpserver/database/core.py b/bluesky_httpserver/database/core.py
index 163fac3..f096edc 100644
--- a/bluesky_httpserver/database/core.py
+++ b/bluesky_httpserver/database/core.py
@@ -1,6 +1,7 @@
import hashlib
import uuid as uuid_module
from datetime import datetime
+from typing import Optional
from alembic import command
from alembic.config import Config
@@ -10,13 +11,13 @@
from .alembic_utils import temp_alembic_ini
from .base import Base
-from .orm import APIKey, Identity, Principal, Session # , Role
+from .orm import APIKey, Identity, PendingSession, Principal, Session # , Role
# This is the alembic revision ID of the database revision
# required by this version of Tiled.
-REQUIRED_REVISION = "722ff4e4fcc7"
+REQUIRED_REVISION = "a1b2c3d4e5f6"
# This is list of all valid revisions (from current to oldest).
-ALL_REVISIONS = ["722ff4e4fcc7", "481830dd6c11"]
+ALL_REVISIONS = ["a1b2c3d4e5f6", "722ff4e4fcc7", "481830dd6c11"]
# def create_default_roles(engine):
@@ -294,3 +295,36 @@ def latest_principal_activity(db, principal):
if all([t is None for t in all_activity]):
return None
return max(t for t in all_activity if t is not None)
+
+
+def lookup_valid_pending_session_by_device_code(db, device_code: bytes) -> Optional[PendingSession]:
+ """
+ Look up a pending session by its device code.
+
+ Returns None if the pending session is not found or has expired.
+ """
+ hashed_device_code = hashlib.sha256(device_code).digest()
+ pending_session = db.query(PendingSession).filter(PendingSession.hashed_device_code == hashed_device_code).first()
+ if pending_session is None:
+ return None
+ if pending_session.expiration_time is not None and pending_session.expiration_time < datetime.utcnow():
+ db.delete(pending_session)
+ db.commit()
+ return None
+ return pending_session
+
+
+def lookup_valid_pending_session_by_user_code(db, user_code: str) -> Optional[PendingSession]:
+ """
+ Look up a pending session by its user code.
+
+ Returns None if the pending session is not found or has expired.
+ """
+ pending_session = db.query(PendingSession).filter(PendingSession.user_code == user_code).first()
+ if pending_session is None:
+ return None
+ if pending_session.expiration_time is not None and pending_session.expiration_time < datetime.utcnow():
+ db.delete(pending_session)
+ db.commit()
+ return None
+ return pending_session
diff --git a/bluesky_httpserver/database/orm.py b/bluesky_httpserver/database/orm.py
index 17d7c82..7611824 100644
--- a/bluesky_httpserver/database/orm.py
+++ b/bluesky_httpserver/database/orm.py
@@ -181,3 +181,24 @@ class Session(Timestamped, Base):
revoked = Column(Boolean, default=False, nullable=False)
principal = relationship("Principal", back_populates="sessions")
+ pending_sessions = relationship("PendingSession", back_populates="session")
+
+
+class PendingSession(Timestamped, Base):
+ """
+ This is used only in Device Code Flow for OIDC authentication.
+
+ When a CLI client initiates the device code flow, a pending session is created
+ with a device_code (for the client to poll) and a user_code (for the user to
+ enter in the browser). Once the user authenticates, the pending session is
+ linked to a real session, which the polling client then receives.
+ """
+
+ __tablename__ = "pending_sessions"
+
+ hashed_device_code = Column(LargeBinary(32), primary_key=True, index=True, nullable=False)
+ user_code = Column(Unicode(8), index=True, nullable=False)
+ expiration_time = Column(DateTime(timezone=False), nullable=False)
+ session_id = Column(Integer, ForeignKey("sessions.id"), nullable=True)
+
+ session = relationship("Session", back_populates="pending_sessions")
diff --git a/bluesky_httpserver/schemas.py b/bluesky_httpserver/schemas.py
index c52d8f2..f1d9fcb 100644
--- a/bluesky_httpserver/schemas.py
+++ b/bluesky_httpserver/schemas.py
@@ -163,6 +163,23 @@ class RefreshToken(pydantic.BaseModel):
refresh_token: str
+class DeviceCode(pydantic.BaseModel):
+ """Schema for device code token polling request."""
+
+ device_code: str
+
+
+class DeviceCodeResponse(pydantic.BaseModel):
+ """Schema for device code flow initiation response."""
+
+ authorization_uri: str
+ verification_uri: str
+ device_code: str
+ user_code: str
+ expires_in: int
+ interval: int
+
+
class AuthenticationMode(str, enum.Enum):
password = "password"
external = "external"
diff --git a/bluesky_httpserver/tests/conftest.py b/bluesky_httpserver/tests/conftest.py
index ec69415..3c43529 100644
--- a/bluesky_httpserver/tests/conftest.py
+++ b/bluesky_httpserver/tests/conftest.py
@@ -195,3 +195,28 @@ def wait_for_ip_kernel_idle(timeout, polling_period=0.2, api_key=API_KEY_FOR_TES
return True
return False
+
+
+# ============================================================================
+# OIDC Test Fixtures
+# ============================================================================
+
+@pytest.fixture
+def oidc_base_url() -> str:
+ """Base URL for mock OIDC provider."""
+ return "https://example.com/realms/example/"
+
+
+@pytest.fixture
+def well_known_response(oidc_base_url: str) -> dict:
+ """Mock OIDC well-known configuration response."""
+ return {
+ "id_token_signing_alg_values_supported": ["RS256"],
+ "issuer": oidc_base_url.rstrip("/"),
+ "jwks_uri": f"{oidc_base_url}protocol/openid-connect/certs",
+ "authorization_endpoint": f"{oidc_base_url}protocol/openid-connect/auth",
+ "token_endpoint": f"{oidc_base_url}protocol/openid-connect/token",
+ "device_authorization_endpoint": f"{oidc_base_url}protocol/openid-connect/auth/device",
+ "end_session_endpoint": f"{oidc_base_url}protocol/openid-connect/logout",
+ }
+
diff --git a/bluesky_httpserver/tests/test_oidc_authenticators.py b/bluesky_httpserver/tests/test_oidc_authenticators.py
new file mode 100644
index 0000000..30303e4
--- /dev/null
+++ b/bluesky_httpserver/tests/test_oidc_authenticators.py
@@ -0,0 +1,224 @@
+"""Tests for OIDC Authenticator functionality."""
+
+import time
+from typing import Any, Tuple
+
+import httpx
+import pytest
+from cryptography.hazmat.primitives.asymmetric import rsa
+from jose import ExpiredSignatureError, jwt
+from jose.backends import RSAKey
+from respx import MockRouter
+
+from bluesky_httpserver.authenticators import OIDCAuthenticator, ProxiedOIDCAuthenticator
+
+
+@pytest.fixture
+def oidc_well_known_url(oidc_base_url: str) -> str:
+ return f"{oidc_base_url}.well-known/openid-configuration"
+
+
+@pytest.fixture
+def keys() -> Tuple[rsa.RSAPrivateKey, rsa.RSAPublicKey]:
+ """Generate RSA key pair for testing."""
+ private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
+ public_key = private_key.public_key()
+ return (private_key, public_key)
+
+
+@pytest.fixture
+def json_web_keyset(keys: Tuple[rsa.RSAPrivateKey, rsa.RSAPublicKey]) -> list[dict[str, Any]]:
+ """Create a JSON Web Key Set from the test keys."""
+ _, public_key = keys
+ return [RSAKey(key=public_key, algorithm="RS256").to_dict()]
+
+
+@pytest.fixture
+def mock_oidc_server(
+ respx_mock: MockRouter,
+ oidc_well_known_url: str,
+ well_known_response: dict[str, Any],
+ json_web_keyset: list[dict[str, Any]],
+) -> MockRouter:
+ """Set up mock OIDC server endpoints."""
+ respx_mock.get(oidc_well_known_url).mock(
+ return_value=httpx.Response(httpx.codes.OK, json=well_known_response)
+ )
+ respx_mock.get(well_known_response["jwks_uri"]).mock(
+ return_value=httpx.Response(httpx.codes.OK, json={"keys": json_web_keyset})
+ )
+ return respx_mock
+
+
+def create_token(issued: bool, expired: bool) -> dict[str, Any]:
+ """Create a test JWT token."""
+ now = time.time()
+ return {
+ "aud": "test_client",
+ "exp": (now - 1500) if expired else (now + 1500),
+ "iat": (now - 1500) if issued else (now + 1500),
+ "iss": "https://example.com/realms/example",
+ "sub": "test_user",
+ }
+
+
+def encrypt_token(token: dict[str, Any], private_key: rsa.RSAPrivateKey) -> str:
+ """Encrypt a token with the test private key."""
+ return jwt.encode(
+ token,
+ key=private_key,
+ algorithm="RS256",
+ headers={"kid": "test_key"},
+ )
+
+
+@pytest.mark.filterwarnings("ignore::DeprecationWarning")
+class TestOIDCAuthenticator:
+ """Tests for OIDCAuthenticator class."""
+
+ def test_oidc_authenticator_caching(
+ self,
+ mock_oidc_server: MockRouter,
+ oidc_well_known_url: str,
+ well_known_response: dict[str, Any],
+ json_web_keyset: list[dict[str, Any]],
+ ):
+ """Test that OIDC configuration is cached after first fetch."""
+ authenticator = OIDCAuthenticator(
+ audience="test_client",
+ client_id="test_client",
+ client_secret="secret",
+ well_known_uri=oidc_well_known_url,
+ )
+
+ # Access multiple properties to ensure caching works
+ assert authenticator.client_id == "test_client"
+ assert authenticator.authorization_endpoint == well_known_response["authorization_endpoint"]
+ assert (
+ authenticator.id_token_signing_alg_values_supported
+ == well_known_response["id_token_signing_alg_values_supported"]
+ )
+ assert authenticator.issuer == well_known_response["issuer"]
+ assert authenticator.jwks_uri == well_known_response["jwks_uri"]
+ assert authenticator.token_endpoint == well_known_response["token_endpoint"]
+ assert (
+ authenticator.device_authorization_endpoint
+ == well_known_response["device_authorization_endpoint"]
+ )
+ assert authenticator.end_session_endpoint == well_known_response["end_session_endpoint"]
+
+ # Should only call well-known endpoint once due to caching
+ assert len(mock_oidc_server.calls) == 1
+ call_request = mock_oidc_server.calls[0].request
+ assert call_request.method == "GET"
+ assert call_request.url == oidc_well_known_url
+
+ # Keys should also be cached
+ assert authenticator.keys() == json_web_keyset
+ assert len(mock_oidc_server.calls) == 2 # Now also fetched JWKS
+
+ # Multiple calls should still be cached
+ for _ in range(5):
+ assert authenticator.keys() == json_web_keyset
+ assert len(mock_oidc_server.calls) == 2 # No new calls
+
+ @pytest.mark.parametrize("issued", [True, False])
+ @pytest.mark.parametrize("expired", [True, False])
+ def test_oidc_token_decoding(
+ self,
+ mock_oidc_server: MockRouter,
+ oidc_well_known_url: str,
+ issued: bool,
+ expired: bool,
+ keys: Tuple[rsa.RSAPrivateKey, rsa.RSAPublicKey],
+ ):
+ """Test token decoding with various validity scenarios."""
+ private_key, _ = keys
+ authenticator = OIDCAuthenticator(
+ audience="test_client",
+ client_id="test_client",
+ client_secret="secret",
+ well_known_uri=oidc_well_known_url,
+ )
+
+ token = create_token(issued, expired)
+ encrypted = encrypt_token(token, private_key)
+
+ if not expired:
+ # Non-expired tokens should decode successfully
+ decoded = authenticator.decode_token(encrypted)
+ assert decoded["sub"] == "test_user"
+ assert decoded["aud"] == "test_client"
+ else:
+ # Expired tokens should raise an error
+ with pytest.raises(ExpiredSignatureError):
+ authenticator.decode_token(encrypted)
+
+ def test_oidc_authenticator_properties(
+ self,
+ mock_oidc_server: MockRouter,
+ oidc_well_known_url: str,
+ well_known_response: dict[str, Any],
+ ):
+ """Test that all authenticator properties are correctly set."""
+ authenticator = OIDCAuthenticator(
+ audience="my_audience",
+ client_id="my_client_id",
+ client_secret="my_secret",
+ well_known_uri=oidc_well_known_url,
+ confirmation_message="Logged in as {id}",
+ redirect_on_success="https://app.example.com/success",
+ redirect_on_failure="https://app.example.com/failure",
+ )
+
+ assert authenticator.client_id == "my_client_id"
+ assert authenticator.confirmation_message == "Logged in as {id}"
+ assert authenticator.redirect_on_success == "https://app.example.com/success"
+ assert authenticator.redirect_on_failure == "https://app.example.com/failure"
+
+
+@pytest.mark.filterwarnings("ignore::DeprecationWarning")
+class TestProxiedOIDCAuthenticator:
+ """Tests for ProxiedOIDCAuthenticator class."""
+
+ @pytest.mark.asyncio
+ async def test_proxied_oidc_oauth2_schema(
+ self,
+ mock_oidc_server: MockRouter,
+ oidc_well_known_url: str,
+ ):
+ """Test that ProxiedOIDCAuthenticator extracts bearer token correctly."""
+ authenticator = ProxiedOIDCAuthenticator(
+ audience="test_client",
+ client_id="test_client",
+ well_known_uri=oidc_well_known_url,
+ device_flow_client_id="test_cli_client",
+ )
+
+ # Create a mock request with Authorization header
+ test_request = httpx.Request(
+ "GET",
+ "http://example.com/api/test",
+ headers={"Authorization": "Bearer TEST_TOKEN"},
+ )
+
+ # The oauth2_schema should extract the bearer token
+ token = await authenticator.oauth2_schema(test_request)
+ assert token == "TEST_TOKEN"
+
+ def test_proxied_oidc_with_scopes(
+ self,
+ mock_oidc_server: MockRouter,
+ oidc_well_known_url: str,
+ ):
+ """Test ProxiedOIDCAuthenticator with custom scopes."""
+ authenticator = ProxiedOIDCAuthenticator(
+ audience="test_client",
+ client_id="test_client",
+ well_known_uri=oidc_well_known_url,
+ device_flow_client_id="test_cli_client",
+ scopes=["openid", "profile", "email"],
+ )
+
+ assert authenticator.scopes == ["openid", "profile", "email"]
+ assert authenticator.device_flow_client_id == "test_cli_client"
diff --git a/requirements-dev.txt b/requirements-dev.txt
index dd7212a..e47dd72 100644
--- a/requirements-dev.txt
+++ b/requirements-dev.txt
@@ -3,13 +3,16 @@
black
codecov
coverage
+cryptography
fastapi[all]
flake8
isort
pre-commit
pytest
+pytest-asyncio
pytest-xprocess
py
+respx
sphinx
ipython
numpydoc
diff --git a/requirements.txt b/requirements.txt
index 818362f..1377ef0 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -3,6 +3,7 @@ bluesky-queueserver
bluesky-queueserver-api
cachetools
fastapi
+httpx
ldap3
orjson
pamela