Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions src/mlpa/core/auth/authorize.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from mlpa.core.config import env
from mlpa.core.routers.appattest import app_attest_auth
from mlpa.core.routers.fxa import fxa_auth
from mlpa.core.utils import parse_app_attest_jwt
from mlpa.core.utils import extract_user_from_play_integrity_jwt, parse_app_attest_jwt


async def authorize_request(
Expand All @@ -15,10 +15,12 @@ async def authorize_request(
service_type: Annotated[ServiceType, Header()],
use_app_attest: Annotated[bool | None, Header()] = None,
use_qa_certificates: Annotated[bool | None, Header()] = None,
use_play_integrity: Annotated[bool | None, Header()] = None,
) -> AuthorizedChatRequest:
if not authorization:
raise HTTPException(status_code=401, detail="Missing authorization header")
if use_app_attest:
# Apple App Attest
assertionAuth = parse_app_attest_jwt(authorization, "assert")
data = await app_attest_auth(assertionAuth, chat_request, use_qa_certificates)
if data:
Expand All @@ -28,8 +30,16 @@ async def authorize_request(
user=f"{assertionAuth.key_id_b64}:{service_type.value}", # "user" is key_id_b64 from app attest
**chat_request.model_dump(exclude_unset=True),
)
elif use_play_integrity:
# Google Play integrity
play_user_id = extract_user_from_play_integrity_jwt(authorization)
if play_user_id:
return AuthorizedChatRequest(
user=f"{play_user_id}:{service_type.value}",
**chat_request.model_dump(exclude_unset=True),
)
else:
# FxA authorization
# Firefox Account authorization
fxa_user_id = fxa_auth(authorization)
if fxa_user_id:
if fxa_user_id.get("error"):
Expand Down
7 changes: 7 additions & 0 deletions src/mlpa/core/classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class UserUpdatePayload(BaseModel):
blocked: bool | None = None


# iOS App Attest
class AttestationAuth(BaseModel):
key_id_b64: str
challenge_b64: str
Expand All @@ -37,6 +38,12 @@ class AssertionAuth(BaseModel):
assertion_obj_b64: str


# Google Play Integrity
class PlayIntegrityRequest(BaseModel):
integrity_token: str
user_id: str


class AuthorizedChatRequest(ChatRequest):
user: str

Expand Down
8 changes: 8 additions & 0 deletions src/mlpa/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,17 @@ def valid_service_types(self) -> list[str]:
APP_ATTEST_QA_BUCKET_PREFIX: str | None = None
APP_ATTEST_QA_GCP_PROJECT_ID: str | None = None

# Play Integrity
PLAY_INTEGRITY_PACKAGE_NAME: str = "com.example.app"
PLAY_INTEGRITY_SERVICE_ACCOUNT_FILE: str = "service_account.json"
PLAY_INTEGRITY_REQUEST_TIMEOUT_SECONDS: int = 30
MLPA_ACCESS_TOKEN_SECRET: str = "mlpa-dev-secret"
MLPA_ACCESS_TOKEN_TTL_SECONDS: int = 300

# FxA
CLIENT_ID: str = "default-client-id"
CLIENT_SECRET: str = "default-client-secret"
FXA_SCOPE: str = "profile"

# PostgreSQL
LITELLM_DB_NAME: str = "litellm"
Expand Down
3 changes: 2 additions & 1 deletion src/mlpa/core/routers/fxa/fxa.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from fastapi import APIRouter, Header, HTTPException

from mlpa.core.config import env
from mlpa.core.logger import logger
from mlpa.core.prometheus_metrics import PrometheusResult, metrics
from mlpa.core.utils import get_fxa_client
Expand All @@ -16,7 +17,7 @@ def fxa_auth(authorization: Annotated[str | None, Header()]):
token = authorization.removeprefix("Bearer ").split()[0]
result = PrometheusResult.ERROR
try:
profile = client.verify_token(token, scope="profile")
profile = client.verify_token(token, scope=env.FXA_SCOPE)
result = PrometheusResult.SUCCESS
except Exception as e:
logger.error(f"FxA auth error: {e}")
Expand Down
3 changes: 3 additions & 0 deletions src/mlpa/core/routers/play/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from mlpa.core.routers.play.play import router as play_router

__all__ = ["play_router"]
100 changes: 100 additions & 0 deletions src/mlpa/core/routers/play/play.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import hashlib
from functools import lru_cache

import httpx
from fastapi import APIRouter, HTTPException
from fastapi.concurrency import run_in_threadpool
from google.auth.transport.requests import Request
from google.oauth2 import service_account
from pydantic import BaseModel

from mlpa.core.classes import PlayIntegrityRequest
from mlpa.core.config import env
from mlpa.core.http_client import get_http_client
from mlpa.core.utils import issue_mlpa_access_token, raise_and_log

router = APIRouter()

PLAY_INTEGRITY_SCOPE = "https://www.googleapis.com/auth/playintegrity"
ALLOWED_DEVICE_VERDICTS = {
"MEETS_DEVICE_INTEGRITY",
"MEETS_BASIC_INTEGRITY",
"MEETS_STRONG_INTEGRITY",
}


@lru_cache(maxsize=1)
def _get_service_account_credentials():
return service_account.Credentials.from_service_account_file(
env.PLAY_INTEGRITY_SERVICE_ACCOUNT_FILE,
scopes=[PLAY_INTEGRITY_SCOPE],
)


def _get_play_integrity_access_token() -> str:
credentials = _get_service_account_credentials()
if not credentials.valid:
credentials.refresh(Request())
if not credentials.token:
raise HTTPException(status_code=500, detail="Failed to fetch access token")
return credentials.token


async def _decode_integrity_token(integrity_token: str) -> dict:
access_token = await run_in_threadpool(_get_play_integrity_access_token)
client = get_http_client()
try:
response = await client.post(
f"https://playintegrity.googleapis.com/v1/{env.PLAY_INTEGRITY_PACKAGE_NAME}:decodeIntegrityToken",
headers={
"Authorization": f"Bearer {access_token}",
"Content-Type": "application/json",
},
json={"integrity_token": integrity_token},
timeout=env.PLAY_INTEGRITY_REQUEST_TIMEOUT_SECONDS,
)
response.raise_for_status()
except httpx.HTTPStatusError as e:
raise_and_log(e, False, 401)
except Exception as e:
raise_and_log(e, False, 502, "Play Integrity validation service unavailable")
return response.json()


def _validate_integrity_payload(payload: dict, expected_hash: str) -> None:
request_details = payload.get("requestDetails", {})
package_name = request_details.get("requestPackageName")
if package_name and package_name != env.PLAY_INTEGRITY_PACKAGE_NAME:
raise HTTPException(status_code=401, detail="Invalid package name")

token_request_hash = request_details.get("requestHash")
if token_request_hash != expected_hash:
raise HTTPException(status_code=401, detail="Invalid request hash")

app_integrity = payload.get("appIntegrity", {})
if app_integrity.get("appRecognitionVerdict") != "PLAY_RECOGNIZED":
raise HTTPException(status_code=401, detail="App not recognized by Play")

device_integrity = payload.get("deviceIntegrity", {})
device_verdicts = set(device_integrity.get("deviceRecognitionVerdict", []))
if not device_verdicts.intersection(ALLOWED_DEVICE_VERDICTS):
raise HTTPException(status_code=401, detail="Device integrity check failed")


@router.post("/play", tags=["Play Integrity"])
async def verify_play_integrity(payload: PlayIntegrityRequest):
decoded = await _decode_integrity_token(payload.integrity_token)
token_payload = decoded.get("tokenPayloadExternal") or decoded.get("tokenPayload")
if not token_payload:
raise HTTPException(status_code=401, detail="Invalid Play Integrity token")

expected_hash = hashlib.sha256(payload.user_id.encode("utf-8")).hexdigest()

_validate_integrity_payload(token_payload, expected_hash)

access_token = issue_mlpa_access_token(payload.user_id)
return {
"access_token": access_token,
"token_type": "Bearer",
"expires_in": env.MLPA_ACCESS_TOKEN_TTL_SECONDS,
}
38 changes: 37 additions & 1 deletion src/mlpa/core/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import ast
import base64
import json
import time

from fastapi import HTTPException
from fxa.oauth import Client
from jwtoxide import DecodingKey, ValidationOptions, decode
from jwtoxide import DecodingKey, ValidationOptions, decode, encode

from mlpa.core.classes import AssertionAuth, AttestationAuth
from mlpa.core.config import LITELLM_MASTER_AUTH_HEADERS, env
Expand Down Expand Up @@ -160,3 +161,38 @@ def raise_and_log(
else response_text_prefix or GENERIC_UPSTREAM_ERROR
},
)


def extract_user_from_play_integrity_jwt(authorization: str):
token = authorization.removeprefix("Bearer ").split()[0]
try:
payload = decode(
token,
env.MLPA_ACCESS_TOKEN_SECRET,
ValidationOptions(
required_spec_claims={"exp", "iat", "sub"},
iss={"mlpa"},
aud=None,
validate_aud=False,
validate_exp=True,
validate_nbf=False,
verify_signature=True,
algorithms=["HS256"],
),
)
return payload["sub"]
except Exception as e:
logger.error(f"Play Integrity JWT decode error: {e}")
raise HTTPException(status_code=401, detail="Invalid MLPA access token")


def issue_mlpa_access_token(user_id: str) -> str:
now = int(time.time())
payload = {
"sub": user_id,
"iat": now,
"exp": now + env.MLPA_ACCESS_TOKEN_TTL_SECONDS,
"iss": "mlpa",
"typ": "mlpa_access",
}
return encode(payload, env.MLPA_ACCESS_TOKEN_SECRET, algorithm="HS256")
11 changes: 7 additions & 4 deletions src/mlpa/run.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import json
import time
from contextlib import asynccontextmanager
from typing import Annotated, Optional

import sentry_sdk
import uvicorn
from fastapi import Depends, FastAPI, HTTPException, Request, Response
from fastapi.exception_handlers import http_exception_handler
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.responses import StreamingResponse
from prometheus_client import CONTENT_TYPE_LATEST, generate_latest

from mlpa.core.auth.authorize import authorize_request
Expand All @@ -21,11 +19,11 @@
from mlpa.core.logger import logger, setup_logger
from mlpa.core.middleware import register_middleware
from mlpa.core.pg_services.services import app_attest_pg, litellm_pg
from mlpa.core.prometheus_metrics import metrics
from mlpa.core.routers.appattest import appattest_router
from mlpa.core.routers.fxa import fxa_router
from mlpa.core.routers.health import health_router
from mlpa.core.routers.mock import mock_router
from mlpa.core.routers.play import play_router
from mlpa.core.routers.user import user_router
from mlpa.core.utils import get_or_create_user

Expand All @@ -36,6 +34,10 @@
"name": "App Attest",
"description": "Endpoints for verifying App Attest payloads.",
},
{
"name": "Play Integrity",
"description": "Endpoints for verifying Play Integrity payloads.",
},
{"name": "LiteLLM", "description": "Endpoints for interacting with LiteLLM."},
{"name": "Mock", "description": "Mock endpoints for testing purposes."},
{
Expand Down Expand Up @@ -91,6 +93,7 @@ async def get_metrics():

app.include_router(health_router, prefix="/health")
app.include_router(appattest_router, prefix="/verify")
app.include_router(play_router, prefix="/verify")
app.include_router(fxa_router, prefix="/fxa")
app.include_router(user_router, prefix="/user")
app.include_router(mock_router, prefix="/mock")
Expand Down
12 changes: 12 additions & 0 deletions src/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import pytest

from mlpa.core.config import env


@pytest.fixture(autouse=True, scope="session")
def _force_mlpa_debug_false():
monkeypatch = pytest.MonkeyPatch()
monkeypatch.setenv("MLPA_DEBUG", "false")
env.MLPA_DEBUG = False
yield
monkeypatch.undo()
83 changes: 83 additions & 0 deletions src/tests/integration/test_play_integrity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import hashlib

from mlpa.core.config import env
from mlpa.core.utils import issue_mlpa_access_token
from tests.consts import SAMPLE_CHAT_REQUEST, SUCCESSFUL_CHAT_RESPONSE, TEST_USER_ID


def _mock_decode_payload(request_hash: str) -> dict:
return {
"tokenPayloadExternal": {
"requestDetails": {
"requestPackageName": env.PLAY_INTEGRITY_PACKAGE_NAME,
"requestHash": request_hash,
},
"appIntegrity": {"appRecognitionVerdict": "PLAY_RECOGNIZED"},
"deviceIntegrity": {"deviceRecognitionVerdict": ["MEETS_DEVICE_INTEGRITY"]},
}
}


def test_verify_play_integrity_success(mocked_client_integration, mocker):
request_hash = hashlib.sha256(TEST_USER_ID.encode("utf-8")).hexdigest()
mocker.patch(
"mlpa.core.routers.play.play._decode_integrity_token",
return_value=_mock_decode_payload(request_hash),
)

response = mocked_client_integration.post(
"/verify/play",
json={"integrity_token": "test-token", "user_id": TEST_USER_ID},
)

assert response.status_code == 200
data = response.json()
assert data["access_token"]
assert data["token_type"] == "Bearer"
assert data["expires_in"] == env.MLPA_ACCESS_TOKEN_TTL_SECONDS


def test_verify_play_integrity_invalid_hash(mocked_client_integration, mocker):
mocker.patch(
"mlpa.core.routers.play.play._decode_integrity_token",
return_value=_mock_decode_payload("bad-hash"),
)

response = mocked_client_integration.post(
"/verify/play",
json={"integrity_token": "test-token", "user_id": TEST_USER_ID},
)

assert response.status_code == 401
assert response.json()["detail"] == "Invalid request hash"


def test_verify_play_integrity_missing_payload(mocked_client_integration, mocker):
mocker.patch(
"mlpa.core.routers.play.play._decode_integrity_token",
return_value={},
)

response = mocked_client_integration.post(
"/verify/play",
json={"integrity_token": "test-token", "user_id": TEST_USER_ID},
)

assert response.status_code == 401
assert response.json()["detail"] == "Invalid Play Integrity token"


def test_chat_with_play_integrity_token_success(mocked_client_integration):
access_token = issue_mlpa_access_token(TEST_USER_ID)
response = mocked_client_integration.post(
"/v1/chat/completions",
headers={
"authorization": f"Bearer {access_token}",
"use-play-integrity": "true",
"service-type": "ai",
},
json=SAMPLE_CHAT_REQUEST.model_dump(exclude_unset=True),
)

assert response.status_code == 200
assert response.json() == SUCCESSFUL_CHAT_RESPONSE