From 1d8f9540b10814ceb08ad32589e50a0951cf8efe Mon Sep 17 00:00:00 2001 From: Umputun Date: Mon, 19 Jan 2026 14:32:57 -0600 Subject: [PATCH 1/2] feat(middleware): add customizable error handler for auth failures Add ErrorHandler option to customize error responses when authentication fails. The handler receives the HTTP status code (401/403) and the error, allowing custom HTML pages, redirects, or JSON responses. Related to #264 --- v2/auth.go | 27 ++++- v2/middleware/auth.go | 33 +++++- v2/middleware/auth_test.go | 228 +++++++++++++++++++++++++++++++++++++ 3 files changed, 277 insertions(+), 11 deletions(-) diff --git a/v2/auth.go b/v2/auth.go index efb4ddb6..537e4d8f 100644 --- a/v2/auth.go +++ b/v2/auth.go @@ -68,17 +68,31 @@ type Opts struct { AvatarRoutePath string // avatar routing prefix, i.e. "/api/v1/avatar", default `/avatar` UseGravatar bool // for email based auth (verified provider) use gravatar service - AdminPasswd string // if presented, allows basic auth with user admin and given password - BasicAuthChecker middleware.BasicAuthFunc // user custom checker for basic auth, if one defined then "AdminPasswd" will ignored - AudienceReader token.Audience // list of allowed aud values, default (empty) allows any - AudSecrets bool // allow multiple secrets (secret per aud) - Logger logger.L // logger interface, default is no logging at all - RefreshCache middleware.RefreshCache // optional cache to keep refreshed tokens + AdminPasswd string // if presented, allows basic auth with user admin and given password + BasicAuthChecker middleware.BasicAuthFunc // user custom checker for basic auth, if one defined then "AdminPasswd" will ignored + AudienceReader token.Audience // list of allowed aud values, default (empty) allows any + AudSecrets bool // allow multiple secrets (secret per aud) + Logger logger.L // logger interface, default is no logging at all + RefreshCache middleware.RefreshCache // optional cache to keep refreshed tokens + ErrorHandler middleware.ErrorHandlerFunc // custom error handler for auth failures } // NewService initializes everything func NewService(opts Opts) (res *Service) { + errorHandler := opts.ErrorHandler + if errorHandler == nil { + // default handler preserves original error messages for backward compatibility + errorHandler = func(w http.ResponseWriter, _ *http.Request, code int, _ error) { + switch code { + case http.StatusForbidden: + http.Error(w, "Access denied", code) + default: + http.Error(w, "Unauthorized", code) + } + } + } + res = &Service{ opts: opts, logger: opts.Logger, @@ -87,6 +101,7 @@ func NewService(opts Opts) (res *Service) { AdminPasswd: opts.AdminPasswd, BasicAuthChecker: opts.BasicAuthChecker, RefreshCache: opts.RefreshCache, + ErrorHandler: errorHandler, }, issuer: opts.Issuer, useGravatar: opts.UseGravatar, diff --git a/v2/middleware/auth.go b/v2/middleware/auth.go index e25968f8..bb375495 100644 --- a/v2/middleware/auth.go +++ b/v2/middleware/auth.go @@ -24,6 +24,7 @@ type Authenticator struct { AdminPasswd string BasicAuthChecker BasicAuthFunc RefreshCache RefreshCache + ErrorHandler ErrorHandlerFunc // custom error handler for auth failures } // RefreshCache defines interface storing and retrieving refreshed tokens @@ -45,6 +46,13 @@ type TokenService interface { // The second return parameter `User` need for add user claims into context of request. type BasicAuthFunc func(user, passwd string) (ok bool, userInfo token.User, err error) +// ErrorHandlerFunc type is an adapter to allow custom error handling for auth failures. +// It receives the suggested HTTP status code and the error that caused the auth failure. +// The handler can respond with custom status codes, HTML pages, redirects, or JSON responses. +// Status codes are typically http.StatusUnauthorized (401) for auth failures +// or http.StatusForbidden (403) for permission denied. +type ErrorHandlerFunc func(w http.ResponseWriter, r *http.Request, statusCode int, err error) + // adminUser sets claims for an optional basic auth var adminUser = token.User{ ID: "admin", @@ -73,7 +81,7 @@ func (a *Authenticator) auth(reqAuth bool) func(http.Handler) http.Handler { return } a.Logf("[DEBUG] auth failed, %v", err) - http.Error(w, "Unauthorized", http.StatusUnauthorized) + a.errResponse(w, r, http.StatusUnauthorized, err) } f := func(h http.Handler) http.Handler { @@ -208,12 +216,12 @@ func (a *Authenticator) AdminOnly(next http.Handler) http.Handler { fn := func(w http.ResponseWriter, r *http.Request) { user, err := token.GetUserInfo(r) if err != nil { - http.Error(w, "Unauthorized", http.StatusUnauthorized) + a.errResponse(w, r, http.StatusUnauthorized, err) return } if !user.IsAdmin() { - http.Error(w, "Access denied", http.StatusForbidden) + a.errResponse(w, r, http.StatusForbidden, fmt.Errorf("user %s is not admin", user.Name)) return } next.ServeHTTP(w, r) @@ -242,6 +250,21 @@ func (a *Authenticator) basicAdminUser(r *http.Request) bool { return true } +// errResponse calls ErrorHandler if set, otherwise returns default http error +func (a *Authenticator) errResponse(w http.ResponseWriter, r *http.Request, code int, err error) { + if a.ErrorHandler != nil { + a.ErrorHandler(w, r, code, err) + return + } + // preserve original error messages for backward compatibility + switch code { + case http.StatusForbidden: + http.Error(w, "Access denied", code) + default: + http.Error(w, "Unauthorized", code) + } +} + // RBAC middleware allows role based control for routes // this handler internally wrapped with auth(true) to avoid situation if RBAC defined without prior Auth func (a *Authenticator) RBAC(roles ...string) func(http.Handler) http.Handler { @@ -250,7 +273,7 @@ func (a *Authenticator) RBAC(roles ...string) func(http.Handler) http.Handler { fn := func(w http.ResponseWriter, r *http.Request) { user, err := token.GetUserInfo(r) if err != nil { - http.Error(w, "Unauthorized", http.StatusUnauthorized) + a.errResponse(w, r, http.StatusUnauthorized, err) return } @@ -262,7 +285,7 @@ func (a *Authenticator) RBAC(roles ...string) func(http.Handler) http.Handler { } } if !matched { - http.Error(w, "Access denied", http.StatusForbidden) + a.errResponse(w, r, http.StatusForbidden, fmt.Errorf("user %s role %s not in allowed roles", user.Name, user.Role)) return } h.ServeHTTP(w, r) diff --git a/v2/middleware/auth_test.go b/v2/middleware/auth_test.go index 9ce53229..3eeec4b3 100644 --- a/v2/middleware/auth_test.go +++ b/v2/middleware/auth_test.go @@ -565,3 +565,231 @@ func (c *testRefreshCache) Set(key string, value token.Claims) { defer c.Unlock() c.data[key] = value } + +func TestAuthWithCustomErrorHandler(t *testing.T) { + a := makeTestAuth(t) + + var capturedErr error + var capturedStatusCode int + a.ErrorHandler = func(w http.ResponseWriter, r *http.Request, statusCode int, err error) { + capturedErr = err + capturedStatusCode = statusCode + http.Error(w, "Custom error page", http.StatusTeapot) + } + + mux := http.NewServeMux() + handler := func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(201) + } + mux.Handle("/auth", a.Auth(http.HandlerFunc(handler))) + server := httptest.NewServer(mux) + defer server.Close() + + client := &http.Client{Timeout: 5 * time.Second} + + t.Run("custom handler called on auth failure", func(t *testing.T) { + capturedErr = nil + capturedStatusCode = 0 + req, err := http.NewRequest("GET", server.URL+"/auth", http.NoBody) + require.NoError(t, err) + + resp, err := client.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + assert.Equal(t, http.StatusTeapot, resp.StatusCode) + assert.Equal(t, http.StatusUnauthorized, capturedStatusCode) + assert.NotNil(t, capturedErr) + assert.Contains(t, capturedErr.Error(), "can't get token") + + data, err := io.ReadAll(resp.Body) + require.NoError(t, err) + assert.Equal(t, "Custom error page\n", string(data)) + }) + + t.Run("custom handler called with invalid token", func(t *testing.T) { + capturedErr = nil + capturedStatusCode = 0 + req, err := http.NewRequest("GET", server.URL+"/auth", http.NoBody) + require.NoError(t, err) + req.Header.Add("X-JWT", "invalid.token.here") + + resp, err := client.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + assert.Equal(t, http.StatusTeapot, resp.StatusCode) + assert.Equal(t, http.StatusUnauthorized, capturedStatusCode) + assert.NotNil(t, capturedErr) + }) + + t.Run("valid token bypasses error handler", func(t *testing.T) { + capturedErr = nil + capturedStatusCode = 0 + req, err := http.NewRequest("GET", server.URL+"/auth", http.NoBody) + require.NoError(t, err) + req.Header.Add("X-JWT", testJwtValid) + + resp, err := client.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + assert.Equal(t, 201, resp.StatusCode) + assert.Nil(t, capturedErr) + assert.Equal(t, 0, capturedStatusCode) + }) +} + +func TestAdminOnlyWithCustomErrorHandler(t *testing.T) { + a := makeTestAuth(t) + + var capturedErr error + var capturedStatusCode int + a.ErrorHandler = func(w http.ResponseWriter, r *http.Request, statusCode int, err error) { + capturedErr = err + capturedStatusCode = statusCode + http.Error(w, "Custom admin error", http.StatusTeapot) + } + + mux := http.NewServeMux() + handler := func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(201) + } + mux.Handle("/admin", a.AdminOnly(http.HandlerFunc(handler))) + server := httptest.NewServer(mux) + defer server.Close() + + client := &http.Client{Timeout: 5 * time.Second} + + t.Run("custom handler called on auth failure", func(t *testing.T) { + capturedErr = nil + capturedStatusCode = 0 + req, err := http.NewRequest("GET", server.URL+"/admin", http.NoBody) + require.NoError(t, err) + + resp, err := client.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + assert.Equal(t, http.StatusTeapot, resp.StatusCode) + assert.Equal(t, http.StatusUnauthorized, capturedStatusCode) + assert.NotNil(t, capturedErr) + }) + + t.Run("custom handler called when not admin", func(t *testing.T) { + // make user non-admin for this test + originalAdmin := adminUser.Attributes["admin"] + adminUser.SetAdmin(false) + defer func() { adminUser.Attributes["admin"] = originalAdmin }() + + capturedErr = nil + capturedStatusCode = 0 + req, err := http.NewRequest("GET", server.URL+"/admin", http.NoBody) + require.NoError(t, err) + req.SetBasicAuth("admin", "123456") + + resp, err := client.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + assert.Equal(t, http.StatusTeapot, resp.StatusCode) + assert.Equal(t, http.StatusForbidden, capturedStatusCode) + assert.NotNil(t, capturedErr) + assert.Contains(t, capturedErr.Error(), "not admin") + }) + + t.Run("admin access bypasses error handler", func(t *testing.T) { + // ensure admin is set + adminUser.SetAdmin(true) + + capturedErr = nil + capturedStatusCode = 0 + req, err := http.NewRequest("GET", server.URL+"/admin", http.NoBody) + require.NoError(t, err) + req.SetBasicAuth("admin", "123456") + + resp, err := client.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + assert.Equal(t, 201, resp.StatusCode) + assert.Nil(t, capturedErr) + assert.Equal(t, 0, capturedStatusCode) + }) +} + +func TestRBACWithCustomErrorHandler(t *testing.T) { + a := makeTestAuth(t) + + var capturedErr error + var capturedStatusCode int + a.ErrorHandler = func(w http.ResponseWriter, r *http.Request, statusCode int, err error) { + capturedErr = err + capturedStatusCode = statusCode + http.Error(w, "Custom RBAC error", http.StatusTeapot) + } + + mux := http.NewServeMux() + handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(201) + }) + mux.Handle("/employees", a.RBAC("employee")(handler)) + server := httptest.NewServer(mux) + defer server.Close() + + client := &http.Client{Timeout: 5 * time.Second} + + t.Run("custom handler called on auth failure", func(t *testing.T) { + capturedErr = nil + capturedStatusCode = 0 + req, err := http.NewRequest("GET", server.URL+"/employees", http.NoBody) + require.NoError(t, err) + + resp, err := client.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + assert.Equal(t, http.StatusTeapot, resp.StatusCode) + assert.Equal(t, http.StatusUnauthorized, capturedStatusCode) + assert.NotNil(t, capturedErr) + }) + + t.Run("custom handler called when role mismatch", func(t *testing.T) { + capturedErr = nil + capturedStatusCode = 0 + expiration := int(365 * 24 * time.Hour.Seconds()) + req, err := http.NewRequest("GET", server.URL+"/employees", http.NoBody) + require.NoError(t, err) + // testJwtValid does not have an employee role + req.AddCookie(&http.Cookie{Name: "JWT", Value: testJwtValid, HttpOnly: true, Path: "/", MaxAge: expiration, Secure: false}) + req.Header.Add("X-XSRF-TOKEN", "random id") + + resp, err := client.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + assert.Equal(t, http.StatusTeapot, resp.StatusCode) + assert.Equal(t, http.StatusForbidden, capturedStatusCode) + assert.NotNil(t, capturedErr) + assert.Contains(t, capturedErr.Error(), "not in allowed roles") + }) + + t.Run("correct role bypasses error handler", func(t *testing.T) { + capturedErr = nil + capturedStatusCode = 0 + expiration := int(365 * 24 * time.Hour.Seconds()) + req, err := http.NewRequest("GET", server.URL+"/employees", http.NoBody) + require.NoError(t, err) + // testJwtWithRole has employee role + req.AddCookie(&http.Cookie{Name: "JWT", Value: testJwtWithRole, HttpOnly: true, Path: "/", MaxAge: expiration, Secure: false}) + req.Header.Add("X-XSRF-TOKEN", "random id") + + resp, err := client.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + assert.Equal(t, 201, resp.StatusCode) + assert.Nil(t, capturedErr) + assert.Equal(t, 0, capturedStatusCode) + }) +} From ef37152cd59143070f5717e972fc68e9ea32b4ba Mon Sep 17 00:00:00 2001 From: Umputun Date: Mon, 19 Jan 2026 15:06:58 -0600 Subject: [PATCH 2/2] fix: address code review findings - remove duplicate default error handler logic from NewService (errResponse already handles the default case) - add ErrorHandler documentation to README.md - add progress-*.txt to .gitignore --- .gitignore | 4 +++- README.md | 21 +++++++++++++++++++++ v2/auth.go | 16 +--------------- 3 files changed, 25 insertions(+), 16 deletions(-) diff --git a/.gitignore b/.gitignore index 2c09a372..17e1eea9 100644 --- a/.gitignore +++ b/.gitignore @@ -12,4 +12,6 @@ *.out .vscode *.cov -Dockerfile \ No newline at end of file +Dockerfile +# ralph progress logs +progress-*.txt diff --git a/README.md b/README.md index 2acd35c7..74e51d78 100644 --- a/README.md +++ b/README.md @@ -448,6 +448,27 @@ There are several ways to adjust functionality of the library: All of the interfaces above have corresponding Func adapters - `SecretFunc`, `ClaimsUpdFunc`, `ValidatorFunc` and `UserUpdFunc`. +Additionally, `ErrorHandlerFunc` type is available in the middleware package for custom error handling on authentication failures: + +```go +options := auth.Opts{ + // ... other options ... + ErrorHandler: middleware.ErrorHandlerFunc(func(w http.ResponseWriter, r *http.Request, code int, err error) { + // return JSON error for API routes + if strings.HasPrefix(r.URL.Path, "/api/") { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(code) + json.NewEncoder(w).Encode(map[string]string{"error": "authentication required"}) + return + } + // redirect to login page for web routes + http.Redirect(w, r, "/login", http.StatusFound) + }), +} +``` + +By default (when `ErrorHandler` is nil), the middleware returns "Unauthorized" (401) or "Access denied" (403) text responses. + ### Implementing black list logic or some other filters Restricting some users or some tokens is two step process: diff --git a/v2/auth.go b/v2/auth.go index 537e4d8f..13a440de 100644 --- a/v2/auth.go +++ b/v2/auth.go @@ -79,20 +79,6 @@ type Opts struct { // NewService initializes everything func NewService(opts Opts) (res *Service) { - - errorHandler := opts.ErrorHandler - if errorHandler == nil { - // default handler preserves original error messages for backward compatibility - errorHandler = func(w http.ResponseWriter, _ *http.Request, code int, _ error) { - switch code { - case http.StatusForbidden: - http.Error(w, "Access denied", code) - default: - http.Error(w, "Unauthorized", code) - } - } - } - res = &Service{ opts: opts, logger: opts.Logger, @@ -101,7 +87,7 @@ func NewService(opts Opts) (res *Service) { AdminPasswd: opts.AdminPasswd, BasicAuthChecker: opts.BasicAuthChecker, RefreshCache: opts.RefreshCache, - ErrorHandler: errorHandler, + ErrorHandler: opts.ErrorHandler, }, issuer: opts.Issuer, useGravatar: opts.UseGravatar,