Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 177 additions & 13 deletions backend/src/xfd_django/xfd_api/auth_saml.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,15 @@ def _env_truthy(in_var: Optional[str]) -> bool:
SAML_SP_CERT = os.getenv("SAML_SP_CERT")
SAML_SP_PRIVATE_KEY = os.getenv("SAML_SP_PRIVATE_KEY")

LOGGER.info(
"SAML init: BACKEND_DOMAIN='%s', FRONTEND_DOMAIN='%s', IS_LOCAL=%s, "
"OKTA_METADATA_URL set=%s",
BACKEND_DOMAIN or "<unset>",
FRONTEND_DOMAIN or "<unset>",
IS_LOCAL,
bool(OKTA_METADATA_URL),
)


# =============================================================================
# Injectable IdP parser + lazy settings cache to avoid import-time fetch errors
Expand All @@ -67,21 +76,35 @@ class _SamlConfig:

def reset_saml_settings_cache_for_tests() -> None:
"""Clear cached SAML settings for tests."""
LOGGER.info("Resetting SAML settings cache (tests).")
_SamlConfig.settings_cache = None


def set_idp_metadata_parser_for_tests(parser_cls) -> None:
"""Inject a fake IdP metadata parser for tests."""
LOGGER.info("Overriding IdP metadata parser for tests: %s", parser_cls)
_SamlConfig.idp_parser = parser_cls


def _build_sp_settings() -> Dict[str, Any]:
"""Build SAML settings dict by merging SP config with Okta IdP metadata."""
if not OKTA_METADATA_URL:
LOGGER.error("OKTA_SAML_METADATA_URL is not set; cannot build SAML settings.")
raise RuntimeError("OKTA_SAML_METADATA_URL is not set")

LOGGER.info(
"Building SAML SP settings from IdP metadata URL: %s", OKTA_METADATA_URL
)

# Fetch & parse IdP metadata
idp_data = _SamlConfig.idp_parser.parse_remote(OKTA_METADATA_URL)
try:
idp_data = _SamlConfig.idp_parser.parse_remote(OKTA_METADATA_URL)
LOGGER.info("Successfully parsed IdP metadata from %s", OKTA_METADATA_URL)
except Exception as exc: # noqa: BLE001
LOGGER.exception(
"Failed to parse IdP metadata from %s: %s", OKTA_METADATA_URL, exc
)
raise

sp_settings: Dict[str, Any] = {
"strict": False, # consider True once IdP config is finalized
Expand Down Expand Up @@ -128,18 +151,29 @@ def _build_sp_settings() -> Dict[str, Any]:
sp_settings["security"]["wantAssertionsEncrypted"] = True
sp_settings["security"]["wantNameIdEncrypted"] = True
sp_settings["security"]["authnRequestsSigned"] = True
LOGGER.info("SAML encryption ENABLED (inline cert/key configured).")
LOGGER.info(
"SAML encryption ENABLED (inline cert/key configured; authnRequestsSigned=True)."
)
else:
LOGGER.info("No SP cert configured; encryption NOT advertised.")

# Merge with IdP metadata
return _SamlConfig.idp_parser.merge_settings(sp_settings, idp_data)
merged = _SamlConfig.idp_parser.merge_settings(sp_settings, idp_data)
LOGGER.info(
"Merged SAML settings built: entityId=%s, acs_url=%s",
merged.get("sp", {}).get("entityId"),
merged.get("sp", {}).get("assertionConsumerService", {}).get("url"),
)
return merged


def _get_settings_dict() -> Dict[str, Any]:
"""Return the merged SAML settings (build lazily and cache)."""
if _SamlConfig.settings_cache is None:
LOGGER.info("SAML settings cache miss; building settings.")
_SamlConfig.settings_cache = _build_sp_settings()
else:
LOGGER.info("SAML settings cache hit; reusing existing settings.")
return _SamlConfig.settings_cache


Expand All @@ -148,7 +182,7 @@ def _get_settings_dict() -> Dict[str, Any]:
# =============================================================================
def _starlette_to_saml_request(request: Request) -> Dict[str, Any]:
"""Translate FastAPI/Starlette request to python3-saml expected dict."""
return {
data = {
"https": "on" if request.url.scheme == "https" else "off",
"http_host": request.headers.get("host"),
"script_name": request.scope.get("root_path", ""),
Expand All @@ -157,6 +191,14 @@ def _starlette_to_saml_request(request: Request) -> Dict[str, Any]:
"get_data": dict(request.query_params),
"post_data": {}, # populated for ACS
}
LOGGER.info(
"Converted Starlette request to SAML dict: path=%s, https=%s, host=%s, port=%s",
request.url.path,
data["https"],
data["http_host"],
data["server_port"],
)
return data


def _get_auth(
Expand All @@ -166,7 +208,16 @@ def _get_auth(
"""Return a OneLogin_Saml2_Auth instance for the given request."""
data = _starlette_to_saml_request(request)
if with_post:
LOGGER.info(
"Initializing SAML auth with POST data for path=%s, RelayState=%s",
request.url.path,
with_post.get("RelayState"),
)
data["post_data"] = with_post
else:
LOGGER.info(
"Initializing SAML auth without POST data for path=%s", request.url.path
)
settings = OneLogin_Saml2_Settings(settings=_get_settings_dict())
return OneLogin_Saml2_Auth(data, old_settings=settings)

Expand All @@ -178,7 +229,15 @@ def _path_only(raw: Optional[str]) -> str:
"""Return a safe, path-only string that always starts with '/'."""
val = raw or "/"
decoded = urllib.parse.unquote(val)
return decoded if decoded.startswith("/") else "/"
cleaned = decoded if decoded.startswith("/") else "/"
if raw != cleaned:
LOGGER.info(
"Normalized path-only value: raw=%r, decoded=%r, cleaned=%r",
raw,
decoded,
cleaned,
)
return cleaned


def _extract_identity(auth: OneLogin_Saml2_Auth) -> Dict[str, Any]:
Expand All @@ -192,14 +251,26 @@ def _extract_identity(auth: OneLogin_Saml2_Auth) -> Dict[str, Any]:
last = (attrs.get("lastName") or attrs.get("family_name") or [""])[0]
groups = attrs.get("groups") or []

return {
identity = {
"okta_id": okta_id,
"email": email,
"first": first,
"last": last,
"groups": groups,
}

LOGGER.info(
"Extracted SAML identity: okta_id=%s, email=%s, first=%s, last=%s, group_count=%d",
okta_id,
email,
first,
last,
len(groups),
)
LOGGER.info("SAML identity groups: %s", groups)

return identity


def _upsert_user(identity: Dict[str, Any]) -> User:
"""Upsert a User keyed by OktaId, with legacy email attachment path."""
Expand All @@ -209,20 +280,35 @@ def _upsert_user(identity: Dict[str, Any]) -> User:
last = identity["last"]
groups = identity["groups"]

LOGGER.info(
"Upserting user from SAML identity: okta_id=%s, email=%s", okta_id, email
)

# Try to find the user by OktaId first
user = User.objects.filter(okta_id=okta_id).first()

if not user:
LOGGER.info(
"No existing user found with okta_id=%s; attempting legacy email lookup for %s",
okta_id,
email,
)
# If no user with OktaId exists, try to find a legacy user by email
user = User.objects.filter(email=email).first()
if user:
LOGGER.info(
"Found legacy user with email=%s; attaching okta_id=%s", email, okta_id
)
# Update the legacy user in place
user.okta_id = okta_id
user.first_name = user.first_name or (first or None)
user.last_name = user.last_name or (last or None)
if user.invite_pending:
user.invite_pending = False
else:
LOGGER.info(
"No legacy user for email=%s; creating new user with okta_id=%s",
email,
okta_id,
)
# Create a new user if no legacy user exists
user = User(
okta_id=okta_id,
Expand All @@ -234,10 +320,21 @@ def _upsert_user(identity: Dict[str, Any]) -> User:
can_select_own_state=True,
)
else:
LOGGER.info(
"Found existing user with okta_id=%s (id=%s); updating identity fields.",
okta_id,
user.id,
)
# Update the existing user with OktaId
user.first_name = user.first_name or (first or None)
user.last_name = user.last_name or (last or None)
if email and user.email != email:
LOGGER.info(
"Updating email for user id=%s from %s to %s",
user.id,
user.email,
email,
)
user.email = email

# Update additional fields
Expand All @@ -247,16 +344,30 @@ def _upsert_user(identity: Dict[str, Any]) -> User:
user.cognito_groups = groups
user.last_logged_in = datetime.now(timezone.utc)

LOGGER.info(
"Updating login block status for user id=%s (okta_id=%s).",
getattr(user, "id", None),
okta_id,
)
# Update login block status and save the user
update_login_block_status(user)
user.save()
LOGGER.info("User upserted and saved: id=%s, okta_id=%s", user.id, okta_id)
return user


def _redirect_with_cookies(relay: Optional[str], token: str) -> RedirectResponse:
"""Return a 303 redirect to the SPA and set auth cookies."""
relay_path = _path_only(relay)
target = f"{FRONTEND_DOMAIN.rstrip('/')}{relay_path}"

LOGGER.info(
"Redirecting after SAML ACS with RelayState=%r -> relay_path=%r, target=%r",
relay,
relay_path,
target,
)

resp = RedirectResponse(target, status_code=303)

is_https = BACKEND_DOMAIN.startswith("https://")
Expand All @@ -279,6 +390,10 @@ def _redirect_with_cookies(relay: Optional[str], token: str) -> RedirectResponse
samesite="Lax",
path="/",
)
LOGGER.info(
"Auth cookies set on redirect response; secure=%s, samesite='Lax', path='/'",
is_https,
)
return resp


Expand All @@ -288,14 +403,17 @@ def _redirect_with_cookies(relay: Optional[str], token: str) -> RedirectResponse
@router.get("/saml/metadata")
def saml_metadata():
"""Return the SAML SP metadata document."""
LOGGER.info("Serving SAML SP metadata.")
settings = OneLogin_Saml2_Settings(settings=_get_settings_dict())
metadata = settings.get_sp_metadata()
errors = settings.validate_metadata(metadata)
if errors:
LOGGER.error("SP metadata validation failed: %s", errors)
raise HTTPException(
status_code=500,
detail=f"SP metadata invalid: {', '.join(errors)}",
)
LOGGER.info("SP metadata validated successfully.")
return Response(content=metadata, media_type="application/xml")


Expand All @@ -306,44 +424,90 @@ def saml_login(request: Request, next: str = "/"):

Optional `next` controls where the user lands after login.
"""
next_path = _path_only(request.query_params.get("next"))
raw_next = request.query_params.get("next", next)
next_path = _path_only(raw_next)
LOGGER.info(
"SAML login initiated: path=%s, raw_next=%r, normalized_next=%r",
request.url.path,
raw_next,
next_path,
)

auth = _get_auth(request)
return RedirectResponse(auth.login(return_to=next_path))
login_url = auth.login(return_to=next_path)
LOGGER.info("Generated SAML login URL: %s", login_url)
return RedirectResponse(login_url)


@router.post("/saml/acs")
async def saml_acs(request: Request):
"""Process the SAML response, upsert a user, issue JWT, and set cookies."""
LOGGER.info("SAML ACS endpoint called for path=%s", request.url.path)
form = dict(await request.form())
relay_state = form.get("RelayState")
LOGGER.info("SAML ACS received RelayState=%r", relay_state)

auth = _get_auth(request, with_post=form)
auth.process_response()

errors = auth.get_errors()
if errors or not auth.is_authenticated():
is_auth = auth.is_authenticated()

if errors or not is_auth:
LOGGER.error(
"SAML auth failed at ACS: errors=%s, is_authenticated=%s",
errors,
is_auth,
)
raise HTTPException(status_code=401, detail=f"SAML auth failed: {errors}")

LOGGER.info(
"SAML auth successful at ACS: is_authenticated=%s",
is_auth,
)

identity = _extract_identity(auth)
if not identity["okta_id"]:
LOGGER.error(
"SAML assertion missing OktaId (NameID/custom:OKTA_ID); identity=%s",
identity,
)
raise HTTPException(
status_code=400,
detail="No OktaId (NameID/custom:OKTA_ID) in SAML assertion",
)

user = _upsert_user(identity)
LOGGER.info(
"Creating JWT token for user id=%s, okta_id=%s", user.id, identity["okta_id"]
)
token = create_jwt_token(user)

LOGGER.info("Validating JSON serialization for user_to_dict output.")
validate_json_serialization(user_to_dict(user), label="User Dict")

relay = form.get("RelayState") or "/"
relay = relay_state or "/"
LOGGER.info("Final RelayState used for redirect: %r", relay)
return _redirect_with_cookies(relay, token)


@router.get("/saml/logout")
def saml_logout(request: Request, next: str = "/"):
"""Log the user out of the app and clear auth cookies."""
next_path = _path_only(request.query_params.get("next"))
raw_next = request.query_params.get("next", next)
next_path = _path_only(raw_next)
target = f"{FRONTEND_DOMAIN.rstrip('/')}{next_path}"

LOGGER.info(
"SAML logout called: path=%s, raw_next=%r, normalized_next=%r, target=%r",
request.url.path,
raw_next,
next_path,
target,
)

resp = RedirectResponse(target, status_code=303)
resp.delete_cookie("token", path="/")
resp.delete_cookie("crossfeed-token", path="/")
LOGGER.info("Cleared auth cookies on logout response.")
return resp
Loading