From fdc3c06819b6933d269739e156dfb614aa758ec5 Mon Sep 17 00:00:00 2001 From: Kodey Kilday-Thomas Date: Tue, 30 Dec 2025 11:12:42 +0000 Subject: [PATCH 1/2] wip: move auth into common --- .../middlewares/anonymous_auth_middleware.go | 2 +- .../anonymous_auth_middleware_test.go | 2 +- .../pkg/middlewares/hmac_auth_middleware.go | 65 +++------------- .../middlewares/hmac_auth_middleware_test.go | 9 +-- .../middlewares/rate_limiting_middleware.go | 10 +-- .../rate_limiting_middleware_test.go | 18 ++--- .../middlewares/require_auth_middleware.go | 2 +- aggregator/pkg/model/config.go | 53 ++----------- aggregator/pkg/scope/scope.go | 2 +- aggregator/tests/utils.go | 15 ++-- build/devenv/services/aggregator.go | 7 +- common/auth/auth.go | 76 +++++++++++++++++++ common/auth/config.go | 48 ++++++++++++ {aggregator/pkg => common}/auth/identity.go | 0 .../pkg => common}/auth/identity_test.go | 0 15 files changed, 176 insertions(+), 133 deletions(-) create mode 100644 common/auth/auth.go create mode 100644 common/auth/config.go rename {aggregator/pkg => common}/auth/identity.go (100%) rename {aggregator/pkg => common}/auth/identity_test.go (100%) diff --git a/aggregator/pkg/middlewares/anonymous_auth_middleware.go b/aggregator/pkg/middlewares/anonymous_auth_middleware.go index 182e0ccec..6a11754ef 100644 --- a/aggregator/pkg/middlewares/anonymous_auth_middleware.go +++ b/aggregator/pkg/middlewares/anonymous_auth_middleware.go @@ -10,7 +10,7 @@ import ( "google.golang.org/grpc/metadata" "google.golang.org/grpc/peer" - "github.com/smartcontractkit/chainlink-ccv/aggregator/pkg/auth" + "github.com/smartcontractkit/chainlink-ccv/common/auth" "github.com/smartcontractkit/chainlink-common/pkg/logger" ) diff --git a/aggregator/pkg/middlewares/anonymous_auth_middleware_test.go b/aggregator/pkg/middlewares/anonymous_auth_middleware_test.go index 2769da17d..c3e8d053c 100644 --- a/aggregator/pkg/middlewares/anonymous_auth_middleware_test.go +++ b/aggregator/pkg/middlewares/anonymous_auth_middleware_test.go @@ -10,7 +10,7 @@ import ( "google.golang.org/grpc/metadata" "google.golang.org/grpc/peer" - "github.com/smartcontractkit/chainlink-ccv/aggregator/pkg/auth" + "github.com/smartcontractkit/chainlink-ccv/common/auth" "github.com/smartcontractkit/chainlink-common/pkg/logger" ) diff --git a/aggregator/pkg/middlewares/hmac_auth_middleware.go b/aggregator/pkg/middlewares/hmac_auth_middleware.go index bfe3c328f..f260872e7 100644 --- a/aggregator/pkg/middlewares/hmac_auth_middleware.go +++ b/aggregator/pkg/middlewares/hmac_auth_middleware.go @@ -4,26 +4,23 @@ import ( "context" "google.golang.org/grpc" - "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" - "google.golang.org/grpc/status" - "github.com/smartcontractkit/chainlink-ccv/aggregator/pkg/auth" - "github.com/smartcontractkit/chainlink-ccv/aggregator/pkg/model" + "github.com/smartcontractkit/chainlink-ccv/common/auth" "github.com/smartcontractkit/chainlink-ccv/protocol/common/hmac" "github.com/smartcontractkit/chainlink-common/pkg/logger" ) type HMACAuthMiddleware struct { - apiKeyConfig *model.APIKeyConfig - logger logger.Logger + hmac auth.HMACAuth + logger logger.Logger } // NewHMACAuthMiddleware creates a new HMAC authentication middleware. -func NewHMACAuthMiddleware(config *model.APIKeyConfig, lggr logger.Logger) *HMACAuthMiddleware { +func NewHMACAuthMiddleware(config *auth.APIKeyConfig, lggr logger.Logger) *HMACAuthMiddleware { return &HMACAuthMiddleware{ - apiKeyConfig: config, - logger: lggr, + hmac: *auth.NewHMACAuth(config, lggr), + logger: lggr, } } @@ -43,56 +40,18 @@ func (m *HMACAuthMiddleware) Intercept(ctx context.Context, req any, info *grpc. return handler(ctx, req) } - // If some but not all HMAC headers are present, this is an error - if apiKey == "" { - m.logger.Warnf("Authentication failed: missing authorization header") - return nil, status.Error(codes.Unauthenticated, "missing authorization header") - } - if timestamp == "" { - m.logger.Warnf("Authentication failed: missing x-authorization-timestamp header") - return nil, status.Error(codes.Unauthenticated, "missing x-authorization-timestamp header") - } - if providedSignature == "" { - m.logger.Warnf("Authentication failed: missing x-authorization-signature-sha256 header") - return nil, status.Error(codes.Unauthenticated, "missing x-authorization-signature-sha256 header") - } - - client, exists := m.apiKeyConfig.GetClientByAPIKey(apiKey) - if !exists { - m.logger.Warnf("Authentication failed: invalid or disabled API key") - return nil, status.Error(codes.Unauthenticated, "invalid credentials") - } - - if len(client.Secrets) == 0 { - m.logger.Errorf("Client %s has no secrets configured", client.ClientID) - return nil, status.Error(codes.Internal, "authentication configuration error") - } - - if err := hmac.ValidateTimestamp(timestamp); err != nil { - m.logger.Warnf("Authentication failed for client %s: %v", client.ClientID, err) - return nil, status.Error(codes.Unauthenticated, "invalid or expired timestamp") - } - body, err := hmac.SerializeRequestBody(req) if err != nil { - m.logger.Errorf("Failed to serialize request body: %v", err) - return nil, status.Error(codes.Internal, "request serialization error") + m.logger.Error("Unable to seralize request body") + // Pass through to other auth mechanisms + return handler(ctx, req) } - bodyHash := hmac.ComputeBodyHash(body) - - stringToSign := hmac.GenerateStringToSign(hmac.HTTPMethodPost, info.FullMethod, bodyHash, apiKey, timestamp) - - if !hmac.ValidateSignature(stringToSign, providedSignature, client.Secrets) { - m.logger.Warnf("Authentication failed for client %s: invalid signature", client.ClientID) - return nil, status.Error(codes.Unauthenticated, "invalid signature") + ctx, err = m.hmac.Authorize(ctx, body, hmac.HTTPMethodPost, info.FullMethod, apiKey, timestamp, providedSignature) + if err != nil { + return nil, err } - identity := auth.CreateCallerIdentity(client.ClientID, false) - ctx = auth.ToContext(ctx, identity) - - m.logger.Debugf("Successfully authenticated client: %s", client.ClientID) - return handler(ctx, req) } diff --git a/aggregator/pkg/middlewares/hmac_auth_middleware_test.go b/aggregator/pkg/middlewares/hmac_auth_middleware_test.go index 17045d7d1..5ad60c8bd 100644 --- a/aggregator/pkg/middlewares/hmac_auth_middleware_test.go +++ b/aggregator/pkg/middlewares/hmac_auth_middleware_test.go @@ -13,8 +13,7 @@ import ( "google.golang.org/grpc/status" "google.golang.org/protobuf/proto" - "github.com/smartcontractkit/chainlink-ccv/aggregator/pkg/auth" - "github.com/smartcontractkit/chainlink-ccv/aggregator/pkg/model" + "github.com/smartcontractkit/chainlink-ccv/common/auth" "github.com/smartcontractkit/chainlink-common/pkg/logger" hmacutil "github.com/smartcontractkit/chainlink-ccv/protocol/common/hmac" @@ -40,9 +39,9 @@ func generateTestSignature( } // Test helper: creates test API key configuration. -func createTestAPIKeyConfig() *model.APIKeyConfig { - return &model.APIKeyConfig{ - Clients: map[string]*model.APIClient{ +func createTestAPIKeyConfig() *auth.APIKeyConfig { + return &auth.APIKeyConfig{ + Clients: map[string]*auth.APIClient{ testAPIKey1: { ClientID: "client-1", Description: "Test client 1", diff --git a/aggregator/pkg/middlewares/rate_limiting_middleware.go b/aggregator/pkg/middlewares/rate_limiting_middleware.go index bf6d4c998..95c4f749d 100644 --- a/aggregator/pkg/middlewares/rate_limiting_middleware.go +++ b/aggregator/pkg/middlewares/rate_limiting_middleware.go @@ -11,9 +11,9 @@ import ( "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" - "github.com/smartcontractkit/chainlink-ccv/aggregator/pkg/auth" "github.com/smartcontractkit/chainlink-ccv/aggregator/pkg/model" "github.com/smartcontractkit/chainlink-ccv/aggregator/pkg/rate_limiting" + "github.com/smartcontractkit/chainlink-ccv/common/auth" "github.com/smartcontractkit/chainlink-ccv/protocol" "github.com/smartcontractkit/chainlink-common/pkg/logger" ) @@ -23,13 +23,13 @@ var _ protocol.HealthReporter = (*RateLimitingMiddleware)(nil) type RateLimitingMiddleware struct { store limiter.Store config model.RateLimitingConfig - apiConfig model.APIKeyConfig + apiConfig auth.APIKeyConfig enabled bool lggr logger.SugaredLogger } // NewRateLimitingMiddleware creates a new rate limiting middleware with the given store and configuration. -func NewRateLimitingMiddleware(store limiter.Store, config model.RateLimitingConfig, apiConfig model.APIKeyConfig, lggr logger.SugaredLogger) *RateLimitingMiddleware { +func NewRateLimitingMiddleware(store limiter.Store, config model.RateLimitingConfig, apiConfig auth.APIKeyConfig, lggr logger.SugaredLogger) *RateLimitingMiddleware { if store == nil { return &RateLimitingMiddleware{enabled: false} } @@ -49,7 +49,7 @@ func (m *RateLimitingMiddleware) buildKey(callerID, method string) string { func (m *RateLimitingMiddleware) getLimitConfig(callerID, method string) (model.RateLimitConfig, bool) { // Get the API client configuration for the caller - var apiClient *model.APIClient + var apiClient *auth.APIClient for _, client := range m.apiConfig.Clients { if client.ClientID == callerID { apiClient = client @@ -123,7 +123,7 @@ func (m *RateLimitingMiddleware) Intercept(ctx context.Context, req any, info *g } // NewRateLimitingMiddlewareFromConfig creates a rate limiting middleware from configuration. -func NewRateLimitingMiddlewareFromConfig(config model.RateLimitingConfig, apiConfig model.APIKeyConfig, lggr logger.SugaredLogger) (*RateLimitingMiddleware, error) { +func NewRateLimitingMiddlewareFromConfig(config model.RateLimitingConfig, apiConfig auth.APIKeyConfig, lggr logger.SugaredLogger) (*RateLimitingMiddleware, error) { if !config.Enabled { return &RateLimitingMiddleware{enabled: false}, nil } diff --git a/aggregator/pkg/middlewares/rate_limiting_middleware_test.go b/aggregator/pkg/middlewares/rate_limiting_middleware_test.go index baae00a95..e2e39dfbd 100644 --- a/aggregator/pkg/middlewares/rate_limiting_middleware_test.go +++ b/aggregator/pkg/middlewares/rate_limiting_middleware_test.go @@ -10,8 +10,8 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/status" - "github.com/smartcontractkit/chainlink-ccv/aggregator/pkg/auth" "github.com/smartcontractkit/chainlink-ccv/aggregator/pkg/model" + "github.com/smartcontractkit/chainlink-ccv/common/auth" "github.com/smartcontractkit/chainlink-common/pkg/logger" ) @@ -45,8 +45,8 @@ func TestRateLimitingMiddleware_DefaultLimits(t *testing.T) { "/test.Service/Method": {LimitPerMinute: 5}, }, } - apiConfig := model.APIKeyConfig{ - Clients: map[string]*model.APIClient{ + apiConfig := auth.APIKeyConfig{ + Clients: map[string]*auth.APIClient{ "test-key": { ClientID: "test-caller", Enabled: true, @@ -90,8 +90,8 @@ func TestRateLimitingMiddleware_GroupLimits(t *testing.T) { "/test.Service/Method": {LimitPerMinute: 10}, }, } - apiConfig := model.APIKeyConfig{ - Clients: map[string]*model.APIClient{ + apiConfig := auth.APIKeyConfig{ + Clients: map[string]*auth.APIClient{ "test-key": { ClientID: "test-caller", Enabled: true, @@ -139,8 +139,8 @@ func TestRateLimitingMiddleware_MostRestrictiveGroup(t *testing.T) { "/test.Service/Method": {LimitPerMinute: 10}, }, } - apiConfig := model.APIKeyConfig{ - Clients: map[string]*model.APIClient{ + apiConfig := auth.APIKeyConfig{ + Clients: map[string]*auth.APIClient{ "test-key": { ClientID: "test-caller", Enabled: true, @@ -190,8 +190,8 @@ func TestRateLimitingMiddleware_CallerSpecificOverridesGroup(t *testing.T) { "/test.Service/Method": {LimitPerMinute: 10}, }, } - apiConfig := model.APIKeyConfig{ - Clients: map[string]*model.APIClient{ + apiConfig := auth.APIKeyConfig{ + Clients: map[string]*auth.APIClient{ "test-key": { ClientID: "test-caller", Enabled: true, diff --git a/aggregator/pkg/middlewares/require_auth_middleware.go b/aggregator/pkg/middlewares/require_auth_middleware.go index 8c6244724..436a0ccfc 100644 --- a/aggregator/pkg/middlewares/require_auth_middleware.go +++ b/aggregator/pkg/middlewares/require_auth_middleware.go @@ -7,8 +7,8 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/status" - "github.com/smartcontractkit/chainlink-ccv/aggregator/pkg/auth" "github.com/smartcontractkit/chainlink-ccv/aggregator/pkg/scope" + "github.com/smartcontractkit/chainlink-ccv/common/auth" "github.com/smartcontractkit/chainlink-common/pkg/logger" ) diff --git a/aggregator/pkg/model/config.go b/aggregator/pkg/model/config.go index 9a2c76350..852f8c1b5 100644 --- a/aggregator/pkg/model/config.go +++ b/aggregator/pkg/model/config.go @@ -9,6 +9,7 @@ import ( "strconv" "strings" + "github.com/smartcontractkit/chainlink-ccv/common/auth" "github.com/smartcontractkit/chainlink-ccv/protocol" ) @@ -105,21 +106,6 @@ type ServerConfig struct { MaxSendMsgSizeBytes int `toml:"maxSendMsgSizeBytes"` } -// APIClient represents a configured client for API access. -type APIClient struct { - ClientID string `toml:"clientId"` - Description string `toml:"description,omitempty"` - Enabled bool `toml:"enabled"` - Secrets map[string]string `toml:"secrets,omitempty"` - Groups []string `toml:"groups,omitempty"` -} - -// APIKeyConfig represents the configuration for API key management. -type APIKeyConfig struct { - // Clients maps API keys to client configurations - Clients map[string]*APIClient `toml:"clients"` -} - // AggregationConfig represents the configuration for the aggregation system. type AggregationConfig struct { // ChannelBufferSize controls the size of the aggregation request channel buffer @@ -230,7 +216,7 @@ type RateLimitingConfig struct { // GetEffectiveLimit resolves the effective rate limit for a given caller and method. // Priority order: 1) Specific caller limit, 2) Group limits (most restrictive), 3) Default limit. -func (c *RateLimitingConfig) GetEffectiveLimit(callerID, method string, apiClient *APIClient) *RateLimitConfig { +func (c *RateLimitingConfig) GetEffectiveLimit(callerID, method string, apiClient *auth.APIClient) *RateLimitConfig { // 1. Check specific caller limit (highest priority) if callerLimits, exists := c.Limits[callerID]; exists { if limit, exists := callerLimits[method]; exists { @@ -252,7 +238,7 @@ func (c *RateLimitingConfig) GetEffectiveLimit(callerID, method string, apiClien } // getMostRestrictiveGroupLimit finds the most restrictive rate limit from all groups the API client belongs to. -func (c *RateLimitingConfig) getMostRestrictiveGroupLimit(apiClient *APIClient, method string) *RateLimitConfig { +func (c *RateLimitingConfig) getMostRestrictiveGroupLimit(apiClient *auth.APIClient, method string) *RateLimitConfig { if apiClient == nil { return nil } @@ -300,39 +286,12 @@ type BeholderConfig struct { TraceBatchTimeout int64 `toml:"TraceBatchTimeout"` } -// GetClientByAPIKey returns the client configuration for a given API key. -func (c *APIKeyConfig) GetClientByAPIKey(apiKey string) (*APIClient, bool) { - client, exists := c.Clients[apiKey] - if !exists || !client.Enabled { - return nil, false - } - return client, true -} - -// ValidateAPIKey validates an API key against the configuration. -func (c *APIKeyConfig) ValidateAPIKey(apiKey string) error { - if strings.TrimSpace(apiKey) == "" { - return errors.New("api key cannot be empty") - } - - client, exists := c.GetClientByAPIKey(apiKey) - if !exists { - return errors.New("invalid or disabled api key") - } - - if client.ClientID == "" { - return errors.New("client id cannot be empty") - } - - return nil -} - // AggregatorConfig is the root configuration for the pb. type AggregatorConfig struct { Committee *Committee `toml:"committee"` Server ServerConfig `toml:"server"` Storage *StorageConfig `toml:"storage"` - APIKeys APIKeyConfig `toml:"-"` + APIKeys auth.APIKeyConfig `toml:"-"` Aggregation AggregationConfig `toml:"aggregation"` OrphanRecovery OrphanRecoveryConfig `toml:"orphanRecovery"` RateLimiting RateLimitingConfig `toml:"rateLimiting"` @@ -382,7 +341,7 @@ func (c *AggregatorConfig) SetDefaults() { c.Storage.ConnMaxIdleTime = 300 // 5 minutes } if c.APIKeys.Clients == nil { - c.APIKeys.Clients = make(map[string]*APIClient) + c.APIKeys.Clients = make(map[string]*auth.APIClient) } // Default orphan recovery: enabled with 5 minute interval if c.OrphanRecovery.IntervalSeconds == 0 { @@ -695,7 +654,7 @@ func (c *AggregatorConfig) LoadFromEnvironment() error { return errors.New("AGGREGATOR_API_KEYS_JSON environment variable is required") } - var apiKeyConfig APIKeyConfig + var apiKeyConfig auth.APIKeyConfig if err := json.Unmarshal([]byte(apiKeysJSON), &apiKeyConfig); err != nil { return fmt.Errorf("failed to parse AGGREGATOR_API_KEYS_JSON: %w", err) } diff --git a/aggregator/pkg/scope/scope.go b/aggregator/pkg/scope/scope.go index edf399354..542cf0111 100644 --- a/aggregator/pkg/scope/scope.go +++ b/aggregator/pkg/scope/scope.go @@ -6,8 +6,8 @@ import ( "github.com/google/uuid" - "github.com/smartcontractkit/chainlink-ccv/aggregator/pkg/auth" "github.com/smartcontractkit/chainlink-ccv/aggregator/pkg/common" + "github.com/smartcontractkit/chainlink-ccv/common/auth" "github.com/smartcontractkit/chainlink-common/pkg/logger" ) diff --git a/aggregator/tests/utils.go b/aggregator/tests/utils.go index 607fc2de9..06b350ae8 100644 --- a/aggregator/tests/utils.go +++ b/aggregator/tests/utils.go @@ -15,6 +15,7 @@ import ( "google.golang.org/grpc/test/bufconn" "github.com/smartcontractkit/chainlink-ccv/aggregator/pkg/model" + "github.com/smartcontractkit/chainlink-ccv/common/auth" "github.com/smartcontractkit/chainlink-ccv/protocol/common/logging" "github.com/smartcontractkit/chainlink-common/pkg/logger" @@ -88,8 +89,8 @@ func CreateServerAndClient(t *testing.T, options ...ConfigOption) (committeepb.C dummyConfig := &model.AggregatorConfig{ Storage: &model.StorageConfig{}, - APIKeys: model.APIKeyConfig{ - Clients: make(map[string]*model.APIClient), + APIKeys: auth.APIKeyConfig{ + Clients: make(map[string]*auth.APIClient), }, } for _, option := range options { @@ -130,8 +131,8 @@ func CreateServerOnly(t *testing.T, options ...ConfigOption) (*bufconn.Listener, Monitoring: model.MonitoringConfig{ Enabled: false, }, - APIKeys: model.APIKeyConfig{ - Clients: make(map[string]*model.APIClient), + APIKeys: auth.APIKeyConfig{ + Clients: make(map[string]*auth.APIClient), }, RateLimiting: model.RateLimitingConfig{ Enabled: true, @@ -151,7 +152,7 @@ func CreateServerOnly(t *testing.T, options ...ConfigOption) (*bufconn.Listener, }, } - config.APIKeys.Clients[defaultAPIKey] = &model.APIClient{ + config.APIKeys.Clients[defaultAPIKey] = &auth.APIClient{ ClientID: "test-client", Description: "Test client for integration tests", Enabled: true, @@ -200,8 +201,8 @@ func CreateAuthenticatedClient(t *testing.T, listener *bufconn.Listener, options dummyConfig := &model.AggregatorConfig{ Storage: &model.StorageConfig{}, - APIKeys: model.APIKeyConfig{ - Clients: make(map[string]*model.APIClient), + APIKeys: auth.APIKeyConfig{ + Clients: make(map[string]*auth.APIClient), }, } for _, option := range options { diff --git a/build/devenv/services/aggregator.go b/build/devenv/services/aggregator.go index 8e94ba2cd..64d36f410 100644 --- a/build/devenv/services/aggregator.go +++ b/build/devenv/services/aggregator.go @@ -19,6 +19,7 @@ import ( aggregator "github.com/smartcontractkit/chainlink-ccv/aggregator/pkg" "github.com/smartcontractkit/chainlink-ccv/aggregator/pkg/configuration" "github.com/smartcontractkit/chainlink-ccv/aggregator/pkg/model" + "github.com/smartcontractkit/chainlink-ccv/common/auth" "github.com/smartcontractkit/chainlink-ccv/devenv/internal/util" "github.com/smartcontractkit/chainlink-testing-framework/framework" ) @@ -96,11 +97,11 @@ type AggregatorInput struct { SharedTLSCerts *TLSCertPaths `toml:"-"` } -func (a *AggregatorInput) GetAPIKeys() (model.APIKeyConfig, error) { - var apiKeyConfig model.APIKeyConfig +func (a *AggregatorInput) GetAPIKeys() (auth.APIKeyConfig, error) { + var apiKeyConfig auth.APIKeyConfig err := json.Unmarshal([]byte(a.Env.APIKeysJSON), &apiKeyConfig) if err != nil { - return model.APIKeyConfig{}, fmt.Errorf("failed to unmarshal API keys JSON: %w", err) + return auth.APIKeyConfig{}, fmt.Errorf("failed to unmarshal API keys JSON: %w", err) } return apiKeyConfig, nil } diff --git a/common/auth/auth.go b/common/auth/auth.go new file mode 100644 index 000000000..823fad37a --- /dev/null +++ b/common/auth/auth.go @@ -0,0 +1,76 @@ +package auth + +import ( + "context" + + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + "github.com/smartcontractkit/chainlink-ccv/protocol" + "github.com/smartcontractkit/chainlink-ccv/protocol/common/hmac" + "github.com/smartcontractkit/chainlink-common/pkg/logger" +) + +type HMACAuth struct { + config *APIKeyConfig + logger logger.Logger +} + +func NewHMACAuth(config *APIKeyConfig, lggr logger.Logger) *HMACAuth { + return &HMACAuth{ + config: config, + logger: lggr, + } +} + +func (h *HMACAuth) Authorize(ctx context.Context, body protocol.ByteSlice, method, path, apiKey, timestamp, signature string) (context.Context, error) { + if err := validateParams(h.logger, apiKey, timestamp, signature); err != nil { + return nil, err + } + + client, exists := h.config.GetClientByAPIKey(apiKey) + if !exists { + h.logger.Warnf("Authentication failed: invalid or disabled API key") + return nil, status.Error(codes.Unauthenticated, "invalid credentials") + } + + if len(client.Secrets) == 0 { + h.logger.Errorf("Client %s has no secrets configured", client.ClientID) + return nil, status.Error(codes.Internal, "authentication configuration error") + } + + if err := hmac.ValidateTimestamp(timestamp); err != nil { + h.logger.Warnf("Authentication failed for client %s: %v", client.ClientID, err) + return nil, status.Error(codes.Unauthenticated, "invalid or expired timestamp") + } + + bodyHash := hmac.ComputeBodyHash(body) + stringToSign := hmac.GenerateStringToSign(method, path, bodyHash, apiKey, timestamp) + + if !hmac.ValidateSignature(stringToSign, signature, client.Secrets) { + h.logger.Warnf("Authentication failed for client %s: invalid signature", client.ClientID) + return nil, status.Error(codes.Unauthenticated, "invalid signature") + } + + identity := CreateCallerIdentity(client.ClientID, false) + h.logger.Debugf("Successfully authenticated client: %s", client.ClientID) + + return ToContext(ctx, identity), nil +} + +func validateParams(logger logger.Logger, apiKey, timestamp, signature string) error { + if apiKey == "" { + logger.Warnf("Authentication failed: missing api key") + return status.Error(codes.Unauthenticated, "missing api key") + } + if timestamp == "" { + logger.Warnf("Authentication failed: missing timestamp") + return status.Error(codes.Unauthenticated, "mmissing timestamp") + } + if signature == "" { + logger.Warnf("Authentication failed: missing signature") + return status.Error(codes.Unauthenticated, "missing signature") + } + + return nil +} diff --git a/common/auth/config.go b/common/auth/config.go new file mode 100644 index 000000000..4741370d0 --- /dev/null +++ b/common/auth/config.go @@ -0,0 +1,48 @@ +package auth + +import ( + "errors" + "strings" +) + +// APIClient represents a configured client for API access. +type APIClient struct { + ClientID string `toml:"clientId"` + Description string `toml:"description,omitempty"` + Enabled bool `toml:"enabled"` + Secrets map[string]string `toml:"secrets,omitempty"` + Groups []string `toml:"groups,omitempty"` +} + +// APIKeyConfig represents the configuration for API key management. +type APIKeyConfig struct { + // Clients maps API keys to client configurations + Clients map[string]*APIClient `toml:"clients"` +} + +// GetClientByAPIKey returns the client configuration for a given API key. +func (c *APIKeyConfig) GetClientByAPIKey(apiKey string) (*APIClient, bool) { + client, exists := c.Clients[apiKey] + if !exists || !client.Enabled { + return nil, false + } + return client, true +} + +// ValidateAPIKey validates an API key against the configuration. +func (c *APIKeyConfig) ValidateAPIKey(apiKey string) error { + if strings.TrimSpace(apiKey) == "" { + return errors.New("api key cannot be empty") + } + + client, exists := c.GetClientByAPIKey(apiKey) + if !exists { + return errors.New("invalid or disabled api key") + } + + if client.ClientID == "" { + return errors.New("client id cannot be empty") + } + + return nil +} diff --git a/aggregator/pkg/auth/identity.go b/common/auth/identity.go similarity index 100% rename from aggregator/pkg/auth/identity.go rename to common/auth/identity.go diff --git a/aggregator/pkg/auth/identity_test.go b/common/auth/identity_test.go similarity index 100% rename from aggregator/pkg/auth/identity_test.go rename to common/auth/identity_test.go From e5861b3a2db7eba65fe933a029ac1783b8659ddc Mon Sep 17 00:00:00 2001 From: Kodey Kilday-Thomas Date: Tue, 30 Dec 2025 11:49:03 +0000 Subject: [PATCH 2/2] wip: fix test --- aggregator/pkg/middlewares/hmac_auth_middleware_test.go | 6 +++--- common/auth/auth.go | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/aggregator/pkg/middlewares/hmac_auth_middleware_test.go b/aggregator/pkg/middlewares/hmac_auth_middleware_test.go index 5ad60c8bd..090159599 100644 --- a/aggregator/pkg/middlewares/hmac_auth_middleware_test.go +++ b/aggregator/pkg/middlewares/hmac_auth_middleware_test.go @@ -151,7 +151,7 @@ func TestHMACAuthMiddleware(t *testing.T) { }, expectError: true, expectedErrorCode: codes.Unauthenticated, - expectedErrorMsg: "missing authorization header", + expectedErrorMsg: "missing api key", validateIdentity: false, }, { @@ -164,7 +164,7 @@ func TestHMACAuthMiddleware(t *testing.T) { }, expectError: true, expectedErrorCode: codes.Unauthenticated, - expectedErrorMsg: "missing x-authorization-timestamp header", + expectedErrorMsg: "missing timestamp", validateIdentity: false, }, { @@ -178,7 +178,7 @@ func TestHMACAuthMiddleware(t *testing.T) { }, expectError: true, expectedErrorCode: codes.Unauthenticated, - expectedErrorMsg: "missing x-authorization-signature-sha256 header", + expectedErrorMsg: "missing signature", validateIdentity: false, }, { diff --git a/common/auth/auth.go b/common/auth/auth.go index 823fad37a..cde97a677 100644 --- a/common/auth/auth.go +++ b/common/auth/auth.go @@ -65,7 +65,7 @@ func validateParams(logger logger.Logger, apiKey, timestamp, signature string) e } if timestamp == "" { logger.Warnf("Authentication failed: missing timestamp") - return status.Error(codes.Unauthenticated, "mmissing timestamp") + return status.Error(codes.Unauthenticated, "missing timestamp") } if signature == "" { logger.Warnf("Authentication failed: missing signature")