From 066aa48c2be4d14246d5d6d6c7876a6736fa55ae Mon Sep 17 00:00:00 2001 From: Chris Stockton Date: Thu, 5 Feb 2026 09:57:07 -0700 Subject: [PATCH] feat: add metadata field to all hooks Adds constructors for all hook input types: * MFAVerificationAttemptInput * PasswordVerificationAttemptInput * CustomAccessTokenInput * SendSMSInput, * SendEmailInput To consistently populate metadata fields: * `name` - Hook Name * `uuid` - Request UUID * `time` - Request Time * `ip_address` Request IP Address This improves observability and security auditing by guaranteeing that all hook invocations include request metadata. It also enables new use cases by passing the request IP address. For example more advanced methods for rate limiting login or MFA attempts may now be implemented. --- internal/api/e2e_test.go | 59 +++++++++++++++++++++++ internal/api/hooks_test.go | 48 +++++++++++-------- internal/api/mail.go | 11 +++-- internal/api/mfa.go | 41 ++++++++-------- internal/api/phone.go | 11 +++-- internal/api/token.go | 11 +++-- internal/hooks/v0hooks/v0hooks.go | 79 +++++++++++++++++++++++++++++-- internal/tokens/service.go | 11 +++-- 8 files changed, 207 insertions(+), 64 deletions(-) diff --git a/internal/api/e2e_test.go b/internal/api/e2e_test.go index 81c2cc7ec..ad2b1ce97 100644 --- a/internal/api/e2e_test.go +++ b/internal/api/e2e_test.go @@ -66,7 +66,12 @@ func runVerifyBeforeUserCreatedHook( hookReq := &v0hooks.BeforeUserCreatedInput{} err := call.Unmarshal(hookReq) require.NoError(t, err) + + require.NotNil(t, hookReq.Metadata) + require.NotEmpty(t, hookReq.Metadata.IPAddress) require.Equal(t, v0hooks.BeforeUserCreated, hookReq.Metadata.Name) + require.NotEqual(t, uuid.Nil, hookReq.Metadata.UUID) + require.False(t, hookReq.Metadata.Time.IsZero()) u := hookReq.User require.Equal(t, expUser.ID, u.ID) @@ -103,7 +108,12 @@ func runVerifyAfterUserCreatedHook( hookReq := &v0hooks.AfterUserCreatedInput{} err := call.Unmarshal(hookReq) require.NoError(t, err) + + require.NotNil(t, hookReq.Metadata) + require.NotEmpty(t, hookReq.Metadata.IPAddress) require.Equal(t, v0hooks.AfterUserCreated, hookReq.Metadata.Name) + require.NotEqual(t, uuid.Nil, hookReq.Metadata.UUID) + require.False(t, hookReq.Metadata.Time.IsZero()) u := hookReq.User require.Equal(t, expUser.ID, u.ID) @@ -176,6 +186,12 @@ func signupAndConfirmEmail( err = call.Unmarshal(hookReq) require.NoError(t, err) + require.NotNil(t, hookReq.Metadata) + require.NotEmpty(t, hookReq.Metadata.IPAddress) + require.Equal(t, v0hooks.SendEmail, hookReq.Metadata.Name) + require.NotEqual(t, uuid.Nil, hookReq.Metadata.UUID) + require.False(t, hookReq.Metadata.Time.IsZero()) + // verify that the latest user from find user matches OTP otpHash := crypto.GenerateTokenHash( expUser.GetEmail(), hookReq.EmailData.Token) @@ -285,6 +301,12 @@ func TestE2EHooks(t *testing.T) { err = call.Unmarshal(hookReq) require.NoError(t, err) + require.NotNil(t, hookReq.Metadata) + require.NotEmpty(t, hookReq.Metadata.IPAddress) + require.Equal(t, v0hooks.SendSMS, hookReq.Metadata.Name) + require.NotEqual(t, uuid.Nil, hookReq.Metadata.UUID) + require.False(t, hookReq.Metadata.Time.IsZero()) + latestUser, err := models.FindUserByID(inst.Conn, signupUser.ID) require.NoError(t, err) require.NotNil(t, latestUser) @@ -383,6 +405,12 @@ func TestE2EHooks(t *testing.T) { err = call.Unmarshal(hookReq) require.NoError(t, err) + require.NotNil(t, hookReq.Metadata) + require.NotEmpty(t, hookReq.Metadata.IPAddress) + require.Equal(t, v0hooks.SendSMS, hookReq.Metadata.Name) + require.NotEqual(t, uuid.Nil, hookReq.Metadata.UUID) + require.False(t, hookReq.Metadata.Time.IsZero()) + require.Equal(t, currentUser.ID, hookReq.User.ID) require.Equal(t, currentUser.Aud, hookReq.User.Aud) require.Equal(t, currentUser.Phone, hookReq.User.Phone) @@ -924,6 +952,13 @@ func TestE2EHooks(t *testing.T) { hookReq := &v0hooks.CustomAccessTokenInput{} err := call.Unmarshal(hookReq) require.NoError(t, err) + + require.NotNil(t, hookReq.Metadata) + require.NotEmpty(t, hookReq.Metadata.IPAddress) + require.Equal(t, v0hooks.CustomizeAccessToken, hookReq.Metadata.Name) + require.NotEqual(t, uuid.Nil, hookReq.Metadata.UUID) + require.False(t, hookReq.Metadata.Time.IsZero()) + require.Equal(t, currentUser.ID, hookReq.UserID) require.Equal(t, currentUser.ID.String(), hookReq.Claims.Subject) } @@ -1127,6 +1162,12 @@ func TestE2EHooks(t *testing.T) { err = call.Unmarshal(hookReq) require.NoError(t, err) + require.NotNil(t, hookReq.Metadata) + require.NotEmpty(t, hookReq.Metadata.IPAddress) + require.Equal(t, v0hooks.SendEmail, hookReq.Metadata.Name) + require.NotEqual(t, uuid.Nil, hookReq.Metadata.UUID) + require.False(t, hookReq.Metadata.Time.IsZero()) + // hook user matches the signup user require.Equal(t, signupUser.ID, hookReq.User.ID) require.Equal(t, signupUser.Aud, hookReq.User.Aud) @@ -1240,6 +1281,12 @@ func TestE2EHooks(t *testing.T) { err = call.Unmarshal(hookReq) require.NoError(t, err) + require.NotNil(t, hookReq.Metadata) + require.NotEmpty(t, hookReq.Metadata.IPAddress) + require.Equal(t, v0hooks.SendEmail, hookReq.Metadata.Name) + require.NotEqual(t, uuid.Nil, hookReq.Metadata.UUID) + require.False(t, hookReq.Metadata.Time.IsZero()) + // verify there is an ott generated ott, err := models.FindOneTimeToken( inst.Conn, @@ -1343,6 +1390,12 @@ func TestE2EHooks(t *testing.T) { err = call.Unmarshal(hookReq) require.NoError(t, err) + require.NotNil(t, hookReq.Metadata) + require.NotEmpty(t, hookReq.Metadata.IPAddress) + require.Equal(t, v0hooks.SendEmail, hookReq.Metadata.Name) + require.NotEqual(t, uuid.Nil, hookReq.Metadata.UUID) + require.False(t, hookReq.Metadata.Time.IsZero()) + // hook user matches the signup user require.Equal(t, signupUser.ID, hookReq.User.ID) require.Equal(t, signupUser.Aud, hookReq.User.Aud) @@ -1453,6 +1506,12 @@ func TestE2EHooks(t *testing.T) { err = call.Unmarshal(hookReq) require.NoError(t, err) + require.NotNil(t, hookReq.Metadata) + require.NotEmpty(t, hookReq.Metadata.IPAddress) + require.Equal(t, v0hooks.SendEmail, hookReq.Metadata.Name) + require.NotEqual(t, uuid.Nil, hookReq.Metadata.UUID) + require.False(t, hookReq.Metadata.Time.IsZero()) + // verify there is an ott generated ott, err := models.FindOneTimeToken( inst.Conn, diff --git a/internal/api/hooks_test.go b/internal/api/hooks_test.go index d3e5bdaba..68020c1cb 100644 --- a/internal/api/hooks_test.go +++ b/internal/api/hooks_test.go @@ -77,12 +77,6 @@ func (ts *HooksTestSuite) TestRunHTTPHook() { // setup mock requests for hooks defer gock.OffAll() - input := v0hooks.SendSMSInput{ - User: ts.TestUser, - SMS: v0hooks.SMS{ - OTP: "123456", - }, - } testURL := "http://localhost:54321/functions/v1/custom-sms-sender" ts.Config.Hook.SendSMS.URI = testURL @@ -126,8 +120,16 @@ func (ts *HooksTestSuite) TestRunHTTPHook() { ts.Run(tc.description, func() { req, _ := http.NewRequest("POST", ts.Config.Hook.SendSMS.URI, nil) + input := v0hooks.NewSendSMSInput( + req, + ts.TestUser, + v0hooks.SMS{ + OTP: "123456", + }, + ) + var output v0hooks.SendSMSOutput - err := ts.API.hooksMgr.InvokeHook(ts.API.db, req, &input, &output) + err := ts.API.hooksMgr.InvokeHook(ts.API.db, req, input, &output) if !tc.expectError { require.NoError(ts.T(), err) @@ -143,12 +145,6 @@ func (ts *HooksTestSuite) TestRunHTTPHook() { func (ts *HooksTestSuite) TestShouldRetryWithRetryAfterHeader() { defer gock.OffAll() - input := v0hooks.SendSMSInput{ - User: ts.TestUser, - SMS: v0hooks.SMS{ - OTP: "123456", - }, - } testURL := "http://localhost:54321/functions/v1/custom-sms-sender" ts.Config.Hook.SendSMS.URI = testURL @@ -169,8 +165,16 @@ func (ts *HooksTestSuite) TestShouldRetryWithRetryAfterHeader() { req, err := http.NewRequest("POST", "http://localhost:9998/otp", nil) require.NoError(ts.T(), err) + input := v0hooks.NewSendSMSInput( + req, + ts.TestUser, + v0hooks.SMS{ + OTP: "123456", + }, + ) + var output v0hooks.SendSMSOutput - err = ts.API.hooksMgr.InvokeHook(ts.API.db, req, &input, &output) + err = ts.API.hooksMgr.InvokeHook(ts.API.db, req, input, &output) require.NoError(ts.T(), err) // Ensure that all expected HTTP interactions (mocks) have been called @@ -180,12 +184,6 @@ func (ts *HooksTestSuite) TestShouldRetryWithRetryAfterHeader() { func (ts *HooksTestSuite) TestShouldReturnErrorForNonJSONContentType() { defer gock.OffAll() - input := v0hooks.SendSMSInput{ - User: ts.TestUser, - SMS: v0hooks.SMS{ - OTP: "123456", - }, - } testURL := "http://localhost:54321/functions/v1/custom-sms-sender" ts.Config.Hook.SendSMS.URI = testURL @@ -198,8 +196,16 @@ func (ts *HooksTestSuite) TestShouldReturnErrorForNonJSONContentType() { req, err := http.NewRequest("POST", "http://localhost:9999/otp", nil) require.NoError(ts.T(), err) + input := v0hooks.NewSendSMSInput( + req, + ts.TestUser, + v0hooks.SMS{ + OTP: "123456", + }, + ) + var output v0hooks.SendSMSOutput - err = ts.API.hooksMgr.InvokeHook(ts.API.db, req, &input, &output) + err = ts.API.hooksMgr.InvokeHook(ts.API.db, req, input, &output) require.Error(ts.T(), err, "Expected an error due to wrong content type") require.Contains(ts.T(), err.Error(), "Invalid JSON response.") require.True(ts.T(), gock.IsDone(), "Expected all mocks to have been called") diff --git a/internal/api/mail.go b/internal/api/mail.go index dc7b8a58d..87c4841c3 100644 --- a/internal/api/mail.go +++ b/internal/api/mail.go @@ -860,12 +860,13 @@ func (a *API) sendEmail(r *http.Request, tx *storage.Connection, u *models.User, emailData.FactorType = params.factorType } - input := v0hooks.SendEmailInput{ - User: u, - EmailData: emailData, - } + input := v0hooks.NewSendEmailInput( + r, + u, + emailData, + ) output := v0hooks.SendEmailOutput{} - return a.hooksMgr.InvokeHook(tx, r, &input, &output) + return a.hooksMgr.InvokeHook(tx, r, input, &output) } // Increment email send operations here, since this metric is meant to count number of mail diff --git a/internal/api/mfa.go b/internal/api/mfa.go index 6415e834a..4b6467ead 100644 --- a/internal/api/mfa.go +++ b/internal/api/mfa.go @@ -408,16 +408,17 @@ func (a *API) challengePhoneFactor(w http.ResponseWriter, r *http.Request) error phone := factor.Phone.String() if config.Hook.SendSMS.Enabled { - input := v0hooks.SendSMSInput{ - User: user, - SMS: v0hooks.SMS{ + input := v0hooks.NewSendSMSInput( + r, + user, + v0hooks.SMS{ OTP: otp, SMSType: "mfa", Phone: phone, }, - } + ) output := v0hooks.SendSMSOutput{} - err := a.hooksMgr.InvokeHook(db, r, &input, &output) + err := a.hooksMgr.InvokeHook(db, r, input, &output) if err != nil { return apierrors.NewInternalServerError("error invoking hook") } @@ -648,15 +649,16 @@ func (a *API) verifyTOTPFactor(w http.ResponseWriter, r *http.Request, params *V }) if config.Hook.MFAVerificationAttempt.Enabled { - input := v0hooks.MFAVerificationAttemptInput{ - UserID: user.ID, - FactorID: factor.ID, - FactorType: factor.FactorType, - Valid: valid, - } + input := v0hooks.NewMFAVerificationAttemptInput( + r, + user.ID, + factor.ID, + factor.FactorType, + valid, + ) output := v0hooks.MFAVerificationAttemptOutput{} - err := a.hooksMgr.InvokeHook(nil, r, &input, &output) + err := a.hooksMgr.InvokeHook(nil, r, input, &output) if err != nil { return err } @@ -799,15 +801,16 @@ func (a *API) verifyPhoneFactor(w http.ResponseWriter, r *http.Request, params * valid = subtle.ConstantTimeCompare([]byte(otpCode), []byte(params.Code)) == 1 } if config.Hook.MFAVerificationAttempt.Enabled { - input := v0hooks.MFAVerificationAttemptInput{ - UserID: user.ID, - FactorID: factor.ID, - FactorType: factor.FactorType, - Valid: valid, - } + input := v0hooks.NewMFAVerificationAttemptInput( + r, + user.ID, + factor.ID, + factor.FactorType, + valid, + ) output := v0hooks.MFAVerificationAttemptOutput{} - err := a.hooksMgr.InvokeHook(nil, r, &input, &output) + err := a.hooksMgr.InvokeHook(nil, r, input, &output) if err != nil { return err } diff --git a/internal/api/phone.go b/internal/api/phone.go index fbd940bcb..77f46ca29 100644 --- a/internal/api/phone.go +++ b/internal/api/phone.go @@ -95,15 +95,16 @@ func (a *API) sendPhoneConfirmation(r *http.Request, tx *storage.Connection, use otp = crypto.GenerateOtp(config.Sms.OtpLength) if config.Hook.SendSMS.Enabled { - input := v0hooks.SendSMSInput{ - User: user, - SMS: v0hooks.SMS{ + input := v0hooks.NewSendSMSInput( + r, + user, + v0hooks.SMS{ OTP: otp, Phone: phone, }, - } + ) output := v0hooks.SendSMSOutput{} - err := a.hooksMgr.InvokeHook(tx, r, &input, &output) + err := a.hooksMgr.InvokeHook(tx, r, input, &output) if err != nil { return "", err } diff --git a/internal/api/token.go b/internal/api/token.go index 38bcf31fb..3a1acea09 100644 --- a/internal/api/token.go +++ b/internal/api/token.go @@ -152,12 +152,13 @@ func (a *API) ResourceOwnerPasswordGrant(ctx context.Context, w http.ResponseWri } if config.Hook.PasswordVerificationAttempt.Enabled { - input := v0hooks.PasswordVerificationAttemptInput{ - UserID: user.ID, - Valid: isValidPassword, - } + input := v0hooks.NewPasswordVerificationAttemptInput( + r, + user.ID, + isValidPassword, + ) output := v0hooks.PasswordVerificationAttemptOutput{} - if err := a.hooksMgr.InvokeHook(nil, r, &input, &output); err != nil { + if err := a.hooksMgr.InvokeHook(nil, r, input, &output); err != nil { return err } diff --git a/internal/hooks/v0hooks/v0hooks.go b/internal/hooks/v0hooks/v0hooks.go index c7c102bf4..dd33f09dc 100644 --- a/internal/hooks/v0hooks/v0hooks.go +++ b/internal/hooks/v0hooks/v0hooks.go @@ -115,20 +115,50 @@ type AccessTokenClaims struct { } type MFAVerificationAttemptInput struct { + Metadata *Metadata `json:"metadata"` UserID uuid.UUID `json:"user_id"` FactorID uuid.UUID `json:"factor_id"` FactorType string `json:"factor_type"` Valid bool `json:"valid"` } +func NewMFAVerificationAttemptInput( + r *http.Request, + userID uuid.UUID, + factorID uuid.UUID, + factorType string, + valid bool, +) *MFAVerificationAttemptInput { + return &MFAVerificationAttemptInput{ + Metadata: NewMetadata(r, MFAVerification), + UserID: userID, + FactorID: factorID, + FactorType: factorType, + Valid: valid, + } +} + type MFAVerificationAttemptOutput struct { Decision string `json:"decision"` Message string `json:"message"` } type PasswordVerificationAttemptInput struct { - UserID uuid.UUID `json:"user_id"` - Valid bool `json:"valid"` + Metadata *Metadata `json:"metadata"` + UserID uuid.UUID `json:"user_id"` + Valid bool `json:"valid"` +} + +func NewPasswordVerificationAttemptInput( + r *http.Request, + userID uuid.UUID, + valid bool, +) *PasswordVerificationAttemptInput { + return &PasswordVerificationAttemptInput{ + Metadata: NewMetadata(r, PasswordVerification), + UserID: userID, + Valid: valid, + } } type PasswordVerificationAttemptOutput struct { @@ -138,11 +168,26 @@ type PasswordVerificationAttemptOutput struct { } type CustomAccessTokenInput struct { + Metadata *Metadata `json:"metadata"` UserID uuid.UUID `json:"user_id"` Claims *AccessTokenClaims `json:"claims"` AuthenticationMethod string `json:"authentication_method"` } +func NewCustomAccessTokenInput( + r *http.Request, + userID uuid.UUID, + claims *AccessTokenClaims, + authenticationMethod string, +) *CustomAccessTokenInput { + return &CustomAccessTokenInput{ + Metadata: NewMetadata(r, CustomizeAccessToken), + UserID: userID, + Claims: claims, + AuthenticationMethod: authenticationMethod, + } +} + type CustomAccessTokenOutput struct { Claims map[string]any `json:"claims"` } @@ -178,17 +223,43 @@ func (o *CustomAccessTokenOutput) UnmarshalJSON(b []byte) error { } type SendSMSInput struct { - User *models.User `json:"user,omitempty"` - SMS SMS `json:"sms,omitempty"` + Metadata *Metadata `json:"metadata"` + User *models.User `json:"user,omitempty"` + SMS SMS `json:"sms,omitempty"` +} + +func NewSendSMSInput( + r *http.Request, + user *models.User, + sms SMS, +) *SendSMSInput { + return &SendSMSInput{ + Metadata: NewMetadata(r, SendSMS), + User: user, + SMS: sms, + } } type SendSMSOutput struct { } type SendEmailInput struct { + Metadata *Metadata `json:"metadata"` User *models.User `json:"user"` EmailData mailer.EmailData `json:"email_data"` } +func NewSendEmailInput( + r *http.Request, + user *models.User, + emailData mailer.EmailData, +) *SendEmailInput { + return &SendEmailInput{ + Metadata: NewMetadata(r, SendEmail), + User: user, + EmailData: emailData, + } +} + type SendEmailOutput struct { } diff --git a/internal/tokens/service.go b/internal/tokens/service.go index b2430ddca..11d807a8b 100644 --- a/internal/tokens/service.go +++ b/internal/tokens/service.go @@ -696,11 +696,12 @@ func (s *Service) GenerateAccessToken(r *http.Request, tx *storage.Connection, p var gotrueClaims jwt.Claims = claims if config.Hook.CustomAccessToken.Enabled { - input := &v0hooks.CustomAccessTokenInput{ - UserID: params.User.ID, - Claims: claims, - AuthenticationMethod: params.AuthenticationMethod.String(), - } + input := v0hooks.NewCustomAccessTokenInput( + r, + params.User.ID, + claims, + params.AuthenticationMethod.String(), + ) output := &v0hooks.CustomAccessTokenOutput{}