diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 6c78b9b..3f32004 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -2,51 +2,50 @@ name: Go on: push: - branches: [ master ] + branches: [master] pull_request: - branches: [ master ] + branches: [master] jobs: - build: name: Build runs-on: ubuntu-latest steps: - - - name: Set up Go 1.x - uses: actions/setup-go@v2 - with: - go-version: ^1.16 - id: go - - - name: Check out code into the Go module directory - uses: actions/checkout@v2 - - - name: Build - run: go build -race -v ./... - - - name: Test - run: | - go test -race -cover -coverprofile ./coverage.out.tmp ./... - cat ./coverage.out.tmp | grep -v '.pb.go' | grep -v 'mock_' > ./coverage.out - rm ./coverage.out.tmp - - - name: Run golangci-lint - uses: golangci/golangci-lint-action@v2 - with: - version: v1.41.1 - - - name: Coveralls - uses: shogo82148/actions-goveralls@v1 - with: - path-to-profile: coverage.out + - name: Check out code into the Go module directory + uses: actions/checkout@v2 + + - name: Set up Go 1.x + uses: actions/setup-go@v2 + with: + go-version: ^1.16 + id: go + + - name: Build + run: go build -race -v ./... + + - name: Test + run: | + go test -race -cover -coverprofile ./coverage.out.tmp ./... + cat ./coverage.out.tmp | grep -v '.pb.go' | grep -v 'mock_' > ./coverage.out + rm ./coverage.out.tmp + + - name: Run golangci-lint + uses: golangci/golangci-lint-action@v7 + with: + version: latest + skip-cache: true + + - name: Coveralls + uses: shogo82148/actions-goveralls@v1 + with: + path-to-profile: coverage.out finish: needs: build runs-on: ubuntu-latest steps: - - name: Coveralls Finished - uses: coverallsapp/github-action@master - with: - github-token: ${{ secrets.GITHUB_TOKEN }} - parallel-finished: true + - name: Coveralls Finished + uses: coverallsapp/github-action@master + with: + github-token: ${{ secrets.GITHUB_TOKEN }} + parallel-finished: true diff --git a/.gitignore b/.gitignore index 93117e4..125dc4e 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,5 @@ vendor/ *.swp + +*.go_ diff --git a/.golangci.yml b/.golangci.yml index f073f8a..0caf621 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -1,97 +1,104 @@ -run: - concurrency: 4 - deadline: 1m - issues-exit-code: 1 - tests: false - skip-files: - - ".*_mock\\.go" - - "mock_.*\\.go" - - ".*/pkg/mod/.*$" - -output: - format: colored-line-number - print-issued-lines: true - print-linter-name: true - -linters-settings: - errcheck: - check-type-assertions: false - check-blank: false - govet: - check-shadowing: false - revive: - ignore-generated-header: true - severity: warning - gofmt: - simplify: true - gocyclo: - min-complexity: 18 - maligned: - suggest-new: true - dupl: - threshold: 50 - goconst: - min-len: 3 - min-occurrences: 2 - depguard: - list-type: blacklist - include-go-root: false - packages: - - github.com/davecgh/go-spew/spew - misspell: - locale: US - ignore-words: - - cancelled - goimports: - local-prefixes: go.opentelemetry.io - - -linters: - disable-all: true +formatters: enable: - - deadcode - - depguard - - errcheck - - gas - - goconst - - gocyclo - gofmt - - revive - - govet - - ineffassign - - megacheck - - misspell - - structcheck - - typecheck - - unconvert - - varcheck - - gosimple - - staticcheck - - unused + exclusions: + paths: + - .*_mock\.go + - mock_.*\.go + - .*/pkg/mod/.*$ + - .*/go/src/.*\.go + - third_party$ + - builtin$ + - examples$ + settings: + gofmt: + simplify: true + goimports: + local-prefixes: + - go.opentelemetry.io +issues: + max-issues-per-linter: 50 +linters: + default: none + enable: - asciicheck - bodyclose - dogsled - - dupl - durationcheck + - errcheck - errorlint - exhaustive - - exportloopref - forbidigo - forcetypeassert + - goconst - gocritic + - gocyclo - godot - - goerr113 - gosec - - ifshort + - govet + - ineffassign + - misspell - nestif - nilerr - nlreturn - noctx - prealloc - predeclared + - revive - sqlclosecheck - - tagliatelle + - staticcheck + - unconvert + - unused - whitespace - wrapcheck - wsl - fast: false + exclusions: + paths: + - .*_mock\.go + - mock_.*\.go + - .*/pkg/mod/.*$ + - .*/go/src/.*\.go + - third_party$ + - builtin$ + - examples$ + presets: + - comments + - common-false-positives + - legacy + - std-error-handling + settings: + depguard: + rules: + main: + allow: + - $all + dupl: + threshold: 99 + errcheck: + check-blank: false + check-type-assertions: false + goconst: + min-len: 3 + min-occurrences: 2 + gocyclo: + min-complexity: 18 + govet: + disable: + - shadow + misspell: + ignore-rules: + - cancelled + locale: US + revive: + severity: warning +output: + formats: + text: + path: stdout + print-issued-lines: true + print-linter-name: true +run: + concurrency: 4 + issues-exit-code: 1 + tests: false +version: "2" diff --git a/authentication/credential/username_password_credential.go b/authentication/credential/username_password_credential.go index 712fdbb..dc1b02b 100644 --- a/authentication/credential/username_password_credential.go +++ b/authentication/credential/username_password_credential.go @@ -17,7 +17,7 @@ type UsernamePasswordCredential struct { var _ Credential = (*UsernamePasswordCredential)(nil) // NewUsernamePasswordCredential constructor. -func NewUsernamePasswordCredential(principal string, credentials string) Credential { +func NewUsernamePasswordCredential(principal string, credentials string) *UsernamePasswordCredential { return &UsernamePasswordCredential{ credentials: credentials, principal: principal, diff --git a/authentication/provider.go b/authentication/provider.go index 0817142..cef470b 100644 --- a/authentication/provider.go +++ b/authentication/provider.go @@ -11,6 +11,7 @@ import ( ) // Provider Service interface for encoding passwords +// //go:generate mockery --name=Provider --inpackage --case underscore type Provider interface { Authenticate(r *http.Request, creds credential.Credential) (*http.Request, error) diff --git a/authentication/provider/dao/dao_authentication_provider.go b/authentication/provider/dao/dao_authentication_provider.go index c6dcb61..79c9108 100644 --- a/authentication/provider/dao/dao_authentication_provider.go +++ b/authentication/provider/dao/dao_authentication_provider.go @@ -51,6 +51,7 @@ func (p *DaoAuthenticationProvider) Authenticate(r *http.Request, creds credenti return r, ErrBadAuthenticationFormat } + // nolint:forcetypeassert u, err := p.userProvider.LoadUserByUsername(auth.GetPrincipal().(string)) if err != nil { return r, fmt.Errorf("user provider failed: %w", err) diff --git a/authentication/provider/dao/user_provider.go b/authentication/provider/dao/user_provider.go index d307998..e7a7843 100644 --- a/authentication/provider/dao/user_provider.go +++ b/authentication/provider/dao/user_provider.go @@ -7,6 +7,7 @@ package dao import "github.com/hyperscale-stack/security/user" // UserProvider interface which loads user-specific data. +// //go:generate mockery --name=UserProvider --inpackage --case underscore type UserProvider interface { LoadUserByUsername(username string) (user.User, error) diff --git a/authentication/provider/oauth2/access.go b/authentication/provider/oauth2/access.go index 864b017..132bb4f 100644 --- a/authentication/provider/oauth2/access.go +++ b/authentication/provider/oauth2/access.go @@ -6,21 +6,79 @@ package oauth2 import ( "context" + "crypto/sha256" + "encoding/base64" + "errors" + "net/http" + "strings" "time" + + "github.com/hyperscale-stack/logger" + "github.com/hyperscale-stack/security/authentication/credential" +) + +var ( + ErrRequestMustBePost = errors.New("request must be POST") + ErrExtraScope = errors.New("the requested scope must not include any scope not originally granted by the resource owner") + ErrClientIDNotSame = errors.New("client id must be the same from previous token") + ErrCodeChallengeNotSame = errors.New("code_verifier failed comparison with code_challenge") + ErrCodeVerifierInvalidFormat = errors.New("code_verifier has invalid format") + ErrRedirectURINotSame = errors.New("redirect uri is different") ) -type AccessToken interface { - GetClient() Client - GetToken() string - IsExpired() bool - GetUserID() string +// AccessRequestType is the type for OAuth2 param `grant_type`. +type AccessRequestType string + +const ( + AUTHORIZATION_CODE AccessRequestType = "authorization_code" + REFRESH_TOKEN AccessRequestType = "refresh_token" + PASSWORD AccessRequestType = "password" + CLIENT_CREDENTIALS AccessRequestType = "client_credentials" + ASSERTION AccessRequestType = "assertion" + IMPLICIT AccessRequestType = "__implicit" +) + +// AccessRequest is a request for access tokens. +type AccessRequest struct { + Type AccessRequestType + Code string + Client Client + AuthorizeData *AuthorizeData + AccessData *AccessData + + // Force finish to use this access data, to allow access data reuse + ForceAccessData *AccessData + RedirectURI string + Scope string + Username string + Password string + AssertionType string + Assertion string + + // Set if request is authorized + Authorized bool + + // Token expiration in seconds. Change if different from default + Expiration time.Duration + + // Set if a refresh token should be generated + GenerateRefresh bool + + // Data to be passed to storage. Not used by the library. + UserData interface{} + + // HttpRequest *http.Request for special use + HttpRequest *http.Request + + // Optional code_verifier as described in rfc7636 + CodeVerifier string } type accessCtxKey struct{} // AccessTokenFromContext returns the Access Token info associated with the ctx. -func AccessTokenFromContext(ctx context.Context) *AccessInfo { - if a, ok := ctx.Value(accessCtxKey{}).(*AccessInfo); ok { +func AccessTokenFromContext(ctx context.Context) *AccessData { + if a, ok := ctx.Value(accessCtxKey{}).(*AccessData); ok { return a } @@ -28,20 +86,20 @@ func AccessTokenFromContext(ctx context.Context) *AccessInfo { } // AccessTokenToContext returns new context with Access Token info. -func AccessTokenToContext(ctx context.Context, access *AccessInfo) context.Context { +func AccessTokenToContext(ctx context.Context, access *AccessData) context.Context { return context.WithValue(ctx, accessCtxKey{}, access) } -// AccessInfo represents an access grant (tokens, expiration, client, etc). -type AccessInfo struct { +// AccessData represents an access grant (tokens, expiration, client, etc). +type AccessData struct { // Client information Client Client // Authorize data, for authorization code - AuthorizeData *AuthorizeInfo + AuthorizeData *AuthorizeData // Previous access data, for refresh token - AccessInfo *AccessInfo + AccessData *AccessData // Access token AccessToken string @@ -50,7 +108,7 @@ type AccessInfo struct { RefreshToken string // Token expiration in seconds - ExpiresIn int32 + ExpiresIn int64 // Requested scope Scope string @@ -66,16 +124,575 @@ type AccessInfo struct { } // IsExpired returns true if access expired. -func (i *AccessInfo) IsExpired() bool { +func (i *AccessData) IsExpired() bool { return i.IsExpiredAt(time.Now()) } // IsExpiredAt returns true if access expires at time 't'. -func (i *AccessInfo) IsExpiredAt(t time.Time) bool { +func (i *AccessData) IsExpiredAt(t time.Time) bool { return i.ExpireAt().Before(t) } // ExpireAt returns the expiration date. -func (i *AccessInfo) ExpireAt() time.Time { +func (i *AccessData) ExpireAt() time.Time { return i.CreatedAt.Add(time.Duration(i.ExpiresIn) * time.Second) } + +// HandleAccessRequest is the http.HandlerFunc for handling access token requests. +func (s *Server) HandleAccessRequest(w *Response, r *http.Request) *AccessRequest { + // Only allow GET or POST + if r.Method == http.MethodGet { + if !s.cfg.AllowGetAccessRequest { + s.setErrorAndLog(w, E_INVALID_REQUEST, ErrRequestMustBePost, "access_request=%s", "GET request not allowed") + + return nil + } + } else if r.Method != http.MethodPost { + s.setErrorAndLog(w, E_INVALID_REQUEST, ErrRequestMustBePost, "access_request=%s", "request must be POST") + + return nil + } + + if err := r.ParseForm(); err != nil { + s.setErrorAndLog(w, E_INVALID_REQUEST, err, "access_request=%s", "parsing error") + + return nil + } + + grantType := AccessRequestType(r.FormValue("grant_type")) + if !s.cfg.AllowedAccessTypes.Exists(grantType) { + s.setErrorAndLog(w, E_UNSUPPORTED_GRANT_TYPE, nil, "access_request=%s", "unknown grant type") + + return nil + } + + //nolint: exhaustive + switch grantType { + case AUTHORIZATION_CODE: + return s.handleAuthorizationCodeRequest(w, r) + case REFRESH_TOKEN: + return s.handleRefreshTokenRequest(w, r) + case PASSWORD: + return s.handlePasswordRequest(w, r) + case CLIENT_CREDENTIALS: + return s.handleClientCredentialsRequest(w, r) + case ASSERTION: + return s.handleAssertionRequest(w, r) + default: + s.setErrorAndLog(w, E_UNSUPPORTED_GRANT_TYPE, nil, "access_request=%s", "unknown grant type") + + return nil + } +} + +func (s *Server) handleAuthorizationCodeRequest(w *Response, r *http.Request) *AccessRequest { + // get client authentication + auth := s.getClientAuth(w, r, s.cfg.AllowClientSecretInParams) + if auth == nil { + return nil + } + + // generate access token + ret := &AccessRequest{ + Type: AUTHORIZATION_CODE, + Code: r.FormValue("code"), + CodeVerifier: r.FormValue("code_verifier"), + RedirectURI: r.FormValue("redirect_uri"), + GenerateRefresh: true, + Expiration: s.cfg.AccessExpiration, + HttpRequest: r, + } + + // "code" is required + if ret.Code == "" { + s.setErrorAndLog(w, E_INVALID_GRANT, nil, "auth_code_request=%s", "code is required") + + return nil + } + + // must have a valid client + if ret.Client = s.getClient(auth, w.Storage, w); ret.Client == nil { + return nil + } + + // must be a valid authorization code + var err error + + ret.AuthorizeData, err = w.Storage.LoadAuthorize(ret.Code) + if err != nil { + s.setErrorAndLog(w, E_INVALID_GRANT, err, "auth_code_request=%s", "error loading authorize data") + + return nil + } + + if ret.AuthorizeData == nil { + s.setErrorAndLog(w, E_UNAUTHORIZED_CLIENT, nil, "auth_code_request=%s", "authorization data is nil") + + return nil + } + + if ret.AuthorizeData.Client == nil { + s.setErrorAndLog(w, E_UNAUTHORIZED_CLIENT, nil, "auth_code_request=%s", "authorization client is nil") + + return nil + } + + if ret.AuthorizeData.Client.GetRedirectURI() == "" { + s.setErrorAndLog(w, E_UNAUTHORIZED_CLIENT, nil, "auth_code_request=%s", "client redirect uri is empty") + + return nil + } + + if ret.AuthorizeData.IsExpiredAt(s.now()) { + s.setErrorAndLog(w, E_INVALID_GRANT, nil, "auth_code_request=%s", "authorization data is expired") + + return nil + } + + // code must be from the client + if ret.AuthorizeData.Client.GetID() != ret.Client.GetID() { + s.setErrorAndLog(w, E_INVALID_GRANT, nil, "auth_code_request=%s", "client code does not match") + + return nil + } + + // check redirect uri + if ret.RedirectURI == "" { + ret.RedirectURI = FirstURI(ret.Client.GetRedirectURI(), s.cfg.RedirectURISeparator) + } + + if realRedirectURI, err := ValidateURIList(ret.Client.GetRedirectURI(), ret.RedirectURI, s.cfg.RedirectURISeparator); err != nil { + s.setErrorAndLog(w, E_INVALID_REQUEST, err, "auth_code_request=%s", "error validating client redirect") + + return nil + } else { + ret.RedirectURI = realRedirectURI + } + + if ret.AuthorizeData.RedirectURI != ret.RedirectURI { + s.setErrorAndLog(w, E_INVALID_REQUEST, ErrRedirectURINotSame, "auth_code_request=%s", "client redirect does not match authorization data") + + return nil + } + + // Verify PKCE, if present in the authorization data + if len(ret.AuthorizeData.CodeChallenge) > 0 { + // https://tools.ietf.org/html/rfc7636#section-4.1 + if matched := pkceMatcher.MatchString(ret.CodeVerifier); !matched { + s.setErrorAndLog( + w, + E_INVALID_REQUEST, + ErrCodeVerifierInvalidFormat, + "auth_code_request=%s", + "pkce code challenge verifier does not match", + ) + + return nil + } + + // https: //tools.ietf.org/html/rfc7636#section-4.6 + codeVerifier := "" + + switch ret.AuthorizeData.CodeChallengeMethod { + case "", PKCE_PLAIN: + codeVerifier = ret.CodeVerifier + case PKCE_S256: + hash := sha256.Sum256([]byte(ret.CodeVerifier)) + codeVerifier = base64.RawURLEncoding.EncodeToString(hash[:]) + default: + s.setErrorAndLog( + w, + E_INVALID_REQUEST, + nil, + "auth_code_request=%s", + "pkce transform algorithm not supported (rfc7636)", + ) + + return nil + } + + if codeVerifier != ret.AuthorizeData.CodeChallenge { + s.setErrorAndLog( + w, + E_INVALID_GRANT, + ErrCodeChallengeNotSame, + "auth_code_request=%s", + "pkce code verifier does not match challenge", + ) + + return nil + } + } + + // set rest of data + ret.Scope = ret.AuthorizeData.Scope + ret.UserData = ret.AuthorizeData.UserData + + return ret +} + +func extraScopes(access_scopes, refresh_scopes string) bool { + access_scopes_list := strings.Split(access_scopes, " ") + refresh_scopes_list := strings.Split(refresh_scopes, " ") + + access_map := make(map[string]int) + + for _, scope := range access_scopes_list { + if scope == "" { + continue + } + + access_map[scope] = 1 + } + + for _, scope := range refresh_scopes_list { + if scope == "" { + continue + } + + if _, ok := access_map[scope]; !ok { + return true + } + } + + return false +} + +func (s *Server) handleRefreshTokenRequest(w *Response, r *http.Request) *AccessRequest { + // get client authentication + auth := s.getClientAuth(w, r, s.cfg.AllowClientSecretInParams) + if auth == nil { + return nil + } + + // generate access token + ret := &AccessRequest{ + Type: REFRESH_TOKEN, + Code: r.FormValue("refresh_token"), + Scope: r.FormValue("scope"), + GenerateRefresh: true, + Expiration: s.cfg.AccessExpiration, + HttpRequest: r, + } + + // "refresh_token" is required + if ret.Code == "" { + s.setErrorAndLog(w, E_INVALID_GRANT, nil, "refresh_token=%s", "refresh_token is required") + + return nil + } + + // must have a valid client + if ret.Client = s.getClient(auth, w.Storage, w); ret.Client == nil { + return nil + } + + // must be a valid refresh code + var err error + + ret.AccessData, err = w.Storage.LoadRefresh(ret.Code) + if err != nil { + s.setErrorAndLog(w, E_INVALID_GRANT, err, "refresh_token=%s", "error loading access data") + + return nil + } + + if ret.AccessData == nil { + s.setErrorAndLog(w, E_UNAUTHORIZED_CLIENT, nil, "refresh_token=%s", "access data is nil") + + return nil + } + + if ret.AccessData.Client == nil { + s.setErrorAndLog(w, E_UNAUTHORIZED_CLIENT, nil, "refresh_token=%s", "access data client is nil") + + return nil + } + + if ret.AccessData.Client.GetRedirectURI() == "" { + s.setErrorAndLog(w, E_UNAUTHORIZED_CLIENT, nil, "refresh_token=%s", "access data client redirect uri is empty") + + return nil + } + + // client must be the same as the previous token + if ret.AccessData.Client.GetID() != ret.Client.GetID() { + s.setErrorAndLog( + w, + E_INVALID_CLIENT, + ErrClientIDNotSame, + "refresh_token=%s, current=%v, previous=%v", + "client mismatch", + ret.Client.GetID(), + ret.AccessData.Client.GetID(), + ) + + return nil + } + + // set rest of data + ret.RedirectURI = ret.AccessData.RedirectURI + ret.UserData = ret.AccessData.UserData + + if ret.Scope == "" { + ret.Scope = ret.AccessData.Scope + } + + if extraScopes(ret.AccessData.Scope, ret.Scope) { + s.setErrorAndLog(w, E_ACCESS_DENIED, ErrExtraScope, "refresh_token=%s", ErrExtraScope.Error()) + + return nil + } + + return ret +} + +//nolint:dupl +func (s *Server) handlePasswordRequest(w *Response, r *http.Request) *AccessRequest { + // get client authentication + auth := s.getClientAuth(w, r, s.cfg.AllowClientSecretInParams) + if auth == nil { + return nil + } + + // generate access token + ret := &AccessRequest{ + Type: PASSWORD, + Username: r.FormValue("username"), + Password: r.FormValue("password"), + Scope: r.FormValue("scope"), + GenerateRefresh: true, + Expiration: s.cfg.AccessExpiration, + HttpRequest: r, + } + + // "username" and "password" is required + if ret.Username == "" || ret.Password == "" { + s.setErrorAndLog(w, E_INVALID_GRANT, nil, "handle_password=%s", "username and pass required") + + return nil + } + + // must have a valid client + if ret.Client = s.getClient(auth, w.Storage, w); ret.Client == nil { + return nil + } + + // set redirect uri + ret.RedirectURI = FirstURI(ret.Client.GetRedirectURI(), s.cfg.RedirectURISeparator) + + return ret +} + +func (s *Server) handleClientCredentialsRequest(w *Response, r *http.Request) *AccessRequest { + // get client authentication + auth := s.getClientAuth(w, r, s.cfg.AllowClientSecretInParams) + if auth == nil { + return nil + } + + // generate access token + ret := &AccessRequest{ + Type: CLIENT_CREDENTIALS, + Scope: r.FormValue("scope"), + GenerateRefresh: false, + Expiration: s.cfg.AccessExpiration, + HttpRequest: r, + } + + // must have a valid client + if ret.Client = s.getClient(auth, w.Storage, w); ret.Client == nil { + return nil + } + + // set redirect uri + ret.RedirectURI = FirstURI(ret.Client.GetRedirectURI(), s.cfg.RedirectURISeparator) + + return ret +} + +//nolint:dupl +func (s *Server) handleAssertionRequest(w *Response, r *http.Request) *AccessRequest { + // get client authentication + auth := s.getClientAuth(w, r, s.cfg.AllowClientSecretInParams) + if auth == nil { + return nil + } + + // generate access token + ret := &AccessRequest{ + Type: ASSERTION, + Scope: r.FormValue("scope"), + AssertionType: r.FormValue("assertion_type"), + Assertion: r.FormValue("assertion"), + GenerateRefresh: false, // assertion should NOT generate a refresh token, per the RFC + Expiration: s.cfg.AccessExpiration, + HttpRequest: r, + } + + // "assertion_type" and "assertion" is required + if ret.AssertionType == "" || ret.Assertion == "" { + s.setErrorAndLog(w, E_INVALID_GRANT, nil, "handle_assertion_request=%s", "assertion and assertion_type required") + + return nil + } + + // must have a valid client + if ret.Client = s.getClient(auth, w.Storage, w); ret.Client == nil { + return nil + } + + // set redirect uri + ret.RedirectURI = FirstURI(ret.Client.GetRedirectURI(), s.cfg.RedirectURISeparator) + + return ret +} + +func (s *Server) FinishAccessRequest(w *Response, r *http.Request, ar *AccessRequest) { + // don't process if is already an error + if w.IsError { + return + } + + redirectUri := r.FormValue("redirect_uri") + + // Get redirect uri from AccessRequest if it's there (e.g., refresh token request) + if ar.RedirectURI != "" { + redirectUri = ar.RedirectURI + } + + if !ar.Authorized { + s.setErrorAndLog(w, E_ACCESS_DENIED, nil, "finish_access_request=%s", "authorization failed") + + return + } + + var ret *AccessData + + var err error + + if ar.ForceAccessData == nil { + // generate access token + ret = &AccessData{ + Client: ar.Client, + AuthorizeData: ar.AuthorizeData, + AccessData: ar.AccessData, + RedirectURI: redirectUri, + CreatedAt: s.now(), + ExpiresIn: int64(ar.Expiration.Seconds()), + UserData: ar.UserData, + Scope: ar.Scope, + } + + // generate access token + // @TODO: add ret at first arg for GenerateAccessToken + ret.AccessToken, ret.RefreshToken, err = s.tokenGenerator.GenerateAccessToken(ar.GenerateRefresh) + if err != nil { + s.setErrorAndLog(w, E_SERVER_ERROR, err, "finish_access_request=%s", "error generating token") + + return + } + } else { + ret = ar.ForceAccessData + } + + // save access token + if err = w.Storage.SaveAccess(ret); err != nil { + s.setErrorAndLog(w, E_SERVER_ERROR, err, "finish_access_request=%s", "error saving access token") + + return + } + + // remove authorization token + if ret.AuthorizeData != nil { + if err := w.Storage.RemoveAuthorize(ret.AuthorizeData.Code); err != nil { + s.logger.Error("oauth2: remove autorize code failed", logger.WithLabels(map[string]interface{}{ + "code": ret.AuthorizeData.Code, + })) + } + } + + // remove previous access token + if ret.AccessData != nil && !s.cfg.RetainTokenAfterRefresh { + if ret.AccessData.RefreshToken != "" { + if err := w.Storage.RemoveRefresh(ret.AccessData.RefreshToken); err != nil { + s.logger.Error("oauth2: remove refresh token failed", logger.WithLabels(map[string]interface{}{ + "refresh_token": ret.AccessData.RefreshToken, + })) + } + } + + if err := w.Storage.RemoveAccess(ret.AccessData.AccessToken); err != nil { + s.logger.Error("oauth2: remove access token failed", logger.WithLabels(map[string]interface{}{ + "access_token": ret.AccessData.AccessToken, + })) + } + } + + // output data + w.Output["access_token"] = ret.AccessToken + w.Output["token_type"] = s.cfg.TokenType + w.Output["expires_in"] = ret.ExpiresIn + + if ret.RefreshToken != "" { + w.Output["refresh_token"] = ret.RefreshToken + } + + if ret.Scope != "" { + w.Output["scope"] = ret.Scope + } +} + +// Helper Functions + +// getClient looks up and authenticates the basic auth using the given +// storage. Sets an error on the response if auth fails or a server error occurs. +func (s Server) getClient(creds *credential.UsernamePasswordCredential, storage StorageProvider, w *Response) Client { + // nolint:forcetypeassert + client, err := storage.LoadClient(creds.GetPrincipal().(string)) + if errors.Is(err, ErrClientNotFound) { + s.setErrorAndLog(w, E_UNAUTHORIZED_CLIENT, nil, "get_client=%s", "not found") + + return nil + } + + if err != nil { + s.setErrorAndLog(w, E_SERVER_ERROR, err, "get_client=%s", "error finding client") + + return nil + } + + if client == nil { + s.setErrorAndLog(w, E_UNAUTHORIZED_CLIENT, nil, "get_client=%s", "client is nil") + + return nil + } + + // nolint:forcetypeassert + if !CheckClientSecret(client, creds.GetCredentials().(string)) { + s.setErrorAndLog(w, E_UNAUTHORIZED_CLIENT, nil, "get_client=%s, client_id=%v", "client check failed", client.GetID()) + + return nil + } + + if client.GetRedirectURI() == "" { + s.setErrorAndLog(w, E_UNAUTHORIZED_CLIENT, nil, "get_client=%s", "client redirect uri is empty") + + return nil + } + + return client +} + +// setErrorAndLog sets the response error and internal error (if non-nil) and logs them along with the provided debug format string and arguments. +func (s Server) setErrorAndLog(w *Response, responseError DefaultErrorID, internalError error, debugFormat string, debugArgs ...interface{}) { + format := "error=%v, internal_error=%#v " + debugFormat + + w.InternalError = internalError + w.SetError(responseError, "") + + s.logger.Error( + format, + logger.WithFormat(append([]interface{}{responseError, internalError}, debugArgs...)...), + ) +} diff --git a/authentication/provider/oauth2/access_test.go b/authentication/provider/oauth2/access_test.go index 2b2e584..2b2f376 100644 --- a/authentication/provider/oauth2/access_test.go +++ b/authentication/provider/oauth2/access_test.go @@ -6,17 +6,24 @@ package oauth2 import ( "context" + "errors" + "net/http" + "net/http/httptest" + "net/url" + "strconv" + "strings" "testing" "time" + "github.com/hyperscale-stack/security/authentication/credential" "github.com/stretchr/testify/assert" ) -func TestAccessInfo(t *testing.T) { +func TestAccessData(t *testing.T) { cat, err := time.Parse("2006-01-02T15:04:05.000Z", "2014-11-12T11:45:26.371Z") assert.NoError(t, err) - ai := &AccessInfo{ + ai := &AccessData{ CreatedAt: cat, ExpiresIn: 10, } @@ -27,7 +34,7 @@ func TestAccessInfo(t *testing.T) { func TestAccessTokenContext(t *testing.T) { ctx := context.Background() - ai := &AccessInfo{ + ai := &AccessData{ CreatedAt: time.Now(), ExpiresIn: 10, } @@ -43,6 +50,361 @@ func TestFromContextWithEmptyContext(t *testing.T) { ctx := context.Background() ai := AccessTokenFromContext(ctx) - assert.Nil(t, ai) } + +func TestServerHandleAccessRequestWithGetMethodNotAllowed(t *testing.T) { + cfg := &Configuration{ + ErrorStatusCode: http.StatusOK, + AllowGetAccessRequest: false, + } + s := NewServer(WithConfig(cfg)) + + w := s.NewResponse() + + req := httptest.NewRequest(http.MethodGet, "http://example.com/v1/me", nil) + + ar := s.HandleAccessRequest(w, req) + assert.Nil(t, ar) + + assert.Equal(t, E_INVALID_REQUEST, w.ErrorID) +} + +func TestServerHandleAccessRequestWithBadMethod(t *testing.T) { + cfg := &Configuration{ + ErrorStatusCode: http.StatusOK, + AllowGetAccessRequest: false, + } + s := NewServer(WithConfig(cfg)) + + w := s.NewResponse() + + req := httptest.NewRequest(http.MethodPut, "http://example.com/v1/me", nil) + + ar := s.HandleAccessRequest(w, req) + assert.Nil(t, ar) + + assert.Equal(t, E_INVALID_REQUEST, w.ErrorID) +} + +func TestServerHandleAccessRequestWithBadBody(t *testing.T) { + cfg := &Configuration{ + ErrorStatusCode: http.StatusOK, + AllowGetAccessRequest: false, + } + s := NewServer(WithConfig(cfg)) + + w := s.NewResponse() + + req := httptest.NewRequest(http.MethodPost, "http://example.com/v1/me?f$$", nil) + + req.Body = nil + + ar := s.HandleAccessRequest(w, req) + assert.Nil(t, ar) + + assert.Equal(t, E_INVALID_REQUEST, w.ErrorID) +} + +func TestServerHandleAccessRequestWithEmptyBody(t *testing.T) { + cfg := &Configuration{ + ErrorStatusCode: http.StatusOK, + AllowGetAccessRequest: false, + } + s := NewServer(WithConfig(cfg)) + + w := s.NewResponse() + + req := httptest.NewRequest(http.MethodPost, "http://example.com/v1/me?f$$", nil) + + ar := s.HandleAccessRequest(w, req) + assert.Nil(t, ar) + + assert.Equal(t, E_UNSUPPORTED_GRANT_TYPE, w.ErrorID) +} + +func TestServerHandleAccessRequestWithPasswordGrandTypeWithInvalidRequest(t *testing.T) { + cfg := &Configuration{ + ErrorStatusCode: http.StatusOK, + AllowGetAccessRequest: false, + AllowedAccessTypes: AllowedAccessType{ + PASSWORD, + }, + } + s := NewServer(WithConfig(cfg)) + + w := s.NewResponse() + + data := url.Values{} + + data.Set("grant_type", "password") + + req := httptest.NewRequest(http.MethodPost, "http://example.com/oauth/token", strings.NewReader(data.Encode())) + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") + req.Header.Add("Content-Length", strconv.Itoa(len(data.Encode()))) + + ar := s.HandleAccessRequest(w, req) + assert.Nil(t, ar) + + assert.Equal(t, E_INVALID_REQUEST, w.ErrorID) +} + +func TestServerGetClientWithErrClientNotFound(t *testing.T) { + cfg := &Configuration{} + storageMock := &MockStorageProvider{} + + storageMock.On("LoadClient", "9b48f589-735c-476c-aa5f-eae9e2422d01").Return(nil, ErrClientNotFound) + + s := NewServer(WithConfig(cfg), WithStorage(storageMock)) + + w := s.NewResponse() + + creds := credential.NewUsernamePasswordCredential("9b48f589-735c-476c-aa5f-eae9e2422d01", "foo") + + c := s.getClient(creds, storageMock, w) + assert.Nil(t, c) + + assert.Equal(t, E_UNAUTHORIZED_CLIENT, w.ErrorID) + + storageMock.AssertExpectations(t) +} + +func TestServerGetClientWithErr(t *testing.T) { + cfg := &Configuration{} + storageMock := &MockStorageProvider{} + + storageMock.On("LoadClient", "9b48f589-735c-476c-aa5f-eae9e2422d01").Return(nil, errors.New("foo")) + + s := NewServer(WithConfig(cfg), WithStorage(storageMock)) + + w := s.NewResponse() + + creds := credential.NewUsernamePasswordCredential("9b48f589-735c-476c-aa5f-eae9e2422d01", "foo") + + c := s.getClient(creds, storageMock, w) + assert.Nil(t, c) + + assert.Equal(t, E_SERVER_ERROR, w.ErrorID) + + storageMock.AssertExpectations(t) +} + +func TestServerGetClientWithClientEmpry(t *testing.T) { + cfg := &Configuration{} + storageMock := &MockStorageProvider{} + + storageMock.On("LoadClient", "9b48f589-735c-476c-aa5f-eae9e2422d01").Return(nil, nil) + + s := NewServer(WithConfig(cfg), WithStorage(storageMock)) + + w := s.NewResponse() + + creds := credential.NewUsernamePasswordCredential("9b48f589-735c-476c-aa5f-eae9e2422d01", "foo") + + c := s.getClient(creds, storageMock, w) + assert.Nil(t, c) + + assert.Equal(t, E_UNAUTHORIZED_CLIENT, w.ErrorID) + + storageMock.AssertExpectations(t) +} + +func TestServerGetClientWithClientBadSecret(t *testing.T) { + cfg := &Configuration{} + storageMock := &MockStorageProvider{} + + client := &DefaultClient{ + Secret: "bar", + } + + storageMock.On("LoadClient", "9b48f589-735c-476c-aa5f-eae9e2422d01").Return(client, nil) + + s := NewServer(WithConfig(cfg), WithStorage(storageMock)) + + w := s.NewResponse() + + creds := credential.NewUsernamePasswordCredential("9b48f589-735c-476c-aa5f-eae9e2422d01", "foo") + + c := s.getClient(creds, storageMock, w) + assert.Nil(t, c) + + assert.Equal(t, E_UNAUTHORIZED_CLIENT, w.ErrorID) + + storageMock.AssertExpectations(t) +} + +func TestServerGetClientWithClientBadRedirect(t *testing.T) { + cfg := &Configuration{} + storageMock := &MockStorageProvider{} + + client := &DefaultClient{ + Secret: "foo", + RedirectURI: "", + } + + storageMock.On("LoadClient", "9b48f589-735c-476c-aa5f-eae9e2422d01").Return(client, nil) + + s := NewServer(WithConfig(cfg), WithStorage(storageMock)) + + w := s.NewResponse() + + creds := credential.NewUsernamePasswordCredential("9b48f589-735c-476c-aa5f-eae9e2422d01", "foo") + + c := s.getClient(creds, storageMock, w) + assert.Nil(t, c) + + assert.Equal(t, E_UNAUTHORIZED_CLIENT, w.ErrorID) + + storageMock.AssertExpectations(t) +} + +func TestServerGetClient(t *testing.T) { + cfg := &Configuration{} + storageMock := &MockStorageProvider{} + + client := &DefaultClient{ + Secret: "foo", + RedirectURI: "https://auth.mydomain.tld/connect", + } + + storageMock.On("LoadClient", "9b48f589-735c-476c-aa5f-eae9e2422d01").Return(client, nil) + + s := NewServer(WithConfig(cfg), WithStorage(storageMock)) + + w := s.NewResponse() + + creds := credential.NewUsernamePasswordCredential("9b48f589-735c-476c-aa5f-eae9e2422d01", "foo") + + c := s.getClient(creds, storageMock, w) + assert.Same(t, client, c) + + storageMock.AssertExpectations(t) +} + +func TestExtraScopes(t *testing.T) { + assert.True(t, extraScopes("foo bar", "foo bar jar")) + assert.False(t, extraScopes("foo bar", "foo bar")) + + assert.False(t, extraScopes(" ", " ")) +} + +func TestServerHandlePasswordRequestWithEmptyUsernameAndPassword(t *testing.T) { + cfg := &Configuration{ + ErrorStatusCode: http.StatusOK, + AllowGetAccessRequest: false, + AllowedAccessTypes: AllowedAccessType{ + PASSWORD, + }, + } + + storageMock := &MockStorageProvider{} + + storageMock.On("LoadClient", "50542ad2-5983-4977-baab-ef3794f08c89").Return(nil, nil) + + s := NewServer(WithConfig(cfg), WithStorage(storageMock)) + + w := s.NewResponse() + /* + data := url.Values{} + + data.Set("grant_type", "password") + data.Set("username", "") + data.Set("password", "") + */ + req := httptest.NewRequest(http.MethodPost, "http://example.com/oauth/token", nil) + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") + //req.Header.Add("Content-Length", strconv.Itoa(len(data.Encode()))) + + ctx := req.Context() + ctx = credential.ToContext(ctx, credential.NewUsernamePasswordCredential("", "")) + req = req.WithContext(ctx) + + ar := s.handlePasswordRequest(w, req) + assert.Nil(t, ar) + + assert.Equal(t, E_INVALID_GRANT, w.ErrorID) +} + +func TestServerHandlePasswordRequestWithClientNotFound(t *testing.T) { + cfg := &Configuration{ + ErrorStatusCode: http.StatusOK, + AllowGetAccessRequest: false, + AllowedAccessTypes: AllowedAccessType{ + PASSWORD, + }, + RedirectURISeparator: " ", + } + + storageMock := &MockStorageProvider{} + + storageMock.On("LoadClient", "50542ad2-5983-4977-baab-ef3794f08c89").Return(nil, ErrClientNotFound) + + s := NewServer(WithConfig(cfg), WithStorage(storageMock)) + + w := s.NewResponse() + + data := url.Values{} + + data.Set("grant_type", "password") + data.Set("username", "user@domain.com") + data.Set("password", "passw0rd") + + req := httptest.NewRequest(http.MethodPost, "http://example.com/oauth/token", strings.NewReader(data.Encode())) + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") + + ctx := req.Context() + ctx = credential.ToContext(ctx, credential.NewUsernamePasswordCredential("50542ad2-5983-4977-baab-ef3794f08c89", "mySecretPassw0rd!")) + req = req.WithContext(ctx) + + ar := s.handlePasswordRequest(w, req) + assert.Nil(t, ar) + + assert.Equal(t, E_UNAUTHORIZED_CLIENT, w.ErrorID) + + storageMock.AssertExpectations(t) +} + +func TestServerHandlePasswordRequest(t *testing.T) { + cfg := &Configuration{ + ErrorStatusCode: http.StatusOK, + AllowGetAccessRequest: false, + AllowedAccessTypes: AllowedAccessType{ + PASSWORD, + }, + RedirectURISeparator: " ", + } + + storageMock := &MockStorageProvider{} + + client := &DefaultClient{ + ID: "50542ad2-5983-4977-baab-ef3794f08c89", + Secret: "mySecretPassw0rd!", + RedirectURI: "https://auth.mydomain.tld/connect", + } + + storageMock.On("LoadClient", "50542ad2-5983-4977-baab-ef3794f08c89").Return(client, nil) + + s := NewServer(WithConfig(cfg), WithStorage(storageMock)) + + w := s.NewResponse() + + data := url.Values{} + + data.Set("grant_type", "password") + data.Set("username", "user@domain.com") + data.Set("password", "passw0rd") + + req := httptest.NewRequest(http.MethodPost, "http://example.com/oauth/token", strings.NewReader(data.Encode())) + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") + + ctx := req.Context() + ctx = credential.ToContext(ctx, credential.NewUsernamePasswordCredential("50542ad2-5983-4977-baab-ef3794f08c89", "mySecretPassw0rd!")) + req = req.WithContext(ctx) + + ar := s.handlePasswordRequest(w, req) + assert.NotNil(t, ar) + + assert.Same(t, client, ar.Client) + + storageMock.AssertExpectations(t) +} diff --git a/authentication/provider/oauth2/authorize.go b/authentication/provider/oauth2/authorize.go index 4da0dad..091df8c 100644 --- a/authentication/provider/oauth2/authorize.go +++ b/authentication/provider/oauth2/authorize.go @@ -4,10 +4,56 @@ package oauth2 -import "time" +import ( + "net/http" + "regexp" + "time" +) -// AuthorizeInfo info. -type AuthorizeInfo struct { +// AuthorizeRequestType is the type for OAuth param `response_type`. +type AuthorizeRequestType string + +const ( + CODE AuthorizeRequestType = "code" + TOKEN AuthorizeRequestType = "token" + + PKCE_PLAIN = "plain" + PKCE_S256 = "S256" +) + +var ( + pkceMatcher = regexp.MustCompile("^[a-zA-Z0-9~._-]{43,128}$") +) + +// Authorize request information. +type AuthorizeRequest struct { + Type AuthorizeRequestType + Client Client + Scope string + RedirectUri string + State string + + // Set if request is authorized + Authorized bool + + // Token expiration in seconds. Change if different from default. + // If type = TOKEN, this expiration will be for the ACCESS token. + Expiration int32 + + // Data to be passed to storage. Not used by the library. + UserData interface{} + + // HttpRequest *http.Request for special use + HttpRequest *http.Request + + // Optional code_challenge as described in rfc7636 + CodeChallenge string + // Optional code_challenge_method as described in rfc7636 + CodeChallengeMethod string +} + +// AuthorizeData info. +type AuthorizeData struct { // Client information Client Client @@ -40,16 +86,16 @@ type AuthorizeInfo struct { } // IsExpired is true if authorization expired. -func (i *AuthorizeInfo) IsExpired() bool { +func (i *AuthorizeData) IsExpired() bool { return i.IsExpiredAt(time.Now()) } // IsExpired is true if authorization expires at time 't'. -func (i *AuthorizeInfo) IsExpiredAt(t time.Time) bool { +func (i *AuthorizeData) IsExpiredAt(t time.Time) bool { return i.ExpireAt().Before(t) } // ExpireAt returns the expiration date. -func (i *AuthorizeInfo) ExpireAt() time.Time { +func (i *AuthorizeData) ExpireAt() time.Time { return i.CreatedAt.Add(time.Duration(i.ExpiresIn) * time.Second) } diff --git a/authentication/provider/oauth2/authorize_test.go b/authentication/provider/oauth2/authorize_test.go index 8a15920..d21a455 100644 --- a/authentication/provider/oauth2/authorize_test.go +++ b/authentication/provider/oauth2/authorize_test.go @@ -11,11 +11,11 @@ import ( "github.com/stretchr/testify/assert" ) -func TestAuthorizeInfo(t *testing.T) { +func TestAuthorizeData(t *testing.T) { cat, err := time.Parse("2006-01-02T15:04:05.000Z", "2014-11-12T11:45:26.371Z") assert.NoError(t, err) - ai := &AuthorizeInfo{ + ai := &AuthorizeData{ CreatedAt: cat, ExpiresIn: 10, } diff --git a/authentication/provider/oauth2/configuration.go b/authentication/provider/oauth2/configuration.go new file mode 100644 index 0000000..9aa86bd --- /dev/null +++ b/authentication/provider/oauth2/configuration.go @@ -0,0 +1,80 @@ +// Copyright 2021 Hyperscale. All rights reserved. +// Use of this source code is governed by a MIT +// license that can be found in the LICENSE file. + +package oauth2 + +import "time" + +// AllowedAuthorizeType is a collection of allowed auth request types. +type AllowedAuthorizeType []AuthorizeRequestType + +// Exists returns true if the auth type exists in the list. +func (t AllowedAuthorizeType) Exists(rt AuthorizeRequestType) bool { + for _, k := range t { + if k == rt { + return true + } + } + + return false +} + +// AllowedAccessType is a collection of allowed access request types. +type AllowedAccessType []AccessRequestType + +// Exists returns true if the access type exists in the list. +func (t AllowedAccessType) Exists(rt AccessRequestType) bool { + for _, k := range t { + if k == rt { + return true + } + } + + return false +} + +type Configuration struct { + PrefixURI string + + // Token type to return + TokenType string + + // old + + AuthorizationExpiration time.Duration + + AccessExpiration time.Duration + + // List of allowed authorize types (only CODE by default). + AllowedAuthorizeTypes AllowedAuthorizeType + + // List of allowed access types (only AUTHORIZATION_CODE by default). + AllowedAccessTypes AllowedAccessType + + // HTTP status code to return for errors - default 200 + // Only used if response was created from server. + ErrorStatusCode int + + // If true allows client secret also in params, else only in + // Authorization header - default false. + AllowClientSecretInParams bool + + // If true allows access request using GET, else only POST - default false. + AllowGetAccessRequest bool + + // Require PKCE for code flows for public OAuth clients - default false. + RequirePKCEForPublicClients bool + + // Separator to support multiple URIs in Client.GetRedirectUri(). + // If blank (the default), don't allow multiple URIs. + RedirectURISeparator string + + // RetainTokenAfter Refresh allows the server to retain the access and + // refresh token for re-use - default false. + RetainTokenAfterRefresh bool +} + +func NewConfiguration() *Configuration { + return &Configuration{} +} diff --git a/authentication/provider/oauth2/configuration_test.go b/authentication/provider/oauth2/configuration_test.go new file mode 100644 index 0000000..9bcfcfb --- /dev/null +++ b/authentication/provider/oauth2/configuration_test.go @@ -0,0 +1,26 @@ +// Copyright 2021 Hyperscale. All rights reserved. +// Use of this source code is governed by a MIT +// license that can be found in the LICENSE file. + +package oauth2 + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestAllowedAuthorizeType(t *testing.T) { + types := AllowedAuthorizeType{CODE} + + assert.True(t, types.Exists(CODE)) + assert.False(t, types.Exists(TOKEN)) +} + +func TestAllowedAccessType(t *testing.T) { + types := AllowedAccessType{AUTHORIZATION_CODE, CLIENT_CREDENTIALS} + + assert.True(t, types.Exists(AUTHORIZATION_CODE)) + assert.True(t, types.Exists(CLIENT_CREDENTIALS)) + assert.False(t, types.Exists(ASSERTION)) +} diff --git a/authentication/provider/oauth2/error.go b/authentication/provider/oauth2/error.go new file mode 100644 index 0000000..a711947 --- /dev/null +++ b/authentication/provider/oauth2/error.go @@ -0,0 +1,65 @@ +// Copyright 2021 Hyperscale. All rights reserved. +// Use of this source code is governed by a MIT +// license that can be found in the LICENSE file. + +package oauth2 + +type DefaultErrorID string + +func (e DefaultErrorID) String() string { + return string(e) +} + +const ( + E_INVALID_REQUEST DefaultErrorID = "invalid_request" + E_UNAUTHORIZED_CLIENT DefaultErrorID = "unauthorized_client" + E_ACCESS_DENIED DefaultErrorID = "access_denied" + E_UNSUPPORTED_RESPONSE_TYPE DefaultErrorID = "unsupported_response_type" + E_INVALID_SCOPE DefaultErrorID = "invalid_scope" + E_SERVER_ERROR DefaultErrorID = "server_error" + E_TEMPORARILY_UNAVAILABLE DefaultErrorID = "temporarily_unavailable" + E_UNSUPPORTED_GRANT_TYPE DefaultErrorID = "unsupported_grant_type" + E_INVALID_GRANT DefaultErrorID = "invalid_grant" + E_INVALID_CLIENT DefaultErrorID = "invalid_client" +) + +var ( + deferror *DefaultErrors = NewDefaultErrors() +) + +// Default errors and messages. +type DefaultErrors struct { + errormap map[DefaultErrorID]string +} + +// NewDefaultErrors initializes OAuth2 error codes and descriptions. +// http://tools.ietf.org/html/rfc6749#section-4.1.2.1 +// http://tools.ietf.org/html/rfc6749#section-4.2.2.1 +// http://tools.ietf.org/html/rfc6749#section-5.2 +// http://tools.ietf.org/html/rfc6749#section-7.2 +func NewDefaultErrors() *DefaultErrors { + r := &DefaultErrors{ + errormap: make(map[DefaultErrorID]string), + } + + r.errormap[E_INVALID_REQUEST] = "The request is missing a required parameter, includes an invalid parameter value, includes a parameter more than once, or is otherwise malformed." + r.errormap[E_UNAUTHORIZED_CLIENT] = "The client is not authorized to request a token using this method." + r.errormap[E_ACCESS_DENIED] = "The resource owner or authorization server denied the request." + r.errormap[E_UNSUPPORTED_RESPONSE_TYPE] = "The authorization server does not support obtaining a token using this method." + r.errormap[E_INVALID_SCOPE] = "The requested scope is invalid, unknown, or malformed." + r.errormap[E_SERVER_ERROR] = "The authorization server encountered an unexpected condition that prevented it from fulfilling the request." + r.errormap[E_TEMPORARILY_UNAVAILABLE] = "The authorization server is currently unable to handle the request due to a temporary overloading or maintenance of the server." + r.errormap[E_UNSUPPORTED_GRANT_TYPE] = "The authorization grant type is not supported by the authorization server." + r.errormap[E_INVALID_GRANT] = "The provided authorization grant (e.g., authorization code, resource owner credentials) or refresh token is invalid, expired, revoked, does not match the redirection URI used in the authorization request, or was issued to another client." + r.errormap[E_INVALID_CLIENT] = "Client authentication failed (e.g., unknown client, no client authentication included, or unsupported authentication method)." + + return r +} + +func (e *DefaultErrors) Get(id DefaultErrorID) string { + if m, ok := e.errormap[id]; ok { + return m + } + + return id.String() +} diff --git a/authentication/provider/oauth2/error_test.go b/authentication/provider/oauth2/error_test.go new file mode 100644 index 0000000..c2b4e2c --- /dev/null +++ b/authentication/provider/oauth2/error_test.go @@ -0,0 +1,23 @@ +// Copyright 2021 Hyperscale. All rights reserved. +// Use of this source code is governed by a MIT +// license that can be found in the LICENSE file. + +package oauth2 + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestErrorGet(t *testing.T) { + for _, item := range []struct { + id DefaultErrorID + expedted string + }{ + {E_ACCESS_DENIED, "The resource owner or authorization server denied the request."}, + {DefaultErrorID("foo"), "foo"}, + } { + assert.Equal(t, item.expedted, deferror.Get(item.id)) + } +} diff --git a/authentication/provider/oauth2/handler.go b/authentication/provider/oauth2/handler.go new file mode 100644 index 0000000..ec8a0b6 --- /dev/null +++ b/authentication/provider/oauth2/handler.go @@ -0,0 +1,77 @@ +// Copyright 2021 Hyperscale. All rights reserved. +// Use of this source code is governed by a MIT +// license that can be found in the LICENSE file. + +package oauth2 + +import ( + "errors" + "net/http" +) + +var ( + ErrOAuthError = errors.New("oauth2 error") +) + +func (s Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { + //nolint:wsl,gocritic + //@TODO: WIP + switch r.URL.Path { + case s.cfg.PrefixURI + "/token": + s.handleTokenRequest(w, r) + /* + case s.cfg.PrefixURI + "/authorize": + s.handleAuthorizeRequest(w, r) + */ + } +} + +func (s Server) handleTokenRequest(w http.ResponseWriter, r *http.Request) { + resp := s.NewResponse() + + requestType := "" + + if ar := s.HandleAccessRequest(resp, r); ar != nil { + requestType = string(ar.Type) + + //nolint:exhaustive + switch ar.Type { + case AUTHORIZATION_CODE: + ar.Authorized = true + case REFRESH_TOKEN: + ar.Authorized = true + case PASSWORD: + user, err := s.userProvider.Authenticate(ar.Username, ar.Password) + if err != nil { + s.setErrorAndLog(resp, E_ACCESS_DENIED, err, "get_user=%s", "failed") + } else { + ar.Authorized = true + ar.UserData = user.GetID() + } + case CLIENT_CREDENTIALS: + ar.Authorized = true + } + + s.FinishAccessRequest(resp, r, ar) + } + + var err error + + if resp.IsError { + if resp.InternalError != nil { + err = resp.InternalError + } else { + err = ErrOAuthError + } + + s.logger.Error(err.Error()) + + s.emitter.Dispatch("oauth."+requestType+".failed", resp) + } else { + s.emitter.Dispatch("oauth."+requestType+".succeeded", resp) + } + + if err := OutputJSON(resp, w, r); err != nil { + s.logger.Error(err.Error()) + } +} diff --git a/authentication/provider/oauth2/mock_access_provider.go b/authentication/provider/oauth2/mock_access_provider.go index 0630619..c1203fb 100644 --- a/authentication/provider/oauth2/mock_access_provider.go +++ b/authentication/provider/oauth2/mock_access_provider.go @@ -10,15 +10,15 @@ type MockAccessProvider struct { } // LoadAccess provides a mock function with given fields: token -func (_m *MockAccessProvider) LoadAccess(token string) (*AccessInfo, error) { +func (_m *MockAccessProvider) LoadAccess(token string) (*AccessData, error) { ret := _m.Called(token) - var r0 *AccessInfo - if rf, ok := ret.Get(0).(func(string) *AccessInfo); ok { + var r0 *AccessData + if rf, ok := ret.Get(0).(func(string) *AccessData); ok { r0 = rf(token) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(*AccessInfo) + r0 = ret.Get(0).(*AccessData) } } @@ -47,11 +47,11 @@ func (_m *MockAccessProvider) RemoveAccess(token string) error { } // SaveAccess provides a mock function with given fields: _a0 -func (_m *MockAccessProvider) SaveAccess(_a0 *AccessInfo) error { +func (_m *MockAccessProvider) SaveAccess(_a0 *AccessData) error { ret := _m.Called(_a0) var r0 error - if rf, ok := ret.Get(0).(func(*AccessInfo) error); ok { + if rf, ok := ret.Get(0).(func(*AccessData) error); ok { r0 = rf(_a0) } else { r0 = ret.Error(0) diff --git a/authentication/provider/oauth2/mock_authorize_provider.go b/authentication/provider/oauth2/mock_authorize_provider.go index a9b1cbe..83e3e33 100644 --- a/authentication/provider/oauth2/mock_authorize_provider.go +++ b/authentication/provider/oauth2/mock_authorize_provider.go @@ -10,15 +10,15 @@ type MockAuthorizeProvider struct { } // LoadAuthorize provides a mock function with given fields: code -func (_m *MockAuthorizeProvider) LoadAuthorize(code string) (*AuthorizeInfo, error) { +func (_m *MockAuthorizeProvider) LoadAuthorize(code string) (*AuthorizeData, error) { ret := _m.Called(code) - var r0 *AuthorizeInfo - if rf, ok := ret.Get(0).(func(string) *AuthorizeInfo); ok { + var r0 *AuthorizeData + if rf, ok := ret.Get(0).(func(string) *AuthorizeData); ok { r0 = rf(code) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(*AuthorizeInfo) + r0 = ret.Get(0).(*AuthorizeData) } } @@ -47,11 +47,11 @@ func (_m *MockAuthorizeProvider) RemoveAuthorize(code string) error { } // SaveAuthorize provides a mock function with given fields: _a0 -func (_m *MockAuthorizeProvider) SaveAuthorize(_a0 *AuthorizeInfo) error { +func (_m *MockAuthorizeProvider) SaveAuthorize(_a0 *AuthorizeData) error { ret := _m.Called(_a0) var r0 error - if rf, ok := ret.Get(0).(func(*AuthorizeInfo) error); ok { + if rf, ok := ret.Get(0).(func(*AuthorizeData) error); ok { r0 = rf(_a0) } else { r0 = ret.Error(0) diff --git a/authentication/provider/oauth2/mock_refresh_provider.go b/authentication/provider/oauth2/mock_refresh_provider.go index 67b23c5..f8d073b 100644 --- a/authentication/provider/oauth2/mock_refresh_provider.go +++ b/authentication/provider/oauth2/mock_refresh_provider.go @@ -10,15 +10,15 @@ type MockRefreshProvider struct { } // LoadRefresh provides a mock function with given fields: token -func (_m *MockRefreshProvider) LoadRefresh(token string) (*AccessInfo, error) { +func (_m *MockRefreshProvider) LoadRefresh(token string) (*AccessData, error) { ret := _m.Called(token) - var r0 *AccessInfo - if rf, ok := ret.Get(0).(func(string) *AccessInfo); ok { + var r0 *AccessData + if rf, ok := ret.Get(0).(func(string) *AccessData); ok { r0 = rf(token) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(*AccessInfo) + r0 = ret.Get(0).(*AccessData) } } @@ -47,11 +47,11 @@ func (_m *MockRefreshProvider) RemoveRefresh(token string) error { } // SaveRefresh provides a mock function with given fields: _a0 -func (_m *MockRefreshProvider) SaveRefresh(_a0 *AccessInfo) error { +func (_m *MockRefreshProvider) SaveRefresh(_a0 *AccessData) error { ret := _m.Called(_a0) var r0 error - if rf, ok := ret.Get(0).(func(*AccessInfo) error); ok { + if rf, ok := ret.Get(0).(func(*AccessData) error); ok { r0 = rf(_a0) } else { r0 = ret.Error(0) diff --git a/authentication/provider/oauth2/mock_storage_provider.go b/authentication/provider/oauth2/mock_storage_provider.go index f1ad0ed..53b43aa 100644 --- a/authentication/provider/oauth2/mock_storage_provider.go +++ b/authentication/provider/oauth2/mock_storage_provider.go @@ -10,15 +10,15 @@ type MockStorageProvider struct { } // LoadAccess provides a mock function with given fields: token -func (_m *MockStorageProvider) LoadAccess(token string) (*AccessInfo, error) { +func (_m *MockStorageProvider) LoadAccess(token string) (*AccessData, error) { ret := _m.Called(token) - var r0 *AccessInfo - if rf, ok := ret.Get(0).(func(string) *AccessInfo); ok { + var r0 *AccessData + if rf, ok := ret.Get(0).(func(string) *AccessData); ok { r0 = rf(token) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(*AccessInfo) + r0 = ret.Get(0).(*AccessData) } } @@ -33,15 +33,15 @@ func (_m *MockStorageProvider) LoadAccess(token string) (*AccessInfo, error) { } // LoadAuthorize provides a mock function with given fields: code -func (_m *MockStorageProvider) LoadAuthorize(code string) (*AuthorizeInfo, error) { +func (_m *MockStorageProvider) LoadAuthorize(code string) (*AuthorizeData, error) { ret := _m.Called(code) - var r0 *AuthorizeInfo - if rf, ok := ret.Get(0).(func(string) *AuthorizeInfo); ok { + var r0 *AuthorizeData + if rf, ok := ret.Get(0).(func(string) *AuthorizeData); ok { r0 = rf(code) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(*AuthorizeInfo) + r0 = ret.Get(0).(*AuthorizeData) } } @@ -79,15 +79,15 @@ func (_m *MockStorageProvider) LoadClient(id string) (Client, error) { } // LoadRefresh provides a mock function with given fields: token -func (_m *MockStorageProvider) LoadRefresh(token string) (*AccessInfo, error) { +func (_m *MockStorageProvider) LoadRefresh(token string) (*AccessData, error) { ret := _m.Called(token) - var r0 *AccessInfo - if rf, ok := ret.Get(0).(func(string) *AccessInfo); ok { + var r0 *AccessData + if rf, ok := ret.Get(0).(func(string) *AccessData); ok { r0 = rf(token) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(*AccessInfo) + r0 = ret.Get(0).(*AccessData) } } @@ -158,11 +158,11 @@ func (_m *MockStorageProvider) RemoveRefresh(token string) error { } // SaveAccess provides a mock function with given fields: _a0 -func (_m *MockStorageProvider) SaveAccess(_a0 *AccessInfo) error { +func (_m *MockStorageProvider) SaveAccess(_a0 *AccessData) error { ret := _m.Called(_a0) var r0 error - if rf, ok := ret.Get(0).(func(*AccessInfo) error); ok { + if rf, ok := ret.Get(0).(func(*AccessData) error); ok { r0 = rf(_a0) } else { r0 = ret.Error(0) @@ -172,11 +172,11 @@ func (_m *MockStorageProvider) SaveAccess(_a0 *AccessInfo) error { } // SaveAuthorize provides a mock function with given fields: _a0 -func (_m *MockStorageProvider) SaveAuthorize(_a0 *AuthorizeInfo) error { +func (_m *MockStorageProvider) SaveAuthorize(_a0 *AuthorizeData) error { ret := _m.Called(_a0) var r0 error - if rf, ok := ret.Get(0).(func(*AuthorizeInfo) error); ok { + if rf, ok := ret.Get(0).(func(*AuthorizeData) error); ok { r0 = rf(_a0) } else { r0 = ret.Error(0) @@ -200,11 +200,11 @@ func (_m *MockStorageProvider) SaveClient(_a0 Client) error { } // SaveRefresh provides a mock function with given fields: _a0 -func (_m *MockStorageProvider) SaveRefresh(_a0 *AccessInfo) error { +func (_m *MockStorageProvider) SaveRefresh(_a0 *AccessData) error { ret := _m.Called(_a0) var r0 error - if rf, ok := ret.Get(0).(func(*AccessInfo) error); ok { + if rf, ok := ret.Get(0).(func(*AccessData) error); ok { r0 = rf(_a0) } else { r0 = ret.Error(0) diff --git a/authentication/provider/oauth2/mock_user_provider.go b/authentication/provider/oauth2/mock_user_provider.go index 21d5495..f7da3b2 100644 --- a/authentication/provider/oauth2/mock_user_provider.go +++ b/authentication/provider/oauth2/mock_user_provider.go @@ -12,6 +12,52 @@ type MockUserProvider struct { mock.Mock } +// Authenticate provides a mock function with given fields: username, password +func (_m *MockUserProvider) Authenticate(username string, password string) (user.User, error) { + ret := _m.Called(username, password) + + var r0 user.User + if rf, ok := ret.Get(0).(func(string, string) user.User); ok { + r0 = rf(username, password) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(user.User) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(string, string) error); ok { + r1 = rf(username, password) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// LoadByUsername provides a mock function with given fields: username +func (_m *MockUserProvider) LoadByUsername(username string) (user.User, error) { + ret := _m.Called(username) + + var r0 user.User + if rf, ok := ret.Get(0).(func(string) user.User); ok { + r0 = rf(username) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(user.User) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(username) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // LoadUser provides a mock function with given fields: id func (_m *MockUserProvider) LoadUser(id string) (user.User, error) { ret := _m.Called(id) diff --git a/authentication/provider/oauth2/oauth2_authentication_provider.go b/authentication/provider/oauth2/oauth2_authentication_provider.go index 9e45f67..fbcb95c 100644 --- a/authentication/provider/oauth2/oauth2_authentication_provider.go +++ b/authentication/provider/oauth2/oauth2_authentication_provider.go @@ -65,6 +65,7 @@ func (p *OAuth2AuthenticationProvider) IsSupported(creds credential.Credential) func (p *OAuth2AuthenticationProvider) authenticateByToken(r *http.Request, creds *credential.TokenCredential) (*http.Request, error) { ctx := r.Context() + // nolint:forcetypeassert token, err := p.accessStorage.LoadAccess(creds.GetPrincipal().(string)) if err != nil { return r, fmt.Errorf("load access token failed: %w", err) @@ -96,11 +97,13 @@ func (p *OAuth2AuthenticationProvider) authenticateByToken(r *http.Request, cred func (p *OAuth2AuthenticationProvider) authenticateByClient(r *http.Request, creds *credential.UsernamePasswordCredential) (*http.Request, error) { ctx := r.Context() + // nolint:forcetypeassert client, err := p.clientStorage.LoadClient(creds.GetPrincipal().(string)) if err != nil { return r, fmt.Errorf("load client info failed: %w", err) } + // nolint:forcetypeassert if c, ok := client.(ClientSecretMatcher); ok { if c.SecretMatches(creds.GetCredentials().(string)) { creds.SetAuthenticated(true) diff --git a/authentication/provider/oauth2/oauth2_authentication_provider_test.go b/authentication/provider/oauth2/oauth2_authentication_provider_test.go index 34c99d3..28d01b8 100644 --- a/authentication/provider/oauth2/oauth2_authentication_provider_test.go +++ b/authentication/provider/oauth2/oauth2_authentication_provider_test.go @@ -155,7 +155,7 @@ func TestOAuth2AuthenticationProviderAuthenticateByAccessTokenWithTokenExpired(t accessStorageMock := &MockAccessProvider{} - access := &AccessInfo{ + access := &AccessData{ AccessToken: "wSxJOjDWo7qQ7kF5Tlg2l9XZYat6gq6GssF5D5I9aKtcEipJzoTba77vRhfscn1vNr0gBM9rSj5sZ3R6252FTlJpxWPUM1c8w2KkvaAAcyrWqNPVNNFX2qAxhpcatdbR", ExpiresIn: 60, UserData: userMock, @@ -186,7 +186,7 @@ func TestOAuth2AuthenticationProviderAuthenticateByAccessTokenWithUserNotFound(t accessStorageMock := &MockAccessProvider{} - access := &AccessInfo{ + access := &AccessData{ AccessToken: "wSxJOjDWo7qQ7kF5Tlg2l9XZYat6gq6GssF5D5I9aKtcEipJzoTba77vRhfscn1vNr0gBM9rSj5sZ3R6252FTlJpxWPUM1c8w2KkvaAAcyrWqNPVNNFX2qAxhpcatdbR", ExpiresIn: 60, CreatedAt: time.Now(), @@ -227,7 +227,7 @@ func TestOAuth2AuthenticationProviderAuthenticateByAccessTokenWithToken(t *testi RedirectURI: "https://connect.myservice.tld", } - access := &AccessInfo{ + access := &AccessData{ Client: client, AccessToken: "wSxJOjDWo7qQ7kF5Tlg2l9XZYat6gq6GssF5D5I9aKtcEipJzoTba77vRhfscn1vNr0gBM9rSj5sZ3R6252FTlJpxWPUM1c8w2KkvaAAcyrWqNPVNNFX2qAxhpcatdbR", ExpiresIn: 60, @@ -282,7 +282,7 @@ func TestOAuth2AuthenticationProviderAuthenticateByAccessTokenWithBadUserDataTyp RedirectURI: "https://connect.myservice.tld", } - access := &AccessInfo{ + access := &AccessData{ Client: client, AccessToken: "wSxJOjDWo7qQ7kF5Tlg2l9XZYat6gq6GssF5D5I9aKtcEipJzoTba77vRhfscn1vNr0gBM9rSj5sZ3R6252FTlJpxWPUM1c8w2KkvaAAcyrWqNPVNNFX2qAxhpcatdbR", ExpiresIn: 60, diff --git a/authentication/provider/oauth2/option.go b/authentication/provider/oauth2/option.go new file mode 100644 index 0000000..1a42967 --- /dev/null +++ b/authentication/provider/oauth2/option.go @@ -0,0 +1,51 @@ +// Copyright 2021 Hyperscale. All rights reserved. +// Use of this source code is governed by a MIT +// license that can be found in the LICENSE file. + +package oauth2 + +import ( + "github.com/euskadi31/go-eventemitter" + "github.com/hyperscale-stack/logger" + "github.com/hyperscale-stack/security/authentication/provider/oauth2/token" +) + +// Option type. +type Option func(server *Server) + +// WithConfig add config to server. +func WithConfig(cfg *Configuration) Option { + return func(server *Server) { + server.cfg = cfg + } +} + +func WithStorage(storage StorageProvider) Option { + return func(server *Server) { + server.storage = storage + } +} + +func WithUserProvider(userProvider UserProvider) Option { + return func(server *Server) { + server.userProvider = userProvider + } +} + +func WithLogger(logger logger.Logger) Option { + return func(server *Server) { + server.logger = logger + } +} + +func WithTokenGenerator(tokenGenerator token.Generator) Option { + return func(server *Server) { + server.tokenGenerator = tokenGenerator + } +} + +func WithEventEmitter(emitter eventemitter.EventEmitter) Option { + return func(server *Server) { + server.emitter = emitter + } +} diff --git a/authentication/provider/oauth2/option_test.go b/authentication/provider/oauth2/option_test.go new file mode 100644 index 0000000..11bd23e --- /dev/null +++ b/authentication/provider/oauth2/option_test.go @@ -0,0 +1,86 @@ +// Copyright 2021 Hyperscale. All rights reserved. +// Use of this source code is governed by a MIT +// license that can be found in the LICENSE file. + +package oauth2 + +import ( + "testing" + + "github.com/euskadi31/go-eventemitter" + "github.com/hyperscale-stack/logger" + "github.com/hyperscale-stack/security/authentication/provider/oauth2/token" + "github.com/stretchr/testify/assert" +) + +func TestWithConfig(t *testing.T) { + cfg := &Configuration{} + + opt := WithConfig(cfg) + + server := &Server{} + + opt(server) + + assert.Same(t, cfg, server.cfg) +} + +func TestWithLogger(t *testing.T) { + logger := &logger.Nop{} + + opt := WithLogger(logger) + + server := &Server{} + + opt(server) + + assert.Same(t, logger, server.logger) +} + +func TestWithStorage(t *testing.T) { + storageMock := &MockStorageProvider{} + + opt := WithStorage(storageMock) + + server := &Server{} + + opt(server) + + assert.Same(t, storageMock, server.storage) +} + +func TestWithUserProvider(t *testing.T) { + userProviderMock := &MockUserProvider{} + + opt := WithUserProvider(userProviderMock) + + server := &Server{} + + opt(server) + + assert.Same(t, userProviderMock, server.userProvider) +} + +func TestWithTokenGenerator(t *testing.T) { + tokenGeneratorMock := &token.MockGenerator{} + + opt := WithTokenGenerator(tokenGeneratorMock) + + server := &Server{} + + opt(server) + + assert.Same(t, tokenGeneratorMock, server.tokenGenerator) +} + +func TestWithEventEmitter(t *testing.T) { + emitter := eventemitter.New() + + opt := WithEventEmitter(emitter) + + server := &Server{} + + opt(server) + + assert.Same(t, emitter, server.emitter) +} diff --git a/authentication/provider/oauth2/response.go b/authentication/provider/oauth2/response.go new file mode 100644 index 0000000..a908eeb --- /dev/null +++ b/authentication/provider/oauth2/response.go @@ -0,0 +1,162 @@ +// Copyright 2021 Hyperscale. All rights reserved. +// Use of this source code is governed by a MIT +// license that can be found in the LICENSE file. + +package oauth2 + +import ( + "errors" + "fmt" + "net/http" + "net/url" +) + +var ( + ErrNotARedirectResponse = errors.New("not a redirect response") +) + +// Data for response output. +type ResponseData map[string]interface{} + +// Response type enum. +type ResponseType int + +const ( + DATA ResponseType = iota + REDIRECT +) + +// Server response. +type Response struct { + Type ResponseType + StatusCode int + StatusText string + ErrorStatusCode int + URL string + Output ResponseData + Headers http.Header + IsError bool + ErrorID DefaultErrorID + InternalError error + RedirectInFragment bool + + // Storage to use in this response - required + Storage StorageProvider +} + +func NewResponse(storage StorageProvider) *Response { + r := &Response{ + Type: DATA, + StatusCode: 200, + ErrorStatusCode: 200, + Output: make(ResponseData), + Headers: make(http.Header), + IsError: false, + Storage: storage, // Clone ? + } + r.Headers.Add( + "Cache-Control", + "no-cache, no-store, max-age=0, must-revalidate", + ) + r.Headers.Add("Pragma", "no-cache") + r.Headers.Add("Expires", "Fri, 01 Jan 1990 00:00:00 GMT") + + return r +} + +// SetError sets an error id and description on the Response +// state and uri are left blank. +func (r *Response) SetError(id DefaultErrorID, description string) { + r.SetErrorURI(id, description, "", "") +} + +// SetErrorState sets an error id, description, and state on the Response +// uri is left blank. +func (r *Response) SetErrorState(id DefaultErrorID, description string, state string) { + r.SetErrorURI(id, description, "", state) +} + +// SetErrorURI sets an error id, description, state, and uri on the Response. +func (r *Response) SetErrorURI(id DefaultErrorID, description string, uri string, state string) { + // get default error message + if description == "" { + description = deferror.Get(id) + } + + // set error parameters + r.IsError = true + r.ErrorID = id + r.StatusCode = r.ErrorStatusCode + + if r.StatusCode != http.StatusOK { + r.StatusText = description + } else { + r.StatusText = "" + } + + r.Output = make(ResponseData) // clear output + r.Output["error"] = id + r.Output["error_description"] = description + + if uri != "" { + r.Output["error_uri"] = uri + } + + if state != "" { + r.Output["state"] = state + } +} + +// SetRedirect changes the response to redirect to the given url. +func (r *Response) SetRedirect(url string) { + // set redirect parameters + r.Type = REDIRECT + r.URL = url +} + +// SetRedirectFragment sets redirect values to be passed in fragment instead of as query parameters. +func (r *Response) SetRedirectFragment(f bool) { + r.RedirectInFragment = f +} + +// GetRedirectURL returns the redirect url with all query string parameters. +func (r *Response) GetRedirectURL() (string, error) { + if r.Type != REDIRECT { + return "", ErrNotARedirectResponse + } + + u, err := url.Parse(r.URL) + if err != nil { + return "", fmt.Errorf("parse url failed: %w", err) + } + + var q url.Values + if r.RedirectInFragment { + // start with empty set for fragment + q = url.Values{} + } else { + // add parameters to existing query + q = u.Query() + } + + // add parameters + for n, v := range r.Output { + q.Set(n, fmt.Sprint(v)) + } + + // https://tools.ietf.org/html/rfc6749#section-4.2.2 + // Fragment should be encoded as application/x-www-form-urlencoded (%-escaped, spaces are represented as '+') + // The stdlib URL#String() doesn't make that easy to accomplish, so build this ourselves + if r.RedirectInFragment { + u.Fragment = "" + redirectURI := u.String() + "#" + q.Encode() + + return redirectURI, nil + } + + // Otherwise, update the query and encode normally + u.RawQuery = q.Encode() + u.Fragment = "" + + return u.String(), nil +} diff --git a/authentication/provider/oauth2/response_json.go b/authentication/provider/oauth2/response_json.go new file mode 100644 index 0000000..ce8b320 --- /dev/null +++ b/authentication/provider/oauth2/response_json.go @@ -0,0 +1,43 @@ +// Copyright 2021 Hyperscale. All rights reserved. +// Use of this source code is governed by a MIT +// license that can be found in the LICENSE file. + +package oauth2 + +import ( + "encoding/json" + "net/http" +) + +// OutputJSON encodes the Response to JSON and writes to the http.ResponseWriter. +func OutputJSON(rs *Response, w http.ResponseWriter, r *http.Request) error { + // Add headers + for i, k := range rs.Headers { + for _, v := range k { + w.Header().Add(i, v) + } + } + + if rs.Type == REDIRECT { + // Output redirect with parameters + u, err := rs.GetRedirectURL() + if err != nil { + return err + } + + w.Header().Add("Location", u) + w.WriteHeader(302) + + return nil + } + + // set content type if the response doesn't already have one associated with it + if w.Header().Get("Content-Type") == "" { + w.Header().Set("Content-Type", "application/json") + } + + w.WriteHeader(rs.StatusCode) + + //nolint:wrapcheck + return json.NewEncoder(w).Encode(rs.Output) +} diff --git a/authentication/provider/oauth2/response_json_test.go b/authentication/provider/oauth2/response_json_test.go new file mode 100644 index 0000000..e82b5a0 --- /dev/null +++ b/authentication/provider/oauth2/response_json_test.go @@ -0,0 +1,111 @@ +// Copyright 2021 Hyperscale. All rights reserved. +// Use of this source code is governed by a MIT +// license that can be found in the LICENSE file. + +package oauth2 + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestResponseJSON(t *testing.T) { + req, err := http.NewRequest("GET", "http://localhost:14000/appauth", nil) + assert.NoError(t, err) + + w := httptest.NewRecorder() + + storageMock := &MockStorageProvider{} + + r := NewResponse(storageMock) + + r.Output["access_token"] = "1234" + r.Output["token_type"] = "5678" + + err = OutputJSON(r, w, req) + assert.NoError(t, err) + + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "application/json", w.Result().Header.Get("Content-Type")) + + // parse output json + output := make(map[string]interface{}) + err = json.Unmarshal(w.Body.Bytes(), &output) + assert.NoError(t, err) + + assert.Contains(t, output, "access_token") + assert.Equal(t, "1234", output["access_token"]) + + assert.Contains(t, output, "token_type") + assert.Equal(t, "5678", output["token_type"]) +} + +func TestErrorResponseJSON(t *testing.T) { + req, err := http.NewRequest("GET", "http://localhost:14000/appauth", nil) + assert.NoError(t, err) + + w := httptest.NewRecorder() + + storageMock := &MockStorageProvider{} + + r := NewResponse(storageMock) + r.ErrorStatusCode = 500 + r.SetError(E_INVALID_REQUEST, "") + + err = OutputJSON(r, w, req) + assert.NoError(t, err) + + assert.Equal(t, 500, w.Code) + + assert.Equal(t, "application/json", w.Result().Header.Get("Content-Type")) + + // parse output json + output := make(map[string]interface{}) + err = json.Unmarshal(w.Body.Bytes(), &output) + assert.NoError(t, err) + + assert.Contains(t, output, "error") + assert.Equal(t, E_INVALID_REQUEST.String(), output["error"]) +} + +func TestRedirectResponseJSON(t *testing.T) { + req, err := http.NewRequest("GET", "http://localhost:14000/appauth", nil) + assert.NoError(t, err) + + w := httptest.NewRecorder() + + storageMock := &MockStorageProvider{} + + r := NewResponse(storageMock) + r.SetRedirect("http://localhost:14000") + + err = OutputJSON(r, w, req) + assert.NoError(t, err) + + assert.Equal(t, http.StatusFound, w.Code) + + assert.Equal(t, "http://localhost:14000", w.Result().Header.Get("Location")) +} + +func TestRedirectResponseJSONWithError(t *testing.T) { + req, err := http.NewRequest("GET", "http://localhost:14000/appauth", nil) + assert.NoError(t, err) + + w := httptest.NewRecorder() + + storageMock := &MockStorageProvider{} + + r := NewResponse(storageMock) + r.SetRedirect(":14000") + + err = OutputJSON(r, w, req) + assert.EqualError(t, err, "parse url failed: parse \":14000\": missing protocol scheme") + + assert.Equal(t, http.StatusOK, w.Code) + + assert.Equal(t, "", w.Result().Header.Get("Location")) +} diff --git a/authentication/provider/oauth2/response_test.go b/authentication/provider/oauth2/response_test.go new file mode 100644 index 0000000..4d9c2c5 --- /dev/null +++ b/authentication/provider/oauth2/response_test.go @@ -0,0 +1,84 @@ +// Copyright 2021 Hyperscale. All rights reserved. +// Use of this source code is governed by a MIT +// license that can be found in the LICENSE file. + +package oauth2 + +import ( + "net/http" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestResponseGetRedirectURLWithDataRequest(t *testing.T) { + storageMock := &MockStorageProvider{} + + r := NewResponse(storageMock) + + r.Type = DATA + + url, err := r.GetRedirectURL() + assert.EqualError(t, err, ErrNotARedirectResponse.Error()) + + assert.Equal(t, "", url) +} + +func TestResponseGetRedirectURLWithRedirectInFragment(t *testing.T) { + storageMock := &MockStorageProvider{} + + r := NewResponse(storageMock) + + r.Output["foo"] = "bar" + r.SetRedirect("https://oauth.mydomain.tld/connect") + r.SetRedirectFragment(true) + + url, err := r.GetRedirectURL() + assert.NoError(t, err) + + assert.Equal(t, "https://oauth.mydomain.tld/connect#foo=bar", url) +} + +func TestResponseSetErrorURI(t *testing.T) { + storageMock := &MockStorageProvider{} + + r := NewResponse(storageMock) + + r.SetErrorURI(E_ACCESS_DENIED, "access denied", "https://oauth.mydomain.tld/connect", "foobar") + + assert.True(t, r.IsError) + assert.Equal(t, http.StatusOK, r.ErrorStatusCode) + assert.Equal(t, E_ACCESS_DENIED, r.ErrorID) + + assert.Equal(t, "", r.StatusText) + + assert.Contains(t, r.Output, "error_uri") + assert.Equal(t, "https://oauth.mydomain.tld/connect", r.Output["error_uri"]) + + assert.Contains(t, r.Output, "state") + assert.Equal(t, "foobar", r.Output["state"]) + + assert.Contains(t, r.Output, "error") + assert.Equal(t, E_ACCESS_DENIED, r.Output["error"]) + + assert.Contains(t, r.Output, "error_description") + assert.Equal(t, "access denied", r.Output["error_description"]) + +} + +func TestResponseSetErrorState(t *testing.T) { + storageMock := &MockStorageProvider{} + + r := NewResponse(storageMock) + + r.SetErrorState(E_ACCESS_DENIED, "", "foobar") + + assert.Contains(t, r.Output, "error") + assert.Equal(t, E_ACCESS_DENIED, r.Output["error"]) + + assert.Contains(t, r.Output, "error_description") + assert.Equal(t, "The resource owner or authorization server denied the request.", r.Output["error_description"]) + + assert.Contains(t, r.Output, "state") + assert.Equal(t, "foobar", r.Output["state"]) +} diff --git a/authentication/provider/oauth2/server.go b/authentication/provider/oauth2/server.go new file mode 100644 index 0000000..3b733a0 --- /dev/null +++ b/authentication/provider/oauth2/server.go @@ -0,0 +1,50 @@ +// Copyright 2021 Hyperscale. All rights reserved. +// Use of this source code is governed by a MIT +// license that can be found in the LICENSE file. + +package oauth2 + +import ( + "time" + + "github.com/euskadi31/go-eventemitter" + "github.com/hyperscale-stack/logger" + "github.com/hyperscale-stack/security/authentication/provider/oauth2/token" + "github.com/hyperscale-stack/security/authentication/provider/oauth2/token/random" +) + +type Server struct { + cfg *Configuration + logger logger.Logger + storage StorageProvider + userProvider UserProvider + tokenGenerator token.Generator + emitter eventemitter.EventEmitter + now func() time.Time +} + +func NewServer(options ...Option) *Server { + cfg := NewConfiguration() + + s := &Server{ + cfg: cfg, + logger: &logger.Nop{}, + tokenGenerator: random.NewTokenGenerator(&random.Configuration{}), + now: time.Now, + emitter: eventemitter.New(), + } + + for _, opt := range options { + opt(s) + } + + return s +} + +// NewResponse creates a new response for the server. +func (s *Server) NewResponse() *Response { + r := NewResponse(s.storage) + r.ErrorStatusCode = s.cfg.ErrorStatusCode + + return r +} diff --git a/authentication/provider/oauth2/server_test.go b/authentication/provider/oauth2/server_test.go new file mode 100644 index 0000000..e8c8aef --- /dev/null +++ b/authentication/provider/oauth2/server_test.go @@ -0,0 +1,31 @@ +// Copyright 2021 Hyperscale. All rights reserved. +// Use of this source code is governed by a MIT +// license that can be found in the LICENSE file. + +package oauth2 + +import ( + "net/http" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestServer(t *testing.T) { + cfg := &Configuration{} + server := NewServer(WithConfig(cfg)) + + assert.Same(t, cfg, server.cfg) +} + +func TestServerNewResponse(t *testing.T) { + cfg := &Configuration{ + ErrorStatusCode: http.StatusOK, + } + server := NewServer(WithConfig(cfg)) + + response := server.NewResponse() + assert.NotNil(t, response) + + assert.Equal(t, cfg.ErrorStatusCode, response.ErrorStatusCode) +} diff --git a/authentication/provider/oauth2/storage.go b/authentication/provider/oauth2/storage.go index f80a32a..f605904 100644 --- a/authentication/provider/oauth2/storage.go +++ b/authentication/provider/oauth2/storage.go @@ -27,28 +27,30 @@ type ClientProvider interface { //go:generate mockery --name=AccessProvider --inpackage --case underscore type AccessProvider interface { - SaveAccess(*AccessInfo) error - LoadAccess(token string) (*AccessInfo, error) + SaveAccess(*AccessData) error + LoadAccess(token string) (*AccessData, error) RemoveAccess(token string) error } //go:generate mockery --name=RefreshProvider --inpackage --case underscore type RefreshProvider interface { - SaveRefresh(*AccessInfo) error - LoadRefresh(token string) (*AccessInfo, error) + SaveRefresh(*AccessData) error + LoadRefresh(token string) (*AccessData, error) RemoveRefresh(token string) error } //go:generate mockery --name=AuthorizeProvider --inpackage --case underscore type AuthorizeProvider interface { - SaveAuthorize(*AuthorizeInfo) error - LoadAuthorize(code string) (*AuthorizeInfo, error) + SaveAuthorize(*AuthorizeData) error + LoadAuthorize(code string) (*AuthorizeData, error) RemoveAuthorize(code string) error } //go:generate mockery --name=UserProvider --inpackage --case underscore type UserProvider interface { LoadUser(id string) (user.User, error) + LoadByUsername(username string) (user.User, error) + Authenticate(username string, password string) (user.User, error) } //go:generate mockery --name=StorageProvider --inpackage --case underscore diff --git a/authentication/provider/oauth2/storage/in_memory_storage.go b/authentication/provider/oauth2/storage/in_memory_storage.go index b7c2028..2fafbe1 100644 --- a/authentication/provider/oauth2/storage/in_memory_storage.go +++ b/authentication/provider/oauth2/storage/in_memory_storage.go @@ -31,7 +31,7 @@ func (s *InMemoryStorage) SaveClient(client oauth2.Client) error { func (s *InMemoryStorage) LoadClient(id string) (oauth2.Client, error) { if client, ok := s.clients.Load(id); ok { - return client.(oauth2.Client), nil + return client.(oauth2.Client), nil // nolint:forcetypeassert } return nil, oauth2.ErrClientNotFound @@ -43,15 +43,15 @@ func (s *InMemoryStorage) RemoveClient(id string) error { return nil } -func (s *InMemoryStorage) SaveAccess(access *oauth2.AccessInfo) error { +func (s *InMemoryStorage) SaveAccess(access *oauth2.AccessData) error { s.accesses.Store(access.AccessToken, access) return nil } -func (s *InMemoryStorage) LoadAccess(token string) (*oauth2.AccessInfo, error) { +func (s *InMemoryStorage) LoadAccess(token string) (*oauth2.AccessData, error) { if access, ok := s.accesses.Load(token); ok { - return access.(*oauth2.AccessInfo), nil + return access.(*oauth2.AccessData), nil // nolint:forcetypeassert } return nil, oauth2.ErrAccessNotFound @@ -63,15 +63,15 @@ func (s *InMemoryStorage) RemoveAccess(token string) error { return nil } -func (s *InMemoryStorage) SaveRefresh(access *oauth2.AccessInfo) error { +func (s *InMemoryStorage) SaveRefresh(access *oauth2.AccessData) error { s.refreshs.Store(access.RefreshToken, access) return nil } -func (s *InMemoryStorage) LoadRefresh(token string) (*oauth2.AccessInfo, error) { +func (s *InMemoryStorage) LoadRefresh(token string) (*oauth2.AccessData, error) { if access, ok := s.refreshs.Load(token); ok { - return access.(*oauth2.AccessInfo), nil + return access.(*oauth2.AccessData), nil // nolint:forcetypeassert } return nil, oauth2.ErrRefreshNotFound @@ -83,15 +83,15 @@ func (s *InMemoryStorage) RemoveRefresh(token string) error { return nil } -func (s *InMemoryStorage) SaveAuthorize(authorize *oauth2.AuthorizeInfo) error { +func (s *InMemoryStorage) SaveAuthorize(authorize *oauth2.AuthorizeData) error { s.authorizes.Store(authorize.Code, authorize) return nil } -func (s *InMemoryStorage) LoadAuthorize(code string) (*oauth2.AuthorizeInfo, error) { +func (s *InMemoryStorage) LoadAuthorize(code string) (*oauth2.AuthorizeData, error) { if authorize, ok := s.authorizes.Load(code); ok { - return authorize.(*oauth2.AuthorizeInfo), nil + return authorize.(*oauth2.AuthorizeData), nil // nolint:forcetypeassert } return nil, oauth2.ErrAuthorizeNotFound diff --git a/authentication/provider/oauth2/storage/in_memory_storage_test.go b/authentication/provider/oauth2/storage/in_memory_storage_test.go index 97cdd7e..ff39411 100644 --- a/authentication/provider/oauth2/storage/in_memory_storage_test.go +++ b/authentication/provider/oauth2/storage/in_memory_storage_test.go @@ -40,7 +40,7 @@ func TestInMemoryStorage(t *testing.T) { assert.Nil(t, client2) // Access Token - access := &oauth2.AccessInfo{ + access := &oauth2.AccessData{ AccessToken: "OKjQ0VjYmJxP8N0TzXH5lxvIOZj4bCM0DlsCvuiL96HCQEhJ8A9ozY8jJ5Ep38vaVvn082fgApThX7NZ7pktKn57A667kEeWLPW0KVA3x1flYdBvkIvHOAZYyvUeKK9q", } @@ -63,7 +63,7 @@ func TestInMemoryStorage(t *testing.T) { assert.Nil(t, access2) // Refresh Token - access = &oauth2.AccessInfo{ + access = &oauth2.AccessData{ RefreshToken: "2oQDkOWnbqtJoEs24MkVEB4WNJnqyoAIErvSJRhjg562K8GznWLbLZuStQodKvReSedAqufswaSZduhlgOuCNcQj9aGbCKPAnXUVvmX7Vmgvryp9PaZVbuqj0HfzN9tD", } @@ -86,7 +86,7 @@ func TestInMemoryStorage(t *testing.T) { assert.Nil(t, access2) // Authorize Code - authorize := &oauth2.AuthorizeInfo{ + authorize := &oauth2.AuthorizeData{ Code: "Je4dJ5RFPRJwuSmuitSo8tX7s3uFOP84sEufxjdqJhiiPABdbxeGofGvvX7LBdvy2ZrwDZy3a6cOF8vgquUlr8yAvA9VpDz4Kv2bZxm0WEl4y3SJSvYPnwBOxRHI5pxK", } diff --git a/authentication/provider/oauth2/token/generator.go b/authentication/provider/oauth2/token/generator.go index cd39ec9..43da7a7 100644 --- a/authentication/provider/oauth2/token/generator.go +++ b/authentication/provider/oauth2/token/generator.go @@ -4,6 +4,7 @@ package token +//go:generate mockery --name=Generator --inpackage --case underscore type Generator interface { GenerateAccessToken(generateRefresh bool) (accessToken string, refreshToken string, err error) } diff --git a/authentication/provider/oauth2/token/mock_generator.go b/authentication/provider/oauth2/token/mock_generator.go new file mode 100644 index 0000000..c56ce5e --- /dev/null +++ b/authentication/provider/oauth2/token/mock_generator.go @@ -0,0 +1,38 @@ +// Code generated by mockery v0.0.0-dev. DO NOT EDIT. + +package token + +import mock "github.com/stretchr/testify/mock" + +// MockGenerator is an autogenerated mock type for the Generator type +type MockGenerator struct { + mock.Mock +} + +// GenerateAccessToken provides a mock function with given fields: generateRefresh +func (_m *MockGenerator) GenerateAccessToken(generateRefresh bool) (string, string, error) { + ret := _m.Called(generateRefresh) + + var r0 string + if rf, ok := ret.Get(0).(func(bool) string); ok { + r0 = rf(generateRefresh) + } else { + r0 = ret.Get(0).(string) + } + + var r1 string + if rf, ok := ret.Get(1).(func(bool) string); ok { + r1 = rf(generateRefresh) + } else { + r1 = ret.Get(1).(string) + } + + var r2 error + if rf, ok := ret.Get(2).(func(bool) error); ok { + r2 = rf(generateRefresh) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} diff --git a/authentication/provider/oauth2/urivalidate.go b/authentication/provider/oauth2/urivalidate.go new file mode 100644 index 0000000..f66b2dd --- /dev/null +++ b/authentication/provider/oauth2/urivalidate.go @@ -0,0 +1,126 @@ +// Copyright 2021 Hyperscale. All rights reserved. +// Use of this source code is governed by a MIT +// license that can be found in the LICENSE file. + +package oauth2 + +import ( + "errors" + "fmt" + "net/url" + "strings" +) + +var ( + ErrNoBlank = errors.New("urls cannot be blank") +) + +// error returned when validation don't match. +type URIValidationError string + +func (e URIValidationError) Error() string { + return string(e) +} + +func newURIValidationError(msg string, base string, redirect string) URIValidationError { + return URIValidationError(fmt.Sprintf("%s: %s / %s", msg, base, redirect)) +} + +// ParseURLs resolving uri references to base url. +func ParseURLs(baseUrl, redirectUrl string) (retBaseUrl, retRedirectUrl *url.URL, err error) { + var base, redirect *url.URL + // parse base url + if base, err = url.Parse(baseUrl); err != nil { + return nil, nil, fmt.Errorf("parse base url failed: %w", err) + } + + // parse redirect url + if redirect, err = url.Parse(redirectUrl); err != nil { + return nil, nil, fmt.Errorf("parse redirect url failed: %w", err) + } + + // must not have fragment + if base.Fragment != "" || redirect.Fragment != "" { + return nil, nil, newURIValidationError("url must not include fragment.", baseUrl, redirectUrl) + } + + // Scheme must match + if redirect.Scheme != base.Scheme { + return nil, nil, newURIValidationError("scheme mismatch", baseUrl, redirectUrl) + } + + // Host must match + if redirect.Host != base.Host { + return nil, nil, newURIValidationError("host mismatch", baseUrl, redirectUrl) + } + + // resolve references to base url + retBaseUrl = (&url.URL{Scheme: base.Scheme, Host: base.Host, Path: "/"}).ResolveReference(&url.URL{Path: base.Path}) + retRedirectUrl = (&url.URL{Scheme: base.Scheme, Host: base.Host, Path: "/"}).ResolveReference(&url.URL{Path: redirect.Path, RawQuery: redirect.RawQuery}) + + return +} + +// ValidateUriList validates that redirectUri is contained in baseUriList. +// baseUriList may be a string separated by separator. +// If separator is blank, validate only 1 URI. +func ValidateURIList(baseUriList string, redirectUri string, separator string) (realRedirectUri string, err error) { + // make a list of uris + var slist []string + if separator != "" { + slist = strings.Split(baseUriList, separator) + } else { + slist = make([]string, 0) + slist = append(slist, baseUriList) + } + + for _, sitem := range slist { + realRedirectUri, err = ValidateURI(sitem, redirectUri) + // validated, return no error + if err == nil { + return realRedirectUri, nil + } + + // if there was an error that is not a validation error, return it + //nolint:errorlint + if _, iok := err.(URIValidationError); !iok { + return "", err + } + } + + return "", newURIValidationError("urls don't validate", baseUriList, redirectUri) +} + +// ValidateURI validates that redirectUri is contained in baseUri. +func ValidateURI(baseUri string, redirectUri string) (realRedirectUri string, err error) { + if baseUri == "" || redirectUri == "" { + return "", ErrNoBlank + } + + base, redirect, err := ParseURLs(baseUri, redirectUri) + if err != nil { + return "", err + } + + // allow exact path matches + if base.Path == redirect.Path { + return redirect.String(), nil + } + + // ensure prefix matches are actually subpaths + requiredPrefix := strings.TrimRight(base.Path, "/") + "/" + if !strings.HasPrefix(redirect.Path, requiredPrefix) { + return "", newURIValidationError("path prefix doesn't match", baseUri, redirectUri) + } + + return redirect.String(), nil +} + +// FirstURI returns the first uri from an uri list. +func FirstURI(baseUriList string, separator string) string { + if separator == "" { + return baseUriList + } + + return strings.Split(baseUriList, separator)[0] +} diff --git a/authentication/provider/oauth2/urivalidate_test.go b/authentication/provider/oauth2/urivalidate_test.go new file mode 100644 index 0000000..dbcd2e5 --- /dev/null +++ b/authentication/provider/oauth2/urivalidate_test.go @@ -0,0 +1,209 @@ +// Copyright 2021 Hyperscale. All rights reserved. +// Use of this source code is governed by a MIT +// license that can be found in the LICENSE file. + +package oauth2 + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestURIValidate(t *testing.T) { + valid := [][]string{ + { + // Exact match + "http://localhost:14000/appauth", + "http://localhost:14000/appauth", + "http://localhost:14000/appauth", + }, + { + // Trailing slash + "http://www.google.com/myapp", + "http://www.google.com/myapp/", + "http://www.google.com/myapp/", + }, + { + // Exact match with trailing slash + "http://www.google.com/myapp/", + "http://www.google.com/myapp/", + "http://www.google.com/myapp/", + }, + { + // Subpath + "http://www.google.com/myapp", + "http://www.google.com/myapp/interface/implementation", + "http://www.google.com/myapp/interface/implementation", + }, + { + // Subpath with trailing slash + "http://www.google.com/myapp/", + "http://www.google.com/myapp/interface/implementation", + "http://www.google.com/myapp/interface/implementation", + }, + { + // Subpath with things that are close to path traversals, but aren't + "http://www.google.com/myapp", + "http://www.google.com/myapp/.../..implementation../...", + "http://www.google.com/myapp/.../..implementation../...", + }, + { + // If the allowed basepath contains path traversals, allow them? + "http://www.google.com/traversal/../allowed", + "http://www.google.com/traversal/../allowed/with/subpath", + "http://www.google.com/allowed/with/subpath", + }, + { + // Backslashes + "https://mysafewebsite.com/secure/redirect", + "https://mysafewebsite.com/secure/redirect/\\../\\../\\../evil", + "https://mysafewebsite.com/secure/redirect/%5C../%5C../%5C../evil", + }, + { + // Backslashes + "https://mysafewebsite.com/secure/redirect", + "https://mysafewebsite.com/secure/redirect/\\..\\../\\../evil", + "https://mysafewebsite.com/secure/redirect/%5C..%5C../%5C../evil", + }, + { + // Query string must be kept + "http://www.google.com/myapp/redir", + "http://www.google.com/myapp/redir?a=1&b=2", + "http://www.google.com/myapp/redir?a=1&b=2", + }, + } + for _, v := range valid { + realRedirectURI, err := ValidateURI(v[0], v[1]) + assert.NoError(t, err) + assert.Equal(t, v[2], realRedirectURI) + } + + invalid := [][]string{ + { + "", + "", + }, + { + "\n", + "\n", + }, + { + // Doesn't satisfy base path + "http://localhost:14000/appauth", + "http://localhost:14000/app", + }, + { + // Doesn't satisfy base path + "http://localhost:14000/app/", + "http://localhost:14000/app", + }, + { + // Not a subpath of base path + "http://localhost:14000/appauth", + "http://localhost:14000/appauthmodifiedpath", + }, + { + // Host mismatch + "http://www.google.com/myapp", + "http://www2.google.com/myapp", + }, + { + // Scheme mismatch + "http://www.google.com/myapp", + "https://www.google.com/myapp", + }, + { + // Path traversal + "http://www.google.com/myapp", + "http://www.google.com/myapp/..", + }, + { + // Embedded path traversal + "http://www.google.com/myapp", + "http://www.google.com/myapp/../test", + }, + { + // Not a subpath + "http://www.google.com/myapp", + "http://www.google.com/myapp../test", + }, + { + // Backslashes + "https://mysafewebsite.com/secure/redirect", + "https://mysafewebsite.com/secure%2fredirect/../evil", + }, + } + for _, v := range invalid { + if _, err := ValidateURI(v[0], v[1]); err == nil { + t.Errorf("Expected ValidateURI(%s, %s) to fail", v[0], v[1]) + } + } +} + +func TestURIListValidate(t *testing.T) { + // V1 + if _, err := ValidateURIList("http://localhost:14000/appauth", "http://localhost:14000/appauth", ""); err != nil { + t.Errorf("V1: %s", err) + } + + // V2 + if _, err := ValidateURIList("http://localhost:14000/appauth", "http://localhost:14000/app", ""); err == nil { + t.Error("V2 should have failed") + } + + // V3 + if _, err := ValidateURIList("http://xxx:14000/appauth;http://localhost:14000/appauth", "http://localhost:14000/appauth", ";"); err != nil { + t.Errorf("V3: %s", err) + } + + // V4 + if _, err := ValidateURIList("http://xxx:14000/appauth;http://localhost:14000/appauth", "http://localhost:14000/app", ";"); err == nil { + t.Error("V4 should have failed") + } + + // V5 + if _, err := ValidateURIList("\n", "\n", ";"); err == nil { + t.Error("V5 should have failed") + } +} + +func TestFirstURI(t *testing.T) { + assert.Equal(t, "https://auth.mydomain.com/connect", FirstURI("https://auth.mydomain.com/connect mybundle://connect", " ")) + assert.Equal(t, "mybundle://connect", FirstURI("mybundle://connect", " ")) + assert.Equal(t, "mybundle://connect", FirstURI("mybundle://connect", "")) + assert.Equal(t, "", FirstURI("", " ")) + assert.Equal(t, "mybundle://connect", FirstURI("mybundle://connect", ";")) +} + +func TestNewURIValidationError(t *testing.T) { + err := newURIValidationError("scheme mismatch", "http://www.google.com/myapp", "http://www.google.com/myapp../test") + + assert.EqualError(t, err, "scheme mismatch: http://www.google.com/myapp / http://www.google.com/myapp../test") +} + +func TestParseURLs(t *testing.T) { + invalid := [][]string{ + { + "\n", + "\n", + }, + { + "https://google.com", + "\n", + }, + { + "https://google.com#foo", + "https://google.com#foo", + }, + { + "http://google.com", + "https://google.com", + }, + } + for _, v := range invalid { + if _, err := ValidateURI(v[0], v[1]); err == nil { + t.Errorf("Expected ValidateURI(%s, %s) to fail", v[0], v[1]) + } + } +} diff --git a/authentication/provider/oauth2/util.go b/authentication/provider/oauth2/util.go new file mode 100644 index 0000000..382e250 --- /dev/null +++ b/authentication/provider/oauth2/util.go @@ -0,0 +1,140 @@ +// Copyright 2021 Hyperscale. All rights reserved. +// Use of this source code is governed by a MIT +// license that can be found in the LICENSE file. + +package oauth2 + +import ( + "crypto/subtle" + "encoding/base64" + "errors" + "fmt" + "net/http" + "net/url" + "strings" + + "github.com/hyperscale-stack/security/authentication/credential" + "github.com/hyperscale-stack/security/http/header" +) + +var ( + ErrInvalidAuthorizationHeader = errors.New("invalid authorization header") + ErrInvalidAuthorizationMessage = errors.New("invalid authorization message") + ErrClientAuthenticationNotSent = errors.New("Client authentication not sent") +) + +// Parse basic authentication header. +type BasicAuth struct { + Username string + Password string +} + +// Parse bearer authentication header. +type BearerAuth struct { + Code string +} + +// CheckClientSecret determines whether the given secret matches a secret held by the client. +// Public clients return true for a secret of "". +func CheckClientSecret(client Client, secret string) bool { + switch client := client.(type) { + case ClientSecretMatcher: + // Prefer the more secure method of giving the secret to the client for comparison + return client.SecretMatches(secret) + default: + // Fallback to the less secure method of extracting the plain text secret from the client for comparison + return subtle.ConstantTimeCompare([]byte(client.GetSecret()), []byte(secret)) == 1 + } +} + +// Return authorization header data. +func CheckBasicAuth(r *http.Request) (*BasicAuth, error) { + if r.Header.Get("Authorization") == "" { + return nil, nil + } + + b64, ok := header.ExtractAuthorizationValue("Basic", r.Header.Get("Authorization")) + if !ok { + return nil, ErrInvalidAuthorizationHeader + } + + b, err := base64.StdEncoding.DecodeString(b64) + if err != nil { + return nil, fmt.Errorf("decode basic auth failed: %w", err) + } + + pair := strings.SplitN(string(b), ":", 2) + if len(pair) != 2 { + return nil, ErrInvalidAuthorizationMessage + } + + // Decode the client_id and client_secret pairs as per + // https://tools.ietf.org/html/rfc6749#section-2.3.1 + + username, err := url.QueryUnescape(pair[0]) + if err != nil { + return nil, fmt.Errorf("unescape username failed: %w", err) + } + + password, err := url.QueryUnescape(pair[1]) + if err != nil { + return nil, fmt.Errorf("unescape password failed: %w", err) + } + + return &BasicAuth{Username: username, Password: password}, nil +} + +// Return "Bearer" token from request. The header has precedence over query string. +func CheckBearerAuth(r *http.Request) *BearerAuth { + authHeader := r.Header.Get("Authorization") + authForm := r.FormValue("code") + + if authHeader == "" && authForm == "" { + return nil + } + + token := authForm + + if authHeader != "" { + v, ok := header.ExtractAuthorizationValue("Bearer", authHeader) + if !ok { + return nil + } + + token = v + } + + return &BearerAuth{Code: token} +} + +// getClientAuth checks client basic authentication in params if allowed, +// otherwise gets it from the header. +// Sets an error on the response if no auth is present or a server error occurs. +func (s Server) getClientAuth(w *Response, r *http.Request, allowQueryParams bool) *credential.UsernamePasswordCredential { + ctx := r.Context() + + // creds := credential.FromContext(ctx) + + if allowQueryParams { + // Allow for auth without password + if _, hasSecret := r.Form["client_secret"]; hasSecret { + auth := credential.NewUsernamePasswordCredential( + r.FormValue("client_id"), + r.FormValue("client_secret"), + ) + + if auth.GetPrincipal() != "" { + return auth + } + } + } + + auth := credential.FromContext(ctx) + if auth == nil { + s.setErrorAndLog(w, E_INVALID_REQUEST, ErrClientAuthenticationNotSent, "get_client_auth=%s", "client authentication not sent") + + return nil + } + + return auth.(*credential.UsernamePasswordCredential) // nolint:forcetypeassert +} diff --git a/authentication/provider/oauth2/util_test.go b/authentication/provider/oauth2/util_test.go new file mode 100644 index 0000000..1a44726 --- /dev/null +++ b/authentication/provider/oauth2/util_test.go @@ -0,0 +1,239 @@ +// Copyright 2021 Hyperscale. All rights reserved. +// Use of this source code is governed by a MIT +// license that can be found in the LICENSE file. + +package oauth2 + +import ( + "net/http" + "net/url" + "testing" + + "github.com/hyperscale-stack/security/authentication" + "github.com/stretchr/testify/assert" +) + +const ( + badAuthValue = "Digest XHHHHHHH" + badBasicAuthValue = "Basic €€€" + badBasicAuthWithBadFormat = "Basic Zm9vCg==" // foo + badUsernameInAuthValue = "Basic dSUyc2VybmFtZTpwYXNzd29yZA==" // u%2sername:password + badPasswordInAuthValue = "Basic dXNlcm5hbWU6cGElMnN3b3Jk" // username:pa%2sword + goodAuthValue = "Basic Y2xpZW50K25hbWU6Y2xpZW50KyUyNGVjcmV0" + goodBearerAuthValue = "Bearer BGFVTDUJDp0ZXN0" +) + +func TestBasicAuth(t *testing.T) { + r := &http.Request{Header: make(http.Header)} + + // Without any header + if b, err := CheckBasicAuth(r); b != nil || err != nil { + t.Errorf("Validated basic auth without header") + } + + // with invalid header + r.Header.Set("Authorization", badAuthValue) + b, err := CheckBasicAuth(r) + if b != nil || err == nil { + t.Errorf("Validated invalid auth") + return + } + + // with invalid value + r.Header.Set("Authorization", badBasicAuthValue) + b, err = CheckBasicAuth(r) + if b != nil || err == nil { + t.Errorf("Validated invalid auth") + return + } + + // with invalid format + r.Header.Set("Authorization", badBasicAuthWithBadFormat) + b, err = CheckBasicAuth(r) + if b != nil || err == nil { + t.Errorf("Validated invalid auth") + return + } + + // with invalid username + r.Header.Set("Authorization", badUsernameInAuthValue) + b, err = CheckBasicAuth(r) + if b != nil || err == nil { + t.Errorf("Validated invalid auth with bad username") + return + } + + // with invalid username + r.Header.Set("Authorization", badPasswordInAuthValue) + b, err = CheckBasicAuth(r) + if b != nil || err == nil { + t.Errorf("Validated invalid auth with bad password") + return + } + + // with valid header + r.Header.Set("Authorization", goodAuthValue) + b, err = CheckBasicAuth(r) + if b == nil || err != nil { + t.Errorf("Could not extract basic auth") + return + } + + // check extracted auth data + if b.Username != "client name" || b.Password != "client $ecret" { + t.Errorf("Error decoding basic auth") + } +} + +func TestGetClientAuth(t *testing.T) { + + urlWithSecret, _ := url.Parse("http://host.tld/path?client_id=xxx&client_secret=yyy") + urlWithEmptySecret, _ := url.Parse("http://host.tld/path?client_id=xxx&client_secret=") + urlNoSecret, _ := url.Parse("http://host.tld/path?client_id=xxx") + + headerNoAuth := make(http.Header) + headerBadAuth := make(http.Header) + headerBadAuth.Set("Authorization", badAuthValue) + headerOKAuth := make(http.Header) + headerOKAuth.Set("Authorization", goodAuthValue) + + storageMock := &MockStorageProvider{} + + sconfig := NewConfiguration() + + server := NewServer(WithStorage(storageMock), WithConfig(sconfig)) + + var tests = []struct { + header http.Header + url *url.URL + allowQueryParams bool + expectAuth bool + }{ + {headerNoAuth, urlWithSecret, true, true}, + {headerNoAuth, urlWithSecret, false, false}, + {headerNoAuth, urlWithEmptySecret, true, true}, + {headerNoAuth, urlWithEmptySecret, false, false}, + {headerNoAuth, urlNoSecret, true, false}, + {headerNoAuth, urlNoSecret, false, false}, + + {headerBadAuth, urlWithSecret, true, true}, + {headerBadAuth, urlWithSecret, false, false}, + {headerBadAuth, urlWithEmptySecret, true, true}, + {headerBadAuth, urlWithEmptySecret, false, false}, + {headerBadAuth, urlNoSecret, true, false}, + {headerBadAuth, urlNoSecret, false, false}, + + {headerOKAuth, urlWithSecret, true, true}, + {headerOKAuth, urlWithSecret, false, true}, + {headerOKAuth, urlWithEmptySecret, true, true}, + {headerOKAuth, urlWithEmptySecret, false, true}, + {headerOKAuth, urlNoSecret, true, true}, + {headerOKAuth, urlNoSecret, false, true}, + } + + for _, tt := range tests { + w := new(Response) + r := &http.Request{Header: tt.header, URL: tt.url} + r.ParseForm() + + f := authentication.NewHTTPBasicFilter() + r = f.OnFilter(r) + + auth := server.getClientAuth(w, r, tt.allowQueryParams) + + if tt.expectAuth { + assert.NotNil(t, auth) + } else { + assert.Nil(t, auth) + } + } + +} + +func TestBearerAuth(t *testing.T) { + r := &http.Request{Header: make(http.Header)} + + // Without any header + if b := CheckBearerAuth(r); b != nil { + t.Errorf("Validated bearer auth without header") + } + + // with invalid header + r.Header.Set("Authorization", badAuthValue) + b := CheckBearerAuth(r) + if b != nil { + t.Errorf("Validated invalid auth") + return + } + + // with valid header + r.Header.Set("Authorization", goodBearerAuthValue) + b = CheckBearerAuth(r) + if b == nil { + t.Errorf("Could not extract bearer auth") + return + } + + // check extracted auth data + if b.Code != "BGFVTDUJDp0ZXN0" { + t.Errorf("Error decoding bearer auth") + } + + // extracts bearer auth from query string + url, _ := url.Parse("http://host.tld/path?code=XYZ") + r = &http.Request{URL: url} + r.ParseForm() + b = CheckBearerAuth(r) + if b.Code != "XYZ" { + t.Errorf("Error decoding bearer auth") + } +} + +// DefaultClient stores all data in struct variables. +type testClient struct { + ID string + Secret string + RedirectURI string + UserData interface{} +} + +func (d *testClient) GetID() string { + return d.ID +} + +func (d *testClient) GetSecret() string { + return d.Secret +} + +func (d *testClient) GetRedirectURI() string { + return d.RedirectURI +} + +func (d *testClient) GetUserData() interface{} { + return d.UserData +} + +func (d *testClient) CopyFrom(client Client) { + d.ID = client.GetID() + d.Secret = client.GetSecret() + d.RedirectURI = client.GetRedirectURI() + d.UserData = client.GetUserData() +} + +func TestCheckClientSecret(t *testing.T) { + { + client := &DefaultClient{ + Secret: "foo", + } + + assert.True(t, CheckClientSecret(client, "foo")) + } + + { + client := &testClient{ + Secret: "foo", + } + + assert.True(t, CheckClientSecret(client, "foo")) + } +} diff --git a/go.mod b/go.mod index f8913c3..6d64245 100644 --- a/go.mod +++ b/go.mod @@ -3,11 +3,12 @@ module github.com/hyperscale-stack/security go 1.15 require ( + github.com/euskadi31/go-eventemitter v1.1.1 github.com/gilcrest/alice v1.0.0 - github.com/hyperscale-stack/secure v1.0.0 // indirect - github.com/rs/zerolog v1.20.0 + github.com/hyperscale-stack/logger v1.0.0 + github.com/hyperscale-stack/secure v1.0.0 + github.com/rs/zerolog v1.23.0 github.com/stretchr/objx v0.3.0 // indirect github.com/stretchr/testify v1.7.0 - golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83 - gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect + golang.org/x/crypto v0.0.0-20210813211128-0a44fdfbc16e ) diff --git a/go.sum b/go.sum index b8a0423..33eadf7 100644 --- a/go.sum +++ b/go.sum @@ -1,37 +1,60 @@ -github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= +github.com/coreos/go-systemd/v22 v22.3.2/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/euskadi31/go-eventemitter v1.1.1 h1:VrDOCjM1uaQfq3XGlrZXce+wL6JtiecPtKPwS2jATo8= +github.com/euskadi31/go-eventemitter v1.1.1/go.mod h1:2FzLg7x4u58JiXblyOwQJ+O+tAcj05CFKnN2Yz8yXEY= github.com/gilcrest/alice v1.0.0 h1:5+CasxidJEUHmgghQxLOl09uYhOlavDfDgNZhyR62LU= github.com/gilcrest/alice v1.0.0/go.mod h1:q5HRhK5WEyU1pDBIIfmYapVGLd/IAAPwiO8LNxKADpw= +github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= +github.com/hyperscale-stack/logger v1.0.0 h1:ZsjQp1DEXwHunoVhi0L5/yZOg0Npsy7LkQy10Mi2o6I= +github.com/hyperscale-stack/logger v1.0.0/go.mod h1:X1apoFZZ8/AKspix5ylBetZzv+HEU2Ive/+3qHJhyOw= github.com/hyperscale-stack/secure v1.0.0 h1:ayGoa/Y/0RcAcP767WKjla1r9KlR+Tul5DPI/jE9dP0= github.com/hyperscale-stack/secure v1.0.0/go.mod h1:PY+BMJQI2aP+YYA3C7R0bFTS/XGJ4xPCYjBp9rEqmtQ= -github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ= -github.com/rs/zerolog v1.20.0 h1:38k9hgtUBdxFwE34yS8rTHmHBa4eN16E4DJlv177LNs= -github.com/rs/zerolog v1.20.0/go.mod h1:IzD0RJ65iWH0w97OQQebJEvTZYvsCUm9WVLWBQrJRjo= +github.com/rs/zerolog v1.23.0 h1:UskrK+saS9P9Y789yNNulYKdARjPZuS35B8gJF2x60g= +github.com/rs/zerolog v1.23.0/go.mod h1:6c7hFfxPOy7TacJc4Fcdi24/J0NKYGzjG8FWRI916Qo= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.3.0 h1:NGXK3lHquSN08v5vWalVI/L8XU9hdzE/G6xsrze47As= github.com/stretchr/objx v0.3.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83 h1:/ZScEX8SfEmUGRHs0gxpqteO5nfNW6axyZbBdw9A12g= -golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20210813211128-0a44fdfbc16e h1:VvfwVmMH40bpMeizC9/K7ipM5Qjucuu16RWfneFPyhQ= +golang.org/x/crypto v0.0.0-20210813211128-0a44fdfbc16e/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210119212857-b64e53b001e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/tools v0.0.0-20190828213141-aed303cbaa74/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.1.0/go.mod h1:xkSsbof2nBLbhDlRMhhhyNLN/zl3eTqcnHD5viDpcZ0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b h1:h8qDotaEPuJATrMmW04NCwg7v22aHH28wwpauUhK9Oo= gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/integrations/oauth2_auth_by_access_token_test.go b/internal/integrations/oauth2_auth_by_access_token_test.go index 392b271..1ce63b2 100644 --- a/internal/integrations/oauth2_auth_by_access_token_test.go +++ b/internal/integrations/oauth2_auth_by_access_token_test.go @@ -37,7 +37,7 @@ func TestOauth2AuthByAccessTokenWithNoAuthHeader(t *testing.T) { storageProvider.SaveClient(client) - storageProvider.SaveAccess(&oauth2.AccessInfo{ + storageProvider.SaveAccess(&oauth2.AccessData{ Client: client, AccessToken: "I3SoKTVXi6QzMZAmDW2Fgw2MLX0msPGRN58bCDLDFthJmy6Qoy8FH5v10dbewR6PfAV3brKhepjnTJVhDplSHFe6qbF3J4YDkI5EzXG0S8X7snSoB6FtrPNFMmISuEmU", UserData: "8c87a032-755d-42f6-be96-0421948f6e94", @@ -84,7 +84,7 @@ func TestOauth2AuthByAccessTokenWithBadToken(t *testing.T) { storageProvider.SaveClient(client) - storageProvider.SaveAccess(&oauth2.AccessInfo{ + storageProvider.SaveAccess(&oauth2.AccessData{ Client: client, AccessToken: "I3SoKTVXi6QzMZAmDW2Fgw2MLX0msPGRN58bCDLDFthJmy6Qoy8FH5v10dbewR6PfAV3brKhepjnTJVhDplSHFe6qbF3J4YDkI5EzXG0S8X7snSoB6FtrPNFMmISuEmU", }) @@ -138,7 +138,7 @@ func TestOauth2AuthByAccessToken(t *testing.T) { storageProvider.SaveClient(client) - storageProvider.SaveAccess(&oauth2.AccessInfo{ + storageProvider.SaveAccess(&oauth2.AccessData{ Client: client, AccessToken: "I3SoKTVXi6QzMZAmDW2Fgw2MLX0msPGRN58bCDLDFthJmy6Qoy8FH5v10dbewR6PfAV3brKhepjnTJVhDplSHFe6qbF3J4YDkI5EzXG0S8X7snSoB6FtrPNFMmISuEmU", ExpiresIn: 60, diff --git a/user/mock_user.go b/user/mock_user.go index 1342d82..544d051 100644 --- a/user/mock_user.go +++ b/user/mock_user.go @@ -9,6 +9,20 @@ type MockUser struct { mock.Mock } +// GetID provides a mock function with given fields: +func (_m *MockUser) GetID() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + // GetPassword provides a mock function with given fields: func (_m *MockUser) GetPassword() string { ret := _m.Called() diff --git a/user/mock_user_password_salt.go b/user/mock_user_password_salt.go index 9909151..d08db06 100644 --- a/user/mock_user_password_salt.go +++ b/user/mock_user_password_salt.go @@ -9,6 +9,20 @@ type MockUserPasswordSalt struct { mock.Mock } +// GetID provides a mock function with given fields: +func (_m *MockUserPasswordSalt) GetID() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + // GetPassword provides a mock function with given fields: func (_m *MockUserPasswordSalt) GetPassword() string { ret := _m.Called() diff --git a/user/user.go b/user/user.go index 74dee72..f46594d 100644 --- a/user/user.go +++ b/user/user.go @@ -5,8 +5,11 @@ package user // User interface provides core user information +// //go:generate mockery --name=User --inpackage --case underscore type User interface { + GetID() string + // GetRoles returns the roles granted to the user. GetRoles() []string @@ -36,6 +39,7 @@ type PasswordSalt interface { } // UserPasswordSalt interface. +// //go:generate mockery --name=UserPasswordSalt --inpackage --case underscore //nolint:golint type UserPasswordSalt interface {