From 3c1650b7c0d3f3d8f6e459307c24c4e0f5aed6d7 Mon Sep 17 00:00:00 2001 From: Jorge Turrado Date: Mon, 3 Nov 2025 14:04:28 +0100 Subject: [PATCH 01/15] feat: Remove refresh_token grant type Signed-off-by: Jorge Turrado --- core/clients/key_flow.go | 74 ++++--------------- core/clients/key_flow_continuous_refresh.go | 2 +- .../key_flow_continuous_refresh_test.go | 41 +++------- core/clients/key_flow_test.go | 66 +---------------- 4 files changed, 27 insertions(+), 156 deletions(-) diff --git a/core/clients/key_flow.go b/core/clients/key_flow.go index 589774314..9a1b5d1e8 100644 --- a/core/clients/key_flow.go +++ b/core/clients/key_flow.go @@ -68,11 +68,10 @@ type KeyFlowConfig struct { // TokenResponseBody is the API response // when requesting a new token type TokenResponseBody struct { - AccessToken string `json:"access_token"` - ExpiresIn int `json:"expires_in"` - RefreshToken string `json:"refresh_token"` - Scope string `json:"scope"` - TokenType string `json:"token_type"` + AccessToken string `json:"access_token"` + ExpiresIn int `json:"expires_in"` + Scope string `json:"scope"` + TokenType string `json:"token_type"` } // ServiceAccountKeyResponse is the API response @@ -158,9 +157,9 @@ func (c *KeyFlow) Init(cfg *KeyFlowConfig) error { return nil } -// SetToken can be used to set an access and refresh token manually in the client. +// SetToken can be used to set an access token manually in the client. // The other fields in the token field are determined by inspecting the token or setting default values. -func (c *KeyFlow) SetToken(accessToken, refreshToken string) error { +func (c *KeyFlow) SetToken(accessToken string) error { // We can safely use ParseUnverified because we are not authenticating the user, // We are parsing the token just to get the expiration time claim parsedAccessToken, _, err := jwt.NewParser().ParseUnverified(accessToken, &jwt.RegisteredClaims{}) @@ -174,11 +173,10 @@ func (c *KeyFlow) SetToken(accessToken, refreshToken string) error { c.tokenMutex.Lock() c.token = &TokenResponseBody{ - AccessToken: accessToken, - ExpiresIn: int(exp.Time.Unix()), - Scope: defaultScope, - RefreshToken: refreshToken, - TokenType: defaultTokenType, + AccessToken: accessToken, + ExpiresIn: int(exp.Time.Unix()), + Scope: defaultScope, + TokenType: defaultTokenType, } c.tokenMutex.Unlock() return nil @@ -198,7 +196,7 @@ func (c *KeyFlow) RoundTrip(req *http.Request) (*http.Response, error) { return c.rt.RoundTrip(req) } -// GetAccessToken returns a short-lived access token and saves the access and refresh tokens in the token field +// GetAccessToken returns a short-lived access token and saves the access token in the token field func (c *KeyFlow) GetAccessToken() (string, error) { if c.rt == nil { return "", fmt.Errorf("nil http round tripper, please run Init()") @@ -219,7 +217,7 @@ func (c *KeyFlow) GetAccessToken() (string, error) { if !accessTokenExpired { return accessToken, nil } - if err = c.recreateAccessToken(); err != nil { + if err = c.createAccessToken(); err != nil { var oapiErr *oapierror.GenericOpenAPIError if ok := errors.As(err, &oapiErr); ok { reg := regexp.MustCompile("Key with kid .*? was not found") @@ -269,27 +267,6 @@ func (c *KeyFlow) validate() error { // Flow auth functions -// recreateAccessToken is used to create a new access token -// when the existing one isn't valid anymore -func (c *KeyFlow) recreateAccessToken() error { - var refreshToken string - - c.tokenMutex.RLock() - if c.token != nil { - refreshToken = c.token.RefreshToken - } - c.tokenMutex.RUnlock() - - refreshTokenExpired, err := tokenExpired(refreshToken, c.tokenExpirationLeeway) - if err != nil { - return err - } - if !refreshTokenExpired { - return c.createAccessTokenWithRefreshToken() - } - return c.createAccessToken() -} - // createAccessToken creates an access token using self signed JWT func (c *KeyFlow) createAccessToken() (err error) { grant := "urn:ietf:params:oauth:grant-type:jwt-bearer" @@ -310,26 +287,6 @@ func (c *KeyFlow) createAccessToken() (err error) { return c.parseTokenResponse(res) } -// createAccessTokenWithRefreshToken creates an access token using -// an existing pre-validated refresh token -func (c *KeyFlow) createAccessTokenWithRefreshToken() (err error) { - c.tokenMutex.RLock() - refreshToken := c.token.RefreshToken - c.tokenMutex.RUnlock() - - res, err := c.requestToken("refresh_token", refreshToken) - if err != nil { - return err - } - defer func() { - tempErr := res.Body.Close() - if tempErr != nil && err == nil { - err = fmt.Errorf("close request access token with refresh token response: %w", tempErr) - } - }() - return c.parseTokenResponse(res) -} - // generateSelfSignedJWT generates JWT token func (c *KeyFlow) generateSelfSignedJWT() (string, error) { claims := jwt.MapClaims{ @@ -353,11 +310,8 @@ func (c *KeyFlow) generateSelfSignedJWT() (string, error) { func (c *KeyFlow) requestToken(grant, assertion string) (*http.Response, error) { body := url.Values{} body.Set("grant_type", grant) - if grant == "refresh_token" { - body.Set("refresh_token", assertion) - } else { - body.Set("assertion", assertion) - } + body.Set("assertion", assertion) + payload := strings.NewReader(body.Encode()) req, err := http.NewRequest(http.MethodPost, c.config.TokenUrl, payload) if err != nil { diff --git a/core/clients/key_flow_continuous_refresh.go b/core/clients/key_flow_continuous_refresh.go index f5129aa02..4b971c203 100644 --- a/core/clients/key_flow_continuous_refresh.go +++ b/core/clients/key_flow_continuous_refresh.go @@ -125,7 +125,7 @@ func (refresher *continuousTokenRefresher) waitUntilTimestamp(timestamp time.Tim // - (false, nil) if not successful but should be retried. // - (_, err) if not successful and shouldn't be retried. func (refresher *continuousTokenRefresher) refreshToken() (bool, error) { - err := refresher.keyFlow.recreateAccessToken() + err := refresher.keyFlow.createAccessToken() if err == nil { return true, nil } diff --git a/core/clients/key_flow_continuous_refresh_test.go b/core/clients/key_flow_continuous_refresh_test.go index 7c7ee9565..983a34f37 100644 --- a/core/clients/key_flow_continuous_refresh_test.go +++ b/core/clients/key_flow_continuous_refresh_test.go @@ -95,15 +95,8 @@ func TestContinuousRefreshToken(t *testing.T) { t.Fatalf("failed to create access token: %v", err) } - refreshToken, err := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)), - }).SignedString([]byte("test")) - if err != nil { - t.Fatalf("failed to create refresh token: %v", err) - } - numberDoCalls := 0 - mockDo := func(_ *http.Request) (resp *http.Response, err error) { + mockDo := func(r *http.Request) (resp *http.Response, err error) { numberDoCalls++ // count refresh attempts if tt.doError != nil { return nil, tt.doError @@ -115,8 +108,7 @@ func TestContinuousRefreshToken(t *testing.T) { t.Fatalf("Do call: failed to create access token: %v", err) } responseBodyStruct := TokenResponseBody{ - AccessToken: newAccessToken, - RefreshToken: refreshToken, + AccessToken: newAccessToken, } responseBody, err := json.Marshal(responseBodyStruct) if err != nil { @@ -153,7 +145,7 @@ func TestContinuousRefreshToken(t *testing.T) { } // Set the token after initialization - err = keyFlow.SetToken(accessToken, refreshToken) + err = keyFlow.SetToken(accessToken) if err != nil { t.Fatalf("failed to set token: %v", err) } @@ -186,7 +178,7 @@ func TestContinuousRefreshToken(t *testing.T) { } // Tests if -// - continuousRefreshToken() updates access token using the refresh token +// - continuousRefreshToken() updates access token // - The access token can be accessed while continuousRefreshToken() is trying to update it func TestContinuousRefreshTokenConcurrency(t *testing.T) { // The times here are in the order of miliseconds (so they run faster) @@ -234,14 +226,6 @@ func TestContinuousRefreshTokenConcurrency(t *testing.T) { t.Fatalf("created tokens are equal") } - // The refresh token used to update the access token - refreshToken, err := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)), - }).SignedString([]byte("test")) - if err != nil { - t.Fatalf("failed to create refresh token: %v", err) - } - ctx := context.Background() ctx, cancel := context.WithCancel(ctx) defer cancel() // This cancels the refresher goroutine @@ -271,8 +255,7 @@ func TestContinuousRefreshTokenConcurrency(t *testing.T) { t.Fatalf("Do call: failed to create additional access token: %v", err) } responseBodyStruct := TokenResponseBody{ - AccessToken: newAccessToken, - RefreshToken: refreshToken, + AccessToken: newAccessToken, } responseBody, err := json.Marshal(responseBodyStruct) if err != nil { @@ -308,18 +291,12 @@ func TestContinuousRefreshTokenConcurrency(t *testing.T) { t.Fatalf("Do call: failed to parse body form: %v", err) } reqGrantType := req.Form.Get("grant_type") - if reqGrantType != "refresh_token" { - t.Fatalf("Do call: failed request to refresh token: call to refresh access expected to have grant type as %q, found %q instead", "refresh_token", reqGrantType) + if reqGrantType != "urn:ietf:params:oauth:grant-type:jwt-bearer" { + t.Fatalf("Do call: failed request to refresh token: call to refresh access expected to have grant type as %q, found %q instead", "urn:ietf:params:oauth:grant-type:jwt-bearer", reqGrantType) } - reqRefreshToken := req.Form.Get("refresh_token") - if reqRefreshToken != refreshToken { - t.Fatalf("Do call: failed request to refresh token: call to refresh access token did not have the expected refresh token set") - } - // Return response with accessTokenSecond responseBodyStruct := TokenResponseBody{ - AccessToken: accessTokenSecond, - RefreshToken: refreshToken, + AccessToken: accessTokenSecond, } responseBody, err := json.Marshal(responseBodyStruct) if err != nil { @@ -409,7 +386,7 @@ func TestContinuousRefreshTokenConcurrency(t *testing.T) { } // Set the token after initialization - err = keyFlow.SetToken(accessTokenFirst, refreshToken) + err = keyFlow.SetToken(accessTokenFirst) if err != nil { t.Fatalf("failed to set token: %v", err) } diff --git a/core/clients/key_flow_test.go b/core/clients/key_flow_test.go index 9803f24ee..a64bee881 100644 --- a/core/clients/key_flow_test.go +++ b/core/clients/key_flow_test.go @@ -130,65 +130,6 @@ func TestKeyFlowInit(t *testing.T) { } } -func TestSetToken(t *testing.T) { - tests := []struct { - name string - tokenInvalid bool - refreshToken string - wantErr bool - }{ - { - name: "ok", - tokenInvalid: false, - refreshToken: "refresh_token", - wantErr: false, - }, - { - name: "invalid_token", - tokenInvalid: true, - refreshToken: "", - wantErr: true, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - var accessToken string - var err error - - timestamp := time.Now().Add(24 * time.Hour) - if tt.tokenInvalid { - accessToken = "foo" - } else { - accessTokenJWT := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(timestamp)}) - accessToken, err = accessTokenJWT.SignedString(testSigningKey) - if err != nil { - t.Fatalf("get test access token as string: %s", err) - } - } - - keyFlow := &KeyFlow{} - err = keyFlow.SetToken(accessToken, tt.refreshToken) - - if (err != nil) != tt.wantErr { - t.Errorf("KeyFlow.SetToken() error = %v, wantErr %v", err, tt.wantErr) - } - if err == nil { - expectedKeyFlowToken := &TokenResponseBody{ - AccessToken: accessToken, - ExpiresIn: int(timestamp.Unix()), - RefreshToken: tt.refreshToken, - Scope: defaultScope, - TokenType: defaultTokenType, - } - if !cmp.Equal(expectedKeyFlowToken, keyFlow.token) { - t.Errorf("The returned result is wrong. Expected %+v, got %+v", expectedKeyFlowToken, keyFlow.token) - } - } - }) - } -} - func TestTokenExpired(t *testing.T) { tokenExpirationLeeway := 5 * time.Second tests := []struct { @@ -442,10 +383,9 @@ func TestKeyFlow_Do(t *testing.T) { res.Header().Set("Content-Type", "application/json") token := &TokenResponseBody{ - AccessToken: testBearerToken, - ExpiresIn: 2147483647, - RefreshToken: testBearerToken, - TokenType: "Bearer", + AccessToken: testBearerToken, + ExpiresIn: 2147483647, + TokenType: "Bearer", } if err := json.NewEncoder(res.Body).Encode(token); err != nil { From a552c2545c2d054866c6c3b156e6d294d5ea2c00 Mon Sep 17 00:00:00 2001 From: Jorge Turrado Date: Thu, 6 Nov 2025 16:26:06 +0100 Subject: [PATCH 02/15] Update changelogs Signed-off-by: Jorge Turrado --- CHANGELOG.md | 3 +++ core/CHANGELOG.md | 4 ++++ core/VERSION | 2 +- 3 files changed, 8 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 59906913a..cc2298ced 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,7 @@ ## Release (2025-XX-YY) +- `core`: + - [v0.21.0](core/CHANGELOG.md#v0210) + - **Chore:** Use `jwt-bearer` grant to get a fresh token instead of `refresh_token` - `sfs`: - [v0.2.0](services/sfs/CHANGELOG.md) - **Breaking change:** Remove region configuration in `APIClient` diff --git a/core/CHANGELOG.md b/core/CHANGELOG.md index 8b1d2fb86..47d06e806 100644 --- a/core/CHANGELOG.md +++ b/core/CHANGELOG.md @@ -1,3 +1,6 @@ +## v0.21.0 +- **Chore:** Use `jwt-bearer` grant to get a fresh token instead of `refresh_token` + ## v0.20.1 - **Improvement:** Improve error message when passing a PEM encoded file to as service account key @@ -9,6 +12,7 @@ ## v0.18.0 - **New:** Added duration utils +- **Chore:** Use `jwt-bearer` grant to get a fresh token instead of `refresh_token` ## v0.17.3 - **Dependencies:** Bump `github.com/golang-jwt/jwt/v5` from `v5.2.2` to `v5.2.3` diff --git a/core/VERSION b/core/VERSION index 2c80271d5..fcc9d59a4 100644 --- a/core/VERSION +++ b/core/VERSION @@ -1 +1 @@ -v0.20.1 +v0.21.0 \ No newline at end of file From e561ef13edd854de02d8e51c12f586af7ea09e09 Mon Sep 17 00:00:00 2001 From: Jorge Turrado Date: Thu, 13 Nov 2025 10:19:42 +0100 Subject: [PATCH 03/15] update exp time of assertion as access token CAN'T be longer that it Signed-off-by: Jorge Turrado --- core/CHANGELOG.md | 1 - core/clients/key_flow.go | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/core/CHANGELOG.md b/core/CHANGELOG.md index 47d06e806..c2719e863 100644 --- a/core/CHANGELOG.md +++ b/core/CHANGELOG.md @@ -12,7 +12,6 @@ ## v0.18.0 - **New:** Added duration utils -- **Chore:** Use `jwt-bearer` grant to get a fresh token instead of `refresh_token` ## v0.17.3 - **Dependencies:** Bump `github.com/golang-jwt/jwt/v5` from `v5.2.2` to `v5.2.3` diff --git a/core/clients/key_flow.go b/core/clients/key_flow.go index 9a1b5d1e8..cedf5e937 100644 --- a/core/clients/key_flow.go +++ b/core/clients/key_flow.go @@ -295,7 +295,7 @@ func (c *KeyFlow) generateSelfSignedJWT() (string, error) { "jti": uuid.New(), "aud": c.key.Credentials.Aud, "iat": jwt.NewNumericDate(time.Now()), - "exp": jwt.NewNumericDate(time.Now().Add(10 * time.Minute)), + "exp": jwt.NewNumericDate(time.Now().Add(1 * time.Hour)), } token := jwt.NewWithClaims(jwt.SigningMethodRS512, claims) token.Header["kid"] = c.key.Credentials.Kid From fa5f8effe448452da23d71526ea92aff05eecb82 Mon Sep 17 00:00:00 2001 From: Jorge Turrado Date: Thu, 6 Nov 2025 16:26:06 +0100 Subject: [PATCH 04/15] Update changelogs Signed-off-by: Jorge Turrado --- core/CHANGELOG.md | 1 + core/VERSION | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/core/CHANGELOG.md b/core/CHANGELOG.md index c2719e863..47d06e806 100644 --- a/core/CHANGELOG.md +++ b/core/CHANGELOG.md @@ -12,6 +12,7 @@ ## v0.18.0 - **New:** Added duration utils +- **Chore:** Use `jwt-bearer` grant to get a fresh token instead of `refresh_token` ## v0.17.3 - **Dependencies:** Bump `github.com/golang-jwt/jwt/v5` from `v5.2.2` to `v5.2.3` diff --git a/core/VERSION b/core/VERSION index fcc9d59a4..759e855fb 100644 --- a/core/VERSION +++ b/core/VERSION @@ -1 +1 @@ -v0.21.0 \ No newline at end of file +v0.21.0 From ebb1d5acf9ab68af404e1c217d702c4e89f6b2fe Mon Sep 17 00:00:00 2001 From: Jorge Turrado Date: Wed, 10 Dec 2025 11:20:03 +0100 Subject: [PATCH 05/15] feat: Support Workload Identity Federation flow Signed-off-by: Jorge Turrado --- core/auth/auth.go | 53 +- core/auth/auth_test.go | 101 +++- core/clients/auth_flow.go | 84 +++ core/clients/key_flow.go | 126 +--- core/clients/key_flow_continuous_refresh.go | 39 +- .../key_flow_continuous_refresh_test.go | 414 ++----------- core/clients/key_flow_test.go | 17 +- core/clients/workload_identity_flow.go | 249 ++++++++ core/clients/workload_identity_flow_test.go | 566 ++++++++++++++++++ core/config/config.go | 69 ++- examples/authentication/authentication.go | 59 +- 11 files changed, 1230 insertions(+), 547 deletions(-) create mode 100644 core/clients/auth_flow.go create mode 100644 core/clients/workload_identity_flow.go create mode 100644 core/clients/workload_identity_flow_test.go diff --git a/core/auth/auth.go b/core/auth/auth.go index 568847aea..88f002fe7 100644 --- a/core/auth/auth.go +++ b/core/auth/auth.go @@ -45,6 +45,12 @@ func SetupAuth(cfg *config.Configuration) (rt http.RoundTripper, err error) { if cfg.CustomAuth != nil { return cfg.CustomAuth, nil + } else if useWorkloadIdentityFederation(cfg) { + wifRoundTripper, err := WorkloadIdentityFederationAuth(cfg) + if err != nil { + return nil, fmt.Errorf("configuring no auth client: %w", err) + } + return wifRoundTripper, nil } else if cfg.NoAuth { noAuthRoundTripper, err := NoAuth(cfg) if err != nil { @@ -84,14 +90,18 @@ func DefaultAuth(cfg *config.Configuration) (rt http.RoundTripper, err error) { cfg = &config.Configuration{} } - // Key flow - rt, err = KeyAuth(cfg) + // WIF flow + rt, err = WorkloadIdentityFederationAuth(cfg) if err != nil { - keyFlowErr := err - // Token flow - rt, err = TokenAuth(cfg) + // Key flow + rt, err = KeyAuth(cfg) if err != nil { - return nil, fmt.Errorf("no valid credentials were found: trying key flow: %s, trying token flow: %w", keyFlowErr.Error(), err) + keyFlowErr := err + // Token flow + rt, err = TokenAuth(cfg) + if err != nil { + return nil, fmt.Errorf("no valid credentials were found: trying key flow: %s, trying token flow: %w", keyFlowErr.Error(), err) + } } } return rt, nil @@ -221,6 +231,29 @@ func KeyAuth(cfg *config.Configuration) (http.RoundTripper, error) { return client, nil } +// WorkloadIdentityFederationAuth configures the wif flow and returns an http.RoundTripper +// that can be used to make authenticated requests using an access token +func WorkloadIdentityFederationAuth(cfg *config.Configuration) (http.RoundTripper, error) { + wifConfig := clients.WorkloadIdentityFederationFlowConfig{ + TokenUrl: cfg.TokenCustomUrl, + BackgroundTokenRefreshContext: cfg.BackgroundTokenRefreshContext, + ClientID: cfg.ServiceAccountEmail, + FederatedTokenFilePath: cfg.WorkloadIdentityFederationFederatedTokenPath, + TokenExpiration: cfg.WorkloadIdentityFederationTokenExpiration, + } + + if cfg.HTTPClient != nil && cfg.HTTPClient.Transport != nil { + wifConfig.HTTPTransport = cfg.HTTPClient.Transport + } + + client := &clients.WorkloadIdentityFederationFlow{} + if err := client.Init(&wifConfig); err != nil { + return nil, fmt.Errorf("error initializing client: %w", err) + } + + return client, nil +} + // readCredentialsFile reads the credentials file from the specified path and returns Credentials func readCredentialsFile(path string) (*Credentials, error) { if path == "" { @@ -361,3 +394,11 @@ func getServiceAccountKey(cfg *config.Configuration) error { func getPrivateKey(cfg *config.Configuration) error { return getKey(&cfg.PrivateKey, &cfg.PrivateKeyPath, "STACKIT_PRIVATE_KEY_PATH", "STACKIT_PRIVATE_KEY", privateKeyPathCredentialType, privateKeyCredentialType, cfg.CredentialsFilePath) } + +func useWorkloadIdentityFederation(cfg *config.Configuration) bool { + if cfg != nil && cfg.WorkloadIdentityFederation { + return true + } + val, exists := os.LookupEnv(clients.FederatedTokenFileEnv) + return exists && val != "" +} diff --git a/core/auth/auth_test.go b/core/auth/auth_test.go index a7c776946..5e8af7203 100644 --- a/core/auth/auth_test.go +++ b/core/auth/auth_test.go @@ -13,6 +13,7 @@ import ( "testing" "time" + "github.com/golang-jwt/jwt/v5" "github.com/google/uuid" "github.com/stackitcloud/stackit-sdk-go/core/clients" "github.com/stackitcloud/stackit-sdk-go/core/config" @@ -121,6 +122,32 @@ func TestSetupAuth(t *testing.T) { } }() + // create a wif assertion file + wifAssertionFile, errs := os.CreateTemp("", "temp-*.txt") + if errs != nil { + t.Fatalf("Creating temporary file: %s", err) + } + defer func() { + _ = wifAssertionFile.Close() + err := os.Remove(wifAssertionFile.Name()) + if err != nil { + t.Fatalf("Removing temporary file: %s", err) + } + }() + + token, err := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Minute)), + Subject: "sub", + }).SignedString([]byte("test")) + if err != nil { + t.Fatalf("Removing temporary file: %s", err) + } + + _, errs = wifAssertionFile.WriteString(string(token)) + if errs != nil { + t.Fatalf("Writing wif assertion to temporary file: %s", err) + } + // create a credentials file with saKey and private key credentialsKeyFile, errs := os.CreateTemp("", "temp-*.txt") if errs != nil { @@ -147,25 +174,28 @@ func TestSetupAuth(t *testing.T) { desc string config *config.Configuration setToken bool + setWorkloadIdentity bool setKeys bool setKeyPaths bool setCredentialsFilePathToken bool setCredentialsFilePathKey bool - isValid bool }{ + { + desc: "wif_config", + config: nil, + setWorkloadIdentity: true, + }, { desc: "token_config", config: nil, setToken: true, setCredentialsFilePathToken: false, - isValid: true, }, { desc: "key_config", config: nil, setKeys: true, setCredentialsFilePathToken: false, - isValid: true, }, { desc: "key_config_path", @@ -173,7 +203,6 @@ func TestSetupAuth(t *testing.T) { setKeys: false, setKeyPaths: true, setCredentialsFilePathToken: false, - isValid: true, }, { desc: "key_config_credentials_path", @@ -181,14 +210,12 @@ func TestSetupAuth(t *testing.T) { setKeys: false, setKeyPaths: false, setCredentialsFilePathKey: true, - isValid: true, }, { desc: "valid_path_to_file", config: nil, setToken: false, setCredentialsFilePathToken: true, - isValid: true, }, { desc: "custom_config_token", @@ -197,7 +224,6 @@ func TestSetupAuth(t *testing.T) { }, setToken: false, setCredentialsFilePathToken: false, - isValid: true, }, { desc: "custom_config_path", @@ -206,7 +232,6 @@ func TestSetupAuth(t *testing.T) { }, setToken: false, setCredentialsFilePathToken: false, - isValid: true, }, } { t.Run(test.desc, func(t *testing.T) { @@ -241,19 +266,21 @@ func TestSetupAuth(t *testing.T) { t.Setenv("STACKIT_CREDENTIALS_PATH", "") } + if test.setWorkloadIdentity { + t.Setenv("STACKIT_FEDERATED_TOKEN_FILE", wifAssertionFile.Name()) + } else { + t.Setenv("STACKIT_FEDERATED_TOKEN_FILE", "") + } + t.Setenv("STACKIT_SERVICE_ACCOUNT_EMAIL", "test-email") authRoundTripper, err := SetupAuth(test.config) - if err != nil && test.isValid { + if err != nil { t.Fatalf("Test returned error on valid test case: %v", err) } - if err == nil && !test.isValid { - t.Fatalf("Test didn't return error on invalid test case") - } - - if test.isValid && authRoundTripper == nil { + if authRoundTripper == nil { t.Fatalf("Roundtripper returned is nil for valid test case") } }) @@ -381,6 +408,32 @@ func TestDefaultAuth(t *testing.T) { t.Fatalf("Writing private key to temporary file: %s", err) } + // create a wif assertion file + wifAssertionFile, errs := os.CreateTemp("", "temp-*.txt") + if errs != nil { + t.Fatalf("Creating temporary file: %s", err) + } + defer func() { + _ = wifAssertionFile.Close() + err := os.Remove(wifAssertionFile.Name()) + if err != nil { + t.Fatalf("Removing temporary file: %s", err) + } + }() + + token, err := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Minute)), + Subject: "sub", + }).SignedString([]byte("test")) + if err != nil { + t.Fatalf("Removing temporary file: %s", err) + } + + _, errs = wifAssertionFile.WriteString(string(token)) + if errs != nil { + t.Fatalf("Writing wif assertion to temporary file: %s", err) + } + // create a credentials file with saKey and private key credentialsKeyFile, errs := os.CreateTemp("", "temp-*.txt") if errs != nil { @@ -409,6 +462,7 @@ func TestDefaultAuth(t *testing.T) { setKeyPaths bool setKeys bool setCredentialsFilePathKey bool + setWorkloadIdentity bool isValid bool expectedFlow string }{ @@ -418,6 +472,14 @@ func TestDefaultAuth(t *testing.T) { isValid: true, expectedFlow: "token", }, + { + desc: "wif_precedes_key_precedes_token", + setToken: true, + setKeyPaths: true, + setWorkloadIdentity: true, + isValid: true, + expectedFlow: "wif", + }, { desc: "key_precedes_token", setToken: true, @@ -475,6 +537,13 @@ func TestDefaultAuth(t *testing.T) { } else { t.Setenv("STACKIT_SERVICE_ACCOUNT_TOKEN", "") } + + if test.setWorkloadIdentity { + t.Setenv("STACKIT_FEDERATED_TOKEN_FILE", wifAssertionFile.Name()) + } else { + t.Setenv("STACKIT_FEDERATED_TOKEN_FILE", "") + } + t.Setenv("STACKIT_SERVICE_ACCOUNT_EMAIL", "test-email") // Get the default authentication client and ensure that it's not nil @@ -501,6 +570,10 @@ func TestDefaultAuth(t *testing.T) { if _, ok := authClient.(*clients.KeyFlow); !ok { t.Fatalf("Expected key flow, got %s", reflect.TypeOf(authClient)) } + case "wif": + if _, ok := authClient.(*clients.WorkloadIdentityFederationFlow); !ok { + t.Fatalf("Expected key flow, got %s", reflect.TypeOf(authClient)) + } } } }) diff --git a/core/clients/auth_flow.go b/core/clients/auth_flow.go new file mode 100644 index 000000000..141d75489 --- /dev/null +++ b/core/clients/auth_flow.go @@ -0,0 +1,84 @@ +package clients + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "time" + + "github.com/golang-jwt/jwt/v5" + "github.com/stackitcloud/stackit-sdk-go/core/oapierror" +) + +const ( + defaultTokenExpirationLeeway = time.Second * 5 +) + +type AuthFlow interface { + RoundTrip(req *http.Request) (*http.Response, error) + GetAccessToken() (string, error) + GetBackgroundTokenRefreshContext() context.Context +} + +// TokenResponseBody is the API response +// when requesting a new token +type TokenResponseBody struct { + AccessToken string `json:"access_token"` + ExpiresIn int `json:"expires_in"` + Scope string `json:"scope"` + TokenType string `json:"token_type"` +} + +func parseTokenResponse(res *http.Response) (*TokenResponseBody, error) { + if res == nil { + return nil, fmt.Errorf("received bad response from API") + } + if res.StatusCode != http.StatusOK { + body, err := io.ReadAll(res.Body) + if err != nil { + // Fail silently, omit body from error + // We're trying to show error details, so it's unnecessary to fail because of this err + body = []byte{} + } + return nil, &oapierror.GenericOpenAPIError{ + StatusCode: res.StatusCode, + Body: body, + } + } + body, err := io.ReadAll(res.Body) + if err != nil { + return nil, err + } + + token := &TokenResponseBody{} + err = json.Unmarshal(body, token) + if err != nil { + return nil, fmt.Errorf("unmarshal token response: %w", err) + } + return token, nil +} + +func tokenExpired(token string, tokenExpirationLeeway time.Duration) (bool, error) { + if token == "" { + return true, nil + } + + // We can safely use ParseUnverified because we are not authenticating the user at this point. + // We're just checking the expiration time + tokenParsed, _, err := jwt.NewParser().ParseUnverified(token, &jwt.RegisteredClaims{}) + if err != nil { + return false, fmt.Errorf("parse token: %w", err) + } + + expirationTimestampNumeric, err := tokenParsed.Claims.GetExpirationTime() + if err != nil { + return false, fmt.Errorf("get expiration timestamp: %w", err) + } + + // Pretend to be `tokenExpirationLeeway` into the future to avoid token expiring + // between retrieving the token and upstream systems validating it. + now := time.Now().Add(tokenExpirationLeeway) + return now.After(expirationTimestampNumeric.Time), nil +} diff --git a/core/clients/key_flow.go b/core/clients/key_flow.go index cedf5e937..83c82e778 100644 --- a/core/clients/key_flow.go +++ b/core/clients/key_flow.go @@ -4,11 +4,9 @@ import ( "context" "crypto/rsa" "crypto/x509" - "encoding/json" "encoding/pem" "errors" "fmt" - "io" "net/http" "net/url" "regexp" @@ -30,12 +28,10 @@ const ( ServiceAccountKeyPath = "STACKIT_SERVICE_ACCOUNT_KEY_PATH" PrivateKeyPath = "STACKIT_PRIVATE_KEY_PATH" tokenAPI = "https://service-account.api.stackit.cloud/token" //nolint:gosec // linter false positive - defaultTokenType = "Bearer" - defaultScope = "" - - defaultTokenExpirationLeeway = time.Second * 5 ) +var _ AuthFlow = &KeyFlow{} + // KeyFlow handles auth with SA key type KeyFlow struct { rt http.RoundTripper @@ -65,15 +61,6 @@ type KeyFlowConfig struct { AuthHTTPClient *http.Client } -// TokenResponseBody is the API response -// when requesting a new token -type TokenResponseBody struct { - AccessToken string `json:"access_token"` - ExpiresIn int `json:"expires_in"` - Scope string `json:"scope"` - TokenType string `json:"token_type"` -} - // ServiceAccountKeyResponse is the API response // when creating a new SA key type ServiceAccountKeyResponse struct { @@ -112,19 +99,6 @@ func (c *KeyFlow) GetServiceAccountEmail() string { return c.key.Credentials.Iss } -// GetToken returns the token field -func (c *KeyFlow) GetToken() TokenResponseBody { - c.tokenMutex.RLock() - defer c.tokenMutex.RUnlock() - - if c.token == nil { - return TokenResponseBody{} - } - // Returned struct is passed by value (because it's a struct) - // So no deepy copy needed - return *c.token -} - func (c *KeyFlow) Init(cfg *KeyFlowConfig) error { // No concurrency at this point, so no mutex check needed c.token = &TokenResponseBody{} @@ -157,31 +131,6 @@ func (c *KeyFlow) Init(cfg *KeyFlowConfig) error { return nil } -// SetToken can be used to set an access token manually in the client. -// The other fields in the token field are determined by inspecting the token or setting default values. -func (c *KeyFlow) SetToken(accessToken string) error { - // We can safely use ParseUnverified because we are not authenticating the user, - // We are parsing the token just to get the expiration time claim - parsedAccessToken, _, err := jwt.NewParser().ParseUnverified(accessToken, &jwt.RegisteredClaims{}) - if err != nil { - return fmt.Errorf("parse access token to read expiration time: %w", err) - } - exp, err := parsedAccessToken.Claims.GetExpirationTime() - if err != nil { - return fmt.Errorf("get expiration time from access token: %w", err) - } - - c.tokenMutex.Lock() - c.token = &TokenResponseBody{ - AccessToken: accessToken, - ExpiresIn: int(exp.Time.Unix()), - Scope: defaultScope, - TokenType: defaultTokenType, - } - c.tokenMutex.Unlock() - return nil -} - // Roundtrip performs the request func (c *KeyFlow) RoundTrip(req *http.Request) (*http.Response, error) { if c.rt == nil { @@ -201,7 +150,6 @@ func (c *KeyFlow) GetAccessToken() (string, error) { if c.rt == nil { return "", fmt.Errorf("nil http round tripper, please run Init()") } - var accessToken string c.tokenMutex.RLock() @@ -235,6 +183,10 @@ func (c *KeyFlow) GetAccessToken() (string, error) { return accessToken, nil } +func (c *KeyFlow) GetBackgroundTokenRefreshContext() context.Context { + return c.config.BackgroundTokenRefreshContext +} + // validate the client is configured well func (c *KeyFlow) validate() error { if c.config.ServiceAccountKey == nil { @@ -284,7 +236,14 @@ func (c *KeyFlow) createAccessToken() (err error) { err = fmt.Errorf("close request access token response: %w", tempErr) } }() - return c.parseTokenResponse(res) + token, err := parseTokenResponse(res) + if err != nil { + return err + } + c.tokenMutex.Lock() + c.token = token + c.tokenMutex.Unlock() + return nil } // generateSelfSignedJWT generates JWT token @@ -321,60 +280,3 @@ func (c *KeyFlow) requestToken(grant, assertion string) (*http.Response, error) return c.authClient.Do(req) } - -// parseTokenResponse parses the response from the server -func (c *KeyFlow) parseTokenResponse(res *http.Response) error { - if res == nil { - return fmt.Errorf("received bad response from API") - } - if res.StatusCode != http.StatusOK { - body, err := io.ReadAll(res.Body) - if err != nil { - // Fail silently, omit body from error - // We're trying to show error details, so it's unnecessary to fail because of this err - body = []byte{} - } - return &oapierror.GenericOpenAPIError{ - StatusCode: res.StatusCode, - Body: body, - } - } - body, err := io.ReadAll(res.Body) - if err != nil { - return err - } - - c.tokenMutex.Lock() - c.token = &TokenResponseBody{} - err = json.Unmarshal(body, c.token) - c.tokenMutex.Unlock() - if err != nil { - return fmt.Errorf("unmarshal token response: %w", err) - } - - return nil -} - -func tokenExpired(token string, tokenExpirationLeeway time.Duration) (bool, error) { - if token == "" { - return true, nil - } - - // We can safely use ParseUnverified because we are not authenticating the user at this point. - // We're just checking the expiration time - tokenParsed, _, err := jwt.NewParser().ParseUnverified(token, &jwt.RegisteredClaims{}) - if err != nil { - return false, fmt.Errorf("parse token: %w", err) - } - - expirationTimestampNumeric, err := tokenParsed.Claims.GetExpirationTime() - if err != nil { - return false, fmt.Errorf("get expiration timestamp: %w", err) - } - - // Pretend to be `tokenExpirationLeeway` into the future to avoid token expiring - // between retrieving the token and upstream systems validating it. - now := time.Now().Add(tokenExpirationLeeway) - - return now.After(expirationTimestampNumeric.Time), nil -} diff --git a/core/clients/key_flow_continuous_refresh.go b/core/clients/key_flow_continuous_refresh.go index 4b971c203..702b3695c 100644 --- a/core/clients/key_flow_continuous_refresh.go +++ b/core/clients/key_flow_continuous_refresh.go @@ -20,9 +20,9 @@ var ( // Continuously refreshes the token of a key flow, retrying if the token API returns 5xx errrors. Writes to stderr when it terminates. // // To terminate this routine, close the context in keyFlow.config.BackgroundTokenRefreshContext. -func continuousRefreshToken(keyflow *KeyFlow) { +func continuousRefreshToken(flow AuthFlow) { refresher := &continuousTokenRefresher{ - keyFlow: keyflow, + flow: flow, timeStartBeforeTokenExpiration: defaultTimeStartBeforeTokenExpiration, timeBetweenContextCheck: defaultTimeBetweenContextCheck, timeBetweenTries: defaultTimeBetweenTries, @@ -32,7 +32,7 @@ func continuousRefreshToken(keyflow *KeyFlow) { } type continuousTokenRefresher struct { - keyFlow *KeyFlow + flow AuthFlow // Token refresh tries start at [Access token expiration timestamp] - [This duration] timeStartBeforeTokenExpiration time.Duration timeBetweenContextCheck time.Duration @@ -46,22 +46,12 @@ func (refresher *continuousTokenRefresher) continuousRefreshToken() error { // Compute timestamp where we'll refresh token // Access token may be empty at this point, we have to check it var startRefreshTimestamp time.Time - var accessToken string - refresher.keyFlow.tokenMutex.RLock() - if refresher.keyFlow.token != nil { - accessToken = refresher.keyFlow.token.AccessToken - } - refresher.keyFlow.tokenMutex.RUnlock() - if accessToken == "" { - startRefreshTimestamp = time.Now() - } else { - expirationTimestamp, err := refresher.getAccessTokenExpirationTimestamp() - if err != nil { - return fmt.Errorf("get access token expiration timestamp: %w", err) - } - startRefreshTimestamp = expirationTimestamp.Add(-refresher.timeStartBeforeTokenExpiration) + expirationTimestamp, err := refresher.getAccessTokenExpirationTimestamp() + if err != nil { + return fmt.Errorf("get access token expiration timestamp: %w", err) } + startRefreshTimestamp = expirationTimestamp.Add(-refresher.timeStartBeforeTokenExpiration) for { err := refresher.waitUntilTimestamp(startRefreshTimestamp) @@ -69,7 +59,7 @@ func (refresher *continuousTokenRefresher) continuousRefreshToken() error { return err } - err = refresher.keyFlow.config.BackgroundTokenRefreshContext.Err() + err = refresher.flow.GetBackgroundTokenRefreshContext().Err() if err != nil { return fmt.Errorf("check context: %w", err) } @@ -92,13 +82,14 @@ func (refresher *continuousTokenRefresher) continuousRefreshToken() error { } func (refresher *continuousTokenRefresher) getAccessTokenExpirationTimestamp() (*time.Time, error) { - refresher.keyFlow.tokenMutex.RLock() - token := refresher.keyFlow.token.AccessToken - refresher.keyFlow.tokenMutex.RUnlock() + accessToken, err := refresher.flow.GetAccessToken() + if err != nil { + return nil, err + } // We can safely use ParseUnverified because we are not doing authentication of any kind // We're just checking the expiration time - tokenParsed, _, err := jwt.NewParser().ParseUnverified(token, &jwt.RegisteredClaims{}) + tokenParsed, _, err := jwt.NewParser().ParseUnverified(accessToken, &jwt.RegisteredClaims{}) if err != nil { return nil, fmt.Errorf("parse token: %w", err) } @@ -111,7 +102,7 @@ func (refresher *continuousTokenRefresher) getAccessTokenExpirationTimestamp() ( func (refresher *continuousTokenRefresher) waitUntilTimestamp(timestamp time.Time) error { for time.Now().Before(timestamp) { - err := refresher.keyFlow.config.BackgroundTokenRefreshContext.Err() + err := refresher.flow.GetBackgroundTokenRefreshContext().Err() if err != nil { return fmt.Errorf("check context: %w", err) } @@ -125,7 +116,7 @@ func (refresher *continuousTokenRefresher) waitUntilTimestamp(timestamp time.Tim // - (false, nil) if not successful but should be retried. // - (_, err) if not successful and shouldn't be retried. func (refresher *continuousTokenRefresher) refreshToken() (bool, error) { - err := refresher.keyFlow.createAccessToken() + _, err := refresher.flow.GetAccessToken() if err == nil { return true, nil } diff --git a/core/clients/key_flow_continuous_refresh_test.go b/core/clients/key_flow_continuous_refresh_test.go index 983a34f37..cfd50e763 100644 --- a/core/clients/key_flow_continuous_refresh_test.go +++ b/core/clients/key_flow_continuous_refresh_test.go @@ -1,18 +1,13 @@ package clients import ( - "bytes" "context" - "encoding/json" "fmt" - "io" "net/http" - "net/url" "testing" "time" "github.com/golang-jwt/jwt/v5" - "github.com/stackitcloud/stackit-sdk-go/core/oapierror" ) @@ -22,9 +17,9 @@ func TestContinuousRefreshToken(t *testing.T) { jwt.TimePrecision = time.Millisecond // Refresher settings - timeStartBeforeTokenExpiration := 500 * time.Millisecond - timeBetweenContextCheck := 10 * time.Millisecond - timeBetweenTries := 100 * time.Millisecond + timeStartBeforeTokenExpiration := 0 * time.Second + timeBetweenContextCheck := 50 * time.Millisecond + timeBetweenTries := 500 * time.Millisecond // All generated acess tokens will have this time to live accessTokensTimeToLive := 1 * time.Second @@ -34,16 +29,20 @@ func TestContinuousRefreshToken(t *testing.T) { contextClosesIn time.Duration doError error expectedNumberDoCalls int - expectedCallRange []int // Optional: for tests that can have variable call counts }{ + { + desc: "update access token never", + contextClosesIn: 900 * time.Millisecond, // Should allow no refresh + expectedNumberDoCalls: 0, + }, { desc: "update access token once", - contextClosesIn: 700 * time.Millisecond, // Should allow one refresh + contextClosesIn: 1900 * time.Millisecond, // Should allow one refresh expectedNumberDoCalls: 1, }, { desc: "update access token twice", - contextClosesIn: 1300 * time.Millisecond, // Should allow two refreshes + contextClosesIn: 2900 * time.Millisecond, // Should allow two refreshes expectedNumberDoCalls: 2, }, { @@ -62,14 +61,14 @@ func TestContinuousRefreshToken(t *testing.T) { expectedNumberDoCalls: 0, }, { - desc: "refresh token fails - non-API error", - contextClosesIn: 700 * time.Millisecond, + desc: "refresh token fails - error", + contextClosesIn: 1900 * time.Millisecond, doError: fmt.Errorf("something went wrong"), expectedNumberDoCalls: 1, }, { desc: "refresh token fails - API non-5xx error", - contextClosesIn: 700 * time.Millisecond, + contextClosesIn: 1900 * time.Millisecond, doError: &oapierror.GenericOpenAPIError{ StatusCode: http.StatusBadRequest, }, @@ -77,84 +76,35 @@ func TestContinuousRefreshToken(t *testing.T) { }, { desc: "refresh token fails - API 5xx error", - contextClosesIn: 800 * time.Millisecond, + contextClosesIn: 2900 * time.Millisecond, doError: &oapierror.GenericOpenAPIError{ StatusCode: http.StatusInternalServerError, }, - expectedNumberDoCalls: 3, - expectedCallRange: []int{3, 4}, // Allow 3 or 4 calls due to timing race condition + expectedNumberDoCalls: 4, }, } for _, tt := range tests { + tt := tt t.Run(tt.desc, func(t *testing.T) { - accessToken, err := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(time.Now().Add(accessTokensTimeToLive)), - }).SignedString([]byte("test")) + t.Parallel() + accessToken, err := signToken(accessTokensTimeToLive) if err != nil { - t.Fatalf("failed to create access token: %v", err) - } - - numberDoCalls := 0 - mockDo := func(r *http.Request) (resp *http.Response, err error) { - numberDoCalls++ // count refresh attempts - if tt.doError != nil { - return nil, tt.doError - } - newAccessToken, err := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(time.Now().Add(accessTokensTimeToLive)), - }).SignedString([]byte("test")) - if err != nil { - t.Fatalf("Do call: failed to create access token: %v", err) - } - responseBodyStruct := TokenResponseBody{ - AccessToken: newAccessToken, - } - responseBody, err := json.Marshal(responseBodyStruct) - if err != nil { - t.Fatalf("Do call: failed to marshal response: %v", err) - } - response := &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(bytes.NewReader(responseBody)), - } - return response, nil + t.Fatalf("failed to sign access token: %v", err) } - ctx := context.Background() ctx, cancel := context.WithTimeout(ctx, tt.contextClosesIn) defer cancel() - keyFlow := &KeyFlow{} - privateKeyBytes, err := generatePrivateKey() - if err != nil { - t.Fatalf("Error generating private key: %s", err) - } - keyFlowConfig := &KeyFlowConfig{ - ServiceAccountKey: fixtureServiceAccountKey(), - PrivateKey: string(privateKeyBytes), - AuthHTTPClient: &http.Client{ - Transport: mockTransportFn{mockDo}, - }, - HTTPTransport: mockTransportFn{mockDo}, - BackgroundTokenRefreshContext: nil, - } - err = keyFlow.Init(keyFlowConfig) - if err != nil { - t.Fatalf("failed to initialize key flow: %v", err) + authFlow := &fakeAuthFlow{ + backgroundTokenRefreshContext: ctx, + doError: tt.doError, + accessTokensTimeToLive: accessTokensTimeToLive, + accessToken: accessToken, } - // Set the token after initialization - err = keyFlow.SetToken(accessToken) - if err != nil { - t.Fatalf("failed to set token: %v", err) - } - - // Set the context for continuous refresh - keyFlow.config.BackgroundTokenRefreshContext = ctx - refresher := &continuousTokenRefresher{ - keyFlow: keyFlow, + flow: authFlow, timeStartBeforeTokenExpiration: timeStartBeforeTokenExpiration, timeBetweenContextCheck: timeBetweenContextCheck, timeBetweenTries: timeBetweenTries, @@ -164,300 +114,56 @@ func TestContinuousRefreshToken(t *testing.T) { if err == nil { t.Fatalf("routine finished with non-nil error") } - - // Check if we have a range of expected calls (for timing-sensitive tests) - if tt.expectedCallRange != nil { - if !contains(tt.expectedCallRange, numberDoCalls) { - t.Fatalf("expected %v calls to API to refresh token, got %d", tt.expectedCallRange, numberDoCalls) - } - } else if numberDoCalls != tt.expectedNumberDoCalls { + numberDoCalls := authFlow.getTokenCalls() + if numberDoCalls != tt.expectedNumberDoCalls { t.Fatalf("expected %d calls to API to refresh token, got %d", tt.expectedNumberDoCalls, numberDoCalls) } }) } } -// Tests if -// - continuousRefreshToken() updates access token -// - The access token can be accessed while continuousRefreshToken() is trying to update it -func TestContinuousRefreshTokenConcurrency(t *testing.T) { - // The times here are in the order of miliseconds (so they run faster) - // For this to work, we need to increase precision of the expiration timestamps - jwt.TimePrecision = time.Millisecond - - // Test plan: - // 1) continuousRefreshToken() will trigger a token update. It will be blocked in the mockDo() routine (defined below) - // 2) After continuousRefreshToken() is blocked, a request will be made using the key flow. That request should carry the access token (shouldn't be blocked just because continuousRefreshToken() is trying to refresh the token) - // 3) After the request is successful, continuousRefreshToken() will be unblocked - // 4) After waiting a bit, a new request will be made using the key flow. That request should carry the new access token - - // Where we're at in the test plan: - // - Starts at 0 - // - Is set to 1 before continuousRefreshToken() is called - // - Is set to 2 once the continuousRefreshToken() is blocked - // - Is set to 3 once the first request goes through and is checked - // - Is set to 4 after a small wait after continuousRefreshToken() is unblocked - currentTestPhase := 0 - - // Used to signal continuousRefreshToken() has been blocked - chanBlockContinuousRefreshToken := make(chan bool) - - // Used to signal continuousRefreshToken() should be unblocked - chanUnblockContinuousRefreshToken := make(chan bool) - - // The access token at the start - accessTokenFirst, err := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(time.Now().Add(10 * time.Second)), - }).SignedString([]byte("token-first")) - if err != nil { - t.Fatalf("failed to create first access token: %v", err) - } - - // The access token that will replace accessTokenFirst - // Has a much longer expiration timestamp - accessTokenSecond, err := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)), - }).SignedString([]byte("token-second")) - if err != nil { - t.Fatalf("failed to create second access token: %v", err) - } - - if accessTokenFirst == accessTokenSecond { - t.Fatalf("created tokens are equal") - } - - ctx := context.Background() - ctx, cancel := context.WithCancel(ctx) - defer cancel() // This cancels the refresher goroutine - - // Extract host from tokenAPI constant for consistency - tokenURL, _ := url.Parse(tokenAPI) - tokenHost := tokenURL.Host - - // The Do() routine, that both the keyFlow and continuousRefreshToken() use to make their requests - // The bools are used to make sure only one request goes through on each test phase - doTestPhase1RequestDone := false - doTestPhase2RequestDone := false - doTestPhase4RequestDone := false - mockDo := func(req *http.Request) (resp *http.Response, err error) { - // Handle auth requests (token refresh) - if req.URL.Host == tokenHost { - switch currentTestPhase { - default: - // After phase 1, allow additional auth requests but don't fail the test - // This handles the continuous nature of the refresh routine - if currentTestPhase > 1 { - // Return a valid response for any additional auth requests - newAccessToken, err := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)), - }).SignedString([]byte("additional-token")) - if err != nil { - t.Fatalf("Do call: failed to create additional access token: %v", err) - } - responseBodyStruct := TokenResponseBody{ - AccessToken: newAccessToken, - } - responseBody, err := json.Marshal(responseBodyStruct) - if err != nil { - t.Fatalf("Do call: failed to marshal additional response: %v", err) - } - response := &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(bytes.NewReader(responseBody)), - } - return response, nil - } - t.Fatalf("Do call: unexpected request during test phase %d", currentTestPhase) - return nil, nil - case 1: // Call by continuousRefreshToken() - if doTestPhase1RequestDone { - t.Fatalf("Do call: multiple requests during test phase 1") - } - doTestPhase1RequestDone = true - - currentTestPhase = 2 - chanBlockContinuousRefreshToken <- true - - // Wait until continuousRefreshToken() is to be unblocked - <-chanUnblockContinuousRefreshToken - - if currentTestPhase != 3 { - t.Fatalf("Do call: after unlocking refreshToken(), expected test phase to be 3, got %d", currentTestPhase) - } - - // Check required fields are passed - err = req.ParseForm() - if err != nil { - t.Fatalf("Do call: failed to parse body form: %v", err) - } - reqGrantType := req.Form.Get("grant_type") - if reqGrantType != "urn:ietf:params:oauth:grant-type:jwt-bearer" { - t.Fatalf("Do call: failed request to refresh token: call to refresh access expected to have grant type as %q, found %q instead", "urn:ietf:params:oauth:grant-type:jwt-bearer", reqGrantType) - } - // Return response with accessTokenSecond - responseBodyStruct := TokenResponseBody{ - AccessToken: accessTokenSecond, - } - responseBody, err := json.Marshal(responseBodyStruct) - if err != nil { - t.Fatalf("Do call: failed request to refresh token: marshal access token response: %v", err) - } - response := &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(bytes.NewReader(responseBody)), - } - return response, nil - } - } - - // Handle regular HTTP requests - switch currentTestPhase { - default: - t.Fatalf("Do call: unexpected request during test phase %d", currentTestPhase) - return nil, nil - case 2: // Call by tokenFlow, first request - if doTestPhase2RequestDone { - t.Fatalf("Do call: multiple requests during test phase 2") - } - doTestPhase2RequestDone = true - - // Check host and access token - host := req.Host - expectedHost := "first-request-url.com" - if host != expectedHost { - t.Fatalf("Do call: first request expected to have host %q, found %q", expectedHost, host) - } - authHeader := req.Header.Get("Authorization") - expectedAuthHeader := fmt.Sprintf("Bearer %s", accessTokenFirst) - if authHeader != expectedAuthHeader { - t.Fatalf("Do call: first request didn't carry first access token. Expected: %s, Got: %s", expectedAuthHeader, authHeader) - } - - // Return empty response - response := &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(bytes.NewReader([]byte{})), - } - return response, nil - case 4: // Call by tokenFlow, second request - if doTestPhase4RequestDone { - t.Fatalf("Do call: multiple requests during test phase 4") - } - doTestPhase4RequestDone = true - - // Check host and access token - host := req.Host - expectedHost := "second-request-url.com" - if host != expectedHost { - t.Fatalf("Do call: second request expected to have host %q, found %q", expectedHost, host) - } - authHeader := req.Header.Get("Authorization") - if authHeader != fmt.Sprintf("Bearer %s", accessTokenSecond) { - t.Fatalf("Do call: second request didn't carry second access token") - } - - // Return empty response - response := &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(bytes.NewReader([]byte{})), - } - return response, nil - } - } - - keyFlow := &KeyFlow{} - privateKeyBytes, err := generatePrivateKey() - if err != nil { - t.Fatalf("Error generating private key: %s", err) - } - keyFlowConfig := &KeyFlowConfig{ - ServiceAccountKey: fixtureServiceAccountKey(), - PrivateKey: string(privateKeyBytes), - AuthHTTPClient: &http.Client{ - Transport: mockTransportFn{mockDo}, - }, - HTTPTransport: mockTransportFn{mockDo}, // Use same mock for regular requests - // Don't start continuous refresh automatically - BackgroundTokenRefreshContext: nil, - } - err = keyFlow.Init(keyFlowConfig) - if err != nil { - t.Fatalf("failed to initialize key flow: %v", err) - } - - // Set the token after initialization - err = keyFlow.SetToken(accessTokenFirst) - if err != nil { - t.Fatalf("failed to set token: %v", err) - } - - // Set the context for continuous refresh - keyFlow.config.BackgroundTokenRefreshContext = ctx - - // Create a custom refresher with shorter timing for the test - refresher := &continuousTokenRefresher{ - keyFlow: keyFlow, - timeStartBeforeTokenExpiration: 9 * time.Second, // Start 9 seconds before expiration - timeBetweenContextCheck: 5 * time.Millisecond, - timeBetweenTries: 40 * time.Millisecond, - } - - // TEST START - currentTestPhase = 1 - // Ignore returned error as expected in test - go func() { - _ = refresher.continuousRefreshToken() - }() +func signToken(expiration time.Duration) (string, error) { + return jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(time.Now().Add(expiration)), + }).SignedString([]byte("test")) +} - // Wait until continuousRefreshToken() is blocked - <-chanBlockContinuousRefreshToken +var _ AuthFlow = &fakeAuthFlow{} - if currentTestPhase != 2 { - t.Fatalf("Unexpected test phase %d after continuousRefreshToken() was blocked", currentTestPhase) - } +type fakeAuthFlow struct { + backgroundTokenRefreshContext context.Context + tokenCounter int + doError error + accessTokensTimeToLive time.Duration + accessToken string +} - // Perform first request - req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://first-request-url.com", http.NoBody) - if err != nil { - t.Fatalf("Create first request failed: %v", err) - } - resp, err := keyFlow.RoundTrip(req) - if err != nil { - t.Fatalf("Perform first request failed: %v", err) - } - err = resp.Body.Close() +func (f *fakeAuthFlow) RoundTrip(req *http.Request) (*http.Response, error) { + return nil, nil +} +func (f *fakeAuthFlow) GetAccessToken() (string, error) { + expired, err := tokenExpired(f.accessToken, 0) if err != nil { - t.Fatalf("First request body failed to close: %v", err) + return "", err } - - // Unblock continuousRefreshToken() - currentTestPhase = 3 - chanUnblockContinuousRefreshToken <- true - - // Wait for a bit - time.Sleep(10 * time.Millisecond) - currentTestPhase = 4 - - // Perform second request - req, err = http.NewRequestWithContext(ctx, http.MethodGet, "http://second-request-url.com", http.NoBody) - if err != nil { - t.Fatalf("Create second request failed: %v", err) + if !expired { + return f.accessToken, nil } - resp, err = keyFlow.RoundTrip(req) - if err != nil { - t.Fatalf("Second request failed: %v", err) + f.tokenCounter++ + if f.doError != nil { + return "", f.doError } - err = resp.Body.Close() + accessToken, err := signToken(f.accessTokensTimeToLive) if err != nil { - t.Fatalf("Second request body failed to close: %v", err) + return "", f.doError } + f.accessToken = accessToken + return accessToken, nil +} +func (f *fakeAuthFlow) GetBackgroundTokenRefreshContext() context.Context { + return f.backgroundTokenRefreshContext } -func contains(arr []int, val int) bool { - for _, v := range arr { - if v == val { - return true - } - } - return false +func (f *fakeAuthFlow) getTokenCalls() int { + return f.tokenCounter } diff --git a/core/clients/key_flow_test.go b/core/clients/key_flow_test.go index a64bee881..8b8877673 100644 --- a/core/clients/key_flow_test.go +++ b/core/clients/key_flow_test.go @@ -17,15 +17,10 @@ import ( "testing" "time" - "github.com/golang-jwt/jwt/v5" "github.com/google/go-cmp/cmp" "github.com/google/uuid" ) -var ( - testSigningKey = []byte(`Test`) -) - const testBearerToken = "eyJhbGciOiJub25lIn0.eyJleHAiOjIxNDc0ODM2NDd9." //nolint:gosec // linter false positive func fixtureServiceAccountKey(mods ...func(*ServiceAccountKeyResponse)) *ServiceAccountKeyResponse { @@ -135,25 +130,25 @@ func TestTokenExpired(t *testing.T) { tests := []struct { desc string tokenInvalid bool - tokenExpiresAt time.Time + tokenDuration time.Duration expectedErr bool expectedIsExpired bool }{ { desc: "token valid", - tokenExpiresAt: time.Now().Add(time.Hour), + tokenDuration: time.Hour, expectedErr: false, expectedIsExpired: false, }, { desc: "token expired", - tokenExpiresAt: time.Now().Add(-time.Hour), + tokenDuration: -time.Hour, expectedErr: false, expectedIsExpired: true, }, { desc: "token almost expired", - tokenExpiresAt: time.Now().Add(tokenExpirationLeeway), + tokenDuration: tokenExpirationLeeway, expectedErr: false, expectedIsExpired: true, }, @@ -169,9 +164,7 @@ func TestTokenExpired(t *testing.T) { var err error token := "foo" if !tt.tokenInvalid { - token, err = jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(tt.tokenExpiresAt), - }).SignedString([]byte("test")) + token, err = signToken(tt.tokenDuration) if err != nil { t.Fatalf("failed to create token: %v", err) } diff --git a/core/clients/workload_identity_flow.go b/core/clients/workload_identity_flow.go new file mode 100644 index 000000000..65b6fc461 --- /dev/null +++ b/core/clients/workload_identity_flow.go @@ -0,0 +1,249 @@ +package clients + +import ( + "context" + "fmt" + "net/http" + "net/url" + "os" + "strings" + "sync" + "time" + + "github.com/golang-jwt/jwt/v5" +) + +const ( + clientIDEnv = "STACKIT_SERVICE_ACCOUNT_EMAIL" + FederatedTokenFileEnv = "STACKIT_FEDERATED_TOKEN_FILE" + wifTokenEndpointEnv = "STACKIT_IDP_ENDPOINT" + wifTokenExpirationEnv = "STACKIT_IDP_EXPIRATION_SECONDS" + + wifClientAssertionType = "urn:schwarz:params:oauth:client-assertion-type:workload-jwt" + wifGrantType = "client_credentials" + defaultWifTokenEndpoint = "https://accounts.stackit.cloud/oauth/v2/token" + defaultFederatedTokenPath = "/var/run/secrets/stackit.cloud/serviceaccount/token" + defaultWifExpirationToken = "1h" +) + +var ( + _ = getEnvOrDefault(wifTokenExpirationEnv, defaultWifExpirationToken) // Not used yet +) + +func getEnvOrDefault(envVar, defaultValue string) string { + if value := os.Getenv(envVar); value != "" { + return value + } + return defaultValue +} + +var _ AuthFlow = &WorkloadIdentityFederationFlow{} + +// WorkloadIdentityFlow handles auth with Workload Identity Federation +type WorkloadIdentityFederationFlow struct { + rt http.RoundTripper + authClient *http.Client + config *WorkloadIdentityFederationFlowConfig + + tokenMutex sync.RWMutex + token *TokenResponseBody + + parser *jwt.Parser + + // If the current access token would expire in less than TokenExpirationLeeway, + // the client will refresh it early to prevent clock skew or other timing issues. + tokenExpirationLeeway time.Duration +} + +// KeyFlowConfig is the flow config +type WorkloadIdentityFederationFlowConfig struct { + TokenUrl string + ClientID string + FederatedTokenFilePath string + TokenExpiration string // Not supported yet + BackgroundTokenRefreshContext context.Context // Functionality is enabled if this isn't nil + HTTPTransport http.RoundTripper + AuthHTTPClient *http.Client +} + +// GetConfig returns the flow configuration +func (c *WorkloadIdentityFederationFlow) GetConfig() WorkloadIdentityFederationFlowConfig { + if c.config == nil { + return WorkloadIdentityFederationFlowConfig{} + } + return *c.config +} + +// GetAccessToken implements AuthFlow. +func (c *WorkloadIdentityFederationFlow) GetAccessToken() (string, error) { + if c.rt == nil { + return "", fmt.Errorf("nil http round tripper, please run Init()") + } + var accessToken string + + c.tokenMutex.RLock() + if c.token != nil { + accessToken = c.token.AccessToken + } + c.tokenMutex.RUnlock() + + accessTokenExpired, err := tokenExpired(accessToken, c.tokenExpirationLeeway) + if err != nil { + return "", fmt.Errorf("check access token is expired: %w", err) + } + if !accessTokenExpired { + return accessToken, nil + } + if err = c.createAccessToken(); err != nil { + return "", fmt.Errorf("get new access token: %w", err) + } + + c.tokenMutex.RLock() + accessToken = c.token.AccessToken + c.tokenMutex.RUnlock() + + return accessToken, nil +} + +// RoundTrip implements the http.RoundTripper interface. +// It gets a token, adds it to the request's authorization header, and performs the request. +func (c *WorkloadIdentityFederationFlow) RoundTrip(req *http.Request) (*http.Response, error) { + if c.rt == nil { + return nil, fmt.Errorf("please run Init()") + } + + accessToken, err := c.GetAccessToken() + if err != nil { + return nil, err + } + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken)) + return c.rt.RoundTrip(req) +} + +// GetBackgroundTokenRefreshContext implements AuthFlow. +func (c *WorkloadIdentityFederationFlow) GetBackgroundTokenRefreshContext() context.Context { + return c.config.BackgroundTokenRefreshContext +} + +func (c *WorkloadIdentityFederationFlow) Init(cfg *WorkloadIdentityFederationFlowConfig) error { + // No concurrency at this point, so no mutex check needed + c.token = &TokenResponseBody{} + c.config = cfg + c.parser = jwt.NewParser() + + if c.config.TokenUrl == "" { + c.config.TokenUrl = getEnvOrDefault(wifTokenEndpointEnv, defaultWifTokenEndpoint) + } + + if c.config.ClientID == "" { + c.config.ClientID = getEnvOrDefault(clientIDEnv, "") + } + + if c.config.FederatedTokenFilePath == "" { + c.config.FederatedTokenFilePath = getEnvOrDefault(FederatedTokenFileEnv, defaultFederatedTokenPath) + } + + c.tokenExpirationLeeway = defaultTokenExpirationLeeway + + if c.rt = cfg.HTTPTransport; c.rt == nil { + c.rt = http.DefaultTransport + } + + if c.authClient = cfg.AuthHTTPClient; cfg.AuthHTTPClient == nil { + c.authClient = &http.Client{ + Transport: c.rt, + Timeout: DefaultClientTimeout, + } + } + + err := c.validate() + if err != nil { + return err + } + + // // Init the token + // _, err = c.GetAccessToken() + // if err != nil { + // return err + // } + + if c.config.BackgroundTokenRefreshContext != nil { + go continuousRefreshToken(c) + } + return nil +} + +// validate the client is configured well +func (c *WorkloadIdentityFederationFlow) validate() error { + if c.config.ClientID == "" { + return fmt.Errorf("client ID cannot be empty") + } + if c.config.TokenUrl == "" { + return fmt.Errorf("token URL cannot be empty") + } + if _, err := c.readJWTFromFileSystem(c.config.FederatedTokenFilePath); err != nil { + return fmt.Errorf("error reading federated token file - %w", err) + } + if c.tokenExpirationLeeway < 0 { + return fmt.Errorf("token expiration leeway cannot be negative") + } + + return nil +} + +// createAccessToken creates an access token using self signed JWT +func (c *WorkloadIdentityFederationFlow) createAccessToken() (err error) { + clientAssertion, err := c.readJWTFromFileSystem(c.config.FederatedTokenFilePath) + if err != nil { + return fmt.Errorf("error reading service account assertion - %w", err) + } + + res, err := c.requestToken(c.config.ClientID, clientAssertion) + if err != nil { + return err + } + defer func() { + tempErr := res.Body.Close() + if tempErr != nil && err == nil { + err = fmt.Errorf("close request access token response: %w", tempErr) + } + }() + token, err := parseTokenResponse(res) + if err != nil { + return err + } + c.tokenMutex.Lock() + c.token = token + c.tokenMutex.Unlock() + return nil +} + +func (c *WorkloadIdentityFederationFlow) requestToken(clientID, assertion string) (*http.Response, error) { + body := url.Values{} + body.Set("grant_type", wifGrantType) + body.Set("client_assertion_type", wifClientAssertionType) + body.Set("client_assertion", assertion) + body.Set("client_id", clientID) + + payload := strings.NewReader(body.Encode()) + req, err := http.NewRequest(http.MethodPost, c.config.TokenUrl, payload) + if err != nil { + return nil, err + } + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") + + return c.authClient.Do(req) +} + +func (c *WorkloadIdentityFederationFlow) readJWTFromFileSystem(tokenFilePath string) (string, error) { + token, err := os.ReadFile(tokenFilePath) + if err != nil { + return "", err + } + tokenStr := string(token) + _, _, err = c.parser.ParseUnverified(tokenStr, jwt.MapClaims{}) + if err != nil { + return "", err + } + return tokenStr, nil +} diff --git a/core/clients/workload_identity_flow_test.go b/core/clients/workload_identity_flow_test.go new file mode 100644 index 000000000..ef8f7a15f --- /dev/null +++ b/core/clients/workload_identity_flow_test.go @@ -0,0 +1,566 @@ +package clients + +import ( + "encoding/json" + "log" + "net/http" + "net/http/httptest" + "os" + "strings" + "testing" + "time" + + "github.com/golang-jwt/jwt/v5" +) + +func TestWorkloadIdentityFlowInit(t *testing.T) { + tests := []struct { + name string + clientID string + clientIDAsEnv bool + customTokenUrl string + customTokenUrlEnv bool + tokenExpiration string + validAssertion bool + tokenFilePathAsEnv bool + missingTokenFilePath bool + wantErr bool + }{ + { + name: "ok setting all", + clientID: "test@stackit.cloud", + validAssertion: true, + wantErr: false, + }, + { + name: "ok using defaults from envs", + clientID: "test@stackit.cloud", + clientIDAsEnv: true, + tokenFilePathAsEnv: true, + customTokenUrlEnv: true, + validAssertion: true, + wantErr: false, + }, + { + name: "ok using defaults from envs", + clientID: "test@stackit.cloud", + clientIDAsEnv: true, + tokenFilePathAsEnv: true, + customTokenUrlEnv: true, + validAssertion: true, + wantErr: false, + }, + { + name: "missing client id", + validAssertion: true, + wantErr: true, + }, + { + name: "missing assertion", + clientID: "test@stackit.cloud", + missingTokenFilePath: true, + wantErr: true, + }, + { + name: "invalid assertion", + clientID: "test@stackit.cloud", + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + flow := &WorkloadIdentityFederationFlow{} + flowConfig := &WorkloadIdentityFederationFlowConfig{} + if tt.customTokenUrl != "" { + if tt.customTokenUrlEnv { + t.Setenv("STACKIT_IDP_ENDPOINT", tt.customTokenUrl) + } else { + flowConfig.TokenUrl = tt.customTokenUrl + } + } + + if tt.clientID != "" { + if tt.clientIDAsEnv { + t.Setenv("STACKIT_SERVICE_ACCOUNT_EMAIL", tt.clientID) + } else { + flowConfig.ClientID = tt.clientID + } + } + if tt.tokenExpiration != "" { + flowConfig.TokenExpiration = tt.tokenExpiration + } + + if !tt.missingTokenFilePath { + file, err := os.CreateTemp("", "*.token") + if err != nil { + log.Fatal(err) + } + defer os.Remove(file.Name()) + if tt.validAssertion { + token, err := signTokenWithSubject("subject", time.Minute) + if err != nil { + t.Fatalf("failed to create token: %v", err) + } + os.WriteFile(file.Name(), []byte(token), os.ModeAppend) + } + if tt.tokenFilePathAsEnv { + t.Setenv("STACKIT_FEDERATED_TOKEN_FILE", file.Name()) + } else { + flowConfig.FederatedTokenFilePath = file.Name() + } + } + + if err := flow.Init(flowConfig); (err != nil) != tt.wantErr { + t.Errorf("KeyFlow.Init() error = %v, wantErr %v", err, tt.wantErr) + } + if flow.config == nil { + t.Error("config is nil") + } + + if flow.config.ClientID != tt.clientID { + t.Errorf("clientID mismatch, want %s, got %s", tt.clientID, flow.config.ClientID) + } + + if tt.customTokenUrl != "" && flow.config.TokenUrl != tt.customTokenUrl { + t.Errorf("tokenUrl mismatch, want %s, got %s", tt.customTokenUrl, flow.config.TokenUrl) + } + + if tt.customTokenUrl == "" && flow.config.TokenUrl != "https://accounts.stackit.cloud/oauth/v2/token" { + t.Errorf("tokenUrl mismatch, want %s, got %s", "https://accounts.stackit.cloud/oauth/v2/token", flow.config.TokenUrl) + } + + if tt.missingTokenFilePath && flow.config.FederatedTokenFilePath != "/var/run/secrets/stackit.cloud/serviceaccount/token" { + t.Errorf("clientID mismatch, want %s, got %s", "/var/run/secrets/stackit.cloud/serviceaccount/token", flow.config.FederatedTokenFilePath) + } + + if !tt.missingTokenFilePath && flow.config.FederatedTokenFilePath == "/var/run/secrets/stackit.cloud/serviceaccount/token" { + t.Errorf("clientID mismatch, want different from %s", flow.config.FederatedTokenFilePath) + } + + if tt.tokenExpiration != "" && flow.config.TokenExpiration != tt.tokenExpiration { + t.Errorf("tokenExpiration mismatch, want %s, got %s", tt.tokenExpiration, flow.config.TokenExpiration) + } + }) + } +} + +func signTokenWithSubject(sub string, expiration time.Duration) (string, error) { + return jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(time.Now().Add(expiration)), + Subject: sub, + }).SignedString([]byte("test")) +} + +func TestWorkloadIdentityFlowRoundTrip(t *testing.T) { + validSub := "valid-sub" + serviceAccountSub := "sa-sub" + tests := []struct { + name string + clientID string + validAssertion bool + wantErr bool + }{ + { + name: "ok setting all", + clientID: "test@stackit.cloud", + validAssertion: true, + wantErr: false, + }, + { + name: "invalid assertion", + clientID: "test@stackit.cloud", + validAssertion: false, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + authServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + r.ParseForm() + assertionType := r.PostForm.Get("client_assertion_type") + if assertionType != "urn:schwarz:params:oauth:client-assertion-type:workload-jwt" { + t.Fatalf("invalid assertion type: %s", assertionType) + } + grantType := r.PostForm.Get("grant_type") + if grantType != "client_credentials" { + t.Fatalf("invalid grant type: %s", assertionType) + } + context, _, err := jwt.NewParser().ParseUnverified(r.PostForm.Get("client_assertion"), jwt.MapClaims{}) + if err != nil { + t.Fatalf("failed to validate token: %v", err) + } + sub, err := context.Claims.GetSubject() + if err != nil { + t.Fatalf("failed to validate token sub: %v", err) + } + if sub != validSub { + w.WriteHeader(http.StatusUnauthorized) + return + } + + token, err := signTokenWithSubject(serviceAccountSub, time.Minute) + if err != nil { + t.Fatalf("failed to create token: %v", err) + } + + tokenResponse := &TokenResponseBody{ + AccessToken: token, + ExpiresIn: 60, + TokenType: "Bearer", + } + + payload, err := json.Marshal(tokenResponse) + if err != nil { + t.Fatalf("failed to create token payload: %v", err) + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write(payload) + })) + t.Cleanup(authServer.Close) + + protectedResource := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + context, _, err := jwt.NewParser().ParseUnverified(strings.Fields(r.Header.Get("Authorization"))[1], jwt.MapClaims{}) + if err != nil { + t.Fatalf("failed to validate token: %v", err) + } + + sub, err := context.Claims.GetSubject() + if err != nil { + t.Fatalf("failed to validate token sub: %v", err) + } + if sub != serviceAccountSub { + t.Fatalf("invalid token on protected resource: %v", err) + } + + w.WriteHeader(http.StatusOK) + })) + t.Cleanup(protectedResource.Close) + + flow := &WorkloadIdentityFederationFlow{} + flowConfig := &WorkloadIdentityFederationFlowConfig{} + flowConfig.TokenUrl = authServer.URL + + flowConfig.ClientID = tt.clientID + file, err := os.CreateTemp("", "*.token") + if err != nil { + log.Fatal(err) + } + defer os.Remove(file.Name()) + flowConfig.FederatedTokenFilePath = file.Name() + + subject := "wrong" + if tt.validAssertion { + subject = validSub + } + token, err := signTokenWithSubject(subject, time.Minute) + if err != nil { + t.Fatalf("failed to create token: %v", err) + } + os.WriteFile(file.Name(), []byte(token), os.ModeAppend) + + if err := flow.Init(flowConfig); err != nil { + t.Errorf("KeyFlow.Init() error = %v", err) + } + if flow.config == nil { + t.Error("config is nil") + } + + client := http.Client{ + Transport: flow, + } + resp, err := client.Get(protectedResource.URL) + if (err != nil || resp.StatusCode != http.StatusOK) && !tt.wantErr { + t.Fatalf("failed request to protected resource: %v", err) + } + }) + } +} + +// func TestRequestToken(t *testing.T) { +// testCases := []struct { +// name string +// grant string +// assertion string +// mockResponse *http.Response +// mockError error +// expectedError error +// }{ +// { +// name: "Success", +// grant: "test_grant", +// assertion: "test_assertion", +// mockResponse: &http.Response{ +// StatusCode: 200, +// Body: io.NopCloser(strings.NewReader(`{"access_token": "test_token"}`)), +// }, +// mockError: nil, +// expectedError: nil, +// }, +// { +// name: "Error", +// grant: "test_grant", +// assertion: "test_assertion", +// mockResponse: nil, +// mockError: fmt.Errorf("request error"), +// expectedError: fmt.Errorf("request error"), +// }, +// } + +// for _, tt := range testCases { +// t.Run(tt.name, func(t *testing.T) { +// keyFlow := &KeyFlow{} +// privateKeyBytes, err := generatePrivateKey() +// if err != nil { +// t.Fatalf("Error generating private key: %s", err) +// } +// keyFlowConfig := &KeyFlowConfig{ +// AuthHTTPClient: &http.Client{ +// Transport: mockTransportFn{func(_ *http.Request) (*http.Response, error) { +// return tt.mockResponse, tt.mockError +// }}, +// }, +// ServiceAccountKey: fixtureServiceAccountKey(), +// PrivateKey: string(privateKeyBytes), +// HTTPTransport: http.DefaultTransport, +// } +// err = keyFlow.Init(keyFlowConfig) +// if err != nil { +// t.Fatalf("failed to initialize key flow: %v", err) +// } + +// res, err := keyFlow.requestToken(tt.grant, tt.assertion) +// defer func() { +// if res != nil { +// tempErr := res.Body.Close() +// if tempErr != nil { +// t.Errorf("closing request token response: %s", tempErr.Error()) +// } +// } +// }() +// if tt.expectedError != nil { +// if err == nil { +// t.Errorf("Expected error '%v' but no error was returned", tt.expectedError) +// } else if errors.Is(err, tt.expectedError) { +// t.Errorf("Error is not correct. Expected %v, got %v", tt.expectedError, err) +// } +// } else { +// if err != nil { +// t.Errorf("Expected no error but error was returned: %v", err) +// } +// if !cmp.Equal(tt.mockResponse, res, cmp.AllowUnexported(strings.Reader{})) { +// t.Errorf("The returned result is wrong. Expected %v, got %v", tt.mockResponse, res) +// } +// } +// }) +// } +// } + +// func TestKeyFlow_Do(t *testing.T) { +// t.Parallel() + +// tests := []struct { +// name string +// handlerFn func(tb testing.TB) http.HandlerFunc +// want int +// wantErr bool +// }{ +// { +// name: "success", +// handlerFn: func(tb testing.TB) http.HandlerFunc { +// tb.Helper() + +// return func(w http.ResponseWriter, r *http.Request) { +// if r.Header.Get("Authorization") != "Bearer "+testBearerToken { +// tb.Errorf("expected Authorization header to be 'Bearer %s', but got %s", testBearerToken, r.Header.Get("Authorization")) +// } + +// w.Header().Set("Content-Type", "application/json") +// w.WriteHeader(http.StatusOK) +// _, _ = fmt.Fprintln(w, `{"status":"ok"}`) +// } +// }, +// want: http.StatusOK, +// wantErr: false, +// }, +// { +// name: "success with code 500", +// handlerFn: func(_ testing.TB) http.HandlerFunc { +// return func(w http.ResponseWriter, _ *http.Request) { +// w.Header().Set("Content-Type", "text/html") +// w.WriteHeader(http.StatusInternalServerError) +// _, _ = fmt.Fprintln(w, `Internal Server Error`) +// } +// }, +// want: http.StatusInternalServerError, +// wantErr: false, +// }, +// { +// name: "success with custom transport", +// handlerFn: func(tb testing.TB) http.HandlerFunc { +// tb.Helper() + +// return func(w http.ResponseWriter, r *http.Request) { +// if r.Header.Get("User-Agent") != "custom_transport" { +// tb.Errorf("expected User-Agent header to be 'custom_transport', but got %s", r.Header.Get("User-Agent")) +// } + +// w.Header().Set("Content-Type", "application/json") +// w.WriteHeader(http.StatusOK) +// _, _ = fmt.Fprintln(w, `{"status":"ok"}`) +// } +// }, +// want: http.StatusOK, +// wantErr: false, +// }, +// { +// name: "fail with custom proxy", +// handlerFn: func(testing.TB) http.HandlerFunc { +// return func(w http.ResponseWriter, _ *http.Request) { +// w.Header().Set("Content-Type", "application/json") +// w.WriteHeader(http.StatusOK) +// _, _ = fmt.Fprintln(w, `{"status":"ok"}`) +// } +// }, +// want: 0, +// wantErr: true, +// }, +// } +// for _, tt := range tests { +// t.Run(tt.name, func(t *testing.T) { +// ctx := context.Background() +// ctx, cancel := context.WithCancel(ctx) +// t.Cleanup(cancel) // This cancels the refresher goroutine + +// privateKeyBytes, err := generatePrivateKey() +// if err != nil { +// t.Errorf("no error is expected, but got %v", err) +// } + +// keyFlow := &KeyFlow{} +// keyFlowConfig := &KeyFlowConfig{ +// ServiceAccountKey: fixtureServiceAccountKey(), +// PrivateKey: string(privateKeyBytes), +// BackgroundTokenRefreshContext: ctx, +// HTTPTransport: func() http.RoundTripper { +// switch tt.name { +// case "success with custom transport": +// return mockTransportFn{ +// fn: func(req *http.Request) (*http.Response, error) { +// req.Header.Set("User-Agent", "custom_transport") +// return http.DefaultTransport.RoundTrip(req) +// }, +// } +// case "fail with custom proxy": +// return &http.Transport{ +// Proxy: func(_ *http.Request) (*url.URL, error) { +// return nil, fmt.Errorf("proxy error") +// }, +// } +// default: +// return http.DefaultTransport +// } +// }(), +// AuthHTTPClient: &http.Client{ +// Transport: mockTransportFn{ +// fn: func(_ *http.Request) (*http.Response, error) { +// res := httptest.NewRecorder() +// res.WriteHeader(http.StatusOK) +// res.Header().Set("Content-Type", "application/json") + +// token := &TokenResponseBody{ +// AccessToken: testBearerToken, +// ExpiresIn: 2147483647, +// TokenType: "Bearer", +// } + +// if err := json.NewEncoder(res.Body).Encode(token); err != nil { +// t.Logf("no error is expected, but got %v", err) +// } + +// return res.Result(), nil +// }, +// }, +// }, +// } +// err = keyFlow.Init(keyFlowConfig) +// if err != nil { +// t.Fatalf("failed to initialize key flow: %v", err) +// } + +// go continuousRefreshToken(keyFlow) + +// tokenCtx, tokenCancel := context.WithTimeout(context.Background(), 1*time.Second) + +// token: +// for { +// select { +// case <-tokenCtx.Done(): +// t.Error(tokenCtx.Err()) +// case <-time.After(50 * time.Millisecond): +// keyFlow.tokenMutex.RLock() +// if keyFlow.token != nil { +// keyFlow.tokenMutex.RUnlock() +// tokenCancel() +// break token +// } + +// keyFlow.tokenMutex.RUnlock() +// } +// } + +// server := httptest.NewServer(tt.handlerFn(t)) +// t.Cleanup(server.Close) + +// u, err := url.Parse(server.URL) +// if err != nil { +// t.Errorf("no error is expected, but got %v", err) +// } + +// req, err := http.NewRequest(http.MethodGet, u.String(), http.NoBody) +// if err != nil { +// t.Errorf("no error is expected, but got %v", err) +// } + +// httpClient := &http.Client{ +// Transport: keyFlow, +// } + +// res, err := httpClient.Do(req) + +// if tt.wantErr { +// if err == nil { +// t.Errorf("error is expected, but got %v", err) +// } +// } else { +// if err != nil { +// t.Errorf("no error is expected, but got %v", err) +// } + +// if res.StatusCode != tt.want { +// t.Errorf("expected status code %d, but got %d", tt.want, res.StatusCode) +// } + +// // Defer discard and close the body +// t.Cleanup(func() { +// if _, err := io.Copy(io.Discard, res.Body); err != nil { +// t.Errorf("no error is expected, but got %v", err) +// } + +// if err := res.Body.Close(); err != nil { +// t.Errorf("no error is expected, but got %v", err) +// } +// }) +// } +// }) +// } +// } + +// type mockTransportFn struct { +// fn func(req *http.Request) (*http.Response, error) +// } + +// func (m mockTransportFn) RoundTrip(req *http.Request) (*http.Response, error) { +// return m.fn(req) +// } diff --git a/core/config/config.go b/core/config/config.go index 93002c02a..ae2d8c498 100644 --- a/core/config/config.go +++ b/core/config/config.go @@ -75,26 +75,29 @@ type Middleware func(http.RoundTripper) http.RoundTripper // Configuration stores the configuration of the API client type Configuration struct { - Host string `json:"host,omitempty"` - Scheme string `json:"scheme,omitempty"` - DefaultHeader map[string]string `json:"defaultHeader,omitempty"` - UserAgent string `json:"userAgent,omitempty"` - Debug bool `json:"debug,omitempty"` - NoAuth bool `json:"noAuth,omitempty"` - ServiceAccountEmail string `json:"serviceAccountEmail,omitempty"` // Deprecated: ServiceAccountEmail is not required and will be removed after 12th June 2025. - Token string `json:"token,omitempty"` - ServiceAccountKey string `json:"serviceAccountKey,omitempty"` - PrivateKey string `json:"privateKey,omitempty"` - ServiceAccountKeyPath string `json:"serviceAccountKeyPath,omitempty"` - PrivateKeyPath string `json:"privateKeyPath,omitempty"` - CredentialsFilePath string `json:"credentialsFilePath,omitempty"` - TokenCustomUrl string `json:"tokenCustomUrl,omitempty"` - Region string `json:"region,omitempty"` - CustomAuth http.RoundTripper - Servers ServerConfigurations - OperationServers map[string]ServerConfigurations - HTTPClient *http.Client - Middleware []Middleware + Host string `json:"host,omitempty"` + Scheme string `json:"scheme,omitempty"` + DefaultHeader map[string]string `json:"defaultHeader,omitempty"` + UserAgent string `json:"userAgent,omitempty"` + Debug bool `json:"debug,omitempty"` + NoAuth bool `json:"noAuth,omitempty"` + WorkloadIdentityFederation bool `json:"workloadIdentityFederation,omitempty"` + WorkloadIdentityFederationTokenExpiration string `json:"workloadIdentityFederationTokenExpiration,omitempty"` + WorkloadIdentityFederationFederatedTokenPath string `json:"workloadIdentityFederationFederatedTokenPath,omitempty"` + ServiceAccountEmail string `json:"serviceAccountEmail,omitempty"` + Token string `json:"token,omitempty"` + ServiceAccountKey string `json:"serviceAccountKey,omitempty"` + PrivateKey string `json:"privateKey,omitempty"` + ServiceAccountKeyPath string `json:"serviceAccountKeyPath,omitempty"` + PrivateKeyPath string `json:"privateKeyPath,omitempty"` + CredentialsFilePath string `json:"credentialsFilePath,omitempty"` + TokenCustomUrl string `json:"tokenCustomUrl,omitempty"` + Region string `json:"region,omitempty"` + CustomAuth http.RoundTripper + Servers ServerConfigurations + OperationServers map[string]ServerConfigurations + HTTPClient *http.Client + Middleware []Middleware // If != nil, a goroutine will be launched that will refresh the service account's access token when it's close to being expired. // The goroutine is killed whenever this context is canceled. @@ -176,8 +179,6 @@ func WithTokenEndpoint(url string) ConfigurationOption { } // WithServiceAccountEmail returns a ConfigurationOption that sets the service account email -// -// Deprecated: WithServiceAccountEmail is not required and will be removed after 12th June 2025. func WithServiceAccountEmail(serviceAccountEmail string) ConfigurationOption { return func(config *Configuration) error { config.ServiceAccountEmail = serviceAccountEmail @@ -237,6 +238,30 @@ func WithToken(token string) ConfigurationOption { } } +// WithWorkloadIdentityFederationAuth returns a ConfigurationOption that sets workload identity flow to be used for authentication in API calls +func WithWorkloadIdentityFederationAuth() ConfigurationOption { + return func(config *Configuration) error { + config.WorkloadIdentityFederation = true + return nil + } +} + +// WithWorkloadIdentityFederation returns a ConfigurationOption that sets workload identity flow to be used for authentication in API calls +func WithWorkloadIdentityFederationTokenPath(path string) ConfigurationOption { + return func(config *Configuration) error { + config.WorkloadIdentityFederationFederatedTokenPath = path + return nil + } +} + +// WithWorkloadIdentityFederationTokenExpiration returns a ConfigurationOption that sets the token expiration for workload identity federation flow +func WithWorkloadIdentityFederationTokenExpiration(expiration string) ConfigurationOption { + return func(config *Configuration) error { + config.WorkloadIdentityFederationTokenExpiration = expiration + return nil + } +} + // Deprecated: retry options were removed to reduce complexity of the client. If this functionality is needed, you can provide your own custom HTTP client. This option has no effect, and will be removed in a later update func WithMaxRetries(_ int) ConfigurationOption { return func(_ *Configuration) error { diff --git a/examples/authentication/authentication.go b/examples/authentication/authentication.go index 839999938..b398b19a9 100644 --- a/examples/authentication/authentication.go +++ b/examples/authentication/authentication.go @@ -35,18 +35,27 @@ func main() { // Create a new API client, that will authenticate using the provided bearer token token := "TOKEN" - _, err = dns.NewAPIClient(config.WithToken(token)) + dnsClient, err := dns.NewAPIClient(config.WithToken(token)) if err != nil { fmt.Fprintf(os.Stderr, "[DNS API] Creating API client: %v\n", err) os.Exit(1) } + // Check that you can make an authenticated request + getZoneResp, err := dnsClient.ListZones(context.Background(), projectId).Execute() + + if err != nil { + fmt.Fprintf(os.Stderr, "[DNS API] Error when calling `ZoneApi.GetZones`: %v\n", err) + } else { + fmt.Printf("[DNS API] Number of zones: %v\n", len(*getZoneResp.Zones)) + } + // Create a new API client, that will authenticate using the key flow // If you created a service account key and provided your own RSA key pair, // you need to add the path to a PEM encoded file including the private key // using config.WithPrivateKeyPath("path/to/private_key.pem") saKeyPath := "/path/to/service_account_key.json" - dnsClient, err := dns.NewAPIClient( + dnsClient, err = dns.NewAPIClient( config.WithServiceAccountKeyPath(saKeyPath), ) if err != nil { @@ -55,7 +64,51 @@ func main() { } // Check that you can make an authenticated request - getZoneResp, err := dnsClient.ListZones(context.Background(), projectId).Execute() + getZoneResp, err = dnsClient.ListZones(context.Background(), projectId).Execute() + + if err != nil { + fmt.Fprintf(os.Stderr, "[DNS API] Error when calling `ZoneApi.GetZones`: %v\n", err) + } else { + fmt.Printf("[DNS API] Number of zones: %v\n", len(*getZoneResp.Zones)) + } + + // Create a new API client, that will authenticate using the wif flow + // You need to create a service account key and configure the federate identity provider, + // then you can init the SDK using default env var + os.Setenv("STACKIT_SERVICE_ACCOUNT_EMAIL", "my-sa@sa-stackit.cloud") + os.Setenv("STACKIT_FEDERATED_TOKEN_FILE", "/path/to/your/federated/token") // Default "/var/run/secrets/stackit.cloud/serviceaccount/token" + os.Setenv("STACKIT_IDP_ENDPOINT", "custom token endpoint") // Default "https://accounts.stackit.cloud/oauth/v2/token" + dnsClient, err = dns.NewAPIClient() + if err != nil { + fmt.Fprintf(os.Stderr, "[DNS API] Creating API client: %v\n", err) + os.Exit(1) + } + + // Check that you can make an authenticated request + getZoneResp, err = dnsClient.ListZones(context.Background(), projectId).Execute() + + if err != nil { + fmt.Fprintf(os.Stderr, "[DNS API] Error when calling `ZoneApi.GetZones`: %v\n", err) + } else { + fmt.Printf("[DNS API] Number of zones: %v\n", len(*getZoneResp.Zones)) + } + + // Create a new API client, that will authenticate using the wif flow + // You need to create a service account key and configure the federate identity provider, + // then you can init the SDK setting fields + dnsClient, err = dns.NewAPIClient( + config.WithWorkloadIdentityFederationAuth(), + config.WithTokenEndpoint("custom token endpoint"), + config.WithWorkloadIdentityFederationTokenPath("/path/to/your/federated/token"), + config.WithServiceAccountEmail("my-sa@sa-stackit.cloud"), + ) + if err != nil { + fmt.Fprintf(os.Stderr, "[DNS API] Creating API client: %v\n", err) + os.Exit(1) + } + + // Check that you can make an authenticated request + getZoneResp, err = dnsClient.ListZones(context.Background(), projectId).Execute() if err != nil { fmt.Fprintf(os.Stderr, "[DNS API] Error when calling `ZoneApi.GetZones`: %v\n", err) From c13c7adba3bbdb7c4e67e133b9b36c4d98c6d8e2 Mon Sep 17 00:00:00 2001 From: Jorge Turrado Date: Wed, 10 Dec 2025 12:16:08 +0100 Subject: [PATCH 06/15] update changelog Signed-off-by: Jorge Turrado --- core/CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/core/CHANGELOG.md b/core/CHANGELOG.md index 47d06e806..aaaa0636c 100644 --- a/core/CHANGELOG.md +++ b/core/CHANGELOG.md @@ -1,5 +1,6 @@ ## v0.21.0 - **Chore:** Use `jwt-bearer` grant to get a fresh token instead of `refresh_token` +- **Feature:** Support Workload Identity Federation flow ## v0.20.1 - **Improvement:** Improve error message when passing a PEM encoded file to as service account key From a4ad581aba73211c9c72261c0f40e6c647fb49af Mon Sep 17 00:00:00 2001 From: Jorge Turrado Date: Mon, 15 Dec 2025 18:38:12 +0100 Subject: [PATCH 07/15] apply feedback Signed-off-by: Jorge Turrado --- core/auth/auth.go | 20 ++++------- core/auth/auth_test.go | 17 ++++++++-- core/clients/key_flow.go | 40 ++++++++++++++++++++++ core/clients/key_flow_test.go | 63 +++++++++++++++++++++++++++++++++++ 4 files changed, 124 insertions(+), 16 deletions(-) diff --git a/core/auth/auth.go b/core/auth/auth.go index 88f002fe7..e3b10bc46 100644 --- a/core/auth/auth.go +++ b/core/auth/auth.go @@ -45,18 +45,18 @@ func SetupAuth(cfg *config.Configuration) (rt http.RoundTripper, err error) { if cfg.CustomAuth != nil { return cfg.CustomAuth, nil - } else if useWorkloadIdentityFederation(cfg) { - wifRoundTripper, err := WorkloadIdentityFederationAuth(cfg) - if err != nil { - return nil, fmt.Errorf("configuring no auth client: %w", err) - } - return wifRoundTripper, nil } else if cfg.NoAuth { noAuthRoundTripper, err := NoAuth(cfg) if err != nil { return nil, fmt.Errorf("configuring no auth client: %w", err) } return noAuthRoundTripper, nil + } else if cfg.WorkloadIdentityFederation { + wifRoundTripper, err := WorkloadIdentityFederationAuth(cfg) + if err != nil { + return nil, fmt.Errorf("configuring no auth client: %w", err) + } + return wifRoundTripper, nil } else if cfg.ServiceAccountKey != "" || cfg.ServiceAccountKeyPath != "" { keyRoundTripper, err := KeyAuth(cfg) if err != nil { @@ -394,11 +394,3 @@ func getServiceAccountKey(cfg *config.Configuration) error { func getPrivateKey(cfg *config.Configuration) error { return getKey(&cfg.PrivateKey, &cfg.PrivateKeyPath, "STACKIT_PRIVATE_KEY_PATH", "STACKIT_PRIVATE_KEY", privateKeyPathCredentialType, privateKeyCredentialType, cfg.CredentialsFilePath) } - -func useWorkloadIdentityFederation(cfg *config.Configuration) bool { - if cfg != nil && cfg.WorkloadIdentityFederation { - return true - } - val, exists := os.LookupEnv(clients.FederatedTokenFileEnv) - return exists && val != "" -} diff --git a/core/auth/auth_test.go b/core/auth/auth_test.go index 5e8af7203..b861bf581 100644 --- a/core/auth/auth_test.go +++ b/core/auth/auth_test.go @@ -179,23 +179,27 @@ func TestSetupAuth(t *testing.T) { setKeyPaths bool setCredentialsFilePathToken bool setCredentialsFilePathKey bool + isValid bool }{ { desc: "wif_config", config: nil, setWorkloadIdentity: true, + isValid: true, }, { desc: "token_config", config: nil, setToken: true, setCredentialsFilePathToken: false, + isValid: true, }, { desc: "key_config", config: nil, setKeys: true, setCredentialsFilePathToken: false, + isValid: true, }, { desc: "key_config_path", @@ -203,6 +207,7 @@ func TestSetupAuth(t *testing.T) { setKeys: false, setKeyPaths: true, setCredentialsFilePathToken: false, + isValid: true, }, { desc: "key_config_credentials_path", @@ -210,12 +215,14 @@ func TestSetupAuth(t *testing.T) { setKeys: false, setKeyPaths: false, setCredentialsFilePathKey: true, + isValid: true, }, { desc: "valid_path_to_file", config: nil, setToken: false, setCredentialsFilePathToken: true, + isValid: true, }, { desc: "custom_config_token", @@ -224,6 +231,7 @@ func TestSetupAuth(t *testing.T) { }, setToken: false, setCredentialsFilePathToken: false, + isValid: true, }, { desc: "custom_config_path", @@ -232,6 +240,7 @@ func TestSetupAuth(t *testing.T) { }, setToken: false, setCredentialsFilePathToken: false, + isValid: true, }, } { t.Run(test.desc, func(t *testing.T) { @@ -276,11 +285,15 @@ func TestSetupAuth(t *testing.T) { authRoundTripper, err := SetupAuth(test.config) - if err != nil { + if err != nil && test.isValid { t.Fatalf("Test returned error on valid test case: %v", err) } - if authRoundTripper == nil { + if err == nil && !test.isValid { + t.Fatalf("Test didn't return error on invalid test case") + } + + if authRoundTripper == nil && test.isValid { t.Fatalf("Roundtripper returned is nil for valid test case") } }) diff --git a/core/clients/key_flow.go b/core/clients/key_flow.go index 83c82e778..46b5d91a0 100644 --- a/core/clients/key_flow.go +++ b/core/clients/key_flow.go @@ -99,6 +99,46 @@ func (c *KeyFlow) GetServiceAccountEmail() string { return c.key.Credentials.Iss } +// GetToken returns the token field +// Deprecated: Use GetAccessToken instead +func (c *KeyFlow) GetToken() TokenResponseBody { + c.tokenMutex.RLock() + defer c.tokenMutex.RUnlock() + + if c.token == nil { + return TokenResponseBody{} + } + // Returned struct is passed by value (because it's a struct) + // So no deepy copy needed + return *c.token +} + +// SetToken can be used to set an access and refresh token manually in the client. +// The other fields in the token field are determined by inspecting the token or setting default values. +// Deprecated +func (c *KeyFlow) SetToken(accessToken, refreshToken string) error { + // We can safely use ParseUnverified because we are not authenticating the user, + // We are parsing the token just to get the expiration time claim + parsedAccessToken, _, err := jwt.NewParser().ParseUnverified(accessToken, &jwt.RegisteredClaims{}) + if err != nil { + return fmt.Errorf("parse access token to read expiration time: %w", err) + } + exp, err := parsedAccessToken.Claims.GetExpirationTime() + if err != nil { + return fmt.Errorf("get expiration time from access token: %w", err) + } + + c.tokenMutex.Lock() + c.token = &TokenResponseBody{ + AccessToken: accessToken, + ExpiresIn: int(exp.Time.Unix()), + Scope: "", + TokenType: "Bearer", + } + c.tokenMutex.Unlock() + return nil +} + func (c *KeyFlow) Init(cfg *KeyFlowConfig) error { // No concurrency at this point, so no mutex check needed c.token = &TokenResponseBody{} diff --git a/core/clients/key_flow_test.go b/core/clients/key_flow_test.go index 8b8877673..7c094331e 100644 --- a/core/clients/key_flow_test.go +++ b/core/clients/key_flow_test.go @@ -17,10 +17,15 @@ import ( "testing" "time" + "github.com/golang-jwt/jwt/v5" "github.com/google/go-cmp/cmp" "github.com/google/uuid" ) +var ( + testSigningKey = []byte(`Test`) +) + const testBearerToken = "eyJhbGciOiJub25lIn0.eyJleHAiOjIxNDc0ODM2NDd9." //nolint:gosec // linter false positive func fixtureServiceAccountKey(mods ...func(*ServiceAccountKeyResponse)) *ServiceAccountKeyResponse { @@ -125,6 +130,64 @@ func TestKeyFlowInit(t *testing.T) { } } +func TestSetToken(t *testing.T) { + tests := []struct { + name string + tokenInvalid bool + refreshToken string + wantErr bool + }{ + { + name: "ok", + tokenInvalid: false, + refreshToken: "refresh_token", + wantErr: false, + }, + { + name: "invalid_token", + tokenInvalid: true, + refreshToken: "", + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var accessToken string + var err error + + timestamp := time.Now().Add(24 * time.Hour) + if tt.tokenInvalid { + accessToken = "foo" + } else { + accessTokenJWT := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(timestamp)}) + accessToken, err = accessTokenJWT.SignedString(testSigningKey) + if err != nil { + t.Fatalf("get test access token as string: %s", err) + } + } + + keyFlow := &KeyFlow{} + err = keyFlow.SetToken(accessToken, tt.refreshToken) + + if (err != nil) != tt.wantErr { + t.Errorf("KeyFlow.SetToken() error = %v, wantErr %v", err, tt.wantErr) + } + if err == nil { + expectedKeyFlowToken := &TokenResponseBody{ + AccessToken: accessToken, + ExpiresIn: int(timestamp.Unix()), + Scope: "", + TokenType: "Bearer", + } + if !cmp.Equal(expectedKeyFlowToken, keyFlow.token) { + t.Errorf("The returned result is wrong. Expected %+v, got %+v", expectedKeyFlowToken, keyFlow.token) + } + } + }) + } +} + func TestTokenExpired(t *testing.T) { tokenExpirationLeeway := 5 * time.Second tests := []struct { From c00e7f098a5a1a5e3cba1c252c88caa0c9a497d9 Mon Sep 17 00:00:00 2001 From: Jorge Turrado Date: Mon, 15 Dec 2025 18:40:51 +0100 Subject: [PATCH 08/15] apply feedback Signed-off-by: Jorge Turrado --- core/clients/key_flow.go | 52 ++++++++++++++++++++-------------------- 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/core/clients/key_flow.go b/core/clients/key_flow.go index 46b5d91a0..d18d4f0bf 100644 --- a/core/clients/key_flow.go +++ b/core/clients/key_flow.go @@ -113,32 +113,6 @@ func (c *KeyFlow) GetToken() TokenResponseBody { return *c.token } -// SetToken can be used to set an access and refresh token manually in the client. -// The other fields in the token field are determined by inspecting the token or setting default values. -// Deprecated -func (c *KeyFlow) SetToken(accessToken, refreshToken string) error { - // We can safely use ParseUnverified because we are not authenticating the user, - // We are parsing the token just to get the expiration time claim - parsedAccessToken, _, err := jwt.NewParser().ParseUnverified(accessToken, &jwt.RegisteredClaims{}) - if err != nil { - return fmt.Errorf("parse access token to read expiration time: %w", err) - } - exp, err := parsedAccessToken.Claims.GetExpirationTime() - if err != nil { - return fmt.Errorf("get expiration time from access token: %w", err) - } - - c.tokenMutex.Lock() - c.token = &TokenResponseBody{ - AccessToken: accessToken, - ExpiresIn: int(exp.Time.Unix()), - Scope: "", - TokenType: "Bearer", - } - c.tokenMutex.Unlock() - return nil -} - func (c *KeyFlow) Init(cfg *KeyFlowConfig) error { // No concurrency at this point, so no mutex check needed c.token = &TokenResponseBody{} @@ -171,6 +145,32 @@ func (c *KeyFlow) Init(cfg *KeyFlowConfig) error { return nil } +// SetToken can be used to set an access and refresh token manually in the client. +// The other fields in the token field are determined by inspecting the token or setting default values. +// Deprecated +func (c *KeyFlow) SetToken(accessToken, refreshToken string) error { + // We can safely use ParseUnverified because we are not authenticating the user, + // We are parsing the token just to get the expiration time claim + parsedAccessToken, _, err := jwt.NewParser().ParseUnverified(accessToken, &jwt.RegisteredClaims{}) + if err != nil { + return fmt.Errorf("parse access token to read expiration time: %w", err) + } + exp, err := parsedAccessToken.Claims.GetExpirationTime() + if err != nil { + return fmt.Errorf("get expiration time from access token: %w", err) + } + + c.tokenMutex.Lock() + c.token = &TokenResponseBody{ + AccessToken: accessToken, + ExpiresIn: int(exp.Time.Unix()), + Scope: "", + TokenType: "Bearer", + } + c.tokenMutex.Unlock() + return nil +} + // Roundtrip performs the request func (c *KeyFlow) RoundTrip(req *http.Request) (*http.Response, error) { if c.rt == nil { From 5af53b20e8c657eca56f6f2b4582254340502f9d Mon Sep 17 00:00:00 2001 From: Jorge Turrado Date: Mon, 15 Dec 2025 19:18:07 +0100 Subject: [PATCH 09/15] apply feedback Signed-off-by: Jorge Turrado --- README.md | 47 +++++++++++++++++++++-- examples/authentication/authentication.go | 21 ---------- 2 files changed, 43 insertions(+), 25 deletions(-) diff --git a/README.md b/README.md index 69d23ae86..b2331eb88 100644 --- a/README.md +++ b/README.md @@ -105,13 +105,20 @@ To authenticate with the SDK, you need a [service account](https://docs.stackit. The SDK supports two authentication methods: -1. **Key Flow** (Recommended) +1. **Workload Identity Federation Flow** (Recommended) + + - Uses OIDC trusted tokens + - Provides best security through short-lived tokens without secrets + +> NOTE: This flow isn't publicly available yet. It'll be public during Q1 2026 + +2. **Key Flow** (Recommended) - Uses RSA key-pair based authentication - Provides better security through short-lived tokens - Supports both STACKIT-generated and custom key pairs -2. **Token Flow** +3. **Token Flow** - Uses long-lived service account tokens - Simpler but less secure @@ -120,10 +127,42 @@ The SDK supports two authentication methods: The SDK searches for credentials in the following order: 1. Explicit configuration in code -2. Environment variables (KEY_PATH for KEY) +2. Environment variables 3. Credentials file (`$HOME/.stackit/credentials.json`) -For each authentication method, the key flow is attempted first, followed by the token flow. +For each authentication method, the try order is: +1. Workload Identity Federation Flow +2. Key Flow +3. Token Flow + +### Using the Workload Identity Fedearion Flow + +1. Create a service account trusted relation in the STACKIT Portal: + + - Navigate to `Service Accounts` → Select account → `Federated Identity Providers` → Add a Federated Identity Provider + - Configure the trusted issuer and the required assertions to trust in. (Link to official docs here after GA) + +2. Configure authentication using any of these methods: + + **A. Code Configuration** + + ```go + // Using wokload identity federation flow + config.WithWorkloadIdentityFederationAuth() + // With the custom path for the external OIDC token + config.WithWorkloadIdentityFederationTokenPath("/path/to/your/federated/token") + // For the service account + config.WithServiceAccountEmail("my-sa@sa-stackit.cloud") + ``` + + **B. Environment Variables** + + ```bash + # With the custom path for the external OIDC token + STACKIT_FEDERATED_TOKEN_FILE=/path/to/your/federated/token + # For the service account + STACKIT_SERVICE_ACCOUNT_EMAIL=my-sa@sa-stackit.cloud + ``` ### Using the Key Flow diff --git a/examples/authentication/authentication.go b/examples/authentication/authentication.go index b398b19a9..8ec2a84db 100644 --- a/examples/authentication/authentication.go +++ b/examples/authentication/authentication.go @@ -72,27 +72,6 @@ func main() { fmt.Printf("[DNS API] Number of zones: %v\n", len(*getZoneResp.Zones)) } - // Create a new API client, that will authenticate using the wif flow - // You need to create a service account key and configure the federate identity provider, - // then you can init the SDK using default env var - os.Setenv("STACKIT_SERVICE_ACCOUNT_EMAIL", "my-sa@sa-stackit.cloud") - os.Setenv("STACKIT_FEDERATED_TOKEN_FILE", "/path/to/your/federated/token") // Default "/var/run/secrets/stackit.cloud/serviceaccount/token" - os.Setenv("STACKIT_IDP_ENDPOINT", "custom token endpoint") // Default "https://accounts.stackit.cloud/oauth/v2/token" - dnsClient, err = dns.NewAPIClient() - if err != nil { - fmt.Fprintf(os.Stderr, "[DNS API] Creating API client: %v\n", err) - os.Exit(1) - } - - // Check that you can make an authenticated request - getZoneResp, err = dnsClient.ListZones(context.Background(), projectId).Execute() - - if err != nil { - fmt.Fprintf(os.Stderr, "[DNS API] Error when calling `ZoneApi.GetZones`: %v\n", err) - } else { - fmt.Printf("[DNS API] Number of zones: %v\n", len(*getZoneResp.Zones)) - } - // Create a new API client, that will authenticate using the wif flow // You need to create a service account key and configure the federate identity provider, // then you can init the SDK setting fields From e13f147959d0ebc45dd6d912e06370800849f8da Mon Sep 17 00:00:00 2001 From: Jorge Turrado Date: Mon, 15 Dec 2025 19:40:02 +0100 Subject: [PATCH 10/15] apply feedback Signed-off-by: Jorge Turrado --- examples/authentication/authentication.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/authentication/authentication.go b/examples/authentication/authentication.go index 8ec2a84db..64758bd87 100644 --- a/examples/authentication/authentication.go +++ b/examples/authentication/authentication.go @@ -14,7 +14,8 @@ func main() { // When creating a new API client without providing any configuration, it will setup default authentication. // The SDK will search for a valid service account key or token in several locations. - // It will first try to use the key flow, by looking into the variables STACKIT_SERVICE_ACCOUNT_KEY, STACKIT_SERVICE_ACCOUNT_KEY_PATH, + // It will first try to use the workload identity federation flow by looking into the variables STACKIT_FEDERATED_TOKEN_FILE, STACKIT_SERVICE_ACCOUNT_EMAIL and their default values, + // Then, it will try key flow, by looking into the variables STACKIT_SERVICE_ACCOUNT_KEY, STACKIT_SERVICE_ACCOUNT_KEY_PATH, // STACKIT_PRIVATE_KEY and STACKIT_PRIVATE_KEY_PATH. If the keys cannot be retrieved, it will check the credentials file located in STACKIT_CREDENTIALS_PATH, if specified, or in // $HOME/.stackit/credentials.json as a fallback. If the key are found and are valid, the KeyAuth flow is used. // If the key flow cannot be used, it will try to find a token in the STACKIT_SERVICE_ACCOUNT_TOKEN. If not present, it will From 84c2f5b06f1e04b7c461daa91e27684fbfedc95a Mon Sep 17 00:00:00 2001 From: Jorge Turrado Date: Sun, 21 Dec 2025 22:41:42 +0100 Subject: [PATCH 11/15] remove docs from PR Signed-off-by: Jorge Turrado --- CHANGELOG.md | 1 + README.md | 49 +++-------------------- core/CHANGELOG.md | 1 - examples/authentication/authentication.go | 43 +++----------------- 4 files changed, 11 insertions(+), 83 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index cc2298ced..3b6f1eb5a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,7 @@ - `core`: - [v0.21.0](core/CHANGELOG.md#v0210) - **Chore:** Use `jwt-bearer` grant to get a fresh token instead of `refresh_token` + - **Feature:** Support Workload Identity Federation flow - `sfs`: - [v0.2.0](services/sfs/CHANGELOG.md) - **Breaking change:** Remove region configuration in `APIClient` diff --git a/README.md b/README.md index b2331eb88..9ca8dcace 100644 --- a/README.md +++ b/README.md @@ -105,20 +105,13 @@ To authenticate with the SDK, you need a [service account](https://docs.stackit. The SDK supports two authentication methods: -1. **Workload Identity Federation Flow** (Recommended) - - - Uses OIDC trusted tokens - - Provides best security through short-lived tokens without secrets - -> NOTE: This flow isn't publicly available yet. It'll be public during Q1 2026 - -2. **Key Flow** (Recommended) +1. **Key Flow** (Recommended) - Uses RSA key-pair based authentication - Provides better security through short-lived tokens - Supports both STACKIT-generated and custom key pairs -3. **Token Flow** +2. **Token Flow** - Uses long-lived service account tokens - Simpler but less secure @@ -127,42 +120,10 @@ The SDK supports two authentication methods: The SDK searches for credentials in the following order: 1. Explicit configuration in code -2. Environment variables +2. Environment variables (KEY_PATH for KEY) 3. Credentials file (`$HOME/.stackit/credentials.json`) -For each authentication method, the try order is: -1. Workload Identity Federation Flow -2. Key Flow -3. Token Flow - -### Using the Workload Identity Fedearion Flow - -1. Create a service account trusted relation in the STACKIT Portal: - - - Navigate to `Service Accounts` → Select account → `Federated Identity Providers` → Add a Federated Identity Provider - - Configure the trusted issuer and the required assertions to trust in. (Link to official docs here after GA) - -2. Configure authentication using any of these methods: - - **A. Code Configuration** - - ```go - // Using wokload identity federation flow - config.WithWorkloadIdentityFederationAuth() - // With the custom path for the external OIDC token - config.WithWorkloadIdentityFederationTokenPath("/path/to/your/federated/token") - // For the service account - config.WithServiceAccountEmail("my-sa@sa-stackit.cloud") - ``` - - **B. Environment Variables** - - ```bash - # With the custom path for the external OIDC token - STACKIT_FEDERATED_TOKEN_FILE=/path/to/your/federated/token - # For the service account - STACKIT_SERVICE_ACCOUNT_EMAIL=my-sa@sa-stackit.cloud - ``` +For each authentication method, the key flow is attempted first, followed by the token flow. ### Using the Key Flow @@ -273,4 +234,4 @@ See the [release documentation](./RELEASE.md) for further information. ## License -Apache 2.0 +Apache 2.0 \ No newline at end of file diff --git a/core/CHANGELOG.md b/core/CHANGELOG.md index aaaa0636c..1e8466cac 100644 --- a/core/CHANGELOG.md +++ b/core/CHANGELOG.md @@ -13,7 +13,6 @@ ## v0.18.0 - **New:** Added duration utils -- **Chore:** Use `jwt-bearer` grant to get a fresh token instead of `refresh_token` ## v0.17.3 - **Dependencies:** Bump `github.com/golang-jwt/jwt/v5` from `v5.2.2` to `v5.2.3` diff --git a/examples/authentication/authentication.go b/examples/authentication/authentication.go index 64758bd87..cb0357b19 100644 --- a/examples/authentication/authentication.go +++ b/examples/authentication/authentication.go @@ -14,8 +14,7 @@ func main() { // When creating a new API client without providing any configuration, it will setup default authentication. // The SDK will search for a valid service account key or token in several locations. - // It will first try to use the workload identity federation flow by looking into the variables STACKIT_FEDERATED_TOKEN_FILE, STACKIT_SERVICE_ACCOUNT_EMAIL and their default values, - // Then, it will try key flow, by looking into the variables STACKIT_SERVICE_ACCOUNT_KEY, STACKIT_SERVICE_ACCOUNT_KEY_PATH, + // It will first try to use the key flow, by looking into the variables STACKIT_SERVICE_ACCOUNT_KEY, STACKIT_SERVICE_ACCOUNT_KEY_PATH, // STACKIT_PRIVATE_KEY and STACKIT_PRIVATE_KEY_PATH. If the keys cannot be retrieved, it will check the credentials file located in STACKIT_CREDENTIALS_PATH, if specified, or in // $HOME/.stackit/credentials.json as a fallback. If the key are found and are valid, the KeyAuth flow is used. // If the key flow cannot be used, it will try to find a token in the STACKIT_SERVICE_ACCOUNT_TOKEN. If not present, it will @@ -36,27 +35,18 @@ func main() { // Create a new API client, that will authenticate using the provided bearer token token := "TOKEN" - dnsClient, err := dns.NewAPIClient(config.WithToken(token)) + _, err = dns.NewAPIClient(config.WithToken(token)) if err != nil { fmt.Fprintf(os.Stderr, "[DNS API] Creating API client: %v\n", err) os.Exit(1) } - // Check that you can make an authenticated request - getZoneResp, err := dnsClient.ListZones(context.Background(), projectId).Execute() - - if err != nil { - fmt.Fprintf(os.Stderr, "[DNS API] Error when calling `ZoneApi.GetZones`: %v\n", err) - } else { - fmt.Printf("[DNS API] Number of zones: %v\n", len(*getZoneResp.Zones)) - } - // Create a new API client, that will authenticate using the key flow // If you created a service account key and provided your own RSA key pair, // you need to add the path to a PEM encoded file including the private key // using config.WithPrivateKeyPath("path/to/private_key.pem") saKeyPath := "/path/to/service_account_key.json" - dnsClient, err = dns.NewAPIClient( + dnsClient, err := dns.NewAPIClient( config.WithServiceAccountKeyPath(saKeyPath), ) if err != nil { @@ -65,34 +55,11 @@ func main() { } // Check that you can make an authenticated request - getZoneResp, err = dnsClient.ListZones(context.Background(), projectId).Execute() - - if err != nil { - fmt.Fprintf(os.Stderr, "[DNS API] Error when calling `ZoneApi.GetZones`: %v\n", err) - } else { - fmt.Printf("[DNS API] Number of zones: %v\n", len(*getZoneResp.Zones)) - } - - // Create a new API client, that will authenticate using the wif flow - // You need to create a service account key and configure the federate identity provider, - // then you can init the SDK setting fields - dnsClient, err = dns.NewAPIClient( - config.WithWorkloadIdentityFederationAuth(), - config.WithTokenEndpoint("custom token endpoint"), - config.WithWorkloadIdentityFederationTokenPath("/path/to/your/federated/token"), - config.WithServiceAccountEmail("my-sa@sa-stackit.cloud"), - ) - if err != nil { - fmt.Fprintf(os.Stderr, "[DNS API] Creating API client: %v\n", err) - os.Exit(1) - } - - // Check that you can make an authenticated request - getZoneResp, err = dnsClient.ListZones(context.Background(), projectId).Execute() + getZoneResp, err := dnsClient.ListZones(context.Background(), projectId).Execute() if err != nil { fmt.Fprintf(os.Stderr, "[DNS API] Error when calling `ZoneApi.GetZones`: %v\n", err) } else { fmt.Printf("[DNS API] Number of zones: %v\n", len(*getZoneResp.Zones)) } -} +} \ No newline at end of file From 125637a0f1cbd68e18b8173ba021521ec3340836 Mon Sep 17 00:00:00 2001 From: Jorge Turrado Date: Sun, 21 Dec 2025 23:40:49 +0100 Subject: [PATCH 12/15] remove docs from PR Signed-off-by: Jorge Turrado --- examples/authentication/authentication.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/authentication/authentication.go b/examples/authentication/authentication.go index cb0357b19..839999938 100644 --- a/examples/authentication/authentication.go +++ b/examples/authentication/authentication.go @@ -62,4 +62,4 @@ func main() { } else { fmt.Printf("[DNS API] Number of zones: %v\n", len(*getZoneResp.Zones)) } -} \ No newline at end of file +} From cd47f0e33a4ec37e229273124e347dd2ccdad4af Mon Sep 17 00:00:00 2001 From: Jorge Turrado Date: Sun, 21 Dec 2025 23:58:35 +0100 Subject: [PATCH 13/15] remove docs from PR Signed-off-by: Jorge Turrado --- examples/authentication/authentication.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/authentication/authentication.go b/examples/authentication/authentication.go index 839999938..cb0357b19 100644 --- a/examples/authentication/authentication.go +++ b/examples/authentication/authentication.go @@ -62,4 +62,4 @@ func main() { } else { fmt.Printf("[DNS API] Number of zones: %v\n", len(*getZoneResp.Zones)) } -} +} \ No newline at end of file From 39961b0b9ad03f61e38b8fd69a25642f3c8e8701 Mon Sep 17 00:00:00 2001 From: Jorge Turrado Date: Mon, 22 Dec 2025 00:03:43 +0100 Subject: [PATCH 14/15] remove docs from PR Signed-off-by: Jorge Turrado --- examples/authentication/authentication.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/authentication/authentication.go b/examples/authentication/authentication.go index cb0357b19..839999938 100644 --- a/examples/authentication/authentication.go +++ b/examples/authentication/authentication.go @@ -62,4 +62,4 @@ func main() { } else { fmt.Printf("[DNS API] Number of zones: %v\n", len(*getZoneResp.Zones)) } -} \ No newline at end of file +} From 960a8fc17ae996002e919c43232e5d3bfdad23fa Mon Sep 17 00:00:00 2001 From: Jorge Turrado Date: Tue, 23 Dec 2025 15:28:03 +0100 Subject: [PATCH 15/15] add static token Signed-off-by: Jorge Turrado --- core/auth/auth.go | 5 +- core/clients/workload_identity_flow.go | 27 +- core/clients/workload_identity_flow_test.go | 314 ++------------------ core/config/config.go | 51 ++-- 4 files changed, 63 insertions(+), 334 deletions(-) diff --git a/core/auth/auth.go b/core/auth/auth.go index e3b10bc46..450361c60 100644 --- a/core/auth/auth.go +++ b/core/auth/auth.go @@ -238,8 +238,9 @@ func WorkloadIdentityFederationAuth(cfg *config.Configuration) (http.RoundTrippe TokenUrl: cfg.TokenCustomUrl, BackgroundTokenRefreshContext: cfg.BackgroundTokenRefreshContext, ClientID: cfg.ServiceAccountEmail, - FederatedTokenFilePath: cfg.WorkloadIdentityFederationFederatedTokenPath, - TokenExpiration: cfg.WorkloadIdentityFederationTokenExpiration, + FederatedTokenFilePath: cfg.ServiceAccountFederatedTokenPath, + TokenExpiration: cfg.ServiceAccountFederatedTokenExpiration, + FederatedToken: cfg.ServiceAccountFederatedToken, } if cfg.HTTPClient != nil && cfg.HTTPClient.Transport != nil { diff --git a/core/clients/workload_identity_flow.go b/core/clients/workload_identity_flow.go index 65b6fc461..0046ec864 100644 --- a/core/clients/workload_identity_flow.go +++ b/core/clients/workload_identity_flow.go @@ -59,6 +59,7 @@ type WorkloadIdentityFederationFlow struct { type WorkloadIdentityFederationFlowConfig struct { TokenUrl string ClientID string + FederatedToken string // Static token string. This is optional, if not set the token will be read from file. FederatedTokenFilePath string TokenExpiration string // Not supported yet BackgroundTokenRefreshContext context.Context // Functionality is enabled if this isn't nil @@ -139,7 +140,7 @@ func (c *WorkloadIdentityFederationFlow) Init(cfg *WorkloadIdentityFederationFlo c.config.ClientID = getEnvOrDefault(clientIDEnv, "") } - if c.config.FederatedTokenFilePath == "" { + if c.config.FederatedToken == "" && c.config.FederatedTokenFilePath == "" { c.config.FederatedTokenFilePath = getEnvOrDefault(FederatedTokenFileEnv, defaultFederatedTokenPath) } @@ -161,12 +162,6 @@ func (c *WorkloadIdentityFederationFlow) Init(cfg *WorkloadIdentityFederationFlo return err } - // // Init the token - // _, err = c.GetAccessToken() - // if err != nil { - // return err - // } - if c.config.BackgroundTokenRefreshContext != nil { go continuousRefreshToken(c) } @@ -181,8 +176,10 @@ func (c *WorkloadIdentityFederationFlow) validate() error { if c.config.TokenUrl == "" { return fmt.Errorf("token URL cannot be empty") } - if _, err := c.readJWTFromFileSystem(c.config.FederatedTokenFilePath); err != nil { - return fmt.Errorf("error reading federated token file - %w", err) + if c.config.FederatedToken == "" { + if _, err := c.readJWTFromFileSystem(c.config.FederatedTokenFilePath); err != nil { + return fmt.Errorf("error reading federated token file - %w", err) + } } if c.tokenExpirationLeeway < 0 { return fmt.Errorf("token expiration leeway cannot be negative") @@ -192,10 +189,14 @@ func (c *WorkloadIdentityFederationFlow) validate() error { } // createAccessToken creates an access token using self signed JWT -func (c *WorkloadIdentityFederationFlow) createAccessToken() (err error) { - clientAssertion, err := c.readJWTFromFileSystem(c.config.FederatedTokenFilePath) - if err != nil { - return fmt.Errorf("error reading service account assertion - %w", err) +func (c *WorkloadIdentityFederationFlow) createAccessToken() error { + clientAssertion := c.config.FederatedToken + if clientAssertion == "" { + var err error + clientAssertion, err = c.readJWTFromFileSystem(c.config.FederatedTokenFilePath) + if err != nil { + return fmt.Errorf("error reading service account assertion - %w", err) + } } res, err := c.requestToken(c.config.ClientID, clientAssertion) diff --git a/core/clients/workload_identity_flow_test.go b/core/clients/workload_identity_flow_test.go index ef8f7a15f..4a9e07161 100644 --- a/core/clients/workload_identity_flow_test.go +++ b/core/clients/workload_identity_flow_test.go @@ -158,6 +158,7 @@ func TestWorkloadIdentityFlowRoundTrip(t *testing.T) { name string clientID string validAssertion bool + injectToken bool wantErr bool }{ { @@ -166,6 +167,13 @@ func TestWorkloadIdentityFlowRoundTrip(t *testing.T) { validAssertion: true, wantErr: false, }, + { + name: "injected token ok", + clientID: "test@stackit.cloud", + validAssertion: true, + injectToken: true, + wantErr: false, + }, { name: "invalid assertion", clientID: "test@stackit.cloud", @@ -243,12 +251,6 @@ func TestWorkloadIdentityFlowRoundTrip(t *testing.T) { flowConfig.TokenUrl = authServer.URL flowConfig.ClientID = tt.clientID - file, err := os.CreateTemp("", "*.token") - if err != nil { - log.Fatal(err) - } - defer os.Remove(file.Name()) - flowConfig.FederatedTokenFilePath = file.Name() subject := "wrong" if tt.validAssertion { @@ -258,7 +260,18 @@ func TestWorkloadIdentityFlowRoundTrip(t *testing.T) { if err != nil { t.Fatalf("failed to create token: %v", err) } - os.WriteFile(file.Name(), []byte(token), os.ModeAppend) + + if tt.injectToken { + flowConfig.FederatedToken = token + } else { + file, err := os.CreateTemp("", "*.token") + if err != nil { + log.Fatal(err) + } + defer os.Remove(file.Name()) + flowConfig.FederatedTokenFilePath = file.Name() + os.WriteFile(file.Name(), []byte(token), os.ModeAppend) + } if err := flow.Init(flowConfig); err != nil { t.Errorf("KeyFlow.Init() error = %v", err) @@ -277,290 +290,3 @@ func TestWorkloadIdentityFlowRoundTrip(t *testing.T) { }) } } - -// func TestRequestToken(t *testing.T) { -// testCases := []struct { -// name string -// grant string -// assertion string -// mockResponse *http.Response -// mockError error -// expectedError error -// }{ -// { -// name: "Success", -// grant: "test_grant", -// assertion: "test_assertion", -// mockResponse: &http.Response{ -// StatusCode: 200, -// Body: io.NopCloser(strings.NewReader(`{"access_token": "test_token"}`)), -// }, -// mockError: nil, -// expectedError: nil, -// }, -// { -// name: "Error", -// grant: "test_grant", -// assertion: "test_assertion", -// mockResponse: nil, -// mockError: fmt.Errorf("request error"), -// expectedError: fmt.Errorf("request error"), -// }, -// } - -// for _, tt := range testCases { -// t.Run(tt.name, func(t *testing.T) { -// keyFlow := &KeyFlow{} -// privateKeyBytes, err := generatePrivateKey() -// if err != nil { -// t.Fatalf("Error generating private key: %s", err) -// } -// keyFlowConfig := &KeyFlowConfig{ -// AuthHTTPClient: &http.Client{ -// Transport: mockTransportFn{func(_ *http.Request) (*http.Response, error) { -// return tt.mockResponse, tt.mockError -// }}, -// }, -// ServiceAccountKey: fixtureServiceAccountKey(), -// PrivateKey: string(privateKeyBytes), -// HTTPTransport: http.DefaultTransport, -// } -// err = keyFlow.Init(keyFlowConfig) -// if err != nil { -// t.Fatalf("failed to initialize key flow: %v", err) -// } - -// res, err := keyFlow.requestToken(tt.grant, tt.assertion) -// defer func() { -// if res != nil { -// tempErr := res.Body.Close() -// if tempErr != nil { -// t.Errorf("closing request token response: %s", tempErr.Error()) -// } -// } -// }() -// if tt.expectedError != nil { -// if err == nil { -// t.Errorf("Expected error '%v' but no error was returned", tt.expectedError) -// } else if errors.Is(err, tt.expectedError) { -// t.Errorf("Error is not correct. Expected %v, got %v", tt.expectedError, err) -// } -// } else { -// if err != nil { -// t.Errorf("Expected no error but error was returned: %v", err) -// } -// if !cmp.Equal(tt.mockResponse, res, cmp.AllowUnexported(strings.Reader{})) { -// t.Errorf("The returned result is wrong. Expected %v, got %v", tt.mockResponse, res) -// } -// } -// }) -// } -// } - -// func TestKeyFlow_Do(t *testing.T) { -// t.Parallel() - -// tests := []struct { -// name string -// handlerFn func(tb testing.TB) http.HandlerFunc -// want int -// wantErr bool -// }{ -// { -// name: "success", -// handlerFn: func(tb testing.TB) http.HandlerFunc { -// tb.Helper() - -// return func(w http.ResponseWriter, r *http.Request) { -// if r.Header.Get("Authorization") != "Bearer "+testBearerToken { -// tb.Errorf("expected Authorization header to be 'Bearer %s', but got %s", testBearerToken, r.Header.Get("Authorization")) -// } - -// w.Header().Set("Content-Type", "application/json") -// w.WriteHeader(http.StatusOK) -// _, _ = fmt.Fprintln(w, `{"status":"ok"}`) -// } -// }, -// want: http.StatusOK, -// wantErr: false, -// }, -// { -// name: "success with code 500", -// handlerFn: func(_ testing.TB) http.HandlerFunc { -// return func(w http.ResponseWriter, _ *http.Request) { -// w.Header().Set("Content-Type", "text/html") -// w.WriteHeader(http.StatusInternalServerError) -// _, _ = fmt.Fprintln(w, `Internal Server Error`) -// } -// }, -// want: http.StatusInternalServerError, -// wantErr: false, -// }, -// { -// name: "success with custom transport", -// handlerFn: func(tb testing.TB) http.HandlerFunc { -// tb.Helper() - -// return func(w http.ResponseWriter, r *http.Request) { -// if r.Header.Get("User-Agent") != "custom_transport" { -// tb.Errorf("expected User-Agent header to be 'custom_transport', but got %s", r.Header.Get("User-Agent")) -// } - -// w.Header().Set("Content-Type", "application/json") -// w.WriteHeader(http.StatusOK) -// _, _ = fmt.Fprintln(w, `{"status":"ok"}`) -// } -// }, -// want: http.StatusOK, -// wantErr: false, -// }, -// { -// name: "fail with custom proxy", -// handlerFn: func(testing.TB) http.HandlerFunc { -// return func(w http.ResponseWriter, _ *http.Request) { -// w.Header().Set("Content-Type", "application/json") -// w.WriteHeader(http.StatusOK) -// _, _ = fmt.Fprintln(w, `{"status":"ok"}`) -// } -// }, -// want: 0, -// wantErr: true, -// }, -// } -// for _, tt := range tests { -// t.Run(tt.name, func(t *testing.T) { -// ctx := context.Background() -// ctx, cancel := context.WithCancel(ctx) -// t.Cleanup(cancel) // This cancels the refresher goroutine - -// privateKeyBytes, err := generatePrivateKey() -// if err != nil { -// t.Errorf("no error is expected, but got %v", err) -// } - -// keyFlow := &KeyFlow{} -// keyFlowConfig := &KeyFlowConfig{ -// ServiceAccountKey: fixtureServiceAccountKey(), -// PrivateKey: string(privateKeyBytes), -// BackgroundTokenRefreshContext: ctx, -// HTTPTransport: func() http.RoundTripper { -// switch tt.name { -// case "success with custom transport": -// return mockTransportFn{ -// fn: func(req *http.Request) (*http.Response, error) { -// req.Header.Set("User-Agent", "custom_transport") -// return http.DefaultTransport.RoundTrip(req) -// }, -// } -// case "fail with custom proxy": -// return &http.Transport{ -// Proxy: func(_ *http.Request) (*url.URL, error) { -// return nil, fmt.Errorf("proxy error") -// }, -// } -// default: -// return http.DefaultTransport -// } -// }(), -// AuthHTTPClient: &http.Client{ -// Transport: mockTransportFn{ -// fn: func(_ *http.Request) (*http.Response, error) { -// res := httptest.NewRecorder() -// res.WriteHeader(http.StatusOK) -// res.Header().Set("Content-Type", "application/json") - -// token := &TokenResponseBody{ -// AccessToken: testBearerToken, -// ExpiresIn: 2147483647, -// TokenType: "Bearer", -// } - -// if err := json.NewEncoder(res.Body).Encode(token); err != nil { -// t.Logf("no error is expected, but got %v", err) -// } - -// return res.Result(), nil -// }, -// }, -// }, -// } -// err = keyFlow.Init(keyFlowConfig) -// if err != nil { -// t.Fatalf("failed to initialize key flow: %v", err) -// } - -// go continuousRefreshToken(keyFlow) - -// tokenCtx, tokenCancel := context.WithTimeout(context.Background(), 1*time.Second) - -// token: -// for { -// select { -// case <-tokenCtx.Done(): -// t.Error(tokenCtx.Err()) -// case <-time.After(50 * time.Millisecond): -// keyFlow.tokenMutex.RLock() -// if keyFlow.token != nil { -// keyFlow.tokenMutex.RUnlock() -// tokenCancel() -// break token -// } - -// keyFlow.tokenMutex.RUnlock() -// } -// } - -// server := httptest.NewServer(tt.handlerFn(t)) -// t.Cleanup(server.Close) - -// u, err := url.Parse(server.URL) -// if err != nil { -// t.Errorf("no error is expected, but got %v", err) -// } - -// req, err := http.NewRequest(http.MethodGet, u.String(), http.NoBody) -// if err != nil { -// t.Errorf("no error is expected, but got %v", err) -// } - -// httpClient := &http.Client{ -// Transport: keyFlow, -// } - -// res, err := httpClient.Do(req) - -// if tt.wantErr { -// if err == nil { -// t.Errorf("error is expected, but got %v", err) -// } -// } else { -// if err != nil { -// t.Errorf("no error is expected, but got %v", err) -// } - -// if res.StatusCode != tt.want { -// t.Errorf("expected status code %d, but got %d", tt.want, res.StatusCode) -// } - -// // Defer discard and close the body -// t.Cleanup(func() { -// if _, err := io.Copy(io.Discard, res.Body); err != nil { -// t.Errorf("no error is expected, but got %v", err) -// } - -// if err := res.Body.Close(); err != nil { -// t.Errorf("no error is expected, but got %v", err) -// } -// }) -// } -// }) -// } -// } - -// type mockTransportFn struct { -// fn func(req *http.Request) (*http.Response, error) -// } - -// func (m mockTransportFn) RoundTrip(req *http.Request) (*http.Response, error) { -// return m.fn(req) -// } diff --git a/core/config/config.go b/core/config/config.go index ae2d8c498..dd9dd98f4 100644 --- a/core/config/config.go +++ b/core/config/config.go @@ -75,29 +75,30 @@ type Middleware func(http.RoundTripper) http.RoundTripper // Configuration stores the configuration of the API client type Configuration struct { - Host string `json:"host,omitempty"` - Scheme string `json:"scheme,omitempty"` - DefaultHeader map[string]string `json:"defaultHeader,omitempty"` - UserAgent string `json:"userAgent,omitempty"` - Debug bool `json:"debug,omitempty"` - NoAuth bool `json:"noAuth,omitempty"` - WorkloadIdentityFederation bool `json:"workloadIdentityFederation,omitempty"` - WorkloadIdentityFederationTokenExpiration string `json:"workloadIdentityFederationTokenExpiration,omitempty"` - WorkloadIdentityFederationFederatedTokenPath string `json:"workloadIdentityFederationFederatedTokenPath,omitempty"` - ServiceAccountEmail string `json:"serviceAccountEmail,omitempty"` - Token string `json:"token,omitempty"` - ServiceAccountKey string `json:"serviceAccountKey,omitempty"` - PrivateKey string `json:"privateKey,omitempty"` - ServiceAccountKeyPath string `json:"serviceAccountKeyPath,omitempty"` - PrivateKeyPath string `json:"privateKeyPath,omitempty"` - CredentialsFilePath string `json:"credentialsFilePath,omitempty"` - TokenCustomUrl string `json:"tokenCustomUrl,omitempty"` - Region string `json:"region,omitempty"` - CustomAuth http.RoundTripper - Servers ServerConfigurations - OperationServers map[string]ServerConfigurations - HTTPClient *http.Client - Middleware []Middleware + Host string `json:"host,omitempty"` + Scheme string `json:"scheme,omitempty"` + DefaultHeader map[string]string `json:"defaultHeader,omitempty"` + UserAgent string `json:"userAgent,omitempty"` + Debug bool `json:"debug,omitempty"` + NoAuth bool `json:"noAuth,omitempty"` + WorkloadIdentityFederation bool `json:"workloadIdentityFederation,omitempty"` + ServiceAccountFederatedTokenExpiration string `json:"serviceAccountFederatedTokenExpiration,omitempty"` + ServiceAccountFederatedToken string `json:"serviceAccountFederatedToken,omitempty"` + ServiceAccountFederatedTokenPath string `json:"serviceAccountFederatedTokenPath,omitempty"` + ServiceAccountEmail string `json:"serviceAccountEmail,omitempty"` + Token string `json:"token,omitempty"` + ServiceAccountKey string `json:"serviceAccountKey,omitempty"` + PrivateKey string `json:"privateKey,omitempty"` + ServiceAccountKeyPath string `json:"serviceAccountKeyPath,omitempty"` + PrivateKeyPath string `json:"privateKeyPath,omitempty"` + CredentialsFilePath string `json:"credentialsFilePath,omitempty"` + TokenCustomUrl string `json:"tokenCustomUrl,omitempty"` + Region string `json:"region,omitempty"` + CustomAuth http.RoundTripper + Servers ServerConfigurations + OperationServers map[string]ServerConfigurations + HTTPClient *http.Client + Middleware []Middleware // If != nil, a goroutine will be launched that will refresh the service account's access token when it's close to being expired. // The goroutine is killed whenever this context is canceled. @@ -249,7 +250,7 @@ func WithWorkloadIdentityFederationAuth() ConfigurationOption { // WithWorkloadIdentityFederation returns a ConfigurationOption that sets workload identity flow to be used for authentication in API calls func WithWorkloadIdentityFederationTokenPath(path string) ConfigurationOption { return func(config *Configuration) error { - config.WorkloadIdentityFederationFederatedTokenPath = path + config.ServiceAccountFederatedTokenPath = path return nil } } @@ -257,7 +258,7 @@ func WithWorkloadIdentityFederationTokenPath(path string) ConfigurationOption { // WithWorkloadIdentityFederationTokenExpiration returns a ConfigurationOption that sets the token expiration for workload identity federation flow func WithWorkloadIdentityFederationTokenExpiration(expiration string) ConfigurationOption { return func(config *Configuration) error { - config.WorkloadIdentityFederationTokenExpiration = expiration + config.ServiceAccountFederatedTokenExpiration = expiration return nil } }