Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion aggregator/pkg/middlewares/anonymous_auth_middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer"

"github.com/smartcontractkit/chainlink-ccv/aggregator/pkg/auth"
"github.com/smartcontractkit/chainlink-ccv/common/auth"
"github.com/smartcontractkit/chainlink-common/pkg/logger"
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer"

"github.com/smartcontractkit/chainlink-ccv/aggregator/pkg/auth"
"github.com/smartcontractkit/chainlink-ccv/common/auth"
"github.com/smartcontractkit/chainlink-common/pkg/logger"
)

Expand Down
65 changes: 12 additions & 53 deletions aggregator/pkg/middlewares/hmac_auth_middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,23 @@ import (
"context"

"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"

"github.com/smartcontractkit/chainlink-ccv/aggregator/pkg/auth"
"github.com/smartcontractkit/chainlink-ccv/aggregator/pkg/model"
"github.com/smartcontractkit/chainlink-ccv/common/auth"
"github.com/smartcontractkit/chainlink-ccv/protocol/common/hmac"
"github.com/smartcontractkit/chainlink-common/pkg/logger"
)

type HMACAuthMiddleware struct {
apiKeyConfig *model.APIKeyConfig
logger logger.Logger
hmac auth.HMACAuth
logger logger.Logger
}

// NewHMACAuthMiddleware creates a new HMAC authentication middleware.
func NewHMACAuthMiddleware(config *model.APIKeyConfig, lggr logger.Logger) *HMACAuthMiddleware {
func NewHMACAuthMiddleware(config *auth.APIKeyConfig, lggr logger.Logger) *HMACAuthMiddleware {
return &HMACAuthMiddleware{
apiKeyConfig: config,
logger: lggr,
hmac: *auth.NewHMACAuth(config, lggr),
logger: lggr,
}
}

Expand All @@ -43,56 +40,18 @@ func (m *HMACAuthMiddleware) Intercept(ctx context.Context, req any, info *grpc.
return handler(ctx, req)
}

// If some but not all HMAC headers are present, this is an error
if apiKey == "" {
m.logger.Warnf("Authentication failed: missing authorization header")
return nil, status.Error(codes.Unauthenticated, "missing authorization header")
}
if timestamp == "" {
m.logger.Warnf("Authentication failed: missing x-authorization-timestamp header")
return nil, status.Error(codes.Unauthenticated, "missing x-authorization-timestamp header")
}
if providedSignature == "" {
m.logger.Warnf("Authentication failed: missing x-authorization-signature-sha256 header")
return nil, status.Error(codes.Unauthenticated, "missing x-authorization-signature-sha256 header")
}

client, exists := m.apiKeyConfig.GetClientByAPIKey(apiKey)
if !exists {
m.logger.Warnf("Authentication failed: invalid or disabled API key")
return nil, status.Error(codes.Unauthenticated, "invalid credentials")
}

if len(client.Secrets) == 0 {
m.logger.Errorf("Client %s has no secrets configured", client.ClientID)
return nil, status.Error(codes.Internal, "authentication configuration error")
}

if err := hmac.ValidateTimestamp(timestamp); err != nil {
m.logger.Warnf("Authentication failed for client %s: %v", client.ClientID, err)
return nil, status.Error(codes.Unauthenticated, "invalid or expired timestamp")
}

body, err := hmac.SerializeRequestBody(req)
if err != nil {
m.logger.Errorf("Failed to serialize request body: %v", err)
return nil, status.Error(codes.Internal, "request serialization error")
m.logger.Error("Unable to seralize request body")
// Pass through to other auth mechanisms
return handler(ctx, req)
}

bodyHash := hmac.ComputeBodyHash(body)

stringToSign := hmac.GenerateStringToSign(hmac.HTTPMethodPost, info.FullMethod, bodyHash, apiKey, timestamp)

if !hmac.ValidateSignature(stringToSign, providedSignature, client.Secrets) {
m.logger.Warnf("Authentication failed for client %s: invalid signature", client.ClientID)
return nil, status.Error(codes.Unauthenticated, "invalid signature")
ctx, err = m.hmac.Authorize(ctx, body, hmac.HTTPMethodPost, info.FullMethod, apiKey, timestamp, providedSignature)
if err != nil {
return nil, err
}

identity := auth.CreateCallerIdentity(client.ClientID, false)
ctx = auth.ToContext(ctx, identity)

m.logger.Debugf("Successfully authenticated client: %s", client.ClientID)

return handler(ctx, req)
}

Expand Down
15 changes: 7 additions & 8 deletions aggregator/pkg/middlewares/hmac_auth_middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@ import (
"google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"

"github.com/smartcontractkit/chainlink-ccv/aggregator/pkg/auth"
"github.com/smartcontractkit/chainlink-ccv/aggregator/pkg/model"
"github.com/smartcontractkit/chainlink-ccv/common/auth"
"github.com/smartcontractkit/chainlink-common/pkg/logger"

hmacutil "github.com/smartcontractkit/chainlink-ccv/protocol/common/hmac"
Expand All @@ -40,9 +39,9 @@ func generateTestSignature(
}

// Test helper: creates test API key configuration.
func createTestAPIKeyConfig() *model.APIKeyConfig {
return &model.APIKeyConfig{
Clients: map[string]*model.APIClient{
func createTestAPIKeyConfig() *auth.APIKeyConfig {
return &auth.APIKeyConfig{
Clients: map[string]*auth.APIClient{
testAPIKey1: {
ClientID: "client-1",
Description: "Test client 1",
Expand Down Expand Up @@ -152,7 +151,7 @@ func TestHMACAuthMiddleware(t *testing.T) {
},
expectError: true,
expectedErrorCode: codes.Unauthenticated,
expectedErrorMsg: "missing authorization header",
expectedErrorMsg: "missing api key",
validateIdentity: false,
},
{
Expand All @@ -165,7 +164,7 @@ func TestHMACAuthMiddleware(t *testing.T) {
},
expectError: true,
expectedErrorCode: codes.Unauthenticated,
expectedErrorMsg: "missing x-authorization-timestamp header",
expectedErrorMsg: "missing timestamp",
validateIdentity: false,
},
{
Expand All @@ -179,7 +178,7 @@ func TestHMACAuthMiddleware(t *testing.T) {
},
expectError: true,
expectedErrorCode: codes.Unauthenticated,
expectedErrorMsg: "missing x-authorization-signature-sha256 header",
expectedErrorMsg: "missing signature",
validateIdentity: false,
},
{
Expand Down
10 changes: 5 additions & 5 deletions aggregator/pkg/middlewares/rate_limiting_middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ import (
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"

"github.com/smartcontractkit/chainlink-ccv/aggregator/pkg/auth"
"github.com/smartcontractkit/chainlink-ccv/aggregator/pkg/model"
"github.com/smartcontractkit/chainlink-ccv/aggregator/pkg/rate_limiting"
"github.com/smartcontractkit/chainlink-ccv/common/auth"
"github.com/smartcontractkit/chainlink-ccv/protocol"
"github.com/smartcontractkit/chainlink-common/pkg/logger"
)
Expand All @@ -23,13 +23,13 @@ var _ protocol.HealthReporter = (*RateLimitingMiddleware)(nil)
type RateLimitingMiddleware struct {
store limiter.Store
config model.RateLimitingConfig
apiConfig model.APIKeyConfig
apiConfig auth.APIKeyConfig
enabled bool
lggr logger.SugaredLogger
}

// NewRateLimitingMiddleware creates a new rate limiting middleware with the given store and configuration.
func NewRateLimitingMiddleware(store limiter.Store, config model.RateLimitingConfig, apiConfig model.APIKeyConfig, lggr logger.SugaredLogger) *RateLimitingMiddleware {
func NewRateLimitingMiddleware(store limiter.Store, config model.RateLimitingConfig, apiConfig auth.APIKeyConfig, lggr logger.SugaredLogger) *RateLimitingMiddleware {
if store == nil {
return &RateLimitingMiddleware{enabled: false}
}
Expand All @@ -49,7 +49,7 @@ func (m *RateLimitingMiddleware) buildKey(callerID, method string) string {

func (m *RateLimitingMiddleware) getLimitConfig(callerID, method string) (model.RateLimitConfig, bool) {
// Get the API client configuration for the caller
var apiClient *model.APIClient
var apiClient *auth.APIClient
for _, client := range m.apiConfig.Clients {
if client.ClientID == callerID {
apiClient = client
Expand Down Expand Up @@ -123,7 +123,7 @@ func (m *RateLimitingMiddleware) Intercept(ctx context.Context, req any, info *g
}

// NewRateLimitingMiddlewareFromConfig creates a rate limiting middleware from configuration.
func NewRateLimitingMiddlewareFromConfig(config model.RateLimitingConfig, apiConfig model.APIKeyConfig, lggr logger.SugaredLogger) (*RateLimitingMiddleware, error) {
func NewRateLimitingMiddlewareFromConfig(config model.RateLimitingConfig, apiConfig auth.APIKeyConfig, lggr logger.SugaredLogger) (*RateLimitingMiddleware, error) {
if !config.Enabled {
return &RateLimitingMiddleware{enabled: false}, nil
}
Expand Down
18 changes: 9 additions & 9 deletions aggregator/pkg/middlewares/rate_limiting_middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ import (
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"

"github.com/smartcontractkit/chainlink-ccv/aggregator/pkg/auth"
"github.com/smartcontractkit/chainlink-ccv/aggregator/pkg/model"
"github.com/smartcontractkit/chainlink-ccv/common/auth"
"github.com/smartcontractkit/chainlink-common/pkg/logger"
)

Expand Down Expand Up @@ -45,8 +45,8 @@ func TestRateLimitingMiddleware_DefaultLimits(t *testing.T) {
"/test.Service/Method": {LimitPerMinute: 5},
},
}
apiConfig := model.APIKeyConfig{
Clients: map[string]*model.APIClient{
apiConfig := auth.APIKeyConfig{
Clients: map[string]*auth.APIClient{
"test-key": {
ClientID: "test-caller",
Enabled: true,
Expand Down Expand Up @@ -90,8 +90,8 @@ func TestRateLimitingMiddleware_GroupLimits(t *testing.T) {
"/test.Service/Method": {LimitPerMinute: 10},
},
}
apiConfig := model.APIKeyConfig{
Clients: map[string]*model.APIClient{
apiConfig := auth.APIKeyConfig{
Clients: map[string]*auth.APIClient{
"test-key": {
ClientID: "test-caller",
Enabled: true,
Expand Down Expand Up @@ -139,8 +139,8 @@ func TestRateLimitingMiddleware_MostRestrictiveGroup(t *testing.T) {
"/test.Service/Method": {LimitPerMinute: 10},
},
}
apiConfig := model.APIKeyConfig{
Clients: map[string]*model.APIClient{
apiConfig := auth.APIKeyConfig{
Clients: map[string]*auth.APIClient{
"test-key": {
ClientID: "test-caller",
Enabled: true,
Expand Down Expand Up @@ -190,8 +190,8 @@ func TestRateLimitingMiddleware_CallerSpecificOverridesGroup(t *testing.T) {
"/test.Service/Method": {LimitPerMinute: 10},
},
}
apiConfig := model.APIKeyConfig{
Clients: map[string]*model.APIClient{
apiConfig := auth.APIKeyConfig{
Clients: map[string]*auth.APIClient{
"test-key": {
ClientID: "test-caller",
Enabled: true,
Expand Down
2 changes: 1 addition & 1 deletion aggregator/pkg/middlewares/require_auth_middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ import (
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"

"github.com/smartcontractkit/chainlink-ccv/aggregator/pkg/auth"
"github.com/smartcontractkit/chainlink-ccv/aggregator/pkg/scope"
"github.com/smartcontractkit/chainlink-ccv/common/auth"
"github.com/smartcontractkit/chainlink-common/pkg/logger"
)

Expand Down
53 changes: 6 additions & 47 deletions aggregator/pkg/model/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"strconv"
"strings"

"github.com/smartcontractkit/chainlink-ccv/common/auth"
"github.com/smartcontractkit/chainlink-ccv/protocol"
)

Expand Down Expand Up @@ -105,21 +106,6 @@ type ServerConfig struct {
MaxSendMsgSizeBytes int `toml:"maxSendMsgSizeBytes"`
}

// APIClient represents a configured client for API access.
type APIClient struct {
ClientID string `toml:"clientId"`
Description string `toml:"description,omitempty"`
Enabled bool `toml:"enabled"`
Secrets map[string]string `toml:"secrets,omitempty"`
Groups []string `toml:"groups,omitempty"`
}

// APIKeyConfig represents the configuration for API key management.
type APIKeyConfig struct {
// Clients maps API keys to client configurations
Clients map[string]*APIClient `toml:"clients"`
}

// AggregationConfig represents the configuration for the aggregation system.
type AggregationConfig struct {
// ChannelBufferSize controls the size of the aggregation request channel buffer
Expand Down Expand Up @@ -230,7 +216,7 @@ type RateLimitingConfig struct {

// GetEffectiveLimit resolves the effective rate limit for a given caller and method.
// Priority order: 1) Specific caller limit, 2) Group limits (most restrictive), 3) Default limit.
func (c *RateLimitingConfig) GetEffectiveLimit(callerID, method string, apiClient *APIClient) *RateLimitConfig {
func (c *RateLimitingConfig) GetEffectiveLimit(callerID, method string, apiClient *auth.APIClient) *RateLimitConfig {
// 1. Check specific caller limit (highest priority)
if callerLimits, exists := c.Limits[callerID]; exists {
if limit, exists := callerLimits[method]; exists {
Expand All @@ -252,7 +238,7 @@ func (c *RateLimitingConfig) GetEffectiveLimit(callerID, method string, apiClien
}

// getMostRestrictiveGroupLimit finds the most restrictive rate limit from all groups the API client belongs to.
func (c *RateLimitingConfig) getMostRestrictiveGroupLimit(apiClient *APIClient, method string) *RateLimitConfig {
func (c *RateLimitingConfig) getMostRestrictiveGroupLimit(apiClient *auth.APIClient, method string) *RateLimitConfig {
if apiClient == nil {
return nil
}
Expand Down Expand Up @@ -300,39 +286,12 @@ type BeholderConfig struct {
TraceBatchTimeout int64 `toml:"TraceBatchTimeout"`
}

// GetClientByAPIKey returns the client configuration for a given API key.
func (c *APIKeyConfig) GetClientByAPIKey(apiKey string) (*APIClient, bool) {
client, exists := c.Clients[apiKey]
if !exists || !client.Enabled {
return nil, false
}
return client, true
}

// ValidateAPIKey validates an API key against the configuration.
func (c *APIKeyConfig) ValidateAPIKey(apiKey string) error {
if strings.TrimSpace(apiKey) == "" {
return errors.New("api key cannot be empty")
}

client, exists := c.GetClientByAPIKey(apiKey)
if !exists {
return errors.New("invalid or disabled api key")
}

if client.ClientID == "" {
return errors.New("client id cannot be empty")
}

return nil
}

// AggregatorConfig is the root configuration for the pb.
type AggregatorConfig struct {
Committee *Committee `toml:"committee"`
Server ServerConfig `toml:"server"`
Storage *StorageConfig `toml:"storage"`
APIKeys APIKeyConfig `toml:"-"`
APIKeys auth.APIKeyConfig `toml:"-"`
Aggregation AggregationConfig `toml:"aggregation"`
OrphanRecovery OrphanRecoveryConfig `toml:"orphanRecovery"`
RateLimiting RateLimitingConfig `toml:"rateLimiting"`
Expand Down Expand Up @@ -382,7 +341,7 @@ func (c *AggregatorConfig) SetDefaults() {
c.Storage.ConnMaxIdleTime = 300 // 5 minutes
}
if c.APIKeys.Clients == nil {
c.APIKeys.Clients = make(map[string]*APIClient)
c.APIKeys.Clients = make(map[string]*auth.APIClient)
}
// Default orphan recovery: enabled with 5 minute interval
if c.OrphanRecovery.IntervalSeconds == 0 {
Expand Down Expand Up @@ -695,7 +654,7 @@ func (c *AggregatorConfig) LoadFromEnvironment() error {
return errors.New("AGGREGATOR_API_KEYS_JSON environment variable is required")
}

var apiKeyConfig APIKeyConfig
var apiKeyConfig auth.APIKeyConfig
if err := json.Unmarshal([]byte(apiKeysJSON), &apiKeyConfig); err != nil {
return fmt.Errorf("failed to parse AGGREGATOR_API_KEYS_JSON: %w", err)
}
Expand Down
2 changes: 1 addition & 1 deletion aggregator/pkg/scope/scope.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ import (

"github.com/google/uuid"

"github.com/smartcontractkit/chainlink-ccv/aggregator/pkg/auth"
"github.com/smartcontractkit/chainlink-ccv/aggregator/pkg/common"
"github.com/smartcontractkit/chainlink-ccv/common/auth"
"github.com/smartcontractkit/chainlink-common/pkg/logger"
)

Expand Down
Loading