Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 0 additions & 13 deletions internal/api/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ const (
oauthVerifierKey = contextKey("oauth_verifier")
ssoProviderKey = contextKey("sso_provider")
externalHostKey = contextKey("external_host")
flowStateKey = contextKey("flow_state_id")
oauthClientStateKey = contextKey("oauth_client_state_id")
flowStateContextKey = contextKey("flow_state")
)
Expand Down Expand Up @@ -128,18 +127,6 @@ func withInviteToken(ctx context.Context, token string) context.Context {
return context.WithValue(ctx, inviteTokenKey, token)
}

func withFlowStateID(ctx context.Context, FlowStateID string) context.Context {
return context.WithValue(ctx, flowStateKey, FlowStateID)
}

func getFlowStateID(ctx context.Context) string {
obj := ctx.Value(flowStateKey)
if obj == nil {
return ""
}
return obj.(string)
}

func withOAuthClientStateID(ctx context.Context, oauthClientStateID uuid.UUID) context.Context {
return context.WithValue(ctx, oauthClientStateKey, oauthClientStateID)
}
Expand Down
104 changes: 4 additions & 100 deletions internal/api/external.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (

"github.com/fatih/structs"
"github.com/gofrs/uuid"
jwt "github.com/golang-jwt/jwt/v5"
"github.com/sirupsen/logrus"
"github.com/supabase/auth/internal/api/apierrors"
"github.com/supabase/auth/internal/api/provider"
Expand All @@ -23,18 +22,6 @@ import (
"golang.org/x/oauth2"
)

// ExternalProviderClaims are the JWT claims sent as the state in the external oauth provider signup flow
type ExternalProviderClaims struct {
AuthMicroserviceClaims
Provider string `json:"provider"`
InviteToken string `json:"invite_token,omitempty"`
Referrer string `json:"referrer,omitempty"`
FlowStateID string `json:"flow_state_id"`
OAuthClientStateID string `json:"oauth_client_state_id,omitempty"`
LinkingTargetID string `json:"linking_target_id,omitempty"`
EmailOptional bool `json:"email_optional,omitempty"`
}

// ExternalProviderRedirect redirects the request to the oauth provider
func (a *API) ExternalProviderRedirect(w http.ResponseWriter, r *http.Request) error {
rurl, err := a.GetExternalProviderRedirectURL(w, r, nil)
Expand Down Expand Up @@ -203,20 +190,7 @@ func (a *API) internalExternalProviderCallback(w http.ResponseWriter, r *http.Re
providerAccessToken := data.token
providerRefreshToken := data.refreshToken

// Get flow state from context (new UUID format) or load from FlowStateID (legacy JWT format)
flowState := getFlowState(ctx)
if flowState == nil {
// Backward compatibility: load from FlowStateID for legacy JWT state
// To be removed in subsequent release.
if flowStateID := getFlowStateID(ctx); flowStateID != "" {
flowState, err = models.FindFlowStateByID(db, flowStateID)
if models.IsNotFoundError(err) {
return apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeFlowStateNotFound, "Flow state not found").WithInternalError(err)
} else if err != nil {
return apierrors.NewInternalServerError("Failed to find flow state").WithInternalError(err)
}
}
}

targetUser := getTargetUser(ctx)
inviteToken := getInviteToken(ctx)
Expand Down Expand Up @@ -545,13 +519,12 @@ func (a *API) loadExternalState(ctx context.Context, r *http.Request, db *storag
return ctx, apierrors.NewBadRequestError(apierrors.ErrorCodeBadOAuthCallback, "OAuth state parameter missing")
}

// Try to parse state as UUID first (new format)
if stateUUID, err := uuid.FromString(state); err == nil {
return a.loadExternalStateFromUUID(ctx, db, stateUUID)
stateUUID, err := uuid.FromString(state)
if err != nil {
return ctx, apierrors.NewBadRequestError(apierrors.ErrorCodeBadOAuthState, "OAuth state parameter is invalid")
}

// Fall back to JWT parsing for backward compatibility
return a.loadExternalStateFromJWT(ctx, db, state)
return a.loadExternalStateFromUUID(ctx, db, stateUUID)
}

// loadExternalStateFromUUID loads OAuth state from a flow_state record (new UUID format)
Expand Down Expand Up @@ -598,75 +571,6 @@ func (a *API) loadExternalStateFromUUID(ctx context.Context, db *storage.Connect
return withSignature(ctx, stateID.String()), nil
}

// loadExternalStateFromJWT loads OAuth state from a JWT (legacy format for backward compatibility)
func (a *API) loadExternalStateFromJWT(ctx context.Context, db *storage.Connection, state string) (context.Context, error) {
config := a.config
claims := ExternalProviderClaims{}
p := jwt.NewParser(jwt.WithValidMethods(config.JWT.ValidMethods))
_, err := p.ParseWithClaims(state, &claims, func(token *jwt.Token) (interface{}, error) {
if kid, ok := token.Header["kid"]; ok {
if kidStr, ok := kid.(string); ok {
key, err := conf.FindPublicKeyByKid(kidStr, &config.JWT)
if err != nil {
return nil, err
}

if key != nil {
return key, nil
}

// otherwise try to use fallback
}
}
if alg, ok := token.Header["alg"]; ok {
if alg == jwt.SigningMethodHS256.Name {
// preserve backward compatibility for cases where the kid is not set or potentially invalid but the key can be decoded with the secret
return []byte(config.JWT.Secret), nil
}
}

return nil, fmt.Errorf("unrecognized JWT kid %v for algorithm %v", token.Header["kid"], token.Header["alg"])
})
if err != nil {
return ctx, apierrors.NewBadRequestError(apierrors.ErrorCodeBadOAuthState, "OAuth callback with invalid state").WithInternalError(err)
}
if claims.Provider == "" {
return ctx, apierrors.NewBadRequestError(apierrors.ErrorCodeBadOAuthState, "OAuth callback with invalid state (missing provider)")
}
if claims.InviteToken != "" {
ctx = withInviteToken(ctx, claims.InviteToken)
}
if claims.Referrer != "" {
ctx = withExternalReferrer(ctx, claims.Referrer)
}
if claims.FlowStateID != "" {
ctx = withFlowStateID(ctx, claims.FlowStateID)
}
if claims.OAuthClientStateID != "" {
oauthClientStateID, err := uuid.FromString(claims.OAuthClientStateID)
if err != nil {
return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeBadOAuthState, "OAuth callback with invalid state (oauth_client_state_id must be UUID)")
}
ctx = withOAuthClientStateID(ctx, oauthClientStateID)
}
if claims.LinkingTargetID != "" {
linkingTargetUserID, err := uuid.FromString(claims.LinkingTargetID)
if err != nil {
return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeBadOAuthState, "OAuth callback with invalid state (linking_target_id must be UUID)")
}
u, err := models.FindUserByID(db, linkingTargetUserID)
if err != nil {
if models.IsNotFoundError(err) {
return nil, apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeUserNotFound, "Linking target user not found")
}
return nil, apierrors.NewInternalServerError("Database error loading user").WithInternalError(err)
}
ctx = withTargetUser(ctx, u)
}
ctx = withExternalProviderType(ctx, claims.Provider, claims.EmailOptional)
return withSignature(ctx, state), nil
}

// Provider returns a Provider interface for the given name.
func (a *API) Provider(ctx context.Context, name string, scopes string) (provider.Provider, conf.OAuthProviderConfiguration, error) {
config := a.config
Expand Down
4 changes: 2 additions & 2 deletions internal/api/external_oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ type OAuthProviderData struct {
code string
}

// loadFlowState parses the `state` query parameter as a JWS payload,
// extracting the provider requested
// loadFlowState parses the `state` query parameter as a UUID,
// loads the flow state from the database, and extracts the provider requested
func (a *API) loadFlowState(w http.ResponseWriter, r *http.Request) (context.Context, error) {
ctx := r.Context()
db := a.db.WithContext(ctx)
Expand Down
150 changes: 32 additions & 118 deletions internal/api/external_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,12 @@ import (
"net/http/httptest"
"net/url"
"testing"
"time"

"github.com/gofrs/uuid"
jwt "github.com/golang-jwt/jwt/v5"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
"github.com/supabase/auth/internal/conf"
"github.com/supabase/auth/internal/models"
"github.com/supabase/auth/internal/tokens"
)

type ExternalTestSuite struct {
Expand Down Expand Up @@ -352,145 +349,62 @@ func setupGenericOAuthServer(ts *ExternalTestSuite, code string) *httptest.Serve
return server
}

// TestOAuthState_BackwardCompatibleJWT tests that the callback endpoint
// still accepts the legacy JWT state format for backward compatibility during migration.
func (ts *ExternalTestSuite) TestOAuthState_BackwardCompatibleJWT() {
// TestOAuthState_UUIDFormat tests that the callback endpoint processes UUID state correctly.
func (ts *ExternalTestSuite) TestOAuthState_UUIDFormat() {
code := "authcode"
server := setupGenericOAuthServer(ts, code)
defer server.Close()

// Create a legacy JWT state token manually
claims := &ExternalProviderClaims{
AuthMicroserviceClaims: AuthMicroserviceClaims{
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(time.Now().Add(5 * time.Minute)),
IssuedAt: jwt.NewNumericDate(time.Now()),
Issuer: ts.Config.JWT.Issuer,
},
},
Provider: "github",
Referrer: "https://example.com/admin",
EmailOptional: false,
}
// Use the standard authorization flow which generates UUID state
w := performAuthorizationRequest(ts, "github", "")
ts.Require().Equal(http.StatusFound, w.Code)
u, err := url.Parse(w.Header().Get("Location"))
ts.Require().NoError(err)

jwtState, err := tokens.SignJWT(&ts.Config.JWT, claims)
require.NoError(ts.T(), err)
require.NotEmpty(ts.T(), jwtState)
state := u.Query().Get("state")
ts.Require().NotEmpty(state)

stateUUID, err := uuid.FromString(state)
require.NoError(ts.T(), err, "state should be a valid UUID")
require.NotEqual(ts.T(), uuid.Nil, stateUUID)

testURL, err := url.Parse("http://localhost/callback")
require.NoError(ts.T(), err)
v := testURL.Query()
v.Set("code", code)
v.Set("state", jwtState)
v.Set("state", state)
testURL.RawQuery = v.Encode()

req := httptest.NewRequest(http.MethodGet, testURL.String(), nil)
w := httptest.NewRecorder()
w = httptest.NewRecorder()
ts.API.handler.ServeHTTP(w, req)

ts.Require().Equal(http.StatusFound, w.Code)
u, err := url.Parse(w.Header().Get("Location"))
ts.Require().NoError(err, "redirect url parse failed")
ts.Require().Equal("/admin", u.Path)

fragment, err := url.ParseQuery(u.Fragment)
resultURL, err := url.Parse(w.Header().Get("Location"))
ts.Require().NoError(err)
ts.NotEmpty(fragment.Get("access_token"), "should have access_token")
ts.NotEmpty(fragment.Get("refresh_token"), "should have refresh_token")

user, err := models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud)
require.NoError(ts.T(), err)
require.NotNil(ts.T(), user)
fragment, err := url.ParseQuery(resultURL.Fragment)
ts.Require().NoError(err)
ts.NotEmpty(fragment.Get("access_token"), "UUID state should result in access_token")
}

// TestOAuthState_MigrationScenario tests that both UUID and JWT state formats
// can be processed during the migration period.
func (ts *ExternalTestSuite) TestOAuthState_MigrationScenario() {
// TestOAuthState_InvalidFormat tests that non-UUID state parameters are rejected.
func (ts *ExternalTestSuite) TestOAuthState_InvalidFormat() {
code := "authcode"
server := setupGenericOAuthServer(ts, code)
defer server.Close()

ts.Run("NewUUIDFormat", func() {
// Use the standard authorization flow which now generates UUID state
w := performAuthorizationRequest(ts, "github", "")
ts.Require().Equal(http.StatusFound, w.Code)
u, err := url.Parse(w.Header().Get("Location"))
ts.Require().NoError(err)

state := u.Query().Get("state")
ts.Require().NotEmpty(state)

// Verify state is a valid UUID
stateUUID, err := uuid.FromString(state)
require.NoError(ts.T(), err, "state should be a valid UUID")
require.NotEqual(ts.T(), uuid.Nil, stateUUID)

// Complete the callback
testURL, err := url.Parse("http://localhost/callback")
require.NoError(ts.T(), err)
v := testURL.Query()
v.Set("code", code)
v.Set("state", state)
testURL.RawQuery = v.Encode()

req := httptest.NewRequest(http.MethodGet, testURL.String(), nil)
w = httptest.NewRecorder()
ts.API.handler.ServeHTTP(w, req)

ts.Require().Equal(http.StatusFound, w.Code)
resultURL, err := url.Parse(w.Header().Get("Location"))
ts.Require().NoError(err)

fragment, err := url.ParseQuery(resultURL.Fragment)
ts.Require().NoError(err)
ts.NotEmpty(fragment.Get("access_token"), "UUID state should result in access_token")
})

// Clean up user for next test
user, _ := models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud)
if user != nil {
require.NoError(ts.T(), ts.API.db.Destroy(user))
}

ts.Run("LegacyJWTFormat", func() {
// Create a legacy JWT state
claims := &ExternalProviderClaims{
AuthMicroserviceClaims: AuthMicroserviceClaims{
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(time.Now().Add(5 * time.Minute)),
IssuedAt: jwt.NewNumericDate(time.Now()),
Issuer: ts.Config.JWT.Issuer,
},
},
Provider: "github",
Referrer: "https://example.com/admin",
}

jwtState, err := tokens.SignJWT(&ts.Config.JWT, claims)
require.NoError(ts.T(), err)

// Verify state is NOT a UUID (it's a JWT)
_, uuidErr := uuid.FromString(jwtState)
require.Error(ts.T(), uuidErr, "JWT state should not be parseable as UUID")

// Complete the callback with JWT state
testURL, err := url.Parse("http://localhost/callback")
require.NoError(ts.T(), err)
v := testURL.Query()
v.Set("code", code)
v.Set("state", jwtState)
testURL.RawQuery = v.Encode()

req := httptest.NewRequest(http.MethodGet, testURL.String(), nil)
w := httptest.NewRecorder()
ts.API.handler.ServeHTTP(w, req)
testURL, err := url.Parse("http://localhost/callback")
require.NoError(ts.T(), err)
v := testURL.Query()
v.Set("code", code)
v.Set("state", "not-a-valid-uuid")
testURL.RawQuery = v.Encode()

ts.Require().Equal(http.StatusFound, w.Code)
resultURL, err := url.Parse(w.Header().Get("Location"))
ts.Require().NoError(err)
req := httptest.NewRequest(http.MethodGet, testURL.String(), nil)
w := httptest.NewRecorder()
ts.API.handler.ServeHTTP(w, req)

fragment, err := url.ParseQuery(resultURL.Fragment)
ts.Require().NoError(err)
ts.NotEmpty(fragment.Get("access_token"), "JWT state should also result in access_token")
})
// Should redirect to site URL with error since state is invalid
ts.Require().Equal(http.StatusSeeOther, w.Code)
}