From 3feb791f781ee3ebbfa9f5c5000623f74e39da0e Mon Sep 17 00:00:00 2001 From: fadymak Date: Fri, 6 Feb 2026 10:40:34 +0100 Subject: [PATCH] chore: remove legacy JWT-based flow state handling --- internal/api/context.go | 13 --- internal/api/external.go | 104 +---------------------- internal/api/external_oauth.go | 4 +- internal/api/external_test.go | 150 +++++++-------------------------- 4 files changed, 38 insertions(+), 233 deletions(-) diff --git a/internal/api/context.go b/internal/api/context.go index 7a0df9ed5..f8367a4ab 100644 --- a/internal/api/context.go +++ b/internal/api/context.go @@ -33,7 +33,6 @@ const ( oauthVerifierKey = contextKey("oauth_verifier") ssoProviderKey = contextKey("sso_provider") externalHostKey = contextKey("external_host") - flowStateKey = contextKey("flow_state_id") oauthClientStateKey = contextKey("oauth_client_state_id") flowStateContextKey = contextKey("flow_state") ) @@ -128,18 +127,6 @@ func withInviteToken(ctx context.Context, token string) context.Context { return context.WithValue(ctx, inviteTokenKey, token) } -func withFlowStateID(ctx context.Context, FlowStateID string) context.Context { - return context.WithValue(ctx, flowStateKey, FlowStateID) -} - -func getFlowStateID(ctx context.Context) string { - obj := ctx.Value(flowStateKey) - if obj == nil { - return "" - } - return obj.(string) -} - func withOAuthClientStateID(ctx context.Context, oauthClientStateID uuid.UUID) context.Context { return context.WithValue(ctx, oauthClientStateKey, oauthClientStateID) } diff --git a/internal/api/external.go b/internal/api/external.go index a3597f5be..d48722910 100644 --- a/internal/api/external.go +++ b/internal/api/external.go @@ -10,7 +10,6 @@ import ( "github.com/fatih/structs" "github.com/gofrs/uuid" - jwt "github.com/golang-jwt/jwt/v5" "github.com/sirupsen/logrus" "github.com/supabase/auth/internal/api/apierrors" "github.com/supabase/auth/internal/api/provider" @@ -23,18 +22,6 @@ import ( "golang.org/x/oauth2" ) -// ExternalProviderClaims are the JWT claims sent as the state in the external oauth provider signup flow -type ExternalProviderClaims struct { - AuthMicroserviceClaims - Provider string `json:"provider"` - InviteToken string `json:"invite_token,omitempty"` - Referrer string `json:"referrer,omitempty"` - FlowStateID string `json:"flow_state_id"` - OAuthClientStateID string `json:"oauth_client_state_id,omitempty"` - LinkingTargetID string `json:"linking_target_id,omitempty"` - EmailOptional bool `json:"email_optional,omitempty"` -} - // ExternalProviderRedirect redirects the request to the oauth provider func (a *API) ExternalProviderRedirect(w http.ResponseWriter, r *http.Request) error { rurl, err := a.GetExternalProviderRedirectURL(w, r, nil) @@ -203,20 +190,7 @@ func (a *API) internalExternalProviderCallback(w http.ResponseWriter, r *http.Re providerAccessToken := data.token providerRefreshToken := data.refreshToken - // Get flow state from context (new UUID format) or load from FlowStateID (legacy JWT format) flowState := getFlowState(ctx) - if flowState == nil { - // Backward compatibility: load from FlowStateID for legacy JWT state - // To be removed in subsequent release. - if flowStateID := getFlowStateID(ctx); flowStateID != "" { - flowState, err = models.FindFlowStateByID(db, flowStateID) - if models.IsNotFoundError(err) { - return apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeFlowStateNotFound, "Flow state not found").WithInternalError(err) - } else if err != nil { - return apierrors.NewInternalServerError("Failed to find flow state").WithInternalError(err) - } - } - } targetUser := getTargetUser(ctx) inviteToken := getInviteToken(ctx) @@ -545,13 +519,12 @@ func (a *API) loadExternalState(ctx context.Context, r *http.Request, db *storag return ctx, apierrors.NewBadRequestError(apierrors.ErrorCodeBadOAuthCallback, "OAuth state parameter missing") } - // Try to parse state as UUID first (new format) - if stateUUID, err := uuid.FromString(state); err == nil { - return a.loadExternalStateFromUUID(ctx, db, stateUUID) + stateUUID, err := uuid.FromString(state) + if err != nil { + return ctx, apierrors.NewBadRequestError(apierrors.ErrorCodeBadOAuthState, "OAuth state parameter is invalid") } - // Fall back to JWT parsing for backward compatibility - return a.loadExternalStateFromJWT(ctx, db, state) + return a.loadExternalStateFromUUID(ctx, db, stateUUID) } // loadExternalStateFromUUID loads OAuth state from a flow_state record (new UUID format) @@ -598,75 +571,6 @@ func (a *API) loadExternalStateFromUUID(ctx context.Context, db *storage.Connect return withSignature(ctx, stateID.String()), nil } -// loadExternalStateFromJWT loads OAuth state from a JWT (legacy format for backward compatibility) -func (a *API) loadExternalStateFromJWT(ctx context.Context, db *storage.Connection, state string) (context.Context, error) { - config := a.config - claims := ExternalProviderClaims{} - p := jwt.NewParser(jwt.WithValidMethods(config.JWT.ValidMethods)) - _, err := p.ParseWithClaims(state, &claims, func(token *jwt.Token) (interface{}, error) { - if kid, ok := token.Header["kid"]; ok { - if kidStr, ok := kid.(string); ok { - key, err := conf.FindPublicKeyByKid(kidStr, &config.JWT) - if err != nil { - return nil, err - } - - if key != nil { - return key, nil - } - - // otherwise try to use fallback - } - } - if alg, ok := token.Header["alg"]; ok { - if alg == jwt.SigningMethodHS256.Name { - // preserve backward compatibility for cases where the kid is not set or potentially invalid but the key can be decoded with the secret - return []byte(config.JWT.Secret), nil - } - } - - return nil, fmt.Errorf("unrecognized JWT kid %v for algorithm %v", token.Header["kid"], token.Header["alg"]) - }) - if err != nil { - return ctx, apierrors.NewBadRequestError(apierrors.ErrorCodeBadOAuthState, "OAuth callback with invalid state").WithInternalError(err) - } - if claims.Provider == "" { - return ctx, apierrors.NewBadRequestError(apierrors.ErrorCodeBadOAuthState, "OAuth callback with invalid state (missing provider)") - } - if claims.InviteToken != "" { - ctx = withInviteToken(ctx, claims.InviteToken) - } - if claims.Referrer != "" { - ctx = withExternalReferrer(ctx, claims.Referrer) - } - if claims.FlowStateID != "" { - ctx = withFlowStateID(ctx, claims.FlowStateID) - } - if claims.OAuthClientStateID != "" { - oauthClientStateID, err := uuid.FromString(claims.OAuthClientStateID) - if err != nil { - return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeBadOAuthState, "OAuth callback with invalid state (oauth_client_state_id must be UUID)") - } - ctx = withOAuthClientStateID(ctx, oauthClientStateID) - } - if claims.LinkingTargetID != "" { - linkingTargetUserID, err := uuid.FromString(claims.LinkingTargetID) - if err != nil { - return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeBadOAuthState, "OAuth callback with invalid state (linking_target_id must be UUID)") - } - u, err := models.FindUserByID(db, linkingTargetUserID) - if err != nil { - if models.IsNotFoundError(err) { - return nil, apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeUserNotFound, "Linking target user not found") - } - return nil, apierrors.NewInternalServerError("Database error loading user").WithInternalError(err) - } - ctx = withTargetUser(ctx, u) - } - ctx = withExternalProviderType(ctx, claims.Provider, claims.EmailOptional) - return withSignature(ctx, state), nil -} - // Provider returns a Provider interface for the given name. func (a *API) Provider(ctx context.Context, name string, scopes string) (provider.Provider, conf.OAuthProviderConfiguration, error) { config := a.config diff --git a/internal/api/external_oauth.go b/internal/api/external_oauth.go index 40e737a04..837aace9a 100644 --- a/internal/api/external_oauth.go +++ b/internal/api/external_oauth.go @@ -26,8 +26,8 @@ type OAuthProviderData struct { code string } -// loadFlowState parses the `state` query parameter as a JWS payload, -// extracting the provider requested +// loadFlowState parses the `state` query parameter as a UUID, +// loads the flow state from the database, and extracts the provider requested func (a *API) loadFlowState(w http.ResponseWriter, r *http.Request) (context.Context, error) { ctx := r.Context() db := a.db.WithContext(ctx) diff --git a/internal/api/external_test.go b/internal/api/external_test.go index eaa4cbb78..b70da6431 100644 --- a/internal/api/external_test.go +++ b/internal/api/external_test.go @@ -6,15 +6,12 @@ import ( "net/http/httptest" "net/url" "testing" - "time" "github.com/gofrs/uuid" - jwt "github.com/golang-jwt/jwt/v5" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" "github.com/supabase/auth/internal/conf" "github.com/supabase/auth/internal/models" - "github.com/supabase/auth/internal/tokens" ) type ExternalTestSuite struct { @@ -352,145 +349,62 @@ func setupGenericOAuthServer(ts *ExternalTestSuite, code string) *httptest.Serve return server } -// TestOAuthState_BackwardCompatibleJWT tests that the callback endpoint -// still accepts the legacy JWT state format for backward compatibility during migration. -func (ts *ExternalTestSuite) TestOAuthState_BackwardCompatibleJWT() { +// TestOAuthState_UUIDFormat tests that the callback endpoint processes UUID state correctly. +func (ts *ExternalTestSuite) TestOAuthState_UUIDFormat() { code := "authcode" server := setupGenericOAuthServer(ts, code) defer server.Close() - // Create a legacy JWT state token manually - claims := &ExternalProviderClaims{ - AuthMicroserviceClaims: AuthMicroserviceClaims{ - RegisteredClaims: jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(time.Now().Add(5 * time.Minute)), - IssuedAt: jwt.NewNumericDate(time.Now()), - Issuer: ts.Config.JWT.Issuer, - }, - }, - Provider: "github", - Referrer: "https://example.com/admin", - EmailOptional: false, - } + // Use the standard authorization flow which generates UUID state + w := performAuthorizationRequest(ts, "github", "") + ts.Require().Equal(http.StatusFound, w.Code) + u, err := url.Parse(w.Header().Get("Location")) + ts.Require().NoError(err) - jwtState, err := tokens.SignJWT(&ts.Config.JWT, claims) - require.NoError(ts.T(), err) - require.NotEmpty(ts.T(), jwtState) + state := u.Query().Get("state") + ts.Require().NotEmpty(state) + + stateUUID, err := uuid.FromString(state) + require.NoError(ts.T(), err, "state should be a valid UUID") + require.NotEqual(ts.T(), uuid.Nil, stateUUID) testURL, err := url.Parse("http://localhost/callback") require.NoError(ts.T(), err) v := testURL.Query() v.Set("code", code) - v.Set("state", jwtState) + v.Set("state", state) testURL.RawQuery = v.Encode() req := httptest.NewRequest(http.MethodGet, testURL.String(), nil) - w := httptest.NewRecorder() + w = httptest.NewRecorder() ts.API.handler.ServeHTTP(w, req) ts.Require().Equal(http.StatusFound, w.Code) - u, err := url.Parse(w.Header().Get("Location")) - ts.Require().NoError(err, "redirect url parse failed") - ts.Require().Equal("/admin", u.Path) - - fragment, err := url.ParseQuery(u.Fragment) + resultURL, err := url.Parse(w.Header().Get("Location")) ts.Require().NoError(err) - ts.NotEmpty(fragment.Get("access_token"), "should have access_token") - ts.NotEmpty(fragment.Get("refresh_token"), "should have refresh_token") - user, err := models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) - require.NoError(ts.T(), err) - require.NotNil(ts.T(), user) + fragment, err := url.ParseQuery(resultURL.Fragment) + ts.Require().NoError(err) + ts.NotEmpty(fragment.Get("access_token"), "UUID state should result in access_token") } -// TestOAuthState_MigrationScenario tests that both UUID and JWT state formats -// can be processed during the migration period. -func (ts *ExternalTestSuite) TestOAuthState_MigrationScenario() { +// TestOAuthState_InvalidFormat tests that non-UUID state parameters are rejected. +func (ts *ExternalTestSuite) TestOAuthState_InvalidFormat() { code := "authcode" server := setupGenericOAuthServer(ts, code) defer server.Close() - ts.Run("NewUUIDFormat", func() { - // Use the standard authorization flow which now generates UUID state - w := performAuthorizationRequest(ts, "github", "") - ts.Require().Equal(http.StatusFound, w.Code) - u, err := url.Parse(w.Header().Get("Location")) - ts.Require().NoError(err) - - state := u.Query().Get("state") - ts.Require().NotEmpty(state) - - // Verify state is a valid UUID - stateUUID, err := uuid.FromString(state) - require.NoError(ts.T(), err, "state should be a valid UUID") - require.NotEqual(ts.T(), uuid.Nil, stateUUID) - - // Complete the callback - testURL, err := url.Parse("http://localhost/callback") - require.NoError(ts.T(), err) - v := testURL.Query() - v.Set("code", code) - v.Set("state", state) - testURL.RawQuery = v.Encode() - - req := httptest.NewRequest(http.MethodGet, testURL.String(), nil) - w = httptest.NewRecorder() - ts.API.handler.ServeHTTP(w, req) - - ts.Require().Equal(http.StatusFound, w.Code) - resultURL, err := url.Parse(w.Header().Get("Location")) - ts.Require().NoError(err) - - fragment, err := url.ParseQuery(resultURL.Fragment) - ts.Require().NoError(err) - ts.NotEmpty(fragment.Get("access_token"), "UUID state should result in access_token") - }) - - // Clean up user for next test - user, _ := models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) - if user != nil { - require.NoError(ts.T(), ts.API.db.Destroy(user)) - } - - ts.Run("LegacyJWTFormat", func() { - // Create a legacy JWT state - claims := &ExternalProviderClaims{ - AuthMicroserviceClaims: AuthMicroserviceClaims{ - RegisteredClaims: jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(time.Now().Add(5 * time.Minute)), - IssuedAt: jwt.NewNumericDate(time.Now()), - Issuer: ts.Config.JWT.Issuer, - }, - }, - Provider: "github", - Referrer: "https://example.com/admin", - } - - jwtState, err := tokens.SignJWT(&ts.Config.JWT, claims) - require.NoError(ts.T(), err) - - // Verify state is NOT a UUID (it's a JWT) - _, uuidErr := uuid.FromString(jwtState) - require.Error(ts.T(), uuidErr, "JWT state should not be parseable as UUID") - - // Complete the callback with JWT state - testURL, err := url.Parse("http://localhost/callback") - require.NoError(ts.T(), err) - v := testURL.Query() - v.Set("code", code) - v.Set("state", jwtState) - testURL.RawQuery = v.Encode() - - req := httptest.NewRequest(http.MethodGet, testURL.String(), nil) - w := httptest.NewRecorder() - ts.API.handler.ServeHTTP(w, req) + testURL, err := url.Parse("http://localhost/callback") + require.NoError(ts.T(), err) + v := testURL.Query() + v.Set("code", code) + v.Set("state", "not-a-valid-uuid") + testURL.RawQuery = v.Encode() - ts.Require().Equal(http.StatusFound, w.Code) - resultURL, err := url.Parse(w.Header().Get("Location")) - ts.Require().NoError(err) + req := httptest.NewRequest(http.MethodGet, testURL.String(), nil) + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) - fragment, err := url.ParseQuery(resultURL.Fragment) - ts.Require().NoError(err) - ts.NotEmpty(fragment.Get("access_token"), "JWT state should also result in access_token") - }) + // Should redirect to site URL with error since state is invalid + ts.Require().Equal(http.StatusSeeOther, w.Code) }