Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion token/jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -247,8 +247,14 @@ func (j *Service) Set(w http.ResponseWriter, claims Claims) (Claims, error) {
return Claims{}, fmt.Errorf("failed to make token token: %w", err)
}

if j.SendJWTHeader {
// For OAuth handshake, always set cookies regardless of SendJWTHeader flag
// This allows the OAuth flow to complete successfully
needsCookies := claims.Handshake != nil

if j.SendJWTHeader && !needsCookies {
w.Header().Set(j.JWTHeaderKey, tokenString)
// reset cookies in case they were set by OAuth handshake
j.Reset(w)
return claims, nil
}

Expand Down
28 changes: 21 additions & 7 deletions token/jwt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -258,13 +258,27 @@ func TestJWT_SendJWTHeader(t *testing.T) {
SendJWTHeader: true,
})

rr := httptest.NewRecorder()
_, err := j.Set(rr, testClaims)
assert.NoError(t, err)
cookies := rr.Result().Cookies()
t.Log(cookies)
require.Equal(t, 0, len(cookies), "no cookies set")
assert.Equal(t, testJwtValid, rr.Result().Header.Get("X-JWT"))
t.Run("with handshake", func(t *testing.T) {
rr := httptest.NewRecorder()
_, err := j.Set(rr, testClaims)
assert.NoError(t, err)
cookies := rr.Result().Cookies()
t.Log(cookies)
require.Equal(t, 2, len(cookies), "cookies are set for handshake")
assert.Equal(t, testJwtValid, rr.Result().Header.Get("X-JWT"))
})

t.Run("without handshake", func(t *testing.T) {
rr := httptest.NewRecorder()
claimsNoHandshake := testClaims
claimsNoHandshake.Handshake = nil
_, err := j.Set(rr, claimsNoHandshake)
assert.NoError(t, err)
cookies := rr.Result().Cookies()
t.Log(cookies)
require.Equal(t, 2, len(cookies), "empty cookies set for non-handshake")
assert.Equal(t, testJwtValidNoHandshake, rr.Result().Header.Get("X-JWT"))
})
}

func TestJWT_SetProlonged(t *testing.T) {
Expand Down
9 changes: 8 additions & 1 deletion v2/token/jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -264,8 +264,14 @@ func (j *Service) Set(w http.ResponseWriter, claims Claims) (Claims, error) {
return Claims{}, fmt.Errorf("failed to make token token: %w", err)
}

if j.SendJWTHeader {
// For OAuth handshake, always set cookies regardless of SendJWTHeader flag
// This allows the OAuth flow to complete successfully
needsCookies := claims.Handshake != nil

if j.SendJWTHeader && !needsCookies {
w.Header().Set(j.JWTHeaderKey, tokenString)
// reset cookies in case they were set by OAuth handshake
j.Reset(w)
return claims, nil
}

Expand All @@ -274,6 +280,7 @@ func (j *Service) Set(w http.ResponseWriter, claims Claims) (Claims, error) {
cookieExpiration = int(j.CookieDuration.Seconds())
}

// Set cookies (always for OAuth handshake, or when SendJWTHeader is false)
jwtCookie := http.Cookie{Name: j.JWTCookieName, Value: tokenString, HttpOnly: true, Path: "/", Domain: j.JWTCookieDomain,
MaxAge: cookieExpiration, Secure: j.SecureCookies, SameSite: j.SameSite}
http.SetCookie(w, &jwtCookie)
Expand Down
28 changes: 21 additions & 7 deletions v2/token/jwt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -258,13 +258,27 @@ func TestJWT_SendJWTHeader(t *testing.T) {
SendJWTHeader: true,
})

rr := httptest.NewRecorder()
_, err := j.Set(rr, testClaims)
assert.NoError(t, err)
cookies := rr.Result().Cookies()
t.Log(cookies)
require.Equal(t, 0, len(cookies), "no cookies set")
assert.Equal(t, testJwtValid, rr.Result().Header.Get("X-JWT"))
t.Run("with handshake", func(t *testing.T) {
rr := httptest.NewRecorder()
_, err := j.Set(rr, testClaims)
assert.NoError(t, err)
cookies := rr.Result().Cookies()
t.Log(cookies)
require.Equal(t, 2, len(cookies), "cookies are set for handshake")
assert.Equal(t, testJwtValid, rr.Result().Header.Get("X-JWT"))
})

t.Run("without handshake", func(t *testing.T) {
rr := httptest.NewRecorder()
claimsNoHandshake := testClaims
claimsNoHandshake.Handshake = nil
_, err := j.Set(rr, claimsNoHandshake)
assert.NoError(t, err)
cookies := rr.Result().Cookies()
t.Log(cookies)
require.Equal(t, 2, len(cookies), "empty cookies set for non-handshake")
assert.Equal(t, testJwtValidNoHandshake, rr.Result().Header.Get("X-JWT"))
})
}

func TestJWT_SetProlonged(t *testing.T) {
Expand Down
Loading