Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,6 @@
*.out
.vscode
*.cov
Dockerfile
Dockerfile
# ralph progress logs
progress-*.txt
21 changes: 21 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,27 @@ There are several ways to adjust functionality of the library:

All of the interfaces above have corresponding Func adapters - `SecretFunc`, `ClaimsUpdFunc`, `ValidatorFunc` and `UserUpdFunc`.

Additionally, `ErrorHandlerFunc` type is available in the middleware package for custom error handling on authentication failures:

```go
options := auth.Opts{
// ... other options ...
ErrorHandler: middleware.ErrorHandlerFunc(func(w http.ResponseWriter, r *http.Request, code int, err error) {
// return JSON error for API routes
if strings.HasPrefix(r.URL.Path, "/api/") {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(code)
json.NewEncoder(w).Encode(map[string]string{"error": "authentication required"})
return
}
// redirect to login page for web routes
http.Redirect(w, r, "/login", http.StatusFound)
}),
}
```

By default (when `ErrorHandler` is nil), the middleware returns "Unauthorized" (401) or "Access denied" (403) text responses.

### Implementing black list logic or some other filters

Restricting some users or some tokens is two step process:
Expand Down
15 changes: 8 additions & 7 deletions v2/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,17 +68,17 @@ type Opts struct {
AvatarRoutePath string // avatar routing prefix, i.e. "/api/v1/avatar", default `/avatar`
UseGravatar bool // for email based auth (verified provider) use gravatar service

AdminPasswd string // if presented, allows basic auth with user admin and given password
BasicAuthChecker middleware.BasicAuthFunc // user custom checker for basic auth, if one defined then "AdminPasswd" will ignored
AudienceReader token.Audience // list of allowed aud values, default (empty) allows any
AudSecrets bool // allow multiple secrets (secret per aud)
Logger logger.L // logger interface, default is no logging at all
RefreshCache middleware.RefreshCache // optional cache to keep refreshed tokens
AdminPasswd string // if presented, allows basic auth with user admin and given password
BasicAuthChecker middleware.BasicAuthFunc // user custom checker for basic auth, if one defined then "AdminPasswd" will ignored
AudienceReader token.Audience // list of allowed aud values, default (empty) allows any
AudSecrets bool // allow multiple secrets (secret per aud)
Logger logger.L // logger interface, default is no logging at all
RefreshCache middleware.RefreshCache // optional cache to keep refreshed tokens
ErrorHandler middleware.ErrorHandlerFunc // custom error handler for auth failures
}

// NewService initializes everything
func NewService(opts Opts) (res *Service) {

res = &Service{
opts: opts,
logger: opts.Logger,
Expand All @@ -87,6 +87,7 @@ func NewService(opts Opts) (res *Service) {
AdminPasswd: opts.AdminPasswd,
BasicAuthChecker: opts.BasicAuthChecker,
RefreshCache: opts.RefreshCache,
ErrorHandler: opts.ErrorHandler,
},
issuer: opts.Issuer,
useGravatar: opts.UseGravatar,
Expand Down
33 changes: 28 additions & 5 deletions v2/middleware/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ type Authenticator struct {
AdminPasswd string
BasicAuthChecker BasicAuthFunc
RefreshCache RefreshCache
ErrorHandler ErrorHandlerFunc // custom error handler for auth failures
}

// RefreshCache defines interface storing and retrieving refreshed tokens
Expand All @@ -45,6 +46,13 @@ type TokenService interface {
// The second return parameter `User` need for add user claims into context of request.
type BasicAuthFunc func(user, passwd string) (ok bool, userInfo token.User, err error)

// ErrorHandlerFunc type is an adapter to allow custom error handling for auth failures.
// It receives the suggested HTTP status code and the error that caused the auth failure.
// The handler can respond with custom status codes, HTML pages, redirects, or JSON responses.
// Status codes are typically http.StatusUnauthorized (401) for auth failures
// or http.StatusForbidden (403) for permission denied.
type ErrorHandlerFunc func(w http.ResponseWriter, r *http.Request, statusCode int, err error)

// adminUser sets claims for an optional basic auth
var adminUser = token.User{
ID: "admin",
Expand Down Expand Up @@ -73,7 +81,7 @@ func (a *Authenticator) auth(reqAuth bool) func(http.Handler) http.Handler {
return
}
a.Logf("[DEBUG] auth failed, %v", err)
http.Error(w, "Unauthorized", http.StatusUnauthorized)
a.errResponse(w, r, http.StatusUnauthorized, err)
}

f := func(h http.Handler) http.Handler {
Expand Down Expand Up @@ -208,12 +216,12 @@ func (a *Authenticator) AdminOnly(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
user, err := token.GetUserInfo(r)
if err != nil {
http.Error(w, "Unauthorized", http.StatusUnauthorized)
a.errResponse(w, r, http.StatusUnauthorized, err)
return
}

if !user.IsAdmin() {
http.Error(w, "Access denied", http.StatusForbidden)
a.errResponse(w, r, http.StatusForbidden, fmt.Errorf("user %s is not admin", user.Name))
return
}
next.ServeHTTP(w, r)
Expand Down Expand Up @@ -242,6 +250,21 @@ func (a *Authenticator) basicAdminUser(r *http.Request) bool {
return true
}

// errResponse calls ErrorHandler if set, otherwise returns default http error
func (a *Authenticator) errResponse(w http.ResponseWriter, r *http.Request, code int, err error) {
if a.ErrorHandler != nil {
a.ErrorHandler(w, r, code, err)
return
}
// preserve original error messages for backward compatibility
switch code {
case http.StatusForbidden:
http.Error(w, "Access denied", code)
default:
http.Error(w, "Unauthorized", code)
}
}

// RBAC middleware allows role based control for routes
// this handler internally wrapped with auth(true) to avoid situation if RBAC defined without prior Auth
func (a *Authenticator) RBAC(roles ...string) func(http.Handler) http.Handler {
Expand All @@ -250,7 +273,7 @@ func (a *Authenticator) RBAC(roles ...string) func(http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
user, err := token.GetUserInfo(r)
if err != nil {
http.Error(w, "Unauthorized", http.StatusUnauthorized)
a.errResponse(w, r, http.StatusUnauthorized, err)
return
}

Expand All @@ -262,7 +285,7 @@ func (a *Authenticator) RBAC(roles ...string) func(http.Handler) http.Handler {
}
}
if !matched {
http.Error(w, "Access denied", http.StatusForbidden)
a.errResponse(w, r, http.StatusForbidden, fmt.Errorf("user %s role %s not in allowed roles", user.Name, user.Role))
return
}
h.ServeHTTP(w, r)
Expand Down
228 changes: 228 additions & 0 deletions v2/middleware/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -565,3 +565,231 @@ func (c *testRefreshCache) Set(key string, value token.Claims) {
defer c.Unlock()
c.data[key] = value
}

func TestAuthWithCustomErrorHandler(t *testing.T) {
a := makeTestAuth(t)

var capturedErr error
var capturedStatusCode int
a.ErrorHandler = func(w http.ResponseWriter, r *http.Request, statusCode int, err error) {
capturedErr = err
capturedStatusCode = statusCode
http.Error(w, "Custom error page", http.StatusTeapot)
}

mux := http.NewServeMux()
handler := func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(201)
}
mux.Handle("/auth", a.Auth(http.HandlerFunc(handler)))
server := httptest.NewServer(mux)
defer server.Close()

client := &http.Client{Timeout: 5 * time.Second}

t.Run("custom handler called on auth failure", func(t *testing.T) {
capturedErr = nil
capturedStatusCode = 0
req, err := http.NewRequest("GET", server.URL+"/auth", http.NoBody)
require.NoError(t, err)

resp, err := client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()

assert.Equal(t, http.StatusTeapot, resp.StatusCode)
assert.Equal(t, http.StatusUnauthorized, capturedStatusCode)
assert.NotNil(t, capturedErr)
assert.Contains(t, capturedErr.Error(), "can't get token")

data, err := io.ReadAll(resp.Body)
require.NoError(t, err)
assert.Equal(t, "Custom error page\n", string(data))
})

t.Run("custom handler called with invalid token", func(t *testing.T) {
capturedErr = nil
capturedStatusCode = 0
req, err := http.NewRequest("GET", server.URL+"/auth", http.NoBody)
require.NoError(t, err)
req.Header.Add("X-JWT", "invalid.token.here")

resp, err := client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()

assert.Equal(t, http.StatusTeapot, resp.StatusCode)
assert.Equal(t, http.StatusUnauthorized, capturedStatusCode)
assert.NotNil(t, capturedErr)
})

t.Run("valid token bypasses error handler", func(t *testing.T) {
capturedErr = nil
capturedStatusCode = 0
req, err := http.NewRequest("GET", server.URL+"/auth", http.NoBody)
require.NoError(t, err)
req.Header.Add("X-JWT", testJwtValid)

resp, err := client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()

assert.Equal(t, 201, resp.StatusCode)
assert.Nil(t, capturedErr)
assert.Equal(t, 0, capturedStatusCode)
})
}

func TestAdminOnlyWithCustomErrorHandler(t *testing.T) {
a := makeTestAuth(t)

var capturedErr error
var capturedStatusCode int
a.ErrorHandler = func(w http.ResponseWriter, r *http.Request, statusCode int, err error) {
capturedErr = err
capturedStatusCode = statusCode
http.Error(w, "Custom admin error", http.StatusTeapot)
}

mux := http.NewServeMux()
handler := func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(201)
}
mux.Handle("/admin", a.AdminOnly(http.HandlerFunc(handler)))
server := httptest.NewServer(mux)
defer server.Close()

client := &http.Client{Timeout: 5 * time.Second}

t.Run("custom handler called on auth failure", func(t *testing.T) {
capturedErr = nil
capturedStatusCode = 0
req, err := http.NewRequest("GET", server.URL+"/admin", http.NoBody)
require.NoError(t, err)

resp, err := client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()

assert.Equal(t, http.StatusTeapot, resp.StatusCode)
assert.Equal(t, http.StatusUnauthorized, capturedStatusCode)
assert.NotNil(t, capturedErr)
})

t.Run("custom handler called when not admin", func(t *testing.T) {
// make user non-admin for this test
originalAdmin := adminUser.Attributes["admin"]
adminUser.SetAdmin(false)
defer func() { adminUser.Attributes["admin"] = originalAdmin }()

capturedErr = nil
capturedStatusCode = 0
req, err := http.NewRequest("GET", server.URL+"/admin", http.NoBody)
require.NoError(t, err)
req.SetBasicAuth("admin", "123456")

resp, err := client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()

assert.Equal(t, http.StatusTeapot, resp.StatusCode)
assert.Equal(t, http.StatusForbidden, capturedStatusCode)
assert.NotNil(t, capturedErr)
assert.Contains(t, capturedErr.Error(), "not admin")
})

t.Run("admin access bypasses error handler", func(t *testing.T) {
// ensure admin is set
adminUser.SetAdmin(true)

capturedErr = nil
capturedStatusCode = 0
req, err := http.NewRequest("GET", server.URL+"/admin", http.NoBody)
require.NoError(t, err)
req.SetBasicAuth("admin", "123456")

resp, err := client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()

assert.Equal(t, 201, resp.StatusCode)
assert.Nil(t, capturedErr)
assert.Equal(t, 0, capturedStatusCode)
})
}

func TestRBACWithCustomErrorHandler(t *testing.T) {
a := makeTestAuth(t)

var capturedErr error
var capturedStatusCode int
a.ErrorHandler = func(w http.ResponseWriter, r *http.Request, statusCode int, err error) {
capturedErr = err
capturedStatusCode = statusCode
http.Error(w, "Custom RBAC error", http.StatusTeapot)
}

mux := http.NewServeMux()
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(201)
})
mux.Handle("/employees", a.RBAC("employee")(handler))
server := httptest.NewServer(mux)
defer server.Close()

client := &http.Client{Timeout: 5 * time.Second}

t.Run("custom handler called on auth failure", func(t *testing.T) {
capturedErr = nil
capturedStatusCode = 0
req, err := http.NewRequest("GET", server.URL+"/employees", http.NoBody)
require.NoError(t, err)

resp, err := client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()

assert.Equal(t, http.StatusTeapot, resp.StatusCode)
assert.Equal(t, http.StatusUnauthorized, capturedStatusCode)
assert.NotNil(t, capturedErr)
})

t.Run("custom handler called when role mismatch", func(t *testing.T) {
capturedErr = nil
capturedStatusCode = 0
expiration := int(365 * 24 * time.Hour.Seconds())
req, err := http.NewRequest("GET", server.URL+"/employees", http.NoBody)
require.NoError(t, err)
// testJwtValid does not have an employee role
req.AddCookie(&http.Cookie{Name: "JWT", Value: testJwtValid, HttpOnly: true, Path: "/", MaxAge: expiration, Secure: false})
req.Header.Add("X-XSRF-TOKEN", "random id")

resp, err := client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()

assert.Equal(t, http.StatusTeapot, resp.StatusCode)
assert.Equal(t, http.StatusForbidden, capturedStatusCode)
assert.NotNil(t, capturedErr)
assert.Contains(t, capturedErr.Error(), "not in allowed roles")
})

t.Run("correct role bypasses error handler", func(t *testing.T) {
capturedErr = nil
capturedStatusCode = 0
expiration := int(365 * 24 * time.Hour.Seconds())
req, err := http.NewRequest("GET", server.URL+"/employees", http.NoBody)
require.NoError(t, err)
// testJwtWithRole has employee role
req.AddCookie(&http.Cookie{Name: "JWT", Value: testJwtWithRole, HttpOnly: true, Path: "/", MaxAge: expiration, Secure: false})
req.Header.Add("X-XSRF-TOKEN", "random id")

resp, err := client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()

assert.Equal(t, 201, resp.StatusCode)
assert.Nil(t, capturedErr)
assert.Equal(t, 0, capturedStatusCode)
})
}
Loading