From 10c3a097477129ed09d6f02301bd5f26776506c0 Mon Sep 17 00:00:00 2001 From: Sam Willcocks Date: Mon, 25 Aug 2025 23:57:08 +0300 Subject: [PATCH] oidc: 1st draft --- apps/oidc/__init__.py | 0 apps/oidc/oauth.py | 134 ++++++++++++++++++ apps/oidc/routes.py | 93 ++++++++++++ main.py | 4 + .../versions/26b2a94192b3_oidc_stuff.py | 83 +++++++++++ models/__init__.py | 1 + models/oidc.py | 27 ++++ pyproject.toml | 2 + templates/oidc/authorize.html | 50 +++++++ uv.lock | 14 ++ 10 files changed, 408 insertions(+) create mode 100644 apps/oidc/__init__.py create mode 100644 apps/oidc/oauth.py create mode 100644 apps/oidc/routes.py create mode 100644 migrations/versions/26b2a94192b3_oidc_stuff.py create mode 100644 models/oidc.py create mode 100644 templates/oidc/authorize.html diff --git a/apps/oidc/__init__.py b/apps/oidc/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/apps/oidc/oauth.py b/apps/oidc/oauth.py new file mode 100644 index 000000000..83029da6d --- /dev/null +++ b/apps/oidc/oauth.py @@ -0,0 +1,134 @@ +from typing import TypedDict + +import authlib.oauth2.rfc6749.grants as rfc6749_grants +import authlib.oidc.core.grants as oidc_core_grants +from authlib.integrations.flask_oauth2 import AuthorizationServer +from authlib.integrations.sqla_oauth2 import create_query_client_func, create_save_token_func +from authlib.oauth2.rfc6749 import OAuth2Request +from authlib.oauth2.rfc6749.util import scope_to_list +from authlib.oidc.core.claims import UserInfo +from flask import request + +from main import db +from models.oidc import OAuth2AuthorizationCode, OAuth2Client, OAuth2Token +from models.user import User + +SCOPES: dict[str, str] = { + "openid": "Your anonymous account identifier", + "email": "The email address associated with your account", + "permissions": "Your EMF website permissions", +} + + +def get_issuer(): + return f"https://{request.host}" + + +def save_authorization_code(code: str, request: OAuth2Request) -> OAuth2AuthorizationCode: + client = request.client + item = OAuth2AuthorizationCode( + code=code, + client_id=client.client_id, + redirect_uri=request.payload.redirect_uri, + scope=request.payload.scope, + user_id=request.user.id, + ) + db.session.add(item) + db.session.commit() + return item + + +def exists_nonce(nonce, request) -> bool: + # TODO: implement me! + return False + + +class JWTConfig(TypedDict): + key: str + alg: str + iss: str + exp: int + + +def get_jwt_config(grant) -> JWTConfig: + return JWTConfig(key="foo", alg="HS256", iss=get_issuer(), exp=999999999999) + + +def generate_user_info(user: User, scope) -> UserInfo: + info = { + "sub": user.id, + } + scopes = scope_to_list(scope) + if "email" in scopes: + info["email"] = user.email + if "permissions" in scopes: + info["permissions"] = [p.name for p in user.permissions] + + return UserInfo(info) + + +class AuthorizationCodeGrant(rfc6749_grants.AuthorizationCodeGrant): + def save_authorization_code(self, code, request): + return save_authorization_code(code, request) + + def parse_authorization_code(self, code, client) -> OAuth2AuthorizationCode | None: + authcode = OAuth2AuthorizationCode.query.filter_by(code=code, client_id=client.client_id).first() + if authcode and not authcode.is_expired(): + return authcode + return None + + def delete_authorization_code(self, authorization_code): + db.session.delete(authorization_code) + db.session.commit() + + def authenticate_user(self, authorization_code): + return User.query.get(authorization_code.user_id) + + +class OpenIDCode(oidc_core_grants.OpenIDCode): + def exists_nonce(self, nonce, request): + return exists_nonce(nonce, request) + + def get_jwt_config(self, grant) -> JWTConfig: + return get_jwt_config(grant) + + def generate_user_info(self, user, scope) -> UserInfo: + return generate_user_info(user, scope) + + +class ImplicitGrant(oidc_core_grants.OpenIDImplicitGrant): + def exists_nonce(self, nonce, request) -> bool: + return exists_nonce(nonce, request) + + def get_jwt_config(self) -> JWTConfig: + return get_jwt_config(grant=None) + + def generate_user_info(self, user, scope) -> UserInfo: + return generate_user_info(user, scope) + + +class HybridGrant(oidc_core_grants.OpenIDHybridGrant): + def save_authorization_code(self, code, request): + return save_authorization_code(code, request) + + def exists_nonce(self, nonce, request) -> bool: + return exists_nonce(nonce, request) + + def get_jwt_config(self) -> JWTConfig: + return get_jwt_config(grant=None) + + def generate_user_info(self, user, scope) -> UserInfo: + return generate_user_info(user, scope) + + +authorization = AuthorizationServer() + + +def init_oauth(app): + query_client = create_query_client_func(db.session, OAuth2Client) + save_token = create_save_token_func(db.session, OAuth2Token) + + authorization.init_app(app, query_client=query_client, save_token=save_token) + authorization.register_grant(AuthorizationCodeGrant, [OpenIDCode(require_nonce=True)]) + authorization.register_grant(ImplicitGrant) + authorization.register_grant(HybridGrant) diff --git a/apps/oidc/routes.py b/apps/oidc/routes.py new file mode 100644 index 000000000..5bee1d9d1 --- /dev/null +++ b/apps/oidc/routes.py @@ -0,0 +1,93 @@ +import logging + +import click +from authlib.oauth2 import OAuth2Error +from authlib.oauth2.rfc6749.util import scope_to_list +from authlib.oidc.discovery import OpenIDProviderMetadata +from flask import Blueprint, jsonify, render_template, request, url_for +from flask.typing import ResponseValue +from flask_login import current_user, login_required +from werkzeug.security import gen_salt + +from main import db +from models.oidc import OAuth2Client + +from .oauth import SCOPES, authorization, get_issuer + +logger = logging.getLogger(__name__) +oidc = Blueprint("oidc", "oidc") + + +@oidc.cli.command("create_client") +@click.option("--name", type=str) +@click.option("--redirecturi", type=str) +@click.option("--official/--unofficial") +@click.option("--scope", default=["openid"], multiple=True) +def create_client(name: str, redirecturi: str, official: bool, scope: list[str]): + if invalid := [s for s in scope if s not in SCOPES]: + logger.error("Invalid scopes: %s", ", ".join(invalid)) + raise click.exceptions.Exit(1) + + client = OAuth2Client( + client_id=gen_salt(24), + official=official, + ) + client.set_client_metadata( + { + "client_name": name, + "redirect_uris": [redirecturi], + "grant_types": ["code"], + "response_types": ["code id_token"], + "scope": " ".join(scope), + } + ) + db.session.add(client) + db.session.commit() + logger.info("New OIDC client created. Client id: %s", client.client_id) + + +@oidc.get("/.well-known/openid-configuration") +def discovery() -> ResponseValue: + """Implements the OpenID Connect Discovery protocol. + + https://openid.net/specs/openid-connect-discovery-1_0.html + """ + m = OpenIDProviderMetadata( + issuer=get_issuer(), + authorization_endpoint=url_for("oidc.authorize", _external=True), + token_endpoint=url_for("oidc.token", _external=True), + jwks_uri=url_for("oidc.jwks", _external=True), + response_types_supported=["code", "id_token", "code id_token"], + subject_types_supported=["public"], + id_token_signing_alg_values_supported=["RS256"], + ) + m.validate() + return m + + +@oidc.route("/oidc/authorize", methods=["GET", "POST"]) +@login_required +def authorize() -> ResponseValue: + if request.method == "GET": + try: + grant = authorization.get_consent_grant(end_user=current_user) + scope = grant.client.get_allowed_scope(grant.request.payload.scope) + except OAuth2Error as e: + return jsonify(dict(e.get_body())) + + scopes = {s: SCOPES[s] for s in scope_to_list(scope)} + scopents = {scope: desc for scope, desc in SCOPES.items() if scope not in scopes} + return render_template("oidc/authorize.html", grant=grant, scopes=scopes, scopents=scopents) + res = authorization.create_authorization_response( + grant_user=current_user if "authorize" in request.form else None + ) + return res + + +@oidc.post("/oidc/token") +def token(): + return authorization.create_token_response() + + +@oidc.get("/.well-known/jwks.json") +def jwks(): ... diff --git a/main.py b/main.py index 9d588f255..d24217164 100644 --- a/main.py +++ b/main.py @@ -321,6 +321,8 @@ def shell_imports(): from apps.cfp import cfp from apps.cfp_review import cfp_review from apps.metrics import metrics + from apps.oidc.oauth import init_oauth + from apps.oidc.routes import oidc from apps.payments import payments from apps.schedule import schedule from apps.tickets import tickets @@ -344,7 +346,9 @@ def shell_imports(): app.register_blueprint(admin, url_prefix="/admin") app.register_blueprint(volunteer, url_prefix="/volunteer") app.register_blueprint(notify, url_prefix="/volunteer/admin/notify") + app.register_blueprint(oidc) + init_oauth(app) volunteer_admin.init_app(app) return app diff --git a/migrations/versions/26b2a94192b3_oidc_stuff.py b/migrations/versions/26b2a94192b3_oidc_stuff.py new file mode 100644 index 000000000..67b08caf4 --- /dev/null +++ b/migrations/versions/26b2a94192b3_oidc_stuff.py @@ -0,0 +1,83 @@ +"""oidc_stuff + +Revision ID: 26b2a94192b3 +Revises: bcf21daa6073 +Create Date: 2025-08-25 21:28:12.004215 + +""" + +# revision identifiers, used by Alembic. +revision = '26b2a94192b3' +down_revision = 'bcf21daa6073' + +from alembic import op +import sqlalchemy as sa + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('oauth2_client', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('official', sa.Boolean(), nullable=True), + sa.Column('client_id', sa.String(length=48), nullable=True), + sa.Column('client_secret', sa.String(length=120), nullable=True), + sa.Column('client_id_issued_at', sa.Integer(), nullable=False), + sa.Column('client_secret_expires_at', sa.Integer(), nullable=False), + sa.Column('client_metadata', sa.Text(), nullable=True), + sa.PrimaryKeyConstraint('id', name=op.f('pk_oauth2_client')) + ) + with op.batch_alter_table('oauth2_client', schema=None) as batch_op: + batch_op.create_index(batch_op.f('ix_oauth2_client_client_id'), ['client_id'], unique=False) + + op.create_table('oauth2_authcode', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('user_id', sa.Integer(), nullable=True), + sa.Column('code', sa.String(length=120), nullable=False), + sa.Column('client_id', sa.String(length=48), nullable=True), + sa.Column('redirect_uri', sa.Text(), nullable=True), + sa.Column('response_type', sa.Text(), nullable=True), + sa.Column('scope', sa.Text(), nullable=True), + sa.Column('nonce', sa.Text(), nullable=True), + sa.Column('auth_time', sa.Integer(), nullable=False), + sa.Column('acr', sa.Text(), nullable=True), + sa.Column('amr', sa.Text(), nullable=True), + sa.Column('code_challenge', sa.Text(), nullable=True), + sa.Column('code_challenge_method', sa.String(length=48), nullable=True), + sa.ForeignKeyConstraint(['user_id'], ['user.id'], name=op.f('fk_oauth2_authcode_user_id_user'), ondelete='CASCADE'), + sa.PrimaryKeyConstraint('id', name=op.f('pk_oauth2_authcode')), + sa.UniqueConstraint('code', name=op.f('uq_oauth2_authcode_code')) + ) + op.create_table('oauth2_token', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('user_id', sa.Integer(), nullable=True), + sa.Column('client_id', sa.String(length=48), nullable=True), + sa.Column('token_type', sa.String(length=40), nullable=True), + sa.Column('access_token', sa.String(length=255), nullable=False), + sa.Column('refresh_token', sa.String(length=255), nullable=True), + sa.Column('scope', sa.Text(), nullable=True), + sa.Column('issued_at', sa.Integer(), nullable=False), + sa.Column('access_token_revoked_at', sa.Integer(), nullable=False), + sa.Column('refresh_token_revoked_at', sa.Integer(), nullable=False), + sa.Column('expires_in', sa.Integer(), nullable=False), + sa.ForeignKeyConstraint(['user_id'], ['user.id'], name=op.f('fk_oauth2_token_user_id_user'), ondelete='CASCADE'), + sa.PrimaryKeyConstraint('id', name=op.f('pk_oauth2_token')), + sa.UniqueConstraint('access_token', name=op.f('uq_oauth2_token_access_token')) + ) + with op.batch_alter_table('oauth2_token', schema=None) as batch_op: + batch_op.create_index(batch_op.f('ix_oauth2_token_refresh_token'), ['refresh_token'], unique=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('oauth2_token', schema=None) as batch_op: + batch_op.drop_index(batch_op.f('ix_oauth2_token_refresh_token')) + + op.drop_table('oauth2_token') + op.drop_table('oauth2_authcode') + with op.batch_alter_table('oauth2_client', schema=None) as batch_op: + batch_op.drop_index(batch_op.f('ix_oauth2_client_client_id')) + + op.drop_table('oauth2_client') + # ### end Alembic commands ### diff --git a/models/__init__.py b/models/__init__.py index 57b53952f..dcf6bc3a1 100644 --- a/models/__init__.py +++ b/models/__init__.py @@ -184,6 +184,7 @@ def config_date(key): from .email import * # noqa: F403 from .event_tickets import * # noqa: F403 from .feature_flag import * # noqa: F403 +from .oidc import * # noqa: F403 from .payment import * # noqa: F403 from .permission import * # noqa: F403 from .product import * # noqa: F403 diff --git a/models/oidc.py b/models/oidc.py new file mode 100644 index 000000000..053f97800 --- /dev/null +++ b/models/oidc.py @@ -0,0 +1,27 @@ +from authlib.integrations.sqla_oauth2 import OAuth2AuthorizationCodeMixin, OAuth2ClientMixin, OAuth2TokenMixin + +from main import db + +from . import BaseModel + + +class OAuth2Client(BaseModel, OAuth2ClientMixin): + __tablename__ = "oauth2_client" + + id = db.Column(db.Integer, primary_key=True) + official = db.Column(db.Boolean) + + +class OAuth2Token(BaseModel, OAuth2TokenMixin): + __tablename__ = "oauth2_token" + id = db.Column(db.Integer, primary_key=True) + user_id = db.Column(db.Integer, db.ForeignKey("user.id", ondelete="CASCADE")) + user = db.relationship("User") + + +class OAuth2AuthorizationCode(BaseModel, OAuth2AuthorizationCodeMixin): + __tablename__ = "oauth2_authcode" + + id = db.Column(db.Integer, primary_key=True) + user_id = db.Column(db.Integer, db.ForeignKey("user.id", ondelete="CASCADE")) + user = db.relationship("User") diff --git a/pyproject.toml b/pyproject.toml index 433af23e3..37e6aa94b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,6 +8,7 @@ requires-python = ">=3.13,<3.14" dependencies = [ "alembic~=1.1", + "authlib>=1.6.2", "cryptography>=44.0.3", "css-inline<0.18", "decorator", @@ -146,5 +147,6 @@ module = [ "sqlalchemy_continuum.*", "sqlalchemy.engine.row", "wtforms_sqlalchemy.*", + "authlib.*", # Annoyingly authlib has typeshed stubs... but they suck ] ignore_missing_imports = true diff --git a/templates/oidc/authorize.html b/templates/oidc/authorize.html new file mode 100644 index 000000000..daf747c96 --- /dev/null +++ b/templates/oidc/authorize.html @@ -0,0 +1,50 @@ +{% extends "base.html" %} +{% block title %}Authorize{% endblock %} +{% block body %} +

Log in with your EMF account

+ +
+
+
+

+ {{ grant.client.client_name }} wants to authenticate you using your EMF account. +

+

+ {{ grant.client.client_name }} + {% if grant.client.official %} + is an official EMF orga-run service. + {% else %} + is a third party service. EMF is not responsible for the service or what it does with your data! + {% endif %} +

+

+ The app will be given read-only access to: +

    + {% for name, detail in scopes.items() %} +
  • ✅ {{ detail }}
  • + {% endfor %} +
+

+ {% if scopents %} +

+ The app will not have access to: +

    + {% for name, detail in scopents.items() %} +
  • ⛔️ {{ detail }}
  • + {% endfor %} +
+

+ {% endif %} +
+ +
+
+{% endblock %} diff --git a/uv.lock b/uv.lock index 28ba55b1c..918ad4f29 100644 --- a/uv.lock +++ b/uv.lock @@ -60,6 +60,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/77/06/bb80f5f86020c4551da315d78b3ab75e8228f89f0162f2c3a819e407941a/attrs-25.3.0-py3-none-any.whl", hash = "sha256:427318ce031701fea540783410126f03899a97ffc6f61596ad581ac2e40e3bc3", size = 63815, upload-time = "2025-03-13T11:10:21.14Z" }, ] +[[package]] +name = "authlib" +version = "1.6.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cryptography" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/8a/95/e4f4ab5ce465821fe2229e10985ab80462941fe5d96387ae76bafd36f0ba/authlib-1.6.2.tar.gz", hash = "sha256:3bde83ac0392683eeef589cd5ab97e63cbe859e552dd75dca010548e79202cb1", size = 160429, upload-time = "2025-08-23T08:34:32.665Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f4/00/fb65909bf4c8d7da893a12006074343402a8dc8c00d916b3cee524d97f3f/authlib-1.6.2-py2.py3-none-any.whl", hash = "sha256:2dd5571013cacf6b15f7addce03ed057ffdf629e9e81bacd9c08455a190e9b57", size = 239601, upload-time = "2025-08-23T08:34:31.4Z" }, +] + [[package]] name = "bidict" version = "0.23.1" @@ -2118,6 +2130,7 @@ version = "0.1" source = { virtual = "." } dependencies = [ { name = "alembic" }, + { name = "authlib" }, { name = "cryptography" }, { name = "css-inline" }, { name = "decorator" }, @@ -2203,6 +2216,7 @@ dev = [ [package.metadata] requires-dist = [ { name = "alembic", specifier = "~=1.1" }, + { name = "authlib", specifier = ">=1.6.2" }, { name = "cryptography", specifier = ">=44.0.3" }, { name = "css-inline", specifier = "<0.18" }, { name = "decorator" },