Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
361 changes: 358 additions & 3 deletions bluesky_httpserver/_authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@
from datetime import datetime, timedelta
from typing import Optional

from fastapi import APIRouter, Depends, HTTPException, Request, Response, Security, WebSocket
from fastapi import APIRouter, Depends, Form, HTTPException, Query, Request, Response, Security, WebSocket
from fastapi.openapi.models import APIKey, APIKeyIn
from fastapi.responses import JSONResponse
from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm, SecurityScopes
from fastapi.security.api_key import APIKeyBase, APIKeyCookie, APIKeyQuery
from fastapi.security.utils import get_authorization_scheme_param
from sqlalchemy.exc import IntegrityError

# To hide third-party warning
# .../jose/backends/cryptography_backend.py:18: CryptographyDeprecationWarning:
Expand All @@ -33,7 +34,14 @@
from .authorization._defaults import _DEFAULT_ANONYMOUS_PROVIDER_NAME
from .core import json_or_msgpack
from .database import orm
from .database.core import create_user, latest_principal_activity, lookup_valid_api_key, lookup_valid_session
from .database.core import (
create_user,
latest_principal_activity,
lookup_valid_api_key,
lookup_valid_pending_session_by_device_code,
lookup_valid_pending_session_by_user_code,
lookup_valid_session,
)
from .settings import get_sessionmaker, get_settings
from .utils import (
API_KEY_COOKIE_NAME,
Expand All @@ -48,6 +56,10 @@
ALGORITHM = "HS256"
UNIT_SECOND = timedelta(seconds=1)

# Device code flow constants
DEVICE_CODE_MAX_AGE = timedelta(minutes=10)
DEVICE_CODE_POLLING_INTERVAL = 5 # seconds


def utcnow():
"UTC now with second resolution"
Expand Down Expand Up @@ -505,6 +517,349 @@ async def handle_credentials(
return handle_credentials


def create_pending_session(db):
"""
Create a pending session for device code flow.

Returns a dict with 'user_code' (user-facing code) and 'device_code' (for polling).
"""
device_code = secrets.token_bytes(32)
hashed_device_code = hashlib.sha256(device_code).digest()
for _ in range(3):
user_code = secrets.token_hex(4).upper() # 8 digit code
pending_session = orm.PendingSession(
user_code=user_code,
hashed_device_code=hashed_device_code,
expiration_time=utcnow() + DEVICE_CODE_MAX_AGE,
)
db.add(pending_session)
try:
db.commit()
except IntegrityError:
# Since the user_code is short, we cannot completely dismiss the
# possibility of a collision. Retry.
db.rollback()
continue
break
formatted_user_code = f"{user_code[:4]}-{user_code[4:]}"
return {
"user_code": formatted_user_code,
"device_code": device_code.hex(),
}


def build_authorize_route(authenticator, provider):
"""Build a GET route that redirects the browser to the OIDC provider for authentication."""

async def authorize_redirect(
request: Request,
state: Optional[str] = Query(None),
):
"""Redirect browser to OAuth provider for authentication."""
redirect_uri = f"{get_base_url(request)}/auth/provider/{provider}/code"

params = {
"client_id": authenticator.client_id,
"response_type": "code",
"scope": "openid profile email",
"redirect_uri": redirect_uri,
}
if state:
params["state"] = state

auth_url = authenticator.authorization_endpoint.copy_with(params=params)
return RedirectResponse(url=str(auth_url))

return authorize_redirect


def build_device_code_authorize_route(authenticator, provider):
"""Build a POST route that initiates the device code flow for CLI/headless clients."""

async def device_code_authorize(
request: Request,
settings: BaseSettings = Depends(get_settings),
):
"""
Initiate device code flow.

Returns authorization_uri for the user to visit in browser,
and device_code + user_code for the CLI client to poll.
"""
request.state.endpoint = "auth"
with get_sessionmaker(settings.database_settings)() as db:
pending_session = create_pending_session(db)

verification_uri = f"{get_base_url(request)}/auth/provider/{provider}/token"
authorization_uri = authenticator.authorization_endpoint.copy_with(
params={
"client_id": authenticator.client_id,
"response_type": "code",
"scope": "openid profile email",
"redirect_uri": f"{get_base_url(request)}/auth/provider/{provider}/device_code",
}
)
return {
"authorization_uri": str(authorization_uri), # URL that user should visit in browser
"verification_uri": str(verification_uri), # URL that terminal client will poll
"interval": DEVICE_CODE_POLLING_INTERVAL, # suggested polling interval
"device_code": pending_session["device_code"],
"expires_in": int(DEVICE_CODE_MAX_AGE.total_seconds()), # seconds
"user_code": pending_session["user_code"],
}

return device_code_authorize


def build_device_code_form_route(authenticator, provider):
"""Build a GET route that shows the user code entry form."""

async def device_code_form(
request: Request,
code: str,
):
"""Show form for user to enter user code after browser auth."""
action = f"{get_base_url(request)}/auth/provider/{provider}/device_code?code={code}"
html_content = f"""
<!DOCTYPE html>
<html>
<head>
<title>Authorize Session</title>
<style>
body {{ font-family: Arial, sans-serif; max-width: 600px; margin: 50px auto; padding: 20px; }}
h1 {{ color: #333; }}
form {{ margin-top: 20px; }}
label {{ display: block; margin-bottom: 10px; }}
input[type="text"] {{ padding: 10px; font-size: 16px; width: 200px; text-transform: uppercase; }}
input[type="submit"] {{ padding: 10px 20px; font-size: 16px; background-color: #007bff; color: white; border: none; cursor: pointer; margin-top: 10px; }}
input[type="submit"]:hover {{ background-color: #0056b3; }}
</style>
</head>
<body>
<h1>Authorize Bluesky HTTP Server Session</h1>
<form action="{action}" method="post">
<label for="user_code">Enter code:</label>
<input type="text" id="user_code" name="user_code" placeholder="XXXX-XXXX" />
<input type="hidden" id="code" name="code" value="{code}" />
<br/>
<input type="submit" value="Authorize" />
</form>
</body>
</html>
"""
return HTMLResponse(content=html_content)

return device_code_form


def build_device_code_submit_route(authenticator, provider):
"""Build a POST route that handles user code submission after browser auth."""

async def device_code_submit(
request: Request,
code: str = Form(),
user_code: str = Form(),
settings: BaseSettings = Depends(get_settings),
api_access_manager=Depends(get_api_access_manager),
):
"""Handle user code submission and link to authenticated session."""
request.state.endpoint = "auth"
action = f"{get_base_url(request)}/auth/provider/{provider}/device_code?code={code}"
normalized_user_code = user_code.upper().replace("-", "").strip()

with get_sessionmaker(settings.database_settings)() as db:
pending_session = lookup_valid_pending_session_by_user_code(db, normalized_user_code)
if pending_session is None:
error_html = f"""
<!DOCTYPE html>
<html>
<head><title>Error</title>
<style>body {{ font-family: Arial, sans-serif; max-width: 600px; margin: 50px auto; padding: 20px; }}
.error {{ background-color: #f8d7da; border: 1px solid #f5c6cb; padding: 15px; border-radius: 5px; color: #721c24; }}</style>
</head>
<body>
<h1>Authorization Failed</h1>
<div class="error">Invalid user code. It may have been mistyped, or the pending request may have expired.</div>
<br/><a href="{action.rsplit('?', 1)[0]}?code={code}">Try again</a>
</body>
</html>
"""
return HTMLResponse(content=error_html, status_code=401)

# Authenticate with the OIDC provider using the authorization code
user_session_state = await authenticator.authenticate(request)
if not user_session_state:
error_html = """
<!DOCTYPE html>
<html>
<head><title>Authentication Failed</title>
<style>body {{ font-family: Arial, sans-serif; max-width: 600px; margin: 50px auto; padding: 20px; }}
.error {{ background-color: #f8d7da; border: 1px solid #f5c6cb; padding: 15px; border-radius: 5px; color: #721c24; }}</style>
</head>
<body>
<h1>Authentication Failed</h1>
<div class="error">User code was correct but authentication with the identity provider failed. Please contact the administrator.</div>
</body>
</html>
"""
return HTMLResponse(content=error_html, status_code=401)

username = user_session_state.user_name
if not api_access_manager.is_user_known(username):
error_html = f"""
<!DOCTYPE html>
<html>
<head><title>Authorization Failed</title>
<style>body {{ font-family: Arial, sans-serif; max-width: 600px; margin: 50px auto; padding: 20px; }}
.error {{ background-color: #f8d7da; border: 1px solid #f5c6cb; padding: 15px; border-radius: 5px; color: #721c24; }}</style>
</head>
<body>
<h1>Authorization Failed</h1>
<div class="error">User '{username}' is not authorized to access this server.</div>
</body>
</html>
"""
return HTMLResponse(content=error_html, status_code=403)

# Create the session
session = await asyncio.get_running_loop().run_in_executor(
None, _create_session_orm, settings, provider, username, db
)

# Link the pending session to the real session
pending_session.session_id = session.id
db.add(pending_session)
db.commit()

success_html = f"""
<!DOCTYPE html>
<html>
<head><title>Success</title>
<style>body {{ font-family: Arial, sans-serif; max-width: 600px; margin: 50px auto; padding: 20px; }}
.success {{ background-color: #d4edda; border: 1px solid #c3e6cb; padding: 15px; border-radius: 5px; color: #155724; }}</style>
</head>
<body>
<h1>Success!</h1>
<div class="success">You have been authenticated. Return to your terminal application - within {DEVICE_CODE_POLLING_INTERVAL} seconds it should be successfully logged in.</div>
</body>
</html>
"""
return HTMLResponse(content=success_html)

return device_code_submit


def _create_session_orm(settings, identity_provider, id, db):
"""
Create a session and return the ORM object (for device code flow).

Unlike create_session(), this returns the ORM object so we can link it
to the pending session.
"""
# Have we seen this Identity before?
identity = (
db.query(orm.Identity)
.filter(orm.Identity.id == id)
.filter(orm.Identity.provider == identity_provider)
.first()
)
now = utcnow()
if identity is None:
# We have not. Make a new Principal and link this new Identity to it.
principal = create_user(db, identity_provider, id)
(new_identity,) = principal.identities
new_identity.latest_login = now
else:
identity.latest_login = now
principal = identity.principal

session = orm.Session(
principal_id=principal.id,
expiration_time=utcnow() + settings.session_max_age,
)
db.add(session)
db.commit()
db.refresh(session)
return session


def build_device_code_token_route(authenticator, provider):
"""Build a POST route for the CLI client to poll for tokens."""

async def device_code_token(
request: Request,
body: schemas.DeviceCode,
settings: BaseSettings = Depends(get_settings),
api_access_manager=Depends(get_api_access_manager),
):
"""
Poll for tokens after device code flow authentication.

Returns tokens if the user has authenticated, or 400 with
'authorization_pending' error if still waiting.
"""
request.state.endpoint = "auth"
device_code_hex = body.device_code
try:
device_code = bytes.fromhex(device_code_hex)
except Exception:
# Not valid hex, therefore not a valid device_code
raise HTTPException(status_code=401, detail="Invalid device code")

with get_sessionmaker(settings.database_settings)() as db:
pending_session = lookup_valid_pending_session_by_device_code(db, device_code)
if pending_session is None:
raise HTTPException(
status_code=404,
detail="No such device_code. The pending request may have expired.",
)
if pending_session.session_id is None:
raise HTTPException(status_code=400, detail={"error": "authorization_pending"})

session = pending_session.session
principal = session.principal

# Get scopes for the user
# Find an identity to get the username
identity = db.query(orm.Identity).filter(orm.Identity.principal_id == principal.id).first()
if identity and api_access_manager.is_user_known(identity.id):
scopes = api_access_manager.get_user_scopes(identity.id)
else:
scopes = set()

# The pending session can only be used once
db.delete(pending_session)
db.commit()

# Generate tokens
data = {
"sub": principal.uuid.hex,
"sub_typ": principal.type.value,
"scp": list(scopes),
"ids": [{"id": ident.id, "idp": ident.provider} for ident in principal.identities],
}
access_token = create_access_token(
data=data,
expires_delta=settings.access_token_max_age,
secret_key=settings.secret_keys[0],
)
refresh_token = create_refresh_token(
session_id=session.uuid.hex,
expires_delta=settings.refresh_token_max_age,
secret_key=settings.secret_keys[0],
)

return {
"access_token": access_token,
"expires_in": int(settings.access_token_max_age / UNIT_SECOND),
"refresh_token": refresh_token,
"refresh_token_expires_in": int(settings.refresh_token_max_age / UNIT_SECOND),
"token_type": "bearer",
}

return device_code_token


def generate_apikey(db, principal, apikey_params, request, allowed_scopes, source_api_key_scopes):
# Use API key scopes if API key is generated based on existing API key, otherwise used allowed scopes
if (source_api_key_scopes is not None) and ("inherit" not in source_api_key_scopes):
Expand Down
Loading