From 1d17eef42586537ad8998ad1a9b27ade6a3f53a5 Mon Sep 17 00:00:00 2001 From: Axel Etcheverry Date: Tue, 3 Aug 2021 23:39:26 +0200 Subject: [PATCH 1/9] feat(authentication/oauth2): add base of server --- authentication/provider/oauth2/access.go | 52 ++++++++++++++-- authentication/provider/oauth2/authorize.go | 11 ++++ .../provider/oauth2/configuration.go | 49 +++++++++++++++ .../provider/oauth2/configuration_test.go | 26 ++++++++ authentication/provider/oauth2/option.go | 21 +++++++ authentication/provider/oauth2/option_test.go | 23 +++++++ authentication/provider/oauth2/server.go | 60 +++++++++++++++++++ authentication/provider/oauth2/server_test.go | 18 ++++++ 8 files changed, 255 insertions(+), 5 deletions(-) create mode 100644 authentication/provider/oauth2/configuration.go create mode 100644 authentication/provider/oauth2/configuration_test.go create mode 100644 authentication/provider/oauth2/option.go create mode 100644 authentication/provider/oauth2/option_test.go create mode 100644 authentication/provider/oauth2/server.go create mode 100644 authentication/provider/oauth2/server_test.go diff --git a/authentication/provider/oauth2/access.go b/authentication/provider/oauth2/access.go index 864b017..439c8ca 100644 --- a/authentication/provider/oauth2/access.go +++ b/authentication/provider/oauth2/access.go @@ -6,14 +6,56 @@ package oauth2 import ( "context" + "net/http" "time" ) -type AccessToken interface { - GetClient() Client - GetToken() string - IsExpired() bool - GetUserID() string +// AccessRequestType is the type for OAuth2 param `grant_type` +type AccessRequestType string + +const ( + AUTHORIZATION_CODE AccessRequestType = "authorization_code" + REFRESH_TOKEN AccessRequestType = "refresh_token" + PASSWORD AccessRequestType = "password" + CLIENT_CREDENTIALS AccessRequestType = "client_credentials" + ASSERTION AccessRequestType = "assertion" + IMPLICIT AccessRequestType = "__implicit" +) + +// AccessRequest is a request for access tokens +type AccessRequest struct { + Type AccessRequestType + Code string + Client Client + AuthorizeInfo *AuthorizeInfo + AccessInfo *AccessInfo + + // Force finish to use this access data, to allow access data reuse + ForceAccessInfo *AccessInfo + RedirectURI string + Scope string + Username string + Password string + AssertionType string + Assertion string + + // Set if request is authorized + Authorized bool + + // Token expiration in seconds. Change if different from default + Expiration int32 + + // Set if a refresh token should be generated + GenerateRefresh bool + + // Data to be passed to storage. Not used by the library. + UserData interface{} + + // HttpRequest *http.Request for special use + HttpRequest *http.Request + + // Optional code_verifier as described in rfc7636 + CodeVerifier string } type accessCtxKey struct{} diff --git a/authentication/provider/oauth2/authorize.go b/authentication/provider/oauth2/authorize.go index 4da0dad..267dd73 100644 --- a/authentication/provider/oauth2/authorize.go +++ b/authentication/provider/oauth2/authorize.go @@ -6,6 +6,17 @@ package oauth2 import "time" +// AuthorizeRequestType is the type for OAuth param `response_type` +type AuthorizeRequestType string + +const ( + CODE AuthorizeRequestType = "code" + TOKEN AuthorizeRequestType = "token" + + PKCE_PLAIN = "plain" + PKCE_S256 = "S256" +) + // AuthorizeInfo info. type AuthorizeInfo struct { // Client information diff --git a/authentication/provider/oauth2/configuration.go b/authentication/provider/oauth2/configuration.go new file mode 100644 index 0000000..4e02516 --- /dev/null +++ b/authentication/provider/oauth2/configuration.go @@ -0,0 +1,49 @@ +// Copyright 2021 Hyperscale. All rights reserved. +// Use of this source code is governed by a MIT +// license that can be found in the LICENSE file. + +package oauth2 + +import "time" + +// AllowedAuthorizeType is a collection of allowed auth request types +type AllowedAuthorizeType []AuthorizeRequestType + +// Exists returns true if the auth type exists in the list +func (t AllowedAuthorizeType) Exists(rt AuthorizeRequestType) bool { + for _, k := range t { + if k == rt { + return true + } + } + + return false +} + +// AllowedAccessType is a collection of allowed access request types +type AllowedAccessType []AccessRequestType + +// Exists returns true if the access type exists in the list +func (t AllowedAccessType) Exists(rt AccessRequestType) bool { + for _, k := range t { + if k == rt { + return true + } + } + + return false +} + +type Configuration struct { + AuthorizationExpiration time.Duration + + AccessExpiration time.Duration + + AllowGetAccessRequest bool + + // List of allowed authorize types (only CODE by default) + AllowedAuthorizeTypes AllowedAuthorizeType + + // List of allowed access types (only AUTHORIZATION_CODE by default) + AllowedAccessTypes AllowedAccessType +} diff --git a/authentication/provider/oauth2/configuration_test.go b/authentication/provider/oauth2/configuration_test.go new file mode 100644 index 0000000..9bcfcfb --- /dev/null +++ b/authentication/provider/oauth2/configuration_test.go @@ -0,0 +1,26 @@ +// Copyright 2021 Hyperscale. All rights reserved. +// Use of this source code is governed by a MIT +// license that can be found in the LICENSE file. + +package oauth2 + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestAllowedAuthorizeType(t *testing.T) { + types := AllowedAuthorizeType{CODE} + + assert.True(t, types.Exists(CODE)) + assert.False(t, types.Exists(TOKEN)) +} + +func TestAllowedAccessType(t *testing.T) { + types := AllowedAccessType{AUTHORIZATION_CODE, CLIENT_CREDENTIALS} + + assert.True(t, types.Exists(AUTHORIZATION_CODE)) + assert.True(t, types.Exists(CLIENT_CREDENTIALS)) + assert.False(t, types.Exists(ASSERTION)) +} diff --git a/authentication/provider/oauth2/option.go b/authentication/provider/oauth2/option.go new file mode 100644 index 0000000..94522b0 --- /dev/null +++ b/authentication/provider/oauth2/option.go @@ -0,0 +1,21 @@ +// Copyright 2021 Hyperscale. All rights reserved. +// Use of this source code is governed by a MIT +// license that can be found in the LICENSE file. + +package oauth2 + +// Option type. +type Option func(server *Server) + +// WithConfig add config to server +func WithConfig(cfg *Configuration) Option { + return func(server *Server) { + server.cfg = cfg + } +} + +func WithStorage(storage StorageProvider) Option { + return func(server *Server) { + + } +} diff --git a/authentication/provider/oauth2/option_test.go b/authentication/provider/oauth2/option_test.go new file mode 100644 index 0000000..c34d004 --- /dev/null +++ b/authentication/provider/oauth2/option_test.go @@ -0,0 +1,23 @@ +// Copyright 2021 Hyperscale. All rights reserved. +// Use of this source code is governed by a MIT +// license that can be found in the LICENSE file. + +package oauth2 + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestWithConfig(t *testing.T) { + cfg := &Configuration{} + + opt := WithConfig(cfg) + + server := &Server{} + + opt(server) + + assert.Same(t, cfg, server.cfg) +} diff --git a/authentication/provider/oauth2/server.go b/authentication/provider/oauth2/server.go new file mode 100644 index 0000000..2914d8a --- /dev/null +++ b/authentication/provider/oauth2/server.go @@ -0,0 +1,60 @@ +// Copyright 2021 Hyperscale. All rights reserved. +// Use of this source code is governed by a MIT +// license that can be found in the LICENSE file. + +package oauth2 + +import ( + "net/http" +) + +type Server struct { + cfg *Configuration +} + +func NewServer(options ...Option) *Server { + s := &Server{} + + for _, opt := range options { + opt(s) + } + + return s +} + +func (s *Server) HandleAccessRequest(w http.ResponseWriter, r *http.Request) *AccessRequest { + // Only allow GET or POST + if r.Method == http.MethodGet { + if !s.cfg.AllowGetAccessRequest { + //s.setErrorAndLog(w, E_INVALID_REQUEST, errors.New("Request must be POST"), "access_request=%s", "GET request not allowed") + return nil + } + } else if r.Method != http.MethodPost { + //s.setErrorAndLog(w, E_INVALID_REQUEST, errors.New("Request must be POST"), "access_request=%s", "request must be POST") + return nil + } + + if err := r.ParseForm(); err != nil { + //s.setErrorAndLog(w, E_INVALID_REQUEST, err, "access_request=%s", "parsing error") + return nil + } + + grantType := AccessRequestType(r.FormValue("grant_type")) + if s.cfg.AllowedAccessTypes.Exists(grantType) { + switch grantType { + case AUTHORIZATION_CODE: + // return s.handleAuthorizationCodeRequest(w, r) + case REFRESH_TOKEN: + // return s.handleRefreshTokenRequest(w, r) + case PASSWORD: + // return s.handlePasswordRequest(w, r) + case CLIENT_CREDENTIALS: + // return s.handleClientCredentialsRequest(w, r) + case ASSERTION: + // return s.handleAssertionRequest(w, r) + } + } + + // s.setErrorAndLog(w, E_UNSUPPORTED_GRANT_TYPE, nil, "access_request=%s", "unknown grant type") + return nil +} diff --git a/authentication/provider/oauth2/server_test.go b/authentication/provider/oauth2/server_test.go new file mode 100644 index 0000000..8a63e31 --- /dev/null +++ b/authentication/provider/oauth2/server_test.go @@ -0,0 +1,18 @@ +// Copyright 2021 Hyperscale. All rights reserved. +// Use of this source code is governed by a MIT +// license that can be found in the LICENSE file. + +package oauth2 + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestServer(t *testing.T) { + cfg := &Configuration{} + server := NewServer(WithConfig(cfg)) + + assert.Same(t, cfg, server.cfg) +} From 55da8d8a06971b526e4f1640f872ad75f1e952bc Mon Sep 17 00:00:00 2001 From: Axel Etcheverry Date: Fri, 13 Aug 2021 23:33:26 +0200 Subject: [PATCH 2/9] feat(authentication/oauth2): add base of server --- authentication/provider/oauth2/access.go | 6 +- .../provider/oauth2/configuration.go | 39 +- authentication/provider/oauth2/error.go | 65 ++++ authentication/provider/oauth2/handler.go | 335 ++++++++++++++++++ .../provider/oauth2/mock_user_provider.go | 46 +++ authentication/provider/oauth2/option.go | 23 ++ authentication/provider/oauth2/option_test.go | 50 +++ authentication/provider/oauth2/response.go | 166 +++++++++ authentication/provider/oauth2/server.go | 59 +-- authentication/provider/oauth2/storage.go | 2 + .../provider/oauth2/token/generator.go | 1 + .../provider/oauth2/token/mock_generator.go | 38 ++ go.mod | 1 + go.sum | 2 + user/mock_user.go | 14 + user/mock_user_password_salt.go | 14 + user/user.go | 2 + 17 files changed, 816 insertions(+), 47 deletions(-) create mode 100644 authentication/provider/oauth2/error.go create mode 100644 authentication/provider/oauth2/handler.go create mode 100644 authentication/provider/oauth2/response.go create mode 100644 authentication/provider/oauth2/token/mock_generator.go diff --git a/authentication/provider/oauth2/access.go b/authentication/provider/oauth2/access.go index 439c8ca..428940b 100644 --- a/authentication/provider/oauth2/access.go +++ b/authentication/provider/oauth2/access.go @@ -43,7 +43,7 @@ type AccessRequest struct { Authorized bool // Token expiration in seconds. Change if different from default - Expiration int32 + Expiration time.Duration // Set if a refresh token should be generated GenerateRefresh bool @@ -80,7 +80,7 @@ type AccessInfo struct { Client Client // Authorize data, for authorization code - AuthorizeData *AuthorizeInfo + AuthorizeInfo *AuthorizeInfo // Previous access data, for refresh token AccessInfo *AccessInfo @@ -92,7 +92,7 @@ type AccessInfo struct { RefreshToken string // Token expiration in seconds - ExpiresIn int32 + ExpiresIn int64 // Requested scope Scope string diff --git a/authentication/provider/oauth2/configuration.go b/authentication/provider/oauth2/configuration.go index 4e02516..06047fd 100644 --- a/authentication/provider/oauth2/configuration.go +++ b/authentication/provider/oauth2/configuration.go @@ -35,15 +35,46 @@ func (t AllowedAccessType) Exists(rt AccessRequestType) bool { } type Configuration struct { + PrefixURI string + + // Token type to return + TokenType string + + // old + AuthorizationExpiration time.Duration AccessExpiration time.Duration - AllowGetAccessRequest bool - - // List of allowed authorize types (only CODE by default) + // List of allowed authorize types (only CODE by default). AllowedAuthorizeTypes AllowedAuthorizeType - // List of allowed access types (only AUTHORIZATION_CODE by default) + // List of allowed access types (only AUTHORIZATION_CODE by default). AllowedAccessTypes AllowedAccessType + + // HTTP status code to return for errors - default 200 + // Only used if response was created from server. + ErrorStatusCode int + + // If true allows client secret also in params, else only in + // Authorization header - default false. + AllowClientSecretInParams bool + + // If true allows access request using GET, else only POST - default false. + AllowGetAccessRequest bool + + // Require PKCE for code flows for public OAuth clients - default false. + RequirePKCEForPublicClients bool + + // Separator to support multiple URIs in Client.GetRedirectUri(). + // If blank (the default), don't allow multiple URIs. + RedirectUriSeparator string + + // RetainTokenAfter Refresh allows the server to retain the access and + // refresh token for re-use - default false. + RetainTokenAfterRefresh bool +} + +func NewConfiguration() *Configuration { + return &Configuration{} } diff --git a/authentication/provider/oauth2/error.go b/authentication/provider/oauth2/error.go new file mode 100644 index 0000000..e3bf27d --- /dev/null +++ b/authentication/provider/oauth2/error.go @@ -0,0 +1,65 @@ +// Copyright 2021 Hyperscale. All rights reserved. +// Use of this source code is governed by a MIT +// license that can be found in the LICENSE file. + +package oauth2 + +type DefaultErrorID string + +func (e DefaultErrorID) String() string { + return string(e) +} + +const ( + E_INVALID_REQUEST DefaultErrorID = "invalid_request" + E_UNAUTHORIZED_CLIENT DefaultErrorID = "unauthorized_client" + E_ACCESS_DENIED DefaultErrorID = "access_denied" + E_UNSUPPORTED_RESPONSE_TYPE DefaultErrorID = "unsupported_response_type" + E_INVALID_SCOPE DefaultErrorID = "invalid_scope" + E_SERVER_ERROR DefaultErrorID = "server_error" + E_TEMPORARILY_UNAVAILABLE DefaultErrorID = "temporarily_unavailable" + E_UNSUPPORTED_GRANT_TYPE DefaultErrorID = "unsupported_grant_type" + E_INVALID_GRANT DefaultErrorID = "invalid_grant" + E_INVALID_CLIENT DefaultErrorID = "invalid_client" +) + +var ( + deferror *DefaultErrors = NewDefaultErrors() +) + +// Default errors and messages +type DefaultErrors struct { + errormap map[DefaultErrorID]string +} + +// NewDefaultErrors initializes OAuth2 error codes and descriptions. +// http://tools.ietf.org/html/rfc6749#section-4.1.2.1 +// http://tools.ietf.org/html/rfc6749#section-4.2.2.1 +// http://tools.ietf.org/html/rfc6749#section-5.2 +// http://tools.ietf.org/html/rfc6749#section-7.2 +func NewDefaultErrors() *DefaultErrors { + r := &DefaultErrors{ + errormap: make(map[DefaultErrorID]string), + } + + r.errormap[E_INVALID_REQUEST] = "The request is missing a required parameter, includes an invalid parameter value, includes a parameter more than once, or is otherwise malformed." + r.errormap[E_UNAUTHORIZED_CLIENT] = "The client is not authorized to request a token using this method." + r.errormap[E_ACCESS_DENIED] = "The resource owner or authorization server denied the request." + r.errormap[E_UNSUPPORTED_RESPONSE_TYPE] = "The authorization server does not support obtaining a token using this method." + r.errormap[E_INVALID_SCOPE] = "The requested scope is invalid, unknown, or malformed." + r.errormap[E_SERVER_ERROR] = "The authorization server encountered an unexpected condition that prevented it from fulfilling the request." + r.errormap[E_TEMPORARILY_UNAVAILABLE] = "The authorization server is currently unable to handle the request due to a temporary overloading or maintenance of the server." + r.errormap[E_UNSUPPORTED_GRANT_TYPE] = "The authorization grant type is not supported by the authorization server." + r.errormap[E_INVALID_GRANT] = "The provided authorization grant (e.g., authorization code, resource owner credentials) or refresh token is invalid, expired, revoked, does not match the redirection URI used in the authorization request, or was issued to another client." + r.errormap[E_INVALID_CLIENT] = "Client authentication failed (e.g., unknown client, no client authentication included, or unsupported authentication method)." + + return r +} + +func (e *DefaultErrors) Get(id DefaultErrorID) string { + if m, ok := e.errormap[id]; ok { + return m + } + + return id.String() +} diff --git a/authentication/provider/oauth2/handler.go b/authentication/provider/oauth2/handler.go new file mode 100644 index 0000000..3908696 --- /dev/null +++ b/authentication/provider/oauth2/handler.go @@ -0,0 +1,335 @@ +// Copyright 2021 Hyperscale. All rights reserved. +// Use of this source code is governed by a MIT +// license that can be found in the LICENSE file. + +package oauth2 + +import ( + "encoding/json" + "errors" + "net/http" + "strings" + + "github.com/hyperscale-stack/logger" + "github.com/hyperscale-stack/security/authentication/credential" +) + +var ( + ErrRequestMustBePost = errors.New("request must be POST") +) + +func (s Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case s.cfg.PrefixURI + "/token": + s.handleTokenRequest(w, r) + case s.cfg.PrefixURI + "/authorize": + s.handleAuthorizeRequest(w, r) + } +} + +func (s Server) handleTokenRequest(w http.ResponseWriter, r *http.Request) { + // Only allow GET or POST + if r.Method == http.MethodGet { + if !s.cfg.AllowGetAccessRequest { + s.error(w, E_INVALID_REQUEST, ErrRequestMustBePost, "access_request=%s", "GET request not allowed") + + return + } + } else if r.Method != http.MethodPost { + s.error(w, E_INVALID_REQUEST, ErrRequestMustBePost, "access_request=%s", "request must be POST") + + return + } + + if err := r.ParseForm(); err != nil { + s.error(w, E_INVALID_REQUEST, err, "access_request=%s", "parsing error") + + return + } + + var ar *AccessRequest + + grantType := AccessRequestType(r.FormValue("grant_type")) + if s.cfg.AllowedAccessTypes.Exists(grantType) { + switch grantType { + case AUTHORIZATION_CODE: + // s.handleAuthorizationCodeRequest(w, r) + ar.Authorized = true + case REFRESH_TOKEN: + // s.handleRefreshTokenRequest(w, r) + ar.Authorized = true + case PASSWORD: + ar = s.handlePasswordRequest(w, r) + + user, err := s.userProvider.Authenticate(ar.Username, ar.Password) + if err != nil { + s.error(w, E_ACCESS_DENIED, err, "get_user=%s", "failed") + + return + } + + ar.Authorized = true + ar.UserData = user.GetID() + + case CLIENT_CREDENTIALS: + // s.handleClientCredentialsRequest(w, r) + ar.Authorized = true + case ASSERTION: + // s.handleAssertionRequest(w, r) + ar.Authorized = false + default: + s.error(w, E_UNSUPPORTED_GRANT_TYPE, nil, "access_request=%s", "unknown grant type") + + return + } + + s.FinishAccessRequest(w, r, ar) + } + +} + +func (s Server) handleAuthorizeRequest(w http.ResponseWriter, r *http.Request) { + +} + +func (s Server) handlePasswordRequest(w http.ResponseWriter, r *http.Request) *AccessRequest { + // get client authentication + auth := s.getClientAuth(w, r, s.cfg.AllowClientSecretInParams) + if auth == nil { + return nil + } + + // generate access token + ar := &AccessRequest{ + Type: PASSWORD, + Username: r.FormValue("username"), + Password: r.FormValue("password"), + Scope: r.FormValue("scope"), + GenerateRefresh: true, + Expiration: s.cfg.AccessExpiration, + HttpRequest: r, + } + + // "username" and "password" is required + if ar.Username == "" || ar.Password == "" { + s.error(w, E_INVALID_GRANT, nil, "handle_password=%s", "username and password required") + + return nil + } + + // must have a valid client + if ar.Client = s.getClient(auth, s.storage, w, true); ar.Client == nil { + return nil + } + + // set redirect uri + ar.RedirectURI = FirstURI(ar.Client.GetRedirectURI(), s.cfg.RedirectUriSeparator) + + /* + user, err := s.userProvider.Authenticate(username, password) + if err != nil && errors.Is(err, ErrUserNotFound) { + s.error(w, E_ACCESS_DENIED, nil, "get_user=%s", "username or password is invalid") + + return + } else if err != nil { + s.error(w, E_ACCESS_DENIED, nil, "get_user=%s", "username or password is invalid") + } + */ + + return ar +} + +// Returns the first uri from an uri list +func FirstURI(baseUriList string, separator string) string { + if separator == "" { + return baseUriList + } + + if slist := strings.Split(baseUriList, separator); len(slist) > 0 { + return slist[0] + } + + return "" +} + +// getClientAuth checks client basic authentication in params if allowed, +// otherwise gets it from the header. +// Sets an error on the response if no auth is present or a server error occurs. +func (s Server) getClientAuth(w http.ResponseWriter, r *http.Request, allowQueryParams bool) *credential.UsernamePasswordCredential { + ctx := r.Context() + + // creds := credential.FromContext(ctx) + + if allowQueryParams { + // Allow for auth without password + if _, hasSecret := r.Form["client_secret"]; hasSecret { + auth := credential.NewUsernamePasswordCredential( + r.FormValue("client_id"), + r.FormValue("client_secret"), + ) + + if auth.GetPrincipal() != "" { + return auth.(*credential.UsernamePasswordCredential) + } + } + } + + auth := credential.FromContext(ctx) + + /* + auth, err := CheckBasicAuth(r) + if err != nil { + s.setErrorAndLog(w, E_INVALID_REQUEST, err, "get_client_auth=%s", "check auth error") + return nil + } + */ + if auth == nil { + s.error(w, E_INVALID_REQUEST, errors.New("Client authentication not sent"), "get_client_auth=%s", "client authentication not sent") + + return nil + } + + return auth.(*credential.UsernamePasswordCredential) +} + +// getClient looks up and authenticates the basic auth using the given +// storage. Sets an error on the response if auth fails or a server error occurs. +func (s Server) getClient(creds *credential.UsernamePasswordCredential, storage StorageProvider, w http.ResponseWriter, allowEmptySecret bool) Client { + client, err := storage.LoadClient(creds.GetPrincipal().(string)) + if errors.Is(err, ErrClientNotFound) { + s.error(w, E_UNAUTHORIZED_CLIENT, nil, "get_client=%s", "not found") + + return nil + } else if err != nil { + s.error(w, E_SERVER_ERROR, err, "get_client=%s", "error finding client") + + return nil + } + + if client == nil { + s.error(w, E_UNAUTHORIZED_CLIENT, nil, "get_client=%s", "client is nil") + + return nil + } + + if c, ok := client.(ClientSecretMatcher); ok { + if creds.GetCredentials() != nil && !c.SecretMatches(creds.GetCredentials().(string)) { + s.error(w, E_UNAUTHORIZED_CLIENT, nil, "get_client=%s, client_id=%v", "client check failed", client.GetID()) + + return nil + } else if creds.GetCredentials() == nil && !allowEmptySecret { + s.error(w, E_UNAUTHORIZED_CLIENT, nil, "get_client=%s, client_id=%v", "client check failed", client.GetID()) + + return nil + } + } else if !allowEmptySecret { + s.error(w, E_UNAUTHORIZED_CLIENT, nil, "get_client=%s, client_id=%v", "client check failed", client.GetID()) + + return nil + } + + if client.GetRedirectURI() == "" { + s.error(w, E_UNAUTHORIZED_CLIENT, nil, "get_client=%s", "client redirect uri is empty") + + return nil + } + + return client +} + +func (s Server) error(w http.ResponseWriter, responseError DefaultErrorID, internalError error, debugFormat string, debugArgs ...interface{}) { + format := "error=%v, internal_error=%#v " + debugFormat + + s.logger.Error( + format, + logger.WithFormat(append([]interface{}{responseError, internalError}, debugArgs...)...), + ) +} + +func (s *Server) FinishAccessRequest(w http.ResponseWriter, r *http.Request, ar *AccessRequest) { + // don't process if is already an error + /*if w.IsError { + return + }*/ + + redirectURI := r.FormValue("redirect_uri") + // Get redirect uri from AccessRequest if it's there (e.g., refresh token request) + if ar.RedirectURI != "" { + redirectURI = ar.RedirectURI + } + + if !ar.Authorized { + s.error(w, E_ACCESS_DENIED, nil, "finish_access_request=%s", "authorization failed") + + return + } + + var ret *AccessInfo + var err error + + if ar.ForceAccessInfo == nil { + // generate access token + ret = &AccessInfo{ + Client: ar.Client, + AuthorizeInfo: ar.AuthorizeInfo, + AccessInfo: ar.AccessInfo, + RedirectURI: redirectURI, + CreatedAt: s.now(), + ExpiresIn: int64(ar.Expiration.Seconds()), + UserData: ar.UserData, + Scope: ar.Scope, + } + + // generate access token + ret.AccessToken, ret.RefreshToken, err = s.tokenGenerator.GenerateAccessToken(ar.GenerateRefresh) + if err != nil { + s.error(w, E_SERVER_ERROR, err, "finish_access_request=%s", "error generating token") + + return + } + } else { + ret = ar.ForceAccessInfo + } + + // save access token + if err = s.storage.SaveAccess(ret); err != nil { + s.error(w, E_SERVER_ERROR, err, "finish_access_request=%s", "error saving access token") + + return + } + + // remove authorization token + if ret.AuthorizeInfo != nil { + s.storage.RemoveAuthorize(ret.AuthorizeInfo.Code) + } + + // remove previous access token + if ret.AccessInfo != nil && !s.cfg.RetainTokenAfterRefresh { + if ret.AccessInfo.RefreshToken != "" { + s.storage.RemoveRefresh(ret.AccessInfo.RefreshToken) + } + + s.storage.RemoveAccess(ret.AccessInfo.AccessToken) + } + + output := map[string]interface{}{ + "access_token": ret.AccessToken, + "token_type": s.cfg.TokenType, + "expires_in": ret.ExpiresIn, + } + + if ret.RefreshToken != "" { + output["refresh_token"] = ret.RefreshToken + } + + if ret.Scope != "" { + output["scope"] = ret.Scope + } + + if err := json.NewEncoder(w).Encode(output); err != nil { + s.error(w, E_SERVER_ERROR, err, "finish_access_request=%s", "serialize response failed") + + return + } + +} diff --git a/authentication/provider/oauth2/mock_user_provider.go b/authentication/provider/oauth2/mock_user_provider.go index 21d5495..f7da3b2 100644 --- a/authentication/provider/oauth2/mock_user_provider.go +++ b/authentication/provider/oauth2/mock_user_provider.go @@ -12,6 +12,52 @@ type MockUserProvider struct { mock.Mock } +// Authenticate provides a mock function with given fields: username, password +func (_m *MockUserProvider) Authenticate(username string, password string) (user.User, error) { + ret := _m.Called(username, password) + + var r0 user.User + if rf, ok := ret.Get(0).(func(string, string) user.User); ok { + r0 = rf(username, password) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(user.User) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(string, string) error); ok { + r1 = rf(username, password) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// LoadByUsername provides a mock function with given fields: username +func (_m *MockUserProvider) LoadByUsername(username string) (user.User, error) { + ret := _m.Called(username) + + var r0 user.User + if rf, ok := ret.Get(0).(func(string) user.User); ok { + r0 = rf(username) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(user.User) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(username) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // LoadUser provides a mock function with given fields: id func (_m *MockUserProvider) LoadUser(id string) (user.User, error) { ret := _m.Called(id) diff --git a/authentication/provider/oauth2/option.go b/authentication/provider/oauth2/option.go index 94522b0..6c6f1e7 100644 --- a/authentication/provider/oauth2/option.go +++ b/authentication/provider/oauth2/option.go @@ -4,6 +4,11 @@ package oauth2 +import ( + "github.com/hyperscale-stack/logger" + "github.com/hyperscale-stack/security/authentication/provider/oauth2/token" +) + // Option type. type Option func(server *Server) @@ -16,6 +21,24 @@ func WithConfig(cfg *Configuration) Option { func WithStorage(storage StorageProvider) Option { return func(server *Server) { + server.storage = storage + } +} + +func WithUserProvider(userProvider UserProvider) Option { + return func(server *Server) { + server.userProvider = userProvider + } +} + +func WithLogger(logger logger.Logger) Option { + return func(server *Server) { + server.logger = logger + } +} +func WithTokenGenerator(tokenGenerator token.Generator) Option { + return func(server *Server) { + server.tokenGenerator = tokenGenerator } } diff --git a/authentication/provider/oauth2/option_test.go b/authentication/provider/oauth2/option_test.go index c34d004..3a63d80 100644 --- a/authentication/provider/oauth2/option_test.go +++ b/authentication/provider/oauth2/option_test.go @@ -7,6 +7,8 @@ package oauth2 import ( "testing" + "github.com/hyperscale-stack/logger" + "github.com/hyperscale-stack/security/authentication/provider/oauth2/token" "github.com/stretchr/testify/assert" ) @@ -21,3 +23,51 @@ func TestWithConfig(t *testing.T) { assert.Same(t, cfg, server.cfg) } + +func TestWithLogger(t *testing.T) { + logger := &logger.Nop{} + + opt := WithLogger(logger) + + server := &Server{} + + opt(server) + + assert.Same(t, logger, server.logger) +} + +func TestWithStorage(t *testing.T) { + storageMock := &MockStorageProvider{} + + opt := WithStorage(storageMock) + + server := &Server{} + + opt(server) + + assert.Same(t, storageMock, server.storage) +} + +func TestWithUserProvider(t *testing.T) { + userProviderMock := &MockUserProvider{} + + opt := WithUserProvider(userProviderMock) + + server := &Server{} + + opt(server) + + assert.Same(t, userProviderMock, server.userProvider) +} + +func TestWithTokenGenerator(t *testing.T) { + tokenGeneratorMock := &token.MockGenerator{} + + opt := WithTokenGenerator(tokenGeneratorMock) + + server := &Server{} + + opt(server) + + assert.Same(t, tokenGeneratorMock, server.tokenGenerator) +} diff --git a/authentication/provider/oauth2/response.go b/authentication/provider/oauth2/response.go new file mode 100644 index 0000000..2894a64 --- /dev/null +++ b/authentication/provider/oauth2/response.go @@ -0,0 +1,166 @@ +// Copyright 2021 Hyperscale. All rights reserved. +// Use of this source code is governed by a MIT +// license that can be found in the LICENSE file. + +package oauth2 + +import ( + "errors" + "fmt" + "net/http" + "net/url" +) + +var ( + ErrNotARedirectResponse = errors.New("Not a redirect response") +) + +// Data for response output +type ResponseData map[string]interface{} + +// Response type enum +type ResponseType int + +const ( + DATA ResponseType = iota + REDIRECT +) + +// Server response +type Response struct { + Type ResponseType + StatusCode int + StatusText string + ErrorStatusCode int + URL string + Output ResponseData + Headers http.Header + IsError bool + ErrorID DefaultErrorID + InternalError error + RedirectInFragment bool + + // Storage to use in this response - required + Storage StorageProvider +} + +func NewResponse(storage StorageProvider) *Response { + r := &Response{ + Type: DATA, + StatusCode: 200, + ErrorStatusCode: 200, + Output: make(ResponseData), + Headers: make(http.Header), + IsError: false, + // Storage: storage.Clone(), + } + r.Headers.Add( + "Cache-Control", + "no-cache, no-store, max-age=0, must-revalidate", + ) + r.Headers.Add("Pragma", "no-cache") + r.Headers.Add("Expires", "Fri, 01 Jan 1990 00:00:00 GMT") + + return r +} + +// SetError sets an error id and description on the Response +// state and uri are left blank +func (r *Response) SetError(id DefaultErrorID, description string) { + r.SetErrorURI(id, description, "", "") +} + +// SetErrorState sets an error id, description, and state on the Response +// uri is left blank +func (r *Response) SetErrorState(id DefaultErrorID, description string, state string) { + r.SetErrorURI(id, description, "", state) +} + +// SetErrorURI sets an error id, description, state, and uri on the Response +func (r *Response) SetErrorURI(id DefaultErrorID, description string, uri string, state string) { + // get default error message + if description == "" { + description = deferror.Get(id) + } + + // set error parameters + r.IsError = true + r.ErrorID = id + r.StatusCode = r.ErrorStatusCode + + if r.StatusCode != http.StatusOK { + r.StatusText = description + } else { + r.StatusText = "" + } + + r.Output = make(ResponseData) // clear output + r.Output["error"] = id + r.Output["error_description"] = description + + if uri != "" { + r.Output["error_uri"] = uri + } + + if state != "" { + r.Output["state"] = state + } +} + +// SetRedirect changes the response to redirect to the given url +func (r *Response) SetRedirect(url string) { + // set redirect parameters + r.Type = REDIRECT + r.URL = url +} + +// SetRedirectFragment sets redirect values to be passed in fragment instead of as query parameters +func (r *Response) SetRedirectFragment(f bool) { + r.RedirectInFragment = f +} + +// GetRedirectURL returns the redirect url with all query string parameters +func (r *Response) GetRedirectURL() (string, error) { + if r.Type != REDIRECT { + return "", ErrNotARedirectResponse + } + + u, err := url.Parse(r.URL) + if err != nil { + return "", fmt.Errorf("parse url failed: %w", err) + } + + var q url.Values + if r.RedirectInFragment { + // start with empty set for fragment + q = url.Values{} + } else { + // add parameters to existing query + q = u.Query() + } + + // add parameters + for n, v := range r.Output { + q.Set(n, fmt.Sprint(v)) + } + + // https://tools.ietf.org/html/rfc6749#section-4.2.2 + // Fragment should be encoded as application/x-www-form-urlencoded (%-escaped, spaces are represented as '+') + // The stdlib URL#String() doesn't make that easy to accomplish, so build this ourselves + if r.RedirectInFragment { + u.Fragment = "" + redirectURI := u.String() + "#" + q.Encode() + + return redirectURI, nil + } + + // Otherwise, update the query and encode normally + u.RawQuery = q.Encode() + u.Fragment = "" + + return u.String(), nil +} + +func (r *Response) Close() { + // r.Storage.Close() +} diff --git a/authentication/provider/oauth2/server.go b/authentication/provider/oauth2/server.go index 2914d8a..1739019 100644 --- a/authentication/provider/oauth2/server.go +++ b/authentication/provider/oauth2/server.go @@ -5,15 +5,31 @@ package oauth2 import ( - "net/http" + "time" + + "github.com/hyperscale-stack/logger" + "github.com/hyperscale-stack/security/authentication/provider/oauth2/token" + "github.com/hyperscale-stack/security/authentication/provider/oauth2/token/random" ) type Server struct { - cfg *Configuration + cfg *Configuration + logger logger.Logger + storage StorageProvider + userProvider UserProvider + tokenGenerator token.Generator + now func() time.Time } func NewServer(options ...Option) *Server { - s := &Server{} + cfg := NewConfiguration() + + s := &Server{ + cfg: cfg, + logger: &logger.Nop{}, + tokenGenerator: random.NewTokenGenerator(&random.Configuration{}), + now: time.Now, + } for _, opt := range options { opt(s) @@ -21,40 +37,3 @@ func NewServer(options ...Option) *Server { return s } - -func (s *Server) HandleAccessRequest(w http.ResponseWriter, r *http.Request) *AccessRequest { - // Only allow GET or POST - if r.Method == http.MethodGet { - if !s.cfg.AllowGetAccessRequest { - //s.setErrorAndLog(w, E_INVALID_REQUEST, errors.New("Request must be POST"), "access_request=%s", "GET request not allowed") - return nil - } - } else if r.Method != http.MethodPost { - //s.setErrorAndLog(w, E_INVALID_REQUEST, errors.New("Request must be POST"), "access_request=%s", "request must be POST") - return nil - } - - if err := r.ParseForm(); err != nil { - //s.setErrorAndLog(w, E_INVALID_REQUEST, err, "access_request=%s", "parsing error") - return nil - } - - grantType := AccessRequestType(r.FormValue("grant_type")) - if s.cfg.AllowedAccessTypes.Exists(grantType) { - switch grantType { - case AUTHORIZATION_CODE: - // return s.handleAuthorizationCodeRequest(w, r) - case REFRESH_TOKEN: - // return s.handleRefreshTokenRequest(w, r) - case PASSWORD: - // return s.handlePasswordRequest(w, r) - case CLIENT_CREDENTIALS: - // return s.handleClientCredentialsRequest(w, r) - case ASSERTION: - // return s.handleAssertionRequest(w, r) - } - } - - // s.setErrorAndLog(w, E_UNSUPPORTED_GRANT_TYPE, nil, "access_request=%s", "unknown grant type") - return nil -} diff --git a/authentication/provider/oauth2/storage.go b/authentication/provider/oauth2/storage.go index f80a32a..742b439 100644 --- a/authentication/provider/oauth2/storage.go +++ b/authentication/provider/oauth2/storage.go @@ -49,6 +49,8 @@ type AuthorizeProvider interface { //go:generate mockery --name=UserProvider --inpackage --case underscore type UserProvider interface { LoadUser(id string) (user.User, error) + LoadByUsername(username string) (user.User, error) + Authenticate(username string, password string) (user.User, error) } //go:generate mockery --name=StorageProvider --inpackage --case underscore diff --git a/authentication/provider/oauth2/token/generator.go b/authentication/provider/oauth2/token/generator.go index cd39ec9..43da7a7 100644 --- a/authentication/provider/oauth2/token/generator.go +++ b/authentication/provider/oauth2/token/generator.go @@ -4,6 +4,7 @@ package token +//go:generate mockery --name=Generator --inpackage --case underscore type Generator interface { GenerateAccessToken(generateRefresh bool) (accessToken string, refreshToken string, err error) } diff --git a/authentication/provider/oauth2/token/mock_generator.go b/authentication/provider/oauth2/token/mock_generator.go new file mode 100644 index 0000000..c56ce5e --- /dev/null +++ b/authentication/provider/oauth2/token/mock_generator.go @@ -0,0 +1,38 @@ +// Code generated by mockery v0.0.0-dev. DO NOT EDIT. + +package token + +import mock "github.com/stretchr/testify/mock" + +// MockGenerator is an autogenerated mock type for the Generator type +type MockGenerator struct { + mock.Mock +} + +// GenerateAccessToken provides a mock function with given fields: generateRefresh +func (_m *MockGenerator) GenerateAccessToken(generateRefresh bool) (string, string, error) { + ret := _m.Called(generateRefresh) + + var r0 string + if rf, ok := ret.Get(0).(func(bool) string); ok { + r0 = rf(generateRefresh) + } else { + r0 = ret.Get(0).(string) + } + + var r1 string + if rf, ok := ret.Get(1).(func(bool) string); ok { + r1 = rf(generateRefresh) + } else { + r1 = ret.Get(1).(string) + } + + var r2 error + if rf, ok := ret.Get(2).(func(bool) error); ok { + r2 = rf(generateRefresh) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} diff --git a/go.mod b/go.mod index f8913c3..726d900 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.15 require ( github.com/gilcrest/alice v1.0.0 + github.com/hyperscale-stack/logger v1.0.0 // indirect github.com/hyperscale-stack/secure v1.0.0 // indirect github.com/rs/zerolog v1.20.0 github.com/stretchr/objx v0.3.0 // indirect diff --git a/go.sum b/go.sum index b8a0423..635bbd0 100644 --- a/go.sum +++ b/go.sum @@ -4,6 +4,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/gilcrest/alice v1.0.0 h1:5+CasxidJEUHmgghQxLOl09uYhOlavDfDgNZhyR62LU= github.com/gilcrest/alice v1.0.0/go.mod h1:q5HRhK5WEyU1pDBIIfmYapVGLd/IAAPwiO8LNxKADpw= +github.com/hyperscale-stack/logger v1.0.0 h1:ZsjQp1DEXwHunoVhi0L5/yZOg0Npsy7LkQy10Mi2o6I= +github.com/hyperscale-stack/logger v1.0.0/go.mod h1:X1apoFZZ8/AKspix5ylBetZzv+HEU2Ive/+3qHJhyOw= github.com/hyperscale-stack/secure v1.0.0 h1:ayGoa/Y/0RcAcP767WKjla1r9KlR+Tul5DPI/jE9dP0= github.com/hyperscale-stack/secure v1.0.0/go.mod h1:PY+BMJQI2aP+YYA3C7R0bFTS/XGJ4xPCYjBp9rEqmtQ= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= diff --git a/user/mock_user.go b/user/mock_user.go index 1342d82..544d051 100644 --- a/user/mock_user.go +++ b/user/mock_user.go @@ -9,6 +9,20 @@ type MockUser struct { mock.Mock } +// GetID provides a mock function with given fields: +func (_m *MockUser) GetID() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + // GetPassword provides a mock function with given fields: func (_m *MockUser) GetPassword() string { ret := _m.Called() diff --git a/user/mock_user_password_salt.go b/user/mock_user_password_salt.go index 9909151..d08db06 100644 --- a/user/mock_user_password_salt.go +++ b/user/mock_user_password_salt.go @@ -9,6 +9,20 @@ type MockUserPasswordSalt struct { mock.Mock } +// GetID provides a mock function with given fields: +func (_m *MockUserPasswordSalt) GetID() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + // GetPassword provides a mock function with given fields: func (_m *MockUserPasswordSalt) GetPassword() string { ret := _m.Called() diff --git a/user/user.go b/user/user.go index 74dee72..a3e55be 100644 --- a/user/user.go +++ b/user/user.go @@ -7,6 +7,8 @@ package user // User interface provides core user information //go:generate mockery --name=User --inpackage --case underscore type User interface { + GetID() string + // GetRoles returns the roles granted to the user. GetRoles() []string From d2f57d17aacd07478d7478c5a97023b810c4fa84 Mon Sep 17 00:00:00 2001 From: Axel Etcheverry Date: Sat, 14 Aug 2021 01:43:12 +0200 Subject: [PATCH 3/9] feat(authentication/oauth2): add base of token endpoint --- .gitignore | 2 + authentication/provider/oauth2/access.go | 566 +++++++++++++++++- authentication/provider/oauth2/access_test.go | 6 +- authentication/provider/oauth2/authorize.go | 47 +- .../provider/oauth2/authorize_test.go | 4 +- .../provider/oauth2/configuration.go | 2 +- authentication/provider/oauth2/handler.go | 335 ----------- .../provider/oauth2/mock_access_provider.go | 12 +- .../oauth2/mock_authorize_provider.go | 12 +- .../provider/oauth2/mock_refresh_provider.go | 12 +- .../provider/oauth2/mock_storage_provider.go | 36 +- .../oauth2_authentication_provider_test.go | 8 +- authentication/provider/oauth2/response.go | 6 +- .../provider/oauth2/response_json.go | 44 ++ .../provider/oauth2/response_json_test.go | 111 ++++ authentication/provider/oauth2/server.go | 13 + authentication/provider/oauth2/storage.go | 12 +- .../oauth2/storage/in_memory_storage.go | 18 +- .../oauth2/storage/in_memory_storage_test.go | 6 +- authentication/provider/oauth2/urivalidate.go | 125 ++++ .../provider/oauth2/urivalidate_test.go | 163 +++++ authentication/provider/oauth2/util.go | 129 ++++ authentication/provider/oauth2/util_test.go | 221 +++++++ .../oauth2_auth_by_access_token_test.go | 6 +- 24 files changed, 1470 insertions(+), 426 deletions(-) delete mode 100644 authentication/provider/oauth2/handler.go create mode 100644 authentication/provider/oauth2/response_json.go create mode 100644 authentication/provider/oauth2/response_json_test.go create mode 100644 authentication/provider/oauth2/urivalidate.go create mode 100644 authentication/provider/oauth2/urivalidate_test.go create mode 100644 authentication/provider/oauth2/util.go create mode 100644 authentication/provider/oauth2/util_test.go diff --git a/.gitignore b/.gitignore index 93117e4..125dc4e 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,5 @@ vendor/ *.swp + +*.go_ diff --git a/authentication/provider/oauth2/access.go b/authentication/provider/oauth2/access.go index 428940b..c857410 100644 --- a/authentication/provider/oauth2/access.go +++ b/authentication/provider/oauth2/access.go @@ -6,8 +6,15 @@ package oauth2 import ( "context" + "crypto/sha256" + "encoding/base64" + "errors" "net/http" + "strings" "time" + + "github.com/hyperscale-stack/logger" + "github.com/hyperscale-stack/security/authentication/credential" ) // AccessRequestType is the type for OAuth2 param `grant_type` @@ -27,11 +34,11 @@ type AccessRequest struct { Type AccessRequestType Code string Client Client - AuthorizeInfo *AuthorizeInfo - AccessInfo *AccessInfo + AuthorizeData *AuthorizeData + AccessData *AccessData // Force finish to use this access data, to allow access data reuse - ForceAccessInfo *AccessInfo + ForceAccessData *AccessData RedirectURI string Scope string Username string @@ -61,8 +68,8 @@ type AccessRequest struct { type accessCtxKey struct{} // AccessTokenFromContext returns the Access Token info associated with the ctx. -func AccessTokenFromContext(ctx context.Context) *AccessInfo { - if a, ok := ctx.Value(accessCtxKey{}).(*AccessInfo); ok { +func AccessTokenFromContext(ctx context.Context) *AccessData { + if a, ok := ctx.Value(accessCtxKey{}).(*AccessData); ok { return a } @@ -70,20 +77,20 @@ func AccessTokenFromContext(ctx context.Context) *AccessInfo { } // AccessTokenToContext returns new context with Access Token info. -func AccessTokenToContext(ctx context.Context, access *AccessInfo) context.Context { +func AccessTokenToContext(ctx context.Context, access *AccessData) context.Context { return context.WithValue(ctx, accessCtxKey{}, access) } -// AccessInfo represents an access grant (tokens, expiration, client, etc). -type AccessInfo struct { +// AccessData represents an access grant (tokens, expiration, client, etc). +type AccessData struct { // Client information Client Client // Authorize data, for authorization code - AuthorizeInfo *AuthorizeInfo + AuthorizeData *AuthorizeData // Previous access data, for refresh token - AccessInfo *AccessInfo + AccessData *AccessData // Access token AccessToken string @@ -108,16 +115,549 @@ type AccessInfo struct { } // IsExpired returns true if access expired. -func (i *AccessInfo) IsExpired() bool { +func (i *AccessData) IsExpired() bool { return i.IsExpiredAt(time.Now()) } // IsExpiredAt returns true if access expires at time 't'. -func (i *AccessInfo) IsExpiredAt(t time.Time) bool { +func (i *AccessData) IsExpiredAt(t time.Time) bool { return i.ExpireAt().Before(t) } // ExpireAt returns the expiration date. -func (i *AccessInfo) ExpireAt() time.Time { +func (i *AccessData) ExpireAt() time.Time { return i.CreatedAt.Add(time.Duration(i.ExpiresIn) * time.Second) } + +// HandleAccessRequest is the http.HandlerFunc for handling access token requests +func (s *Server) HandleAccessRequest(w *Response, r *http.Request) *AccessRequest { + // Only allow GET or POST + if r.Method == http.MethodGet { + if !s.cfg.AllowGetAccessRequest { + s.setErrorAndLog(w, E_INVALID_REQUEST, errors.New("Request must be POST"), "access_request=%s", "GET request not allowed") + + return nil + } + } else if r.Method != http.MethodPost { + s.setErrorAndLog(w, E_INVALID_REQUEST, errors.New("Request must be POST"), "access_request=%s", "request must be POST") + + return nil + } + + if err := r.ParseForm(); err != nil { + s.setErrorAndLog(w, E_INVALID_REQUEST, err, "access_request=%s", "parsing error") + + return nil + } + + grantType := AccessRequestType(r.FormValue("grant_type")) + if !s.cfg.AllowedAccessTypes.Exists(grantType) { + s.setErrorAndLog(w, E_UNSUPPORTED_GRANT_TYPE, nil, "access_request=%s", "unknown grant type") + + return nil + } + + switch grantType { + case AUTHORIZATION_CODE: + return s.handleAuthorizationCodeRequest(w, r) + case REFRESH_TOKEN: + return s.handleRefreshTokenRequest(w, r) + case PASSWORD: + return s.handlePasswordRequest(w, r) + case CLIENT_CREDENTIALS: + return s.handleClientCredentialsRequest(w, r) + default: + return s.handleAssertionRequest(w, r) + } +} + +func (s *Server) handleAuthorizationCodeRequest(w *Response, r *http.Request) *AccessRequest { + // get client authentication + auth := s.getClientAuth(w, r, s.cfg.AllowClientSecretInParams) + if auth == nil { + return nil + } + + // generate access token + ret := &AccessRequest{ + Type: AUTHORIZATION_CODE, + Code: r.FormValue("code"), + CodeVerifier: r.FormValue("code_verifier"), + RedirectURI: r.FormValue("redirect_uri"), + GenerateRefresh: true, + Expiration: s.cfg.AccessExpiration, + HttpRequest: r, + } + + // "code" is required + if ret.Code == "" { + s.setErrorAndLog(w, E_INVALID_GRANT, nil, "auth_code_request=%s", "code is required") + + return nil + } + + // must have a valid client + if ret.Client = s.getClient(auth, w.Storage, w); ret.Client == nil { + return nil + } + + // must be a valid authorization code + var err error + + ret.AuthorizeData, err = w.Storage.LoadAuthorize(ret.Code) + if err != nil { + s.setErrorAndLog(w, E_INVALID_GRANT, err, "auth_code_request=%s", "error loading authorize data") + + return nil + } + + if ret.AuthorizeData == nil { + s.setErrorAndLog(w, E_UNAUTHORIZED_CLIENT, nil, "auth_code_request=%s", "authorization data is nil") + + return nil + } + + if ret.AuthorizeData.Client == nil { + s.setErrorAndLog(w, E_UNAUTHORIZED_CLIENT, nil, "auth_code_request=%s", "authorization client is nil") + + return nil + } + + if ret.AuthorizeData.Client.GetRedirectURI() == "" { + s.setErrorAndLog(w, E_UNAUTHORIZED_CLIENT, nil, "auth_code_request=%s", "client redirect uri is empty") + + return nil + } + + if ret.AuthorizeData.IsExpiredAt(s.now()) { + s.setErrorAndLog(w, E_INVALID_GRANT, nil, "auth_code_request=%s", "authorization data is expired") + + return nil + } + + // code must be from the client + if ret.AuthorizeData.Client.GetID() != ret.Client.GetID() { + s.setErrorAndLog(w, E_INVALID_GRANT, nil, "auth_code_request=%s", "client code does not match") + + return nil + } + + // check redirect uri + if ret.RedirectURI == "" { + ret.RedirectURI = FirstURI(ret.Client.GetRedirectURI(), s.cfg.RedirectURISeparator) + } + if realRedirectURI, err := ValidateURIList(ret.Client.GetRedirectURI(), ret.RedirectURI, s.cfg.RedirectURISeparator); err != nil { + + s.setErrorAndLog(w, E_INVALID_REQUEST, err, "auth_code_request=%s", "error validating client redirect") + return nil + } else { + ret.RedirectURI = realRedirectURI + } + if ret.AuthorizeData.RedirectURI != ret.RedirectURI { + s.setErrorAndLog(w, E_INVALID_REQUEST, errors.New("Redirect uri is different"), "auth_code_request=%s", "client redirect does not match authorization data") + + return nil + } + + // Verify PKCE, if present in the authorization data + if len(ret.AuthorizeData.CodeChallenge) > 0 { + // https://tools.ietf.org/html/rfc7636#section-4.1 + if matched := pkceMatcher.MatchString(ret.CodeVerifier); !matched { + s.setErrorAndLog( + w, + E_INVALID_REQUEST, + errors.New("code_verifier has invalid format"), + "auth_code_request=%s", + "pkce code challenge verifier does not match", + ) + + return nil + } + + // https: //tools.ietf.org/html/rfc7636#section-4.6 + codeVerifier := "" + switch ret.AuthorizeData.CodeChallengeMethod { + case "", PKCE_PLAIN: + codeVerifier = ret.CodeVerifier + case PKCE_S256: + hash := sha256.Sum256([]byte(ret.CodeVerifier)) + codeVerifier = base64.RawURLEncoding.EncodeToString(hash[:]) + default: + s.setErrorAndLog( + w, + E_INVALID_REQUEST, + nil, + "auth_code_request=%s", + "pkce transform algorithm not supported (rfc7636)", + ) + + return nil + } + + if codeVerifier != ret.AuthorizeData.CodeChallenge { + s.setErrorAndLog( + w, + E_INVALID_GRANT, + errors.New("code_verifier failed comparison with code_challenge"), + "auth_code_request=%s", + "pkce code verifier does not match challenge", + ) + + return nil + } + } + + // set rest of data + ret.Scope = ret.AuthorizeData.Scope + ret.UserData = ret.AuthorizeData.UserData + + return ret +} + +func extraScopes(access_scopes, refresh_scopes string) bool { + access_scopes_list := strings.Split(access_scopes, " ") + refresh_scopes_list := strings.Split(refresh_scopes, " ") + + access_map := make(map[string]int) + + for _, scope := range access_scopes_list { + if scope == "" { + continue + } + + access_map[scope] = 1 + } + + for _, scope := range refresh_scopes_list { + if scope == "" { + continue + } + + if _, ok := access_map[scope]; !ok { + return true + } + } + + return false +} + +func (s *Server) handleRefreshTokenRequest(w *Response, r *http.Request) *AccessRequest { + // get client authentication + auth := s.getClientAuth(w, r, s.cfg.AllowClientSecretInParams) + if auth == nil { + return nil + } + + // generate access token + ret := &AccessRequest{ + Type: REFRESH_TOKEN, + Code: r.FormValue("refresh_token"), + Scope: r.FormValue("scope"), + GenerateRefresh: true, + Expiration: s.cfg.AccessExpiration, + HttpRequest: r, + } + + // "refresh_token" is required + if ret.Code == "" { + s.setErrorAndLog(w, E_INVALID_GRANT, nil, "refresh_token=%s", "refresh_token is required") + + return nil + } + + // must have a valid client + if ret.Client = s.getClient(auth, w.Storage, w); ret.Client == nil { + return nil + } + + // must be a valid refresh code + var err error + + ret.AccessData, err = w.Storage.LoadRefresh(ret.Code) + if err != nil { + s.setErrorAndLog(w, E_INVALID_GRANT, err, "refresh_token=%s", "error loading access data") + + return nil + } + + if ret.AccessData == nil { + s.setErrorAndLog(w, E_UNAUTHORIZED_CLIENT, nil, "refresh_token=%s", "access data is nil") + + return nil + } + + if ret.AccessData.Client == nil { + s.setErrorAndLog(w, E_UNAUTHORIZED_CLIENT, nil, "refresh_token=%s", "access data client is nil") + + return nil + } + + if ret.AccessData.Client.GetRedirectURI() == "" { + s.setErrorAndLog(w, E_UNAUTHORIZED_CLIENT, nil, "refresh_token=%s", "access data client redirect uri is empty") + + return nil + } + + // client must be the same as the previous token + if ret.AccessData.Client.GetID() != ret.Client.GetID() { + s.setErrorAndLog( + w, + E_INVALID_CLIENT, + errors.New("Client id must be the same from previous token"), + "refresh_token=%s, current=%v, previous=%v", + "client mismatch", + ret.Client.GetID(), + ret.AccessData.Client.GetID(), + ) + + return nil + } + + // set rest of data + ret.RedirectURI = ret.AccessData.RedirectURI + ret.UserData = ret.AccessData.UserData + + if ret.Scope == "" { + ret.Scope = ret.AccessData.Scope + } + + if extraScopes(ret.AccessData.Scope, ret.Scope) { + msg := "the requested scope must not include any scope not originally granted by the resource owner" + s.setErrorAndLog(w, E_ACCESS_DENIED, errors.New(msg), "refresh_token=%s", msg) + + return nil + } + + return ret +} + +func (s *Server) handlePasswordRequest(w *Response, r *http.Request) *AccessRequest { + // get client authentication + auth := s.getClientAuth(w, r, s.cfg.AllowClientSecretInParams) + if auth == nil { + return nil + } + + // generate access token + ret := &AccessRequest{ + Type: PASSWORD, + Username: r.FormValue("username"), + Password: r.FormValue("password"), + Scope: r.FormValue("scope"), + GenerateRefresh: true, + Expiration: s.cfg.AccessExpiration, + HttpRequest: r, + } + + // "username" and "password" is required + if ret.Username == "" || ret.Password == "" { + s.setErrorAndLog(w, E_INVALID_GRANT, nil, "handle_password=%s", "username and pass required") + + return nil + } + + // must have a valid client + if ret.Client = s.getClient(auth, w.Storage, w); ret.Client == nil { + return nil + } + + // set redirect uri + ret.RedirectURI = FirstURI(ret.Client.GetRedirectURI(), s.cfg.RedirectURISeparator) + + return ret +} + +func (s *Server) handleClientCredentialsRequest(w *Response, r *http.Request) *AccessRequest { + // get client authentication + auth := s.getClientAuth(w, r, s.cfg.AllowClientSecretInParams) + if auth == nil { + return nil + } + + // generate access token + ret := &AccessRequest{ + Type: CLIENT_CREDENTIALS, + Scope: r.FormValue("scope"), + GenerateRefresh: false, + Expiration: s.cfg.AccessExpiration, + HttpRequest: r, + } + + // must have a valid client + if ret.Client = s.getClient(auth, w.Storage, w); ret.Client == nil { + return nil + } + + // set redirect uri + ret.RedirectURI = FirstURI(ret.Client.GetRedirectURI(), s.cfg.RedirectURISeparator) + + return ret +} + +func (s *Server) handleAssertionRequest(w *Response, r *http.Request) *AccessRequest { + // get client authentication + auth := s.getClientAuth(w, r, s.cfg.AllowClientSecretInParams) + if auth == nil { + return nil + } + + // generate access token + ret := &AccessRequest{ + Type: ASSERTION, + Scope: r.FormValue("scope"), + AssertionType: r.FormValue("assertion_type"), + Assertion: r.FormValue("assertion"), + GenerateRefresh: false, // assertion should NOT generate a refresh token, per the RFC + Expiration: s.cfg.AccessExpiration, + HttpRequest: r, + } + + // "assertion_type" and "assertion" is required + if ret.AssertionType == "" || ret.Assertion == "" { + s.setErrorAndLog(w, E_INVALID_GRANT, nil, "handle_assertion_request=%s", "assertion and assertion_type required") + + return nil + } + + // must have a valid client + if ret.Client = s.getClient(auth, w.Storage, w); ret.Client == nil { + return nil + } + + // set redirect uri + ret.RedirectURI = FirstURI(ret.Client.GetRedirectURI(), s.cfg.RedirectURISeparator) + + return ret +} + +func (s *Server) FinishAccessRequest(w *Response, r *http.Request, ar *AccessRequest) { + // don't process if is already an error + if w.IsError { + return + } + + redirectUri := r.FormValue("redirect_uri") + + // Get redirect uri from AccessRequest if it's there (e.g., refresh token request) + if ar.RedirectURI != "" { + redirectUri = ar.RedirectURI + } + + if ar.Authorized { + var ret *AccessData + var err error + + if ar.ForceAccessData == nil { + // generate access token + ret = &AccessData{ + Client: ar.Client, + AuthorizeData: ar.AuthorizeData, + AccessData: ar.AccessData, + RedirectURI: redirectUri, + CreatedAt: s.now(), + ExpiresIn: int64(ar.Expiration.Seconds()), + UserData: ar.UserData, + Scope: ar.Scope, + } + + // generate access token + // @TODO: add ret at first arg for GenerateAccessToken + ret.AccessToken, ret.RefreshToken, err = s.tokenGenerator.GenerateAccessToken(ar.GenerateRefresh) + if err != nil { + s.setErrorAndLog(w, E_SERVER_ERROR, err, "finish_access_request=%s", "error generating token") + + return + } + } else { + ret = ar.ForceAccessData + } + + // save access token + if err = w.Storage.SaveAccess(ret); err != nil { + s.setErrorAndLog(w, E_SERVER_ERROR, err, "finish_access_request=%s", "error saving access token") + + return + } + + // remove authorization token + if ret.AuthorizeData != nil { + w.Storage.RemoveAuthorize(ret.AuthorizeData.Code) + } + + // remove previous access token + if ret.AccessData != nil && !s.cfg.RetainTokenAfterRefresh { + if ret.AccessData.RefreshToken != "" { + w.Storage.RemoveRefresh(ret.AccessData.RefreshToken) + } + + w.Storage.RemoveAccess(ret.AccessData.AccessToken) + } + + // output data + w.Output["access_token"] = ret.AccessToken + w.Output["token_type"] = s.cfg.TokenType + w.Output["expires_in"] = ret.ExpiresIn + + if ret.RefreshToken != "" { + w.Output["refresh_token"] = ret.RefreshToken + } + + if ret.Scope != "" { + w.Output["scope"] = ret.Scope + } + } else { + s.setErrorAndLog(w, E_ACCESS_DENIED, nil, "finish_access_request=%s", "authorization failed") + } +} + +// Helper Functions + +// getClient looks up and authenticates the basic auth using the given +// storage. Sets an error on the response if auth fails or a server error occurs. +func (s Server) getClient(creds *credential.UsernamePasswordCredential, storage StorageProvider, w *Response) Client { + client, err := storage.LoadClient(creds.GetPrincipal().(string)) + if errors.Is(err, ErrClientNotFound) { + s.setErrorAndLog(w, E_UNAUTHORIZED_CLIENT, nil, "get_client=%s", "not found") + + return nil + } + + if err != nil { + s.setErrorAndLog(w, E_SERVER_ERROR, err, "get_client=%s", "error finding client") + + return nil + } + + if client == nil { + s.setErrorAndLog(w, E_UNAUTHORIZED_CLIENT, nil, "get_client=%s", "client is nil") + + return nil + } + + if !CheckClientSecret(client, creds.GetCredentials().(string)) { + s.setErrorAndLog(w, E_UNAUTHORIZED_CLIENT, nil, "get_client=%s, client_id=%v", "client check failed", client.GetID()) + + return nil + } + + if client.GetRedirectURI() == "" { + s.setErrorAndLog(w, E_UNAUTHORIZED_CLIENT, nil, "get_client=%s", "client redirect uri is empty") + + return nil + } + + return client +} + +// setErrorAndLog sets the response error and internal error (if non-nil) and logs them along with the provided debug format string and arguments. +func (s Server) setErrorAndLog(w *Response, responseError DefaultErrorID, internalError error, debugFormat string, debugArgs ...interface{}) { + format := "error=%v, internal_error=%#v " + debugFormat + + w.InternalError = internalError + w.SetError(responseError, "") + + s.logger.Error( + format, + logger.WithFormat(append([]interface{}{responseError, internalError}, debugArgs...)...), + ) +} diff --git a/authentication/provider/oauth2/access_test.go b/authentication/provider/oauth2/access_test.go index 2b2e584..576bd84 100644 --- a/authentication/provider/oauth2/access_test.go +++ b/authentication/provider/oauth2/access_test.go @@ -12,11 +12,11 @@ import ( "github.com/stretchr/testify/assert" ) -func TestAccessInfo(t *testing.T) { +func TestAccessData(t *testing.T) { cat, err := time.Parse("2006-01-02T15:04:05.000Z", "2014-11-12T11:45:26.371Z") assert.NoError(t, err) - ai := &AccessInfo{ + ai := &AccessData{ CreatedAt: cat, ExpiresIn: 10, } @@ -27,7 +27,7 @@ func TestAccessInfo(t *testing.T) { func TestAccessTokenContext(t *testing.T) { ctx := context.Background() - ai := &AccessInfo{ + ai := &AccessData{ CreatedAt: time.Now(), ExpiresIn: 10, } diff --git a/authentication/provider/oauth2/authorize.go b/authentication/provider/oauth2/authorize.go index 267dd73..d0caa18 100644 --- a/authentication/provider/oauth2/authorize.go +++ b/authentication/provider/oauth2/authorize.go @@ -4,7 +4,11 @@ package oauth2 -import "time" +import ( + "net/http" + "regexp" + "time" +) // AuthorizeRequestType is the type for OAuth param `response_type` type AuthorizeRequestType string @@ -17,8 +21,39 @@ const ( PKCE_S256 = "S256" ) -// AuthorizeInfo info. -type AuthorizeInfo struct { +var ( + pkceMatcher = regexp.MustCompile("^[a-zA-Z0-9~._-]{43,128}$") +) + +// Authorize request information +type AuthorizeRequest struct { + Type AuthorizeRequestType + Client Client + Scope string + RedirectUri string + State string + + // Set if request is authorized + Authorized bool + + // Token expiration in seconds. Change if different from default. + // If type = TOKEN, this expiration will be for the ACCESS token. + Expiration int32 + + // Data to be passed to storage. Not used by the library. + UserData interface{} + + // HttpRequest *http.Request for special use + HttpRequest *http.Request + + // Optional code_challenge as described in rfc7636 + CodeChallenge string + // Optional code_challenge_method as described in rfc7636 + CodeChallengeMethod string +} + +// AuthorizeData info. +type AuthorizeData struct { // Client information Client Client @@ -51,16 +86,16 @@ type AuthorizeInfo struct { } // IsExpired is true if authorization expired. -func (i *AuthorizeInfo) IsExpired() bool { +func (i *AuthorizeData) IsExpired() bool { return i.IsExpiredAt(time.Now()) } // IsExpired is true if authorization expires at time 't'. -func (i *AuthorizeInfo) IsExpiredAt(t time.Time) bool { +func (i *AuthorizeData) IsExpiredAt(t time.Time) bool { return i.ExpireAt().Before(t) } // ExpireAt returns the expiration date. -func (i *AuthorizeInfo) ExpireAt() time.Time { +func (i *AuthorizeData) ExpireAt() time.Time { return i.CreatedAt.Add(time.Duration(i.ExpiresIn) * time.Second) } diff --git a/authentication/provider/oauth2/authorize_test.go b/authentication/provider/oauth2/authorize_test.go index 8a15920..d21a455 100644 --- a/authentication/provider/oauth2/authorize_test.go +++ b/authentication/provider/oauth2/authorize_test.go @@ -11,11 +11,11 @@ import ( "github.com/stretchr/testify/assert" ) -func TestAuthorizeInfo(t *testing.T) { +func TestAuthorizeData(t *testing.T) { cat, err := time.Parse("2006-01-02T15:04:05.000Z", "2014-11-12T11:45:26.371Z") assert.NoError(t, err) - ai := &AuthorizeInfo{ + ai := &AuthorizeData{ CreatedAt: cat, ExpiresIn: 10, } diff --git a/authentication/provider/oauth2/configuration.go b/authentication/provider/oauth2/configuration.go index 06047fd..0f12738 100644 --- a/authentication/provider/oauth2/configuration.go +++ b/authentication/provider/oauth2/configuration.go @@ -68,7 +68,7 @@ type Configuration struct { // Separator to support multiple URIs in Client.GetRedirectUri(). // If blank (the default), don't allow multiple URIs. - RedirectUriSeparator string + RedirectURISeparator string // RetainTokenAfter Refresh allows the server to retain the access and // refresh token for re-use - default false. diff --git a/authentication/provider/oauth2/handler.go b/authentication/provider/oauth2/handler.go deleted file mode 100644 index 3908696..0000000 --- a/authentication/provider/oauth2/handler.go +++ /dev/null @@ -1,335 +0,0 @@ -// Copyright 2021 Hyperscale. All rights reserved. -// Use of this source code is governed by a MIT -// license that can be found in the LICENSE file. - -package oauth2 - -import ( - "encoding/json" - "errors" - "net/http" - "strings" - - "github.com/hyperscale-stack/logger" - "github.com/hyperscale-stack/security/authentication/credential" -) - -var ( - ErrRequestMustBePost = errors.New("request must be POST") -) - -func (s Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { - switch r.URL.Path { - case s.cfg.PrefixURI + "/token": - s.handleTokenRequest(w, r) - case s.cfg.PrefixURI + "/authorize": - s.handleAuthorizeRequest(w, r) - } -} - -func (s Server) handleTokenRequest(w http.ResponseWriter, r *http.Request) { - // Only allow GET or POST - if r.Method == http.MethodGet { - if !s.cfg.AllowGetAccessRequest { - s.error(w, E_INVALID_REQUEST, ErrRequestMustBePost, "access_request=%s", "GET request not allowed") - - return - } - } else if r.Method != http.MethodPost { - s.error(w, E_INVALID_REQUEST, ErrRequestMustBePost, "access_request=%s", "request must be POST") - - return - } - - if err := r.ParseForm(); err != nil { - s.error(w, E_INVALID_REQUEST, err, "access_request=%s", "parsing error") - - return - } - - var ar *AccessRequest - - grantType := AccessRequestType(r.FormValue("grant_type")) - if s.cfg.AllowedAccessTypes.Exists(grantType) { - switch grantType { - case AUTHORIZATION_CODE: - // s.handleAuthorizationCodeRequest(w, r) - ar.Authorized = true - case REFRESH_TOKEN: - // s.handleRefreshTokenRequest(w, r) - ar.Authorized = true - case PASSWORD: - ar = s.handlePasswordRequest(w, r) - - user, err := s.userProvider.Authenticate(ar.Username, ar.Password) - if err != nil { - s.error(w, E_ACCESS_DENIED, err, "get_user=%s", "failed") - - return - } - - ar.Authorized = true - ar.UserData = user.GetID() - - case CLIENT_CREDENTIALS: - // s.handleClientCredentialsRequest(w, r) - ar.Authorized = true - case ASSERTION: - // s.handleAssertionRequest(w, r) - ar.Authorized = false - default: - s.error(w, E_UNSUPPORTED_GRANT_TYPE, nil, "access_request=%s", "unknown grant type") - - return - } - - s.FinishAccessRequest(w, r, ar) - } - -} - -func (s Server) handleAuthorizeRequest(w http.ResponseWriter, r *http.Request) { - -} - -func (s Server) handlePasswordRequest(w http.ResponseWriter, r *http.Request) *AccessRequest { - // get client authentication - auth := s.getClientAuth(w, r, s.cfg.AllowClientSecretInParams) - if auth == nil { - return nil - } - - // generate access token - ar := &AccessRequest{ - Type: PASSWORD, - Username: r.FormValue("username"), - Password: r.FormValue("password"), - Scope: r.FormValue("scope"), - GenerateRefresh: true, - Expiration: s.cfg.AccessExpiration, - HttpRequest: r, - } - - // "username" and "password" is required - if ar.Username == "" || ar.Password == "" { - s.error(w, E_INVALID_GRANT, nil, "handle_password=%s", "username and password required") - - return nil - } - - // must have a valid client - if ar.Client = s.getClient(auth, s.storage, w, true); ar.Client == nil { - return nil - } - - // set redirect uri - ar.RedirectURI = FirstURI(ar.Client.GetRedirectURI(), s.cfg.RedirectUriSeparator) - - /* - user, err := s.userProvider.Authenticate(username, password) - if err != nil && errors.Is(err, ErrUserNotFound) { - s.error(w, E_ACCESS_DENIED, nil, "get_user=%s", "username or password is invalid") - - return - } else if err != nil { - s.error(w, E_ACCESS_DENIED, nil, "get_user=%s", "username or password is invalid") - } - */ - - return ar -} - -// Returns the first uri from an uri list -func FirstURI(baseUriList string, separator string) string { - if separator == "" { - return baseUriList - } - - if slist := strings.Split(baseUriList, separator); len(slist) > 0 { - return slist[0] - } - - return "" -} - -// getClientAuth checks client basic authentication in params if allowed, -// otherwise gets it from the header. -// Sets an error on the response if no auth is present or a server error occurs. -func (s Server) getClientAuth(w http.ResponseWriter, r *http.Request, allowQueryParams bool) *credential.UsernamePasswordCredential { - ctx := r.Context() - - // creds := credential.FromContext(ctx) - - if allowQueryParams { - // Allow for auth without password - if _, hasSecret := r.Form["client_secret"]; hasSecret { - auth := credential.NewUsernamePasswordCredential( - r.FormValue("client_id"), - r.FormValue("client_secret"), - ) - - if auth.GetPrincipal() != "" { - return auth.(*credential.UsernamePasswordCredential) - } - } - } - - auth := credential.FromContext(ctx) - - /* - auth, err := CheckBasicAuth(r) - if err != nil { - s.setErrorAndLog(w, E_INVALID_REQUEST, err, "get_client_auth=%s", "check auth error") - return nil - } - */ - if auth == nil { - s.error(w, E_INVALID_REQUEST, errors.New("Client authentication not sent"), "get_client_auth=%s", "client authentication not sent") - - return nil - } - - return auth.(*credential.UsernamePasswordCredential) -} - -// getClient looks up and authenticates the basic auth using the given -// storage. Sets an error on the response if auth fails or a server error occurs. -func (s Server) getClient(creds *credential.UsernamePasswordCredential, storage StorageProvider, w http.ResponseWriter, allowEmptySecret bool) Client { - client, err := storage.LoadClient(creds.GetPrincipal().(string)) - if errors.Is(err, ErrClientNotFound) { - s.error(w, E_UNAUTHORIZED_CLIENT, nil, "get_client=%s", "not found") - - return nil - } else if err != nil { - s.error(w, E_SERVER_ERROR, err, "get_client=%s", "error finding client") - - return nil - } - - if client == nil { - s.error(w, E_UNAUTHORIZED_CLIENT, nil, "get_client=%s", "client is nil") - - return nil - } - - if c, ok := client.(ClientSecretMatcher); ok { - if creds.GetCredentials() != nil && !c.SecretMatches(creds.GetCredentials().(string)) { - s.error(w, E_UNAUTHORIZED_CLIENT, nil, "get_client=%s, client_id=%v", "client check failed", client.GetID()) - - return nil - } else if creds.GetCredentials() == nil && !allowEmptySecret { - s.error(w, E_UNAUTHORIZED_CLIENT, nil, "get_client=%s, client_id=%v", "client check failed", client.GetID()) - - return nil - } - } else if !allowEmptySecret { - s.error(w, E_UNAUTHORIZED_CLIENT, nil, "get_client=%s, client_id=%v", "client check failed", client.GetID()) - - return nil - } - - if client.GetRedirectURI() == "" { - s.error(w, E_UNAUTHORIZED_CLIENT, nil, "get_client=%s", "client redirect uri is empty") - - return nil - } - - return client -} - -func (s Server) error(w http.ResponseWriter, responseError DefaultErrorID, internalError error, debugFormat string, debugArgs ...interface{}) { - format := "error=%v, internal_error=%#v " + debugFormat - - s.logger.Error( - format, - logger.WithFormat(append([]interface{}{responseError, internalError}, debugArgs...)...), - ) -} - -func (s *Server) FinishAccessRequest(w http.ResponseWriter, r *http.Request, ar *AccessRequest) { - // don't process if is already an error - /*if w.IsError { - return - }*/ - - redirectURI := r.FormValue("redirect_uri") - // Get redirect uri from AccessRequest if it's there (e.g., refresh token request) - if ar.RedirectURI != "" { - redirectURI = ar.RedirectURI - } - - if !ar.Authorized { - s.error(w, E_ACCESS_DENIED, nil, "finish_access_request=%s", "authorization failed") - - return - } - - var ret *AccessInfo - var err error - - if ar.ForceAccessInfo == nil { - // generate access token - ret = &AccessInfo{ - Client: ar.Client, - AuthorizeInfo: ar.AuthorizeInfo, - AccessInfo: ar.AccessInfo, - RedirectURI: redirectURI, - CreatedAt: s.now(), - ExpiresIn: int64(ar.Expiration.Seconds()), - UserData: ar.UserData, - Scope: ar.Scope, - } - - // generate access token - ret.AccessToken, ret.RefreshToken, err = s.tokenGenerator.GenerateAccessToken(ar.GenerateRefresh) - if err != nil { - s.error(w, E_SERVER_ERROR, err, "finish_access_request=%s", "error generating token") - - return - } - } else { - ret = ar.ForceAccessInfo - } - - // save access token - if err = s.storage.SaveAccess(ret); err != nil { - s.error(w, E_SERVER_ERROR, err, "finish_access_request=%s", "error saving access token") - - return - } - - // remove authorization token - if ret.AuthorizeInfo != nil { - s.storage.RemoveAuthorize(ret.AuthorizeInfo.Code) - } - - // remove previous access token - if ret.AccessInfo != nil && !s.cfg.RetainTokenAfterRefresh { - if ret.AccessInfo.RefreshToken != "" { - s.storage.RemoveRefresh(ret.AccessInfo.RefreshToken) - } - - s.storage.RemoveAccess(ret.AccessInfo.AccessToken) - } - - output := map[string]interface{}{ - "access_token": ret.AccessToken, - "token_type": s.cfg.TokenType, - "expires_in": ret.ExpiresIn, - } - - if ret.RefreshToken != "" { - output["refresh_token"] = ret.RefreshToken - } - - if ret.Scope != "" { - output["scope"] = ret.Scope - } - - if err := json.NewEncoder(w).Encode(output); err != nil { - s.error(w, E_SERVER_ERROR, err, "finish_access_request=%s", "serialize response failed") - - return - } - -} diff --git a/authentication/provider/oauth2/mock_access_provider.go b/authentication/provider/oauth2/mock_access_provider.go index 0630619..c1203fb 100644 --- a/authentication/provider/oauth2/mock_access_provider.go +++ b/authentication/provider/oauth2/mock_access_provider.go @@ -10,15 +10,15 @@ type MockAccessProvider struct { } // LoadAccess provides a mock function with given fields: token -func (_m *MockAccessProvider) LoadAccess(token string) (*AccessInfo, error) { +func (_m *MockAccessProvider) LoadAccess(token string) (*AccessData, error) { ret := _m.Called(token) - var r0 *AccessInfo - if rf, ok := ret.Get(0).(func(string) *AccessInfo); ok { + var r0 *AccessData + if rf, ok := ret.Get(0).(func(string) *AccessData); ok { r0 = rf(token) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(*AccessInfo) + r0 = ret.Get(0).(*AccessData) } } @@ -47,11 +47,11 @@ func (_m *MockAccessProvider) RemoveAccess(token string) error { } // SaveAccess provides a mock function with given fields: _a0 -func (_m *MockAccessProvider) SaveAccess(_a0 *AccessInfo) error { +func (_m *MockAccessProvider) SaveAccess(_a0 *AccessData) error { ret := _m.Called(_a0) var r0 error - if rf, ok := ret.Get(0).(func(*AccessInfo) error); ok { + if rf, ok := ret.Get(0).(func(*AccessData) error); ok { r0 = rf(_a0) } else { r0 = ret.Error(0) diff --git a/authentication/provider/oauth2/mock_authorize_provider.go b/authentication/provider/oauth2/mock_authorize_provider.go index a9b1cbe..83e3e33 100644 --- a/authentication/provider/oauth2/mock_authorize_provider.go +++ b/authentication/provider/oauth2/mock_authorize_provider.go @@ -10,15 +10,15 @@ type MockAuthorizeProvider struct { } // LoadAuthorize provides a mock function with given fields: code -func (_m *MockAuthorizeProvider) LoadAuthorize(code string) (*AuthorizeInfo, error) { +func (_m *MockAuthorizeProvider) LoadAuthorize(code string) (*AuthorizeData, error) { ret := _m.Called(code) - var r0 *AuthorizeInfo - if rf, ok := ret.Get(0).(func(string) *AuthorizeInfo); ok { + var r0 *AuthorizeData + if rf, ok := ret.Get(0).(func(string) *AuthorizeData); ok { r0 = rf(code) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(*AuthorizeInfo) + r0 = ret.Get(0).(*AuthorizeData) } } @@ -47,11 +47,11 @@ func (_m *MockAuthorizeProvider) RemoveAuthorize(code string) error { } // SaveAuthorize provides a mock function with given fields: _a0 -func (_m *MockAuthorizeProvider) SaveAuthorize(_a0 *AuthorizeInfo) error { +func (_m *MockAuthorizeProvider) SaveAuthorize(_a0 *AuthorizeData) error { ret := _m.Called(_a0) var r0 error - if rf, ok := ret.Get(0).(func(*AuthorizeInfo) error); ok { + if rf, ok := ret.Get(0).(func(*AuthorizeData) error); ok { r0 = rf(_a0) } else { r0 = ret.Error(0) diff --git a/authentication/provider/oauth2/mock_refresh_provider.go b/authentication/provider/oauth2/mock_refresh_provider.go index 67b23c5..f8d073b 100644 --- a/authentication/provider/oauth2/mock_refresh_provider.go +++ b/authentication/provider/oauth2/mock_refresh_provider.go @@ -10,15 +10,15 @@ type MockRefreshProvider struct { } // LoadRefresh provides a mock function with given fields: token -func (_m *MockRefreshProvider) LoadRefresh(token string) (*AccessInfo, error) { +func (_m *MockRefreshProvider) LoadRefresh(token string) (*AccessData, error) { ret := _m.Called(token) - var r0 *AccessInfo - if rf, ok := ret.Get(0).(func(string) *AccessInfo); ok { + var r0 *AccessData + if rf, ok := ret.Get(0).(func(string) *AccessData); ok { r0 = rf(token) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(*AccessInfo) + r0 = ret.Get(0).(*AccessData) } } @@ -47,11 +47,11 @@ func (_m *MockRefreshProvider) RemoveRefresh(token string) error { } // SaveRefresh provides a mock function with given fields: _a0 -func (_m *MockRefreshProvider) SaveRefresh(_a0 *AccessInfo) error { +func (_m *MockRefreshProvider) SaveRefresh(_a0 *AccessData) error { ret := _m.Called(_a0) var r0 error - if rf, ok := ret.Get(0).(func(*AccessInfo) error); ok { + if rf, ok := ret.Get(0).(func(*AccessData) error); ok { r0 = rf(_a0) } else { r0 = ret.Error(0) diff --git a/authentication/provider/oauth2/mock_storage_provider.go b/authentication/provider/oauth2/mock_storage_provider.go index f1ad0ed..53b43aa 100644 --- a/authentication/provider/oauth2/mock_storage_provider.go +++ b/authentication/provider/oauth2/mock_storage_provider.go @@ -10,15 +10,15 @@ type MockStorageProvider struct { } // LoadAccess provides a mock function with given fields: token -func (_m *MockStorageProvider) LoadAccess(token string) (*AccessInfo, error) { +func (_m *MockStorageProvider) LoadAccess(token string) (*AccessData, error) { ret := _m.Called(token) - var r0 *AccessInfo - if rf, ok := ret.Get(0).(func(string) *AccessInfo); ok { + var r0 *AccessData + if rf, ok := ret.Get(0).(func(string) *AccessData); ok { r0 = rf(token) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(*AccessInfo) + r0 = ret.Get(0).(*AccessData) } } @@ -33,15 +33,15 @@ func (_m *MockStorageProvider) LoadAccess(token string) (*AccessInfo, error) { } // LoadAuthorize provides a mock function with given fields: code -func (_m *MockStorageProvider) LoadAuthorize(code string) (*AuthorizeInfo, error) { +func (_m *MockStorageProvider) LoadAuthorize(code string) (*AuthorizeData, error) { ret := _m.Called(code) - var r0 *AuthorizeInfo - if rf, ok := ret.Get(0).(func(string) *AuthorizeInfo); ok { + var r0 *AuthorizeData + if rf, ok := ret.Get(0).(func(string) *AuthorizeData); ok { r0 = rf(code) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(*AuthorizeInfo) + r0 = ret.Get(0).(*AuthorizeData) } } @@ -79,15 +79,15 @@ func (_m *MockStorageProvider) LoadClient(id string) (Client, error) { } // LoadRefresh provides a mock function with given fields: token -func (_m *MockStorageProvider) LoadRefresh(token string) (*AccessInfo, error) { +func (_m *MockStorageProvider) LoadRefresh(token string) (*AccessData, error) { ret := _m.Called(token) - var r0 *AccessInfo - if rf, ok := ret.Get(0).(func(string) *AccessInfo); ok { + var r0 *AccessData + if rf, ok := ret.Get(0).(func(string) *AccessData); ok { r0 = rf(token) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(*AccessInfo) + r0 = ret.Get(0).(*AccessData) } } @@ -158,11 +158,11 @@ func (_m *MockStorageProvider) RemoveRefresh(token string) error { } // SaveAccess provides a mock function with given fields: _a0 -func (_m *MockStorageProvider) SaveAccess(_a0 *AccessInfo) error { +func (_m *MockStorageProvider) SaveAccess(_a0 *AccessData) error { ret := _m.Called(_a0) var r0 error - if rf, ok := ret.Get(0).(func(*AccessInfo) error); ok { + if rf, ok := ret.Get(0).(func(*AccessData) error); ok { r0 = rf(_a0) } else { r0 = ret.Error(0) @@ -172,11 +172,11 @@ func (_m *MockStorageProvider) SaveAccess(_a0 *AccessInfo) error { } // SaveAuthorize provides a mock function with given fields: _a0 -func (_m *MockStorageProvider) SaveAuthorize(_a0 *AuthorizeInfo) error { +func (_m *MockStorageProvider) SaveAuthorize(_a0 *AuthorizeData) error { ret := _m.Called(_a0) var r0 error - if rf, ok := ret.Get(0).(func(*AuthorizeInfo) error); ok { + if rf, ok := ret.Get(0).(func(*AuthorizeData) error); ok { r0 = rf(_a0) } else { r0 = ret.Error(0) @@ -200,11 +200,11 @@ func (_m *MockStorageProvider) SaveClient(_a0 Client) error { } // SaveRefresh provides a mock function with given fields: _a0 -func (_m *MockStorageProvider) SaveRefresh(_a0 *AccessInfo) error { +func (_m *MockStorageProvider) SaveRefresh(_a0 *AccessData) error { ret := _m.Called(_a0) var r0 error - if rf, ok := ret.Get(0).(func(*AccessInfo) error); ok { + if rf, ok := ret.Get(0).(func(*AccessData) error); ok { r0 = rf(_a0) } else { r0 = ret.Error(0) diff --git a/authentication/provider/oauth2/oauth2_authentication_provider_test.go b/authentication/provider/oauth2/oauth2_authentication_provider_test.go index 34c99d3..28d01b8 100644 --- a/authentication/provider/oauth2/oauth2_authentication_provider_test.go +++ b/authentication/provider/oauth2/oauth2_authentication_provider_test.go @@ -155,7 +155,7 @@ func TestOAuth2AuthenticationProviderAuthenticateByAccessTokenWithTokenExpired(t accessStorageMock := &MockAccessProvider{} - access := &AccessInfo{ + access := &AccessData{ AccessToken: "wSxJOjDWo7qQ7kF5Tlg2l9XZYat6gq6GssF5D5I9aKtcEipJzoTba77vRhfscn1vNr0gBM9rSj5sZ3R6252FTlJpxWPUM1c8w2KkvaAAcyrWqNPVNNFX2qAxhpcatdbR", ExpiresIn: 60, UserData: userMock, @@ -186,7 +186,7 @@ func TestOAuth2AuthenticationProviderAuthenticateByAccessTokenWithUserNotFound(t accessStorageMock := &MockAccessProvider{} - access := &AccessInfo{ + access := &AccessData{ AccessToken: "wSxJOjDWo7qQ7kF5Tlg2l9XZYat6gq6GssF5D5I9aKtcEipJzoTba77vRhfscn1vNr0gBM9rSj5sZ3R6252FTlJpxWPUM1c8w2KkvaAAcyrWqNPVNNFX2qAxhpcatdbR", ExpiresIn: 60, CreatedAt: time.Now(), @@ -227,7 +227,7 @@ func TestOAuth2AuthenticationProviderAuthenticateByAccessTokenWithToken(t *testi RedirectURI: "https://connect.myservice.tld", } - access := &AccessInfo{ + access := &AccessData{ Client: client, AccessToken: "wSxJOjDWo7qQ7kF5Tlg2l9XZYat6gq6GssF5D5I9aKtcEipJzoTba77vRhfscn1vNr0gBM9rSj5sZ3R6252FTlJpxWPUM1c8w2KkvaAAcyrWqNPVNNFX2qAxhpcatdbR", ExpiresIn: 60, @@ -282,7 +282,7 @@ func TestOAuth2AuthenticationProviderAuthenticateByAccessTokenWithBadUserDataTyp RedirectURI: "https://connect.myservice.tld", } - access := &AccessInfo{ + access := &AccessData{ Client: client, AccessToken: "wSxJOjDWo7qQ7kF5Tlg2l9XZYat6gq6GssF5D5I9aKtcEipJzoTba77vRhfscn1vNr0gBM9rSj5sZ3R6252FTlJpxWPUM1c8w2KkvaAAcyrWqNPVNNFX2qAxhpcatdbR", ExpiresIn: 60, diff --git a/authentication/provider/oauth2/response.go b/authentication/provider/oauth2/response.go index 2894a64..985c16f 100644 --- a/authentication/provider/oauth2/response.go +++ b/authentication/provider/oauth2/response.go @@ -52,7 +52,7 @@ func NewResponse(storage StorageProvider) *Response { Output: make(ResponseData), Headers: make(http.Header), IsError: false, - // Storage: storage.Clone(), + Storage: storage, // Clone ? } r.Headers.Add( "Cache-Control", @@ -160,7 +160,3 @@ func (r *Response) GetRedirectURL() (string, error) { return u.String(), nil } - -func (r *Response) Close() { - // r.Storage.Close() -} diff --git a/authentication/provider/oauth2/response_json.go b/authentication/provider/oauth2/response_json.go new file mode 100644 index 0000000..730232b --- /dev/null +++ b/authentication/provider/oauth2/response_json.go @@ -0,0 +1,44 @@ +// Copyright 2021 Hyperscale. All rights reserved. +// Use of this source code is governed by a MIT +// license that can be found in the LICENSE file. + +package oauth2 + +import ( + "encoding/json" + "net/http" +) + +// OutputJSON encodes the Response to JSON and writes to the http.ResponseWriter +func OutputJSON(rs *Response, w http.ResponseWriter, r *http.Request) error { + // Add headers + for i, k := range rs.Headers { + for _, v := range k { + w.Header().Add(i, v) + } + } + + if rs.Type == REDIRECT { + // Output redirect with parameters + u, err := rs.GetRedirectURL() + if err != nil { + return err + } + + w.Header().Add("Location", u) + w.WriteHeader(302) + + return nil + + } + + // set content type if the response doesn't already have one associated with it + if w.Header().Get("Content-Type") == "" { + w.Header().Set("Content-Type", "application/json") + } + + w.WriteHeader(rs.StatusCode) + + return json.NewEncoder(w).Encode(rs.Output) + +} diff --git a/authentication/provider/oauth2/response_json_test.go b/authentication/provider/oauth2/response_json_test.go new file mode 100644 index 0000000..e82b5a0 --- /dev/null +++ b/authentication/provider/oauth2/response_json_test.go @@ -0,0 +1,111 @@ +// Copyright 2021 Hyperscale. All rights reserved. +// Use of this source code is governed by a MIT +// license that can be found in the LICENSE file. + +package oauth2 + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestResponseJSON(t *testing.T) { + req, err := http.NewRequest("GET", "http://localhost:14000/appauth", nil) + assert.NoError(t, err) + + w := httptest.NewRecorder() + + storageMock := &MockStorageProvider{} + + r := NewResponse(storageMock) + + r.Output["access_token"] = "1234" + r.Output["token_type"] = "5678" + + err = OutputJSON(r, w, req) + assert.NoError(t, err) + + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "application/json", w.Result().Header.Get("Content-Type")) + + // parse output json + output := make(map[string]interface{}) + err = json.Unmarshal(w.Body.Bytes(), &output) + assert.NoError(t, err) + + assert.Contains(t, output, "access_token") + assert.Equal(t, "1234", output["access_token"]) + + assert.Contains(t, output, "token_type") + assert.Equal(t, "5678", output["token_type"]) +} + +func TestErrorResponseJSON(t *testing.T) { + req, err := http.NewRequest("GET", "http://localhost:14000/appauth", nil) + assert.NoError(t, err) + + w := httptest.NewRecorder() + + storageMock := &MockStorageProvider{} + + r := NewResponse(storageMock) + r.ErrorStatusCode = 500 + r.SetError(E_INVALID_REQUEST, "") + + err = OutputJSON(r, w, req) + assert.NoError(t, err) + + assert.Equal(t, 500, w.Code) + + assert.Equal(t, "application/json", w.Result().Header.Get("Content-Type")) + + // parse output json + output := make(map[string]interface{}) + err = json.Unmarshal(w.Body.Bytes(), &output) + assert.NoError(t, err) + + assert.Contains(t, output, "error") + assert.Equal(t, E_INVALID_REQUEST.String(), output["error"]) +} + +func TestRedirectResponseJSON(t *testing.T) { + req, err := http.NewRequest("GET", "http://localhost:14000/appauth", nil) + assert.NoError(t, err) + + w := httptest.NewRecorder() + + storageMock := &MockStorageProvider{} + + r := NewResponse(storageMock) + r.SetRedirect("http://localhost:14000") + + err = OutputJSON(r, w, req) + assert.NoError(t, err) + + assert.Equal(t, http.StatusFound, w.Code) + + assert.Equal(t, "http://localhost:14000", w.Result().Header.Get("Location")) +} + +func TestRedirectResponseJSONWithError(t *testing.T) { + req, err := http.NewRequest("GET", "http://localhost:14000/appauth", nil) + assert.NoError(t, err) + + w := httptest.NewRecorder() + + storageMock := &MockStorageProvider{} + + r := NewResponse(storageMock) + r.SetRedirect(":14000") + + err = OutputJSON(r, w, req) + assert.EqualError(t, err, "parse url failed: parse \":14000\": missing protocol scheme") + + assert.Equal(t, http.StatusOK, w.Code) + + assert.Equal(t, "", w.Result().Header.Get("Location")) +} diff --git a/authentication/provider/oauth2/server.go b/authentication/provider/oauth2/server.go index 1739019..0e88710 100644 --- a/authentication/provider/oauth2/server.go +++ b/authentication/provider/oauth2/server.go @@ -5,6 +5,7 @@ package oauth2 import ( + "errors" "time" "github.com/hyperscale-stack/logger" @@ -12,6 +13,10 @@ import ( "github.com/hyperscale-stack/security/authentication/provider/oauth2/token/random" ) +var ( + ErrRequestMustBePost = errors.New("request must be POST") +) + type Server struct { cfg *Configuration logger logger.Logger @@ -37,3 +42,11 @@ func NewServer(options ...Option) *Server { return s } + +// NewResponse creates a new response for the server +func (s *Server) NewResponse() *Response { + r := NewResponse(s.storage) + r.ErrorStatusCode = s.cfg.ErrorStatusCode + + return r +} diff --git a/authentication/provider/oauth2/storage.go b/authentication/provider/oauth2/storage.go index 742b439..f605904 100644 --- a/authentication/provider/oauth2/storage.go +++ b/authentication/provider/oauth2/storage.go @@ -27,22 +27,22 @@ type ClientProvider interface { //go:generate mockery --name=AccessProvider --inpackage --case underscore type AccessProvider interface { - SaveAccess(*AccessInfo) error - LoadAccess(token string) (*AccessInfo, error) + SaveAccess(*AccessData) error + LoadAccess(token string) (*AccessData, error) RemoveAccess(token string) error } //go:generate mockery --name=RefreshProvider --inpackage --case underscore type RefreshProvider interface { - SaveRefresh(*AccessInfo) error - LoadRefresh(token string) (*AccessInfo, error) + SaveRefresh(*AccessData) error + LoadRefresh(token string) (*AccessData, error) RemoveRefresh(token string) error } //go:generate mockery --name=AuthorizeProvider --inpackage --case underscore type AuthorizeProvider interface { - SaveAuthorize(*AuthorizeInfo) error - LoadAuthorize(code string) (*AuthorizeInfo, error) + SaveAuthorize(*AuthorizeData) error + LoadAuthorize(code string) (*AuthorizeData, error) RemoveAuthorize(code string) error } diff --git a/authentication/provider/oauth2/storage/in_memory_storage.go b/authentication/provider/oauth2/storage/in_memory_storage.go index b7c2028..f7ceefc 100644 --- a/authentication/provider/oauth2/storage/in_memory_storage.go +++ b/authentication/provider/oauth2/storage/in_memory_storage.go @@ -43,15 +43,15 @@ func (s *InMemoryStorage) RemoveClient(id string) error { return nil } -func (s *InMemoryStorage) SaveAccess(access *oauth2.AccessInfo) error { +func (s *InMemoryStorage) SaveAccess(access *oauth2.AccessData) error { s.accesses.Store(access.AccessToken, access) return nil } -func (s *InMemoryStorage) LoadAccess(token string) (*oauth2.AccessInfo, error) { +func (s *InMemoryStorage) LoadAccess(token string) (*oauth2.AccessData, error) { if access, ok := s.accesses.Load(token); ok { - return access.(*oauth2.AccessInfo), nil + return access.(*oauth2.AccessData), nil } return nil, oauth2.ErrAccessNotFound @@ -63,15 +63,15 @@ func (s *InMemoryStorage) RemoveAccess(token string) error { return nil } -func (s *InMemoryStorage) SaveRefresh(access *oauth2.AccessInfo) error { +func (s *InMemoryStorage) SaveRefresh(access *oauth2.AccessData) error { s.refreshs.Store(access.RefreshToken, access) return nil } -func (s *InMemoryStorage) LoadRefresh(token string) (*oauth2.AccessInfo, error) { +func (s *InMemoryStorage) LoadRefresh(token string) (*oauth2.AccessData, error) { if access, ok := s.refreshs.Load(token); ok { - return access.(*oauth2.AccessInfo), nil + return access.(*oauth2.AccessData), nil } return nil, oauth2.ErrRefreshNotFound @@ -83,15 +83,15 @@ func (s *InMemoryStorage) RemoveRefresh(token string) error { return nil } -func (s *InMemoryStorage) SaveAuthorize(authorize *oauth2.AuthorizeInfo) error { +func (s *InMemoryStorage) SaveAuthorize(authorize *oauth2.AuthorizeData) error { s.authorizes.Store(authorize.Code, authorize) return nil } -func (s *InMemoryStorage) LoadAuthorize(code string) (*oauth2.AuthorizeInfo, error) { +func (s *InMemoryStorage) LoadAuthorize(code string) (*oauth2.AuthorizeData, error) { if authorize, ok := s.authorizes.Load(code); ok { - return authorize.(*oauth2.AuthorizeInfo), nil + return authorize.(*oauth2.AuthorizeData), nil } return nil, oauth2.ErrAuthorizeNotFound diff --git a/authentication/provider/oauth2/storage/in_memory_storage_test.go b/authentication/provider/oauth2/storage/in_memory_storage_test.go index 97cdd7e..ff39411 100644 --- a/authentication/provider/oauth2/storage/in_memory_storage_test.go +++ b/authentication/provider/oauth2/storage/in_memory_storage_test.go @@ -40,7 +40,7 @@ func TestInMemoryStorage(t *testing.T) { assert.Nil(t, client2) // Access Token - access := &oauth2.AccessInfo{ + access := &oauth2.AccessData{ AccessToken: "OKjQ0VjYmJxP8N0TzXH5lxvIOZj4bCM0DlsCvuiL96HCQEhJ8A9ozY8jJ5Ep38vaVvn082fgApThX7NZ7pktKn57A667kEeWLPW0KVA3x1flYdBvkIvHOAZYyvUeKK9q", } @@ -63,7 +63,7 @@ func TestInMemoryStorage(t *testing.T) { assert.Nil(t, access2) // Refresh Token - access = &oauth2.AccessInfo{ + access = &oauth2.AccessData{ RefreshToken: "2oQDkOWnbqtJoEs24MkVEB4WNJnqyoAIErvSJRhjg562K8GznWLbLZuStQodKvReSedAqufswaSZduhlgOuCNcQj9aGbCKPAnXUVvmX7Vmgvryp9PaZVbuqj0HfzN9tD", } @@ -86,7 +86,7 @@ func TestInMemoryStorage(t *testing.T) { assert.Nil(t, access2) // Authorize Code - authorize := &oauth2.AuthorizeInfo{ + authorize := &oauth2.AuthorizeData{ Code: "Je4dJ5RFPRJwuSmuitSo8tX7s3uFOP84sEufxjdqJhiiPABdbxeGofGvvX7LBdvy2ZrwDZy3a6cOF8vgquUlr8yAvA9VpDz4Kv2bZxm0WEl4y3SJSvYPnwBOxRHI5pxK", } diff --git a/authentication/provider/oauth2/urivalidate.go b/authentication/provider/oauth2/urivalidate.go new file mode 100644 index 0000000..e3533ae --- /dev/null +++ b/authentication/provider/oauth2/urivalidate.go @@ -0,0 +1,125 @@ +// Copyright 2021 Hyperscale. All rights reserved. +// Use of this source code is governed by a MIT +// license that can be found in the LICENSE file. + +package oauth2 + +import ( + "errors" + "fmt" + "net/url" + "strings" +) + +// error returned when validation don't match +type URIValidationError string + +func (e URIValidationError) Error() string { + return string(e) +} + +func newURIValidationError(msg string, base string, redirect string) URIValidationError { + return URIValidationError(fmt.Sprintf("%s: %s / %s", msg, base, redirect)) +} + +// Parse urls, resolving uri references to base url +func ParseURLs(baseUrl, redirectUrl string) (retBaseUrl, retRedirectUrl *url.URL, err error) { + var base, redirect *url.URL + // parse base url + if base, err = url.Parse(baseUrl); err != nil { + return nil, nil, err + } + + // parse redirect url + if redirect, err = url.Parse(redirectUrl); err != nil { + return nil, nil, err + } + + // must not have fragment + if base.Fragment != "" || redirect.Fragment != "" { + return nil, nil, newURIValidationError("url must not include fragment.", baseUrl, redirectUrl) + } + + // Scheme must match + if redirect.Scheme != base.Scheme { + return nil, nil, newURIValidationError("scheme mismatch", baseUrl, redirectUrl) + } + + // Host must match + if redirect.Host != base.Host { + return nil, nil, newURIValidationError("host mismatch", baseUrl, redirectUrl) + } + + // resolve references to base url + retBaseUrl = (&url.URL{Scheme: base.Scheme, Host: base.Host, Path: "/"}).ResolveReference(&url.URL{Path: base.Path}) + retRedirectUrl = (&url.URL{Scheme: base.Scheme, Host: base.Host, Path: "/"}).ResolveReference(&url.URL{Path: redirect.Path, RawQuery: redirect.RawQuery}) + + return +} + +// ValidateUriList validates that redirectUri is contained in baseUriList. +// baseUriList may be a string separated by separator. +// If separator is blank, validate only 1 URI. +func ValidateURIList(baseUriList string, redirectUri string, separator string) (realRedirectUri string, err error) { + // make a list of uris + var slist []string + if separator != "" { + slist = strings.Split(baseUriList, separator) + } else { + slist = make([]string, 0) + slist = append(slist, baseUriList) + } + + for _, sitem := range slist { + realRedirectUri, err = ValidateURI(sitem, redirectUri) + // validated, return no error + if err == nil { + return realRedirectUri, nil + } + + // if there was an error that is not a validation error, return it + if _, iok := err.(URIValidationError); !iok { + return "", err + } + } + + return "", newURIValidationError("urls don't validate", baseUriList, redirectUri) +} + +// ValidateURI validates that redirectUri is contained in baseUri. +func ValidateURI(baseUri string, redirectUri string) (realRedirectUri string, err error) { + if baseUri == "" || redirectUri == "" { + return "", errors.New("urls cannot be blank") + } + + base, redirect, err := ParseURLs(baseUri, redirectUri) + if err != nil { + return "", err + } + + // allow exact path matches + if base.Path == redirect.Path { + return redirect.String(), nil + } + + // ensure prefix matches are actually subpaths + requiredPrefix := strings.TrimRight(base.Path, "/") + "/" + if !strings.HasPrefix(redirect.Path, requiredPrefix) { + return "", newURIValidationError("path prefix doesn't match", baseUri, redirectUri) + } + + return redirect.String(), nil +} + +// FirstURI returns the first uri from an uri list. +func FirstURI(baseUriList string, separator string) string { + if separator == "" { + return baseUriList + } + + if slist := strings.Split(baseUriList, separator); len(slist) > 0 { + return slist[0] + } + + return "" +} diff --git a/authentication/provider/oauth2/urivalidate_test.go b/authentication/provider/oauth2/urivalidate_test.go new file mode 100644 index 0000000..2b498ec --- /dev/null +++ b/authentication/provider/oauth2/urivalidate_test.go @@ -0,0 +1,163 @@ +// Copyright 2021 Hyperscale. All rights reserved. +// Use of this source code is governed by a MIT +// license that can be found in the LICENSE file. + +package oauth2 + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestURIValidate(t *testing.T) { + valid := [][]string{ + { + // Exact match + "http://localhost:14000/appauth", + "http://localhost:14000/appauth", + "http://localhost:14000/appauth", + }, + { + // Trailing slash + "http://www.google.com/myapp", + "http://www.google.com/myapp/", + "http://www.google.com/myapp/", + }, + { + // Exact match with trailing slash + "http://www.google.com/myapp/", + "http://www.google.com/myapp/", + "http://www.google.com/myapp/", + }, + { + // Subpath + "http://www.google.com/myapp", + "http://www.google.com/myapp/interface/implementation", + "http://www.google.com/myapp/interface/implementation", + }, + { + // Subpath with trailing slash + "http://www.google.com/myapp/", + "http://www.google.com/myapp/interface/implementation", + "http://www.google.com/myapp/interface/implementation", + }, + { + // Subpath with things that are close to path traversals, but aren't + "http://www.google.com/myapp", + "http://www.google.com/myapp/.../..implementation../...", + "http://www.google.com/myapp/.../..implementation../...", + }, + { + // If the allowed basepath contains path traversals, allow them? + "http://www.google.com/traversal/../allowed", + "http://www.google.com/traversal/../allowed/with/subpath", + "http://www.google.com/allowed/with/subpath", + }, + { + // Backslashes + "https://mysafewebsite.com/secure/redirect", + "https://mysafewebsite.com/secure/redirect/\\../\\../\\../evil", + "https://mysafewebsite.com/secure/redirect/%5C../%5C../%5C../evil", + }, + { + // Backslashes + "https://mysafewebsite.com/secure/redirect", + "https://mysafewebsite.com/secure/redirect/\\..\\../\\../evil", + "https://mysafewebsite.com/secure/redirect/%5C..%5C../%5C../evil", + }, + { + // Query string must be kept + "http://www.google.com/myapp/redir", + "http://www.google.com/myapp/redir?a=1&b=2", + "http://www.google.com/myapp/redir?a=1&b=2", + }, + } + for _, v := range valid { + realRedirectURI, err := ValidateURI(v[0], v[1]) + assert.NoError(t, err) + assert.Equal(t, v[2], realRedirectURI) + } + + invalid := [][]string{ + { + // Doesn't satisfy base path + "http://localhost:14000/appauth", + "http://localhost:14000/app", + }, + { + // Doesn't satisfy base path + "http://localhost:14000/app/", + "http://localhost:14000/app", + }, + { + // Not a subpath of base path + "http://localhost:14000/appauth", + "http://localhost:14000/appauthmodifiedpath", + }, + { + // Host mismatch + "http://www.google.com/myapp", + "http://www2.google.com/myapp", + }, + { + // Scheme mismatch + "http://www.google.com/myapp", + "https://www.google.com/myapp", + }, + { + // Path traversal + "http://www.google.com/myapp", + "http://www.google.com/myapp/..", + }, + { + // Embedded path traversal + "http://www.google.com/myapp", + "http://www.google.com/myapp/../test", + }, + { + // Not a subpath + "http://www.google.com/myapp", + "http://www.google.com/myapp../test", + }, + { + // Backslashes + "https://mysafewebsite.com/secure/redirect", + "https://mysafewebsite.com/secure%2fredirect/../evil", + }, + } + for _, v := range invalid { + if _, err := ValidateURI(v[0], v[1]); err == nil { + t.Errorf("Expected ValidateURI(%s, %s) to fail", v[0], v[1]) + } + } +} + +func TestURIListValidate(t *testing.T) { + // V1 + if _, err := ValidateURIList("http://localhost:14000/appauth", "http://localhost:14000/appauth", ""); err != nil { + t.Errorf("V1: %s", err) + } + + // V2 + if _, err := ValidateURIList("http://localhost:14000/appauth", "http://localhost:14000/app", ""); err == nil { + t.Error("V2 should have failed") + } + + // V3 + if _, err := ValidateURIList("http://xxx:14000/appauth;http://localhost:14000/appauth", "http://localhost:14000/appauth", ";"); err != nil { + t.Errorf("V3: %s", err) + } + + // V4 + if _, err := ValidateURIList("http://xxx:14000/appauth;http://localhost:14000/appauth", "http://localhost:14000/app", ";"); err == nil { + t.Error("V4 should have failed") + } +} + +func TestFirstURI(t *testing.T) { + assert.Equal(t, "https://auth.mydomain.com/connect", FirstURI("https://auth.mydomain.com/connect mybundle://connect", " ")) + assert.Equal(t, "mybundle://connect", FirstURI("mybundle://connect", " ")) + assert.Equal(t, "mybundle://connect", FirstURI("mybundle://connect", "")) + assert.Equal(t, "", FirstURI("", " ")) +} diff --git a/authentication/provider/oauth2/util.go b/authentication/provider/oauth2/util.go new file mode 100644 index 0000000..67f3be5 --- /dev/null +++ b/authentication/provider/oauth2/util.go @@ -0,0 +1,129 @@ +// Copyright 2021 Hyperscale. All rights reserved. +// Use of this source code is governed by a MIT +// license that can be found in the LICENSE file. + +package oauth2 + +import ( + "crypto/subtle" + "encoding/base64" + "errors" + "net/http" + "net/url" + "strings" + + "github.com/hyperscale-stack/security/authentication/credential" +) + +// Parse basic authentication header +type BasicAuth struct { + Username string + Password string +} + +// Parse bearer authentication header +type BearerAuth struct { + Code string +} + +// CheckClientSecret determines whether the given secret matches a secret held by the client. +// Public clients return true for a secret of "" +func CheckClientSecret(client Client, secret string) bool { + switch client := client.(type) { + case ClientSecretMatcher: + // Prefer the more secure method of giving the secret to the client for comparison + return client.SecretMatches(secret) + default: + // Fallback to the less secure method of extracting the plain text secret from the client for comparison + return subtle.ConstantTimeCompare([]byte(client.GetSecret()), []byte(secret)) == 1 + } +} + +// Return authorization header data +func CheckBasicAuth(r *http.Request) (*BasicAuth, error) { + if r.Header.Get("Authorization") == "" { + return nil, nil + } + + s := strings.SplitN(r.Header.Get("Authorization"), " ", 2) + if len(s) != 2 || s[0] != "Basic" { + return nil, errors.New("Invalid authorization header") + } + + b, err := base64.StdEncoding.DecodeString(s[1]) + if err != nil { + return nil, err + } + pair := strings.SplitN(string(b), ":", 2) + if len(pair) != 2 { + return nil, errors.New("Invalid authorization message") + } + + // Decode the client_id and client_secret pairs as per + // https://tools.ietf.org/html/rfc6749#section-2.3.1 + + username, err := url.QueryUnescape(pair[0]) + if err != nil { + return nil, err + } + + password, err := url.QueryUnescape(pair[1]) + if err != nil { + return nil, err + } + + return &BasicAuth{Username: username, Password: password}, nil +} + +// Return "Bearer" token from request. The header has precedence over query string. +func CheckBearerAuth(r *http.Request) *BearerAuth { + authHeader := r.Header.Get("Authorization") + authForm := r.FormValue("code") + if authHeader == "" && authForm == "" { + return nil + } + token := authForm + if authHeader != "" { + s := strings.SplitN(authHeader, " ", 2) + if (len(s) != 2 || strings.ToLower(s[0]) != "bearer") && token == "" { + return nil + } + //Use authorization header token only if token type is bearer else query string access token would be returned + if len(s) > 0 && strings.ToLower(s[0]) == "bearer" { + token = s[1] + } + } + return &BearerAuth{Code: token} +} + +// getClientAuth checks client basic authentication in params if allowed, +// otherwise gets it from the header. +// Sets an error on the response if no auth is present or a server error occurs. +func (s Server) getClientAuth(w *Response, r *http.Request, allowQueryParams bool) *credential.UsernamePasswordCredential { + ctx := r.Context() + + // creds := credential.FromContext(ctx) + + if allowQueryParams { + // Allow for auth without password + if _, hasSecret := r.Form["client_secret"]; hasSecret { + auth := credential.NewUsernamePasswordCredential( + r.FormValue("client_id"), + r.FormValue("client_secret"), + ) + + if auth.GetPrincipal() != "" { + return auth.(*credential.UsernamePasswordCredential) + } + } + } + + auth := credential.FromContext(ctx) + if auth == nil { + s.setErrorAndLog(w, E_INVALID_REQUEST, errors.New("Client authentication not sent"), "get_client_auth=%s", "client authentication not sent") + + return nil + } + + return auth.(*credential.UsernamePasswordCredential) +} diff --git a/authentication/provider/oauth2/util_test.go b/authentication/provider/oauth2/util_test.go new file mode 100644 index 0000000..910878d --- /dev/null +++ b/authentication/provider/oauth2/util_test.go @@ -0,0 +1,221 @@ +// Copyright 2021 Hyperscale. All rights reserved. +// Use of this source code is governed by a MIT +// license that can be found in the LICENSE file. + +package oauth2 + +import ( + "net/http" + "net/url" + "testing" + + "github.com/hyperscale-stack/security/authentication" + "github.com/stretchr/testify/assert" +) + +const ( + badAuthValue = "Digest XHHHHHHH" + badUsernameInAuthValue = "Basic dSUyc2VybmFtZTpwYXNzd29yZA==" // u%2sername:password + badPasswordInAuthValue = "Basic dXNlcm5hbWU6cGElMnN3b3Jk" // username:pa%2sword + goodAuthValue = "Basic Y2xpZW50K25hbWU6Y2xpZW50KyUyNGVjcmV0" + goodBearerAuthValue = "Bearer BGFVTDUJDp0ZXN0" +) + +func TestBasicAuth(t *testing.T) { + r := &http.Request{Header: make(http.Header)} + + // Without any header + if b, err := CheckBasicAuth(r); b != nil || err != nil { + t.Errorf("Validated basic auth without header") + } + + // with invalid header + r.Header.Set("Authorization", badAuthValue) + b, err := CheckBasicAuth(r) + if b != nil || err == nil { + t.Errorf("Validated invalid auth") + return + } + + // with invalid username + r.Header.Set("Authorization", badUsernameInAuthValue) + b, err = CheckBasicAuth(r) + if b != nil || err == nil { + t.Errorf("Validated invalid auth with bad username") + return + } + + // with invalid username + r.Header.Set("Authorization", badPasswordInAuthValue) + b, err = CheckBasicAuth(r) + if b != nil || err == nil { + t.Errorf("Validated invalid auth with bad password") + return + } + + // with valid header + r.Header.Set("Authorization", goodAuthValue) + b, err = CheckBasicAuth(r) + if b == nil || err != nil { + t.Errorf("Could not extract basic auth") + return + } + + // check extracted auth data + if b.Username != "client name" || b.Password != "client $ecret" { + t.Errorf("Error decoding basic auth") + } +} + +func TestGetClientAuth(t *testing.T) { + + urlWithSecret, _ := url.Parse("http://host.tld/path?client_id=xxx&client_secret=yyy") + urlWithEmptySecret, _ := url.Parse("http://host.tld/path?client_id=xxx&client_secret=") + urlNoSecret, _ := url.Parse("http://host.tld/path?client_id=xxx") + + headerNoAuth := make(http.Header) + headerBadAuth := make(http.Header) + headerBadAuth.Set("Authorization", badAuthValue) + headerOKAuth := make(http.Header) + headerOKAuth.Set("Authorization", goodAuthValue) + + storageMock := &MockStorageProvider{} + + sconfig := NewConfiguration() + + server := NewServer(WithStorage(storageMock), WithConfig(sconfig)) + + var tests = []struct { + header http.Header + url *url.URL + allowQueryParams bool + expectAuth bool + }{ + {headerNoAuth, urlWithSecret, true, true}, + {headerNoAuth, urlWithSecret, false, false}, + {headerNoAuth, urlWithEmptySecret, true, true}, + {headerNoAuth, urlWithEmptySecret, false, false}, + {headerNoAuth, urlNoSecret, true, false}, + {headerNoAuth, urlNoSecret, false, false}, + + {headerBadAuth, urlWithSecret, true, true}, + {headerBadAuth, urlWithSecret, false, false}, + {headerBadAuth, urlWithEmptySecret, true, true}, + {headerBadAuth, urlWithEmptySecret, false, false}, + {headerBadAuth, urlNoSecret, true, false}, + {headerBadAuth, urlNoSecret, false, false}, + + {headerOKAuth, urlWithSecret, true, true}, + {headerOKAuth, urlWithSecret, false, true}, + {headerOKAuth, urlWithEmptySecret, true, true}, + {headerOKAuth, urlWithEmptySecret, false, true}, + {headerOKAuth, urlNoSecret, true, true}, + {headerOKAuth, urlNoSecret, false, true}, + } + + for _, tt := range tests { + w := new(Response) + r := &http.Request{Header: tt.header, URL: tt.url} + r.ParseForm() + + f := authentication.NewHTTPBasicFilter() + r = f.OnFilter(r) + + auth := server.getClientAuth(w, r, tt.allowQueryParams) + + if tt.expectAuth { + assert.NotNil(t, auth) + } else { + assert.Nil(t, auth) + } + } + +} + +func TestBearerAuth(t *testing.T) { + r := &http.Request{Header: make(http.Header)} + + // Without any header + if b := CheckBearerAuth(r); b != nil { + t.Errorf("Validated bearer auth without header") + } + + // with invalid header + r.Header.Set("Authorization", badAuthValue) + b := CheckBearerAuth(r) + if b != nil { + t.Errorf("Validated invalid auth") + return + } + + // with valid header + r.Header.Set("Authorization", goodBearerAuthValue) + b = CheckBearerAuth(r) + if b == nil { + t.Errorf("Could not extract bearer auth") + return + } + + // check extracted auth data + if b.Code != "BGFVTDUJDp0ZXN0" { + t.Errorf("Error decoding bearer auth") + } + + // extracts bearer auth from query string + url, _ := url.Parse("http://host.tld/path?code=XYZ") + r = &http.Request{URL: url} + r.ParseForm() + b = CheckBearerAuth(r) + if b.Code != "XYZ" { + t.Errorf("Error decoding bearer auth") + } +} + +// DefaultClient stores all data in struct variables. +type testClient struct { + ID string + Secret string + RedirectURI string + UserData interface{} +} + +func (d *testClient) GetID() string { + return d.ID +} + +func (d *testClient) GetSecret() string { + return d.Secret +} + +func (d *testClient) GetRedirectURI() string { + return d.RedirectURI +} + +func (d *testClient) GetUserData() interface{} { + return d.UserData +} + +func (d *testClient) CopyFrom(client Client) { + d.ID = client.GetID() + d.Secret = client.GetSecret() + d.RedirectURI = client.GetRedirectURI() + d.UserData = client.GetUserData() +} + +func TestCheckClientSecret(t *testing.T) { + { + client := &DefaultClient{ + Secret: "foo", + } + + assert.True(t, CheckClientSecret(client, "foo")) + } + + { + client := &testClient{ + Secret: "foo", + } + + assert.True(t, CheckClientSecret(client, "foo")) + } +} diff --git a/internal/integrations/oauth2_auth_by_access_token_test.go b/internal/integrations/oauth2_auth_by_access_token_test.go index 392b271..1ce63b2 100644 --- a/internal/integrations/oauth2_auth_by_access_token_test.go +++ b/internal/integrations/oauth2_auth_by_access_token_test.go @@ -37,7 +37,7 @@ func TestOauth2AuthByAccessTokenWithNoAuthHeader(t *testing.T) { storageProvider.SaveClient(client) - storageProvider.SaveAccess(&oauth2.AccessInfo{ + storageProvider.SaveAccess(&oauth2.AccessData{ Client: client, AccessToken: "I3SoKTVXi6QzMZAmDW2Fgw2MLX0msPGRN58bCDLDFthJmy6Qoy8FH5v10dbewR6PfAV3brKhepjnTJVhDplSHFe6qbF3J4YDkI5EzXG0S8X7snSoB6FtrPNFMmISuEmU", UserData: "8c87a032-755d-42f6-be96-0421948f6e94", @@ -84,7 +84,7 @@ func TestOauth2AuthByAccessTokenWithBadToken(t *testing.T) { storageProvider.SaveClient(client) - storageProvider.SaveAccess(&oauth2.AccessInfo{ + storageProvider.SaveAccess(&oauth2.AccessData{ Client: client, AccessToken: "I3SoKTVXi6QzMZAmDW2Fgw2MLX0msPGRN58bCDLDFthJmy6Qoy8FH5v10dbewR6PfAV3brKhepjnTJVhDplSHFe6qbF3J4YDkI5EzXG0S8X7snSoB6FtrPNFMmISuEmU", }) @@ -138,7 +138,7 @@ func TestOauth2AuthByAccessToken(t *testing.T) { storageProvider.SaveClient(client) - storageProvider.SaveAccess(&oauth2.AccessInfo{ + storageProvider.SaveAccess(&oauth2.AccessData{ Client: client, AccessToken: "I3SoKTVXi6QzMZAmDW2Fgw2MLX0msPGRN58bCDLDFthJmy6Qoy8FH5v10dbewR6PfAV3brKhepjnTJVhDplSHFe6qbF3J4YDkI5EzXG0S8X7snSoB6FtrPNFMmISuEmU", ExpiresIn: 60, From df5dfcc331697e37e6bb9d0db0e1a75e50b78aa9 Mon Sep 17 00:00:00 2001 From: Axel Etcheverry Date: Sun, 15 Aug 2021 23:59:31 +0200 Subject: [PATCH 4/9] test(authentication/oauth2): add more many tests --- .../username_password_credential.go | 2 +- authentication/provider/oauth2/access.go | 8 +- authentication/provider/oauth2/access_test.go | 387 +++++++++++++++++- authentication/provider/oauth2/error_test.go | 23 ++ authentication/provider/oauth2/response.go | 2 +- .../provider/oauth2/response_test.go | 84 ++++ authentication/provider/oauth2/server.go | 5 - authentication/provider/oauth2/server_test.go | 13 + authentication/provider/oauth2/urivalidate.go | 20 +- .../provider/oauth2/urivalidate_test.go | 46 +++ authentication/provider/oauth2/util.go | 22 +- authentication/provider/oauth2/util_test.go | 28 +- 12 files changed, 608 insertions(+), 32 deletions(-) create mode 100644 authentication/provider/oauth2/error_test.go create mode 100644 authentication/provider/oauth2/response_test.go diff --git a/authentication/credential/username_password_credential.go b/authentication/credential/username_password_credential.go index 712fdbb..dc1b02b 100644 --- a/authentication/credential/username_password_credential.go +++ b/authentication/credential/username_password_credential.go @@ -17,7 +17,7 @@ type UsernamePasswordCredential struct { var _ Credential = (*UsernamePasswordCredential)(nil) // NewUsernamePasswordCredential constructor. -func NewUsernamePasswordCredential(principal string, credentials string) Credential { +func NewUsernamePasswordCredential(principal string, credentials string) *UsernamePasswordCredential { return &UsernamePasswordCredential{ credentials: credentials, principal: principal, diff --git a/authentication/provider/oauth2/access.go b/authentication/provider/oauth2/access.go index c857410..229c296 100644 --- a/authentication/provider/oauth2/access.go +++ b/authentication/provider/oauth2/access.go @@ -17,6 +17,10 @@ import ( "github.com/hyperscale-stack/security/authentication/credential" ) +var ( + ErrRequestMustBePost = errors.New("request must be POST") +) + // AccessRequestType is the type for OAuth2 param `grant_type` type AccessRequestType string @@ -134,12 +138,12 @@ func (s *Server) HandleAccessRequest(w *Response, r *http.Request) *AccessReques // Only allow GET or POST if r.Method == http.MethodGet { if !s.cfg.AllowGetAccessRequest { - s.setErrorAndLog(w, E_INVALID_REQUEST, errors.New("Request must be POST"), "access_request=%s", "GET request not allowed") + s.setErrorAndLog(w, E_INVALID_REQUEST, ErrRequestMustBePost, "access_request=%s", "GET request not allowed") return nil } } else if r.Method != http.MethodPost { - s.setErrorAndLog(w, E_INVALID_REQUEST, errors.New("Request must be POST"), "access_request=%s", "request must be POST") + s.setErrorAndLog(w, E_INVALID_REQUEST, ErrRequestMustBePost, "access_request=%s", "request must be POST") return nil } diff --git a/authentication/provider/oauth2/access_test.go b/authentication/provider/oauth2/access_test.go index 576bd84..5296947 100644 --- a/authentication/provider/oauth2/access_test.go +++ b/authentication/provider/oauth2/access_test.go @@ -6,10 +6,18 @@ package oauth2 import ( "context" + "errors" + "net/http" + "net/http/httptest" + "net/url" + "strconv" + "strings" "testing" "time" + "github.com/hyperscale-stack/security/authentication/credential" "github.com/stretchr/testify/assert" + mock "github.com/stretchr/testify/mock" ) func TestAccessData(t *testing.T) { @@ -43,6 +51,383 @@ func TestFromContextWithEmptyContext(t *testing.T) { ctx := context.Background() ai := AccessTokenFromContext(ctx) - assert.Nil(t, ai) } + +func TestServerHandleAccessRequestWithGetMethodNotAllowed(t *testing.T) { + cfg := &Configuration{ + ErrorStatusCode: http.StatusOK, + AllowGetAccessRequest: false, + } + s := NewServer(WithConfig(cfg)) + + w := s.NewResponse() + + req := httptest.NewRequest(http.MethodGet, "http://example.com/v1/me", nil) + + ar := s.HandleAccessRequest(w, req) + assert.Nil(t, ar) + + assert.Equal(t, E_INVALID_REQUEST, w.ErrorID) +} + +func TestServerHandleAccessRequestWithBadMethod(t *testing.T) { + cfg := &Configuration{ + ErrorStatusCode: http.StatusOK, + AllowGetAccessRequest: false, + } + s := NewServer(WithConfig(cfg)) + + w := s.NewResponse() + + req := httptest.NewRequest(http.MethodPut, "http://example.com/v1/me", nil) + + ar := s.HandleAccessRequest(w, req) + assert.Nil(t, ar) + + assert.Equal(t, E_INVALID_REQUEST, w.ErrorID) +} + +type mockReadCloser struct { + mock.Mock +} + +func (m *mockReadCloser) Read(p []byte) (n int, err error) { + args := m.Called(p) + + return args.Int(0), args.Error(1) +} + +func (m *mockReadCloser) Close() error { + args := m.Called() + + return args.Error(0) +} + +func TestServerHandleAccessRequestWithBadBody(t *testing.T) { + /*mockReadCloser := &mockReadCloser{} + // if Read is called, it will return error + mockReadCloser.On("Read", mock.AnythingOfType("[]uint8")).Return(0, fmt.Errorf("error reading")) + // if Close is called, it will return error + mockReadCloser.On("Close").Return(fmt.Errorf("error closing")) + */ + cfg := &Configuration{ + ErrorStatusCode: http.StatusOK, + AllowGetAccessRequest: false, + } + s := NewServer(WithConfig(cfg)) + + w := s.NewResponse() + + req := httptest.NewRequest(http.MethodPost, "http://example.com/v1/me?f$$", nil) + + req.Body = nil + + ar := s.HandleAccessRequest(w, req) + assert.Nil(t, ar) + + assert.Equal(t, E_INVALID_REQUEST, w.ErrorID) +} + +func TestServerHandleAccessRequestWithEmptyBody(t *testing.T) { + cfg := &Configuration{ + ErrorStatusCode: http.StatusOK, + AllowGetAccessRequest: false, + } + s := NewServer(WithConfig(cfg)) + + w := s.NewResponse() + + req := httptest.NewRequest(http.MethodPost, "http://example.com/v1/me?f$$", nil) + + ar := s.HandleAccessRequest(w, req) + assert.Nil(t, ar) + + assert.Equal(t, E_UNSUPPORTED_GRANT_TYPE, w.ErrorID) +} + +func TestServerHandleAccessRequestWithPasswordGrandTypeWithInvalidRequest(t *testing.T) { + cfg := &Configuration{ + ErrorStatusCode: http.StatusOK, + AllowGetAccessRequest: false, + AllowedAccessTypes: AllowedAccessType{ + PASSWORD, + }, + } + s := NewServer(WithConfig(cfg)) + + w := s.NewResponse() + + data := url.Values{} + + data.Set("grant_type", "password") + + req := httptest.NewRequest(http.MethodPost, "http://example.com/oauth/token", strings.NewReader(data.Encode())) + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") + req.Header.Add("Content-Length", strconv.Itoa(len(data.Encode()))) + + ar := s.HandleAccessRequest(w, req) + assert.Nil(t, ar) + + assert.Equal(t, E_INVALID_REQUEST, w.ErrorID) +} + +func TestServerGetClientWithErrClientNotFound(t *testing.T) { + cfg := &Configuration{} + storageMock := &MockStorageProvider{} + + storageMock.On("LoadClient", "9b48f589-735c-476c-aa5f-eae9e2422d01").Return(nil, ErrClientNotFound) + + s := NewServer(WithConfig(cfg), WithStorage(storageMock)) + + w := s.NewResponse() + + creds := credential.NewUsernamePasswordCredential("9b48f589-735c-476c-aa5f-eae9e2422d01", "foo") + + c := s.getClient(creds, storageMock, w) + assert.Nil(t, c) + + assert.Equal(t, E_UNAUTHORIZED_CLIENT, w.ErrorID) + + storageMock.AssertExpectations(t) +} + +func TestServerGetClientWithErr(t *testing.T) { + cfg := &Configuration{} + storageMock := &MockStorageProvider{} + + storageMock.On("LoadClient", "9b48f589-735c-476c-aa5f-eae9e2422d01").Return(nil, errors.New("foo")) + + s := NewServer(WithConfig(cfg), WithStorage(storageMock)) + + w := s.NewResponse() + + creds := credential.NewUsernamePasswordCredential("9b48f589-735c-476c-aa5f-eae9e2422d01", "foo") + + c := s.getClient(creds, storageMock, w) + assert.Nil(t, c) + + assert.Equal(t, E_SERVER_ERROR, w.ErrorID) + + storageMock.AssertExpectations(t) +} + +func TestServerGetClientWithClientEmpry(t *testing.T) { + cfg := &Configuration{} + storageMock := &MockStorageProvider{} + + storageMock.On("LoadClient", "9b48f589-735c-476c-aa5f-eae9e2422d01").Return(nil, nil) + + s := NewServer(WithConfig(cfg), WithStorage(storageMock)) + + w := s.NewResponse() + + creds := credential.NewUsernamePasswordCredential("9b48f589-735c-476c-aa5f-eae9e2422d01", "foo") + + c := s.getClient(creds, storageMock, w) + assert.Nil(t, c) + + assert.Equal(t, E_UNAUTHORIZED_CLIENT, w.ErrorID) + + storageMock.AssertExpectations(t) +} + +func TestServerGetClientWithClientBadSecret(t *testing.T) { + cfg := &Configuration{} + storageMock := &MockStorageProvider{} + + client := &DefaultClient{ + Secret: "bar", + } + + storageMock.On("LoadClient", "9b48f589-735c-476c-aa5f-eae9e2422d01").Return(client, nil) + + s := NewServer(WithConfig(cfg), WithStorage(storageMock)) + + w := s.NewResponse() + + creds := credential.NewUsernamePasswordCredential("9b48f589-735c-476c-aa5f-eae9e2422d01", "foo") + + c := s.getClient(creds, storageMock, w) + assert.Nil(t, c) + + assert.Equal(t, E_UNAUTHORIZED_CLIENT, w.ErrorID) + + storageMock.AssertExpectations(t) +} + +func TestServerGetClientWithClientBadRedirect(t *testing.T) { + cfg := &Configuration{} + storageMock := &MockStorageProvider{} + + client := &DefaultClient{ + Secret: "foo", + RedirectURI: "", + } + + storageMock.On("LoadClient", "9b48f589-735c-476c-aa5f-eae9e2422d01").Return(client, nil) + + s := NewServer(WithConfig(cfg), WithStorage(storageMock)) + + w := s.NewResponse() + + creds := credential.NewUsernamePasswordCredential("9b48f589-735c-476c-aa5f-eae9e2422d01", "foo") + + c := s.getClient(creds, storageMock, w) + assert.Nil(t, c) + + assert.Equal(t, E_UNAUTHORIZED_CLIENT, w.ErrorID) + + storageMock.AssertExpectations(t) +} + +func TestServerGetClient(t *testing.T) { + cfg := &Configuration{} + storageMock := &MockStorageProvider{} + + client := &DefaultClient{ + Secret: "foo", + RedirectURI: "https://auth.mydomain.tld/connect", + } + + storageMock.On("LoadClient", "9b48f589-735c-476c-aa5f-eae9e2422d01").Return(client, nil) + + s := NewServer(WithConfig(cfg), WithStorage(storageMock)) + + w := s.NewResponse() + + creds := credential.NewUsernamePasswordCredential("9b48f589-735c-476c-aa5f-eae9e2422d01", "foo") + + c := s.getClient(creds, storageMock, w) + assert.Same(t, client, c) + + storageMock.AssertExpectations(t) +} + +func TestExtraScopes(t *testing.T) { + assert.True(t, extraScopes("foo bar", "foo bar jar")) + assert.False(t, extraScopes("foo bar", "foo bar")) + + assert.False(t, extraScopes(" ", " ")) +} + +func TestServerHandlePasswordRequestWithEmptyUsernameAndPassword(t *testing.T) { + cfg := &Configuration{ + ErrorStatusCode: http.StatusOK, + AllowGetAccessRequest: false, + AllowedAccessTypes: AllowedAccessType{ + PASSWORD, + }, + } + + storageMock := &MockStorageProvider{} + + storageMock.On("LoadClient", "50542ad2-5983-4977-baab-ef3794f08c89").Return(nil, nil) + + s := NewServer(WithConfig(cfg), WithStorage(storageMock)) + + w := s.NewResponse() + /* + data := url.Values{} + + data.Set("grant_type", "password") + data.Set("username", "") + data.Set("password", "") + */ + req := httptest.NewRequest(http.MethodPost, "http://example.com/oauth/token", nil) + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") + //req.Header.Add("Content-Length", strconv.Itoa(len(data.Encode()))) + + ctx := req.Context() + ctx = credential.ToContext(ctx, credential.NewUsernamePasswordCredential("", "")) + req = req.WithContext(ctx) + + ar := s.handlePasswordRequest(w, req) + assert.Nil(t, ar) + + assert.Equal(t, E_INVALID_GRANT, w.ErrorID) +} + +func TestServerHandlePasswordRequestWithClientNotFound(t *testing.T) { + cfg := &Configuration{ + ErrorStatusCode: http.StatusOK, + AllowGetAccessRequest: false, + AllowedAccessTypes: AllowedAccessType{ + PASSWORD, + }, + RedirectURISeparator: " ", + } + + storageMock := &MockStorageProvider{} + + storageMock.On("LoadClient", "50542ad2-5983-4977-baab-ef3794f08c89").Return(nil, ErrClientNotFound) + + s := NewServer(WithConfig(cfg), WithStorage(storageMock)) + + w := s.NewResponse() + + data := url.Values{} + + data.Set("grant_type", "password") + data.Set("username", "user@domain.com") + data.Set("password", "passw0rd") + + req := httptest.NewRequest(http.MethodPost, "http://example.com/oauth/token", strings.NewReader(data.Encode())) + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") + + ctx := req.Context() + ctx = credential.ToContext(ctx, credential.NewUsernamePasswordCredential("50542ad2-5983-4977-baab-ef3794f08c89", "mySecretPassw0rd!")) + req = req.WithContext(ctx) + + ar := s.handlePasswordRequest(w, req) + assert.Nil(t, ar) + + assert.Equal(t, E_UNAUTHORIZED_CLIENT, w.ErrorID) + + storageMock.AssertExpectations(t) +} + +func TestServerHandlePasswordRequest(t *testing.T) { + cfg := &Configuration{ + ErrorStatusCode: http.StatusOK, + AllowGetAccessRequest: false, + AllowedAccessTypes: AllowedAccessType{ + PASSWORD, + }, + RedirectURISeparator: " ", + } + + storageMock := &MockStorageProvider{} + + client := &DefaultClient{ + ID: "50542ad2-5983-4977-baab-ef3794f08c89", + Secret: "mySecretPassw0rd!", + RedirectURI: "https://auth.mydomain.tld/connect", + } + + storageMock.On("LoadClient", "50542ad2-5983-4977-baab-ef3794f08c89").Return(client, nil) + + s := NewServer(WithConfig(cfg), WithStorage(storageMock)) + + w := s.NewResponse() + + data := url.Values{} + + data.Set("grant_type", "password") + data.Set("username", "user@domain.com") + data.Set("password", "passw0rd") + + req := httptest.NewRequest(http.MethodPost, "http://example.com/oauth/token", strings.NewReader(data.Encode())) + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") + + ctx := req.Context() + ctx = credential.ToContext(ctx, credential.NewUsernamePasswordCredential("50542ad2-5983-4977-baab-ef3794f08c89", "mySecretPassw0rd!")) + req = req.WithContext(ctx) + + ar := s.handlePasswordRequest(w, req) + assert.NotNil(t, ar) + + assert.Same(t, client, ar.Client) + + storageMock.AssertExpectations(t) +} diff --git a/authentication/provider/oauth2/error_test.go b/authentication/provider/oauth2/error_test.go new file mode 100644 index 0000000..c2b4e2c --- /dev/null +++ b/authentication/provider/oauth2/error_test.go @@ -0,0 +1,23 @@ +// Copyright 2021 Hyperscale. All rights reserved. +// Use of this source code is governed by a MIT +// license that can be found in the LICENSE file. + +package oauth2 + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestErrorGet(t *testing.T) { + for _, item := range []struct { + id DefaultErrorID + expedted string + }{ + {E_ACCESS_DENIED, "The resource owner or authorization server denied the request."}, + {DefaultErrorID("foo"), "foo"}, + } { + assert.Equal(t, item.expedted, deferror.Get(item.id)) + } +} diff --git a/authentication/provider/oauth2/response.go b/authentication/provider/oauth2/response.go index 985c16f..4ff6612 100644 --- a/authentication/provider/oauth2/response.go +++ b/authentication/provider/oauth2/response.go @@ -12,7 +12,7 @@ import ( ) var ( - ErrNotARedirectResponse = errors.New("Not a redirect response") + ErrNotARedirectResponse = errors.New("not a redirect response") ) // Data for response output diff --git a/authentication/provider/oauth2/response_test.go b/authentication/provider/oauth2/response_test.go new file mode 100644 index 0000000..4d9c2c5 --- /dev/null +++ b/authentication/provider/oauth2/response_test.go @@ -0,0 +1,84 @@ +// Copyright 2021 Hyperscale. All rights reserved. +// Use of this source code is governed by a MIT +// license that can be found in the LICENSE file. + +package oauth2 + +import ( + "net/http" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestResponseGetRedirectURLWithDataRequest(t *testing.T) { + storageMock := &MockStorageProvider{} + + r := NewResponse(storageMock) + + r.Type = DATA + + url, err := r.GetRedirectURL() + assert.EqualError(t, err, ErrNotARedirectResponse.Error()) + + assert.Equal(t, "", url) +} + +func TestResponseGetRedirectURLWithRedirectInFragment(t *testing.T) { + storageMock := &MockStorageProvider{} + + r := NewResponse(storageMock) + + r.Output["foo"] = "bar" + r.SetRedirect("https://oauth.mydomain.tld/connect") + r.SetRedirectFragment(true) + + url, err := r.GetRedirectURL() + assert.NoError(t, err) + + assert.Equal(t, "https://oauth.mydomain.tld/connect#foo=bar", url) +} + +func TestResponseSetErrorURI(t *testing.T) { + storageMock := &MockStorageProvider{} + + r := NewResponse(storageMock) + + r.SetErrorURI(E_ACCESS_DENIED, "access denied", "https://oauth.mydomain.tld/connect", "foobar") + + assert.True(t, r.IsError) + assert.Equal(t, http.StatusOK, r.ErrorStatusCode) + assert.Equal(t, E_ACCESS_DENIED, r.ErrorID) + + assert.Equal(t, "", r.StatusText) + + assert.Contains(t, r.Output, "error_uri") + assert.Equal(t, "https://oauth.mydomain.tld/connect", r.Output["error_uri"]) + + assert.Contains(t, r.Output, "state") + assert.Equal(t, "foobar", r.Output["state"]) + + assert.Contains(t, r.Output, "error") + assert.Equal(t, E_ACCESS_DENIED, r.Output["error"]) + + assert.Contains(t, r.Output, "error_description") + assert.Equal(t, "access denied", r.Output["error_description"]) + +} + +func TestResponseSetErrorState(t *testing.T) { + storageMock := &MockStorageProvider{} + + r := NewResponse(storageMock) + + r.SetErrorState(E_ACCESS_DENIED, "", "foobar") + + assert.Contains(t, r.Output, "error") + assert.Equal(t, E_ACCESS_DENIED, r.Output["error"]) + + assert.Contains(t, r.Output, "error_description") + assert.Equal(t, "The resource owner or authorization server denied the request.", r.Output["error_description"]) + + assert.Contains(t, r.Output, "state") + assert.Equal(t, "foobar", r.Output["state"]) +} diff --git a/authentication/provider/oauth2/server.go b/authentication/provider/oauth2/server.go index 0e88710..6de953c 100644 --- a/authentication/provider/oauth2/server.go +++ b/authentication/provider/oauth2/server.go @@ -5,7 +5,6 @@ package oauth2 import ( - "errors" "time" "github.com/hyperscale-stack/logger" @@ -13,10 +12,6 @@ import ( "github.com/hyperscale-stack/security/authentication/provider/oauth2/token/random" ) -var ( - ErrRequestMustBePost = errors.New("request must be POST") -) - type Server struct { cfg *Configuration logger logger.Logger diff --git a/authentication/provider/oauth2/server_test.go b/authentication/provider/oauth2/server_test.go index 8a63e31..e8c8aef 100644 --- a/authentication/provider/oauth2/server_test.go +++ b/authentication/provider/oauth2/server_test.go @@ -5,6 +5,7 @@ package oauth2 import ( + "net/http" "testing" "github.com/stretchr/testify/assert" @@ -16,3 +17,15 @@ func TestServer(t *testing.T) { assert.Same(t, cfg, server.cfg) } + +func TestServerNewResponse(t *testing.T) { + cfg := &Configuration{ + ErrorStatusCode: http.StatusOK, + } + server := NewServer(WithConfig(cfg)) + + response := server.NewResponse() + assert.NotNil(t, response) + + assert.Equal(t, cfg.ErrorStatusCode, response.ErrorStatusCode) +} diff --git a/authentication/provider/oauth2/urivalidate.go b/authentication/provider/oauth2/urivalidate.go index e3533ae..6713f6b 100644 --- a/authentication/provider/oauth2/urivalidate.go +++ b/authentication/provider/oauth2/urivalidate.go @@ -11,7 +11,11 @@ import ( "strings" ) -// error returned when validation don't match +var ( + ErrNoBlank = errors.New("urls cannot be blank") +) + +// error returned when validation don't match. type URIValidationError string func (e URIValidationError) Error() string { @@ -22,17 +26,17 @@ func newURIValidationError(msg string, base string, redirect string) URIValidati return URIValidationError(fmt.Sprintf("%s: %s / %s", msg, base, redirect)) } -// Parse urls, resolving uri references to base url +// ParseURLs resolving uri references to base url. func ParseURLs(baseUrl, redirectUrl string) (retBaseUrl, retRedirectUrl *url.URL, err error) { var base, redirect *url.URL // parse base url if base, err = url.Parse(baseUrl); err != nil { - return nil, nil, err + return nil, nil, fmt.Errorf("parse base url failed: %w", err) } // parse redirect url if redirect, err = url.Parse(redirectUrl); err != nil { - return nil, nil, err + return nil, nil, fmt.Errorf("parse redirect url failed: %w", err) } // must not have fragment @@ -89,7 +93,7 @@ func ValidateURIList(baseUriList string, redirectUri string, separator string) ( // ValidateURI validates that redirectUri is contained in baseUri. func ValidateURI(baseUri string, redirectUri string) (realRedirectUri string, err error) { if baseUri == "" || redirectUri == "" { - return "", errors.New("urls cannot be blank") + return "", ErrNoBlank } base, redirect, err := ParseURLs(baseUri, redirectUri) @@ -117,9 +121,5 @@ func FirstURI(baseUriList string, separator string) string { return baseUriList } - if slist := strings.Split(baseUriList, separator); len(slist) > 0 { - return slist[0] - } - - return "" + return strings.Split(baseUriList, separator)[0] } diff --git a/authentication/provider/oauth2/urivalidate_test.go b/authentication/provider/oauth2/urivalidate_test.go index 2b498ec..dbcd2e5 100644 --- a/authentication/provider/oauth2/urivalidate_test.go +++ b/authentication/provider/oauth2/urivalidate_test.go @@ -80,6 +80,14 @@ func TestURIValidate(t *testing.T) { } invalid := [][]string{ + { + "", + "", + }, + { + "\n", + "\n", + }, { // Doesn't satisfy base path "http://localhost:14000/appauth", @@ -153,6 +161,11 @@ func TestURIListValidate(t *testing.T) { if _, err := ValidateURIList("http://xxx:14000/appauth;http://localhost:14000/appauth", "http://localhost:14000/app", ";"); err == nil { t.Error("V4 should have failed") } + + // V5 + if _, err := ValidateURIList("\n", "\n", ";"); err == nil { + t.Error("V5 should have failed") + } } func TestFirstURI(t *testing.T) { @@ -160,4 +173,37 @@ func TestFirstURI(t *testing.T) { assert.Equal(t, "mybundle://connect", FirstURI("mybundle://connect", " ")) assert.Equal(t, "mybundle://connect", FirstURI("mybundle://connect", "")) assert.Equal(t, "", FirstURI("", " ")) + assert.Equal(t, "mybundle://connect", FirstURI("mybundle://connect", ";")) +} + +func TestNewURIValidationError(t *testing.T) { + err := newURIValidationError("scheme mismatch", "http://www.google.com/myapp", "http://www.google.com/myapp../test") + + assert.EqualError(t, err, "scheme mismatch: http://www.google.com/myapp / http://www.google.com/myapp../test") +} + +func TestParseURLs(t *testing.T) { + invalid := [][]string{ + { + "\n", + "\n", + }, + { + "https://google.com", + "\n", + }, + { + "https://google.com#foo", + "https://google.com#foo", + }, + { + "http://google.com", + "https://google.com", + }, + } + for _, v := range invalid { + if _, err := ValidateURI(v[0], v[1]); err == nil { + t.Errorf("Expected ValidateURI(%s, %s) to fail", v[0], v[1]) + } + } } diff --git a/authentication/provider/oauth2/util.go b/authentication/provider/oauth2/util.go index 67f3be5..1f59340 100644 --- a/authentication/provider/oauth2/util.go +++ b/authentication/provider/oauth2/util.go @@ -8,11 +8,18 @@ import ( "crypto/subtle" "encoding/base64" "errors" + "fmt" "net/http" "net/url" "strings" "github.com/hyperscale-stack/security/authentication/credential" + "github.com/hyperscale-stack/security/http/header" +) + +var ( + ErrInvalidAuthorizationHeader = errors.New("invalid authorization header") + ErrInvalidAuthorizationMessage = errors.New("invalid authorization message") ) // Parse basic authentication header @@ -45,18 +52,19 @@ func CheckBasicAuth(r *http.Request) (*BasicAuth, error) { return nil, nil } - s := strings.SplitN(r.Header.Get("Authorization"), " ", 2) - if len(s) != 2 || s[0] != "Basic" { - return nil, errors.New("Invalid authorization header") + b64, ok := header.ExtractAuthorizationValue("Basic", r.Header.Get("Authorization")) + if !ok { + return nil, ErrInvalidAuthorizationHeader } - b, err := base64.StdEncoding.DecodeString(s[1]) + b, err := base64.StdEncoding.DecodeString(b64) if err != nil { - return nil, err + return nil, fmt.Errorf("decode basic auth failed: %w", err) } + pair := strings.SplitN(string(b), ":", 2) if len(pair) != 2 { - return nil, errors.New("Invalid authorization message") + return nil, ErrInvalidAuthorizationMessage } // Decode the client_id and client_secret pairs as per @@ -113,7 +121,7 @@ func (s Server) getClientAuth(w *Response, r *http.Request, allowQueryParams boo ) if auth.GetPrincipal() != "" { - return auth.(*credential.UsernamePasswordCredential) + return auth } } } diff --git a/authentication/provider/oauth2/util_test.go b/authentication/provider/oauth2/util_test.go index 910878d..1a44726 100644 --- a/authentication/provider/oauth2/util_test.go +++ b/authentication/provider/oauth2/util_test.go @@ -14,11 +14,13 @@ import ( ) const ( - badAuthValue = "Digest XHHHHHHH" - badUsernameInAuthValue = "Basic dSUyc2VybmFtZTpwYXNzd29yZA==" // u%2sername:password - badPasswordInAuthValue = "Basic dXNlcm5hbWU6cGElMnN3b3Jk" // username:pa%2sword - goodAuthValue = "Basic Y2xpZW50K25hbWU6Y2xpZW50KyUyNGVjcmV0" - goodBearerAuthValue = "Bearer BGFVTDUJDp0ZXN0" + badAuthValue = "Digest XHHHHHHH" + badBasicAuthValue = "Basic €€€" + badBasicAuthWithBadFormat = "Basic Zm9vCg==" // foo + badUsernameInAuthValue = "Basic dSUyc2VybmFtZTpwYXNzd29yZA==" // u%2sername:password + badPasswordInAuthValue = "Basic dXNlcm5hbWU6cGElMnN3b3Jk" // username:pa%2sword + goodAuthValue = "Basic Y2xpZW50K25hbWU6Y2xpZW50KyUyNGVjcmV0" + goodBearerAuthValue = "Bearer BGFVTDUJDp0ZXN0" ) func TestBasicAuth(t *testing.T) { @@ -37,6 +39,22 @@ func TestBasicAuth(t *testing.T) { return } + // with invalid value + r.Header.Set("Authorization", badBasicAuthValue) + b, err = CheckBasicAuth(r) + if b != nil || err == nil { + t.Errorf("Validated invalid auth") + return + } + + // with invalid format + r.Header.Set("Authorization", badBasicAuthWithBadFormat) + b, err = CheckBasicAuth(r) + if b != nil || err == nil { + t.Errorf("Validated invalid auth") + return + } + // with invalid username r.Header.Set("Authorization", badUsernameInAuthValue) b, err = CheckBasicAuth(r) From 8795a49eed648f30966efb81fbbeb8781d5ce320 Mon Sep 17 00:00:00 2001 From: Axel Etcheverry Date: Mon, 16 Aug 2021 00:29:37 +0200 Subject: [PATCH 5/9] fix(authentication/oauth2): lint --- authentication/provider/oauth2/access.go | 149 +++++++++++------- authentication/provider/oauth2/authorize.go | 4 +- .../provider/oauth2/configuration.go | 8 +- authentication/provider/oauth2/error.go | 2 +- authentication/provider/oauth2/option.go | 2 +- authentication/provider/oauth2/response.go | 18 +-- .../provider/oauth2/response_json.go | 5 +- authentication/provider/oauth2/server.go | 2 +- authentication/provider/oauth2/urivalidate.go | 1 + authentication/provider/oauth2/util.go | 29 ++-- 10 files changed, 126 insertions(+), 94 deletions(-) diff --git a/authentication/provider/oauth2/access.go b/authentication/provider/oauth2/access.go index 229c296..fa39c96 100644 --- a/authentication/provider/oauth2/access.go +++ b/authentication/provider/oauth2/access.go @@ -18,10 +18,15 @@ import ( ) var ( - ErrRequestMustBePost = errors.New("request must be POST") + ErrRequestMustBePost = errors.New("request must be POST") + ErrExtraScope = errors.New("the requested scope must not include any scope not originally granted by the resource owner") + ErrClientIDNotSame = errors.New("client id must be the same from previous token") + ErrCodeChallengeNotSame = errors.New("code_verifier failed comparison with code_challenge") + ErrCodeVerifierInvalidFormat = errors.New("code_verifier has invalid format") + ErrRedirectURINotSame = errors.New("redirect uri is different") ) -// AccessRequestType is the type for OAuth2 param `grant_type` +// AccessRequestType is the type for OAuth2 param `grant_type`. type AccessRequestType string const ( @@ -33,7 +38,7 @@ const ( IMPLICIT AccessRequestType = "__implicit" ) -// AccessRequest is a request for access tokens +// AccessRequest is a request for access tokens. type AccessRequest struct { Type AccessRequestType Code string @@ -133,7 +138,7 @@ func (i *AccessData) ExpireAt() time.Time { return i.CreatedAt.Add(time.Duration(i.ExpiresIn) * time.Second) } -// HandleAccessRequest is the http.HandlerFunc for handling access token requests +// HandleAccessRequest is the http.HandlerFunc for handling access token requests. func (s *Server) HandleAccessRequest(w *Response, r *http.Request) *AccessRequest { // Only allow GET or POST if r.Method == http.MethodGet { @@ -161,6 +166,7 @@ func (s *Server) HandleAccessRequest(w *Response, r *http.Request) *AccessReques return nil } + //nolint: exhaustive switch grantType { case AUTHORIZATION_CODE: return s.handleAuthorizationCodeRequest(w, r) @@ -170,8 +176,12 @@ func (s *Server) HandleAccessRequest(w *Response, r *http.Request) *AccessReques return s.handlePasswordRequest(w, r) case CLIENT_CREDENTIALS: return s.handleClientCredentialsRequest(w, r) - default: + case ASSERTION: return s.handleAssertionRequest(w, r) + default: + s.setErrorAndLog(w, E_UNSUPPORTED_GRANT_TYPE, nil, "access_request=%s", "unknown grant type") + + return nil } } @@ -250,15 +260,17 @@ func (s *Server) handleAuthorizationCodeRequest(w *Response, r *http.Request) *A if ret.RedirectURI == "" { ret.RedirectURI = FirstURI(ret.Client.GetRedirectURI(), s.cfg.RedirectURISeparator) } - if realRedirectURI, err := ValidateURIList(ret.Client.GetRedirectURI(), ret.RedirectURI, s.cfg.RedirectURISeparator); err != nil { + if realRedirectURI, err := ValidateURIList(ret.Client.GetRedirectURI(), ret.RedirectURI, s.cfg.RedirectURISeparator); err != nil { s.setErrorAndLog(w, E_INVALID_REQUEST, err, "auth_code_request=%s", "error validating client redirect") + return nil } else { ret.RedirectURI = realRedirectURI } + if ret.AuthorizeData.RedirectURI != ret.RedirectURI { - s.setErrorAndLog(w, E_INVALID_REQUEST, errors.New("Redirect uri is different"), "auth_code_request=%s", "client redirect does not match authorization data") + s.setErrorAndLog(w, E_INVALID_REQUEST, ErrRedirectURINotSame, "auth_code_request=%s", "client redirect does not match authorization data") return nil } @@ -270,7 +282,7 @@ func (s *Server) handleAuthorizationCodeRequest(w *Response, r *http.Request) *A s.setErrorAndLog( w, E_INVALID_REQUEST, - errors.New("code_verifier has invalid format"), + ErrCodeVerifierInvalidFormat, "auth_code_request=%s", "pkce code challenge verifier does not match", ) @@ -280,6 +292,7 @@ func (s *Server) handleAuthorizationCodeRequest(w *Response, r *http.Request) *A // https: //tools.ietf.org/html/rfc7636#section-4.6 codeVerifier := "" + switch ret.AuthorizeData.CodeChallengeMethod { case "", PKCE_PLAIN: codeVerifier = ret.CodeVerifier @@ -302,7 +315,7 @@ func (s *Server) handleAuthorizationCodeRequest(w *Response, r *http.Request) *A s.setErrorAndLog( w, E_INVALID_GRANT, - errors.New("code_verifier failed comparison with code_challenge"), + ErrCodeChallengeNotSame, "auth_code_request=%s", "pkce code verifier does not match challenge", ) @@ -407,7 +420,7 @@ func (s *Server) handleRefreshTokenRequest(w *Response, r *http.Request) *Access s.setErrorAndLog( w, E_INVALID_CLIENT, - errors.New("Client id must be the same from previous token"), + ErrClientIDNotSame, "refresh_token=%s, current=%v, previous=%v", "client mismatch", ret.Client.GetID(), @@ -426,8 +439,7 @@ func (s *Server) handleRefreshTokenRequest(w *Response, r *http.Request) *Access } if extraScopes(ret.AccessData.Scope, ret.Scope) { - msg := "the requested scope must not include any scope not originally granted by the resource owner" - s.setErrorAndLog(w, E_ACCESS_DENIED, errors.New(msg), "refresh_token=%s", msg) + s.setErrorAndLog(w, E_ACCESS_DENIED, ErrExtraScope, "refresh_token=%s", ErrExtraScope.Error()) return nil } @@ -435,6 +447,7 @@ func (s *Server) handleRefreshTokenRequest(w *Response, r *http.Request) *Access return ret } +//nolint:dupl func (s *Server) handlePasswordRequest(w *Response, r *http.Request) *AccessRequest { // get client authentication auth := s.getClientAuth(w, r, s.cfg.AllowClientSecretInParams) @@ -498,6 +511,7 @@ func (s *Server) handleClientCredentialsRequest(w *Response, r *http.Request) *A return ret } +//nolint:dupl func (s *Server) handleAssertionRequest(w *Response, r *http.Request) *AccessRequest { // get client authentication auth := s.getClientAuth(w, r, s.cfg.AllowClientSecretInParams) @@ -547,70 +561,85 @@ func (s *Server) FinishAccessRequest(w *Response, r *http.Request, ar *AccessReq redirectUri = ar.RedirectURI } - if ar.Authorized { - var ret *AccessData - var err error - - if ar.ForceAccessData == nil { - // generate access token - ret = &AccessData{ - Client: ar.Client, - AuthorizeData: ar.AuthorizeData, - AccessData: ar.AccessData, - RedirectURI: redirectUri, - CreatedAt: s.now(), - ExpiresIn: int64(ar.Expiration.Seconds()), - UserData: ar.UserData, - Scope: ar.Scope, - } + if !ar.Authorized { + s.setErrorAndLog(w, E_ACCESS_DENIED, nil, "finish_access_request=%s", "authorization failed") - // generate access token - // @TODO: add ret at first arg for GenerateAccessToken - ret.AccessToken, ret.RefreshToken, err = s.tokenGenerator.GenerateAccessToken(ar.GenerateRefresh) - if err != nil { - s.setErrorAndLog(w, E_SERVER_ERROR, err, "finish_access_request=%s", "error generating token") + return + } - return - } - } else { - ret = ar.ForceAccessData + var ret *AccessData + + var err error + + if ar.ForceAccessData == nil { + // generate access token + ret = &AccessData{ + Client: ar.Client, + AuthorizeData: ar.AuthorizeData, + AccessData: ar.AccessData, + RedirectURI: redirectUri, + CreatedAt: s.now(), + ExpiresIn: int64(ar.Expiration.Seconds()), + UserData: ar.UserData, + Scope: ar.Scope, } - // save access token - if err = w.Storage.SaveAccess(ret); err != nil { - s.setErrorAndLog(w, E_SERVER_ERROR, err, "finish_access_request=%s", "error saving access token") + // generate access token + // @TODO: add ret at first arg for GenerateAccessToken + ret.AccessToken, ret.RefreshToken, err = s.tokenGenerator.GenerateAccessToken(ar.GenerateRefresh) + if err != nil { + s.setErrorAndLog(w, E_SERVER_ERROR, err, "finish_access_request=%s", "error generating token") return } + } else { + ret = ar.ForceAccessData + } + + // save access token + if err = w.Storage.SaveAccess(ret); err != nil { + s.setErrorAndLog(w, E_SERVER_ERROR, err, "finish_access_request=%s", "error saving access token") + + return + } - // remove authorization token - if ret.AuthorizeData != nil { - w.Storage.RemoveAuthorize(ret.AuthorizeData.Code) + // remove authorization token + if ret.AuthorizeData != nil { + if err := w.Storage.RemoveAuthorize(ret.AuthorizeData.Code); err != nil { + s.logger.Error("oauth2: remove autorize code failed", logger.WithLabels(map[string]interface{}{ + "code": ret.AuthorizeData.Code, + })) } + } - // remove previous access token - if ret.AccessData != nil && !s.cfg.RetainTokenAfterRefresh { - if ret.AccessData.RefreshToken != "" { - w.Storage.RemoveRefresh(ret.AccessData.RefreshToken) + // remove previous access token + if ret.AccessData != nil && !s.cfg.RetainTokenAfterRefresh { + if ret.AccessData.RefreshToken != "" { + if err := w.Storage.RemoveRefresh(ret.AccessData.RefreshToken); err != nil { + s.logger.Error("oauth2: remove refresh token failed", logger.WithLabels(map[string]interface{}{ + "refresh_token": ret.AccessData.RefreshToken, + })) } + } - w.Storage.RemoveAccess(ret.AccessData.AccessToken) + if err := w.Storage.RemoveAccess(ret.AccessData.AccessToken); err != nil { + s.logger.Error("oauth2: remove access token failed", logger.WithLabels(map[string]interface{}{ + "access_token": ret.AccessData.AccessToken, + })) } + } - // output data - w.Output["access_token"] = ret.AccessToken - w.Output["token_type"] = s.cfg.TokenType - w.Output["expires_in"] = ret.ExpiresIn + // output data + w.Output["access_token"] = ret.AccessToken + w.Output["token_type"] = s.cfg.TokenType + w.Output["expires_in"] = ret.ExpiresIn - if ret.RefreshToken != "" { - w.Output["refresh_token"] = ret.RefreshToken - } + if ret.RefreshToken != "" { + w.Output["refresh_token"] = ret.RefreshToken + } - if ret.Scope != "" { - w.Output["scope"] = ret.Scope - } - } else { - s.setErrorAndLog(w, E_ACCESS_DENIED, nil, "finish_access_request=%s", "authorization failed") + if ret.Scope != "" { + w.Output["scope"] = ret.Scope } } diff --git a/authentication/provider/oauth2/authorize.go b/authentication/provider/oauth2/authorize.go index d0caa18..091df8c 100644 --- a/authentication/provider/oauth2/authorize.go +++ b/authentication/provider/oauth2/authorize.go @@ -10,7 +10,7 @@ import ( "time" ) -// AuthorizeRequestType is the type for OAuth param `response_type` +// AuthorizeRequestType is the type for OAuth param `response_type`. type AuthorizeRequestType string const ( @@ -25,7 +25,7 @@ var ( pkceMatcher = regexp.MustCompile("^[a-zA-Z0-9~._-]{43,128}$") ) -// Authorize request information +// Authorize request information. type AuthorizeRequest struct { Type AuthorizeRequestType Client Client diff --git a/authentication/provider/oauth2/configuration.go b/authentication/provider/oauth2/configuration.go index 0f12738..9aa86bd 100644 --- a/authentication/provider/oauth2/configuration.go +++ b/authentication/provider/oauth2/configuration.go @@ -6,10 +6,10 @@ package oauth2 import "time" -// AllowedAuthorizeType is a collection of allowed auth request types +// AllowedAuthorizeType is a collection of allowed auth request types. type AllowedAuthorizeType []AuthorizeRequestType -// Exists returns true if the auth type exists in the list +// Exists returns true if the auth type exists in the list. func (t AllowedAuthorizeType) Exists(rt AuthorizeRequestType) bool { for _, k := range t { if k == rt { @@ -20,10 +20,10 @@ func (t AllowedAuthorizeType) Exists(rt AuthorizeRequestType) bool { return false } -// AllowedAccessType is a collection of allowed access request types +// AllowedAccessType is a collection of allowed access request types. type AllowedAccessType []AccessRequestType -// Exists returns true if the access type exists in the list +// Exists returns true if the access type exists in the list. func (t AllowedAccessType) Exists(rt AccessRequestType) bool { for _, k := range t { if k == rt { diff --git a/authentication/provider/oauth2/error.go b/authentication/provider/oauth2/error.go index e3bf27d..a711947 100644 --- a/authentication/provider/oauth2/error.go +++ b/authentication/provider/oauth2/error.go @@ -27,7 +27,7 @@ var ( deferror *DefaultErrors = NewDefaultErrors() ) -// Default errors and messages +// Default errors and messages. type DefaultErrors struct { errormap map[DefaultErrorID]string } diff --git a/authentication/provider/oauth2/option.go b/authentication/provider/oauth2/option.go index 6c6f1e7..837c11d 100644 --- a/authentication/provider/oauth2/option.go +++ b/authentication/provider/oauth2/option.go @@ -12,7 +12,7 @@ import ( // Option type. type Option func(server *Server) -// WithConfig add config to server +// WithConfig add config to server. func WithConfig(cfg *Configuration) Option { return func(server *Server) { server.cfg = cfg diff --git a/authentication/provider/oauth2/response.go b/authentication/provider/oauth2/response.go index 4ff6612..a908eeb 100644 --- a/authentication/provider/oauth2/response.go +++ b/authentication/provider/oauth2/response.go @@ -15,10 +15,10 @@ var ( ErrNotARedirectResponse = errors.New("not a redirect response") ) -// Data for response output +// Data for response output. type ResponseData map[string]interface{} -// Response type enum +// Response type enum. type ResponseType int const ( @@ -26,7 +26,7 @@ const ( REDIRECT ) -// Server response +// Server response. type Response struct { Type ResponseType StatusCode int @@ -65,18 +65,18 @@ func NewResponse(storage StorageProvider) *Response { } // SetError sets an error id and description on the Response -// state and uri are left blank +// state and uri are left blank. func (r *Response) SetError(id DefaultErrorID, description string) { r.SetErrorURI(id, description, "", "") } // SetErrorState sets an error id, description, and state on the Response -// uri is left blank +// uri is left blank. func (r *Response) SetErrorState(id DefaultErrorID, description string, state string) { r.SetErrorURI(id, description, "", state) } -// SetErrorURI sets an error id, description, state, and uri on the Response +// SetErrorURI sets an error id, description, state, and uri on the Response. func (r *Response) SetErrorURI(id DefaultErrorID, description string, uri string, state string) { // get default error message if description == "" { @@ -107,19 +107,19 @@ func (r *Response) SetErrorURI(id DefaultErrorID, description string, uri string } } -// SetRedirect changes the response to redirect to the given url +// SetRedirect changes the response to redirect to the given url. func (r *Response) SetRedirect(url string) { // set redirect parameters r.Type = REDIRECT r.URL = url } -// SetRedirectFragment sets redirect values to be passed in fragment instead of as query parameters +// SetRedirectFragment sets redirect values to be passed in fragment instead of as query parameters. func (r *Response) SetRedirectFragment(f bool) { r.RedirectInFragment = f } -// GetRedirectURL returns the redirect url with all query string parameters +// GetRedirectURL returns the redirect url with all query string parameters. func (r *Response) GetRedirectURL() (string, error) { if r.Type != REDIRECT { return "", ErrNotARedirectResponse diff --git a/authentication/provider/oauth2/response_json.go b/authentication/provider/oauth2/response_json.go index 730232b..ce8b320 100644 --- a/authentication/provider/oauth2/response_json.go +++ b/authentication/provider/oauth2/response_json.go @@ -9,7 +9,7 @@ import ( "net/http" ) -// OutputJSON encodes the Response to JSON and writes to the http.ResponseWriter +// OutputJSON encodes the Response to JSON and writes to the http.ResponseWriter. func OutputJSON(rs *Response, w http.ResponseWriter, r *http.Request) error { // Add headers for i, k := range rs.Headers { @@ -29,7 +29,6 @@ func OutputJSON(rs *Response, w http.ResponseWriter, r *http.Request) error { w.WriteHeader(302) return nil - } // set content type if the response doesn't already have one associated with it @@ -39,6 +38,6 @@ func OutputJSON(rs *Response, w http.ResponseWriter, r *http.Request) error { w.WriteHeader(rs.StatusCode) + //nolint:wrapcheck return json.NewEncoder(w).Encode(rs.Output) - } diff --git a/authentication/provider/oauth2/server.go b/authentication/provider/oauth2/server.go index 6de953c..384ed8a 100644 --- a/authentication/provider/oauth2/server.go +++ b/authentication/provider/oauth2/server.go @@ -38,7 +38,7 @@ func NewServer(options ...Option) *Server { return s } -// NewResponse creates a new response for the server +// NewResponse creates a new response for the server. func (s *Server) NewResponse() *Response { r := NewResponse(s.storage) r.ErrorStatusCode = s.cfg.ErrorStatusCode diff --git a/authentication/provider/oauth2/urivalidate.go b/authentication/provider/oauth2/urivalidate.go index 6713f6b..f66b2dd 100644 --- a/authentication/provider/oauth2/urivalidate.go +++ b/authentication/provider/oauth2/urivalidate.go @@ -82,6 +82,7 @@ func ValidateURIList(baseUriList string, redirectUri string, separator string) ( } // if there was an error that is not a validation error, return it + //nolint:errorlint if _, iok := err.(URIValidationError); !iok { return "", err } diff --git a/authentication/provider/oauth2/util.go b/authentication/provider/oauth2/util.go index 1f59340..e6119e1 100644 --- a/authentication/provider/oauth2/util.go +++ b/authentication/provider/oauth2/util.go @@ -20,21 +20,22 @@ import ( var ( ErrInvalidAuthorizationHeader = errors.New("invalid authorization header") ErrInvalidAuthorizationMessage = errors.New("invalid authorization message") + ErrClientAuthenticationNotSent = errors.New("Client authentication not sent") ) -// Parse basic authentication header +// Parse basic authentication header. type BasicAuth struct { Username string Password string } -// Parse bearer authentication header +// Parse bearer authentication header. type BearerAuth struct { Code string } // CheckClientSecret determines whether the given secret matches a secret held by the client. -// Public clients return true for a secret of "" +// Public clients return true for a secret of "". func CheckClientSecret(client Client, secret string) bool { switch client := client.(type) { case ClientSecretMatcher: @@ -46,7 +47,7 @@ func CheckClientSecret(client Client, secret string) bool { } } -// Return authorization header data +// Return authorization header data. func CheckBasicAuth(r *http.Request) (*BasicAuth, error) { if r.Header.Get("Authorization") == "" { return nil, nil @@ -72,12 +73,12 @@ func CheckBasicAuth(r *http.Request) (*BasicAuth, error) { username, err := url.QueryUnescape(pair[0]) if err != nil { - return nil, err + return nil, fmt.Errorf("unescape username failed: %w", err) } password, err := url.QueryUnescape(pair[1]) if err != nil { - return nil, err + return nil, fmt.Errorf("unescape password failed: %w", err) } return &BasicAuth{Username: username, Password: password}, nil @@ -87,20 +88,22 @@ func CheckBasicAuth(r *http.Request) (*BasicAuth, error) { func CheckBearerAuth(r *http.Request) *BearerAuth { authHeader := r.Header.Get("Authorization") authForm := r.FormValue("code") + if authHeader == "" && authForm == "" { return nil } + token := authForm + if authHeader != "" { - s := strings.SplitN(authHeader, " ", 2) - if (len(s) != 2 || strings.ToLower(s[0]) != "bearer") && token == "" { + v, ok := header.ExtractAuthorizationValue("Bearer", authHeader) + if !ok { return nil } - //Use authorization header token only if token type is bearer else query string access token would be returned - if len(s) > 0 && strings.ToLower(s[0]) == "bearer" { - token = s[1] - } + + token = v } + return &BearerAuth{Code: token} } @@ -128,7 +131,7 @@ func (s Server) getClientAuth(w *Response, r *http.Request, allowQueryParams boo auth := credential.FromContext(ctx) if auth == nil { - s.setErrorAndLog(w, E_INVALID_REQUEST, errors.New("Client authentication not sent"), "get_client_auth=%s", "client authentication not sent") + s.setErrorAndLog(w, E_INVALID_REQUEST, ErrClientAuthenticationNotSent, "get_client_auth=%s", "client authentication not sent") return nil } From 756f9b5bf96b5214e4a248f5ae95f3c82f81d957 Mon Sep 17 00:00:00 2001 From: Axel Etcheverry Date: Mon, 16 Aug 2021 00:51:52 +0200 Subject: [PATCH 6/9] feat(authentication/oauth2): add endpoint {prefox}/token support --- authentication/provider/oauth2/handler.go | 74 +++++++++++++++++++ authentication/provider/oauth2/option.go | 7 ++ authentication/provider/oauth2/option_test.go | 13 ++++ authentication/provider/oauth2/server.go | 3 + go.mod | 10 +-- go.sum | 39 +++++++--- 6 files changed, 132 insertions(+), 14 deletions(-) create mode 100644 authentication/provider/oauth2/handler.go diff --git a/authentication/provider/oauth2/handler.go b/authentication/provider/oauth2/handler.go new file mode 100644 index 0000000..a604aa5 --- /dev/null +++ b/authentication/provider/oauth2/handler.go @@ -0,0 +1,74 @@ +// Copyright 2021 Hyperscale. All rights reserved. +// Use of this source code is governed by a MIT +// license that can be found in the LICENSE file. + +package oauth2 + +import ( + "errors" + "net/http" +) + +var ( + ErrOAuthError = errors.New("oauth2 error") +) + +func (s Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case s.cfg.PrefixURI + "/token": + s.handleTokenRequest(w, r) + /* + case s.cfg.PrefixURI + "/authorize": + s.handleAuthorizeRequest(w, r) + */ + } +} + +func (s Server) handleTokenRequest(w http.ResponseWriter, r *http.Request) { + resp := s.NewResponse() + + requestType := "" + + if ar := s.HandleAccessRequest(resp, r); ar != nil { + requestType = string(ar.Type) + + switch ar.Type { + case AUTHORIZATION_CODE: + ar.Authorized = true + case REFRESH_TOKEN: + ar.Authorized = true + case PASSWORD: + user, err := s.userProvider.Authenticate(ar.Username, ar.Password) + if err != nil { + s.setErrorAndLog(resp, E_ACCESS_DENIED, err, "get_user=%s", "failed") + } else { + ar.Authorized = true + ar.UserData = user.GetID() + } + case CLIENT_CREDENTIALS: + ar.Authorized = true + } + + s.FinishAccessRequest(resp, r, ar) + } + + var err error + + if resp.IsError { + if resp.InternalError != nil { + err = resp.InternalError + } else { + err = ErrOAuthError + } + + s.logger.Error(err.Error()) + + s.emitter.Dispatch("oauth."+requestType+".failed", resp) + } else { + s.emitter.Dispatch("oauth."+requestType+".succeeded", resp) + } + + if err := OutputJSON(resp, w, r); err != nil { + s.logger.Error(err.Error()) + } +} diff --git a/authentication/provider/oauth2/option.go b/authentication/provider/oauth2/option.go index 837c11d..1a42967 100644 --- a/authentication/provider/oauth2/option.go +++ b/authentication/provider/oauth2/option.go @@ -5,6 +5,7 @@ package oauth2 import ( + "github.com/euskadi31/go-eventemitter" "github.com/hyperscale-stack/logger" "github.com/hyperscale-stack/security/authentication/provider/oauth2/token" ) @@ -42,3 +43,9 @@ func WithTokenGenerator(tokenGenerator token.Generator) Option { server.tokenGenerator = tokenGenerator } } + +func WithEventEmitter(emitter eventemitter.EventEmitter) Option { + return func(server *Server) { + server.emitter = emitter + } +} diff --git a/authentication/provider/oauth2/option_test.go b/authentication/provider/oauth2/option_test.go index 3a63d80..11bd23e 100644 --- a/authentication/provider/oauth2/option_test.go +++ b/authentication/provider/oauth2/option_test.go @@ -7,6 +7,7 @@ package oauth2 import ( "testing" + "github.com/euskadi31/go-eventemitter" "github.com/hyperscale-stack/logger" "github.com/hyperscale-stack/security/authentication/provider/oauth2/token" "github.com/stretchr/testify/assert" @@ -71,3 +72,15 @@ func TestWithTokenGenerator(t *testing.T) { assert.Same(t, tokenGeneratorMock, server.tokenGenerator) } + +func TestWithEventEmitter(t *testing.T) { + emitter := eventemitter.New() + + opt := WithEventEmitter(emitter) + + server := &Server{} + + opt(server) + + assert.Same(t, emitter, server.emitter) +} diff --git a/authentication/provider/oauth2/server.go b/authentication/provider/oauth2/server.go index 384ed8a..3b733a0 100644 --- a/authentication/provider/oauth2/server.go +++ b/authentication/provider/oauth2/server.go @@ -7,6 +7,7 @@ package oauth2 import ( "time" + "github.com/euskadi31/go-eventemitter" "github.com/hyperscale-stack/logger" "github.com/hyperscale-stack/security/authentication/provider/oauth2/token" "github.com/hyperscale-stack/security/authentication/provider/oauth2/token/random" @@ -18,6 +19,7 @@ type Server struct { storage StorageProvider userProvider UserProvider tokenGenerator token.Generator + emitter eventemitter.EventEmitter now func() time.Time } @@ -29,6 +31,7 @@ func NewServer(options ...Option) *Server { logger: &logger.Nop{}, tokenGenerator: random.NewTokenGenerator(&random.Configuration{}), now: time.Now, + emitter: eventemitter.New(), } for _, opt := range options { diff --git a/go.mod b/go.mod index 726d900..6d64245 100644 --- a/go.mod +++ b/go.mod @@ -3,12 +3,12 @@ module github.com/hyperscale-stack/security go 1.15 require ( + github.com/euskadi31/go-eventemitter v1.1.1 github.com/gilcrest/alice v1.0.0 - github.com/hyperscale-stack/logger v1.0.0 // indirect - github.com/hyperscale-stack/secure v1.0.0 // indirect - github.com/rs/zerolog v1.20.0 + github.com/hyperscale-stack/logger v1.0.0 + github.com/hyperscale-stack/secure v1.0.0 + github.com/rs/zerolog v1.23.0 github.com/stretchr/objx v0.3.0 // indirect github.com/stretchr/testify v1.7.0 - golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83 - gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect + golang.org/x/crypto v0.0.0-20210813211128-0a44fdfbc16e ) diff --git a/go.sum b/go.sum index 635bbd0..33eadf7 100644 --- a/go.sum +++ b/go.sum @@ -1,39 +1,60 @@ -github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= +github.com/coreos/go-systemd/v22 v22.3.2/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/euskadi31/go-eventemitter v1.1.1 h1:VrDOCjM1uaQfq3XGlrZXce+wL6JtiecPtKPwS2jATo8= +github.com/euskadi31/go-eventemitter v1.1.1/go.mod h1:2FzLg7x4u58JiXblyOwQJ+O+tAcj05CFKnN2Yz8yXEY= github.com/gilcrest/alice v1.0.0 h1:5+CasxidJEUHmgghQxLOl09uYhOlavDfDgNZhyR62LU= github.com/gilcrest/alice v1.0.0/go.mod h1:q5HRhK5WEyU1pDBIIfmYapVGLd/IAAPwiO8LNxKADpw= +github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/hyperscale-stack/logger v1.0.0 h1:ZsjQp1DEXwHunoVhi0L5/yZOg0Npsy7LkQy10Mi2o6I= github.com/hyperscale-stack/logger v1.0.0/go.mod h1:X1apoFZZ8/AKspix5ylBetZzv+HEU2Ive/+3qHJhyOw= github.com/hyperscale-stack/secure v1.0.0 h1:ayGoa/Y/0RcAcP767WKjla1r9KlR+Tul5DPI/jE9dP0= github.com/hyperscale-stack/secure v1.0.0/go.mod h1:PY+BMJQI2aP+YYA3C7R0bFTS/XGJ4xPCYjBp9rEqmtQ= -github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ= -github.com/rs/zerolog v1.20.0 h1:38k9hgtUBdxFwE34yS8rTHmHBa4eN16E4DJlv177LNs= -github.com/rs/zerolog v1.20.0/go.mod h1:IzD0RJ65iWH0w97OQQebJEvTZYvsCUm9WVLWBQrJRjo= +github.com/rs/zerolog v1.23.0 h1:UskrK+saS9P9Y789yNNulYKdARjPZuS35B8gJF2x60g= +github.com/rs/zerolog v1.23.0/go.mod h1:6c7hFfxPOy7TacJc4Fcdi24/J0NKYGzjG8FWRI916Qo= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.3.0 h1:NGXK3lHquSN08v5vWalVI/L8XU9hdzE/G6xsrze47As= github.com/stretchr/objx v0.3.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83 h1:/ZScEX8SfEmUGRHs0gxpqteO5nfNW6axyZbBdw9A12g= -golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20210813211128-0a44fdfbc16e h1:VvfwVmMH40bpMeizC9/K7ipM5Qjucuu16RWfneFPyhQ= +golang.org/x/crypto v0.0.0-20210813211128-0a44fdfbc16e/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210119212857-b64e53b001e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/tools v0.0.0-20190828213141-aed303cbaa74/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.1.0/go.mod h1:xkSsbof2nBLbhDlRMhhhyNLN/zl3eTqcnHD5viDpcZ0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b h1:h8qDotaEPuJATrMmW04NCwg7v22aHH28wwpauUhK9Oo= gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= From 690e47eca3b66d81d8990456229108a960caf050 Mon Sep 17 00:00:00 2001 From: Axel Etcheverry Date: Mon, 16 Aug 2021 00:54:28 +0200 Subject: [PATCH 7/9] fix(authentication/oauth2): lint --- authentication/provider/oauth2/handler.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/authentication/provider/oauth2/handler.go b/authentication/provider/oauth2/handler.go index a604aa5..ec8a0b6 100644 --- a/authentication/provider/oauth2/handler.go +++ b/authentication/provider/oauth2/handler.go @@ -14,6 +14,8 @@ var ( ) func (s Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { + //nolint:wsl,gocritic + //@TODO: WIP switch r.URL.Path { case s.cfg.PrefixURI + "/token": s.handleTokenRequest(w, r) @@ -32,6 +34,7 @@ func (s Server) handleTokenRequest(w http.ResponseWriter, r *http.Request) { if ar := s.HandleAccessRequest(resp, r); ar != nil { requestType = string(ar.Type) + //nolint:exhaustive switch ar.Type { case AUTHORIZATION_CODE: ar.Authorized = true From 85a577d5ad4a33b47cc5a0f789fbfe55096439e2 Mon Sep 17 00:00:00 2001 From: Axel Etcheverry Date: Thu, 28 Nov 2024 22:23:40 +0100 Subject: [PATCH 8/9] fix: linter --- .github/workflows/go.yml | 71 ++++++----- .golangci.yml | 110 ++++++++---------- authentication/provider.go | 1 + authentication/provider/dao/user_provider.go | 1 + authentication/provider/oauth2/access_test.go | 23 ---- user/user.go | 2 + 6 files changed, 90 insertions(+), 118 deletions(-) diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 6c78b9b..a9bc0e7 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -2,51 +2,50 @@ name: Go on: push: - branches: [ master ] + branches: [master] pull_request: - branches: [ master ] + branches: [master] jobs: - build: name: Build runs-on: ubuntu-latest steps: - - - name: Set up Go 1.x - uses: actions/setup-go@v2 - with: - go-version: ^1.16 - id: go - - - name: Check out code into the Go module directory - uses: actions/checkout@v2 - - - name: Build - run: go build -race -v ./... - - - name: Test - run: | - go test -race -cover -coverprofile ./coverage.out.tmp ./... - cat ./coverage.out.tmp | grep -v '.pb.go' | grep -v 'mock_' > ./coverage.out - rm ./coverage.out.tmp - - - name: Run golangci-lint - uses: golangci/golangci-lint-action@v2 - with: - version: v1.41.1 - - - name: Coveralls - uses: shogo82148/actions-goveralls@v1 - with: - path-to-profile: coverage.out + - name: Check out code into the Go module directory + uses: actions/checkout@v2 + + - name: Set up Go 1.x + uses: actions/setup-go@v2 + with: + go-version: ^1.16 + id: go + + - name: Build + run: go build -race -v ./... + + - name: Test + run: | + go test -race -cover -coverprofile ./coverage.out.tmp ./... + cat ./coverage.out.tmp | grep -v '.pb.go' | grep -v 'mock_' > ./coverage.out + rm ./coverage.out.tmp + + - name: Run golangci-lint + uses: golangci/golangci-lint-action@v3 + with: + version: v1.62.2 + skip-cache: true + + - name: Coveralls + uses: shogo82148/actions-goveralls@v1 + with: + path-to-profile: coverage.out finish: needs: build runs-on: ubuntu-latest steps: - - name: Coveralls Finished - uses: coverallsapp/github-action@master - with: - github-token: ${{ secrets.GITHUB_TOKEN }} - parallel-finished: true + - name: Coveralls Finished + uses: coverallsapp/github-action@master + with: + github-token: ${{ secrets.GITHUB_TOKEN }} + parallel-finished: true diff --git a/.golangci.yml b/.golangci.yml index f073f8a..bfbff43 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -1,88 +1,41 @@ -run: - concurrency: 4 - deadline: 1m - issues-exit-code: 1 - tests: false - skip-files: - - ".*_mock\\.go" - - "mock_.*\\.go" - - ".*/pkg/mod/.*$" - -output: - format: colored-line-number - print-issued-lines: true - print-linter-name: true - -linters-settings: - errcheck: - check-type-assertions: false - check-blank: false - govet: - check-shadowing: false - revive: - ignore-generated-header: true - severity: warning - gofmt: - simplify: true - gocyclo: - min-complexity: 18 - maligned: - suggest-new: true - dupl: - threshold: 50 - goconst: - min-len: 3 - min-occurrences: 2 - depguard: - list-type: blacklist - include-go-root: false - packages: - - github.com/davecgh/go-spew/spew - misspell: - locale: US - ignore-words: - - cancelled - goimports: - local-prefixes: go.opentelemetry.io - - +issues: + exclude-case-sensitive: false + exclude-dirs-use-default: true + exclude-files: + - .*_mock\.go + - mock_.*\.go + - .*/pkg/mod/.*$ + - .*/go/src/.*\.go + exclude-generated: strict + exclude-use-default: true + max-issues-per-linter: 50 linters: disable-all: true enable: - - deadcode - - depguard - errcheck - - gas - goconst - gocyclo - gofmt - revive - govet - ineffassign - - megacheck - misspell - - structcheck - typecheck - unconvert - - varcheck - gosimple - staticcheck - unused - asciicheck - bodyclose - dogsled - - dupl - durationcheck - errorlint - exhaustive - - exportloopref - forbidigo - forcetypeassert - gocritic - godot - - goerr113 - gosec - - ifshort - nestif - nilerr - nlreturn @@ -90,8 +43,47 @@ linters: - prealloc - predeclared - sqlclosecheck - - tagliatelle - whitespace - wrapcheck - wsl fast: false +linters-settings: + depguard: + rules: + main: + allow: + - $all + dupl: + threshold: 99 + errcheck: + check-blank: false + check-type-assertions: false + goconst: + min-len: 3 + min-occurrences: 2 + gocyclo: + min-complexity: 18 + gofmt: + simplify: true + goimports: + local-prefixes: go.opentelemetry.io + govet: + disable: + - shadow + misspell: + ignore-words: + - cancelled + locale: US + revive: + ignore-generated-header: true + severity: warning +output: + formats: + - format: colored-line-number + print-issued-lines: true + print-linter-name: true +run: + concurrency: 4 + issues-exit-code: 1 + tests: false + timeout: 1m diff --git a/authentication/provider.go b/authentication/provider.go index 0817142..cef470b 100644 --- a/authentication/provider.go +++ b/authentication/provider.go @@ -11,6 +11,7 @@ import ( ) // Provider Service interface for encoding passwords +// //go:generate mockery --name=Provider --inpackage --case underscore type Provider interface { Authenticate(r *http.Request, creds credential.Credential) (*http.Request, error) diff --git a/authentication/provider/dao/user_provider.go b/authentication/provider/dao/user_provider.go index d307998..e7a7843 100644 --- a/authentication/provider/dao/user_provider.go +++ b/authentication/provider/dao/user_provider.go @@ -7,6 +7,7 @@ package dao import "github.com/hyperscale-stack/security/user" // UserProvider interface which loads user-specific data. +// //go:generate mockery --name=UserProvider --inpackage --case underscore type UserProvider interface { LoadUserByUsername(username string) (user.User, error) diff --git a/authentication/provider/oauth2/access_test.go b/authentication/provider/oauth2/access_test.go index 5296947..2b2f376 100644 --- a/authentication/provider/oauth2/access_test.go +++ b/authentication/provider/oauth2/access_test.go @@ -17,7 +17,6 @@ import ( "github.com/hyperscale-stack/security/authentication/credential" "github.com/stretchr/testify/assert" - mock "github.com/stretchr/testify/mock" ) func TestAccessData(t *testing.T) { @@ -88,29 +87,7 @@ func TestServerHandleAccessRequestWithBadMethod(t *testing.T) { assert.Equal(t, E_INVALID_REQUEST, w.ErrorID) } -type mockReadCloser struct { - mock.Mock -} - -func (m *mockReadCloser) Read(p []byte) (n int, err error) { - args := m.Called(p) - - return args.Int(0), args.Error(1) -} - -func (m *mockReadCloser) Close() error { - args := m.Called() - - return args.Error(0) -} - func TestServerHandleAccessRequestWithBadBody(t *testing.T) { - /*mockReadCloser := &mockReadCloser{} - // if Read is called, it will return error - mockReadCloser.On("Read", mock.AnythingOfType("[]uint8")).Return(0, fmt.Errorf("error reading")) - // if Close is called, it will return error - mockReadCloser.On("Close").Return(fmt.Errorf("error closing")) - */ cfg := &Configuration{ ErrorStatusCode: http.StatusOK, AllowGetAccessRequest: false, diff --git a/user/user.go b/user/user.go index a3e55be..f46594d 100644 --- a/user/user.go +++ b/user/user.go @@ -5,6 +5,7 @@ package user // User interface provides core user information +// //go:generate mockery --name=User --inpackage --case underscore type User interface { GetID() string @@ -38,6 +39,7 @@ type PasswordSalt interface { } // UserPasswordSalt interface. +// //go:generate mockery --name=UserPasswordSalt --inpackage --case underscore //nolint:golint type UserPasswordSalt interface { From 6f4ffecfa78a7311340014c59937247aded03977 Mon Sep 17 00:00:00 2001 From: Axel Etcheverry Date: Mon, 24 Mar 2025 22:59:51 +0100 Subject: [PATCH 9/9] fix: lint --- .github/workflows/go.yml | 4 +- .golangci.yml | 131 ++++++++++-------- .../dao/dao_authentication_provider.go | 1 + authentication/provider/oauth2/access.go | 2 + .../oauth2/oauth2_authentication_provider.go | 3 + .../oauth2/storage/in_memory_storage.go | 8 +- authentication/provider/oauth2/util.go | 2 +- 7 files changed, 86 insertions(+), 65 deletions(-) diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index a9bc0e7..3f32004 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -30,9 +30,9 @@ jobs: rm ./coverage.out.tmp - name: Run golangci-lint - uses: golangci/golangci-lint-action@v3 + uses: golangci/golangci-lint-action@v7 with: - version: v1.62.2 + version: latest skip-cache: true - name: Coveralls diff --git a/.golangci.yml b/.golangci.yml index bfbff43..0caf621 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -1,89 +1,104 @@ +formatters: + enable: + - gofmt + exclusions: + paths: + - .*_mock\.go + - mock_.*\.go + - .*/pkg/mod/.*$ + - .*/go/src/.*\.go + - third_party$ + - builtin$ + - examples$ + settings: + gofmt: + simplify: true + goimports: + local-prefixes: + - go.opentelemetry.io issues: - exclude-case-sensitive: false - exclude-dirs-use-default: true - exclude-files: - - .*_mock\.go - - mock_.*\.go - - .*/pkg/mod/.*$ - - .*/go/src/.*\.go - exclude-generated: strict - exclude-use-default: true max-issues-per-linter: 50 linters: - disable-all: true + default: none enable: - - errcheck - - goconst - - gocyclo - - gofmt - - revive - - govet - - ineffassign - - misspell - - typecheck - - unconvert - - gosimple - - staticcheck - - unused - asciicheck - bodyclose - dogsled - durationcheck + - errcheck - errorlint - exhaustive - forbidigo - forcetypeassert + - goconst - gocritic + - gocyclo - godot - gosec + - govet + - ineffassign + - misspell - nestif - nilerr - nlreturn - noctx - prealloc - predeclared + - revive - sqlclosecheck + - staticcheck + - unconvert + - unused - whitespace - wrapcheck - wsl - fast: false -linters-settings: - depguard: - rules: - main: - allow: - - $all - dupl: - threshold: 99 - errcheck: - check-blank: false - check-type-assertions: false - goconst: - min-len: 3 - min-occurrences: 2 - gocyclo: - min-complexity: 18 - gofmt: - simplify: true - goimports: - local-prefixes: go.opentelemetry.io - govet: - disable: - - shadow - misspell: - ignore-words: - - cancelled - locale: US - revive: - ignore-generated-header: true - severity: warning + exclusions: + paths: + - .*_mock\.go + - mock_.*\.go + - .*/pkg/mod/.*$ + - .*/go/src/.*\.go + - third_party$ + - builtin$ + - examples$ + presets: + - comments + - common-false-positives + - legacy + - std-error-handling + settings: + depguard: + rules: + main: + allow: + - $all + dupl: + threshold: 99 + errcheck: + check-blank: false + check-type-assertions: false + goconst: + min-len: 3 + min-occurrences: 2 + gocyclo: + min-complexity: 18 + govet: + disable: + - shadow + misspell: + ignore-rules: + - cancelled + locale: US + revive: + severity: warning output: formats: - - format: colored-line-number - print-issued-lines: true - print-linter-name: true + text: + path: stdout + print-issued-lines: true + print-linter-name: true run: concurrency: 4 issues-exit-code: 1 tests: false - timeout: 1m +version: "2" diff --git a/authentication/provider/dao/dao_authentication_provider.go b/authentication/provider/dao/dao_authentication_provider.go index c6dcb61..79c9108 100644 --- a/authentication/provider/dao/dao_authentication_provider.go +++ b/authentication/provider/dao/dao_authentication_provider.go @@ -51,6 +51,7 @@ func (p *DaoAuthenticationProvider) Authenticate(r *http.Request, creds credenti return r, ErrBadAuthenticationFormat } + // nolint:forcetypeassert u, err := p.userProvider.LoadUserByUsername(auth.GetPrincipal().(string)) if err != nil { return r, fmt.Errorf("user provider failed: %w", err) diff --git a/authentication/provider/oauth2/access.go b/authentication/provider/oauth2/access.go index fa39c96..132bb4f 100644 --- a/authentication/provider/oauth2/access.go +++ b/authentication/provider/oauth2/access.go @@ -648,6 +648,7 @@ func (s *Server) FinishAccessRequest(w *Response, r *http.Request, ar *AccessReq // getClient looks up and authenticates the basic auth using the given // storage. Sets an error on the response if auth fails or a server error occurs. func (s Server) getClient(creds *credential.UsernamePasswordCredential, storage StorageProvider, w *Response) Client { + // nolint:forcetypeassert client, err := storage.LoadClient(creds.GetPrincipal().(string)) if errors.Is(err, ErrClientNotFound) { s.setErrorAndLog(w, E_UNAUTHORIZED_CLIENT, nil, "get_client=%s", "not found") @@ -667,6 +668,7 @@ func (s Server) getClient(creds *credential.UsernamePasswordCredential, storage return nil } + // nolint:forcetypeassert if !CheckClientSecret(client, creds.GetCredentials().(string)) { s.setErrorAndLog(w, E_UNAUTHORIZED_CLIENT, nil, "get_client=%s, client_id=%v", "client check failed", client.GetID()) diff --git a/authentication/provider/oauth2/oauth2_authentication_provider.go b/authentication/provider/oauth2/oauth2_authentication_provider.go index 9e45f67..fbcb95c 100644 --- a/authentication/provider/oauth2/oauth2_authentication_provider.go +++ b/authentication/provider/oauth2/oauth2_authentication_provider.go @@ -65,6 +65,7 @@ func (p *OAuth2AuthenticationProvider) IsSupported(creds credential.Credential) func (p *OAuth2AuthenticationProvider) authenticateByToken(r *http.Request, creds *credential.TokenCredential) (*http.Request, error) { ctx := r.Context() + // nolint:forcetypeassert token, err := p.accessStorage.LoadAccess(creds.GetPrincipal().(string)) if err != nil { return r, fmt.Errorf("load access token failed: %w", err) @@ -96,11 +97,13 @@ func (p *OAuth2AuthenticationProvider) authenticateByToken(r *http.Request, cred func (p *OAuth2AuthenticationProvider) authenticateByClient(r *http.Request, creds *credential.UsernamePasswordCredential) (*http.Request, error) { ctx := r.Context() + // nolint:forcetypeassert client, err := p.clientStorage.LoadClient(creds.GetPrincipal().(string)) if err != nil { return r, fmt.Errorf("load client info failed: %w", err) } + // nolint:forcetypeassert if c, ok := client.(ClientSecretMatcher); ok { if c.SecretMatches(creds.GetCredentials().(string)) { creds.SetAuthenticated(true) diff --git a/authentication/provider/oauth2/storage/in_memory_storage.go b/authentication/provider/oauth2/storage/in_memory_storage.go index f7ceefc..2fafbe1 100644 --- a/authentication/provider/oauth2/storage/in_memory_storage.go +++ b/authentication/provider/oauth2/storage/in_memory_storage.go @@ -31,7 +31,7 @@ func (s *InMemoryStorage) SaveClient(client oauth2.Client) error { func (s *InMemoryStorage) LoadClient(id string) (oauth2.Client, error) { if client, ok := s.clients.Load(id); ok { - return client.(oauth2.Client), nil + return client.(oauth2.Client), nil // nolint:forcetypeassert } return nil, oauth2.ErrClientNotFound @@ -51,7 +51,7 @@ func (s *InMemoryStorage) SaveAccess(access *oauth2.AccessData) error { func (s *InMemoryStorage) LoadAccess(token string) (*oauth2.AccessData, error) { if access, ok := s.accesses.Load(token); ok { - return access.(*oauth2.AccessData), nil + return access.(*oauth2.AccessData), nil // nolint:forcetypeassert } return nil, oauth2.ErrAccessNotFound @@ -71,7 +71,7 @@ func (s *InMemoryStorage) SaveRefresh(access *oauth2.AccessData) error { func (s *InMemoryStorage) LoadRefresh(token string) (*oauth2.AccessData, error) { if access, ok := s.refreshs.Load(token); ok { - return access.(*oauth2.AccessData), nil + return access.(*oauth2.AccessData), nil // nolint:forcetypeassert } return nil, oauth2.ErrRefreshNotFound @@ -91,7 +91,7 @@ func (s *InMemoryStorage) SaveAuthorize(authorize *oauth2.AuthorizeData) error { func (s *InMemoryStorage) LoadAuthorize(code string) (*oauth2.AuthorizeData, error) { if authorize, ok := s.authorizes.Load(code); ok { - return authorize.(*oauth2.AuthorizeData), nil + return authorize.(*oauth2.AuthorizeData), nil // nolint:forcetypeassert } return nil, oauth2.ErrAuthorizeNotFound diff --git a/authentication/provider/oauth2/util.go b/authentication/provider/oauth2/util.go index e6119e1..382e250 100644 --- a/authentication/provider/oauth2/util.go +++ b/authentication/provider/oauth2/util.go @@ -136,5 +136,5 @@ func (s Server) getClientAuth(w *Response, r *http.Request, allowQueryParams boo return nil } - return auth.(*credential.UsernamePasswordCredential) + return auth.(*credential.UsernamePasswordCredential) // nolint:forcetypeassert }