diff --git a/apps/oidc/__init__.py b/apps/oidc/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/apps/oidc/oauth.py b/apps/oidc/oauth.py new file mode 100644 index 000000000..83029da6d --- /dev/null +++ b/apps/oidc/oauth.py @@ -0,0 +1,134 @@ +from typing import TypedDict + +import authlib.oauth2.rfc6749.grants as rfc6749_grants +import authlib.oidc.core.grants as oidc_core_grants +from authlib.integrations.flask_oauth2 import AuthorizationServer +from authlib.integrations.sqla_oauth2 import create_query_client_func, create_save_token_func +from authlib.oauth2.rfc6749 import OAuth2Request +from authlib.oauth2.rfc6749.util import scope_to_list +from authlib.oidc.core.claims import UserInfo +from flask import request + +from main import db +from models.oidc import OAuth2AuthorizationCode, OAuth2Client, OAuth2Token +from models.user import User + +SCOPES: dict[str, str] = { + "openid": "Your anonymous account identifier", + "email": "The email address associated with your account", + "permissions": "Your EMF website permissions", +} + + +def get_issuer(): + return f"https://{request.host}" + + +def save_authorization_code(code: str, request: OAuth2Request) -> OAuth2AuthorizationCode: + client = request.client + item = OAuth2AuthorizationCode( + code=code, + client_id=client.client_id, + redirect_uri=request.payload.redirect_uri, + scope=request.payload.scope, + user_id=request.user.id, + ) + db.session.add(item) + db.session.commit() + return item + + +def exists_nonce(nonce, request) -> bool: + # TODO: implement me! + return False + + +class JWTConfig(TypedDict): + key: str + alg: str + iss: str + exp: int + + +def get_jwt_config(grant) -> JWTConfig: + return JWTConfig(key="foo", alg="HS256", iss=get_issuer(), exp=999999999999) + + +def generate_user_info(user: User, scope) -> UserInfo: + info = { + "sub": user.id, + } + scopes = scope_to_list(scope) + if "email" in scopes: + info["email"] = user.email + if "permissions" in scopes: + info["permissions"] = [p.name for p in user.permissions] + + return UserInfo(info) + + +class AuthorizationCodeGrant(rfc6749_grants.AuthorizationCodeGrant): + def save_authorization_code(self, code, request): + return save_authorization_code(code, request) + + def parse_authorization_code(self, code, client) -> OAuth2AuthorizationCode | None: + authcode = OAuth2AuthorizationCode.query.filter_by(code=code, client_id=client.client_id).first() + if authcode and not authcode.is_expired(): + return authcode + return None + + def delete_authorization_code(self, authorization_code): + db.session.delete(authorization_code) + db.session.commit() + + def authenticate_user(self, authorization_code): + return User.query.get(authorization_code.user_id) + + +class OpenIDCode(oidc_core_grants.OpenIDCode): + def exists_nonce(self, nonce, request): + return exists_nonce(nonce, request) + + def get_jwt_config(self, grant) -> JWTConfig: + return get_jwt_config(grant) + + def generate_user_info(self, user, scope) -> UserInfo: + return generate_user_info(user, scope) + + +class ImplicitGrant(oidc_core_grants.OpenIDImplicitGrant): + def exists_nonce(self, nonce, request) -> bool: + return exists_nonce(nonce, request) + + def get_jwt_config(self) -> JWTConfig: + return get_jwt_config(grant=None) + + def generate_user_info(self, user, scope) -> UserInfo: + return generate_user_info(user, scope) + + +class HybridGrant(oidc_core_grants.OpenIDHybridGrant): + def save_authorization_code(self, code, request): + return save_authorization_code(code, request) + + def exists_nonce(self, nonce, request) -> bool: + return exists_nonce(nonce, request) + + def get_jwt_config(self) -> JWTConfig: + return get_jwt_config(grant=None) + + def generate_user_info(self, user, scope) -> UserInfo: + return generate_user_info(user, scope) + + +authorization = AuthorizationServer() + + +def init_oauth(app): + query_client = create_query_client_func(db.session, OAuth2Client) + save_token = create_save_token_func(db.session, OAuth2Token) + + authorization.init_app(app, query_client=query_client, save_token=save_token) + authorization.register_grant(AuthorizationCodeGrant, [OpenIDCode(require_nonce=True)]) + authorization.register_grant(ImplicitGrant) + authorization.register_grant(HybridGrant) diff --git a/apps/oidc/routes.py b/apps/oidc/routes.py new file mode 100644 index 000000000..5bee1d9d1 --- /dev/null +++ b/apps/oidc/routes.py @@ -0,0 +1,93 @@ +import logging + +import click +from authlib.oauth2 import OAuth2Error +from authlib.oauth2.rfc6749.util import scope_to_list +from authlib.oidc.discovery import OpenIDProviderMetadata +from flask import Blueprint, jsonify, render_template, request, url_for +from flask.typing import ResponseValue +from flask_login import current_user, login_required +from werkzeug.security import gen_salt + +from main import db +from models.oidc import OAuth2Client + +from .oauth import SCOPES, authorization, get_issuer + +logger = logging.getLogger(__name__) +oidc = Blueprint("oidc", "oidc") + + +@oidc.cli.command("create_client") +@click.option("--name", type=str) +@click.option("--redirecturi", type=str) +@click.option("--official/--unofficial") +@click.option("--scope", default=["openid"], multiple=True) +def create_client(name: str, redirecturi: str, official: bool, scope: list[str]): + if invalid := [s for s in scope if s not in SCOPES]: + logger.error("Invalid scopes: %s", ", ".join(invalid)) + raise click.exceptions.Exit(1) + + client = OAuth2Client( + client_id=gen_salt(24), + official=official, + ) + client.set_client_metadata( + { + "client_name": name, + "redirect_uris": [redirecturi], + "grant_types": ["code"], + "response_types": ["code id_token"], + "scope": " ".join(scope), + } + ) + db.session.add(client) + db.session.commit() + logger.info("New OIDC client created. Client id: %s", client.client_id) + + +@oidc.get("/.well-known/openid-configuration") +def discovery() -> ResponseValue: + """Implements the OpenID Connect Discovery protocol. + + https://openid.net/specs/openid-connect-discovery-1_0.html + """ + m = OpenIDProviderMetadata( + issuer=get_issuer(), + authorization_endpoint=url_for("oidc.authorize", _external=True), + token_endpoint=url_for("oidc.token", _external=True), + jwks_uri=url_for("oidc.jwks", _external=True), + response_types_supported=["code", "id_token", "code id_token"], + subject_types_supported=["public"], + id_token_signing_alg_values_supported=["RS256"], + ) + m.validate() + return m + + +@oidc.route("/oidc/authorize", methods=["GET", "POST"]) +@login_required +def authorize() -> ResponseValue: + if request.method == "GET": + try: + grant = authorization.get_consent_grant(end_user=current_user) + scope = grant.client.get_allowed_scope(grant.request.payload.scope) + except OAuth2Error as e: + return jsonify(dict(e.get_body())) + + scopes = {s: SCOPES[s] for s in scope_to_list(scope)} + scopents = {scope: desc for scope, desc in SCOPES.items() if scope not in scopes} + return render_template("oidc/authorize.html", grant=grant, scopes=scopes, scopents=scopents) + res = authorization.create_authorization_response( + grant_user=current_user if "authorize" in request.form else None + ) + return res + + +@oidc.post("/oidc/token") +def token(): + return authorization.create_token_response() + + +@oidc.get("/.well-known/jwks.json") +def jwks(): ... diff --git a/main.py b/main.py index 9d588f255..d24217164 100644 --- a/main.py +++ b/main.py @@ -321,6 +321,8 @@ def shell_imports(): from apps.cfp import cfp from apps.cfp_review import cfp_review from apps.metrics import metrics + from apps.oidc.oauth import init_oauth + from apps.oidc.routes import oidc from apps.payments import payments from apps.schedule import schedule from apps.tickets import tickets @@ -344,7 +346,9 @@ def shell_imports(): app.register_blueprint(admin, url_prefix="/admin") app.register_blueprint(volunteer, url_prefix="/volunteer") app.register_blueprint(notify, url_prefix="/volunteer/admin/notify") + app.register_blueprint(oidc) + init_oauth(app) volunteer_admin.init_app(app) return app diff --git a/migrations/versions/26b2a94192b3_oidc_stuff.py b/migrations/versions/26b2a94192b3_oidc_stuff.py new file mode 100644 index 000000000..67b08caf4 --- /dev/null +++ b/migrations/versions/26b2a94192b3_oidc_stuff.py @@ -0,0 +1,83 @@ +"""oidc_stuff + +Revision ID: 26b2a94192b3 +Revises: bcf21daa6073 +Create Date: 2025-08-25 21:28:12.004215 + +""" + +# revision identifiers, used by Alembic. +revision = '26b2a94192b3' +down_revision = 'bcf21daa6073' + +from alembic import op +import sqlalchemy as sa + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('oauth2_client', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('official', sa.Boolean(), nullable=True), + sa.Column('client_id', sa.String(length=48), nullable=True), + sa.Column('client_secret', sa.String(length=120), nullable=True), + sa.Column('client_id_issued_at', sa.Integer(), nullable=False), + sa.Column('client_secret_expires_at', sa.Integer(), nullable=False), + sa.Column('client_metadata', sa.Text(), nullable=True), + sa.PrimaryKeyConstraint('id', name=op.f('pk_oauth2_client')) + ) + with op.batch_alter_table('oauth2_client', schema=None) as batch_op: + batch_op.create_index(batch_op.f('ix_oauth2_client_client_id'), ['client_id'], unique=False) + + op.create_table('oauth2_authcode', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('user_id', sa.Integer(), nullable=True), + sa.Column('code', sa.String(length=120), nullable=False), + sa.Column('client_id', sa.String(length=48), nullable=True), + sa.Column('redirect_uri', sa.Text(), nullable=True), + sa.Column('response_type', sa.Text(), nullable=True), + sa.Column('scope', sa.Text(), nullable=True), + sa.Column('nonce', sa.Text(), nullable=True), + sa.Column('auth_time', sa.Integer(), nullable=False), + sa.Column('acr', sa.Text(), nullable=True), + sa.Column('amr', sa.Text(), nullable=True), + sa.Column('code_challenge', sa.Text(), nullable=True), + sa.Column('code_challenge_method', sa.String(length=48), nullable=True), + sa.ForeignKeyConstraint(['user_id'], ['user.id'], name=op.f('fk_oauth2_authcode_user_id_user'), ondelete='CASCADE'), + sa.PrimaryKeyConstraint('id', name=op.f('pk_oauth2_authcode')), + sa.UniqueConstraint('code', name=op.f('uq_oauth2_authcode_code')) + ) + op.create_table('oauth2_token', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('user_id', sa.Integer(), nullable=True), + sa.Column('client_id', sa.String(length=48), nullable=True), + sa.Column('token_type', sa.String(length=40), nullable=True), + sa.Column('access_token', sa.String(length=255), nullable=False), + sa.Column('refresh_token', sa.String(length=255), nullable=True), + sa.Column('scope', sa.Text(), nullable=True), + sa.Column('issued_at', sa.Integer(), nullable=False), + sa.Column('access_token_revoked_at', sa.Integer(), nullable=False), + sa.Column('refresh_token_revoked_at', sa.Integer(), nullable=False), + sa.Column('expires_in', sa.Integer(), nullable=False), + sa.ForeignKeyConstraint(['user_id'], ['user.id'], name=op.f('fk_oauth2_token_user_id_user'), ondelete='CASCADE'), + sa.PrimaryKeyConstraint('id', name=op.f('pk_oauth2_token')), + sa.UniqueConstraint('access_token', name=op.f('uq_oauth2_token_access_token')) + ) + with op.batch_alter_table('oauth2_token', schema=None) as batch_op: + batch_op.create_index(batch_op.f('ix_oauth2_token_refresh_token'), ['refresh_token'], unique=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('oauth2_token', schema=None) as batch_op: + batch_op.drop_index(batch_op.f('ix_oauth2_token_refresh_token')) + + op.drop_table('oauth2_token') + op.drop_table('oauth2_authcode') + with op.batch_alter_table('oauth2_client', schema=None) as batch_op: + batch_op.drop_index(batch_op.f('ix_oauth2_client_client_id')) + + op.drop_table('oauth2_client') + # ### end Alembic commands ### diff --git a/models/__init__.py b/models/__init__.py index 57b53952f..dcf6bc3a1 100644 --- a/models/__init__.py +++ b/models/__init__.py @@ -184,6 +184,7 @@ def config_date(key): from .email import * # noqa: F403 from .event_tickets import * # noqa: F403 from .feature_flag import * # noqa: F403 +from .oidc import * # noqa: F403 from .payment import * # noqa: F403 from .permission import * # noqa: F403 from .product import * # noqa: F403 diff --git a/models/oidc.py b/models/oidc.py new file mode 100644 index 000000000..053f97800 --- /dev/null +++ b/models/oidc.py @@ -0,0 +1,27 @@ +from authlib.integrations.sqla_oauth2 import OAuth2AuthorizationCodeMixin, OAuth2ClientMixin, OAuth2TokenMixin + +from main import db + +from . import BaseModel + + +class OAuth2Client(BaseModel, OAuth2ClientMixin): + __tablename__ = "oauth2_client" + + id = db.Column(db.Integer, primary_key=True) + official = db.Column(db.Boolean) + + +class OAuth2Token(BaseModel, OAuth2TokenMixin): + __tablename__ = "oauth2_token" + id = db.Column(db.Integer, primary_key=True) + user_id = db.Column(db.Integer, db.ForeignKey("user.id", ondelete="CASCADE")) + user = db.relationship("User") + + +class OAuth2AuthorizationCode(BaseModel, OAuth2AuthorizationCodeMixin): + __tablename__ = "oauth2_authcode" + + id = db.Column(db.Integer, primary_key=True) + user_id = db.Column(db.Integer, db.ForeignKey("user.id", ondelete="CASCADE")) + user = db.relationship("User") diff --git a/pyproject.toml b/pyproject.toml index 433af23e3..37e6aa94b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,6 +8,7 @@ requires-python = ">=3.13,<3.14" dependencies = [ "alembic~=1.1", + "authlib>=1.6.2", "cryptography>=44.0.3", "css-inline<0.18", "decorator", @@ -146,5 +147,6 @@ module = [ "sqlalchemy_continuum.*", "sqlalchemy.engine.row", "wtforms_sqlalchemy.*", + "authlib.*", # Annoyingly authlib has typeshed stubs... but they suck ] ignore_missing_imports = true diff --git a/templates/oidc/authorize.html b/templates/oidc/authorize.html new file mode 100644 index 000000000..daf747c96 --- /dev/null +++ b/templates/oidc/authorize.html @@ -0,0 +1,50 @@ +{% extends "base.html" %} +{% block title %}Authorize{% endblock %} +{% block body %} +