diff --git a/token/jwt.go b/token/jwt.go index a653e62..7073e27 100644 --- a/token/jwt.go +++ b/token/jwt.go @@ -247,8 +247,14 @@ func (j *Service) Set(w http.ResponseWriter, claims Claims) (Claims, error) { return Claims{}, fmt.Errorf("failed to make token token: %w", err) } - if j.SendJWTHeader { + // For OAuth handshake, always set cookies regardless of SendJWTHeader flag + // This allows the OAuth flow to complete successfully + needsCookies := claims.Handshake != nil + + if j.SendJWTHeader && !needsCookies { w.Header().Set(j.JWTHeaderKey, tokenString) + // reset cookies in case they were set by OAuth handshake + j.Reset(w) return claims, nil } diff --git a/token/jwt_test.go b/token/jwt_test.go index b1d5e81..cdb9986 100644 --- a/token/jwt_test.go +++ b/token/jwt_test.go @@ -258,13 +258,27 @@ func TestJWT_SendJWTHeader(t *testing.T) { SendJWTHeader: true, }) - rr := httptest.NewRecorder() - _, err := j.Set(rr, testClaims) - assert.NoError(t, err) - cookies := rr.Result().Cookies() - t.Log(cookies) - require.Equal(t, 0, len(cookies), "no cookies set") - assert.Equal(t, testJwtValid, rr.Result().Header.Get("X-JWT")) + t.Run("with handshake", func(t *testing.T) { + rr := httptest.NewRecorder() + _, err := j.Set(rr, testClaims) + assert.NoError(t, err) + cookies := rr.Result().Cookies() + t.Log(cookies) + require.Equal(t, 2, len(cookies), "cookies are set for handshake") + assert.Equal(t, testJwtValid, rr.Result().Header.Get("X-JWT")) + }) + + t.Run("without handshake", func(t *testing.T) { + rr := httptest.NewRecorder() + claimsNoHandshake := testClaims + claimsNoHandshake.Handshake = nil + _, err := j.Set(rr, claimsNoHandshake) + assert.NoError(t, err) + cookies := rr.Result().Cookies() + t.Log(cookies) + require.Equal(t, 2, len(cookies), "empty cookies set for non-handshake") + assert.Equal(t, testJwtValidNoHandshake, rr.Result().Header.Get("X-JWT")) + }) } func TestJWT_SetProlonged(t *testing.T) { diff --git a/v2/token/jwt.go b/v2/token/jwt.go index 22954ef..15b04d7 100644 --- a/v2/token/jwt.go +++ b/v2/token/jwt.go @@ -264,8 +264,14 @@ func (j *Service) Set(w http.ResponseWriter, claims Claims) (Claims, error) { return Claims{}, fmt.Errorf("failed to make token token: %w", err) } - if j.SendJWTHeader { + // For OAuth handshake, always set cookies regardless of SendJWTHeader flag + // This allows the OAuth flow to complete successfully + needsCookies := claims.Handshake != nil + + if j.SendJWTHeader && !needsCookies { w.Header().Set(j.JWTHeaderKey, tokenString) + // reset cookies in case they were set by OAuth handshake + j.Reset(w) return claims, nil } @@ -274,6 +280,7 @@ func (j *Service) Set(w http.ResponseWriter, claims Claims) (Claims, error) { cookieExpiration = int(j.CookieDuration.Seconds()) } + // Set cookies (always for OAuth handshake, or when SendJWTHeader is false) jwtCookie := http.Cookie{Name: j.JWTCookieName, Value: tokenString, HttpOnly: true, Path: "/", Domain: j.JWTCookieDomain, MaxAge: cookieExpiration, Secure: j.SecureCookies, SameSite: j.SameSite} http.SetCookie(w, &jwtCookie) diff --git a/v2/token/jwt_test.go b/v2/token/jwt_test.go index ce71fde..d492b73 100644 --- a/v2/token/jwt_test.go +++ b/v2/token/jwt_test.go @@ -258,13 +258,27 @@ func TestJWT_SendJWTHeader(t *testing.T) { SendJWTHeader: true, }) - rr := httptest.NewRecorder() - _, err := j.Set(rr, testClaims) - assert.NoError(t, err) - cookies := rr.Result().Cookies() - t.Log(cookies) - require.Equal(t, 0, len(cookies), "no cookies set") - assert.Equal(t, testJwtValid, rr.Result().Header.Get("X-JWT")) + t.Run("with handshake", func(t *testing.T) { + rr := httptest.NewRecorder() + _, err := j.Set(rr, testClaims) + assert.NoError(t, err) + cookies := rr.Result().Cookies() + t.Log(cookies) + require.Equal(t, 2, len(cookies), "cookies are set for handshake") + assert.Equal(t, testJwtValid, rr.Result().Header.Get("X-JWT")) + }) + + t.Run("without handshake", func(t *testing.T) { + rr := httptest.NewRecorder() + claimsNoHandshake := testClaims + claimsNoHandshake.Handshake = nil + _, err := j.Set(rr, claimsNoHandshake) + assert.NoError(t, err) + cookies := rr.Result().Cookies() + t.Log(cookies) + require.Equal(t, 2, len(cookies), "empty cookies set for non-handshake") + assert.Equal(t, testJwtValidNoHandshake, rr.Result().Header.Get("X-JWT")) + }) } func TestJWT_SetProlonged(t *testing.T) {