diff --git a/internal/api/api.go b/internal/api/api.go index c2536c0a7..07798e606 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -375,6 +375,21 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne }) }) } + + // Custom OAuth/OIDC provider management endpoints + if globalConfig.CustomOAuth.Enabled { + r.Route("/custom-providers", func(r *router) { + // supports both OAuth2 and OIDC via provider_type) + r.Get("/", api.adminCustomOAuthProvidersList) // Optional ?type=oauth2 or ?type=oidc filter + r.Post("/", api.adminCustomOAuthProviderCreate) // provider_type in request body + + r.Route("/{identifier}", func(r *router) { + r.Get("/", api.adminCustomOAuthProviderGet) + r.Put("/", api.adminCustomOAuthProviderUpdate) + r.Delete("/", api.adminCustomOAuthProviderDelete) + }) + }) + } }) // OAuth Dynamic Client Registration endpoint (public, rate limited) diff --git a/internal/api/apierrors/errorcode.go b/internal/api/apierrors/errorcode.go index 58963eea3..2e7789928 100644 --- a/internal/api/apierrors/errorcode.go +++ b/internal/api/apierrors/errorcode.go @@ -104,4 +104,9 @@ const ( ErrorCodeOAuthClientNotFound ErrorCode = "oauth_client_not_found" ErrorCodeOAuthAuthorizationNotFound ErrorCode = "oauth_authorization_not_found" ErrorCodeOAuthConsentNotFound ErrorCode = "oauth_consent_not_found" + + // Custom OAuth/OIDC provider error codes + ErrorCodeProviderNotFound ErrorCode = "provider_not_found" + ErrorCodeFeatureDisabled ErrorCode = "feature_disabled" + ErrorCodeOverQuota ErrorCode = "over_quota" ) diff --git a/internal/api/custom_oauth_admin.go b/internal/api/custom_oauth_admin.go new file mode 100644 index 000000000..9405830d2 --- /dev/null +++ b/internal/api/custom_oauth_admin.go @@ -0,0 +1,667 @@ +package api + +import ( + "net/http" + "slices" + "strings" + + "github.com/go-chi/chi/v5" + "github.com/supabase/auth/internal/api/apierrors" + "github.com/supabase/auth/internal/models" + "github.com/supabase/auth/internal/observability" + "github.com/supabase/auth/internal/storage" + "github.com/supabase/auth/internal/utilities" +) + +// TODO: Admin Audit Logging for Custom OAuth/OIDC Providers +// +// Current state: No audit logging is implemented for provider management operations. +// +// Why: The existing audit logging system (models.NewAuditLogEntry) is designed for +// user-centric actions where there's always a "user actor" performing an action. +// Admin infrastructure operations like provider management are different: +// - They're admin-only configuration changes +// - They don't have a regular "user" as the actor (it's an admin/operator) +// - They need different metadata (who made the change, what was changed, when, from where) +// +// What's needed: +// 1. Design a separate admin audit log system or extend the existing one +// 2. Consider what should be logged: +// - WHO: Admin identifier (could be service role, API key, or admin user) +// - WHAT: Operation (create/update/delete provider) +// - WHEN: Timestamp +// - WHERE: IP address, request ID +// - DETAILS: Provider identifier, what changed (for updates) +// 3. Consider compliance requirements (SOC2, GDPR, etc.) +// 4. Decide on storage (same audit_log_entries table or separate table?) +// +// For now, all create/update/delete operations have TODO comments where audit +// logging should be added once the design is finalized. + +// AdminCustomOAuthProviderParams defines parameters for creating/updating providers +type AdminCustomOAuthProviderParams struct { + // Common fields + ProviderType string `json:"provider_type"` // "oauth2" or "oidc" + Identifier string `json:"identifier"` + Name string `json:"name"` + ClientID string `json:"client_id"` + ClientSecret string `json:"client_secret"` + AcceptableClientIDs []string `json:"acceptable_client_ids,omitempty"` + Scopes []string `json:"scopes"` + PKCEEnabled *bool `json:"pkce_enabled,omitempty"` + AttributeMapping map[string]interface{} `json:"attribute_mapping,omitempty"` + AuthorizationParams map[string]interface{} `json:"authorization_params,omitempty"` + Enabled *bool `json:"enabled,omitempty"` + EmailOptional *bool `json:"email_optional,omitempty"` + + // OIDC-specific fields + Issuer string `json:"issuer,omitempty"` + DiscoveryURL *string `json:"discovery_url,omitempty"` + SkipNonceCheck *bool `json:"skip_nonce_check,omitempty"` + + // OAuth2-specific fields + AuthorizationURL string `json:"authorization_url,omitempty"` + TokenURL string `json:"token_url,omitempty"` + UserinfoURL string `json:"userinfo_url,omitempty"` + JwksURI *string `json:"jwks_uri,omitempty"` +} + +// =================================== +// Provider Admin Endpoints +// =================================== + +// adminCustomOAuthProvidersList returns all custom OAuth/OIDC providers +func (a *API) adminCustomOAuthProvidersList(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + db := a.db.WithContext(ctx) + + // Check for optional type filter + providerTypeParam := r.URL.Query().Get("type") + var providers []*models.CustomOAuthProvider + var err error + + if providerTypeParam != "" { + // Validate type parameter + providerType := models.ProviderType(providerTypeParam) + if providerType != models.ProviderTypeOAuth2 && providerType != models.ProviderTypeOIDC { + return apierrors.NewBadRequestError( + apierrors.ErrorCodeValidationFailed, + "type must be either 'oauth2' or 'oidc'", + ) + } + providers, err = models.FindAllCustomOAuthProvidersByType(db, providerType) + } else { + providers, err = models.FindAllCustomOAuthProviders(db) + } + + if err != nil { + return apierrors.NewInternalServerError("Error retrieving custom OAuth providers").WithInternalError(err) + } + + return sendJSON(w, http.StatusOK, map[string]interface{}{ + "providers": providers, + }) +} + +// adminCustomOAuthProviderGet returns a single custom OAuth/OIDC provider +func (a *API) adminCustomOAuthProviderGet(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + db := a.db.WithContext(ctx) + + identifier := chi.URLParam(r, "identifier") + if identifier == "" { + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "identifier is required") + } + + // Validate identifier starts with 'custom:' prefix + if !strings.HasPrefix(identifier, "custom:") { + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "identifier must start with 'custom:' prefix") + } + + observability.LogEntrySetField(r, "identifier", identifier) + + provider, err := models.FindCustomOAuthProviderByIdentifier(db, identifier) + if err != nil { + if models.IsNotFoundError(err) { + return apierrors.NewNotFoundError(apierrors.ErrorCodeProviderNotFound, "Custom OAuth provider not found") + } + return apierrors.NewInternalServerError("Error retrieving custom OAuth provider").WithInternalError(err) + } + + return sendJSON(w, http.StatusOK, provider) +} + +// adminCustomOAuthProviderCreate creates a new custom OAuth/OIDC provider +func (a *API) adminCustomOAuthProviderCreate(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + db := a.db.WithContext(ctx) + config := a.config + + // Check if custom OAuth is enabled + if !config.CustomOAuth.Enabled { + return apierrors.NewBadRequestError(apierrors.ErrorCodeFeatureDisabled, "Custom OAuth/OIDC providers are not enabled") + } + + // Parse request parameters + params := &AdminCustomOAuthProviderParams{} + if err := retrieveRequestParams(r, params); err != nil { + return err + } + + // Validate provider type + if params.ProviderType != string(models.ProviderTypeOAuth2) && params.ProviderType != string(models.ProviderTypeOIDC) { + return apierrors.NewBadRequestError( + apierrors.ErrorCodeValidationFailed, + "provider_type must be either 'oauth2' or 'oidc'", + ) + } + + providerType := models.ProviderType(params.ProviderType) + + // Validate type-specific required fields + if err := validateProviderParams(params, providerType); err != nil { + return err + } + + // Validate authorization params (no reserved OAuth parameters) + if err := validateAuthorizationParams(params.AuthorizationParams); err != nil { + return err + } + + // Validate attribute mapping (no protected system fields) + if err := validateAttributeMapping(params.AttributeMapping); err != nil { + return err + } + + // Check quota if configured + if config.CustomOAuth.MaxProviders > 0 { + totalCount, err := models.CountCustomOAuthProviders(db) + if err != nil { + return apierrors.NewInternalServerError("Error checking provider quota").WithInternalError(err) + } + if totalCount >= config.CustomOAuth.MaxProviders { + return apierrors.NewBadRequestError( + apierrors.ErrorCodeOverQuota, + "Maximum number of custom OAuth/OIDC providers reached", + ) + } + } + + // Validate URLs based on provider type + if err := validateProviderURLs(params, providerType); err != nil { + return err + } + + // Check if provider with this identifier already exists + existingProvider, err := models.FindCustomOAuthProviderByIdentifier(db, params.Identifier) + if err != nil && !models.IsNotFoundError(err) { + return apierrors.NewInternalServerError("Error checking for existing provider").WithInternalError(err) + } + if existingProvider != nil { + return apierrors.NewBadRequestError( + apierrors.ErrorCodeConflict, + "A custom OAuth provider with this identifier already exists", + ) + } + + // Create provider model + provider := buildProviderFromParams(params, providerType) + + // Encrypt and store client secret + if err := provider.SetClientSecret(params.ClientSecret, config.Security.DBEncryption); err != nil { + return apierrors.NewInternalServerError("Error encrypting custom OAuth provider client secret").WithInternalError(err) + } + + // Create in database + err = db.Transaction(func(tx *storage.Connection) error { + if terr := models.CreateCustomOAuthProvider(tx, provider); terr != nil { + return terr + } + + // TODO: Implement proper admin audit logging for infrastructure changes + // The current audit log is user-centric. We need a separate audit mechanism + // for admin operations like provider management that doesn't require a "user actor" + // but tracks admin API changes for security and compliance. + + return nil + }) + + if err != nil { + return apierrors.NewInternalServerError("Error creating custom OAuth provider").WithInternalError(err) + } + + return sendJSON(w, http.StatusCreated, provider) +} + +// adminCustomOAuthProviderUpdate updates an existing custom OAuth/OIDC provider +func (a *API) adminCustomOAuthProviderUpdate(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + db := a.db.WithContext(ctx) + config := a.config + + identifier := chi.URLParam(r, "identifier") + if identifier == "" { + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "identifier is required") + } + + // Validate identifier starts with 'custom:' prefix + if !strings.HasPrefix(identifier, "custom:") { + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "identifier must start with 'custom:' prefix") + } + + observability.LogEntrySetField(r, "identifier", identifier) + + // Parse request parameters + params := &AdminCustomOAuthProviderParams{} + if err := retrieveRequestParams(r, params); err != nil { + return err + } + + // Validate authorization params if provided + if params.AuthorizationParams != nil { + if err := validateAuthorizationParams(params.AuthorizationParams); err != nil { + return err + } + } + + // Validate attribute mapping if provided + if params.AttributeMapping != nil { + if err := validateAttributeMapping(params.AttributeMapping); err != nil { + return err + } + } + + var provider *models.CustomOAuthProvider + err := db.Transaction(func(tx *storage.Connection) error { + var terr error + provider, terr = models.FindCustomOAuthProviderByIdentifier(tx, identifier) + if terr != nil { + if models.IsNotFoundError(terr) { + return apierrors.NewNotFoundError(apierrors.ErrorCodeProviderNotFound, "Custom OAuth provider not found") + } + return apierrors.NewInternalServerError("Error retrieving custom OAuth provider").WithInternalError(terr) + } + + // Update provider with new non-secret values + if terr := updateProviderFromParams(provider, params); terr != nil { + return terr + } + + // If a new client secret is provided, encrypt and store it (likely move to out of the transaction) + if params.ClientSecret != "" { + if terr := provider.SetClientSecret(params.ClientSecret, config.Security.DBEncryption); terr != nil { + return apierrors.NewInternalServerError("Error encrypting custom OAuth provider client secret").WithInternalError(terr) + } + } + + if terr := models.UpdateCustomOAuthProvider(tx, provider); terr != nil { + return apierrors.NewInternalServerError("Error updating custom OAuth provider").WithInternalError(terr) + } + + // TODO: Add admin audit logging here (see create endpoint for details) + + return nil + }) + + if err != nil { + return err + } + + return sendJSON(w, http.StatusOK, provider) +} + +// adminCustomOAuthProviderDelete deletes a custom OAuth/OIDC provider +func (a *API) adminCustomOAuthProviderDelete(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + db := a.db.WithContext(ctx) + + identifier := chi.URLParam(r, "identifier") + if identifier == "" { + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "identifier is required") + } + + // Validate identifier starts with 'custom:' prefix + if !strings.HasPrefix(identifier, "custom:") { + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "identifier must start with 'custom:' prefix") + } + + observability.LogEntrySetField(r, "identifier", identifier) + + err := db.Transaction(func(tx *storage.Connection) error { + provider, terr := models.FindCustomOAuthProviderByIdentifier(tx, identifier) + if terr != nil { + if models.IsNotFoundError(terr) { + return apierrors.NewNotFoundError(apierrors.ErrorCodeProviderNotFound, "Custom OAuth provider not found") + } + return apierrors.NewInternalServerError("Error retrieving custom OAuth provider").WithInternalError(terr) + } + + // TODO: Add admin audit logging here (see create endpoint for details) + + if terr := models.DeleteCustomOAuthProvider(tx, provider.ID); terr != nil { + return apierrors.NewInternalServerError("Error deleting custom OAuth provider").WithInternalError(terr) + } + + return nil + }) + + if err != nil { + return err + } + + w.WriteHeader(http.StatusNoContent) + return nil +} + +// =================================== +// Helper Functions +// =================================== + +// validateProviderParams validates type-specific required fields +func validateProviderParams(params *AdminCustomOAuthProviderParams, providerType models.ProviderType) error { + // Common validations + if params.Identifier == "" { + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "identifier is required") + } + + // Ensure identifier starts with 'custom:' prefix + if !strings.HasPrefix(params.Identifier, "custom:") { + params.Identifier = "custom:" + params.Identifier + } + + // Check for reserved provider names (built-in OAuth providers) + // These are already handled by Supabase Auth and shouldn't be overridden with custom providers + reservedProviderNames := []string{ + "apple", "azure", "bitbucket", "discord", "facebook", "figma", "fly", "github", "gitlab", + "google", "kakao", "keycloak", "linkedin_oidc", "linkedin", "notion", "slack_oidc", + "slack", "spotify", "twitch", "twitter", "workos", "x", "zoom", + } + + // Extract the base identifier without the "custom:" prefix for checking + baseIdentifier := strings.TrimPrefix(params.Identifier, "custom:") + if slices.Contains(reservedProviderNames, strings.ToLower(baseIdentifier)) { + return apierrors.NewBadRequestError( + apierrors.ErrorCodeValidationFailed, + "Cannot use reserved provider name: %s. This provider is already built into Supabase Auth.", + baseIdentifier, + ) + } + + if params.Name == "" { + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "name is required") + } + if params.ClientID == "" { + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "client_id is required") + } + if params.ClientSecret == "" { + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "client_secret is required") + } + + // Type-specific validations + if providerType == models.ProviderTypeOIDC { + if params.Issuer == "" { + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "issuer is required for OIDC providers") + } + } else if providerType == models.ProviderTypeOAuth2 { + if params.AuthorizationURL == "" { + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "authorization_url is required for OAuth2 providers") + } + if params.TokenURL == "" { + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "token_url is required for OAuth2 providers") + } + if params.UserinfoURL == "" { + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "userinfo_url is required for OAuth2 providers") + } + } + + return nil +} + +// validateProviderURLs validates URLs with SSRF protection +func validateProviderURLs(params *AdminCustomOAuthProviderParams, providerType models.ProviderType) error { + var urls []string + + if providerType == models.ProviderTypeOIDC { + urls = append(urls, params.Issuer) + if params.DiscoveryURL != nil && *params.DiscoveryURL != "" { + urls = append(urls, *params.DiscoveryURL) + } + } else if providerType == models.ProviderTypeOAuth2 { + urls = []string{ + params.AuthorizationURL, + params.TokenURL, + params.UserinfoURL, + } + if params.JwksURI != nil && *params.JwksURI != "" { + urls = append(urls, *params.JwksURI) + } + } + + for _, urlStr := range urls { + if urlStr != "" { + if err := utilities.ValidateOAuthURL(urlStr); err != nil { + return err + } + } + } + + return nil +} + +// buildProviderFromParams creates a provider model from params +func buildProviderFromParams(params *AdminCustomOAuthProviderParams, providerType models.ProviderType) *models.CustomOAuthProvider { + provider := &models.CustomOAuthProvider{ + ProviderType: providerType, + Identifier: params.Identifier, + Name: params.Name, + ClientID: params.ClientID, + AcceptableClientIDs: models.StringSlice(params.AcceptableClientIDs), + Scopes: models.StringSlice(params.Scopes), + PKCEEnabled: getBoolOrDefault(params.PKCEEnabled, true), + AttributeMapping: models.OAuthAttributeMapping(params.AttributeMapping), + AuthorizationParams: models.OAuthAuthorizationParams(params.AuthorizationParams), + Enabled: getBoolOrDefault(params.Enabled, true), + EmailOptional: getBoolOrDefault(params.EmailOptional, false), + } + + // Set type-specific fields + if providerType == models.ProviderTypeOIDC { + provider.Issuer = ¶ms.Issuer + provider.DiscoveryURL = params.DiscoveryURL + provider.SkipNonceCheck = getBoolOrDefault(params.SkipNonceCheck, false) + + // Ensure openid scope is present for OIDC + hasOpenID := false + for _, scope := range provider.Scopes { + if scope == "openid" { + hasOpenID = true + break + } + } + if !hasOpenID { + provider.Scopes = append(models.StringSlice{"openid"}, provider.Scopes...) + } + } else if providerType == models.ProviderTypeOAuth2 { + provider.AuthorizationURL = ¶ms.AuthorizationURL + provider.TokenURL = ¶ms.TokenURL + provider.UserinfoURL = ¶ms.UserinfoURL + provider.JwksURI = params.JwksURI + } + + // Initialize empty maps if nil + if provider.AttributeMapping == nil { + provider.AttributeMapping = make(models.OAuthAttributeMapping) + } + if provider.AuthorizationParams == nil { + provider.AuthorizationParams = make(models.OAuthAuthorizationParams) + } + + return provider +} + +// updateProviderFromParams updates a provider model from params +func updateProviderFromParams(provider *models.CustomOAuthProvider, params *AdminCustomOAuthProviderParams) error { + // Update common fields + if params.Name != "" { + provider.Name = params.Name + } + if params.ClientID != "" { + provider.ClientID = params.ClientID + } + if params.AcceptableClientIDs != nil { + provider.AcceptableClientIDs = models.StringSlice(params.AcceptableClientIDs) + } + if params.Scopes != nil { + provider.Scopes = models.StringSlice(params.Scopes) + // Ensure openid scope for OIDC + if provider.IsOIDC() { + hasOpenID := false + for _, scope := range provider.Scopes { + if scope == "openid" { + hasOpenID = true + break + } + } + if !hasOpenID { + provider.Scopes = append(models.StringSlice{"openid"}, provider.Scopes...) + } + } + } + if params.PKCEEnabled != nil { + provider.PKCEEnabled = *params.PKCEEnabled + } + if params.AttributeMapping != nil { + provider.AttributeMapping = models.OAuthAttributeMapping(params.AttributeMapping) + } + if params.AuthorizationParams != nil { + provider.AuthorizationParams = models.OAuthAuthorizationParams(params.AuthorizationParams) + } + if params.Enabled != nil { + provider.Enabled = *params.Enabled + } + if params.EmailOptional != nil { + provider.EmailOptional = *params.EmailOptional + } + + // Update type-specific fields + if provider.IsOIDC() { + if params.Issuer != "" { + if err := utilities.ValidateOAuthURL(params.Issuer); err != nil { + return err + } + provider.Issuer = ¶ms.Issuer + } + if params.DiscoveryURL != nil && *params.DiscoveryURL != "" { + if err := utilities.ValidateOAuthURL(*params.DiscoveryURL); err != nil { + return err + } + provider.DiscoveryURL = params.DiscoveryURL + } + if params.SkipNonceCheck != nil { + provider.SkipNonceCheck = *params.SkipNonceCheck + } + } else if provider.IsOAuth2() { + if params.AuthorizationURL != "" { + if err := utilities.ValidateOAuthURL(params.AuthorizationURL); err != nil { + return err + } + provider.AuthorizationURL = ¶ms.AuthorizationURL + } + if params.TokenURL != "" { + if err := utilities.ValidateOAuthURL(params.TokenURL); err != nil { + return err + } + provider.TokenURL = ¶ms.TokenURL + } + if params.UserinfoURL != "" { + if err := utilities.ValidateOAuthURL(params.UserinfoURL); err != nil { + return err + } + provider.UserinfoURL = ¶ms.UserinfoURL + } + if params.JwksURI != nil && *params.JwksURI != "" { + if err := utilities.ValidateOAuthURL(*params.JwksURI); err != nil { + return err + } + provider.JwksURI = params.JwksURI + } + } + + return nil +} + +// getBoolOrDefault returns the value or default if nil +func getBoolOrDefault(value *bool, defaultValue bool) bool { + if value == nil { + return defaultValue + } + return *value +} + +// validateAuthorizationParams ensures no reserved OAuth parameters are overridden +func validateAuthorizationParams(params map[string]interface{}) error { + if params == nil { + return nil + } + + // Reserved OAuth2/OIDC parameters that should never be overridden + // These are set by the auth server and allowing override would be a security issue + reservedParams := []string{ + "client_id", + "client_secret", + "redirect_uri", + "response_type", + "state", + "code_challenge", + "code_challenge_method", + "code_verifier", + "nonce", // We control nonce generation for security + } + + for key := range params { + if slices.Contains(reservedParams, key) { + return apierrors.NewBadRequestError( + apierrors.ErrorCodeValidationFailed, + "Cannot override reserved OAuth parameter: %s", key, + ) + } + } + + return nil +} + +// validateAttributeMapping ensures no sensitive system fields are targeted +func validateAttributeMapping(mapping map[string]interface{}) error { + if mapping == nil { + return nil + } + + // System fields that should never be populated from external providers + // Allowing these could lead to privilege escalation or security bypass + blockedTargets := []string{ + "id", // User UUID - system generated + "aud", // JWT audience - system controlled + "role", // User role - should be managed via database, not external provider + "app_metadata", // Admin-only metadata - not for external providers + "created_at", // System timestamp + "updated_at", // System timestamp + "confirmed_at", // Email confirmation - system controlled + "email_confirmed_at", + "phone_confirmed_at", + "email_verified", // Email verification status - should come from provider, not be overridden + "phone_verified", // Phone verification status - should come from provider, not be overridden + "banned_until", // Security field - system controlled + "is_super_admin", // Admin flag - system controlled + } + + for targetField := range mapping { + if slices.Contains(blockedTargets, targetField) { + return apierrors.NewBadRequestError( + apierrors.ErrorCodeValidationFailed, + "Cannot map to protected system field: %s", targetField, + ) + } + } + + return nil +} diff --git a/internal/api/custom_oauth_admin_test.go b/internal/api/custom_oauth_admin_test.go new file mode 100644 index 000000000..34f28dd42 --- /dev/null +++ b/internal/api/custom_oauth_admin_test.go @@ -0,0 +1,639 @@ +package api + +import ( + "bytes" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + + jwt "github.com/golang-jwt/jwt/v5" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "github.com/supabase/auth/internal/api/apierrors" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/models" + "github.com/supabase/auth/internal/storage" +) + +type CustomOAuthAdminTestSuite struct { + suite.Suite + API *API + Config *conf.GlobalConfiguration + token string +} + +func TestCustomOAuthAdmin(t *testing.T) { + api, config, err := setupAPIForTestWithCallback(func(config *conf.GlobalConfiguration, conn *storage.Connection) { + if config != nil { + // Enable custom OAuth feature before API initialization + config.CustomOAuth.Enabled = true + config.CustomOAuth.MaxProviders = 10 + // Ensure database encryption is enabled for tests that rely on encrypted client_secret + config.Security.DBEncryption.Encrypt = true + } + }) + require.NoError(t, err) + + ts := &CustomOAuthAdminTestSuite{ + API: api, + Config: config, + } + defer api.db.Close() + + suite.Run(t, ts) +} + +func (ts *CustomOAuthAdminTestSuite) SetupTest() { + models.TruncateAll(ts.API.db) + + // Reset config to default values before each test + // This prevents config changes from one test affecting others + ts.Config.CustomOAuth.Enabled = true + ts.Config.CustomOAuth.MaxProviders = 10 + ts.Config.Security.DBEncryption.Encrypt = true + + // Generate admin token + claims := &AccessTokenClaims{ + Role: "supabase_admin", + } + token, err := jwt.NewWithClaims(jwt.SigningMethodHS256, claims).SignedString([]byte(ts.Config.JWT.Secret)) + require.NoError(ts.T(), err, "Error generating admin jwt") + ts.token = token +} + +// Test POST /admin/custom-providers (Create) + +func (ts *CustomOAuthAdminTestSuite) TestCreateOAuth2Provider() { + payload := map[string]interface{}{ + "provider_type": "oauth2", + "identifier": "github-enterprise", + "name": "GitHub Enterprise", + "client_id": "test-client-id", + "client_secret": "test-client-secret", + "scopes": []string{"read:user", "user:email"}, + "authorization_url": "https://example.com/oauth/authorize", + "token_url": "https://example.com/oauth/token", + "userinfo_url": "https://example.com/api/user", + "pkce_enabled": true, + "enabled": true, + } + + var body bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&body).Encode(payload)) + + req := httptest.NewRequest(http.MethodPost, "/admin/custom-providers", &body) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ts.token)) + + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + + require.Equal(ts.T(), http.StatusCreated, w.Code) + + var provider models.CustomOAuthProvider + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&provider)) + + assert.Equal(ts.T(), models.ProviderTypeOAuth2, provider.ProviderType) + assert.Equal(ts.T(), "custom:github-enterprise", provider.Identifier) // Prefix added + assert.Equal(ts.T(), "GitHub Enterprise", provider.Name) + assert.True(ts.T(), provider.PKCEEnabled) + assert.True(ts.T(), provider.Enabled) + + // Ensure client secret is not exposed in JSON and is stored encrypted + assert.Empty(ts.T(), provider.ClientSecret) +} + +func (ts *CustomOAuthAdminTestSuite) TestCreateOIDCProvider() { + payload := map[string]interface{}{ + "provider_type": "oidc", + "identifier": "self-keycloak", + "name": "Keycloak", + "client_id": "test-client-id", + "client_secret": "test-client-secret", + "issuer": "https://example.com/realms/myrealm", + "scopes": []string{"profile", "email"}, + "pkce_enabled": true, + "enabled": true, + } + + var body bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&body).Encode(payload)) + + req := httptest.NewRequest(http.MethodPost, "/admin/custom-providers", &body) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ts.token)) + + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + + require.Equal(ts.T(), http.StatusCreated, w.Code) + + var provider models.CustomOAuthProvider + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&provider)) + + assert.Equal(ts.T(), models.ProviderTypeOIDC, provider.ProviderType) + assert.Equal(ts.T(), "custom:self-keycloak", provider.Identifier) + assert.Contains(ts.T(), provider.Scopes, "openid") // Auto-added for OIDC + assert.Contains(ts.T(), provider.Scopes, "profile") + + // Ensure client secret is not exposed in JSON + assert.Empty(ts.T(), provider.ClientSecret) +} + +func (ts *CustomOAuthAdminTestSuite) TestCreateProviderValidation() { + tests := []struct { + name string + payload map[string]interface{} + wantStatus int + errMsg string + }{ + { + name: "Missing provider_type", + payload: map[string]interface{}{ + "identifier": "test", + "name": "Test", + "client_id": "id", + "client_secret": "secret", + }, + wantStatus: http.StatusBadRequest, + errMsg: "provider_type must be either 'oauth2' or 'oidc'", + }, + { + name: "Invalid provider_type", + payload: map[string]interface{}{ + "provider_type": "invalid", + "identifier": "test", + "name": "Test", + "client_id": "id", + "client_secret": "secret", + }, + wantStatus: http.StatusBadRequest, + errMsg: "provider_type must be either 'oauth2' or 'oidc'", + }, + { + name: "Missing OAuth2 required fields", + payload: map[string]interface{}{ + "provider_type": "oauth2", + "identifier": "test", + "name": "Test", + "client_id": "id", + "client_secret": "secret", + // Missing authorization_url, token_url, userinfo_url + }, + wantStatus: http.StatusBadRequest, + errMsg: "authorization_url is required", + }, + { + name: "Missing OIDC issuer", + payload: map[string]interface{}{ + "provider_type": "oidc", + "identifier": "test", + "name": "Test", + "client_id": "id", + "client_secret": "secret", + // Missing issuer + }, + wantStatus: http.StatusBadRequest, + errMsg: "issuer is required", + }, + { + name: "Reserved provider name", + payload: map[string]interface{}{ + "provider_type": "oauth2", + "identifier": "google", + "name": "Google", + "client_id": "id", + "client_secret": "secret", + "authorization_url": "https://example.com/authorize", + "token_url": "https://example.com/token", + "userinfo_url": "https://example.com/userinfo", + }, + wantStatus: http.StatusBadRequest, + }, + { + name: "HTTP URL not allowed", + payload: map[string]interface{}{ + "provider_type": "oauth2", + "identifier": "test", + "name": "Test", + "client_id": "id", + "client_secret": "secret", + "authorization_url": "http://example.com/authorize", + "token_url": "https://example.com/token", + "userinfo_url": "https://example.com/userinfo", + }, + wantStatus: http.StatusBadRequest, + errMsg: "URL must use HTTPS", + }, + { + name: "Localhost blocked (SSRF)", + payload: map[string]interface{}{ + "provider_type": "oauth2", + "identifier": "test", + "name": "Test", + "client_id": "id", + "client_secret": "secret", + "authorization_url": "https://example.com/authorize", + "token_url": "https://localhost/token", + "userinfo_url": "https://example.com/userinfo", + }, + wantStatus: http.StatusBadRequest, + errMsg: "localhost", + }, + { + name: "Private IP blocked (SSRF)", + payload: map[string]interface{}{ + "provider_type": "oauth2", + "identifier": "test", + "name": "Test", + "client_id": "id", + "client_secret": "secret", + "authorization_url": "https://example.com/authorize", + "token_url": "https://10.0.0.1/token", + "userinfo_url": "https://example.com/userinfo", + }, + wantStatus: http.StatusBadRequest, + errMsg: "private network", + }, + { + name: "Reserved OAuth param (client_id)", + payload: map[string]interface{}{ + "provider_type": "oauth2", + "identifier": "test", + "name": "Test", + "client_id": "id", + "client_secret": "secret", + "authorization_url": "https://example.com/authorize", + "token_url": "https://example.com/token", + "userinfo_url": "https://example.com/userinfo", + "authorization_params": map[string]interface{}{ + "client_id": "overridden", + }, + }, + wantStatus: http.StatusBadRequest, + errMsg: "reserved OAuth parameter", + }, + { + name: "Reserved OAuth param (state)", + payload: map[string]interface{}{ + "provider_type": "oauth2", + "identifier": "test", + "name": "Test", + "client_id": "id", + "client_secret": "secret", + "authorization_url": "https://example.com/authorize", + "token_url": "https://example.com/token", + "userinfo_url": "https://example.com/userinfo", + "authorization_params": map[string]interface{}{ + "state": "custom", + }, + }, + wantStatus: http.StatusBadRequest, + errMsg: "reserved OAuth parameter", + }, + { + name: "Protected system field in attribute mapping (id)", + payload: map[string]interface{}{ + "provider_type": "oauth2", + "identifier": "test", + "name": "Test", + "client_id": "id", + "client_secret": "secret", + "authorization_url": "https://example.com/authorize", + "token_url": "https://example.com/token", + "userinfo_url": "https://example.com/userinfo", + "attribute_mapping": map[string]interface{}{ + "id": "external_id", + }, + }, + wantStatus: http.StatusBadRequest, + errMsg: "protected system field", + }, + { + name: "Protected system field in attribute mapping (role)", + payload: map[string]interface{}{ + "provider_type": "oauth2", + "identifier": "test", + "name": "Test", + "client_id": "id", + "client_secret": "secret", + "authorization_url": "https://example.com/authorize", + "token_url": "https://example.com/token", + "userinfo_url": "https://example.com/userinfo", + "attribute_mapping": map[string]interface{}{ + "role": "admin", + }, + }, + wantStatus: http.StatusBadRequest, + errMsg: "protected system field", + }, + } + + for _, tt := range tests { + ts.Run(tt.name, func() { + var body bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&body).Encode(tt.payload)) + + req := httptest.NewRequest(http.MethodPost, "/admin/custom-providers", &body) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ts.token)) + + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + + assert.Equal(ts.T(), tt.wantStatus, w.Code) + + if tt.errMsg != "" { + var apiErr apierrors.HTTPError + json.NewDecoder(w.Body).Decode(&apiErr) + assert.Contains(ts.T(), apiErr.Message, tt.errMsg) + } + }) + } +} + +func (ts *CustomOAuthAdminTestSuite) TestCreateProviderQuotaEnforcement() { + // Set quota to 2 + ts.Config.CustomOAuth.MaxProviders = 2 + + // Create first provider + payload1 := ts.createTestOAuth2Payload("provider1") + ts.createProvider(payload1, http.StatusCreated) + + // Create second provider + payload2 := ts.createTestOAuth2Payload("provider2") + ts.createProvider(payload2, http.StatusCreated) + + // Third provider should fail (quota exceeded) + payload3 := ts.createTestOAuth2Payload("provider3") + w := ts.createProvider(payload3, http.StatusBadRequest) + + var apiErr apierrors.HTTPError + json.NewDecoder(w.Body).Decode(&apiErr) + assert.Contains(ts.T(), apiErr.Message, "Maximum number") +} + +func (ts *CustomOAuthAdminTestSuite) TestCreateProviderFeatureDisabled() { + ts.Config.CustomOAuth.Enabled = false + + payload := ts.createTestOAuth2Payload("test") + w := ts.createProvider(payload, http.StatusBadRequest) + + var apiErr apierrors.HTTPError + json.NewDecoder(w.Body).Decode(&apiErr) + assert.Contains(ts.T(), apiErr.Message, "not enabled") +} + +func (ts *CustomOAuthAdminTestSuite) TestCreateProviderDuplicateIdentifier() { + identifier := "duplicate-test" + + // Create first provider + payload1 := ts.createTestOAuth2Payload(identifier) + w := ts.createProvider(payload1, http.StatusCreated) + require.Equal(ts.T(), http.StatusCreated, w.Code) + + // Try to create another provider with the same identifier + payload2 := ts.createTestOAuth2Payload(identifier) + w = ts.createProvider(payload2, http.StatusBadRequest) + + require.Equal(ts.T(), http.StatusBadRequest, w.Code) + + var apiErr apierrors.HTTPError + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&apiErr)) + assert.Equal(ts.T(), apierrors.ErrorCodeConflict, apiErr.ErrorCode) + assert.Contains(ts.T(), apiErr.Message, "already exists") + assert.Contains(ts.T(), apiErr.Message, "identifier") +} + +func (ts *CustomOAuthAdminTestSuite) TestCreateProviderDuplicateIdentifierWithCustomPrefix() { + // Test that identifier normalization works correctly + // User provides "custom:test" and we should still detect duplicates + identifier := "custom:duplicate-prefix-test" + + // Create first provider with explicit custom: prefix + payload1 := ts.createTestOAuth2Payload(identifier) + w := ts.createProvider(payload1, http.StatusCreated) + require.Equal(ts.T(), http.StatusCreated, w.Code) + + // Try to create another provider with the same identifier (without prefix) + payload2 := ts.createTestOAuth2Payload("duplicate-prefix-test") + w = ts.createProvider(payload2, http.StatusBadRequest) + + require.Equal(ts.T(), http.StatusBadRequest, w.Code) + + var apiErr apierrors.HTTPError + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&apiErr)) + assert.Equal(ts.T(), apierrors.ErrorCodeConflict, apiErr.ErrorCode) + assert.Contains(ts.T(), apiErr.Message, "already exists") +} + +// Test GET /admin/custom-providers (List) + +func (ts *CustomOAuthAdminTestSuite) TestListProviders() { + // Create some providers + ts.createProvider(ts.createTestOAuth2Payload("oauth2-1"), http.StatusCreated) + ts.createProvider(ts.createTestOAuth2Payload("oauth2-2"), http.StatusCreated) + ts.createProvider(ts.createTestOIDCPayload("oidc-1", "https://oidc1.example.com"), http.StatusCreated) + + req := httptest.NewRequest(http.MethodGet, "/admin/custom-providers", nil) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ts.token)) + + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + + require.Equal(ts.T(), http.StatusOK, w.Code) + + var response map[string]interface{} + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&response)) + + providers := response["providers"].([]interface{}) + assert.Len(ts.T(), providers, 3) +} + +func (ts *CustomOAuthAdminTestSuite) TestListProvidersWithTypeFilter() { + // Create mixed providers + ts.createProvider(ts.createTestOAuth2Payload("oauth2-1"), http.StatusCreated) + ts.createProvider(ts.createTestOIDCPayload("oidc-1", "https://oidc1.example.com"), http.StatusCreated) + ts.createProvider(ts.createTestOIDCPayload("oidc-2", "https://oidc2.example.com"), http.StatusCreated) + + // Filter by OAuth2 + req := httptest.NewRequest(http.MethodGet, "/admin/custom-providers?type=oauth2", nil) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ts.token)) + + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + + require.Equal(ts.T(), http.StatusOK, w.Code) + + var response map[string]interface{} + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&response)) + + providers := response["providers"].([]interface{}) + assert.Len(ts.T(), providers, 1) + + // Filter by OIDC + req = httptest.NewRequest(http.MethodGet, "/admin/custom-providers?type=oidc", nil) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ts.token)) + + w = httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + + require.Equal(ts.T(), http.StatusOK, w.Code) + + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&response)) + + providers = response["providers"].([]interface{}) + assert.Len(ts.T(), providers, 2) +} + +// Test GET /admin/custom-providers/:id (Get) + +func (ts *CustomOAuthAdminTestSuite) TestGetProvider() { + w := ts.createProvider(ts.createTestOAuth2Payload("test-provider"), http.StatusCreated) + + var created models.CustomOAuthProvider + json.NewDecoder(w.Body).Decode(&created) + + req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("/admin/custom-providers/%s", created.Identifier), nil) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ts.token)) + + w = httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + + require.Equal(ts.T(), http.StatusOK, w.Code) + + var provider models.CustomOAuthProvider + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&provider)) + + assert.Equal(ts.T(), created.ID, provider.ID) + assert.Equal(ts.T(), created.Identifier, provider.Identifier) +} + +func (ts *CustomOAuthAdminTestSuite) TestGetProviderNotFound() { + // Use a valid identifier format but non-existent provider + fakeIdentifier := "custom:non-existent-provider" + + req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("/admin/custom-providers/%s", fakeIdentifier), nil) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ts.token)) + + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + + require.Equal(ts.T(), http.StatusNotFound, w.Code) +} + +// Test PUT /admin/custom-providers/:id (Update) + +func (ts *CustomOAuthAdminTestSuite) TestUpdateProvider() { + w := ts.createProvider(ts.createTestOAuth2Payload("test-provider"), http.StatusCreated) + + var created models.CustomOAuthProvider + json.NewDecoder(w.Body).Decode(&created) + + updatePayload := map[string]interface{}{ + "name": "Updated Name", + "client_id": "new-client-id", + "enabled": false, + "scopes": []string{"openid", "profile", "email"}, + } + + var body bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&body).Encode(updatePayload)) + + req := httptest.NewRequest(http.MethodPut, fmt.Sprintf("/admin/custom-providers/%s", created.Identifier), &body) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ts.token)) + + w = httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + + require.Equal(ts.T(), http.StatusOK, w.Code) + + var updated models.CustomOAuthProvider + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&updated)) + + assert.Equal(ts.T(), "Updated Name", updated.Name) + assert.Equal(ts.T(), "new-client-id", updated.ClientID) + assert.False(ts.T(), updated.Enabled) + assert.Equal(ts.T(), models.StringSlice{"openid", "profile", "email"}, updated.Scopes) +} + +// Test DELETE /admin/custom-providers/:id (Delete) + +func (ts *CustomOAuthAdminTestSuite) TestDeleteProvider() { + w := ts.createProvider(ts.createTestOAuth2Payload("test-provider"), http.StatusCreated) + + var created models.CustomOAuthProvider + json.NewDecoder(w.Body).Decode(&created) + + req := httptest.NewRequest(http.MethodDelete, fmt.Sprintf("/admin/custom-providers/%s", created.Identifier), nil) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ts.token)) + + w = httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + + require.Equal(ts.T(), http.StatusNoContent, w.Code) + + // Verify the response body is empty (204 No Content should have no body) + assert.Empty(ts.T(), w.Body.String()) + + // Verify deletion + req = httptest.NewRequest(http.MethodGet, fmt.Sprintf("/admin/custom-providers/%s", created.Identifier), nil) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ts.token)) + + w = httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + + require.Equal(ts.T(), http.StatusNotFound, w.Code) +} + +// Helper methods + +func (ts *CustomOAuthAdminTestSuite) createTestOAuth2Payload(identifier string) map[string]interface{} { + return map[string]interface{}{ + "provider_type": "oauth2", + "identifier": identifier, + "name": "Test OAuth2 Provider", + "client_id": "test-client-id", + "client_secret": "test-client-secret", + "scopes": []string{"openid", "profile"}, + "authorization_url": "https://example.com/authorize", + "token_url": "https://example.com/token", + "userinfo_url": "https://example.com/userinfo", + "pkce_enabled": true, + "enabled": true, + } +} + +func (ts *CustomOAuthAdminTestSuite) createTestOIDCPayload(identifier, issuer string) map[string]interface{} { + // If issuer is not provided or uses non-resolvable domain, use example.com + if issuer == "" || strings.Contains(issuer, "oidc1.example.com") || strings.Contains(issuer, "oidc2.example.com") { + issuer = "https://example.com/realms/" + identifier + } + return map[string]interface{}{ + "provider_type": "oidc", + "identifier": identifier, + "name": "Test OIDC Provider", + "client_id": "test-client-id", + "client_secret": "test-client-secret", + "issuer": issuer, + "scopes": []string{"profile", "email"}, + "pkce_enabled": true, + "enabled": true, + } +} + +func (ts *CustomOAuthAdminTestSuite) createProvider(payload map[string]interface{}, expectedStatus int) *httptest.ResponseRecorder { + var body bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&body).Encode(payload)) + + req := httptest.NewRequest(http.MethodPost, "/admin/custom-providers", &body) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ts.token)) + + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + + require.Equal(ts.T(), expectedStatus, w.Code) + + return w +} diff --git a/internal/api/external.go b/internal/api/external.go index a3597f5be..725077ee0 100644 --- a/internal/api/external.go +++ b/internal/api/external.go @@ -670,12 +670,18 @@ func (a *API) loadExternalStateFromJWT(ctx context.Context, db *storage.Connecti // Provider returns a Provider interface for the given name. func (a *API) Provider(ctx context.Context, name string, scopes string) (provider.Provider, conf.OAuthProviderConfiguration, error) { config := a.config + db := a.db.WithContext(ctx) name = strings.ToLower(name) var err error var p provider.Provider var pConfig conf.OAuthProviderConfiguration + // Check if this is a custom provider (format: custom:identifier) + if strings.HasPrefix(name, "custom:") { + return a.loadCustomProvider(ctx, db, name, scopes) + } + switch name { case "apple": pConfig = config.External.Apple @@ -759,6 +765,121 @@ func (a *API) Provider(ctx context.Context, name string, scopes string) (provide return p, pConfig, err } +// loadCustomProvider loads a custom OAuth or OIDC provider from the database +// identifier should be the full provider name with 'custom:' prefix (e.g., 'custom:github-enterprise') +func (a *API) loadCustomProvider(ctx context.Context, db *storage.Connection, identifier string, scopes string) (provider.Provider, conf.OAuthProviderConfiguration, error) { + config := a.config + var pConfig conf.OAuthProviderConfiguration + + // Build the redirect URL + redirectURL := config.API.ExternalURL + "/callback" + + // Parse scopes + var scopeList []string + if scopes != "" { + scopeList = strings.Split(scopes, ",") + // Trim whitespace + for i := range scopeList { + scopeList[i] = strings.TrimSpace(scopeList[i]) + } + } + + // Find the custom provider by identifier (which now includes 'custom:' prefix) + customProvider, err := models.FindCustomOAuthProviderByIdentifier(db, identifier) + if err != nil { + if models.IsNotFoundError(err) { + return nil, pConfig, fmt.Errorf("custom provider %s not found", identifier) + } + return nil, pConfig, fmt.Errorf("error finding custom provider: %w", err) + } + + // Check if provider is enabled + if !customProvider.Enabled { + return nil, pConfig, fmt.Errorf("custom provider %s is disabled", identifier) + } + + // Use provider scopes if not overridden + if len(scopeList) == 0 { + scopeList = customProvider.Scopes + } + + // Decrypt client secret for runtime use + clientSecret, err := customProvider.GetClientSecret(config.Security.DBEncryption) + if err != nil { + return nil, pConfig, fmt.Errorf("error decrypting client secret for provider %s: %w", identifier, err) + } + + // Handle based on provider type + if customProvider.IsOAuth2() { + // OAuth2 provider + if customProvider.AuthorizationURL == nil || customProvider.TokenURL == nil || customProvider.UserinfoURL == nil { + return nil, pConfig, fmt.Errorf("OAuth2 provider %s missing required endpoints", identifier) + } + + // Create custom OAuth provider instance + p := provider.NewCustomOAuthProvider( + customProvider.ClientID, + clientSecret, + *customProvider.AuthorizationURL, + *customProvider.TokenURL, + *customProvider.UserinfoURL, + redirectURL, + scopeList, + customProvider.PKCEEnabled, + customProvider.AcceptableClientIDs, + customProvider.AttributeMapping, + customProvider.AuthorizationParams, + ) + + // Build provider configuration + pConfig = conf.OAuthProviderConfiguration{ + Enabled: true, + ClientID: []string{customProvider.ClientID}, + Secret: clientSecret, + RedirectURI: redirectURL, + URL: *customProvider.AuthorizationURL, + EmailOptional: customProvider.EmailOptional, + } + + return p, pConfig, nil + } + + // OIDC provider + if customProvider.Issuer == nil { + return nil, pConfig, fmt.Errorf("OIDC provider %s missing issuer", identifier) + } + + // Create custom OIDC provider instance + // oidc.NewProvider() will automatically fetch discovery document + p, err := provider.NewCustomOIDCProvider( + ctx, + customProvider.ClientID, + clientSecret, + redirectURL, + scopeList, + *customProvider.Issuer, + customProvider.PKCEEnabled, + customProvider.AcceptableClientIDs, + customProvider.AttributeMapping, + customProvider.AuthorizationParams, + ) + if err != nil { + return nil, pConfig, fmt.Errorf("error creating OIDC provider: %w", err) + } + + // Build provider configuration + pConfig = conf.OAuthProviderConfiguration{ + Enabled: true, + ClientID: []string{customProvider.ClientID}, + Secret: clientSecret, + RedirectURI: redirectURL, + URL: p.Config().Endpoint.AuthURL, + EmailOptional: customProvider.EmailOptional, + } + + return p, pConfig, nil +} + func redirectErrors(handler apiHandler, w http.ResponseWriter, r *http.Request, u *url.URL) { ctx := r.Context() log := observability.GetLogEntry(r).Entry diff --git a/internal/api/helpers.go b/internal/api/helpers.go index 24643f736..33cb01816 100644 --- a/internal/api/helpers.go +++ b/internal/api/helpers.go @@ -48,6 +48,7 @@ func (a *API) requestAud(ctx context.Context, r *http.Request) string { type RequestParams interface { AdminUserParams | + AdminCustomOAuthProviderParams | CreateSSOProviderParams | EnrollFactorParams | GenerateLinkParams | diff --git a/internal/api/provider/custom_oauth.go b/internal/api/provider/custom_oauth.go new file mode 100644 index 000000000..4906ad547 --- /dev/null +++ b/internal/api/provider/custom_oauth.go @@ -0,0 +1,353 @@ +package provider + +import ( + "context" + "encoding/json" + "errors" + "fmt" + + "github.com/coreos/go-oidc/v3/oidc" + "golang.org/x/oauth2" +) + +// CustomOAuthProvider implements OAuthProvider for custom OAuth2 providers +type CustomOAuthProvider struct { + config *oauth2.Config + userinfoURL string + pkceEnabled bool + acceptableClientIDs []string + attributeMapping map[string]interface{} + authorizationParams map[string]interface{} +} + +// NewCustomOAuthProvider creates a new custom OAuth provider +func NewCustomOAuthProvider( + clientID, clientSecret, authorizationURL, tokenURL, userinfoURL, redirectURL string, + scopes []string, + pkceEnabled bool, + acceptableClientIDs []string, + attributeMapping, authorizationParams map[string]interface{}, +) *CustomOAuthProvider { + config := &oauth2.Config{ + ClientID: clientID, + ClientSecret: clientSecret, + RedirectURL: redirectURL, + Scopes: scopes, + Endpoint: oauth2.Endpoint{ + AuthURL: authorizationURL, + TokenURL: tokenURL, + }, + } + + return &CustomOAuthProvider{ + config: config, + userinfoURL: userinfoURL, + pkceEnabled: pkceEnabled, + acceptableClientIDs: acceptableClientIDs, + attributeMapping: attributeMapping, + authorizationParams: authorizationParams, + } +} + +// AuthCodeURL returns the authorization URL for the OAuth flow +func (p *CustomOAuthProvider) AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string { + // Add any additional authorization parameters + if len(p.authorizationParams) > 0 { + for key, value := range p.authorizationParams { + // Convert value to string if needed + var strValue string + switch v := value.(type) { + case string: + strValue = v + default: + // For complex types, serialize to JSON + if b, err := json.Marshal(v); err == nil { + strValue = string(b) + } + } + if strValue != "" { + opts = append(opts, oauth2.SetAuthURLParam(key, strValue)) + } + } + } + + return p.config.AuthCodeURL(state, opts...) +} + +// GetOAuthToken exchanges the authorization code for an access token +func (p *CustomOAuthProvider) GetOAuthToken(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) { + return p.config.Exchange(ctx, code, opts...) +} + +// GetUserData fetches user data from the provider's userinfo endpoint +func (p *CustomOAuthProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*UserProvidedData, error) { + var claims Claims + if err := makeRequest(ctx, tok, p.config, p.userinfoURL, &claims); err != nil { + return nil, err + } + + // Apply attribute mapping if configured + if len(p.attributeMapping) > 0 { + claims = applyAttributeMapping(claims, p.attributeMapping) + } + + // Extract emails + emails := []Email{} + if claims.Email != "" { + emails = append(emails, Email{ + Email: claims.Email, + Verified: claims.EmailVerified, + Primary: true, + }) + } + + return &UserProvidedData{ + Emails: emails, + Metadata: &claims, + }, nil +} + +// RequiresPKCE returns whether this provider requires PKCE +func (p *CustomOAuthProvider) RequiresPKCE() bool { + return p.pkceEnabled +} + +// CustomOIDCProvider implements OAuthProvider for custom OIDC providers +type CustomOIDCProvider struct { + config *oauth2.Config + oidcProvider *oidc.Provider + userinfoEndpoint string + pkceEnabled bool + acceptableClientIDs []string + attributeMapping map[string]interface{} + authorizationParams map[string]interface{} +} + +// NewCustomOIDCProvider creates a new custom OIDC provider +func NewCustomOIDCProvider( + ctx context.Context, + clientID, clientSecret, redirectURL string, + scopes []string, + issuer string, + pkceEnabled bool, + acceptableClientIDs []string, + attributeMapping, authorizationParams map[string]interface{}, +) (*CustomOIDCProvider, error) { + // Ensure 'openid' scope is always present for OIDC + hasOpenID := false + for _, scope := range scopes { + if scope == "openid" { + hasOpenID = true + break + } + } + if !hasOpenID { + scopes = append([]string{"openid"}, scopes...) + } + + // Create OIDC provider - this automatically fetches discovery document + oidcProvider, err := oidc.NewProvider(ctx, issuer) + if err != nil { + return nil, fmt.Errorf("failed to create OIDC provider: %w", err) + } + + // Get endpoints from the OIDC provider + endpoint := oidcProvider.Endpoint() + userinfoEndpoint := oidcProvider.UserInfoEndpoint() + + config := &oauth2.Config{ + ClientID: clientID, + ClientSecret: clientSecret, + RedirectURL: redirectURL, + Scopes: scopes, + Endpoint: endpoint, + } + + return &CustomOIDCProvider{ + config: config, + oidcProvider: oidcProvider, + userinfoEndpoint: userinfoEndpoint, + pkceEnabled: pkceEnabled, + acceptableClientIDs: acceptableClientIDs, + attributeMapping: attributeMapping, + authorizationParams: authorizationParams, + }, nil +} + +// AuthCodeURL returns the authorization URL for the OIDC flow +func (p *CustomOIDCProvider) AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string { + // Add any additional authorization parameters + if len(p.authorizationParams) > 0 { + for key, value := range p.authorizationParams { + // Convert value to string if needed + var strValue string + switch v := value.(type) { + case string: + strValue = v + default: + // For complex types, serialize to JSON + if b, err := json.Marshal(v); err == nil { + strValue = string(b) + } + } + if strValue != "" { + opts = append(opts, oauth2.SetAuthURLParam(key, strValue)) + } + } + } + + return p.config.AuthCodeURL(state, opts...) +} + +// GetOAuthToken exchanges the authorization code for an access token +func (p *CustomOIDCProvider) GetOAuthToken(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) { + return p.config.Exchange(ctx, code, opts...) +} + +// GetUserData fetches user data from the provider's userinfo endpoint or ID token +func (p *CustomOIDCProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*UserProvidedData, error) { + // First, try to extract and verify claims from ID token if present + if idToken, ok := tok.Extra("id_token").(string); ok && idToken != "" { + // Skip client ID check in the library and validate manually to support multiple client IDs + idTokenObj, userData, err := ParseIDToken(ctx, p.oidcProvider, &oidc.Config{ + SkipClientIDCheck: true, // We'll validate audience manually + }, idToken, ParseIDTokenOptions{ + SkipAccessTokenCheck: true, // We don't need at_hash validation in callback flow + }) + if err != nil { + // If ID token verification fails, fall back to userinfo endpoint + if p.userinfoEndpoint != "" { + var claims Claims + if err := makeRequest(ctx, tok, p.config, p.userinfoEndpoint, &claims); err != nil { + return nil, fmt.Errorf("failed to verify ID token and fetch userinfo: %w", err) + } + + // Apply attribute mapping + if len(p.attributeMapping) > 0 { + claims = applyAttributeMapping(claims, p.attributeMapping) + } + + // Extract emails + emails := []Email{} + if claims.Email != "" { + emails = append(emails, Email{ + Email: claims.Email, + Verified: claims.EmailVerified, + Primary: true, + }) + } + + return &UserProvidedData{ + Emails: emails, + Metadata: &claims, + }, nil + } + return nil, fmt.Errorf("failed to verify ID token and no userinfo endpoint available: %w", err) + } + + // Validate audience claim against acceptable client IDs + if err := p.validateAudience(idTokenObj.Audience); err != nil { + return nil, err + } + + // Apply attribute mapping to the metadata from ID token + if len(p.attributeMapping) > 0 && userData.Metadata != nil { + *userData.Metadata = applyAttributeMapping(*userData.Metadata, p.attributeMapping) + } + + return userData, nil + } + + // No ID token, use userinfo endpoint + if p.userinfoEndpoint != "" { + var claims Claims + if err := makeRequest(ctx, tok, p.config, p.userinfoEndpoint, &claims); err != nil { + return nil, err + } + + // Apply attribute mapping + if len(p.attributeMapping) > 0 { + claims = applyAttributeMapping(claims, p.attributeMapping) + } + + // Extract emails + emails := []Email{} + if claims.Email != "" { + emails = append(emails, Email{ + Email: claims.Email, + Verified: claims.EmailVerified, + Primary: true, + }) + } + + return &UserProvidedData{ + Emails: emails, + Metadata: &claims, + }, nil + } + + return nil, errors.New("no ID token or userinfo endpoint available") +} + +// RequiresPKCE returns whether this provider requires PKCE +func (p *CustomOIDCProvider) RequiresPKCE() bool { + return p.pkceEnabled +} + +// Config returns the OAuth2 config for accessing endpoints +func (p *CustomOIDCProvider) Config() *oauth2.Config { + return p.config +} + +// validateAudience validates that the token's audience matches one of the acceptable client IDs +func (p *CustomOIDCProvider) validateAudience(audiences []string) error { + // Build list of acceptable audiences: main client_id + acceptable_client_ids + acceptableAudiences := append([]string{p.config.ClientID}, p.acceptableClientIDs...) + + // Check if any audience in the token matches any acceptable audience + for _, tokenAud := range audiences { + for _, acceptableAud := range acceptableAudiences { + if tokenAud == acceptableAud { + return nil // Valid audience found + } + } + } + + // No valid audience found + return fmt.Errorf("token audience %v does not match any acceptable client ID", audiences) +} + +// applyAttributeMapping applies custom attribute mapping to claims +func applyAttributeMapping(claims Claims, mapping map[string]interface{}) Claims { + // Create a map representation of claims for easier manipulation + claimsMap := make(map[string]interface{}) + claimsBytes, _ := json.Marshal(claims) + if err := json.Unmarshal(claimsBytes, &claimsMap); err != nil { + // If unmarshaling fails, return original claims + return claims + } + + // Apply mappings + for targetField, sourceFieldOrValue := range mapping { + switch v := sourceFieldOrValue.(type) { + case string: + // If it's a string, treat it as a source field name + if value, exists := claimsMap[v]; exists { + claimsMap[targetField] = value + } + default: + // Otherwise, use it as a literal value + claimsMap[targetField] = v + } + } + + // Convert back to Claims struct + var result Claims + mappedBytes, _ := json.Marshal(claimsMap) + if err := json.Unmarshal(mappedBytes, &result); err != nil { + // If unmarshaling fails, return original claims + return claims + } + + return result +} diff --git a/internal/api/provider/custom_oauth_test.go b/internal/api/provider/custom_oauth_test.go new file mode 100644 index 000000000..1e51fa882 --- /dev/null +++ b/internal/api/provider/custom_oauth_test.go @@ -0,0 +1,514 @@ +package provider + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/oauth2" +) + +func TestNewCustomOAuthProvider(t *testing.T) { + provider := NewCustomOAuthProvider( + "test-client-id", + "test-client-secret", + "https://example.com/authorize", + "https://example.com/token", + "https://example.com/userinfo", + "https://myapp.com/callback", + []string{"openid", "profile"}, + true, // PKCE enabled + []string{"ios-client-id", "android-client-id"}, + map[string]interface{}{ + "email": "user_email", + }, + map[string]interface{}{ + "prompt": "consent", + }, + ) + + assert.NotNil(t, provider) + assert.Equal(t, "test-client-id", provider.config.ClientID) + assert.Equal(t, "test-client-secret", provider.config.ClientSecret) + assert.Equal(t, "https://myapp.com/callback", provider.config.RedirectURL) + assert.Equal(t, []string{"openid", "profile"}, provider.config.Scopes) + assert.Equal(t, "https://example.com/authorize", provider.config.Endpoint.AuthURL) + assert.Equal(t, "https://example.com/token", provider.config.Endpoint.TokenURL) + assert.Equal(t, "https://example.com/userinfo", provider.userinfoURL) + assert.True(t, provider.RequiresPKCE()) + assert.Equal(t, []string{"ios-client-id", "android-client-id"}, provider.acceptableClientIDs) + assert.Equal(t, "user_email", provider.attributeMapping["email"]) + assert.Equal(t, "consent", provider.authorizationParams["prompt"]) +} + +func TestCustomOAuthProvider_AuthCodeURL(t *testing.T) { + t.Run("Auth URL with authorization params", func(t *testing.T) { + provider := NewCustomOAuthProvider( + "client-id", + "client-secret", + "https://example.com/authorize", + "https://example.com/token", + "https://example.com/userinfo", + "https://myapp.com/callback", + []string{"openid", "profile"}, + false, + nil, + nil, + map[string]interface{}{ + "prompt": "consent", + "access_type": "offline", + "custom_param": "custom_value", + }, + ) + + authURL := provider.AuthCodeURL("test-state") + + assert.Contains(t, authURL, "client_id=client-id") + assert.Contains(t, authURL, "redirect_uri=https") + assert.Contains(t, authURL, "response_type=code") + assert.Contains(t, authURL, "state=test-state") + assert.Contains(t, authURL, "prompt=consent") + assert.Contains(t, authURL, "access_type=offline") + assert.Contains(t, authURL, "custom_param=custom_value") + }) + + t.Run("Auth URL with complex authorization params (JSON serialization)", func(t *testing.T) { + provider := NewCustomOAuthProvider( + "client-id", + "client-secret", + "https://example.com/authorize", + "https://example.com/token", + "https://example.com/userinfo", + "https://myapp.com/callback", + []string{"openid"}, + false, + nil, + nil, + map[string]interface{}{ + "complex_param": map[string]interface{}{ + "key": "value", + }, + }, + ) + + authURL := provider.AuthCodeURL("test-state") + + // Complex params should be JSON serialized + assert.Contains(t, authURL, "complex_param=") + }) +} + +func TestCustomOAuthProvider_GetUserData(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Verify bearer token + authHeader := r.Header.Get("Authorization") + if authHeader != "Bearer test-access-token" { + w.WriteHeader(http.StatusUnauthorized) + return + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]interface{}{ + "sub": "user-123", + "email": "test@example.com", + "email_verified": true, + "name": "Test User", + "picture": "https://example.com/avatar.jpg", + }) + })) + defer server.Close() + + provider := NewCustomOAuthProvider( + "client-id", + "client-secret", + "https://example.com/authorize", + "https://example.com/token", + server.URL, // userinfo URL + "https://myapp.com/callback", + []string{"openid", "profile", "email"}, + false, + nil, + nil, + nil, + ) + + token := &oauth2.Token{ + AccessToken: "test-access-token", + TokenType: "Bearer", + } + + userData, err := provider.GetUserData(context.Background(), token) + require.NoError(t, err) + require.NotNil(t, userData) + + require.Len(t, userData.Emails, 1) + assert.Equal(t, "test@example.com", userData.Emails[0].Email) + assert.True(t, userData.Emails[0].Verified) + assert.True(t, userData.Emails[0].Primary) +} + +func TestCustomOAuthProvider_GetUserDataWithAttributeMapping(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]interface{}{ + "sub": "user-123", + "email": "test@example.com", + "email_verified": false, // Will be overridden by literal mapping + "full_name": "John Doe", + }) + })) + defer server.Close() + + provider := NewCustomOAuthProvider( + "client-id", + "client-secret", + "https://example.com/authorize", + "https://example.com/token", + server.URL, + "https://myapp.com/callback", + []string{"openid"}, + false, + nil, + map[string]interface{}{ + "email_verified": true, // Override with literal boolean value + "name": "full_name", // Map full_name field to name + }, + nil, + ) + + token := &oauth2.Token{ + AccessToken: "test-access-token", + TokenType: "Bearer", + } + + userData, err := provider.GetUserData(context.Background(), token) + require.NoError(t, err) + require.NotNil(t, userData) + + require.Len(t, userData.Emails, 1) + assert.Equal(t, "test@example.com", userData.Emails[0].Email) + assert.True(t, userData.Emails[0].Verified) // Should be true from literal mapping +} + +func TestApplyAttributeMapping(t *testing.T) { + tests := []struct { + name string + claims Claims + mapping map[string]interface{} + expected Claims + }{ + { + name: "Map with literal non-string values", + claims: Claims{ + Subject: "user-456", + Email: "test@example.com", + }, + mapping: map[string]interface{}{ + "email_verified": true, // Literal boolean value + "iat": float64(1234567890), // Literal number value + }, + expected: Claims{ + Subject: "user-456", + Email: "test@example.com", + EmailVerified: true, + Iat: float64(1234567890), + }, + }, + { + name: "Map between existing fields", + claims: Claims{ + Subject: "user-123", + Email: "test@example.com", + FullName: "John Doe", + AvatarURL: "https://example.com/avatar.jpg", + }, + mapping: map[string]interface{}{ + "name": "full_name", // Map full_name -> name + "picture": "avatar_url", // Map avatar_url -> picture + }, + expected: Claims{ + Subject: "user-123", + Email: "test@example.com", + Name: "John Doe", + Picture: "https://example.com/avatar.jpg", + FullName: "John Doe", // Original field still exists + AvatarURL: "https://example.com/avatar.jpg", + }, + }, + { + name: "Empty mapping returns original claims", + claims: Claims{ + Subject: "user-789", + Email: "unchanged@example.com", + }, + mapping: map[string]interface{}{}, + expected: Claims{ + Subject: "user-789", + Email: "unchanged@example.com", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := applyAttributeMapping(tt.claims, tt.mapping) + + assert.Equal(t, tt.expected.Subject, result.Subject) + assert.Equal(t, tt.expected.Email, result.Email) + assert.Equal(t, tt.expected.EmailVerified, result.EmailVerified) + assert.Equal(t, tt.expected.Name, result.Name) + if tt.expected.Picture != "" { + assert.Equal(t, tt.expected.Picture, result.Picture) + } + }) + } +} + +func TestNewCustomOIDCProvider(t *testing.T) { + // Mock OIDC provider server + var server *httptest.Server + server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/.well-known/openid-configuration" { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]interface{}{ + "issuer": server.URL, + "authorization_endpoint": server.URL + "/authorize", + "token_endpoint": server.URL + "/token", + "userinfo_endpoint": server.URL + "/userinfo", + "jwks_uri": server.URL + "/jwks", + }) + } else if r.URL.Path == "/jwks" { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]interface{}{ + "keys": []interface{}{}, + }) + } + })) + defer server.Close() + + // Pass issuer URL directly - oidc.NewProvider will fetch discovery automatically + provider, err := NewCustomOIDCProvider( + context.Background(), + "test-client-id", + "test-client-secret", + "https://myapp.com/callback", + []string{"profile", "email"}, // Without openid + server.URL, // issuer + true, // PKCE enabled + []string{"ios-client", "android-client"}, + map[string]interface{}{"email": "user_email"}, + map[string]interface{}{"prompt": "consent"}, + ) + + require.NoError(t, err) + require.NotNil(t, provider) + + // Verify openid scope was automatically added + assert.Contains(t, provider.config.Scopes, "openid") + assert.Contains(t, provider.config.Scopes, "profile") + assert.Contains(t, provider.config.Scopes, "email") + + assert.True(t, provider.RequiresPKCE()) + assert.Equal(t, []string{"ios-client", "android-client"}, provider.acceptableClientIDs) +} + +func TestCustomOIDCProvider_ValidateAudience(t *testing.T) { + tests := []struct { + name string + clientID string + acceptableClientIDs []string + tokenAudiences []string + wantErr bool + }{ + { + name: "Valid single audience matches client ID", + clientID: "web-client-id", + acceptableClientIDs: nil, + tokenAudiences: []string{"web-client-id"}, + wantErr: false, + }, + { + name: "Valid audience matches one of acceptable client IDs", + clientID: "web-client-id", + acceptableClientIDs: []string{"ios-client-id", "android-client-id"}, + tokenAudiences: []string{"ios-client-id"}, + wantErr: false, + }, + { + name: "Valid audience matches different acceptable client ID", + clientID: "web-client-id", + acceptableClientIDs: []string{"ios-client-id", "android-client-id"}, + tokenAudiences: []string{"android-client-id"}, + wantErr: false, + }, + { + name: "Valid multiple audiences, one matches", + clientID: "web-client-id", + acceptableClientIDs: []string{"ios-client-id"}, + tokenAudiences: []string{"web-client-id", "other-client-id"}, + wantErr: false, + }, + { + name: "Invalid - no matching audience", + clientID: "web-client-id", + acceptableClientIDs: []string{"ios-client-id", "android-client-id"}, + tokenAudiences: []string{"unknown-client-id"}, + wantErr: true, + }, + { + name: "Invalid - empty token audiences", + clientID: "web-client-id", + acceptableClientIDs: []string{"ios-client-id"}, + tokenAudiences: []string{}, + wantErr: true, + }, + { + name: "Valid - multiple acceptable client IDs, multi-platform scenario", + clientID: "web-client-id", + acceptableClientIDs: []string{"com.myapp.ios", "com.myapp.android", "com.myapp.macos"}, + tokenAudiences: []string{"com.myapp.ios"}, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create a minimal OIDC provider for testing validateAudience + provider := &CustomOIDCProvider{ + config: &oauth2.Config{ + ClientID: tt.clientID, + }, + acceptableClientIDs: tt.acceptableClientIDs, + } + + err := provider.validateAudience(tt.tokenAudiences) + + if tt.wantErr { + assert.Error(t, err) + assert.Contains(t, err.Error(), "does not match any acceptable client ID") + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestCustomOIDCProvider_AuthCodeURL(t *testing.T) { + // Mock OIDC provider server + var server *httptest.Server + server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/.well-known/openid-configuration" { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]interface{}{ + "issuer": server.URL, + "authorization_endpoint": server.URL + "/authorize", + "token_endpoint": server.URL + "/token", + "jwks_uri": server.URL + "/jwks", + }) + } else if r.URL.Path == "/jwks" { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]interface{}{ + "keys": []interface{}{}, + }) + } + })) + defer server.Close() + + // Pass issuer URL directly - oidc.NewProvider will fetch discovery automatically + provider, err := NewCustomOIDCProvider( + context.Background(), + "client-id", + "client-secret", + "https://myapp.com/callback", + []string{"openid", "profile"}, + server.URL, // issuer + false, + nil, + nil, + map[string]interface{}{ + "prompt": "consent", + "max_age": "3600", + "ui_locales": "en", + "login_hint": "user@example.com", + }, + ) + + require.NoError(t, err) + + authURL := provider.AuthCodeURL("test-state") + + // Verify standard OAuth2 params + assert.Contains(t, authURL, "client_id=client-id") + assert.Contains(t, authURL, "state=test-state") + assert.Contains(t, authURL, "response_type=code") + + // Verify custom authorization params + assert.Contains(t, authURL, "prompt=consent") + assert.Contains(t, authURL, "max_age=3600") + assert.Contains(t, authURL, "ui_locales=en") + assert.Contains(t, authURL, "login_hint=user") +} + +func TestCustomOIDCProvider_RequiresPKCE(t *testing.T) { + // Mock OIDC provider server + var server *httptest.Server + server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/.well-known/openid-configuration" { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]interface{}{ + "issuer": server.URL, + "authorization_endpoint": server.URL + "/authorize", + "token_endpoint": server.URL + "/token", + "jwks_uri": server.URL + "/jwks", + }) + } else if r.URL.Path == "/jwks" { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]interface{}{ + "keys": []interface{}{}, + }) + } + })) + defer server.Close() + + t.Run("PKCE enabled", func(t *testing.T) { + // Pass issuer URL directly - oidc.NewProvider will fetch discovery automatically + provider, err := NewCustomOIDCProvider( + context.Background(), + "client-id", + "client-secret", + "https://myapp.com/callback", + []string{"openid"}, + server.URL, // issuer + true, // PKCE enabled + nil, + nil, + nil, + ) + + require.NoError(t, err) + assert.True(t, provider.RequiresPKCE()) + }) + + t.Run("PKCE disabled", func(t *testing.T) { + // Pass issuer URL directly - oidc.NewProvider will fetch discovery automatically + provider, err := NewCustomOIDCProvider( + context.Background(), + "client-id", + "client-secret", + "https://myapp.com/callback", + []string{"openid"}, + server.URL, // issuer + false, // PKCE disabled + nil, + nil, + nil, + ) + + require.NoError(t, err) + assert.False(t, provider.RequiresPKCE()) + }) +} diff --git a/internal/api/provider/provider.go b/internal/api/provider/provider.go index f7acb91e6..c0e31a7be 100644 --- a/internal/api/provider/provider.go +++ b/internal/api/provider/provider.go @@ -43,6 +43,42 @@ func (a *audience) UnmarshalJSON(b []byte) error { return nil } +// UnixTimeOrString accepts either: +// - number: seconds since epoch (OIDC NumericDate) +// - string: RFC3339 timestamp +type UnixTimeOrString time.Time + +func (t *UnixTimeOrString) UnmarshalJSON(b []byte) error { + // null + if bytes.Equal(b, []byte("null")) { + *t = UnixTimeOrString(time.Time{}) + return nil + } + + // number (possibly float) + var f float64 + if err := json.Unmarshal(b, &f); err == nil { + sec := int64(f) + *t = UnixTimeOrString(time.Unix(sec, 0).UTC()) + return nil + } + + // string (RFC3339) + var s string + if err := json.Unmarshal(b, &s); err != nil { + return err + } + if s == "" { + *t = UnixTimeOrString(time.Time{}) + return nil + } + if sec, err := time.Parse(time.RFC3339, s); err == nil { + *t = UnixTimeOrString(sec.UTC()) + return nil + } + return &time.ParseError{Layout: time.RFC3339, Value: s} +} + type Claims struct { // Reserved claims Issuer string `json:"iss,omitempty" structs:"iss,omitempty"` @@ -65,7 +101,7 @@ type Claims struct { Birthdate string `json:"birthdate,omitempty" structs:"birthdate,omitempty"` ZoneInfo string `json:"zoneinfo,omitempty" structs:"zoneinfo,omitempty"` Locale string `json:"locale,omitempty" structs:"locale,omitempty"` - UpdatedAt string `json:"updated_at,omitempty" structs:"updated_at,omitempty"` + UpdatedAt *UnixTimeOrString `json:"updated_at,omitempty" structs:"updated_at,omitempty"` Email string `json:"email,omitempty" structs:"email,omitempty"` EmailVerified bool `json:"email_verified,omitempty" structs:"email_verified"` Phone string `json:"phone,omitempty" structs:"phone,omitempty"` diff --git a/internal/api/provider/provider_test.go b/internal/api/provider/provider_test.go new file mode 100644 index 000000000..e00c2fe3c --- /dev/null +++ b/internal/api/provider/provider_test.go @@ -0,0 +1,25 @@ +package provider + +import ( + "encoding/json" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestClaimsUpdatedAt_Unmarshal(t *testing.T) { + t.Run("numeric date seconds", func(t *testing.T) { + var c Claims + require.NoError(t, json.Unmarshal([]byte(`{"updated_at": 1700000000}`), &c)) + require.NotNil(t, c.UpdatedAt) + require.Equal(t, int64(1700000000), time.Time(*c.UpdatedAt).Unix()) + }) + + t.Run("rfc3339 string", func(t *testing.T) { + var c Claims + require.NoError(t, json.Unmarshal([]byte(`{"updated_at": "2024-01-02T03:04:05Z"}`), &c)) + require.NotNil(t, c.UpdatedAt) + require.Equal(t, int64(1704164645), time.Time(*c.UpdatedAt).Unix()) + }) +} diff --git a/internal/api/token_oidc.go b/internal/api/token_oidc.go index 817a9b059..a1c6776e5 100644 --- a/internal/api/token_oidc.go +++ b/internal/api/token_oidc.go @@ -29,7 +29,7 @@ type IdTokenGrantParams struct { LinkIdentity bool `json:"link_identity"` } -func (p *IdTokenGrantParams) getProvider(ctx context.Context, config *conf.GlobalConfiguration, r *http.Request) (*oidc.Provider, bool, string, []string, bool, error) { +func (p *IdTokenGrantParams) getProvider(ctx context.Context, db *storage.Connection, config *conf.GlobalConfiguration, r *http.Request) (*oidc.Provider, bool, string, []string, bool, error) { log := observability.GetLogEntry(r).Entry var cfg *conf.OAuthProviderConfiguration @@ -123,6 +123,40 @@ func (p *IdTokenGrantParams) getProvider(ctx context.Context, config *conf.Globa issuer = provider.IssuerSnapchat acceptableClientIDs = append(acceptableClientIDs, config.External.Snapchat.ClientID...) + case strings.HasPrefix(p.Provider, "custom:"): + // Custom OIDC provider - identifier already includes 'custom:' prefix + customProvider, err := models.FindCustomOAuthProviderByIdentifier(db, p.Provider) + if err != nil { + if models.IsNotFoundError(err) { + return nil, false, "", nil, false, apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "Custom provider %q not found", p.Provider) + } + return nil, false, "", nil, false, apierrors.NewInternalServerError("Error finding custom provider").WithInternalError(err) + } + + if !customProvider.Enabled { + return nil, false, "", nil, false, apierrors.NewBadRequestError(apierrors.ErrorCodeProviderDisabled, "Custom provider %q is disabled", p.Provider) + } + + // Ensure it's an OIDC provider + if !customProvider.IsOIDC() { + return nil, false, "", nil, false, apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "Provider %q is not an OIDC provider", p.Provider) + } + + if customProvider.Issuer == nil { + return nil, false, "", nil, false, apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "OIDC provider %q missing issuer", p.Provider) + } + + providerType = p.Provider + issuer = *customProvider.Issuer + acceptableClientIDs = append(acceptableClientIDs, customProvider.ClientID) + acceptableClientIDs = append(acceptableClientIDs, customProvider.AcceptableClientIDs...) + + cfg = &conf.OAuthProviderConfiguration{ + Enabled: true, // already checked above + SkipNonceCheck: customProvider.SkipNonceCheck, + EmailOptional: customProvider.EmailOptional, + } + default: log.WithField("issuer", p.Issuer).WithField("client_id", p.ClientID).Warn("Use of POST /token with arbitrary issuer and client_id is deprecated for security reasons. Please switch to using the API with provider only!") @@ -200,7 +234,7 @@ func (a *API) IdTokenGrant(ctx context.Context, w http.ResponseWriter, r *http.R ctx = withTargetUser(ctx, targetUser) } - oidcProvider, skipNonceCheck, providerType, acceptableClientIDs, emailOptional, err := params.getProvider(ctx, config, r) + oidcProvider, skipNonceCheck, providerType, acceptableClientIDs, emailOptional, err := params.getProvider(ctx, db, config, r) if err != nil { return err } diff --git a/internal/api/token_oidc_test.go b/internal/api/token_oidc_test.go index 2ca589370..4cd0e5fb3 100644 --- a/internal/api/token_oidc_test.go +++ b/internal/api/token_oidc_test.go @@ -62,7 +62,7 @@ func (ts *TokenOIDCTestSuite) TestGetProvider() { ts.Config.External.AllowedIdTokenIssuers = []string{server.URL} req := httptest.NewRequest(http.MethodPost, "http://localhost", nil) - oidcProvider, skipNonceCheck, providerType, acceptableClientIds, emailOptional, err := params.getProvider(context.Background(), ts.Config, req) + oidcProvider, skipNonceCheck, providerType, acceptableClientIds, emailOptional, err := params.getProvider(context.Background(), ts.API.db, ts.Config, req) require.NoError(ts.T(), err) require.NotNil(ts.T(), oidcProvider) require.False(ts.T(), skipNonceCheck) @@ -115,7 +115,7 @@ func (ts *TokenOIDCTestSuite) TestGetProviderAppleWithIncorrectIssuer() { } req := httptest.NewRequest(http.MethodPost, "http://localhost", nil) - _, _, _, _, _, err := params.getProvider(context.Background(), ts.Config, req) + _, _, _, _, _, err := params.getProvider(context.Background(), ts.API.db, ts.Config, req) require.Error(ts.T(), err) require.Contains(ts.T(), err.Error(), "not an Apple ID token issuer") @@ -138,7 +138,7 @@ func (ts *TokenOIDCTestSuite) TestGetProviderAzureWithNonAzureTokenIssuer() { } req := httptest.NewRequest(http.MethodPost, "http://localhost", nil) - _, _, _, _, _, err := params.getProvider(context.Background(), ts.Config, req) + _, _, _, _, _, err := params.getProvider(context.Background(), ts.API.db, ts.Config, req) // This should fail - the token's issuer is not an accepted issuer require.Error(ts.T(), err) @@ -162,7 +162,7 @@ func (ts *TokenOIDCTestSuite) TestGetProviderAppleWithNonAppleIssuerInToken() { } req := httptest.NewRequest(http.MethodPost, "http://localhost", nil) - _, _, _, _, _, err := params.getProvider(context.Background(), ts.Config, req) + _, _, _, _, _, err := params.getProvider(context.Background(), ts.API.db, ts.Config, req) // This should fail - the token's actual issuer is not appleid.apple.com require.Error(ts.T(), err) diff --git a/internal/conf/configuration.go b/internal/conf/configuration.go index 3e397be69..7cd4320d5 100644 --- a/internal/conf/configuration.go +++ b/internal/conf/configuration.go @@ -85,6 +85,12 @@ type AnonymousProviderConfiguration struct { Enabled bool `json:"enabled" default:"false"` } +// CustomOAuthConfiguration holds configuration for custom OAuth and OIDC providers +type CustomOAuthConfiguration struct { + Enabled bool `json:"enabled" split_words:"true" default:"false"` + MaxProviders int `json:"max_providers" split_words:"true" default:"0"` +} + type EmailProviderConfiguration struct { Enabled bool `json:"enabled" default:"true"` @@ -318,6 +324,7 @@ type GlobalConfiguration struct { API APIConfiguration DB DBConfiguration External ProviderConfiguration + CustomOAuth CustomOAuthConfiguration `envconfig:"CUSTOM_OAUTH"` OAuthServer OAuthServerConfiguration `envconfig:"OAUTH_SERVER"` Logging LoggingConfig `envconfig:"LOG"` Profiler ProfilerConfig `envconfig:"PROFILER"` diff --git a/internal/models/connection.go b/internal/models/connection.go index 82a5e8775..363c16980 100644 --- a/internal/models/connection.go +++ b/internal/models/connection.go @@ -50,6 +50,7 @@ func TruncateAll(conn *storage.Connection) error { (&pop.Model{Value: FlowState{}}).TableName(), (&pop.Model{Value: OneTimeToken{}}).TableName(), (&pop.Model{Value: OAuthServerClient{}}).TableName(), + (&pop.Model{Value: CustomOAuthProvider{}}).TableName(), } for _, tableName := range tables { diff --git a/internal/models/custom_oauth_provider.go b/internal/models/custom_oauth_provider.go new file mode 100644 index 000000000..fe43f5f67 --- /dev/null +++ b/internal/models/custom_oauth_provider.go @@ -0,0 +1,430 @@ +package models + +import ( + "database/sql" + "database/sql/driver" + "encoding/json" + "fmt" + "strings" + "time" + + "github.com/gofrs/uuid" + "github.com/pkg/errors" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/crypto" + "github.com/supabase/auth/internal/storage" +) + +// ProviderType represents the type of OAuth provider +type ProviderType string + +const ( + ProviderTypeOAuth2 ProviderType = "oauth2" + ProviderTypeOIDC ProviderType = "oidc" +) + +// CustomOAuthProvider represents a custom OAuth2 or OIDC provider configuration +type CustomOAuthProvider struct { + ID uuid.UUID `db:"id" json:"id"` + ProviderType ProviderType `db:"provider_type" json:"provider_type"` + + // Common fields for both OAuth2 and OIDC + Identifier string `db:"identifier" json:"identifier"` + Name string `db:"name" json:"name"` + ClientID string `db:"client_id" json:"client_id"` + // TODO: Implement client_secret encryption + // + // Current state: Client secrets are stored as plaintext in the database. + // + // Security concern: OAuth client secrets are sensitive credentials that should be + // encrypted at rest. If the database is compromised, attackers could use these + // secrets to impersonate the auth server to external OAuth providers. + // + // What's needed: + // 1. Choose encryption approach: + // - Application-level encryption using a master key (stored in env/secrets manager) + // - Database-level encryption (PostgreSQL pgcrypto extension) + // 2. Implement encryption when storing (CreateCustomOAuthProvider, UpdateCustomOAuthProvider) + // 3. Implement decryption when retrieving (all Find* functions) + // 4. Consider key rotation strategy + ClientSecret string `db:"client_secret" json:"-"` // Never expose in JSON + AcceptableClientIDs StringSlice `db:"acceptable_client_ids" json:"acceptable_client_ids"` + Scopes StringSlice `db:"scopes" json:"scopes"` + PKCEEnabled bool `db:"pkce_enabled" json:"pkce_enabled"` + AttributeMapping OAuthAttributeMapping `db:"attribute_mapping" json:"attribute_mapping"` + AuthorizationParams OAuthAuthorizationParams `db:"authorization_params" json:"authorization_params"` + Enabled bool `db:"enabled" json:"enabled"` + EmailOptional bool `db:"email_optional" json:"email_optional"` + + // OIDC-specific fields (null for OAuth2 providers) + Issuer *string `db:"issuer" json:"issuer,omitempty"` + DiscoveryURL *string `db:"discovery_url" json:"discovery_url,omitempty"` + SkipNonceCheck bool `db:"skip_nonce_check" json:"skip_nonce_check"` + CachedDiscovery *OIDCDiscovery `db:"cached_discovery" json:"-"` // Internal caching, not exposed in API + DiscoveryCachedAt *time.Time `db:"discovery_cached_at" json:"-"` // Internal caching, not exposed in API + + // OAuth2-specific fields (null for OIDC providers) + AuthorizationURL *string `db:"authorization_url" json:"authorization_url,omitempty"` + TokenURL *string `db:"token_url" json:"token_url,omitempty"` + UserinfoURL *string `db:"userinfo_url" json:"userinfo_url,omitempty"` + JwksURI *string `db:"jwks_uri" json:"jwks_uri,omitempty"` + + // Timestamps + CreatedAt time.Time `db:"created_at" json:"created_at"` + UpdatedAt time.Time `db:"updated_at" json:"updated_at"` +} + +func (p CustomOAuthProvider) TableName() string { + return "custom_oauth_providers" +} + +// SetClientSecret encrypts and stores the client secret using the configured +// database encryption settings. If encryption is disabled, the secret is +// stored in plaintext (temporary fallback for now) +func (p *CustomOAuthProvider) SetClientSecret(secret string, dbEncryption conf.DatabaseEncryptionConfiguration) error { + if !dbEncryption.Encrypt { + // Fallback: store in plaintext when encryption is not enabled. + p.ClientSecret = secret + return nil + } + + if dbEncryption.EncryptionKeyID == "" || dbEncryption.EncryptionKey == "" { + return errors.New("database encryption key configuration is invalid") + } + + es, err := crypto.NewEncryptedString(p.ID.String(), []byte(secret), dbEncryption.EncryptionKeyID, dbEncryption.EncryptionKey) + if err != nil { + return errors.Wrap(err, "error encrypting custom OAuth client secret") + } + + p.ClientSecret = es.String() + return nil +} + +// GetClientSecret decrypts and returns the client secret using the configured +// database decryption keys. It expects the client secret to be stored in +// encrypted form when encryption is enabled, but will also handle plaintext +// secrets (for deployments where encryption is not yet configured). +func (p *CustomOAuthProvider) GetClientSecret(dbEncryption conf.DatabaseEncryptionConfiguration) (string, error) { + if p.ClientSecret == "" { + return "", nil + } + + es := crypto.ParseEncryptedString(p.ClientSecret) + if es == nil { + // Not an encrypted string – treat as plaintext. + return p.ClientSecret, nil + } + + if dbEncryption.DecryptionKeys == nil { + return "", errors.New("database decryption keys not configured") + } + + bytes, err := es.Decrypt(p.ID.String(), dbEncryption.DecryptionKeys) + if err != nil { + return "", errors.Wrap(err, "error decrypting custom OAuth client secret") + } + + return string(bytes), nil +} + +// IsOIDC returns true if this is an OIDC provider +func (p *CustomOAuthProvider) IsOIDC() bool { + return p.ProviderType == ProviderTypeOIDC +} + +// IsOAuth2 returns true if this is an OAuth2 provider +func (p *CustomOAuthProvider) IsOAuth2() bool { + return p.ProviderType == ProviderTypeOAuth2 +} + +// GetProviderName returns the provider identifier (which already includes "custom:" prefix) +func (p *CustomOAuthProvider) GetProviderName() string { + return p.Identifier +} + +// GetDiscoveryURL returns the discovery URL for OIDC providers +// If discovery_url is set, use that; otherwise construct from issuer +func (p *CustomOAuthProvider) GetDiscoveryURL() string { + if !p.IsOIDC() || p.Issuer == nil { + return "" + } + + if p.DiscoveryURL != nil && *p.DiscoveryURL != "" { + return *p.DiscoveryURL + } + + return *p.Issuer + "/.well-known/openid-configuration" +} + +// StringSlice handles JSON-encoded string arrays stored as jsonb +type StringSlice []string + +func (s *StringSlice) Scan(src interface{}) error { + if src == nil { + *s = []string{} + return nil + } + + var b []byte + switch v := src.(type) { + case []byte: + b = v + case string: + b = []byte(v) + default: + return fmt.Errorf("cannot scan %T into StringSlice", src) + } + + // Handle empty/null JSON values + b = []byte(strings.TrimSpace(string(b))) + if len(b) == 0 || string(b) == "null" || string(b) == "[]" { + *s = []string{} + return nil + } + + var tmp []string + if err := json.Unmarshal(b, &tmp); err != nil { + return errors.Wrap(err, "error unmarshaling StringSlice") + } + + *s = StringSlice(tmp) + return nil +} + +func (s StringSlice) Value() (driver.Value, error) { + if len(s) == 0 { + return []byte("[]"), nil + } + + b, err := json.Marshal([]string(s)) + if err != nil { + return nil, errors.Wrap(err, "error marshaling StringSlice") + } + return b, nil +} + +// OAuthAttributeMapping defines how to map provider attributes to user fields +type OAuthAttributeMapping map[string]interface{} + +func (m *OAuthAttributeMapping) Scan(src interface{}) error { + if src == nil { + *m = make(OAuthAttributeMapping) + return nil + } + + b, ok := src.([]byte) + if !ok { + str, ok := src.(string) + if !ok { + return errors.New("scan source was not []byte or string") + } + b = []byte(str) + } + + if err := json.Unmarshal(b, m); err != nil { + return errors.Wrap(err, "error unmarshaling attribute mapping") + } + + return nil +} + +func (m OAuthAttributeMapping) Value() (driver.Value, error) { + if m == nil { + return []byte("{}"), nil + } + + b, err := json.Marshal(m) + if err != nil { + return nil, errors.Wrap(err, "error marshaling attribute mapping") + } + + return b, nil +} + +// OAuthAuthorizationParams holds additional parameters for authorization requests +type OAuthAuthorizationParams map[string]interface{} + +func (p *OAuthAuthorizationParams) Scan(src interface{}) error { + if src == nil { + *p = make(OAuthAuthorizationParams) + return nil + } + + b, ok := src.([]byte) + if !ok { + str, ok := src.(string) + if !ok { + return errors.New("scan source was not []byte or string") + } + b = []byte(str) + } + + if err := json.Unmarshal(b, p); err != nil { + return errors.Wrap(err, "error unmarshaling authorization params") + } + + return nil +} + +func (p OAuthAuthorizationParams) Value() (driver.Value, error) { + if p == nil { + return []byte("{}"), nil + } + + b, err := json.Marshal(p) + if err != nil { + return nil, errors.Wrap(err, "error marshaling authorization params") + } + + return b, nil +} + +// OIDCDiscovery represents cached OIDC discovery document +type OIDCDiscovery struct { + Issuer string `json:"issuer"` + AuthorizationEndpoint string `json:"authorization_endpoint"` + TokenEndpoint string `json:"token_endpoint"` + UserinfoEndpoint string `json:"userinfo_endpoint,omitempty"` + JwksURI string `json:"jwks_uri"` + ScopesSupported []string `json:"scopes_supported,omitempty"` + ResponseTypesSupported []string `json:"response_types_supported,omitempty"` + GrantTypesSupported []string `json:"grant_types_supported,omitempty"` + SubjectTypesSupported []string `json:"subject_types_supported,omitempty"` +} + +func (d *OIDCDiscovery) Scan(src interface{}) error { + if src == nil { + return nil + } + + b, ok := src.([]byte) + if !ok { + str, ok := src.(string) + if !ok { + return errors.New("scan source was not []byte or string") + } + b = []byte(str) + } + + if err := json.Unmarshal(b, d); err != nil { + return errors.Wrap(err, "error unmarshaling OIDC discovery") + } + + return nil +} + +func (d *OIDCDiscovery) Value() (driver.Value, error) { + if d == nil { + return nil, nil + } + + b, err := json.Marshal(d) + if err != nil { + return nil, errors.Wrap(err, "error marshaling OIDC discovery") + } + + return b, nil +} + +// CRUD operations for CustomOAuthProvider + +// FindCustomOAuthProviderByID finds a custom OAuth provider by ID +func FindCustomOAuthProviderByID(tx *storage.Connection, id uuid.UUID) (*CustomOAuthProvider, error) { + var provider CustomOAuthProvider + + if err := tx.Q().Where("id = ?", id).First(&provider); err != nil { + if errors.Cause(err) == sql.ErrNoRows { + return nil, CustomOAuthProviderNotFoundError{} + } + return nil, errors.Wrap(err, "error finding custom OAuth provider by ID") + } + + return &provider, nil +} + +// FindCustomOAuthProviderByIdentifier finds a custom OAuth provider by identifier +func FindCustomOAuthProviderByIdentifier(tx *storage.Connection, identifier string) (*CustomOAuthProvider, error) { + var provider CustomOAuthProvider + + if err := tx.Q().Where("identifier = ?", identifier).First(&provider); err != nil { + if errors.Cause(err) == sql.ErrNoRows { + return nil, CustomOAuthProviderNotFoundError{} + } + return nil, errors.Wrap(err, "error finding custom OAuth provider by identifier") + } + + return &provider, nil +} + +// FindAllCustomOAuthProviders finds all custom OAuth providers +func FindAllCustomOAuthProviders(tx *storage.Connection) ([]*CustomOAuthProvider, error) { + var providers []*CustomOAuthProvider + + if err := tx.Q().Order("created_at desc").All(&providers); err != nil { + if errors.Cause(err) == sql.ErrNoRows { + return []*CustomOAuthProvider{}, nil + } + return nil, errors.Wrap(err, "error finding all custom OAuth providers") + } + + return providers, nil +} + +// FindAllCustomOAuthProvidersByType finds all custom OAuth providers of a specific type +func FindAllCustomOAuthProvidersByType(tx *storage.Connection, providerType ProviderType) ([]*CustomOAuthProvider, error) { + var providers []*CustomOAuthProvider + + if err := tx.Q().Where("provider_type = ?", providerType).Order("created_at desc").All(&providers); err != nil { + if errors.Cause(err) == sql.ErrNoRows { + return []*CustomOAuthProvider{}, nil + } + return nil, errors.Wrap(err, "error finding custom OAuth providers by type") + } + + return providers, nil +} + +// CountCustomOAuthProviders counts all custom OAuth providers +func CountCustomOAuthProviders(tx *storage.Connection) (int, error) { + count, err := tx.Q().Count(&CustomOAuthProvider{}) + if err != nil { + return 0, errors.Wrap(err, "error counting custom OAuth providers") + } + return count, nil +} + +// CreateCustomOAuthProvider creates a new custom OAuth provider +func CreateCustomOAuthProvider(tx *storage.Connection, provider *CustomOAuthProvider) error { + if provider.ID == uuid.Nil { + id, err := uuid.NewV4() + if err != nil { + return errors.Wrap(err, "error generating custom OAuth provider ID") + } + provider.ID = id + } + + if err := tx.Create(provider); err != nil { + return errors.Wrap(err, "error creating custom OAuth provider") + } + + return nil +} + +// UpdateCustomOAuthProvider updates an existing custom OAuth provider +func UpdateCustomOAuthProvider(tx *storage.Connection, provider *CustomOAuthProvider) error { + // Set updated_at timestamp explicitly in application code + provider.UpdatedAt = time.Now() + if err := tx.Update(provider); err != nil { + return errors.Wrap(err, "error updating custom OAuth provider") + } + return nil +} + +// DeleteCustomOAuthProvider deletes a custom OAuth provider +func DeleteCustomOAuthProvider(tx *storage.Connection, id uuid.UUID) error { + if err := tx.Destroy(&CustomOAuthProvider{ID: id}); err != nil { + return errors.Wrap(err, "error deleting custom OAuth provider") + } + return nil +} diff --git a/internal/models/custom_oauth_provider_test.go b/internal/models/custom_oauth_provider_test.go new file mode 100644 index 000000000..ce0480de0 --- /dev/null +++ b/internal/models/custom_oauth_provider_test.go @@ -0,0 +1,428 @@ +package models + +import ( + "testing" + + "github.com/gofrs/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/storage" + "github.com/supabase/auth/internal/storage/test" +) + +type CustomOAuthProviderTestSuite struct { + suite.Suite + db *storage.Connection + config *conf.GlobalConfiguration +} + +func (ts *CustomOAuthProviderTestSuite) SetupTest() { + TruncateAll(ts.db) +} + +func TestCustomOAuthProvider(t *testing.T) { + globalConfig, err := conf.LoadGlobal(modelsTestConfig) + require.NoError(t, err) + + conn, err := test.SetupDBConnection(globalConfig) + require.NoError(t, err) + + ts := &CustomOAuthProviderTestSuite{ + db: conn, + config: globalConfig, + } + defer ts.db.Close() + + suite.Run(t, ts) +} + +// Test CRUD Operations + +func (ts *CustomOAuthProviderTestSuite) TestCreateCustomOAuthProvider() { + tests := []struct { + name string + providerType ProviderType + }{ + { + name: "Create OAuth2 provider", + providerType: ProviderTypeOAuth2, + }, + { + name: "Create OIDC provider", + providerType: ProviderTypeOIDC, + }, + } + + for _, tt := range tests { + ts.Run(tt.name, func() { + provider := ts.createTestProvider(tt.providerType, "test-provider-"+string(tt.providerType)) + + require.NotEqual(ts.T(), uuid.Nil, provider.ID) + require.Equal(ts.T(), tt.providerType, provider.ProviderType) + require.NotEmpty(ts.T(), provider.CreatedAt) + require.NotEmpty(ts.T(), provider.UpdatedAt) + }) + } +} + +func (ts *CustomOAuthProviderTestSuite) TestFindCustomOAuthProviderByID() { + provider := ts.createTestProvider(ProviderTypeOAuth2, "custom:test-oauth2") + + found, err := FindCustomOAuthProviderByID(ts.db, provider.ID) + require.NoError(ts.T(), err) + require.Equal(ts.T(), provider.ID, found.ID) + require.Equal(ts.T(), provider.Identifier, found.Identifier) + + // Test not found + nonExistentID, _ := uuid.NewV4() + _, err = FindCustomOAuthProviderByID(ts.db, nonExistentID) + require.Error(ts.T(), err) + require.True(ts.T(), IsNotFoundError(err)) +} + +func (ts *CustomOAuthProviderTestSuite) TestFindCustomOAuthProviderByIdentifier() { + identifier := "custom:test-github-enterprise" + provider := ts.createTestProvider(ProviderTypeOAuth2, identifier) + + found, err := FindCustomOAuthProviderByIdentifier(ts.db, identifier) + require.NoError(ts.T(), err) + require.Equal(ts.T(), provider.ID, found.ID) + require.Equal(ts.T(), identifier, found.Identifier) + + // Test not found + _, err = FindCustomOAuthProviderByIdentifier(ts.db, "custom:nonexistent") + require.Error(ts.T(), err) + require.True(ts.T(), IsNotFoundError(err)) +} + +func (ts *CustomOAuthProviderTestSuite) TestFindAllCustomOAuthProviders() { + ts.createTestProvider(ProviderTypeOAuth2, "custom:provider1") + ts.createTestProvider(ProviderTypeOAuth2, "custom:provider2") + ts.createTestOIDCProvider("custom:provider3", "https://oidc1.example.com") + + providers, err := FindAllCustomOAuthProviders(ts.db) + require.NoError(ts.T(), err) + require.Len(ts.T(), providers, 3) +} + +func (ts *CustomOAuthProviderTestSuite) TestFindAllCustomOAuthProvidersByType() { + ts.createTestProvider(ProviderTypeOAuth2, "custom:oauth2-1") + ts.createTestProvider(ProviderTypeOAuth2, "custom:oauth2-2") + ts.createTestOIDCProvider("custom:oidc-1", "https://oidc1.example.com") + ts.createTestOIDCProvider("custom:oidc-2", "https://oidc2.example.com") + + oauth2Providers, err := FindAllCustomOAuthProvidersByType(ts.db, ProviderTypeOAuth2) + require.NoError(ts.T(), err) + require.Len(ts.T(), oauth2Providers, 2) + + oidcProviders, err := FindAllCustomOAuthProvidersByType(ts.db, ProviderTypeOIDC) + require.NoError(ts.T(), err) + require.Len(ts.T(), oidcProviders, 2) +} + +func (ts *CustomOAuthProviderTestSuite) TestCountCustomOAuthProviders() { + ts.createTestProvider(ProviderTypeOAuth2, "custom:count1") + ts.createTestOIDCProvider("custom:count2", "https://count.example.com") + + count, err := CountCustomOAuthProviders(ts.db) + require.NoError(ts.T(), err) + assert.Equal(ts.T(), 2, count) +} +func (ts *CustomOAuthProviderTestSuite) TestUpdateCustomOAuthProvider() { + provider := ts.createTestProvider(ProviderTypeOAuth2, "custom:update-test") + + // Update name + provider.Name = "Updated Name" + provider.ClientID = "new-client-id" + provider.Scopes = StringSlice{"openid", "profile", "email"} + + err := UpdateCustomOAuthProvider(ts.db, provider) + require.NoError(ts.T(), err) + + // Verify update + updated, err := FindCustomOAuthProviderByID(ts.db, provider.ID) + require.NoError(ts.T(), err) + assert.Equal(ts.T(), "Updated Name", updated.Name) + assert.Equal(ts.T(), "new-client-id", updated.ClientID) + assert.Equal(ts.T(), StringSlice{"openid", "profile", "email"}, updated.Scopes) +} + +func (ts *CustomOAuthProviderTestSuite) TestDeleteCustomOAuthProvider() { + provider := ts.createTestProvider(ProviderTypeOAuth2, "custom:delete-test") + + err := DeleteCustomOAuthProvider(ts.db, provider.ID) + require.NoError(ts.T(), err) + + // Verify deletion + _, err = FindCustomOAuthProviderByID(ts.db, provider.ID) + require.Error(ts.T(), err) + require.True(ts.T(), IsNotFoundError(err)) +} + +// Test Custom Types + +func (ts *CustomOAuthProviderTestSuite) TestStringSliceSerialisation() { + provider := &CustomOAuthProvider{ + ProviderType: ProviderTypeOAuth2, + Identifier: "custom:string-slice-test", + Name: "String Slice Test", + ClientID: "client-id", + Scopes: StringSlice{"openid", "profile", "email"}, + AcceptableClientIDs: StringSlice{"ios-client", "android-client", "web-client"}, + AuthorizationURL: stringPtr("https://example.com/authorize"), + TokenURL: stringPtr("https://example.com/token"), + UserinfoURL: stringPtr("https://example.com/userinfo"), + PKCEEnabled: true, + Enabled: true, + } + + err := CreateCustomOAuthProvider(ts.db, provider) + require.NoError(ts.T(), err) + + // Retrieve and verify + found, err := FindCustomOAuthProviderByID(ts.db, provider.ID) + require.NoError(ts.T(), err) + assert.Equal(ts.T(), StringSlice{"openid", "profile", "email"}, found.Scopes) + assert.Equal(ts.T(), StringSlice{"ios-client", "android-client", "web-client"}, found.AcceptableClientIDs) + + // Test empty slice + provider.Scopes = StringSlice{} + err = UpdateCustomOAuthProvider(ts.db, provider) + require.NoError(ts.T(), err) + + found, err = FindCustomOAuthProviderByID(ts.db, provider.ID) + require.NoError(ts.T(), err) + assert.Empty(ts.T(), found.Scopes) +} + +func (ts *CustomOAuthProviderTestSuite) TestOAuthAttributeMappingSerialization() { + mapping := OAuthAttributeMapping{ + "email": "user_email", + "name": "full_name", + "avatar_url": "picture", + "custom_field": map[string]interface{}{ + "nested": "value", + }, + } + + provider := &CustomOAuthProvider{ + ProviderType: ProviderTypeOAuth2, + Identifier: "custom:mapping-test", + Name: "Mapping Test", + ClientID: "client-id", + AuthorizationURL: stringPtr("https://example.com/authorize"), + TokenURL: stringPtr("https://example.com/token"), + UserinfoURL: stringPtr("https://example.com/userinfo"), + AttributeMapping: mapping, + PKCEEnabled: true, + Enabled: true, + } + + err := CreateCustomOAuthProvider(ts.db, provider) + require.NoError(ts.T(), err) + + // Retrieve and verify + found, err := FindCustomOAuthProviderByID(ts.db, provider.ID) + require.NoError(ts.T(), err) + assert.Equal(ts.T(), "user_email", found.AttributeMapping["email"]) + assert.Equal(ts.T(), "full_name", found.AttributeMapping["name"]) + assert.Equal(ts.T(), "picture", found.AttributeMapping["avatar_url"]) + assert.NotNil(ts.T(), found.AttributeMapping["custom_field"]) +} + +func (ts *CustomOAuthProviderTestSuite) TestOAuthAuthorizationParamsSerialization() { + params := OAuthAuthorizationParams{ + "prompt": "consent", + "access_type": "offline", + "custom_param": "value", + "complex_param": []string{"val1", "val2"}, + } + + provider := &CustomOAuthProvider{ + ProviderType: ProviderTypeOAuth2, + Identifier: "custom:params-test", + Name: "Params Test", + ClientID: "client-id", + AuthorizationURL: stringPtr("https://example.com/authorize"), + TokenURL: stringPtr("https://example.com/token"), + UserinfoURL: stringPtr("https://example.com/userinfo"), + AuthorizationParams: params, + PKCEEnabled: true, + Enabled: true, + } + + err := CreateCustomOAuthProvider(ts.db, provider) + require.NoError(ts.T(), err) + + // Retrieve and verify + found, err := FindCustomOAuthProviderByID(ts.db, provider.ID) + require.NoError(ts.T(), err) + assert.Equal(ts.T(), "consent", found.AuthorizationParams["prompt"]) + assert.Equal(ts.T(), "offline", found.AuthorizationParams["access_type"]) + assert.Equal(ts.T(), "value", found.AuthorizationParams["custom_param"]) +} + +func (ts *CustomOAuthProviderTestSuite) TestOIDCDiscoverySerialization() { + issuer := "https://oidc-discovery-test.example.com" + discovery := &OIDCDiscovery{ + Issuer: issuer, + AuthorizationEndpoint: issuer + "/authorize", + TokenEndpoint: issuer + "/token", + UserinfoEndpoint: issuer + "/userinfo", + JwksURI: issuer + "/jwks", + ScopesSupported: []string{"openid", "profile", "email"}, + ResponseTypesSupported: []string{"code", "token", "id_token"}, + GrantTypesSupported: []string{"authorization_code", "refresh_token"}, + SubjectTypesSupported: []string{"public"}, + } + + provider := &CustomOAuthProvider{ + ProviderType: ProviderTypeOIDC, + Identifier: "custom:oidc-discovery-test", + Name: "OIDC Discovery Test", + ClientID: "client-id", + Issuer: &issuer, + CachedDiscovery: discovery, + PKCEEnabled: true, + Enabled: true, + } + + err := CreateCustomOAuthProvider(ts.db, provider) + require.NoError(ts.T(), err) + + // Retrieve and verify + found, err := FindCustomOAuthProviderByID(ts.db, provider.ID) + require.NoError(ts.T(), err) + require.NotNil(ts.T(), found.CachedDiscovery) + assert.Equal(ts.T(), issuer, found.CachedDiscovery.Issuer) + assert.Equal(ts.T(), issuer+"/authorize", found.CachedDiscovery.AuthorizationEndpoint) + assert.Equal(ts.T(), issuer+"/token", found.CachedDiscovery.TokenEndpoint) + assert.Equal(ts.T(), issuer+"/userinfo", found.CachedDiscovery.UserinfoEndpoint) + assert.Equal(ts.T(), issuer+"/jwks", found.CachedDiscovery.JwksURI) + assert.Equal(ts.T(), []string{"openid", "profile", "email"}, found.CachedDiscovery.ScopesSupported) +} + +// Test Helper Methods + +func (ts *CustomOAuthProviderTestSuite) TestIsOIDC() { + oauth2Provider := ts.createTestProvider(ProviderTypeOAuth2, "custom:oauth2-check") + assert.False(ts.T(), oauth2Provider.IsOIDC()) + assert.True(ts.T(), oauth2Provider.IsOAuth2()) + + oidcProvider := ts.createTestOIDCProvider("custom:oidc-check", "https://oidc.example.com") + assert.True(ts.T(), oidcProvider.IsOIDC()) + assert.False(ts.T(), oidcProvider.IsOAuth2()) +} + +func (ts *CustomOAuthProviderTestSuite) TestGetProviderName() { + provider := ts.createTestProvider(ProviderTypeOAuth2, "custom:github-enterprise") + assert.Equal(ts.T(), "custom:github-enterprise", provider.GetProviderName()) +} + +func (ts *CustomOAuthProviderTestSuite) TestGetDiscoveryURL() { + issuer1 := "https://oidc-auto.example.com" + + // Test without explicit discovery URL (should construct from issuer) + provider := &CustomOAuthProvider{ + ProviderType: ProviderTypeOIDC, + Identifier: "custom:oidc-auto-discovery", + Name: "OIDC Auto Discovery", + ClientID: "client-id", + Issuer: &issuer1, + PKCEEnabled: true, + Enabled: true, + } + + err := CreateCustomOAuthProvider(ts.db, provider) + require.NoError(ts.T(), err) + + assert.Equal(ts.T(), "https://oidc-auto.example.com/.well-known/openid-configuration", provider.GetDiscoveryURL()) + + // Test with explicit discovery URL (use different issuer to avoid constraint violation) + issuer2 := "https://oidc-explicit.example.com" + explicitDiscoveryURL := "https://oidc-explicit.example.com/.well-known/openid-configuration-custom" + provider2 := &CustomOAuthProvider{ + ProviderType: ProviderTypeOIDC, + Identifier: "custom:oidc-explicit-discovery", + Name: "OIDC Explicit Discovery", + ClientID: "client-id", + Issuer: &issuer2, + DiscoveryURL: &explicitDiscoveryURL, + PKCEEnabled: true, + Enabled: true, + } + + err = CreateCustomOAuthProvider(ts.db, provider2) + require.NoError(ts.T(), err) + + assert.Equal(ts.T(), explicitDiscoveryURL, provider2.GetDiscoveryURL()) + + // Test OAuth2 provider returns empty string + oauth2Provider := ts.createTestProvider(ProviderTypeOAuth2, "custom:oauth2-no-discovery") + assert.Empty(ts.T(), oauth2Provider.GetDiscoveryURL()) +} + +// Helper functions + +func (ts *CustomOAuthProviderTestSuite) createTestProvider(providerType ProviderType, identifier string) *CustomOAuthProvider { + provider := &CustomOAuthProvider{ + ProviderType: providerType, + Identifier: identifier, + Name: "Test Provider", + ClientID: "test-client-id", + Scopes: StringSlice{"openid", "profile"}, + PKCEEnabled: true, + Enabled: true, + } + + if providerType == ProviderTypeOAuth2 { + authURL := "https://example.com/authorize" + // #nosec G101 - These are test URLs, not actual credentials + tokenURL := "https://example.com/token" + userinfoURL := "https://example.com/userinfo" + provider.AuthorizationURL = &authURL + provider.TokenURL = &tokenURL + provider.UserinfoURL = &userinfoURL + } else if providerType == ProviderTypeOIDC { + // For OIDC, generate a unique issuer to avoid constraint violations + issuer := "https://oidc-" + identifier + ".example.com" + provider.Issuer = &issuer + } + + // Encrypt and set client secret before persisting + err := provider.SetClientSecret("test-client-secret", ts.config.Security.DBEncryption) + require.NoError(ts.T(), err) + + err = CreateCustomOAuthProvider(ts.db, provider) + require.NoError(ts.T(), err) + + return provider +} + +func (ts *CustomOAuthProviderTestSuite) createTestOIDCProvider(identifier, issuer string) *CustomOAuthProvider { + provider := &CustomOAuthProvider{ + ProviderType: ProviderTypeOIDC, + Identifier: identifier, + Name: "Test OIDC Provider", + ClientID: "test-client-id", + Issuer: &issuer, + Scopes: StringSlice{"openid", "profile"}, + PKCEEnabled: true, + Enabled: true, + } + + // Encrypt and set client secret before persisting + err := provider.SetClientSecret("test-client-secret", ts.config.Security.DBEncryption) + require.NoError(ts.T(), err) + + err = CreateCustomOAuthProvider(ts.db, provider) + require.NoError(ts.T(), err) + + return provider +} + +func stringPtr(s string) *string { + return &s +} diff --git a/internal/models/errors.go b/internal/models/errors.go index 4f1c95e60..c78c14522 100644 --- a/internal/models/errors.go +++ b/internal/models/errors.go @@ -33,6 +33,8 @@ func IsNotFoundError(err error) bool { return true case OAuthClientStateNotFoundError, *OAuthClientStateNotFoundError: return true + case CustomOAuthProviderNotFoundError, *CustomOAuthProviderNotFoundError: + return true } return false } @@ -135,3 +137,10 @@ type OAuthClientStateNotFoundError struct{} func (e OAuthClientStateNotFoundError) Error() string { return "OAuth state not found" } + +// CustomOAuthProviderNotFoundError represents an error when a custom OAuth/OIDC provider can't be found +type CustomOAuthProviderNotFoundError struct{} + +func (e CustomOAuthProviderNotFoundError) Error() string { + return "Custom OAuth provider not found" +} diff --git a/internal/utilities/url_validator.go b/internal/utilities/url_validator.go new file mode 100644 index 000000000..d6dd97211 --- /dev/null +++ b/internal/utilities/url_validator.go @@ -0,0 +1,205 @@ +package utilities + +import ( + "context" + "fmt" + "net" + "net/http" + "net/url" + "strings" + "time" + + "github.com/supabase/auth/internal/api/apierrors" +) + +// ValidateOAuthURL validates that a URL is safe for OAuth/OIDC operations +// and protects against SSRF attacks by blocking private IPs and metadata endpoints +func ValidateOAuthURL(urlStr string) error { + // Parse the URL + parsedURL, err := url.Parse(urlStr) + if err != nil { + return apierrors.NewBadRequestError( + apierrors.ErrorCodeValidationFailed, + "Invalid URL format", + ).WithInternalError(err) + } + + // Enforce HTTPS + if parsedURL.Scheme != "https" { + return apierrors.NewBadRequestError( + apierrors.ErrorCodeValidationFailed, + "URL must use HTTPS", + ) + } + + // Extract hostname + hostname := parsedURL.Hostname() + if hostname == "" { + return apierrors.NewBadRequestError( + apierrors.ErrorCodeValidationFailed, + "URL must have a valid hostname", + ) + } + + // Check for localhost and loopback + if isLocalhost(hostname) { + return apierrors.NewBadRequestError( + apierrors.ErrorCodeValidationFailed, + "URL cannot point to localhost or loopback addresses", + ) + } + + // Resolve hostname to IP addresses + ips, err := net.LookupIP(hostname) + if err != nil { + return apierrors.NewBadRequestError( + apierrors.ErrorCodeValidationFailed, + "Unable to resolve hostname", + ).WithInternalError(err) + } + + // Check each resolved IP + for _, ip := range ips { + if err := validateIP(ip); err != nil { + return err + } + } + + return nil +} + +// isLocalhost checks if the hostname is localhost or a loopback address +func isLocalhost(hostname string) bool { + hostname = strings.ToLower(hostname) + + localhostVariants := []string{ + "localhost", + "127.0.0.1", + "::1", + "0.0.0.0", + "::", + } + + for _, variant := range localhostVariants { + if hostname == variant { + return true + } + } + + // Check for localhost subdomains like "foo.localhost" + if strings.HasSuffix(hostname, ".localhost") { + return true + } + + return false +} + +// validateIP checks if an IP address is safe for OAuth/OIDC operations +func validateIP(ip net.IP) error { + // Block loopback addresses (127.0.0.0/8, ::1) + if ip.IsLoopback() { + return apierrors.NewBadRequestError( + apierrors.ErrorCodeValidationFailed, + "URL cannot resolve to loopback addresses", + ) + } + + // Block private network addresses (RFC 1918) + // 10.0.0.0/8, 172.16.0.0/12, 192.168.0.0/16 + if ip.IsPrivate() { + return apierrors.NewBadRequestError( + apierrors.ErrorCodeValidationFailed, + "URL cannot resolve to private network addresses", + ) + } + + // Block link-local addresses (169.254.0.0/16, fe80::/10) + if ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() { + return apierrors.NewBadRequestError( + apierrors.ErrorCodeValidationFailed, + "URL cannot resolve to link-local addresses", + ) + } + + // Block cloud metadata endpoints (169.254.169.254) + if ip.String() == "169.254.169.254" { + return apierrors.NewBadRequestError( + apierrors.ErrorCodeValidationFailed, + "URL cannot resolve to cloud metadata endpoints", + ) + } + + // Block multicast addresses + if ip.IsMulticast() { + return apierrors.NewBadRequestError( + apierrors.ErrorCodeValidationFailed, + "URL cannot resolve to multicast addresses", + ) + } + + // Block unspecified addresses (0.0.0.0, ::) + if ip.IsUnspecified() { + return apierrors.NewBadRequestError( + apierrors.ErrorCodeValidationFailed, + "URL cannot resolve to unspecified addresses", + ) + } + + return nil +} + +// FetchURLWithTimeout fetches a URL with timeout and SSRF protection +// This is used for fetching OIDC discovery documents and JWKS +func FetchURLWithTimeout(ctx context.Context, urlStr string, timeout time.Duration) (*http.Response, error) { + // Validate URL first + if err := ValidateOAuthURL(urlStr); err != nil { + return nil, err + } + + // Create HTTP client with timeout + client := &http.Client{ + Timeout: timeout, + // Use a custom transport that re-validates the IP after DNS resolution + Transport: &ssrfProtectedTransport{ + base: http.DefaultTransport, + }, + } + + // Create request with context + req, err := http.NewRequestWithContext(ctx, http.MethodGet, urlStr, nil) + if err != nil { + return nil, apierrors.NewInternalServerError( + "Failed to create HTTP request", + ).WithInternalError(err) + } + + // Set user agent + req.Header.Set("User-Agent", "Supabase-Auth/1.0") + req.Header.Set("Accept", "application/json") + + // Execute request + resp, err := client.Do(req) + if err != nil { + return nil, apierrors.NewInternalServerError( + "Failed to fetch URL", + ).WithInternalError(err) + } + + return resp, nil +} + +// ssrfProtectedTransport wraps http.RoundTripper with additional SSRF checks +// TODO(cemal) :: should we keep it? +type ssrfProtectedTransport struct { + base http.RoundTripper +} + +func (t *ssrfProtectedTransport) RoundTrip(req *http.Request) (*http.Response, error) { + // Re-validate URL before making the request + // This protects against DNS rebinding attacks + if err := ValidateOAuthURL(req.URL.String()); err != nil { + return nil, fmt.Errorf("SSRF protection: %w", err) + } + + return t.base.RoundTrip(req) +} diff --git a/internal/utilities/url_validator_test.go b/internal/utilities/url_validator_test.go new file mode 100644 index 000000000..63d29b707 --- /dev/null +++ b/internal/utilities/url_validator_test.go @@ -0,0 +1,243 @@ +package utilities + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/supabase/auth/internal/api/apierrors" +) + +func TestValidateOAuthURL(t *testing.T) { + tests := []struct { + name string + url string + wantErr bool + errMsg string + }{ + // Valid URLs (these resolve via DNS and pass SSRF checks) + { + name: "Valid HTTPS URL", + url: "https://example.com/oauth/authorize", + wantErr: false, + }, + { + name: "Valid HTTPS URL with port", + url: "https://example.com:8443/authorize", + wantErr: false, + }, + { + name: "Valid HTTPS URL with path and query", + url: "https://example.com/path?query=value", + wantErr: false, + }, + + // Invalid scheme + { + name: "HTTP not allowed", + url: "http://example.com/authorize", + wantErr: true, + errMsg: "URL must use HTTPS", + }, + { + name: "FTP not allowed", + url: "ftp://example.com/file", + wantErr: true, + errMsg: "URL must use HTTPS", + }, + + // Localhost variants + { + name: "Localhost blocked", + url: "https://localhost/authorize", + wantErr: true, + errMsg: "URL cannot point to localhost or loopback addresses", + }, + { + name: "127.0.0.1 blocked", + url: "https://127.0.0.1/authorize", + wantErr: true, + errMsg: "URL cannot point to localhost or loopback addresses", + }, + { + name: "::1 (IPv6 loopback) blocked", + url: "https://[::1]/authorize", + wantErr: true, + errMsg: "URL cannot point to localhost or loopback addresses", + }, + { + name: "0.0.0.0 blocked", + url: "https://0.0.0.0/authorize", + wantErr: true, + errMsg: "URL cannot point to localhost or loopback addresses", + }, + { + name: "Subdomain of localhost blocked", + url: "https://test.localhost/authorize", + wantErr: true, + errMsg: "URL cannot point to localhost or loopback addresses", + }, + + // Private IP ranges (RFC 1918) + { + name: "10.0.0.0/8 network blocked", + url: "https://10.1.2.3/authorize", + wantErr: true, + errMsg: "URL cannot resolve to private network addresses", + }, + { + name: "172.16.0.0/12 network blocked", + url: "https://172.16.0.1/authorize", + wantErr: true, + errMsg: "URL cannot resolve to private network addresses", + }, + { + name: "192.168.0.0/16 network blocked", + url: "https://192.168.1.1/authorize", + wantErr: true, + errMsg: "URL cannot resolve to private network addresses", + }, + + // Cloud metadata endpoint (caught by link-local check) + { + name: "Cloud metadata endpoint blocked", + url: "https://169.254.169.254/latest/meta-data", + wantErr: true, + errMsg: "link-local", // 169.254.0.0/16 is link-local and caught first + }, + + // Invalid URLs (caught by URL parsing) + { + name: "Malformed URL", + url: "not-a-valid-url", + wantErr: true, + errMsg: "HTTPS", // Parse succeeds but scheme check fails + }, + { + name: "Empty URL", + url: "", + wantErr: true, + errMsg: "HTTPS", // Parse succeeds (empty string) but scheme check fails + }, + { + name: "URL without hostname", + url: "https://", + wantErr: true, + errMsg: "URL must have a valid hostname", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidateOAuthURL(tt.url) + + if tt.wantErr { + require.Error(t, err) + if tt.errMsg != "" { + apiErr, ok := err.(*apierrors.HTTPError) + require.True(t, ok, "expected apierrors.HTTPError") + assert.Contains(t, apiErr.Message, tt.errMsg) + } + } else { + require.NoError(t, err) + } + }) + } +} + +func TestIsLocalhost(t *testing.T) { + tests := []struct { + name string + hostname string + want bool + }{ + {"localhost", "localhost", true}, + {"Localhost uppercase", "Localhost", true}, + {"LOCALHOST uppercase", "LOCALHOST", true}, + {"127.0.0.1", "127.0.0.1", true}, + {"::1 IPv6", "::1", true}, + {"0.0.0.0", "0.0.0.0", true}, + {"::", "::", true}, + {"test.localhost subdomain", "test.localhost", true}, + {"api.test.localhost nested subdomain", "api.test.localhost", true}, + {"example.com", "example.com", false}, + {"localhostbutnotreally.com", "localhostbutnotreally.com", false}, + {"localhost.example.com", "localhost.example.com", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := isLocalhost(tt.hostname) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestFetchURLWithTimeout(t *testing.T) { + t.Run("Successful fetch", func(t *testing.T) { + // Use TLS server since ValidateOAuthURL enforces HTTPS + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"test": "data"}`)) + })) + defer server.Close() + + // Note: This test will fail DNS validation since the test server uses 127.0.0.1 + // This is expected behavior - in production, HTTPS URLs must point to public IPs + // For this test, we're just verifying the HTTP client setup works + ctx := context.Background() + _, err := FetchURLWithTimeout(ctx, server.URL, 5*time.Second) + // Expect error due to localhost/loopback check (SSRF protection) + require.Error(t, err) + assert.Contains(t, err.Error(), "localhost or loopback") + }) + + t.Run("Invalid URL fails SSRF check", func(t *testing.T) { + ctx := context.Background() + _, err := FetchURLWithTimeout(ctx, "https://localhost/test", 5*time.Second) + require.Error(t, err) + assert.Contains(t, err.Error(), "localhost or loopback addresses") + }) + + t.Run("HTTP URL rejected", func(t *testing.T) { + ctx := context.Background() + _, err := FetchURLWithTimeout(ctx, "http://example.com/test", 5*time.Second) + require.Error(t, err) + assert.Contains(t, err.Error(), "URL must use HTTPS") + }) + + t.Run("Context timeout", func(t *testing.T) { + // This test verifies timeout behavior would work, but we can't actually test it + // with a real server due to SSRF protection. We're testing the error path instead. + ctx := context.Background() + _, err := FetchURLWithTimeout(ctx, "https://127.0.0.1:8443/test", 50*time.Millisecond) + require.Error(t, err) + // Should fail SSRF check before timeout + assert.Contains(t, err.Error(), "loopback") + }) + +} + +func TestSSRFProtectedTransport(t *testing.T) { + t.Run("SSRF protection re-validates on redirect", func(t *testing.T) { + // This test verifies that the SSRF protection re-validates URLs + // even after DNS resolution, protecting against DNS rebinding attacks + + transport := &ssrfProtectedTransport{ + base: http.DefaultTransport, + } + + // Try to create a request to a private IP + req, err := http.NewRequest(http.MethodGet, "https://10.0.0.1/test", nil) + require.NoError(t, err) + + _, err = transport.RoundTrip(req) + require.Error(t, err) + assert.Contains(t, err.Error(), "SSRF protection") + }) +} diff --git a/migrations/20260128120000_add_custom_oauth_providers.up.sql b/migrations/20260128120000_add_custom_oauth_providers.up.sql new file mode 100644 index 000000000..d67513665 --- /dev/null +++ b/migrations/20260128120000_add_custom_oauth_providers.up.sql @@ -0,0 +1,114 @@ +-- Create unified custom OAuth/OIDC providers table +-- This table stores both OAuth2 and OIDC providers with type discrimination + +/* auth_migration: 20260128120000 */ +create table if not exists {{ index .Options "Namespace" }}.custom_oauth_providers ( + id uuid not null default gen_random_uuid(), + + -- Provider type: 'oauth2' or 'oidc' + provider_type text not null check (provider_type in ('oauth2', 'oidc')), + + -- Common fields for both OAuth2 and OIDC + identifier text not null, + name text not null, + client_id text not null, + client_secret text not null, -- Encrypted at application level + -- Store JSON-encoded string slices in jsonb columns + acceptable_client_ids jsonb not null default '[]'::jsonb, -- Additional client IDs for multi-platform apps + scopes jsonb not null default '[]'::jsonb, + pkce_enabled boolean not null default true, + attribute_mapping jsonb not null default '{}', + authorization_params jsonb not null default '{}', + enabled boolean not null default true, + email_optional boolean not null default false, -- Allow sign-in without email + + -- OIDC-specific fields (null for OAuth2 providers) + issuer text null, + discovery_url text null, -- Optional override for .well-known/openid-configuration + skip_nonce_check boolean not null default false, + cached_discovery jsonb null, + discovery_cached_at timestamptz null, + + -- OAuth2-specific fields (null for OIDC providers) + authorization_url text null, + token_url text null, + userinfo_url text null, + jwks_uri text null, + + -- Timestamps + created_at timestamptz not null default now(), + updated_at timestamptz not null default now(), + + -- Primary key and unique constraints + constraint custom_oauth_providers_pkey primary key (id), + constraint custom_oauth_providers_identifier_key unique (identifier), + + -- OIDC-specific constraints + constraint custom_oauth_providers_oidc_requires_issuer check ( + provider_type != 'oidc' or issuer is not null + ), + constraint custom_oauth_providers_oidc_issuer_https check ( + provider_type != 'oidc' or issuer is null or issuer like 'https://%' + ), + constraint custom_oauth_providers_oidc_discovery_url_https check ( + provider_type != 'oidc' or discovery_url is null or discovery_url like 'https://%' + ), + + -- OAuth2-specific constraints + constraint custom_oauth_providers_oauth2_requires_endpoints check ( + provider_type != 'oauth2' or ( + authorization_url is not null and + token_url is not null and + userinfo_url is not null + ) + ), + + -- Format and length constraints + -- Identifier must be alphanumeric with optional hyphens (no leading/trailing hyphens) + constraint custom_oauth_providers_identifier_format check ( + identifier ~ '^[a-z0-9][a-z0-9:-]{0,48}[a-z0-9]$' + ), + constraint custom_oauth_providers_identifier_length check ( + char_length(identifier) >= 1 and char_length(identifier) <= 50 + ), + constraint custom_oauth_providers_name_length check ( + char_length(name) >= 1 and char_length(name) <= 100 + ), + constraint custom_oauth_providers_issuer_length check ( + issuer is null or (char_length(issuer) >= 1 and char_length(issuer) <= 2048) + ), + constraint custom_oauth_providers_discovery_url_length check ( + discovery_url is null or char_length(discovery_url) <= 2048 + ), + constraint custom_oauth_providers_authorization_url_length check ( + authorization_url is null or char_length(authorization_url) <= 2048 + ), + constraint custom_oauth_providers_token_url_length check ( + token_url is null or char_length(token_url) <= 2048 + ), + constraint custom_oauth_providers_userinfo_url_length check ( + userinfo_url is null or char_length(userinfo_url) <= 2048 + ), + constraint custom_oauth_providers_jwks_uri_length check ( + jwks_uri is null or char_length(jwks_uri) <= 2048 + ), + constraint custom_oauth_providers_client_id_length check ( + char_length(client_id) >= 1 and char_length(client_id) <= 512 + ) +); + +/* auth_migration: 20260128120000 */ +create index if not exists custom_oauth_providers_identifier_idx + on {{ index .Options "Namespace" }}.custom_oauth_providers (identifier); + +/* auth_migration: 20260128120000 */ +create index if not exists custom_oauth_providers_provider_type_idx + on {{ index .Options "Namespace" }}.custom_oauth_providers (provider_type); + +/* auth_migration: 20260128120000 */ +create index if not exists custom_oauth_providers_enabled_idx + on {{ index .Options "Namespace" }}.custom_oauth_providers (enabled); + +/* auth_migration: 20260128120000 */ +create index if not exists custom_oauth_providers_created_at_idx + on {{ index .Options "Namespace" }}.custom_oauth_providers (created_at); diff --git a/openapi.yaml b/openapi.yaml index 30a675efc..7dd8310c6 100644 --- a/openapi.yaml +++ b/openapi.yaml @@ -2234,6 +2234,444 @@ paths: 403: $ref: "#/components/responses/ForbiddenResponse" + /admin/custom-providers: + get: + summary: List all custom OIDC/OAuth providers + description: > + Retrieves a list of all custom OAuth 2.0 and OIDC provider configurations. + Optionally filter by provider type. Only available when custom OIDC/OAuth providers are enabled + (set `GOTRUE_CUSTOM_OAUTH_ENABLED=true` for self-hosted or enable in Supabase Dashboard). + tags: + - admin + - oauth-client + security: + - APIKeyAuth: [] + AdminAuth: [] + parameters: + - name: type + in: query + required: false + description: Filter by provider type + schema: + type: string + enum: + - oauth2 + - oidc + responses: + 200: + description: List of custom OIDC/OAuth providers + content: + application/json: + schema: + type: object + properties: + providers: + type: array + items: + $ref: "#/components/schemas/CustomOAuthProviderSchema" + 400: + description: Invalid type parameter + content: + application/json: + schema: + $ref: "#/components/schemas/ErrorSchema" + examples: + invalid_type: + summary: validation_failed + value: + error_code: validation_failed + msg: "type must be either 'oauth2' or 'oidc'" + 401: + $ref: "#/components/responses/UnauthorizedResponse" + 403: + $ref: "#/components/responses/ForbiddenResponse" + post: + summary: Create a new custom OIDC/OAuth provider + description: > + Creates a new custom OAuth 2.0 or OIDC provider configuration. + Required fields differ based on provider_type. Only available when custom OIDC/OAuth providers are enabled. + tags: + - admin + - oauth-client + security: + - APIKeyAuth: [] + AdminAuth: [] + requestBody: + required: true + content: + application/json: + schema: + type: object + required: + - provider_type + - identifier + - name + - client_id + - client_secret + properties: + provider_type: + type: string + enum: + - oauth2 + - oidc + description: Type of OAuth provider + identifier: + type: string + pattern: "^[a-zA-Z0-9_-]+$" + description: Unique identifier (will be prefixed with 'custom:' automatically). Cannot use reserved provider names. + example: "mycompany" + name: + type: string + description: Human-readable display name + example: "My Company SSO" + client_id: + type: string + description: OAuth client ID from the provider + client_secret: + type: string + description: OAuth client secret (will be encrypted at rest) + acceptable_client_ids: + type: array + items: + type: string + description: Additional acceptable client IDs for token validation + scopes: + type: array + items: + type: string + description: OAuth scopes to request (OIDC providers will automatically include 'openid') + example: ["email", "profile"] + pkce_enabled: + type: boolean + description: Enable PKCE (Proof Key for Code Exchange) + default: true + attribute_mapping: + type: object + description: Map provider claims to user attributes (cannot map to protected system fields) + additionalProperties: true + example: + email: "user_email" + name: "full_name" + authorization_params: + type: object + description: Additional authorization request parameters (cannot override reserved OAuth parameters) + additionalProperties: true + example: + prompt: "consent" + enabled: + type: boolean + description: Whether the provider is enabled + default: true + email_optional: + type: boolean + description: Whether email is optional for users from this provider + default: false + issuer: + type: string + format: uri + description: "OIDC issuer URL (required for provider_type: oidc)" + example: "https://accounts.google.com" + discovery_url: + type: string + format: uri + description: OIDC discovery URL (optional for OIDC, defaults to {issuer}/.well-known/openid-configuration) + skip_nonce_check: + type: boolean + description: Skip nonce validation for OIDC (not recommended for production) + default: false + authorization_url: + type: string + format: uri + description: "OAuth 2.0 authorization endpoint (required for provider_type: oauth2)" + example: "https://provider.com/oauth/authorize" + token_url: + type: string + format: uri + description: "OAuth 2.0 token endpoint (required for provider_type: oauth2)" + example: "https://provider.com/oauth/token" + userinfo_url: + type: string + format: uri + description: "OAuth 2.0 userinfo endpoint (required for provider_type: oauth2)" + example: "https://provider.com/oauth/userinfo" + jwks_uri: + type: string + format: uri + description: JWKS URI for token validation (optional for OAuth2) + examples: + oidc_provider: + summary: Create OIDC provider + value: + provider_type: "oidc" + identifier: "mycompany" + name: "My Company SSO" + client_id: "your-client-id" + client_secret: "your-client-secret" + issuer: "https://accounts.mycompany.com" + scopes: ["email", "profile"] + oauth2_provider: + summary: Create OAuth2 provider + value: + provider_type: "oauth2" + identifier: "customauth" + name: "Custom Auth Provider" + client_id: "your-client-id" + client_secret: "your-client-secret" + authorization_url: "https://provider.com/oauth/authorize" + token_url: "https://provider.com/oauth/token" + userinfo_url: "https://provider.com/oauth/userinfo" + scopes: ["email", "profile"] + responses: + 201: + description: Custom OIDC/OAuth provider created successfully + content: + application/json: + schema: + $ref: "#/components/schemas/CustomOAuthProviderSchema" + 400: + description: Bad request - validation failed, provider exists, or quota exceeded + content: + application/json: + schema: + $ref: "#/components/schemas/ErrorSchema" + examples: + validation_failed: + summary: validation_failed + value: + error_code: validation_failed + msg: "issuer is required for OIDC providers" + provider_exists: + summary: conflict + value: + error_code: conflict + msg: "A custom OAuth provider with this identifier already exists" + quota_exceeded: + summary: over_quota + value: + error_code: over_quota + msg: "Maximum number of custom OIDC/OAuth providers reached" + reserved_name: + summary: validation_failed + value: + error_code: validation_failed + msg: "Cannot use reserved provider name: google. This provider is already built into Supabase Auth." + 401: + $ref: "#/components/responses/UnauthorizedResponse" + 403: + $ref: "#/components/responses/ForbiddenResponse" + + /admin/custom-providers/{identifier}: + parameters: + - name: identifier + in: path + required: true + description: Provider identifier (must start with 'custom:' prefix) + schema: + type: string + pattern: "^custom:[a-zA-Z0-9_-]+$" + example: "custom:mycompany" + get: + summary: Get custom OIDC/OAuth provider details + description: > + Retrieves details of a specific custom OIDC/OAuth provider. + Only available when custom OIDC/OAuth providers are enabled. + tags: + - admin + - oauth-client + security: + - APIKeyAuth: [] + AdminAuth: [] + responses: + 200: + description: Custom OIDC/OAuth provider details + content: + application/json: + schema: + $ref: "#/components/schemas/CustomOAuthProviderSchema" + 400: + description: Invalid identifier format + content: + application/json: + schema: + $ref: "#/components/schemas/ErrorSchema" + examples: + validation_failed: + summary: validation_failed + value: + error_code: validation_failed + msg: "identifier must start with 'custom:' prefix" + 401: + $ref: "#/components/responses/UnauthorizedResponse" + 403: + $ref: "#/components/responses/ForbiddenResponse" + 404: + description: Provider not found + content: + application/json: + schema: + $ref: "#/components/schemas/ErrorSchema" + examples: + not_found: + summary: provider_not_found + value: + error_code: provider_not_found + msg: "Custom OAuth provider not found" + put: + summary: Update custom OIDC/OAuth provider + description: > + Updates an existing custom OIDC/OAuth provider. All fields are optional. + Only provided fields will be updated. Only available when custom OIDC/OAuth providers are enabled. + tags: + - admin + - oauth-client + security: + - APIKeyAuth: [] + AdminAuth: [] + requestBody: + required: true + content: + application/json: + schema: + type: object + properties: + name: + type: string + description: Human-readable display name + client_id: + type: string + description: OAuth client ID + client_secret: + type: string + description: OAuth client secret (only provide if changing, will be encrypted) + acceptable_client_ids: + type: array + items: + type: string + description: Additional acceptable client IDs + scopes: + type: array + items: + type: string + description: OAuth scopes to request + pkce_enabled: + type: boolean + description: Enable PKCE + attribute_mapping: + type: object + description: Map provider claims to user attributes + additionalProperties: true + authorization_params: + type: object + description: Additional authorization request parameters + additionalProperties: true + enabled: + type: boolean + description: Whether the provider is enabled + email_optional: + type: boolean + description: Whether email is optional + issuer: + type: string + format: uri + description: OIDC issuer URL (for OIDC providers) + discovery_url: + type: string + format: uri + description: OIDC discovery URL (for OIDC providers) + skip_nonce_check: + type: boolean + description: Skip nonce validation for OIDC + authorization_url: + type: string + format: uri + description: OAuth 2.0 authorization endpoint (for OAuth2 providers) + token_url: + type: string + format: uri + description: OAuth 2.0 token endpoint (for OAuth2 providers) + userinfo_url: + type: string + format: uri + description: OAuth 2.0 userinfo endpoint (for OAuth2 providers) + jwks_uri: + type: string + format: uri + description: JWKS URI for token validation (for OAuth2 providers) + examples: + update_name: + summary: Update provider name + value: + name: "Updated Company Name" + update_scopes: + summary: Update scopes + value: + scopes: ["openid", "email", "profile", "groups"] + responses: + 200: + description: Custom OIDC/OAuth provider updated successfully + content: + application/json: + schema: + $ref: "#/components/schemas/CustomOAuthProviderSchema" + 400: + description: Invalid identifier or validation failed + content: + application/json: + schema: + $ref: "#/components/schemas/ErrorSchema" + 401: + $ref: "#/components/responses/UnauthorizedResponse" + 403: + $ref: "#/components/responses/ForbiddenResponse" + 404: + description: Provider not found + content: + application/json: + schema: + $ref: "#/components/schemas/ErrorSchema" + delete: + summary: Delete custom OIDC/OAuth provider + description: > + Permanently removes a custom OIDC/OAuth provider configuration. + Only available when custom OIDC/OAuth providers are enabled. + tags: + - admin + - oauth-client + security: + - APIKeyAuth: [] + AdminAuth: [] + responses: + 204: + description: Custom OIDC/OAuth provider deleted successfully (no content) + 400: + description: Invalid identifier format + content: + application/json: + schema: + $ref: "#/components/schemas/ErrorSchema" + examples: + validation_failed: + summary: validation_failed + value: + error_code: validation_failed + msg: "identifier must start with 'custom:' prefix" + 401: + $ref: "#/components/responses/UnauthorizedResponse" + 403: + $ref: "#/components/responses/ForbiddenResponse" + 404: + description: Provider not found + content: + application/json: + schema: + $ref: "#/components/schemas/ErrorSchema" + examples: + not_found: + summary: provider_not_found + value: + error_code: provider_not_found + msg: "Custom OAuth provider not found" + /oauth/clients/register: post: summary: Register a new OAuth client dynamically (public endpoint). @@ -3263,6 +3701,110 @@ components: type: string format: date-time + CustomOAuthProviderSchema: + type: object + description: Represents a custom OAuth 2.0 or OIDC provider configuration + required: + - id + - provider_type + - identifier + - name + - client_id + properties: + id: + type: string + format: uuid + description: Unique provider identifier + provider_type: + type: string + enum: + - oauth2 + - oidc + description: Type of OAuth provider + identifier: + type: string + pattern: "^custom:[a-zA-Z0-9_-]+$" + description: Unique identifier for the provider (must start with 'custom:' prefix) + example: "custom:mycompany" + name: + type: string + description: Human-readable name of the provider + example: "My Company SSO" + client_id: + type: string + description: OAuth client ID + acceptable_client_ids: + type: array + items: + type: string + description: Additional acceptable client IDs for token validation + scopes: + type: array + items: + type: string + description: OAuth scopes to request (OIDC providers will automatically include 'openid') + example: ["openid", "email", "profile"] + pkce_enabled: + type: boolean + description: Whether PKCE (Proof Key for Code Exchange) is enabled + default: true + attribute_mapping: + type: object + description: Maps provider claims to user attributes + additionalProperties: true + example: + email: "user_email" + name: "full_name" + authorization_params: + type: object + description: Additional parameters to include in authorization requests (cannot override reserved OAuth parameters) + additionalProperties: true + example: + prompt: "consent" + enabled: + type: boolean + description: Whether the provider is enabled + default: true + email_optional: + type: boolean + description: Whether email is optional for users from this provider + default: false + issuer: + type: string + format: uri + description: OIDC issuer URL (required for OIDC providers) + example: "https://accounts.google.com" + discovery_url: + type: string + format: uri + description: OIDC discovery URL (optional, defaults to {issuer}/.well-known/openid-configuration) + skip_nonce_check: + type: boolean + description: Skip nonce validation for OIDC (not recommended for production) + default: false + authorization_url: + type: string + format: uri + description: OAuth 2.0 authorization endpoint (required for OAuth2 providers) + token_url: + type: string + format: uri + description: OAuth 2.0 token endpoint (required for OAuth2 providers) + userinfo_url: + type: string + format: uri + description: OAuth 2.0 userinfo endpoint (required for OAuth2 providers) + jwks_uri: + type: string + format: uri + description: JWKS URI for token validation (optional for OAuth2 providers) + created_at: + type: string + format: date-time + updated_at: + type: string + format: date-time + responses: OAuthCallbackRedirectResponse: description: >