diff --git a/go.mod b/go.mod index 18af9134b..61e87d177 100644 --- a/go.mod +++ b/go.mod @@ -39,6 +39,7 @@ require ( github.com/crate-crypto/go-eth-kzg v1.3.0 // indirect github.com/crate-crypto/go-ipa v0.0.0-20240724233137-53bbb0ceb27a // indirect github.com/decred/dcrd/dcrec/secp256k1/v4 v4.3.0 // indirect + github.com/di-wu/parser v0.2.2 // indirect github.com/dprotaso/go-yit v0.0.0-20220510233725-9ba8df137936 // indirect github.com/ethereum/c-kzg-4844/v2 v2.1.0 // indirect github.com/ethereum/go-verkle v0.2.2 // indirect @@ -64,6 +65,7 @@ require ( github.com/oasdiff/yaml v0.0.0-20250309154309-f31be36b4037 // indirect github.com/oasdiff/yaml3 v0.0.0-20250309153720-d2182401db90 // indirect github.com/perimeterx/marshmallow v1.1.5 // indirect + github.com/scim2/filter-parser/v2 v2.2.0 // indirect github.com/segmentio/asm v1.2.0 // indirect github.com/speakeasy-api/openapi-overlay v0.9.0 // indirect github.com/supranational/blst v0.3.14 // indirect diff --git a/go.sum b/go.sum index fcd1b1dc4..fe079a481 100644 --- a/go.sum +++ b/go.sum @@ -80,6 +80,8 @@ github.com/decred/dcrd/crypto/blake256 v1.0.1 h1:7PltbUIQB7u/FfZ39+DGa/ShuMyJ5il github.com/decred/dcrd/crypto/blake256 v1.0.1/go.mod h1:2OfgNZ5wDpcsFmHmCK5gZTPcCXqlm2ArzUIkw9czNJo= github.com/decred/dcrd/dcrec/secp256k1/v4 v4.3.0 h1:rpfIENRNNilwHwZeG5+P150SMrnNEcHYvcCuK6dPZSg= github.com/decred/dcrd/dcrec/secp256k1/v4 v4.3.0/go.mod h1:v57UDF4pDQJcEfFUCRop3lJL149eHGSe9Jvczhzjo/0= +github.com/di-wu/parser v0.2.2 h1:I9oHJ8spBXOeL7Wps0ffkFFFiXJf/pk7NX9lcAMqRMU= +github.com/di-wu/parser v0.2.2/go.mod h1:SLp58pW6WamdmznrVRrw2NTyn4wAvT9rrEFynKX7nYo= github.com/didip/tollbooth/v5 v5.1.1 h1:QpKFg56jsbNuQ6FFj++Z1gn2fbBsvAc1ZPLUaDOYW5k= github.com/didip/tollbooth/v5 v5.1.1/go.mod h1:d9rzwOULswrD3YIrAQmP3bfjxab32Df4IaO6+D25l9g= github.com/dprotaso/go-yit v0.0.0-20191028211022-135eb7262960/go.mod h1:9HQzr9D/0PGwMEbC3d5AB7oi67+h4TsQqItC1GVYG58= @@ -430,6 +432,8 @@ github.com/russellhaering/goxmldsig v1.3.0 h1:DllIWUgMy0cRUMfGiASiYEa35nsieyD3ci github.com/russellhaering/goxmldsig v1.3.0/go.mod h1:gM4MDENBQf7M+V824SGfyIUVFWydB7n0KkEubVJl+Tw= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0= +github.com/scim2/filter-parser/v2 v2.2.0 h1:QGadEcsmypxg8gYChRSM2j1edLyE/2j72j+hdmI4BJM= +github.com/scim2/filter-parser/v2 v2.2.0/go.mod h1:jWnkDToqX/Y0ugz0P5VvpVEUKcWcyHHj+X+je9ce5JA= github.com/sebest/xff v0.0.0-20160910043805-6c115e0ffa35 h1:eajwn6K3weW5cd1ZXLu2sJ4pvwlBiCWY4uDejOr73gM= github.com/sebest/xff v0.0.0-20160910043805-6c115e0ffa35/go.mod h1:wozgYq9WEBQBaIJe4YZ0qTSFAMxmcwBhQH0fO0R34Z0= github.com/segmentio/asm v1.2.0 h1:9BQrFxC+YOHJlTlHGkTrFWf59nbL3XnCoFLTwDCI7ys= diff --git a/hack/test.env b/hack/test.env index dc4769eaa..09c38a590 100644 --- a/hack/test.env +++ b/hack/test.env @@ -129,6 +129,7 @@ GOTRUE_SECURITY_CAPTCHA_PROVIDER="hcaptcha" GOTRUE_SECURITY_CAPTCHA_SECRET="0x0000000000000000000000000000000000000000" GOTRUE_SECURITY_CAPTCHA_TIMEOUT="10s" GOTRUE_SAML_ENABLED="true" +GOTRUE_SCIM_ENABLED="true" GOTRUE_SAML_PRIVATE_KEY="MIIEowIBAAKCAQEAszrVveMQcSsa0Y+zN1ZFb19cRS0jn4UgIHTprW2tVBmO2PABzjY3XFCfx6vPirMAPWBYpsKmXrvm1tr0A6DZYmA8YmJd937VUQ67fa6DMyppBYTjNgGEkEhmKuszvF3MARsIKCGtZqUrmS7UG4404wYxVppnr2EYm3RGtHlkYsXu20MBqSDXP47bQP+PkJqC3BuNGk3xt5UHl2FSFpTHelkI6lBynw16B+lUT1F96SERNDaMqi/TRsZdGe5mB/29ngC/QBMpEbRBLNRir5iUevKS7Pn4aph9Qjaxx/97siktK210FJT23KjHpgcUfjoQ6BgPBTLtEeQdRyDuc/CgfwIDAQABAoIBAGYDWOEpupQPSsZ4mjMnAYJwrp4ZISuMpEqVAORbhspVeb70bLKonT4IDcmiexCg7cQBcLQKGpPVM4CbQ0RFazXZPMVq470ZDeWDEyhoCfk3bGtdxc1Zc9CDxNMs6FeQs6r1beEZug6weG5J/yRn/qYxQife3qEuDMl+lzfl2EN3HYVOSnBmdt50dxRuX26iW3nqqbMRqYn9OHuJ1LvRRfYeyVKqgC5vgt/6Tf7DAJwGe0dD7q08byHV8DBZ0pnMVU0bYpf1GTgMibgjnLjK//EVWafFHtN+RXcjzGmyJrk3+7ZyPUpzpDjO21kpzUQLrpEkkBRnmg6bwHnSrBr8avECgYEA3pq1PTCAOuLQoIm1CWR9/dhkbJQiKTJevlWV8slXQLR50P0WvI2RdFuSxlWmA4xZej8s4e7iD3MYye6SBsQHygOVGc4efvvEZV8/XTlDdyj7iLVGhnEmu2r7AFKzy8cOvXx0QcLg+zNd7vxZv/8D3Qj9Jje2LjLHKM5n/dZ3RzUCgYEAzh5Lo2anc4WN8faLGt7rPkGQF+7/18ImQE11joHWa3LzAEy7FbeOGpE/vhOv5umq5M/KlWFIRahMEQv4RusieHWI19ZLIP+JwQFxWxS+cPp3xOiGcquSAZnlyVSxZ//dlVgaZq2o2MfrxECcovRlaknl2csyf+HjFFwKlNxHm2MCgYAr//R3BdEy0oZeVRndo2lr9YvUEmu2LOihQpWDCd0fQw0ZDA2kc28eysL2RROte95r1XTvq6IvX5a0w11FzRWlDpQ4J4/LlcQ6LVt+98SoFwew+/PWuyLmxLycUbyMOOpm9eSc4wJJZNvaUzMCSkvfMtmm5jgyZYMMQ9A2Ul/9SQKBgB9mfh9mhBwVPIqgBJETZMMXOdxrjI5SBYHGSyJqpT+5Q0vIZLfqPrvNZOiQFzwWXPJ+tV4Mc/YorW3rZOdo6tdvEGnRO6DLTTEaByrY/io3/gcBZXoSqSuVRmxleqFdWWRnB56c1hwwWLqNHU+1671FhL6pNghFYVK4suP6qu4BAoGBAMk+VipXcIlD67mfGrET/xDqiWWBZtgTzTMjTpODhDY1GZck1eb4CQMP5j5V3gFJ4cSgWDJvnWg8rcz0unz/q4aeMGl1rah5WNDWj1QKWMS6vJhMHM/rqN1WHWR0ZnV83svYgtg0zDnQKlLujqW4JmGXLMU7ur6a+e6lpa1fvLsP" GOTRUE_MAX_VERIFIED_FACTORS=10 GOTRUE_SMS_TEST_OTP_VALID_UNTIL="" diff --git a/internal/api/admin.go b/internal/api/admin.go index e75fbb353..c55a3b055 100644 --- a/internal/api/admin.go +++ b/internal/api/admin.go @@ -299,7 +299,7 @@ func (a *API) adminUserUpdate(w http.ResponseWriter, r *http.Request) error { } if banDuration != nil { - if terr := user.Ban(tx, *banDuration); terr != nil { + if terr := user.Ban(tx, *banDuration, nil); terr != nil { return terr } } @@ -493,7 +493,7 @@ func (a *API) adminUserCreate(w http.ResponseWriter, r *http.Request) error { } if banDuration != nil { - if terr := user.Ban(tx, *banDuration); terr != nil { + if terr := user.Ban(tx, *banDuration, nil); terr != nil { return terr } } diff --git a/internal/api/api.go b/internal/api/api.go index c2536c0a7..a54077e0e 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -201,6 +201,48 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne r.Post("/", api.ExternalProviderCallback) }) + // SCIM v2 API endpoints + if api.config.SCIM.Enabled { + r.Route("/scim/v2", func(r *router) { + r.Use(api.requireSCIMAuthentication) + + // SCIM-specific NotFound handler for proper error format + r.NotFound(api.scimNotFound) + r.MethodNotAllowed(api.scimMethodNotAllowed) + + // Service Provider Configuration + r.Get("/ServiceProviderConfig", api.scimServiceProviderConfig) + r.Get("/ResourceTypes", api.scimResourceTypes) + r.Get("/ResourceTypes/{resource_type_id}", api.scimResourceTypeByID) + r.Get("/Schemas", api.scimSchemas) + r.Get("/Schemas/{schema_id}", api.scimSchemaByID) + + // User endpoints + r.Route("/Users", func(r *router) { + r.Get("/", api.scimListUsers) + r.Post("/", api.scimCreateUser) + r.Route("/{user_id}", func(r *router) { + r.Get("/", api.scimGetUser) + r.Put("/", api.scimReplaceUser) + r.Patch("/", api.scimPatchUser) + r.Delete("/", api.scimDeleteUser) + }) + }) + + // Group endpoints + r.Route("/Groups", func(r *router) { + r.Get("/", api.scimListGroups) + r.Post("/", api.scimCreateGroup) + r.Route("/{group_id}", func(r *router) { + r.Get("/", api.scimGetGroup) + r.Put("/", api.scimReplaceGroup) + r.Patch("/", api.scimPatchGroup) + r.Delete("/", api.scimDeleteGroup) + }) + }) + }) + } + r.Route("/", func(r *router) { r.Use(api.isValidExternalHost) @@ -352,6 +394,14 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne r.Get("/", api.adminSSOProvidersGet) r.Put("/", api.adminSSOProvidersUpdate) r.Delete("/", api.adminSSOProvidersDelete) + + // SCIM management endpoints + r.Route("/scim", func(r *router) { + r.Get("/", api.adminSSOProviderGetSCIM) + r.Post("/", api.adminSSOProviderEnableSCIM) + r.Delete("/", api.adminSSOProviderDisableSCIM) + r.Post("/rotate", api.adminSSOProviderRotateSCIMToken) + }) }) }) }) diff --git a/internal/api/apierrors/apierrors.go b/internal/api/apierrors/apierrors.go index adab1d39c..110f97a6b 100644 --- a/internal/api/apierrors/apierrors.go +++ b/internal/api/apierrors/apierrors.go @@ -120,3 +120,81 @@ func (e *HTTPError) WithInternalMessage(fmtString string, args ...any) *HTTPErro e.InternalMessage = fmt.Sprintf(fmtString, args...) return e } + +// SCIMHTTPError is an error with SCIM-specific format per RFC 7644 Section 3.12 +type SCIMHTTPError struct { + HTTPStatus int `json:"-"` + Schemas []string `json:"schemas"` + Status string `json:"status"` + Detail string `json:"detail,omitempty"` + ScimType string `json:"scimType,omitempty"` + InternalError error `json:"-"` + InternalMessage string `json:"-"` +} + +const SCIMSchemaError = "urn:ietf:params:scim:api:messages:2.0:Error" + +func NewSCIMHTTPError(httpStatus int, detail string, scimType string) *SCIMHTTPError { + return &SCIMHTTPError{ + HTTPStatus: httpStatus, + Schemas: []string{SCIMSchemaError}, + Status: fmt.Sprintf("%d", httpStatus), + Detail: detail, + ScimType: scimType, + } +} + +func NewSCIMBadRequestError(detail string, scimType string) *SCIMHTTPError { + return NewSCIMHTTPError(http.StatusBadRequest, detail, scimType) +} + +func NewSCIMNotFoundError(detail string) *SCIMHTTPError { + return NewSCIMHTTPError(http.StatusNotFound, detail, "") +} + +func NewSCIMUnauthorizedError(detail string) *SCIMHTTPError { + return NewSCIMHTTPError(http.StatusUnauthorized, detail, "") +} + +func NewSCIMConflictError(detail string, scimType string) *SCIMHTTPError { + return NewSCIMHTTPError(http.StatusConflict, detail, scimType) +} + +func NewSCIMForbiddenError(detail string) *SCIMHTTPError { + return NewSCIMHTTPError(http.StatusForbidden, detail, "") +} + +func NewSCIMRequestTooLargeError(detail string) *SCIMHTTPError { + return NewSCIMHTTPError(http.StatusRequestEntityTooLarge, detail, "") +} + +func NewSCIMInternalServerError(detail string) *SCIMHTTPError { + return NewSCIMHTTPError(http.StatusInternalServerError, detail, "") +} + +func (e *SCIMHTTPError) Error() string { + if e.InternalMessage != "" { + return e.InternalMessage + } + return fmt.Sprintf("%d: %s", e.HTTPStatus, e.Detail) +} + +// Cause returns the root cause error +func (e *SCIMHTTPError) Cause() error { + if e.InternalError != nil { + return e.InternalError + } + return e +} + +// WithInternalError adds internal error information to the error +func (e *SCIMHTTPError) WithInternalError(err error) *SCIMHTTPError { + e.InternalError = err + return e +} + +// WithInternalMessage adds internal message information to the error +func (e *SCIMHTTPError) WithInternalMessage(fmtString string, args ...any) *SCIMHTTPError { + e.InternalMessage = fmt.Sprintf(fmtString, args...) + return e +} diff --git a/internal/api/apierrors/errorcode.go b/internal/api/apierrors/errorcode.go index 58963eea3..10fdbc167 100644 --- a/internal/api/apierrors/errorcode.go +++ b/internal/api/apierrors/errorcode.go @@ -61,6 +61,7 @@ const ( ErrorCodeUserAlreadyExists ErrorCode = "user_already_exists" ErrorCodeSSOProviderNotFound ErrorCode = "sso_provider_not_found" ErrorCodeSSOProviderDisabled ErrorCode = "sso_provider_disabled" + ErrorCodeSCIMDisabled ErrorCode = "scim_disabled" ErrorCodeSAMLMetadataFetchFailed ErrorCode = "saml_metadata_fetch_failed" ErrorCodeSAMLIdPAlreadyExists ErrorCode = "saml_idp_already_exists" ErrorCodeSSODomainAlreadyExists ErrorCode = "sso_domain_already_exists" diff --git a/internal/api/errors.go b/internal/api/errors.go index a9b467f36..4ae6668dd 100644 --- a/internal/api/errors.go +++ b/internal/api/errors.go @@ -188,6 +188,19 @@ func HandleResponseError(err error, w http.ResponseWriter, r *http.Request) { log.WithError(jsonErr).Warn("Failed to send JSON on ResponseWriter") } + case *apierrors.SCIMHTTPError: + switch { + case e.HTTPStatus >= http.StatusInternalServerError: + log.WithError(e.Cause()).Error(e.Error()) + case e.HTTPStatus == http.StatusTooManyRequests: + log.WithError(e.Cause()).Warn(e.Error()) + default: + log.WithError(e.Cause()).Info(e.Error()) + } + if jsonErr := sendSCIMJSON(w, e.HTTPStatus, e); jsonErr != nil && jsonErr != context.DeadlineExceeded { + log.WithError(jsonErr).Warn("Failed to send JSON on ResponseWriter") + } + case ErrorCause: HandleResponseError(e.Cause(), w, r) diff --git a/internal/api/router.go b/internal/api/router.go index 1feb66d3f..d7c9a0b3b 100644 --- a/internal/api/router.go +++ b/internal/api/router.go @@ -30,6 +30,9 @@ func (r *router) Post(pattern string, fn apiHandler) { func (r *router) Put(pattern string, fn apiHandler) { r.chi.Put(pattern, handler(fn)) } +func (r *router) Patch(pattern string, fn apiHandler) { + r.chi.Patch(pattern, handler(fn)) +} func (r *router) Delete(pattern string, fn apiHandler) { r.chi.Delete(pattern, handler(fn)) } @@ -51,6 +54,14 @@ func (r *router) UseBypass(fn func(next http.Handler) http.Handler) { r.chi.Use(fn) } +func (r *router) NotFound(fn apiHandler) { + r.chi.NotFound(handler(fn)) +} + +func (r *router) MethodNotAllowed(fn apiHandler) { + r.chi.MethodNotAllowed(handler(fn)) +} + func (r *router) ServeHTTP(w http.ResponseWriter, req *http.Request) { r.chi.ServeHTTP(w, req) } diff --git a/internal/api/scim.go b/internal/api/scim.go new file mode 100644 index 000000000..bbcd96e3f --- /dev/null +++ b/internal/api/scim.go @@ -0,0 +1,1428 @@ +package api + +import ( + "context" + "fmt" + "math" + "net/http" + "strings" + "time" + + "github.com/go-chi/chi/v5" + "github.com/gofrs/uuid" + filter "github.com/scim2/filter-parser/v2" + "github.com/supabase/auth/internal/api/apierrors" + "github.com/supabase/auth/internal/models" + "github.com/supabase/auth/internal/storage" + "github.com/supabase/auth/internal/utilities" +) + +func (a *API) requireSCIMAuthentication(w http.ResponseWriter, r *http.Request) (context.Context, error) { + ctx := r.Context() + db := a.db.WithContext(ctx) + + token, err := a.extractBearerToken(r) + if err != nil { + return nil, apierrors.NewSCIMUnauthorizedError("Invalid or missing SCIM bearer token") + } + + provider, err := models.FindSSOProviderBySCIMToken(db, token) + if err != nil { + if models.IsNotFoundError(err) { + return nil, apierrors.NewSCIMUnauthorizedError("Invalid SCIM bearer token") + } + return nil, apierrors.NewSCIMInternalServerError("Error validating SCIM token").WithInternalError(err) + } + + if !provider.IsSCIMEnabled() { + return nil, apierrors.NewSCIMForbiddenError("SCIM provisioning is not enabled for this provider") + } + + if !provider.IsEnabled() { + return nil, apierrors.NewSCIMForbiddenError("SSO provider is disabled") + } + + return withSSOProvider(ctx, provider), nil +} + +func (a *API) scimListUsers(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + db := a.db.WithContext(ctx) + provider := getSSOProvider(ctx) + + startIndex, count := parseSCIMPagination(r) + + filterStr := r.URL.Query().Get("filter") + filterClause, err := ParseSCIMFilterToSQL(filterStr, SCIMUserFilterAttrs) + if err != nil { + return err + } + + providerType := "sso:" + provider.ID.String() + + users, totalResults, err := models.FindUsersByProviderWithFilter(db, providerType, filterClause, startIndex, count) + if err != nil { + return apierrors.NewSCIMInternalServerError("Error fetching users").WithInternalError(err) + } + + resources := make([]interface{}, len(users)) + for i, user := range users { + resources[i] = a.userToSCIMResponse(user, providerType) + } + + return sendSCIMJSON(w, http.StatusOK, &SCIMListResponse{ + Schemas: []string{SCIMSchemaListResponse}, + TotalResults: totalResults, + StartIndex: startIndex, + ItemsPerPage: len(users), + Resources: resources, + }) +} + +func (a *API) scimGetUser(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + db := a.db.WithContext(ctx) + provider := getSSOProvider(ctx) + + userID, err := uuid.FromString(chi.URLParam(r, "user_id")) + if err != nil { + return apierrors.NewSCIMNotFoundError(scimErrUserNotFound) + } + + user, err := models.FindUserByID(db, userID) + if err != nil { + if models.IsNotFoundError(err) { + return apierrors.NewSCIMNotFoundError(scimErrUserNotFound) + } + return apierrors.NewSCIMInternalServerError("Error fetching user").WithInternalError(err) + } + + if !models.UserBelongsToSSOProvider(user, provider.ID) { + return apierrors.NewSCIMNotFoundError(scimErrUserNotFound) + } + + return sendSCIMJSON(w, http.StatusOK, a.userToSCIMResponse(user, "sso:"+provider.ID.String())) +} + +func (a *API) scimCreateUser(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + db := a.db.WithContext(ctx) + provider := getSSOProvider(ctx) + config := a.config + + var params SCIMUserParams + if err := a.parseSCIMBody(w, r, ¶ms); err != nil { + return err + } + if err := params.Validate(); err != nil { + return err + } + + email, emailType := extractPrimarySCIMEmail(params.Emails) + + if email == "" { + return apierrors.NewSCIMBadRequestError("At least one email address is required", "invalidValue") + } + + email, err := a.validateEmail(email) + if err != nil { + return apierrors.NewSCIMBadRequestError("Invalid email address", "invalidValue") + } + + providerType := "sso:" + provider.ID.String() + + var user *models.User + terr := db.Transaction(func(tx *storage.Connection) error { + nonSSOUser, err := models.FindUserByEmailAndAudience(tx, email, config.JWT.Aud) + if err != nil && !models.IsNotFoundError(err) { + return apierrors.NewSCIMInternalServerError("Error checking existing user").WithInternalError(err) + } + if nonSSOUser != nil { + return apierrors.NewSCIMConflictError(scimErrEmailConflict, "uniqueness") + } + + ssoUsers, err := models.FindSSOUsersByEmailAndProvider(tx, email, config.JWT.Aud, providerType) + if err != nil { + return apierrors.NewSCIMInternalServerError("Error checking existing SSO user").WithInternalError(err) + } + + if len(ssoUsers) > 0 { + var deprovisioned []*models.User + for _, u := range ssoUsers { + if u.BannedReason == nil || *u.BannedReason != scimDeprovisionedReason { + return apierrors.NewSCIMConflictError(scimErrEmailConflict, "uniqueness") + } + deprovisioned = append(deprovisioned, u) + } + + if len(deprovisioned) > 1 { + return apierrors.NewSCIMConflictError(scimErrAmbiguousDeprovisioned, "uniqueness") + } + + candidate := deprovisioned[0] + + if params.Active == nil || bool(*params.Active) { + if err := candidate.Ban(tx, 0, nil); err != nil { + return apierrors.NewSCIMInternalServerError("Error reactivating user").WithInternalError(err) + } + } + + if params.Name != nil { + metadata := candidate.UserMetaData + if metadata == nil { + metadata = make(map[string]interface{}) + } + applySCIMNameToMetadata(metadata, params.Name) + candidate.UserMetaData = metadata + if err := tx.UpdateOnly(candidate, "raw_user_meta_data"); err != nil { + return apierrors.NewSCIMInternalServerError("Error updating user metadata").WithInternalError(err) + } + } + + if email != candidate.GetEmail() { + if err := candidate.SetEmail(tx, email); err != nil { + return apierrors.NewSCIMInternalServerError("Error updating user email").WithInternalError(err) + } + } + + for i := range candidate.Identities { + if candidate.Identities[i].Provider == providerType { + if candidate.Identities[i].IdentityData == nil { + candidate.Identities[i].IdentityData = make(map[string]interface{}) + } + candidate.Identities[i].IdentityData["user_name"] = params.UserName + candidate.Identities[i].IdentityData["email"] = email + if params.ExternalID != "" { + candidate.Identities[i].ProviderID = params.ExternalID + candidate.Identities[i].IdentityData["external_id"] = params.ExternalID + candidate.Identities[i].IdentityData["sub"] = params.ExternalID + } else { + candidate.Identities[i].ProviderID = params.UserName + candidate.Identities[i].IdentityData["sub"] = params.UserName + } + if err := tx.UpdateOnly(&candidate.Identities[i], "provider_id", "identity_data"); err != nil { + if pgErr := utilities.NewPostgresError(err); pgErr != nil && pgErr.IsUniqueConstraintViolated() { + return apierrors.NewSCIMConflictError(scimErrExternalIDConflict, "uniqueness") + } + return apierrors.NewSCIMInternalServerError("Error updating identity").WithInternalError(err) + } + break + } + } + + auditAction := "reactivated" + if params.Active != nil && !bool(*params.Active) { + auditAction = "reprovisioned_inactive" + } + if terr := models.NewAuditLogEntry(config.AuditLog, r, tx, candidate, models.UserModifiedAction, utilities.GetIPAddress(r), map[string]interface{}{ + "provider": "scim", + "sso_provider_id": provider.ID, + "action": auditAction, + }); terr != nil { + return apierrors.NewSCIMInternalServerError("Error recording audit log entry").WithInternalError(terr) + } + user = candidate + return nil + } + + user, err = models.NewUser("", email, "", config.JWT.Aud, nil) + if err != nil { + return apierrors.NewSCIMInternalServerError("Error creating user").WithInternalError(err) + } + user.IsSSOUser = true + + if params.Name != nil { + metadata := make(map[string]interface{}) + applySCIMNameToMetadata(metadata, params.Name) + if len(metadata) > 0 { + user.UserMetaData = metadata + } + } + + if err := tx.Create(user); err != nil { + if pgErr := utilities.NewPostgresError(err); pgErr != nil && pgErr.IsUniqueConstraintViolated() { + return apierrors.NewSCIMConflictError(scimErrEmailConflict, "uniqueness") + } + return apierrors.NewSCIMInternalServerError("Error saving user").WithInternalError(err) + } + + identityID := params.ExternalID + if identityID == "" { + identityID = params.UserName + } + + if _, err := a.createNewIdentity(tx, user, providerType, map[string]interface{}{ + "sub": identityID, + "external_id": params.ExternalID, + "email": email, + "email_type": emailType, + "user_name": params.UserName, + }); err != nil { + errToCheck := err + if httpErr, ok := err.(*apierrors.HTTPError); ok && httpErr.InternalError != nil { + errToCheck = httpErr.InternalError + } + if pgErr := utilities.NewPostgresError(errToCheck); pgErr != nil && pgErr.IsUniqueConstraintViolated() { + return apierrors.NewSCIMConflictError(scimErrExternalIDConflict, "uniqueness") + } + return err + } + + if terr := models.NewAuditLogEntry(config.AuditLog, r, tx, user, models.UserSignedUpAction, utilities.GetIPAddress(r), map[string]interface{}{ + "provider": "scim", + "sso_provider_id": provider.ID, + }); terr != nil { + return apierrors.NewSCIMInternalServerError("Error recording audit log entry").WithInternalError(terr) + } + + if err := tx.Eager().Find(user, user.ID); err != nil { + return apierrors.NewSCIMInternalServerError("Error reloading user").WithInternalError(err) + } + + if params.Active != nil && !bool(*params.Active) { + if err := user.Ban(tx, time.Duration(math.MaxInt64), &scimDeprovisionedReason); err != nil { + return apierrors.NewSCIMInternalServerError("Error banning user").WithInternalError(err) + } + if err := models.Logout(tx, user.ID); err != nil { + return apierrors.NewSCIMInternalServerError("Error invalidating sessions").WithInternalError(err) + } + } + + return nil + }) + + if terr != nil { + return terr + } + + return sendSCIMJSON(w, http.StatusCreated, a.userToSCIMResponse(user, providerType)) +} + +func (a *API) scimReplaceUser(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + db := a.db.WithContext(ctx) + provider := getSSOProvider(ctx) + config := a.config + + userID, err := uuid.FromString(chi.URLParam(r, "user_id")) + if err != nil { + return apierrors.NewSCIMNotFoundError(scimErrUserNotFound) + } + + var params SCIMUserParams + if err := a.parseSCIMBody(w, r, ¶ms); err != nil { + return err + } + if err := params.Validate(); err != nil { + return err + } + + email, _ := extractPrimarySCIMEmail(params.Emails) + if email != "" { + email, err = a.validateEmail(email) + if err != nil { + return apierrors.NewSCIMBadRequestError("Invalid email address", "invalidValue") + } + } + + var user *models.User + terr := db.Transaction(func(tx *storage.Connection) error { + var err error + user, err = models.FindUserByID(tx, userID) + if err != nil { + if models.IsNotFoundError(err) { + return apierrors.NewSCIMNotFoundError(scimErrUserNotFound) + } + return apierrors.NewSCIMInternalServerError("Error fetching user").WithInternalError(err) + } + + if !models.UserBelongsToSSOProvider(user, provider.ID) { + return apierrors.NewSCIMNotFoundError(scimErrUserNotFound) + } + + metadata := user.UserMetaData + if metadata == nil { + metadata = make(map[string]interface{}) + } + delete(metadata, "given_name") + delete(metadata, "family_name") + delete(metadata, "full_name") + applySCIMNameToMetadata(metadata, params.Name) + user.UserMetaData = metadata + + if params.Active != nil { + if *params.Active { + if err := user.Ban(tx, 0, nil); err != nil { + return apierrors.NewSCIMInternalServerError("Error unbanning user").WithInternalError(err) + } + } else { + if err := user.Ban(tx, time.Duration(math.MaxInt64), &scimDeprovisionedReason); err != nil { + return apierrors.NewSCIMInternalServerError("Error banning user").WithInternalError(err) + } + if err := models.Logout(tx, user.ID); err != nil { + return apierrors.NewSCIMInternalServerError("Error invalidating sessions").WithInternalError(err) + } + } + } + + providerType := "sso:" + provider.ID.String() + + if email != "" && email != user.GetEmail() { + if err := checkSCIMEmailUniqueness(tx, email, config.JWT.Aud, providerType, user.ID); err != nil { + return err + } + if err := user.SetEmail(tx, email); err != nil { + if pgErr := utilities.NewPostgresError(err); pgErr != nil && pgErr.IsUniqueConstraintViolated() { + return apierrors.NewSCIMConflictError(scimErrEmailConflict, "uniqueness") + } + return apierrors.NewSCIMInternalServerError("Error updating user email").WithInternalError(err) + } + } + + if err := tx.UpdateOnly(user, "raw_user_meta_data"); err != nil { + return apierrors.NewSCIMInternalServerError("Error updating user").WithInternalError(err) + } + for i := range user.Identities { + if user.Identities[i].Provider == providerType { + if user.Identities[i].IdentityData == nil { + user.Identities[i].IdentityData = make(map[string]interface{}) + } + user.Identities[i].IdentityData["user_name"] = params.UserName + if email != "" { + user.Identities[i].IdentityData["email"] = email + } + updateCols := []string{"identity_data", "provider_id"} + if params.ExternalID != "" { + user.Identities[i].ProviderID = params.ExternalID + user.Identities[i].IdentityData["external_id"] = params.ExternalID + user.Identities[i].IdentityData["sub"] = params.ExternalID + } else { + delete(user.Identities[i].IdentityData, "external_id") + user.Identities[i].ProviderID = params.UserName + user.Identities[i].IdentityData["sub"] = params.UserName + } + if err := tx.UpdateOnly(&user.Identities[i], updateCols...); err != nil { + if pgErr := utilities.NewPostgresError(err); pgErr != nil && pgErr.IsUniqueConstraintViolated() { + return apierrors.NewSCIMConflictError(scimErrExternalIDConflict, "uniqueness") + } + return apierrors.NewSCIMInternalServerError("Error updating identity").WithInternalError(err) + } + break + } + } + + if terr := models.NewAuditLogEntry(config.AuditLog, r, tx, user, models.UserModifiedAction, utilities.GetIPAddress(r), map[string]interface{}{ + "provider": "scim", + "sso_provider_id": provider.ID, + }); terr != nil { + return apierrors.NewSCIMInternalServerError("Error recording audit log entry").WithInternalError(terr) + } + + if err := tx.Eager().Find(user, user.ID); err != nil { + return apierrors.NewSCIMInternalServerError("Error reloading user").WithInternalError(err) + } + + return nil + }) + + if terr != nil { + return terr + } + + return sendSCIMJSON(w, http.StatusOK, a.userToSCIMResponse(user, "sso:"+provider.ID.String())) +} + +func (a *API) scimPatchUser(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + db := a.db.WithContext(ctx) + provider := getSSOProvider(ctx) + config := a.config + + userID, err := uuid.FromString(chi.URLParam(r, "user_id")) + if err != nil { + return apierrors.NewSCIMNotFoundError(scimErrUserNotFound) + } + + var params SCIMPatchRequest + if err := a.parseSCIMBody(w, r, ¶ms); err != nil { + return err + } + if err := params.Validate(); err != nil { + return err + } + + var user *models.User + terr := db.Transaction(func(tx *storage.Connection) error { + var err error + user, err = models.FindUserByID(tx, userID) + if err != nil { + if models.IsNotFoundError(err) { + return apierrors.NewSCIMNotFoundError(scimErrUserNotFound) + } + return apierrors.NewSCIMInternalServerError("Error fetching user").WithInternalError(err) + } + + if !models.UserBelongsToSSOProvider(user, provider.ID) { + return apierrors.NewSCIMNotFoundError(scimErrUserNotFound) + } + + for _, op := range params.Operations { + if err := a.applySCIMUserPatch(tx, user, op, provider.ID); err != nil { + return err + } + } + + if err := tx.Eager().Find(user, user.ID); err != nil { + return apierrors.NewSCIMInternalServerError("Error reloading user").WithInternalError(err) + } + + if terr := models.NewAuditLogEntry(config.AuditLog, r, tx, user, models.UserModifiedAction, utilities.GetIPAddress(r), map[string]interface{}{ + "provider": "scim", + "sso_provider_id": provider.ID, + }); terr != nil { + return apierrors.NewSCIMInternalServerError("Error recording audit log entry").WithInternalError(terr) + } + + return nil + }) + + if terr != nil { + return terr + } + + return sendSCIMJSON(w, http.StatusOK, a.userToSCIMResponse(user, "sso:"+provider.ID.String())) +} + +func (a *API) applySCIMUserPatch(tx *storage.Connection, user *models.User, op SCIMPatchOperation, providerID uuid.UUID) error { + providerType := "sso:" + providerID.String() + + var path *filter.Path + if op.Path != "" { + p, err := filter.ParsePath([]byte(op.Path)) + if err != nil { + return apierrors.NewSCIMBadRequestError( + fmt.Sprintf("Invalid path: %v", err), "invalidPath") + } + path = &p + } + + switch strings.ToLower(op.Op) { + case "remove": + return a.applySCIMUserRemove(tx, user, op, path, providerType) + case "add": + return a.applySCIMUserAdd(tx, user, op, path, providerType) + case "replace": + return a.applySCIMUserReplace(tx, user, op, path, providerType) + default: + return apierrors.NewSCIMBadRequestError(fmt.Sprintf("Unsupported patch operation: %s", op.Op), "invalidSyntax") + } +} + +func (a *API) applySCIMUserRemove(tx *storage.Connection, user *models.User, op SCIMPatchOperation, path *filter.Path, providerType string) error { + if path == nil { + return apierrors.NewSCIMBadRequestError("remove operation requires a path", "noTarget") + } + attrName := strings.ToLower(path.AttributePath.AttributeName) + if attrName == "externalid" { + if identity := findSSOIdentity(user, providerType); identity != nil { + if identity.IdentityData != nil { + delete(identity.IdentityData, "external_id") + } + fallbackID := user.GetEmail() + if userName, ok := identity.IdentityData["user_name"].(string); ok && userName != "" { + fallbackID = userName + } + identity.ProviderID = fallbackID + identity.IdentityData["sub"] = fallbackID + if err := tx.UpdateOnly(identity, "provider_id", "identity_data"); err != nil { + return apierrors.NewSCIMInternalServerError("Error updating identity").WithInternalError(err) + } + } + return nil + } + return apierrors.NewSCIMBadRequestError(fmt.Sprintf("Unsupported remove path: %s", op.Path), "invalidPath") +} + +func (a *API) applySCIMUserAdd(tx *storage.Connection, user *models.User, op SCIMPatchOperation, path *filter.Path, providerType string) error { + if path != nil { + attrName := strings.ToLower(path.AttributePath.AttributeName) + if attrName == "externalid" { + externalID, ok := op.Value.(string) + if !ok { + return apierrors.NewSCIMBadRequestError("externalId must be a string", "invalidValue") + } + if identity := findSSOIdentity(user, providerType); identity != nil { + return setSCIMExternalID(tx, identity, externalID) + } + return nil + } + return apierrors.NewSCIMBadRequestError(fmt.Sprintf("Unsupported add path: %s", op.Path), "invalidPath") + } + + valueMap, ok := op.Value.(map[string]interface{}) + if !ok { + return apierrors.NewSCIMBadRequestError("add operation without path requires an object value", "invalidValue") + } + for key, val := range valueMap { + if key == "" { + continue + } + keyPath, err := filter.ParsePath([]byte(key)) + if err != nil { + return apierrors.NewSCIMBadRequestError(fmt.Sprintf("Invalid attribute path: %s", key), "invalidPath") + } + if strings.ToLower(keyPath.AttributePath.AttributeName) == "externalid" { + if externalID, ok := val.(string); ok { + if identity := findSSOIdentity(user, providerType); identity != nil { + if err := setSCIMExternalID(tx, identity, externalID); err != nil { + return err + } + } + } + } + } + return nil +} + +func (a *API) applySCIMUserReplace(tx *storage.Connection, user *models.User, op SCIMPatchOperation, path *filter.Path, providerType string) error { + if path != nil { + return a.applySCIMUserReplaceWithPath(tx, user, op, path, providerType) + } + + valueMap, ok := op.Value.(map[string]interface{}) + if !ok { + return apierrors.NewSCIMBadRequestError("replace operation value must be an object when path is not specified", "invalidValue") + } + if user.UserMetaData == nil { + user.UserMetaData = make(map[string]interface{}) + } + metadataUpdated := false + for key, val := range valueMap { + if key == "" { + continue + } + keyPath, err := filter.ParsePath([]byte(key)) + if err != nil { + return apierrors.NewSCIMBadRequestError(fmt.Sprintf("Invalid attribute path: %s", key), "invalidPath") + } + attrName := strings.ToLower(keyPath.AttributePath.AttributeName) + subAttr := strings.ToLower(keyPath.AttributePath.SubAttributeName()) + + switch { + case attrName == "username": + if userName, ok := val.(string); ok && userName != "" { + if identity := findSSOIdentity(user, providerType); identity != nil { + if err := setSCIMUserName(tx, identity, userName); err != nil { + return err + } + } + } + case attrName == "name" && subAttr == "formatted": + if v, ok := val.(string); ok { + user.UserMetaData["full_name"] = v + metadataUpdated = true + } + case attrName == "name" && subAttr == "familyname": + if v, ok := val.(string); ok { + user.UserMetaData["family_name"] = v + metadataUpdated = true + } + case attrName == "name" && subAttr == "givenname": + if v, ok := val.(string); ok { + user.UserMetaData["given_name"] = v + metadataUpdated = true + } + case attrName == "externalid": + if externalID, ok := val.(string); ok { + if identity := findSSOIdentity(user, providerType); identity != nil { + if err := setSCIMExternalID(tx, identity, externalID); err != nil { + return err + } + } + } + case attrName == "emails" && keyPath.ValueExpression != nil && strings.ToLower(keyPath.SubAttributeName()) == "value": + if emailValue, ok := val.(string); ok { + if err := a.applySCIMEmailUpdate(tx, user, emailValue, providerType); err != nil { + return err + } + } + case attrName == "active": + if err := a.applySCIMActiveUpdate(tx, user, val); err != nil { + return err + } + } + } + if metadataUpdated { + if err := tx.UpdateOnly(user, "raw_user_meta_data"); err != nil { + return apierrors.NewSCIMInternalServerError("Error updating user metadata").WithInternalError(err) + } + } + return nil +} + +func (a *API) applySCIMUserReplaceWithPath(tx *storage.Connection, user *models.User, op SCIMPatchOperation, path *filter.Path, providerType string) error { + attrName := strings.ToLower(path.AttributePath.AttributeName) + switch { + case attrName == "active": + return a.applySCIMActiveUpdate(tx, user, op.Value) + case attrName == "username": + userName, ok := op.Value.(string) + if !ok { + return apierrors.NewSCIMBadRequestError("userName must be a string", "invalidValue") + } + if identity := findSSOIdentity(user, providerType); identity != nil { + return setSCIMUserName(tx, identity, userName) + } + return nil + case attrName == "emails" && path.ValueExpression != nil && strings.ToLower(path.SubAttributeName()) == "value": + newEmail, ok := op.Value.(string) + if !ok { + return apierrors.NewSCIMBadRequestError("email value must be a string", "invalidValue") + } + return a.applySCIMEmailUpdate(tx, user, newEmail, providerType) + default: + return apierrors.NewSCIMBadRequestError(fmt.Sprintf("Unsupported replace path: %s", op.Path), "invalidPath") + } +} + +func (a *API) applySCIMEmailUpdate(tx *storage.Connection, user *models.User, newEmail, providerType string) error { + validatedEmail, err := a.validateEmail(newEmail) + if err != nil { + return apierrors.NewSCIMBadRequestError("Invalid email address", "invalidValue") + } + if err := checkSCIMEmailUniqueness(tx, validatedEmail, a.config.JWT.Aud, providerType, user.ID); err != nil { + return err + } + if err := user.SetEmail(tx, validatedEmail); err != nil { + if pgErr := utilities.NewPostgresError(err); pgErr != nil && pgErr.IsUniqueConstraintViolated() { + return apierrors.NewSCIMConflictError(scimErrEmailConflict, "uniqueness") + } + return apierrors.NewSCIMInternalServerError("Error updating email").WithInternalError(err) + } + if identity := findSSOIdentity(user, providerType); identity != nil { + return setSCIMIdentityField(tx, identity, "email", validatedEmail) + } + return nil +} + +func (a *API) applySCIMActiveUpdate(tx *storage.Connection, user *models.User, val interface{}) error { + active, err := parseSCIMActiveBool(val) + if err != nil { + return err + } + if active { + if err := user.Ban(tx, 0, nil); err != nil { + return apierrors.NewSCIMInternalServerError("Error unbanning user").WithInternalError(err) + } + return nil + } + if err := user.Ban(tx, time.Duration(math.MaxInt64), &scimDeprovisionedReason); err != nil { + return apierrors.NewSCIMInternalServerError("Error banning user").WithInternalError(err) + } + if err := models.Logout(tx, user.ID); err != nil { + return apierrors.NewSCIMInternalServerError("Error invalidating sessions").WithInternalError(err) + } + return nil +} + +func (a *API) scimDeleteUser(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + db := a.db.WithContext(ctx) + provider := getSSOProvider(ctx) + config := a.config + + userID, err := uuid.FromString(chi.URLParam(r, "user_id")) + if err != nil { + return apierrors.NewSCIMNotFoundError(scimErrUserNotFound) + } + + terr := db.Transaction(func(tx *storage.Connection) error { + user, err := models.FindUserByID(tx, userID) + if err != nil { + if models.IsNotFoundError(err) { + return apierrors.NewSCIMNotFoundError(scimErrUserNotFound) + } + return apierrors.NewSCIMInternalServerError("Error fetching user").WithInternalError(err) + } + + if !models.UserBelongsToSSOProvider(user, provider.ID) { + return apierrors.NewSCIMNotFoundError(scimErrUserNotFound) + } + + if user.IsBanned() && user.BannedReason != nil && *user.BannedReason == scimDeprovisionedReason { + return apierrors.NewSCIMNotFoundError(scimErrUserNotFound) + } + + if err := user.Ban(tx, time.Duration(math.MaxInt64), &scimDeprovisionedReason); err != nil { + return apierrors.NewSCIMInternalServerError("Error deprovisioning user").WithInternalError(err) + } + + if terr := models.NewAuditLogEntry(config.AuditLog, r, tx, user, models.UserDeletedAction, utilities.GetIPAddress(r), map[string]interface{}{ + "provider": "scim", + "sso_provider_id": provider.ID, + }); terr != nil { + return apierrors.NewSCIMInternalServerError("Error recording audit log entry").WithInternalError(terr) + } + + if err := models.Logout(tx, user.ID); err != nil { + return apierrors.NewSCIMInternalServerError("Error invalidating sessions").WithInternalError(err) + } + return nil + }) + + if terr != nil { + return terr + } + + w.Header().Set("Content-Type", "application/scim+json") + w.WriteHeader(http.StatusNoContent) + return nil +} + +func (a *API) scimListGroups(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + db := a.db.WithContext(ctx) + provider := getSSOProvider(ctx) + + startIndex, count := parseSCIMPagination(r) + + filterStr := r.URL.Query().Get("filter") + filterClause, err := ParseSCIMFilterToSQL(filterStr, SCIMGroupFilterAttrs) + if err != nil { + return err + } + + groups, totalResults, err := models.FindSCIMGroupsBySSOProviderWithFilter(db, provider.ID, filterClause, startIndex, count) + if err != nil { + return apierrors.NewSCIMInternalServerError("Error fetching groups").WithInternalError(err) + } + + attrs := strings.ToLower(r.URL.Query().Get("attributes")) + excludeMembers := strings.Contains(strings.ToLower(r.URL.Query().Get("excludedAttributes")), "members") + includeMembers := !excludeMembers && (attrs == "" || strings.Contains(attrs, "members")) + + var membersByGroup map[uuid.UUID][]*models.User + if includeMembers && len(groups) > 0 { + groupIDs := make([]uuid.UUID, len(groups)) + for i, g := range groups { + groupIDs[i] = g.ID + } + var err error + membersByGroup, err = models.GetMembersForGroups(db, groupIDs) + if err != nil { + return apierrors.NewSCIMInternalServerError("Error fetching group members").WithInternalError(err) + } + } + + resources := make([]interface{}, len(groups)) + for i, group := range groups { + var members []*models.User + if includeMembers { + members = membersByGroup[group.ID] + } + resources[i] = a.groupToSCIMResponse(group, members) + } + + return sendSCIMJSON(w, http.StatusOK, &SCIMListResponse{ + Schemas: []string{SCIMSchemaListResponse}, + TotalResults: totalResults, + StartIndex: startIndex, + ItemsPerPage: len(groups), + Resources: resources, + }) +} + +func (a *API) scimGetGroup(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + db := a.db.WithContext(ctx) + provider := getSSOProvider(ctx) + + groupID, err := uuid.FromString(chi.URLParam(r, "group_id")) + if err != nil { + return apierrors.NewSCIMNotFoundError(scimErrGroupNotFound) + } + + group, err := models.FindSCIMGroupByID(db, groupID) + if err != nil { + if models.IsNotFoundError(err) { + return apierrors.NewSCIMNotFoundError(scimErrGroupNotFound) + } + return apierrors.NewSCIMInternalServerError("Error fetching group").WithInternalError(err) + } + + if group.SSOProviderID != provider.ID { + return apierrors.NewSCIMNotFoundError(scimErrGroupNotFound) + } + + excludeMembers := strings.Contains(strings.ToLower(r.URL.Query().Get("excludedAttributes")), "members") + + var members []*models.User + if !excludeMembers { + members, err = group.GetMembers(db) + if err != nil { + return apierrors.NewSCIMInternalServerError("Error fetching group members").WithInternalError(err) + } + } + + return sendSCIMJSON(w, http.StatusOK, a.groupToSCIMResponse(group, members)) +} + +func (a *API) scimCreateGroup(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + db := a.db.WithContext(ctx) + provider := getSSOProvider(ctx) + + var params SCIMGroupParams + if err := a.parseSCIMBody(w, r, ¶ms); err != nil { + return err + } + if err := params.Validate(); err != nil { + return err + } + + var group *models.SCIMGroup + terr := db.Transaction(func(tx *storage.Connection) error { + if params.ExternalID != "" { + existing, err := models.FindSCIMGroupByExternalID(tx, provider.ID, params.ExternalID) + if err == nil && existing != nil { + return apierrors.NewSCIMConflictError(scimErrGroupExternalIDConflict, "uniqueness") + } + if err != nil && !models.IsNotFoundError(err) { + return apierrors.NewSCIMInternalServerError("Error checking existing group").WithInternalError(err) + } + } + + group = models.NewSCIMGroup(provider.ID, params.ExternalID, params.DisplayName) + if err := tx.Create(group); err != nil { + if pgErr := utilities.NewPostgresError(err); pgErr != nil && pgErr.IsUniqueConstraintViolated() { + return apierrors.NewSCIMConflictError(scimErrGroupDisplayNameConflict, "uniqueness") + } + return apierrors.NewSCIMInternalServerError("Error creating group").WithInternalError(err) + } + + if len(params.Members) > 0 { + memberIDs, err := parseSCIMGroupMemberRefs(params.Members) + if err != nil { + return err + } + if err := group.AddMembers(tx, memberIDs); err != nil { + if _, ok := err.(models.UserNotFoundError); ok { + return apierrors.NewSCIMNotFoundError(scimErrMembersNotFound) + } + if _, ok := err.(models.UserNotInSSOProviderError); ok { + return apierrors.NewSCIMBadRequestError(scimErrMembersWrongProvider, "invalidValue") + } + return apierrors.NewSCIMInternalServerError("Error adding group members").WithInternalError(err) + } + } + + return nil + }) + + if terr != nil { + return terr + } + + members, err := group.GetMembers(db) + if err != nil { + return apierrors.NewSCIMInternalServerError("Error fetching group members").WithInternalError(err) + } + return sendSCIMJSON(w, http.StatusCreated, a.groupToSCIMResponse(group, members)) +} + +func (a *API) scimReplaceGroup(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + db := a.db.WithContext(ctx) + provider := getSSOProvider(ctx) + + groupID, err := uuid.FromString(chi.URLParam(r, "group_id")) + if err != nil { + return apierrors.NewSCIMNotFoundError(scimErrGroupNotFound) + } + + var params SCIMGroupParams + if err := a.parseSCIMBody(w, r, ¶ms); err != nil { + return err + } + if err := params.Validate(); err != nil { + return err + } + + var group *models.SCIMGroup + terr := db.Transaction(func(tx *storage.Connection) error { + var err error + group, err = models.FindSCIMGroupByID(tx, groupID) + if err != nil { + if models.IsNotFoundError(err) { + return apierrors.NewSCIMNotFoundError(scimErrGroupNotFound) + } + return apierrors.NewSCIMInternalServerError("Error fetching group").WithInternalError(err) + } + + if group.SSOProviderID != provider.ID { + return apierrors.NewSCIMNotFoundError(scimErrGroupNotFound) + } + + group.DisplayName = params.DisplayName + group.ExternalID = storage.NullString(params.ExternalID) + + if err := tx.Update(group); err != nil { + if pgErr := utilities.NewPostgresError(err); pgErr != nil && pgErr.IsUniqueConstraintViolated() { + return apierrors.NewSCIMConflictError(scimErrGroupDisplayNameConflict, "uniqueness") + } + return apierrors.NewSCIMInternalServerError("Error updating group").WithInternalError(err) + } + + memberIDs, err := parseSCIMGroupMemberRefs(params.Members) + if err != nil { + return err + } + + if err := group.SetMembers(tx, memberIDs); err != nil { + if models.IsNotFoundError(err) { + return apierrors.NewSCIMNotFoundError(scimErrMembersNotFound) + } + if _, ok := err.(models.UserNotInSSOProviderError); ok { + return apierrors.NewSCIMBadRequestError(scimErrMembersWrongProvider, "invalidValue") + } + return apierrors.NewSCIMInternalServerError("Error setting group members").WithInternalError(err) + } + return nil + }) + + if terr != nil { + return terr + } + + group, err = models.FindSCIMGroupByID(db, groupID) + if err != nil { + return apierrors.NewSCIMInternalServerError("Error reloading group").WithInternalError(err) + } + + members, err := group.GetMembers(db) + if err != nil { + return apierrors.NewSCIMInternalServerError("Error fetching group members").WithInternalError(err) + } + return sendSCIMJSON(w, http.StatusOK, a.groupToSCIMResponse(group, members)) +} + +func (a *API) scimPatchGroup(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + db := a.db.WithContext(ctx) + provider := getSSOProvider(ctx) + + groupID, err := uuid.FromString(chi.URLParam(r, "group_id")) + if err != nil { + return apierrors.NewSCIMNotFoundError(scimErrGroupNotFound) + } + + var params SCIMPatchRequest + if err := a.parseSCIMBody(w, r, ¶ms); err != nil { + return err + } + if err := params.Validate(); err != nil { + return err + } + + var group *models.SCIMGroup + terr := db.Transaction(func(tx *storage.Connection) error { + var err error + group, err = models.FindSCIMGroupByID(tx, groupID) + if err != nil { + if models.IsNotFoundError(err) { + return apierrors.NewSCIMNotFoundError(scimErrGroupNotFound) + } + return apierrors.NewSCIMInternalServerError("Error fetching group").WithInternalError(err) + } + + if group.SSOProviderID != provider.ID { + return apierrors.NewSCIMNotFoundError(scimErrGroupNotFound) + } + + for _, op := range params.Operations { + if err := a.applySCIMGroupPatch(tx, group, op); err != nil { + return err + } + } + + return nil + }) + + if terr != nil { + return terr + } + + group, err = models.FindSCIMGroupByID(db, groupID) + if err != nil { + return apierrors.NewSCIMInternalServerError("Error reloading group").WithInternalError(err) + } + + members, err := group.GetMembers(db) + if err != nil { + return apierrors.NewSCIMInternalServerError("Error fetching group members").WithInternalError(err) + } + return sendSCIMJSON(w, http.StatusOK, a.groupToSCIMResponse(group, members)) +} + +func (a *API) applySCIMGroupPatch(tx *storage.Connection, group *models.SCIMGroup, op SCIMPatchOperation) error { + var path *filter.Path + if op.Path != "" { + p, err := filter.ParsePath([]byte(op.Path)) + if err != nil { + return apierrors.NewSCIMBadRequestError( + fmt.Sprintf("Invalid path: %v", err), "invalidPath") + } + path = &p + } + + switch strings.ToLower(op.Op) { + case "add": + if path != nil { + attrName := strings.ToLower(path.AttributePath.AttributeName) + switch attrName { + case "externalid": + externalID, ok := op.Value.(string) + if !ok { + return apierrors.NewSCIMBadRequestError("externalId must be a string", "invalidValue") + } + return updateGroupExternalID(tx, group, externalID) + case "members": + // fall through to member handling below + default: + return apierrors.NewSCIMBadRequestError(fmt.Sprintf("Unsupported add path: %s", op.Path), "invalidPath") + } + } + members, ok := op.Value.([]interface{}) + if !ok { + return apierrors.NewSCIMBadRequestError("members must be an array", "invalidValue") + } + if len(members) > SCIMMaxMembers { + return apierrors.NewSCIMRequestTooLargeError(fmt.Sprintf("Maximum %d members per operation", SCIMMaxMembers)) + } + memberIDs, err := parseSCIMGroupMemberIDsRaw(members) + if err != nil { + return err + } + if err := group.AddMembers(tx, memberIDs); err != nil { + if _, ok := err.(models.UserNotFoundError); ok { + return apierrors.NewSCIMNotFoundError(scimErrMembersNotFound) + } + if _, ok := err.(models.UserNotInSSOProviderError); ok { + return apierrors.NewSCIMBadRequestError(scimErrMembersWrongProvider, "invalidValue") + } + return apierrors.NewSCIMInternalServerError("Error adding group members").WithInternalError(err) + } + + case "remove": + if path == nil { + return apierrors.NewSCIMBadRequestError("remove operation requires a path", "noTarget") + } + attrName := strings.ToLower(path.AttributePath.AttributeName) + switch { + case attrName == "externalid": + return updateGroupExternalID(tx, group, "") + case attrName == "members" && path.ValueExpression != nil: + attrExpr, ok := path.ValueExpression.(*filter.AttributeExpression) + if !ok || attrExpr.Operator != filter.EQ || strings.ToLower(attrExpr.AttributePath.AttributeName) != "value" { + return apierrors.NewSCIMBadRequestError("Unsupported member filter", "invalidFilter") + } + memberIDStr, ok := attrExpr.CompareValue.(string) + if !ok { + return apierrors.NewSCIMBadRequestError("Member filter value must be a string", "invalidValue") + } + memberID, err := uuid.FromString(memberIDStr) + if err != nil { + return apierrors.NewSCIMBadRequestError("Invalid member ID in path", "invalidValue") + } + if err := group.RemoveMember(tx, memberID); err != nil { + return apierrors.NewSCIMInternalServerError("Error removing group member").WithInternalError(err) + } + return nil + default: + return apierrors.NewSCIMBadRequestError(fmt.Sprintf("Unsupported remove path: %s", op.Path), "invalidPath") + } + + case "replace": + if path != nil { + attrName := strings.ToLower(path.AttributePath.AttributeName) + switch attrName { + case "externalid": + externalID, ok := op.Value.(string) + if !ok { + return apierrors.NewSCIMBadRequestError("externalId must be a string", "invalidValue") + } + return updateGroupExternalID(tx, group, externalID) + case "displayname": + displayName, ok := op.Value.(string) + if !ok { + return apierrors.NewSCIMBadRequestError("displayName must be a string", "invalidValue") + } + group.DisplayName = displayName + if err := tx.UpdateOnly(group, "display_name"); err != nil { + if pgErr := utilities.NewPostgresError(err); pgErr != nil && pgErr.IsUniqueConstraintViolated() { + return apierrors.NewSCIMConflictError(scimErrGroupDisplayNameConflict, "uniqueness") + } + return apierrors.NewSCIMInternalServerError("Error updating group display name").WithInternalError(err) + } + return nil + case "members": + members, ok := op.Value.([]interface{}) + if !ok { + return apierrors.NewSCIMBadRequestError("members must be an array", "invalidValue") + } + if len(members) > SCIMMaxMembers { + return apierrors.NewSCIMRequestTooLargeError(fmt.Sprintf("Maximum %d members per operation", SCIMMaxMembers)) + } + memberIDs, err := parseSCIMGroupMemberIDsRaw(members) + if err != nil { + return err + } + if err := group.SetMembers(tx, memberIDs); err != nil { + if models.IsNotFoundError(err) { + return apierrors.NewSCIMNotFoundError(scimErrMembersNotFound) + } + if _, ok := err.(models.UserNotInSSOProviderError); ok { + return apierrors.NewSCIMBadRequestError(scimErrMembersWrongProvider, "invalidValue") + } + return apierrors.NewSCIMInternalServerError("Error setting group members").WithInternalError(err) + } + return nil + default: + return apierrors.NewSCIMBadRequestError(fmt.Sprintf("Unsupported replace path: %s", op.Path), "invalidPath") + } + } + + valueMap, ok := op.Value.(map[string]interface{}) + if !ok { + return apierrors.NewSCIMBadRequestError("replace operation value must be an object when path is not specified", "invalidValue") + } + columnsToUpdate := []string{} + for key, val := range valueMap { + if key == "" { + continue + } + keyPath, err := filter.ParsePath([]byte(key)) + if err != nil { + return apierrors.NewSCIMBadRequestError(fmt.Sprintf("Invalid attribute path: %s", key), "invalidPath") + } + switch strings.ToLower(keyPath.AttributePath.AttributeName) { + case "externalid": + if externalID, ok := val.(string); ok { + group.ExternalID = storage.NullString(externalID) + columnsToUpdate = append(columnsToUpdate, "external_id") + } + case "displayname": + if displayName, ok := val.(string); ok { + group.DisplayName = displayName + columnsToUpdate = append(columnsToUpdate, "display_name") + } + } + } + if len(columnsToUpdate) > 0 { + if err := tx.UpdateOnly(group, columnsToUpdate...); err != nil { + if pgErr := utilities.NewPostgresError(err); pgErr != nil && pgErr.IsUniqueConstraintViolated() { + return apierrors.NewSCIMConflictError(scimErrGroupDisplayNameConflict, "uniqueness") + } + return apierrors.NewSCIMInternalServerError("Error updating group").WithInternalError(err) + } + } + + default: + return apierrors.NewSCIMBadRequestError(fmt.Sprintf("Unsupported patch operation: %s", op.Op), "invalidSyntax") + } + return nil +} + +func (a *API) scimDeleteGroup(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + db := a.db.WithContext(ctx) + provider := getSSOProvider(ctx) + + groupID, err := uuid.FromString(chi.URLParam(r, "group_id")) + if err != nil { + return apierrors.NewSCIMNotFoundError(scimErrGroupNotFound) + } + + terr := db.Transaction(func(tx *storage.Connection) error { + group, err := models.FindSCIMGroupByID(tx, groupID) + if err != nil { + if models.IsNotFoundError(err) { + return apierrors.NewSCIMNotFoundError(scimErrGroupNotFound) + } + return apierrors.NewSCIMInternalServerError("Error fetching group").WithInternalError(err) + } + + if group.SSOProviderID != provider.ID { + return apierrors.NewSCIMNotFoundError(scimErrGroupNotFound) + } + + if err := tx.Destroy(group); err != nil { + return apierrors.NewSCIMInternalServerError("Error deleting group").WithInternalError(err) + } + return nil + }) + + if terr != nil { + return terr + } + + w.Header().Set("Content-Type", "application/scim+json") + w.WriteHeader(http.StatusNoContent) + return nil +} + +func (a *API) scimServiceProviderConfig(w http.ResponseWriter, r *http.Request) error { + baseURL := a.getSCIMBaseURL() + + return sendSCIMJSON(w, http.StatusOK, map[string]interface{}{ + "schemas": []string{"urn:ietf:params:scim:schemas:core:2.0:ServiceProviderConfig"}, + "documentationUri": "https://supabase.com/docs/guides/auth/enterprise-sso/scim", + "patch": map[string]interface{}{"supported": true}, + "bulk": map[string]interface{}{"supported": false, "maxOperations": 0, "maxPayloadSize": 0}, + "filter": map[string]interface{}{"supported": true, "maxResults": SCIMMaxPageSize}, + "changePassword": map[string]interface{}{"supported": false}, + "sort": map[string]interface{}{"supported": false}, + "etag": map[string]interface{}{"supported": false}, + "authenticationSchemes": []map[string]interface{}{ + { + "type": "oauthbearertoken", + "name": "OAuth Bearer Token", + "description": "Authentication scheme using the OAuth Bearer Token", + "specUri": "http://www.rfc-editor.org/info/rfc6750", + "primary": true, + }, + }, + "meta": map[string]interface{}{ + "resourceType": "ServiceProviderConfig", + "location": baseURL + "/scim/v2/ServiceProviderConfig", + }, + }) +} + +func (a *API) buildSCIMResourceType(id, name, endpoint, description, schema string) map[string]interface{} { + baseURL := a.getSCIMBaseURL() + return map[string]interface{}{ + "schemas": []string{"urn:ietf:params:scim:schemas:core:2.0:ResourceType"}, + "id": id, + "name": name, + "endpoint": endpoint, + "description": description, + "schema": schema, + "meta": map[string]interface{}{"resourceType": "ResourceType", "location": baseURL + "/scim/v2/ResourceTypes/" + id}, + } +} + +func (a *API) buildSCIMSchema(id, name, description string, attributes []map[string]interface{}) map[string]interface{} { + baseURL := a.getSCIMBaseURL() + return map[string]interface{}{ + "schemas": []string{"urn:ietf:params:scim:schemas:core:2.0:Schema"}, + "id": id, + "name": name, + "description": description, + "attributes": attributes, + "meta": map[string]interface{}{ + "resourceType": "Schema", + "location": baseURL + "/scim/v2/Schemas/" + id, + }, + } +} + +var scimUserSchemaAttributes = []map[string]interface{}{ + {"name": "userName", "type": "string", "multiValued": false, "required": true, "caseExact": false, "mutability": "readWrite", "returned": "default", "uniqueness": "server"}, + {"name": "name", "type": "complex", "multiValued": false, "required": false, "mutability": "readWrite", "returned": "default", "uniqueness": "none", "subAttributes": []map[string]interface{}{ + {"name": "formatted", "type": "string", "multiValued": false, "required": false, "caseExact": false, "mutability": "readWrite", "returned": "default", "uniqueness": "none"}, + {"name": "familyName", "type": "string", "multiValued": false, "required": false, "caseExact": false, "mutability": "readWrite", "returned": "default", "uniqueness": "none"}, + {"name": "givenName", "type": "string", "multiValued": false, "required": false, "caseExact": false, "mutability": "readWrite", "returned": "default", "uniqueness": "none"}, + }}, + {"name": "emails", "type": "complex", "multiValued": true, "required": true, "mutability": "readWrite", "returned": "default", "uniqueness": "none", "subAttributes": []map[string]interface{}{ + {"name": "value", "type": "string", "multiValued": false, "required": false, "caseExact": false, "mutability": "readWrite", "returned": "default", "uniqueness": "none"}, + {"name": "type", "type": "string", "multiValued": false, "required": false, "caseExact": false, "mutability": "readWrite", "returned": "default", "uniqueness": "none"}, + {"name": "primary", "type": "boolean", "multiValued": false, "required": false, "mutability": "readWrite", "returned": "default", "uniqueness": "none"}, + }}, + {"name": "active", "type": "boolean", "multiValued": false, "required": false, "mutability": "readWrite", "returned": "default", "uniqueness": "none"}, + {"name": "externalId", "type": "string", "multiValued": false, "required": false, "caseExact": true, "mutability": "readWrite", "returned": "default", "uniqueness": "none"}, +} + +var scimGroupSchemaAttributes = []map[string]interface{}{ + {"name": "displayName", "type": "string", "multiValued": false, "required": true, "caseExact": false, "mutability": "readWrite", "returned": "default", "uniqueness": "none"}, + {"name": "members", "type": "complex", "multiValued": true, "required": false, "mutability": "readWrite", "returned": "default", "uniqueness": "none", "subAttributes": []map[string]interface{}{ + {"name": "value", "type": "string", "multiValued": false, "required": false, "caseExact": false, "mutability": "immutable", "returned": "default", "uniqueness": "none"}, + {"name": "$ref", "type": "reference", "multiValued": false, "required": false, "caseExact": false, "mutability": "immutable", "returned": "default", "uniqueness": "none", "referenceTypes": []string{"User"}}, + {"name": "display", "type": "string", "multiValued": false, "required": false, "caseExact": false, "mutability": "readOnly", "returned": "default", "uniqueness": "none"}, + }}, + {"name": "externalId", "type": "string", "multiValued": false, "required": false, "caseExact": true, "mutability": "readWrite", "returned": "default", "uniqueness": "none"}, +} + +func (a *API) scimResourceTypes(w http.ResponseWriter, r *http.Request) error { + resourceTypes := []interface{}{ + a.buildSCIMResourceType("User", "User", "/Users", "User Account", SCIMSchemaUser), + a.buildSCIMResourceType("Group", "Group", "/Groups", "Group", SCIMSchemaGroup), + } + + return sendSCIMJSON(w, http.StatusOK, SCIMListResponse{ + Schemas: []string{SCIMSchemaListResponse}, + TotalResults: len(resourceTypes), + StartIndex: 1, + ItemsPerPage: len(resourceTypes), + Resources: resourceTypes, + }) +} + +func (a *API) scimSchemas(w http.ResponseWriter, r *http.Request) error { + schemas := []interface{}{ + a.buildSCIMSchema(SCIMSchemaUser, "User", "User Account", scimUserSchemaAttributes), + a.buildSCIMSchema(SCIMSchemaGroup, "Group", "Group", scimGroupSchemaAttributes), + } + + return sendSCIMJSON(w, http.StatusOK, SCIMListResponse{ + Schemas: []string{SCIMSchemaListResponse}, + TotalResults: len(schemas), + StartIndex: 1, + ItemsPerPage: len(schemas), + Resources: schemas, + }) +} + +func (a *API) scimResourceTypeByID(w http.ResponseWriter, r *http.Request) error { + resourceTypeID := chi.URLParam(r, "resource_type_id") + + var resourceType map[string]interface{} + switch resourceTypeID { + case "User": + resourceType = a.buildSCIMResourceType("User", "User", "/Users", "User Account", SCIMSchemaUser) + case "Group": + resourceType = a.buildSCIMResourceType("Group", "Group", "/Groups", "Group", SCIMSchemaGroup) + default: + return sendSCIMError(w, http.StatusNotFound, "Resource type not found") + } + + return sendSCIMJSON(w, http.StatusOK, resourceType) +} + +func (a *API) scimSchemaByID(w http.ResponseWriter, r *http.Request) error { + schemaID := chi.URLParam(r, "schema_id") + + var schema map[string]interface{} + switch schemaID { + case SCIMSchemaUser: + schema = a.buildSCIMSchema(SCIMSchemaUser, "User", "User Account", scimUserSchemaAttributes) + case SCIMSchemaGroup: + schema = a.buildSCIMSchema(SCIMSchemaGroup, "Group", "Group", scimGroupSchemaAttributes) + default: + return sendSCIMError(w, http.StatusNotFound, "Schema not found") + } + + return sendSCIMJSON(w, http.StatusOK, schema) +} + +func sendSCIMError(w http.ResponseWriter, status int, detail string) error { + return sendSCIMJSON(w, status, apierrors.NewSCIMHTTPError(status, detail, "")) +} + +func (a *API) scimNotFound(w http.ResponseWriter, r *http.Request) error { + return sendSCIMError(w, http.StatusNotFound, "Resource not found") +} + +func (a *API) scimMethodNotAllowed(w http.ResponseWriter, r *http.Request) error { + return sendSCIMError(w, http.StatusMethodNotAllowed, "Method not allowed") +} diff --git a/internal/api/scim_filter.go b/internal/api/scim_filter.go new file mode 100644 index 000000000..b30d5ebb4 --- /dev/null +++ b/internal/api/scim_filter.go @@ -0,0 +1,201 @@ +package api + +import ( + "fmt" + "strings" + + filter "github.com/scim2/filter-parser/v2" + "github.com/supabase/auth/internal/api/apierrors" + "github.com/supabase/auth/internal/models" +) + +var SCIMUserFilterAttrs = map[string]string{ + "username": "COALESCE(NULLIF(i.identity_data->>'user_name', ''), u.email)", + "externalid": "i.identity_data->>'external_id'", + "email": "u.email", + "emails.value": "u.email", +} + +var SCIMGroupFilterAttrs = map[string]string{ + "displayname": "display_name", + "externalid": "external_id", +} + +const SCIMMaxFilterLength = 1024 + +func ParseSCIMFilterToSQL(filterStr string, allowedAttrs map[string]string) (*models.SCIMFilterClause, error) { + if filterStr == "" { + return &models.SCIMFilterClause{Where: "1=1", Args: nil}, nil + } + + if len(filterStr) > SCIMMaxFilterLength { + return nil, apierrors.NewSCIMBadRequestError("Filter exceeds maximum length", "invalidFilter") + } + + expr, err := filter.ParseFilter([]byte(filterStr)) + if err != nil { + return nil, apierrors.NewSCIMBadRequestError( + fmt.Sprintf("Invalid filter syntax: %v", err), "invalidFilter") + } + + return exprToSQL(expr, allowedAttrs) +} + +func exprToSQL(expr filter.Expression, allowedAttrs map[string]string) (*models.SCIMFilterClause, error) { + switch e := expr.(type) { + case *filter.AttributeExpression: + return attrExprToSQL(*e, allowedAttrs) + case *filter.LogicalExpression: + return logicalExprToSQL(*e, allowedAttrs) + case *filter.NotExpression: + return notExprToSQL(*e, allowedAttrs) + case *filter.ValuePath: + return valuePathToSQL(*e, allowedAttrs) + default: + return nil, apierrors.NewSCIMBadRequestError( + fmt.Sprintf("Unsupported filter expression type: %T", expr), "invalidFilter") + } +} + +func attrExprToSQL(e filter.AttributeExpression, allowedAttrs map[string]string) (*models.SCIMFilterClause, error) { + attrName := strings.ToLower(e.AttributePath.AttributeName) + if e.AttributePath.SubAttribute != nil { + attrName = attrName + "." + strings.ToLower(*e.AttributePath.SubAttribute) + } + + dbColumn, ok := allowedAttrs[attrName] + if !ok { + return nil, apierrors.NewSCIMBadRequestError( + fmt.Sprintf("Filtering on attribute '%s' is not supported", attrName), "invalidFilter") + } + + switch e.Operator { + case filter.EQ: + return &models.SCIMFilterClause{ + Where: fmt.Sprintf("LOWER(CAST(%s AS TEXT)) = LOWER(?)", dbColumn), + Args: []interface{}{fmt.Sprintf("%v", e.CompareValue)}, + }, nil + + case filter.NE: + return &models.SCIMFilterClause{ + Where: fmt.Sprintf("LOWER(CAST(%s AS TEXT)) != LOWER(?)", dbColumn), + Args: []interface{}{fmt.Sprintf("%v", e.CompareValue)}, + }, nil + + case filter.CO: + val, ok := e.CompareValue.(string) + if !ok { + return nil, apierrors.NewSCIMBadRequestError("'co' operator requires a string value", "invalidValue") + } + return &models.SCIMFilterClause{ + Where: fmt.Sprintf("LOWER(CAST(%s AS TEXT)) LIKE LOWER(?) ESCAPE '\\'", dbColumn), + Args: []interface{}{"%" + escapeLikePattern(val) + "%"}, + }, nil + + case filter.SW: + val, ok := e.CompareValue.(string) + if !ok { + return nil, apierrors.NewSCIMBadRequestError("'sw' operator requires a string value", "invalidValue") + } + return &models.SCIMFilterClause{ + Where: fmt.Sprintf("LOWER(CAST(%s AS TEXT)) LIKE LOWER(?) ESCAPE '\\'", dbColumn), + Args: []interface{}{escapeLikePattern(val) + "%"}, + }, nil + + case filter.EW: + val, ok := e.CompareValue.(string) + if !ok { + return nil, apierrors.NewSCIMBadRequestError("'ew' operator requires a string value", "invalidValue") + } + return &models.SCIMFilterClause{ + Where: fmt.Sprintf("LOWER(CAST(%s AS TEXT)) LIKE LOWER(?) ESCAPE '\\'", dbColumn), + Args: []interface{}{"%" + escapeLikePattern(val)}, + }, nil + + case filter.PR: + return &models.SCIMFilterClause{ + Where: fmt.Sprintf("(%s IS NOT NULL AND CAST(%s AS TEXT) != '')", dbColumn, dbColumn), + Args: nil, + }, nil + + case filter.GT, filter.GE, filter.LT, filter.LE: + if _, ok := e.CompareValue.(string); !ok { + return nil, apierrors.NewSCIMBadRequestError( + fmt.Sprintf("'%s' operator requires a string value", e.Operator), "invalidValue") + } + ops := map[filter.CompareOperator]string{ + filter.GT: ">", filter.GE: ">=", filter.LT: "<", filter.LE: "<=", + } + return &models.SCIMFilterClause{ + Where: fmt.Sprintf("%s %s ?", dbColumn, ops[e.Operator]), + Args: []interface{}{e.CompareValue}, + }, nil + + default: + return nil, apierrors.NewSCIMBadRequestError( + fmt.Sprintf("Unsupported operator: %s", e.Operator), "invalidFilter") + } +} + +func logicalExprToSQL(e filter.LogicalExpression, allowedAttrs map[string]string) (*models.SCIMFilterClause, error) { + left, err := exprToSQL(e.Left, allowedAttrs) + if err != nil { + return nil, err + } + + right, err := exprToSQL(e.Right, allowedAttrs) + if err != nil { + return nil, err + } + + op := "AND" + if e.Operator == filter.OR { + op = "OR" + } + + return &models.SCIMFilterClause{ + Where: fmt.Sprintf("(%s %s %s)", left.Where, op, right.Where), + Args: append(left.Args, right.Args...), + }, nil +} + +func notExprToSQL(e filter.NotExpression, allowedAttrs map[string]string) (*models.SCIMFilterClause, error) { + operand, err := exprToSQL(e.Expression, allowedAttrs) + if err != nil { + return nil, err + } + + return &models.SCIMFilterClause{ + Where: fmt.Sprintf("NOT (%s)", operand.Where), + Args: operand.Args, + }, nil +} + +// valuePathToSQL handles bracket notation (e.g., emails[value eq "x"]). +// Only emails[value ...] is supported since Supabase Auth stores one email per user. +func valuePathToSQL(e filter.ValuePath, allowedAttrs map[string]string) (*models.SCIMFilterClause, error) { + attrName := strings.ToLower(e.AttributePath.AttributeName) + + if attrName == "emails" && e.ValueFilter != nil { + if attrExpr, ok := e.ValueFilter.(*filter.AttributeExpression); ok { + if strings.ToLower(attrExpr.AttributePath.AttributeName) == "value" { + modifiedExpr := filter.AttributeExpression{ + AttributePath: filter.AttributePath{AttributeName: "email"}, + Operator: attrExpr.Operator, + CompareValue: attrExpr.CompareValue, + } + return attrExprToSQL(modifiedExpr, allowedAttrs) + } + } + } + + return nil, apierrors.NewSCIMBadRequestError( + fmt.Sprintf("Value path filter '%s[...]' is not supported", attrName), "invalidFilter") +} + +func escapeLikePattern(s string) string { + s = strings.ReplaceAll(s, "\\", "\\\\") + s = strings.ReplaceAll(s, "%", "\\%") + s = strings.ReplaceAll(s, "_", "\\_") + return s +} diff --git a/internal/api/scim_helpers.go b/internal/api/scim_helpers.go new file mode 100644 index 000000000..40bd15c16 --- /dev/null +++ b/internal/api/scim_helpers.go @@ -0,0 +1,337 @@ +package api + +import ( + "encoding/json" + "errors" + "fmt" + "net/http" + "strconv" + "strings" + + "github.com/gofrs/uuid" + "github.com/supabase/auth/internal/api/apierrors" + "github.com/supabase/auth/internal/models" + "github.com/supabase/auth/internal/storage" + "github.com/supabase/auth/internal/utilities" +) + +func parseSCIMPagination(r *http.Request) (startIndex, count int) { + startIndex = 1 + count = SCIMDefaultPageSize + + if v := r.URL.Query().Get("startIndex"); v != "" { + if i, err := strconv.Atoi(v); err == nil && i > 0 { + startIndex = i + if startIndex > SCIMMaxStartIndex { + startIndex = SCIMMaxStartIndex + } + } + } + + if v := r.URL.Query().Get("count"); v != "" { + if i, err := strconv.Atoi(v); err == nil && i >= 0 { + count = i + if count > SCIMMaxPageSize { + count = SCIMMaxPageSize + } + } + } + + return startIndex, count +} + +func (a *API) parseSCIMBody(w http.ResponseWriter, r *http.Request, v interface{}) error { + r.Body = http.MaxBytesReader(w, r.Body, SCIMMaxBodySize) + body, err := utilities.GetBodyBytes(r) + if err != nil { + var maxBytesErr *http.MaxBytesError + if errors.As(err, &maxBytesErr) { + return apierrors.NewSCIMRequestTooLargeError("Request body exceeds maximum size of 1MB") + } + return apierrors.NewSCIMInternalServerError("Could not read request body").WithInternalError(err) + } + if err := json.Unmarshal(body, v); err != nil { + return apierrors.NewSCIMBadRequestError("Invalid JSON in request body", "invalidSyntax").WithInternalError(err) + } + return nil +} + +func (a *API) userToSCIMResponse(user *models.User, providerType string) *SCIMUserResponse { + baseURL := a.getSCIMBaseURL() + resp := &SCIMUserResponse{ + Schemas: []string{SCIMSchemaUser}, + ID: user.ID.String(), + UserName: user.GetEmail(), + Active: !user.IsBanned(), + Meta: SCIMMeta{ + ResourceType: "User", + Created: &user.CreatedAt, + LastModified: &user.UpdatedAt, + Location: baseURL + "/scim/v2/Users/" + user.ID.String(), + }, + } + + var emailType string + for _, identity := range user.Identities { + if identity.Provider == providerType { + if identity.IdentityData != nil { + if extID, ok := identity.IdentityData["external_id"].(string); ok && extID != "" { + resp.ExternalID = extID + } + if userName, ok := identity.IdentityData["user_name"].(string); ok && userName != "" { + resp.UserName = userName + } + if et, ok := identity.IdentityData["email_type"].(string); ok { + emailType = et + } + } + break + } + } + + if email := user.GetEmail(); email != "" { + scimEmail := SCIMEmail{Value: email, Primary: true} + if emailType != "" { + scimEmail.Type = emailType + } + resp.Emails = []SCIMEmail{scimEmail} + } + + if user.UserMetaData != nil { + name := &SCIMName{} + hasName := false + if v, ok := user.UserMetaData["given_name"].(string); ok { + name.GivenName = v + hasName = true + } + if v, ok := user.UserMetaData["family_name"].(string); ok { + name.FamilyName = v + hasName = true + } + if v, ok := user.UserMetaData["full_name"].(string); ok { + name.Formatted = v + hasName = true + } + if hasName { + resp.Name = name + } + } + + return resp +} + +func (a *API) groupToSCIMResponse(group *models.SCIMGroup, members []*models.User) *SCIMGroupResponse { + baseURL := a.getSCIMBaseURL() + resp := &SCIMGroupResponse{ + Schemas: []string{SCIMSchemaGroup}, + ID: group.ID.String(), + ExternalID: string(group.ExternalID), + DisplayName: group.DisplayName, + Members: []SCIMGroupMemberRef{}, + Meta: SCIMMeta{ + ResourceType: "Group", + Created: &group.CreatedAt, + LastModified: &group.UpdatedAt, + Location: baseURL + "/scim/v2/Groups/" + group.ID.String(), + }, + } + + if len(members) > 0 { + resp.Members = make([]SCIMGroupMemberRef, len(members)) + for i, m := range members { + resp.Members[i] = SCIMGroupMemberRef{ + Value: m.ID.String(), + Ref: baseURL + "/scim/v2/Users/" + m.ID.String(), + Display: m.GetEmail(), + } + } + } + + return resp +} + +func (a *API) getSCIMBaseURL() string { + return strings.TrimRight(a.config.API.ExternalURL, "/") +} + +func sendSCIMJSON(w http.ResponseWriter, status int, obj interface{}) error { + w.Header().Set("Content-Type", "application/scim+json") + w.WriteHeader(status) + return json.NewEncoder(w).Encode(obj) +} + +func parseSCIMActiveBool(val interface{}) (bool, error) { + switch v := val.(type) { + case bool: + return v, nil + case string: + switch strings.ToLower(v) { + case "true": + return true, nil + case "false": + return false, nil + } + } + return false, apierrors.NewSCIMBadRequestError("active must be a boolean or \"true\"/\"false\"", "invalidValue") +} + +func findSSOIdentity(user *models.User, providerType string) *models.Identity { + for i := range user.Identities { + if user.Identities[i].Provider == providerType { + return &user.Identities[i] + } + } + return nil +} + +func setSCIMExternalID(tx *storage.Connection, identity *models.Identity, externalID string) error { + if strings.TrimSpace(externalID) == "" { + return apierrors.NewSCIMBadRequestError("externalId must not be empty", "invalidValue") + } + + identity.ProviderID = externalID + if identity.IdentityData == nil { + identity.IdentityData = make(map[string]interface{}) + } + identity.IdentityData["external_id"] = externalID + identity.IdentityData["sub"] = externalID + if err := tx.UpdateOnly(identity, "provider_id", "identity_data"); err != nil { + if pgErr := utilities.NewPostgresError(err); pgErr != nil && pgErr.IsUniqueConstraintViolated() { + return apierrors.NewSCIMConflictError(scimErrExternalIDConflict, "uniqueness") + } + return apierrors.NewSCIMInternalServerError("Error updating identity").WithInternalError(err) + } + return nil +} + +func setSCIMUserName(tx *storage.Connection, identity *models.Identity, userName string) error { + if identity.IdentityData == nil { + identity.IdentityData = make(map[string]interface{}) + } + identity.IdentityData["user_name"] = userName + + updateCols := []string{"identity_data"} + if externalID, ok := identity.IdentityData["external_id"].(string); !ok || externalID == "" { + identity.ProviderID = userName + identity.IdentityData["sub"] = userName + updateCols = append(updateCols, "provider_id") + } + + if err := tx.UpdateOnly(identity, updateCols...); err != nil { + if pgErr := utilities.NewPostgresError(err); pgErr != nil && pgErr.IsUniqueConstraintViolated() { + return apierrors.NewSCIMConflictError(scimErrUserNameConflict, "uniqueness") + } + return apierrors.NewSCIMInternalServerError("Error updating identity").WithInternalError(err) + } + return nil +} + +func setSCIMIdentityField(tx *storage.Connection, identity *models.Identity, key, value string) error { + if identity.IdentityData == nil { + identity.IdentityData = make(map[string]interface{}) + } + identity.IdentityData[key] = value + if err := tx.UpdateOnly(identity, "identity_data"); err != nil { + return apierrors.NewSCIMInternalServerError("Error updating identity").WithInternalError(err) + } + return nil +} + +func extractPrimarySCIMEmail(emails []SCIMEmail) (email, emailType string) { + if len(emails) == 0 { + return "", "" + } + for _, e := range emails { + if e.Primary { + return e.Value, e.Type + } + } + return emails[0].Value, emails[0].Type +} + +func applySCIMNameToMetadata(metadata map[string]interface{}, name *SCIMName) { + if name == nil { + return + } + if name.GivenName != "" { + metadata["given_name"] = name.GivenName + } + if name.FamilyName != "" { + metadata["family_name"] = name.FamilyName + } + if name.Formatted != "" { + metadata["full_name"] = name.Formatted + } +} + +func parseSCIMGroupMemberRefs(members []SCIMGroupMemberRef) ([]uuid.UUID, error) { + ids := make([]uuid.UUID, 0, len(members)) + for _, member := range members { + id, err := uuid.FromString(member.Value) + if err != nil { + return nil, apierrors.NewSCIMBadRequestError(fmt.Sprintf("Invalid member ID: %s", member.Value), "invalidValue") + } + ids = append(ids, id) + } + return ids, nil +} + +func parseSCIMGroupMemberIDsRaw(members []interface{}) ([]uuid.UUID, error) { + ids := make([]uuid.UUID, 0, len(members)) + for _, m := range members { + memberMap, ok := m.(map[string]interface{}) + if !ok { + return nil, apierrors.NewSCIMBadRequestError("Invalid member format", "invalidValue") + } + value, ok := memberMap["value"].(string) + if !ok { + return nil, apierrors.NewSCIMBadRequestError("Member value must be a string", "invalidValue") + } + id, err := uuid.FromString(value) + if err != nil { + return nil, apierrors.NewSCIMBadRequestError(fmt.Sprintf("Invalid member ID: %s", value), "invalidValue") + } + ids = append(ids, id) + } + return ids, nil +} + +func updateGroupExternalID(tx *storage.Connection, group *models.SCIMGroup, externalID string) error { + group.ExternalID = storage.NullString(externalID) + if err := tx.UpdateOnly(group, "external_id"); err != nil { + if pgErr := utilities.NewPostgresError(err); pgErr != nil && pgErr.IsUniqueConstraintViolated() { + return apierrors.NewSCIMConflictError(scimErrGroupExternalIDConflict, "uniqueness") + } + return apierrors.NewSCIMInternalServerError("Error updating group external ID").WithInternalError(err) + } + return nil +} + +func checkSCIMEmailUniqueness(tx *storage.Connection, email, aud, providerType string, excludeUserID uuid.UUID) error { + existingUser, err := models.FindUserByEmailAndAudience(tx, email, aud) + if err != nil && !models.IsNotFoundError(err) { + return apierrors.NewSCIMInternalServerError("Error checking email uniqueness").WithInternalError(err) + } + if existingUser != nil && existingUser.ID != excludeUserID { + if !existingUser.IsSSOUser { + return apierrors.NewSCIMConflictError(scimErrEmailConflict, "uniqueness") + } + if existingUser.BannedReason == nil || *existingUser.BannedReason != scimDeprovisionedReason { + return apierrors.NewSCIMConflictError(scimErrEmailConflict, "uniqueness") + } + } + + ssoUsers, err := models.FindSSOUsersByEmailAndProvider(tx, email, aud, providerType) + if err != nil { + return apierrors.NewSCIMInternalServerError("Error checking email uniqueness").WithInternalError(err) + } + for _, u := range ssoUsers { + if u.ID == excludeUserID { + continue + } + if u.BannedReason == nil || *u.BannedReason != scimDeprovisionedReason { + return apierrors.NewSCIMConflictError(scimErrEmailConflict, "uniqueness") + } + } + return nil +} diff --git a/internal/api/scim_test.go b/internal/api/scim_test.go new file mode 100644 index 000000000..578597080 --- /dev/null +++ b/internal/api/scim_test.go @@ -0,0 +1,3105 @@ +package api + +import ( + "bytes" + "encoding/json" + "fmt" + "math" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + jwt "github.com/golang-jwt/jwt/v5" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/models" +) + +type scimTestUser struct { + UserName string + Email string + GivenName string + FamilyName string + Formatted string + ExternalID string +} + +type scimTestGroup struct { + DisplayName string + ExternalID string +} + +var ( + testUser1 = scimTestUser{UserName: "user1@acme.com", Email: "user1@acme.com"} + testUser2 = scimTestUser{UserName: "user2@acme.com", Email: "user2@acme.com", GivenName: "Test", FamilyName: "User", Formatted: "Test User"} + testUser3 = scimTestUser{UserName: "user3@acme.com", Email: "user3@acme.com", ExternalID: "ext-001"} + testUser4 = scimTestUser{UserName: "user4@acme.com", Email: "user4@acme.com"} + testUser5 = scimTestUser{UserName: "user5@acme.com", Email: "user5@acme.com"} + testUser6 = scimTestUser{UserName: "user6@example.com", Email: "user6@example.com"} + testUser7 = scimTestUser{UserName: "user7@example.com", Email: "user7@example.com"} + testUser8 = scimTestUser{UserName: "user8@example.com", Email: "user8@example.com"} + testUser9 = scimTestUser{UserName: "user9@acme.com", Email: "user9@acme.com", GivenName: "Jane", FamilyName: "Doe", Formatted: "Jane Doe", ExternalID: "ext-002"} + testUser10 = scimTestUser{UserName: "user10@acme.com", Email: "user10@acme.com", GivenName: "John", FamilyName: "Smith", Formatted: "John Smith", ExternalID: "ext-003"} + testUser13 = scimTestUser{UserName: "user13@example.com", Email: "user13@example.com", ExternalID: "ext-006"} + testUser14 = scimTestUser{UserName: "user14@acme.com", Email: "user14@acme.com", ExternalID: "ext-007"} + testUser15 = scimTestUser{UserName: "user15@acme.com", Email: "user15@acme.com", ExternalID: "ext-008"} + testUser16 = scimTestUser{UserName: "user16@example.com", Email: "user16@example.com", ExternalID: "ext-009"} + testUser17 = scimTestUser{UserName: "user17@acme.com", Email: "user17@acme.com", GivenName: "Reactivated", FamilyName: "User", Formatted: "Reactivated User", ExternalID: "ext-010"} + testUser18 = scimTestUser{UserName: "crossemail@acme.com", Email: "crossemail@acme.com", ExternalID: "ext-011"} + testUser19 = scimTestUser{UserName: "ambiguous@acme.com", Email: "ambiguous@acme.com", ExternalID: "ext-012"} + + testGroup1 = scimTestGroup{DisplayName: "Engineering", ExternalID: "grp-001"} + testGroup2 = scimTestGroup{DisplayName: "Sales", ExternalID: "grp-002"} + testGroup3 = scimTestGroup{DisplayName: "Marketing", ExternalID: "grp-003"} + testGroup4 = scimTestGroup{DisplayName: "Platform", ExternalID: "grp-004"} + testGroup5 = scimTestGroup{DisplayName: "Support", ExternalID: "grp-005"} +) + +type SCIMTestSuite struct { + suite.Suite + API *API + Config *conf.GlobalConfiguration + SCIMToken string + SSOProvider *models.SSOProvider +} + +func TestSCIM(t *testing.T) { + api, config, err := setupAPIForTest() + require.NoError(t, err) + + ts := &SCIMTestSuite{ + API: api, + Config: config, + } + defer api.db.Close() + + suite.Run(t, ts) +} + +func (ts *SCIMTestSuite) SetupTest() { + models.TruncateAll(ts.API.db) + ts.SCIMToken = "test-scim-token-12345" + ts.SSOProvider = ts.createSSOProviderWithSCIM() +} + +func (ts *SCIMTestSuite) createSSOProviderWithSCIM() *models.SSOProvider { + provider := &models.SSOProvider{} + require.NoError(ts.T(), ts.API.db.Create(provider)) + provider.SetSCIMToken(ts.SCIMToken) + require.NoError(ts.T(), ts.API.db.Update(provider)) + require.NoError(ts.T(), ts.API.db.Reload(provider)) + return provider +} + +func (ts *SCIMTestSuite) makeSCIMRequest(method, path string, body interface{}) *http.Request { + var reqBody *bytes.Buffer + if body != nil { + jsonBody, err := json.Marshal(body) + require.NoError(ts.T(), err) + reqBody = bytes.NewBuffer(jsonBody) + } else { + reqBody = bytes.NewBuffer(nil) + } + + req := httptest.NewRequest(method, "http://localhost"+path, reqBody) + req.Header.Set("Authorization", "Bearer "+ts.SCIMToken) + req.Header.Set("Content-Type", "application/scim+json") + return req +} + +func (ts *SCIMTestSuite) createSCIMUser(userName, email string) *SCIMUserResponse { + body := map[string]interface{}{ + "schemas": []string{SCIMSchemaUser}, + "userName": userName, + "emails": []map[string]interface{}{ + {"value": email, "primary": true, "type": "work"}, + }, + } + + req := ts.makeSCIMRequest(http.MethodPost, "/scim/v2/Users", body) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusCreated, w.Code, "Failed to create SCIM user: %s", w.Body.String()) + + var result SCIMUserResponse + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&result)) + return &result +} + +func (ts *SCIMTestSuite) createSCIMUserWithName(userName, email, givenName, familyName string) *SCIMUserResponse { + body := map[string]interface{}{ + "schemas": []string{SCIMSchemaUser}, + "userName": userName, + "name": map[string]interface{}{ + "givenName": givenName, + "familyName": familyName, + "formatted": givenName + " " + familyName, + }, + "emails": []map[string]interface{}{ + {"value": email, "primary": true, "type": "work"}, + }, + } + + req := ts.makeSCIMRequest(http.MethodPost, "/scim/v2/Users", body) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusCreated, w.Code, "Failed to create SCIM user: %s", w.Body.String()) + + var result SCIMUserResponse + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&result)) + return &result +} + +func (ts *SCIMTestSuite) createSCIMUserWithExternalID(userName, email, externalID string) *SCIMUserResponse { + body := map[string]interface{}{ + "schemas": []string{SCIMSchemaUser}, + "userName": userName, + "externalId": externalID, + "emails": []map[string]interface{}{ + {"value": email, "primary": true, "type": "work"}, + }, + } + + req := ts.makeSCIMRequest(http.MethodPost, "/scim/v2/Users", body) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusCreated, w.Code, "Failed to create SCIM user: %s", w.Body.String()) + + var result SCIMUserResponse + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&result)) + return &result +} + +func (ts *SCIMTestSuite) createSCIMGroup(displayName string) *SCIMGroupResponse { + body := map[string]interface{}{ + "schemas": []string{SCIMSchemaGroup}, + "displayName": displayName, + } + + req := ts.makeSCIMRequest(http.MethodPost, "/scim/v2/Groups", body) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusCreated, w.Code, "Failed to create SCIM group: %s", w.Body.String()) + + var result SCIMGroupResponse + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&result)) + return &result +} + +func (ts *SCIMTestSuite) createSCIMGroupWithExternalID(displayName, externalID string) *SCIMGroupResponse { + body := map[string]interface{}{ + "schemas": []string{SCIMSchemaGroup}, + "displayName": displayName, + "externalId": externalID, + } + + req := ts.makeSCIMRequest(http.MethodPost, "/scim/v2/Groups", body) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusCreated, w.Code, "Failed to create SCIM group: %s", w.Body.String()) + + var result SCIMGroupResponse + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&result)) + return &result +} + +func (ts *SCIMTestSuite) createSCIMGroupWithMembers(displayName string, memberIDs []string) *SCIMGroupResponse { + members := make([]map[string]interface{}, len(memberIDs)) + for i, id := range memberIDs { + members[i] = map[string]interface{}{"value": id} + } + + body := map[string]interface{}{ + "schemas": []string{SCIMSchemaGroup}, + "displayName": displayName, + "members": members, + } + + req := ts.makeSCIMRequest(http.MethodPost, "/scim/v2/Groups", body) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusCreated, w.Code, "Failed to create SCIM group: %s", w.Body.String()) + + var result SCIMGroupResponse + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&result)) + return &result +} + +func (ts *SCIMTestSuite) assertSCIMError(w *httptest.ResponseRecorder, expectedStatus int) { + require.Equal(ts.T(), expectedStatus, w.Code) + + var errorResp map[string]interface{} + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&errorResp)) + + schemas, ok := errorResp["schemas"].([]interface{}) + require.True(ts.T(), ok, "SCIM error should have schemas field") + require.Len(ts.T(), schemas, 1) + require.Equal(ts.T(), "urn:ietf:params:scim:api:messages:2.0:Error", schemas[0]) + + _, ok = errorResp["detail"].(string) + require.True(ts.T(), ok, "SCIM error should have detail field") + + // SCIM status is a string per RFC 7644 + status, ok := errorResp["status"].(string) + require.True(ts.T(), ok, "SCIM error should have status field") + require.Equal(ts.T(), fmt.Sprintf("%d", expectedStatus), status) +} + +func (ts *SCIMTestSuite) assertSCIMListResponse(w *httptest.ResponseRecorder, expectedTotal int) *SCIMListResponse { + require.Equal(ts.T(), http.StatusOK, w.Code) + + var result SCIMListResponse + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&result)) + + require.Len(ts.T(), result.Schemas, 1) + require.Equal(ts.T(), SCIMSchemaListResponse, result.Schemas[0]) + require.Equal(ts.T(), expectedTotal, result.TotalResults) + require.GreaterOrEqual(ts.T(), result.StartIndex, 1) + + return &result +} +func (ts *SCIMTestSuite) TestSCIMProviderSetup() { + require.NotNil(ts.T(), ts.SSOProvider) + require.True(ts.T(), ts.SSOProvider.IsSCIMEnabled()) +} + +func (ts *SCIMTestSuite) TestSCIMTokenValidation() { + req := ts.makeSCIMRequest(http.MethodGet, "/scim/v2/Users", nil) + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) +} + +func (ts *SCIMTestSuite) TestSCIMInvalidToken() { + req := httptest.NewRequest(http.MethodGet, "http://localhost/scim/v2/Users", nil) + req.Header.Set("Authorization", "Bearer invalid-token") + req.Header.Set("Content-Type", "application/scim+json") + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + ts.assertSCIMError(w, http.StatusUnauthorized) +} + +func (ts *SCIMTestSuite) TestSCIMMissingToken() { + req := httptest.NewRequest(http.MethodGet, "http://localhost/scim/v2/Users", nil) + req.Header.Set("Content-Type", "application/scim+json") + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + ts.assertSCIMError(w, http.StatusUnauthorized) +} + +func (ts *SCIMTestSuite) TestSCIMEmptyUserList() { + req := ts.makeSCIMRequest(http.MethodGet, "/scim/v2/Users", nil) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + + result := ts.assertSCIMListResponse(w, 0) + require.Len(ts.T(), result.Resources, 0) +} + +func (ts *SCIMTestSuite) TestSCIMEmptyGroupList() { + req := ts.makeSCIMRequest(http.MethodGet, "/scim/v2/Groups", nil) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + + result := ts.assertSCIMListResponse(w, 0) + require.Len(ts.T(), result.Resources, 0) +} + +func (ts *SCIMTestSuite) TestSCIMCreateUser() { + user := ts.createSCIMUser(testUser1.UserName, testUser1.Email) + + require.NotEmpty(ts.T(), user.ID) + require.Equal(ts.T(), testUser1.UserName, user.UserName) + require.True(ts.T(), user.Active) + require.Len(ts.T(), user.Emails, 1) + require.Equal(ts.T(), testUser1.Email, user.Emails[0].Value) +} + +func (ts *SCIMTestSuite) TestSCIMCreateGroup() { + group := ts.createSCIMGroup(testGroup1.DisplayName) + + require.NotEmpty(ts.T(), group.ID) + require.Equal(ts.T(), testGroup1.DisplayName, group.DisplayName) +} + +func (ts *SCIMTestSuite) TestSCIMServiceProviderConfig() { + req := ts.makeSCIMRequest(http.MethodGet, "/scim/v2/ServiceProviderConfig", nil) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + var result map[string]interface{} + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&result)) + + schemas, ok := result["schemas"].([]interface{}) + require.True(ts.T(), ok) + require.Len(ts.T(), schemas, 1) + require.Equal(ts.T(), "urn:ietf:params:scim:schemas:core:2.0:ServiceProviderConfig", schemas[0]) + + patch, ok := result["patch"].(map[string]interface{}) + require.True(ts.T(), ok) + require.True(ts.T(), patch["supported"].(bool)) + + filter, ok := result["filter"].(map[string]interface{}) + require.True(ts.T(), ok) + require.True(ts.T(), filter["supported"].(bool)) +} + +func (ts *SCIMTestSuite) TestSCIMResourceTypes() { + req := ts.makeSCIMRequest(http.MethodGet, "/scim/v2/ResourceTypes", nil) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + var result SCIMListResponse + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&result)) + + require.Equal(ts.T(), 2, result.TotalResults) + require.Len(ts.T(), result.Resources, 2) +} + +func (ts *SCIMTestSuite) TestSCIMSchemas() { + req := ts.makeSCIMRequest(http.MethodGet, "/scim/v2/Schemas", nil) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + var result SCIMListResponse + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&result)) + + require.Equal(ts.T(), 2, result.TotalResults) + require.Len(ts.T(), result.Resources, 2) +} + +func (ts *SCIMTestSuite) TestSCIMGetUserNotFound() { + req := ts.makeSCIMRequest(http.MethodGet, "/scim/v2/Users/00000000-0000-0000-0000-000000000000", nil) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + ts.assertSCIMError(w, http.StatusNotFound) +} + +func (ts *SCIMTestSuite) TestSCIMGetGroupNotFound() { + req := ts.makeSCIMRequest(http.MethodGet, "/scim/v2/Groups/00000000-0000-0000-0000-000000000000", nil) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + ts.assertSCIMError(w, http.StatusNotFound) +} + +func (ts *SCIMTestSuite) TestSCIMMethodNotAllowedReturnsSCIMError() { + user := ts.createSCIMUser("method_not_allowed@test.com", "method_not_allowed@test.com") + + req := ts.makeSCIMRequest(http.MethodPost, "/scim/v2/Users/"+user.ID, map[string]interface{}{}) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + ts.assertSCIMError(w, http.StatusMethodNotAllowed) + require.Equal(ts.T(), "application/scim+json", w.Header().Get("Content-Type")) +} + +func (ts *SCIMTestSuite) TestSCIMCreateUserWithName() { + user := ts.createSCIMUserWithName(testUser2.UserName, testUser2.Email, testUser2.GivenName, testUser2.FamilyName) + + require.NotEmpty(ts.T(), user.ID) + require.Equal(ts.T(), testUser2.UserName, user.UserName) + require.NotNil(ts.T(), user.Name) + require.Equal(ts.T(), testUser2.GivenName, user.Name.GivenName) + require.Equal(ts.T(), testUser2.FamilyName, user.Name.FamilyName) + require.Equal(ts.T(), testUser2.Formatted, user.Name.Formatted) +} + +func (ts *SCIMTestSuite) TestSCIMCreateUserWithExternalID() { + user := ts.createSCIMUserWithExternalID(testUser3.UserName, testUser3.Email, testUser3.ExternalID) + + require.NotEmpty(ts.T(), user.ID) + require.Equal(ts.T(), testUser3.UserName, user.UserName) + require.Equal(ts.T(), testUser3.ExternalID, user.ExternalID) +} + +func (ts *SCIMTestSuite) TestSCIMGetUser() { + created := ts.createSCIMUser(testUser4.UserName, testUser4.Email) + + req := ts.makeSCIMRequest(http.MethodGet, "/scim/v2/Users/"+created.ID, nil) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + var result SCIMUserResponse + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&result)) + + require.Equal(ts.T(), created.ID, result.ID) + require.Equal(ts.T(), testUser4.UserName, result.UserName) +} + +func (ts *SCIMTestSuite) TestSCIMGetGroup() { + created := ts.createSCIMGroup(testGroup2.DisplayName) + + req := ts.makeSCIMRequest(http.MethodGet, "/scim/v2/Groups/"+created.ID, nil) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + var result SCIMGroupResponse + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&result)) + + require.Equal(ts.T(), created.ID, result.ID) + require.Equal(ts.T(), testGroup2.DisplayName, result.DisplayName) +} + +func (ts *SCIMTestSuite) TestSCIMListUsersWithData() { + ts.createSCIMUser(testUser1.UserName, testUser1.Email) + ts.createSCIMUser(testUser2.UserName, testUser2.Email) + ts.createSCIMUser(testUser3.UserName, testUser3.Email) + + req := ts.makeSCIMRequest(http.MethodGet, "/scim/v2/Users", nil) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + + result := ts.assertSCIMListResponse(w, 3) + require.Len(ts.T(), result.Resources, 3) +} + +func (ts *SCIMTestSuite) TestSCIMListGroupsWithData() { + ts.createSCIMGroup(testGroup1.DisplayName) + ts.createSCIMGroup(testGroup3.DisplayName) + + req := ts.makeSCIMRequest(http.MethodGet, "/scim/v2/Groups", nil) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + + result := ts.assertSCIMListResponse(w, 2) + require.Len(ts.T(), result.Resources, 2) +} + +func (ts *SCIMTestSuite) TestSCIMDeleteUser() { + user := ts.createSCIMUser(testUser5.UserName, testUser5.Email) + + require.True(ts.T(), user.Active) + + req := ts.makeSCIMRequest(http.MethodDelete, "/scim/v2/Users/"+user.ID, nil) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusNoContent, w.Code) + + req = ts.makeSCIMRequest(http.MethodGet, "/scim/v2/Users/"+user.ID, nil) + w = httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + var result SCIMUserResponse + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&result)) + require.False(ts.T(), result.Active, "Deprovisioned user should have active=false") + + req = ts.makeSCIMRequest(http.MethodDelete, "/scim/v2/Users/"+user.ID, nil) + w = httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusNotFound, w.Code) +} + +func (ts *SCIMTestSuite) TestSCIMDeleteGroup() { + group := ts.createSCIMGroup(testGroup4.DisplayName) + + req := ts.makeSCIMRequest(http.MethodDelete, "/scim/v2/Groups/"+group.ID, nil) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusNoContent, w.Code) + + req = ts.makeSCIMRequest(http.MethodGet, "/scim/v2/Groups/"+group.ID, nil) + w = httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + ts.assertSCIMError(w, http.StatusNotFound) +} + +func (ts *SCIMTestSuite) TestSCIMCreateGroupWithMembers() { + user1 := ts.createSCIMUser(testUser6.UserName, testUser6.Email) + user2 := ts.createSCIMUser(testUser7.UserName, testUser7.Email) + + group := ts.createSCIMGroupWithMembers(testGroup5.DisplayName, []string{user1.ID, user2.ID}) + + require.NotEmpty(ts.T(), group.ID) + require.Equal(ts.T(), testGroup5.DisplayName, group.DisplayName) + require.Len(ts.T(), group.Members, 2) +} + +func (ts *SCIMTestSuite) TestSCIMContentTypeHeader() { + req := ts.makeSCIMRequest(http.MethodGet, "/scim/v2/Users", nil) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + + require.Equal(ts.T(), "application/scim+json", w.Header().Get("Content-Type")) +} + +func (ts *SCIMTestSuite) TestSCIMCreateUserMissingUserName() { + body := map[string]interface{}{ + "schemas": []string{SCIMSchemaUser}, + "emails": []map[string]interface{}{ + {"value": "test@example.com", "primary": true}, + }, + } + + req := ts.makeSCIMRequest(http.MethodPost, "/scim/v2/Users", body) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + ts.assertSCIMError(w, http.StatusBadRequest) +} + +func (ts *SCIMTestSuite) TestSCIMCreateGroupMissingDisplayName() { + body := map[string]interface{}{ + "schemas": []string{SCIMSchemaGroup}, + } + + req := ts.makeSCIMRequest(http.MethodPost, "/scim/v2/Groups", body) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + ts.assertSCIMError(w, http.StatusBadRequest) +} + +func (ts *SCIMTestSuite) TestSCIMUserPagination() { + for i := 0; i < 5; i++ { + ts.createSCIMUser(fmt.Sprintf("pageuser%d@acme.com", i), fmt.Sprintf("pageuser%d@acme.com", i)) + } + + req := ts.makeSCIMRequest(http.MethodGet, "/scim/v2/Users?startIndex=1&count=2", nil) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + var result SCIMListResponse + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&result)) + + require.Equal(ts.T(), 5, result.TotalResults) + require.Equal(ts.T(), 2, result.ItemsPerPage) + require.Len(ts.T(), result.Resources, 2) +} + +func (ts *SCIMTestSuite) assertSCIMErrorWithType(w *httptest.ResponseRecorder, expectedStatus int, expectedScimType string) { + require.Equal(ts.T(), expectedStatus, w.Code) + + var errorResp map[string]interface{} + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&errorResp)) + + schemas, ok := errorResp["schemas"].([]interface{}) + require.True(ts.T(), ok, "SCIM error should have schemas field") + require.Len(ts.T(), schemas, 1) + require.Equal(ts.T(), "urn:ietf:params:scim:api:messages:2.0:Error", schemas[0]) + + _, ok = errorResp["detail"].(string) + require.True(ts.T(), ok, "SCIM error should have detail field") + + status, ok := errorResp["status"].(string) + require.True(ts.T(), ok, "SCIM error should have status field") + require.Equal(ts.T(), fmt.Sprintf("%d", expectedStatus), status) + + if expectedScimType != "" { + scimType, ok := errorResp["scimType"].(string) + require.True(ts.T(), ok, "SCIM error should have scimType field") + require.Equal(ts.T(), expectedScimType, scimType) + } +} + +func (ts *SCIMTestSuite) TestSCIMCreateUserAzure() { + body := map[string]interface{}{ + "schemas": []string{SCIMSchemaUser}, + "userName": testUser9.UserName, + "externalId": testUser9.ExternalID, + "name": map[string]interface{}{ + "formatted": testUser9.Formatted, + "familyName": testUser9.FamilyName, + "givenName": testUser9.GivenName, + }, + "emails": []map[string]interface{}{ + {"primary": true, "value": testUser9.Email}, + }, + "active": true, + } + + req := ts.makeSCIMRequest(http.MethodPost, "/scim/v2/Users", body) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + + require.Equal(ts.T(), http.StatusCreated, w.Code) + + var result SCIMUserResponse + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&result)) + + require.Len(ts.T(), result.Schemas, 1) + require.Equal(ts.T(), SCIMSchemaUser, result.Schemas[0]) + require.NotEmpty(ts.T(), result.ID) + require.Equal(ts.T(), testUser9.ExternalID, result.ExternalID) + require.Equal(ts.T(), testUser9.UserName, result.UserName) + require.NotNil(ts.T(), result.Name) + require.Equal(ts.T(), testUser9.Formatted, result.Name.Formatted) + require.Equal(ts.T(), testUser9.FamilyName, result.Name.FamilyName) + require.Equal(ts.T(), testUser9.GivenName, result.Name.GivenName) + require.Len(ts.T(), result.Emails, 1) + require.Equal(ts.T(), testUser9.Email, result.Emails[0].Value) + require.True(ts.T(), bool(result.Emails[0].Primary)) + require.True(ts.T(), result.Active) + require.Equal(ts.T(), "User", result.Meta.ResourceType) + require.NotNil(ts.T(), result.Meta.Created) + require.NotNil(ts.T(), result.Meta.LastModified) + require.Contains(ts.T(), result.Meta.Location, "/scim/v2/Users/"+result.ID) +} + +func (ts *SCIMTestSuite) TestSCIMCreateUserDuplicateExternalID() { + body := map[string]interface{}{ + "schemas": []string{SCIMSchemaUser}, + "userName": testUser10.UserName, + "externalId": testUser10.ExternalID, + "name": map[string]interface{}{ + "formatted": testUser10.Formatted, + "familyName": testUser10.FamilyName, + "givenName": testUser10.GivenName, + }, + "emails": []map[string]interface{}{ + {"primary": true, "value": testUser10.Email}, + }, + "active": true, + } + + req := ts.makeSCIMRequest(http.MethodPost, "/scim/v2/Users", body) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusCreated, w.Code) + + req = ts.makeSCIMRequest(http.MethodPost, "/scim/v2/Users", body) + w = httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + ts.assertSCIMErrorWithType(w, http.StatusConflict, "uniqueness") +} + +func (ts *SCIMTestSuite) TestSCIMDeleteNonExistentUser() { + nonExistentID := "f1937c5d-cd6d-4151-93b7-dbfb7fb9b31d" + + req := ts.makeSCIMRequest(http.MethodDelete, "/scim/v2/Users/"+nonExistentID, nil) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + ts.assertSCIMError(w, http.StatusNotFound) +} + +func (ts *SCIMTestSuite) TestSCIMReactivateDeprovisionedUser() { + user := ts.createSCIMUserWithName(testUser17.UserName, testUser17.Email, testUser17.GivenName, testUser17.FamilyName) + require.True(ts.T(), user.Active) + + req := ts.makeSCIMRequest(http.MethodDelete, "/scim/v2/Users/"+user.ID, nil) + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusNoContent, w.Code) + + req = ts.makeSCIMRequest(http.MethodGet, "/scim/v2/Users/"+user.ID, nil) + w = httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + var deprovisioned SCIMUserResponse + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&deprovisioned)) + require.False(ts.T(), deprovisioned.Active) + + body := map[string]interface{}{ + "schemas": []string{SCIMSchemaUser}, + "userName": testUser17.UserName, + "name": map[string]interface{}{ + "givenName": "Updated", + "familyName": "Name", + "formatted": "Updated Name", + }, + "emails": []map[string]interface{}{ + {"value": testUser17.Email, "primary": true, "type": "work"}, + }, + } + + req = ts.makeSCIMRequest(http.MethodPost, "/scim/v2/Users", body) + w = httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusCreated, w.Code, "Reactivating a deprovisioned SSO user should succeed: %s", w.Body.String()) + + var reactivated SCIMUserResponse + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&reactivated)) + require.True(ts.T(), reactivated.Active) + require.Equal(ts.T(), user.ID, reactivated.ID, "Reactivated user should have the same ID") + require.Equal(ts.T(), "Updated", reactivated.Name.GivenName) + require.Equal(ts.T(), "Name", reactivated.Name.FamilyName) +} + +func (ts *SCIMTestSuite) TestSCIMReactivateAmbiguousDeprovisioned() { + user1 := ts.createSCIMUserWithExternalID(testUser19.UserName, testUser19.Email, testUser19.ExternalID) + + req := ts.makeSCIMRequest(http.MethodDelete, "/scim/v2/Users/"+user1.ID, nil) + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusNoContent, w.Code) + + user2, err := models.NewUser("", testUser19.Email, "", ts.Config.JWT.Aud, nil) + require.NoError(ts.T(), err) + user2.IsSSOUser = true + reason := "SCIM_DEPROVISIONED" + user2.BannedReason = &reason + bannedUntil := time.Now().Add(time.Duration(math.MaxInt64)) + user2.BannedUntil = &bannedUntil + require.NoError(ts.T(), ts.API.db.Create(user2)) + + providerType := "sso:" + ts.SSOProvider.ID.String() + identity, err := models.NewIdentity(user2, providerType, map[string]interface{}{"sub": user2.ID.String()}) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.API.db.Create(identity)) + + body := map[string]interface{}{ + "schemas": []string{SCIMSchemaUser}, + "userName": testUser19.UserName, + "emails": []map[string]interface{}{ + {"value": testUser19.Email, "primary": true, "type": "work"}, + }, + } + + req = ts.makeSCIMRequest(http.MethodPost, "/scim/v2/Users", body) + w = httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusConflict, w.Code, "Ambiguous deprovisioned users should return 409: %s", w.Body.String()) +} + +func (ts *SCIMTestSuite) TestSCIMCreateUserCrossProviderSameEmail() { + ts.createSCIMUserWithExternalID(testUser18.UserName, testUser18.Email, testUser18.ExternalID) + + provider2 := &models.SSOProvider{} + require.NoError(ts.T(), ts.API.db.Create(provider2)) + token2 := "other-provider-token-cross" + provider2.SetSCIMToken(token2) + require.NoError(ts.T(), ts.API.db.Update(provider2)) + + body := map[string]interface{}{ + "schemas": []string{SCIMSchemaUser}, + "userName": testUser18.UserName, + "externalId": "other-provider-ext", + "emails": []map[string]interface{}{ + {"value": testUser18.Email, "primary": true, "type": "work"}, + }, + } + + req := ts.makeSCIMRequest(http.MethodPost, "/scim/v2/Users", body) + req.Header.Set("Authorization", "Bearer "+token2) + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusCreated, w.Code, "Cross-provider create with same email should succeed: %s", w.Body.String()) + + var result SCIMUserResponse + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&result)) + require.True(ts.T(), result.Active) + require.Equal(ts.T(), testUser18.UserName, result.UserName) +} + +func (ts *SCIMTestSuite) TestSCIMFilterUserByUserNameExisting() { + created := ts.createSCIMUserWithExternalID(testUser13.UserName, testUser13.Email, testUser13.ExternalID) + + req := ts.makeSCIMRequest(http.MethodGet, "/scim/v2/Users?filter=userName+eq+%22user13%40example.com%22", nil) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + var result SCIMListResponse + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&result)) + + require.Len(ts.T(), result.Schemas, 1) + require.Equal(ts.T(), SCIMSchemaListResponse, result.Schemas[0]) + require.Equal(ts.T(), 1, result.TotalResults) + require.Equal(ts.T(), 1, result.StartIndex) + require.Equal(ts.T(), 1, result.ItemsPerPage) + require.Len(ts.T(), result.Resources, 1) + + resource := result.Resources[0].(map[string]interface{}) + require.Equal(ts.T(), created.ID, resource["id"]) + require.Equal(ts.T(), testUser13.UserName, resource["userName"]) + require.Equal(ts.T(), testUser13.ExternalID, resource["externalId"]) + require.Equal(ts.T(), true, resource["active"]) + + schemas := resource["schemas"].([]interface{}) + require.Len(ts.T(), schemas, 1) + require.Equal(ts.T(), SCIMSchemaUser, schemas[0]) + + meta := resource["meta"].(map[string]interface{}) + require.Equal(ts.T(), "User", meta["resourceType"]) + require.NotEmpty(ts.T(), meta["created"]) + require.NotEmpty(ts.T(), meta["lastModified"]) + require.Contains(ts.T(), meta["location"], "/scim/v2/Users/"+created.ID) +} + +func (ts *SCIMTestSuite) TestSCIMFilterUserByUserNameNonExistent() { + ts.createSCIMUser(testUser8.UserName, testUser8.Email) + + req := ts.makeSCIMRequest(http.MethodGet, "/scim/v2/Users?filter=userName+eq+%22nonexistent%40example.com%22", nil) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + var result SCIMListResponse + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&result)) + + require.Len(ts.T(), result.Schemas, 1) + require.Equal(ts.T(), SCIMSchemaListResponse, result.Schemas[0]) + require.Equal(ts.T(), 0, result.TotalResults) + require.Equal(ts.T(), 1, result.StartIndex) + require.Equal(ts.T(), 0, result.ItemsPerPage) + require.Len(ts.T(), result.Resources, 0) +} + +func (ts *SCIMTestSuite) TestSCIMFilterUserByUserNameCaseInsensitive() { + created := ts.createSCIMUserWithExternalID(testUser14.UserName, testUser14.Email, testUser14.ExternalID) + + req := ts.makeSCIMRequest(http.MethodGet, "/scim/v2/Users?filter=userName+eq+%22USER14%40ACME.COM%22", nil) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + var result SCIMListResponse + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&result)) + + require.Len(ts.T(), result.Schemas, 1) + require.Equal(ts.T(), SCIMSchemaListResponse, result.Schemas[0]) + require.Equal(ts.T(), 1, result.TotalResults) + require.Equal(ts.T(), 1, result.StartIndex) + require.Equal(ts.T(), 1, result.ItemsPerPage) + require.Len(ts.T(), result.Resources, 1) + + resource := result.Resources[0].(map[string]interface{}) + require.Equal(ts.T(), created.ID, resource["id"]) + require.Equal(ts.T(), testUser14.UserName, resource["userName"]) + require.Equal(ts.T(), true, resource["active"]) + + schemas := resource["schemas"].([]interface{}) + require.Len(ts.T(), schemas, 1) + require.Equal(ts.T(), SCIMSchemaUser, schemas[0]) + + meta := resource["meta"].(map[string]interface{}) + require.Equal(ts.T(), "User", meta["resourceType"]) + require.NotEmpty(ts.T(), meta["created"]) + require.NotEmpty(ts.T(), meta["lastModified"]) + require.Contains(ts.T(), meta["location"], "/scim/v2/Users/"+created.ID) +} + +func (ts *SCIMTestSuite) TestSCIMPatchUserUpdateUserName() { + user := ts.createSCIMUserWithExternalID(testUser15.UserName, testUser15.Email, testUser15.ExternalID) + newUserName := "sam.updated@acme.com" + + body := map[string]interface{}{ + "schemas": []string{SCIMSchemaPatchOp}, + "Operations": []map[string]interface{}{ + {"op": "replace", "value": map[string]interface{}{"userName": newUserName}}, + }, + } + + req := ts.makeSCIMRequest(http.MethodPatch, "/scim/v2/Users/"+user.ID, body) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + var result SCIMUserResponse + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&result)) + + require.Len(ts.T(), result.Schemas, 1) + require.Equal(ts.T(), SCIMSchemaUser, result.Schemas[0]) + require.Equal(ts.T(), user.ID, result.ID) + require.Equal(ts.T(), testUser15.ExternalID, result.ExternalID) + require.Equal(ts.T(), newUserName, result.UserName) + require.True(ts.T(), result.Active) + require.Equal(ts.T(), "User", result.Meta.ResourceType) + require.NotNil(ts.T(), result.Meta.Created) + require.NotNil(ts.T(), result.Meta.LastModified) + require.Contains(ts.T(), result.Meta.Location, "/scim/v2/Users/"+result.ID) +} + +func (ts *SCIMTestSuite) TestSCIMPatchUserDisable() { + user := ts.createSCIMUserWithExternalID(testUser16.UserName, testUser16.Email, testUser16.ExternalID) + + require.True(ts.T(), user.Active) + + body := map[string]interface{}{ + "schemas": []string{SCIMSchemaPatchOp}, + "Operations": []map[string]interface{}{ + {"op": "replace", "path": "active", "value": false}, + }, + } + + req := ts.makeSCIMRequest(http.MethodPatch, "/scim/v2/Users/"+user.ID, body) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + var result SCIMUserResponse + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&result)) + + require.Len(ts.T(), result.Schemas, 1) + require.Equal(ts.T(), SCIMSchemaUser, result.Schemas[0]) + require.Equal(ts.T(), user.ID, result.ID) + require.Equal(ts.T(), testUser16.ExternalID, result.ExternalID) + require.Equal(ts.T(), testUser16.UserName, result.UserName) + require.False(ts.T(), result.Active) + require.Equal(ts.T(), "User", result.Meta.ResourceType) + require.NotNil(ts.T(), result.Meta.Created) + require.NotNil(ts.T(), result.Meta.LastModified) + require.Contains(ts.T(), result.Meta.Location, "/scim/v2/Users/"+result.ID) + + req = ts.makeSCIMRequest(http.MethodGet, "/scim/v2/Users/"+user.ID, nil) + w = httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + var getResult SCIMUserResponse + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&getResult)) + require.False(ts.T(), getResult.Active) +} + +func (ts *SCIMTestSuite) TestSCIMPatchUserReplaceEmailPrimaryEqTrue() { + origUserName := "patchemail@acme.com" + origEmail := "patchemail@acme.com" + newEmail := "updated.email@acme.com" + user := ts.createSCIMUserWithExternalID(origUserName, origEmail, "ext-patch-email-001") + + require.Len(ts.T(), user.Emails, 1) + require.Equal(ts.T(), origEmail, user.Emails[0].Value) + + body := map[string]interface{}{ + "schemas": []string{SCIMSchemaPatchOp}, + "Operations": []map[string]interface{}{ + {"op": "replace", "path": "emails[primary eq true].value", "value": newEmail}, + {"op": "replace", "value": map[string]interface{}{ + "name.formatted": "Updated Name", + "name.familyName": "Name", + "name.givenName": "Updated", + "active": true, + "externalId": "ext-patch-email-002", + }}, + }, + } + + req := ts.makeSCIMRequest(http.MethodPatch, "/scim/v2/Users/"+user.ID, body) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + var result SCIMUserResponse + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&result)) + + require.Len(ts.T(), result.Schemas, 1) + require.Equal(ts.T(), SCIMSchemaUser, result.Schemas[0]) + require.Equal(ts.T(), user.ID, result.ID) + require.Equal(ts.T(), "ext-patch-email-002", result.ExternalID) + require.Equal(ts.T(), origUserName, result.UserName) + require.NotNil(ts.T(), result.Name) + require.Equal(ts.T(), "Updated Name", result.Name.Formatted) + require.Equal(ts.T(), "Name", result.Name.FamilyName) + require.Equal(ts.T(), "Updated", result.Name.GivenName) + require.True(ts.T(), result.Active) + + require.Len(ts.T(), result.Emails, 1) + require.Equal(ts.T(), newEmail, result.Emails[0].Value) + require.True(ts.T(), bool(result.Emails[0].Primary)) + + req = ts.makeSCIMRequest(http.MethodGet, "/scim/v2/Users/"+user.ID, nil) + w = httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + var getResult SCIMUserResponse + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&getResult)) + require.Len(ts.T(), getResult.Emails, 1) + require.Equal(ts.T(), newEmail, getResult.Emails[0].Value) +} + +func (ts *SCIMTestSuite) TestSCIMPatchUserMultipleOperationsSameAttribute() { + userName := "multiop@acme.com" + user := ts.createSCIMUserWithExternalID(userName, userName, "ext-multi-001") + + body := map[string]interface{}{ + "schemas": []string{SCIMSchemaPatchOp}, + "Operations": []map[string]interface{}{ + {"op": "remove", "path": "externalId"}, + {"op": "add", "value": map[string]interface{}{"externalId": "ext-multi-002"}}, + {"op": "replace", "value": map[string]interface{}{"externalId": "ext-multi-003"}}, + }, + } + + req := ts.makeSCIMRequest(http.MethodPatch, "/scim/v2/Users/"+user.ID, body) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + var result SCIMUserResponse + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&result)) + + require.Len(ts.T(), result.Schemas, 1) + require.Equal(ts.T(), SCIMSchemaUser, result.Schemas[0]) + require.Equal(ts.T(), user.ID, result.ID) + require.Equal(ts.T(), "ext-multi-003", result.ExternalID) + require.Equal(ts.T(), userName, result.UserName) + require.True(ts.T(), result.Active) + require.Equal(ts.T(), "User", result.Meta.ResourceType) + require.NotNil(ts.T(), result.Meta.Created) + require.NotNil(ts.T(), result.Meta.LastModified) + require.Contains(ts.T(), result.Meta.Location, "/scim/v2/Users/"+result.ID) +} + +func (ts *SCIMTestSuite) TestSCIMPatchUserNotFound() { + body := map[string]interface{}{ + "schemas": []string{SCIMSchemaPatchOp}, + "Operations": []map[string]interface{}{ + {"op": "replace", "path": "active", "value": false}, + }, + } + + req := ts.makeSCIMRequest(http.MethodPatch, "/scim/v2/Users/00000000-0000-0000-0000-000000000000", body) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + ts.assertSCIMError(w, http.StatusNotFound) +} + +func (ts *SCIMTestSuite) TestSCIMPatchUserReEnableUser() { + user := ts.createSCIMUserWithExternalID("disabled_user@test.com", "disabled_user@test.com", "disable-reenable-test") + + disableBody := map[string]interface{}{ + "schemas": []string{SCIMSchemaPatchOp}, + "Operations": []map[string]interface{}{ + {"op": "replace", "path": "active", "value": false}, + }, + } + + req := ts.makeSCIMRequest(http.MethodPatch, "/scim/v2/Users/"+user.ID, disableBody) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + var disabledResult SCIMUserResponse + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&disabledResult)) + require.False(ts.T(), disabledResult.Active) + + enableBody := map[string]interface{}{ + "schemas": []string{SCIMSchemaPatchOp}, + "Operations": []map[string]interface{}{ + {"op": "replace", "path": "active", "value": true}, + }, + } + + req = ts.makeSCIMRequest(http.MethodPatch, "/scim/v2/Users/"+user.ID, enableBody) + w = httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + var enabledResult SCIMUserResponse + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&enabledResult)) + require.True(ts.T(), enabledResult.Active) + + req = ts.makeSCIMRequest(http.MethodGet, "/scim/v2/Users/"+user.ID, nil) + w = httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + var getResult SCIMUserResponse + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&getResult)) + require.True(ts.T(), getResult.Active) +} + +func (ts *SCIMTestSuite) TestSCIMPatchUserUpdateUserNameWithPath() { + user := ts.createSCIMUserWithExternalID("original_username@test.com", "original_username@test.com", "username-path-test") + + body := map[string]interface{}{ + "schemas": []string{SCIMSchemaPatchOp}, + "Operations": []map[string]interface{}{ + {"op": "replace", "path": "userName", "value": "new_username@test.com"}, + }, + } + + req := ts.makeSCIMRequest(http.MethodPatch, "/scim/v2/Users/"+user.ID, body) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + var result SCIMUserResponse + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&result)) + require.Equal(ts.T(), "new_username@test.com", result.UserName) + + req = ts.makeSCIMRequest(http.MethodGet, "/scim/v2/Users/"+user.ID, nil) + w = httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + var getResult SCIMUserResponse + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&getResult)) + require.Equal(ts.T(), "new_username@test.com", getResult.UserName) +} + +func (ts *SCIMTestSuite) TestSCIMPatchUserUpdateUserNameWithPathSyncsSubjectWhenExternalIDMissing() { + oldUserName := "subject_sync_original@test.com" + newUserName := "subject_sync_new@test.com" + user := ts.createSCIMUser(oldUserName, oldUserName) + + body := map[string]interface{}{ + "schemas": []string{SCIMSchemaPatchOp}, + "Operations": []map[string]interface{}{ + {"op": "replace", "path": "userName", "value": newUserName}, + }, + } + + req := ts.makeSCIMRequest(http.MethodPatch, "/scim/v2/Users/"+user.ID, body) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + var result SCIMUserResponse + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&result)) + require.Equal(ts.T(), newUserName, result.UserName) + + providerType := "sso:" + ts.SSOProvider.ID.String() + identity, err := models.FindIdentityByIdAndProvider(ts.API.db, newUserName, providerType) + require.NoError(ts.T(), err) + require.Equal(ts.T(), newUserName, identity.ProviderID) + require.Equal(ts.T(), newUserName, identity.IdentityData["user_name"]) + require.Equal(ts.T(), newUserName, identity.IdentityData["sub"]) + + _, err = models.FindIdentityByIdAndProvider(ts.API.db, oldUserName, providerType) + require.Error(ts.T(), err) +} + +func (ts *SCIMTestSuite) TestSCIMPatchUserInvalidActiveType() { + user := ts.createSCIMUser("invalid_active_test@test.com", "invalid_active_test@test.com") + + body := map[string]interface{}{ + "schemas": []string{SCIMSchemaPatchOp}, + "Operations": []map[string]interface{}{ + {"op": "replace", "path": "active", "value": "not_a_boolean"}, + }, + } + + req := ts.makeSCIMRequest(http.MethodPatch, "/scim/v2/Users/"+user.ID, body) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + ts.assertSCIMErrorWithType(w, http.StatusBadRequest, "invalidValue") +} + +func (ts *SCIMTestSuite) TestSCIMCreateGroupAzure() { + body := map[string]interface{}{ + "schemas": []string{SCIMSchemaGroup}, + "displayName": "QGKWKSWJWHXE", + "externalId": "7dae2322-0f90-42d2-97a1-b8268d2993d3", + } + + req := ts.makeSCIMRequest(http.MethodPost, "/scim/v2/Groups", body) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + + require.Equal(ts.T(), http.StatusCreated, w.Code) + + var result SCIMGroupResponse + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&result)) + + require.Len(ts.T(), result.Schemas, 1) + require.Equal(ts.T(), SCIMSchemaGroup, result.Schemas[0]) + require.NotEmpty(ts.T(), result.ID) + require.Equal(ts.T(), "7dae2322-0f90-42d2-97a1-b8268d2993d3", result.ExternalID) + require.Equal(ts.T(), "QGKWKSWJWHXE", result.DisplayName) + require.Equal(ts.T(), "Group", result.Meta.ResourceType) + require.NotNil(ts.T(), result.Meta.Created) + require.NotNil(ts.T(), result.Meta.LastModified) + require.Contains(ts.T(), result.Meta.Location, "/scim/v2/Groups/"+result.ID) +} + +func (ts *SCIMTestSuite) TestSCIMCreateGroupDuplicateExternalID() { + body := map[string]interface{}{ + "schemas": []string{SCIMSchemaGroup}, + "displayName": "SMVGZDBVFFRO", + "externalId": "e164812e-d012-4cc3-85dc-9ceb13765d62", + } + + req := ts.makeSCIMRequest(http.MethodPost, "/scim/v2/Groups", body) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusCreated, w.Code) + + body["displayName"] = "DIFFERENT_NAME" + req = ts.makeSCIMRequest(http.MethodPost, "/scim/v2/Groups", body) + w = httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + ts.assertSCIMErrorWithType(w, http.StatusConflict, "uniqueness") +} + +func (ts *SCIMTestSuite) TestSCIMDeleteNonExistentGroup() { + nonExistentID := "a0f1d64e-cf53-45cf-8b4b-ea0d7b9ada90" + + req := ts.makeSCIMRequest(http.MethodDelete, "/scim/v2/Groups/"+nonExistentID, nil) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + ts.assertSCIMError(w, http.StatusNotFound) +} + +func (ts *SCIMTestSuite) TestSCIMDeleteGroupTwice() { + group := ts.createSCIMGroupWithExternalID("YLKGXWFUUUOH", "69565956-96c5-4951-910d-951bba6d2533") + + req := ts.makeSCIMRequest(http.MethodDelete, "/scim/v2/Groups/"+group.ID, nil) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusNoContent, w.Code) + + req = ts.makeSCIMRequest(http.MethodDelete, "/scim/v2/Groups/"+group.ID, nil) + w = httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + ts.assertSCIMError(w, http.StatusNotFound) +} + +func (ts *SCIMTestSuite) TestSCIMGetGroupByIdExcludingMembers() { + group := ts.createSCIMGroupWithExternalID("YWWBHTHEMMLR", "94631638-0b6c-4b97-a369-aba35a454041") + + req := ts.makeSCIMRequest(http.MethodGet, "/scim/v2/Groups/"+group.ID+"?excludedAttributes=members", nil) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + var result SCIMGroupResponse + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&result)) + + require.Len(ts.T(), result.Schemas, 1) + require.Equal(ts.T(), SCIMSchemaGroup, result.Schemas[0]) + require.Equal(ts.T(), group.ID, result.ID) + require.Equal(ts.T(), "94631638-0b6c-4b97-a369-aba35a454041", result.ExternalID) + require.Equal(ts.T(), "YWWBHTHEMMLR", result.DisplayName) + require.Equal(ts.T(), "Group", result.Meta.ResourceType) + require.NotNil(ts.T(), result.Meta.Created) + require.NotNil(ts.T(), result.Meta.LastModified) + require.Contains(ts.T(), result.Meta.Location, "/scim/v2/Groups/"+result.ID) +} + +func (ts *SCIMTestSuite) TestSCIMPatchUserInvalidUserNameType() { + user := ts.createSCIMUser("invalid_username_test@test.com", "invalid_username_test@test.com") + + body := map[string]interface{}{ + "schemas": []string{SCIMSchemaPatchOp}, + "Operations": []map[string]interface{}{ + {"op": "replace", "path": "userName", "value": 12345}, + }, + } + + req := ts.makeSCIMRequest(http.MethodPatch, "/scim/v2/Users/"+user.ID, body) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + ts.assertSCIMErrorWithType(w, http.StatusBadRequest, "invalidValue") +} + +func (ts *SCIMTestSuite) TestSCIMPatchUserUnsupportedOp() { + user := ts.createSCIMUser("unsupported_op_test@test.com", "unsupported_op_test@test.com") + + body := map[string]interface{}{ + "schemas": []string{SCIMSchemaPatchOp}, + "Operations": []map[string]interface{}{ + {"op": "copy", "path": "userName", "value": "new@test.com"}, + }, + } + + req := ts.makeSCIMRequest(http.MethodPatch, "/scim/v2/Users/"+user.ID, body) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + ts.assertSCIMErrorWithType(w, http.StatusBadRequest, "invalidSyntax") +} + +func (ts *SCIMTestSuite) TestSCIMPatchUserUnsupportedReplacePath() { + user := ts.createSCIMUser("unsup_replace_path@test.com", "unsup_replace_path@test.com") + + body := map[string]interface{}{ + "schemas": []string{SCIMSchemaPatchOp}, + "Operations": []map[string]interface{}{ + {"op": "replace", "path": "displayName", "value": "Foo"}, + }, + } + + req := ts.makeSCIMRequest(http.MethodPatch, "/scim/v2/Users/"+user.ID, body) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + ts.assertSCIMErrorWithType(w, http.StatusBadRequest, "invalidPath") +} + +func (ts *SCIMTestSuite) TestSCIMPatchUserUnsupportedRemovePath() { + user := ts.createSCIMUser("unsup_remove_path@test.com", "unsup_remove_path@test.com") + + body := map[string]interface{}{ + "schemas": []string{SCIMSchemaPatchOp}, + "Operations": []map[string]interface{}{ + {"op": "remove", "path": "displayName"}, + }, + } + + req := ts.makeSCIMRequest(http.MethodPatch, "/scim/v2/Users/"+user.ID, body) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + ts.assertSCIMErrorWithType(w, http.StatusBadRequest, "invalidPath") +} + +func (ts *SCIMTestSuite) TestSCIMPatchUserRemoveWithoutPath() { + user := ts.createSCIMUser("remove_no_path@test.com", "remove_no_path@test.com") + + body := map[string]interface{}{ + "schemas": []string{SCIMSchemaPatchOp}, + "Operations": []map[string]interface{}{ + {"op": "remove"}, + }, + } + + req := ts.makeSCIMRequest(http.MethodPatch, "/scim/v2/Users/"+user.ID, body) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + ts.assertSCIMErrorWithType(w, http.StatusBadRequest, "noTarget") +} + +func (ts *SCIMTestSuite) TestSCIMPatchUserAddExternalIDWithPath() { + user := ts.createSCIMUser("add_ext_path@test.com", "add_ext_path@test.com") + + body := map[string]interface{}{ + "schemas": []string{SCIMSchemaPatchOp}, + "Operations": []map[string]interface{}{ + {"op": "add", "path": "externalId", "value": "new-ext-via-path"}, + }, + } + + req := ts.makeSCIMRequest(http.MethodPatch, "/scim/v2/Users/"+user.ID, body) + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code, "add with path should succeed: %s", w.Body.String()) + + var result SCIMUserResponse + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&result)) + require.Equal(ts.T(), "new-ext-via-path", result.ExternalID) +} + +func (ts *SCIMTestSuite) TestSCIMPatchUserRejectsEmptyExternalID() { + user := ts.createSCIMUserWithExternalID("empty_external_id_patch@test.com", "empty_external_id_patch@test.com", "ext-original-id") + + body := map[string]interface{}{ + "schemas": []string{SCIMSchemaPatchOp}, + "Operations": []map[string]interface{}{ + {"op": "replace", "path": "externalId", "value": ""}, + }, + } + + req := ts.makeSCIMRequest(http.MethodPatch, "/scim/v2/Users/"+user.ID, body) + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + ts.assertSCIMErrorWithType(w, http.StatusBadRequest, "invalidValue") + + req = ts.makeSCIMRequest(http.MethodGet, "/scim/v2/Users/"+user.ID, nil) + w = httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + var result SCIMUserResponse + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&result)) + require.Equal(ts.T(), "ext-original-id", result.ExternalID) +} + +func (ts *SCIMTestSuite) TestSCIMPatchUserAddInvalidValueType() { + user := ts.createSCIMUser("add_invalid_val@test.com", "add_invalid_val@test.com") + + body := map[string]interface{}{ + "schemas": []string{SCIMSchemaPatchOp}, + "Operations": []map[string]interface{}{ + {"op": "add", "value": "not_an_object"}, + }, + } + + req := ts.makeSCIMRequest(http.MethodPatch, "/scim/v2/Users/"+user.ID, body) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + ts.assertSCIMErrorWithType(w, http.StatusBadRequest, "invalidValue") +} + +func (ts *SCIMTestSuite) TestSCIMPatchUserReplaceInvalidValueType() { + user := ts.createSCIMUser("replace_invalid_val@test.com", "replace_invalid_val@test.com") + + body := map[string]interface{}{ + "schemas": []string{SCIMSchemaPatchOp}, + "Operations": []map[string]interface{}{ + {"op": "replace", "value": "not_an_object"}, + }, + } + + req := ts.makeSCIMRequest(http.MethodPatch, "/scim/v2/Users/"+user.ID, body) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + ts.assertSCIMErrorWithType(w, http.StatusBadRequest, "invalidValue") +} + +func (ts *SCIMTestSuite) TestSCIMPatchGroupUnsupportedReplacePath() { + group := ts.createSCIMGroupWithExternalID("UnsupReplPath", "unsup-repl-path-ext") + + body := map[string]interface{}{ + "schemas": []string{SCIMSchemaPatchOp}, + "Operations": []map[string]interface{}{ + {"op": "replace", "path": "schemas", "value": "Foo"}, + }, + } + + req := ts.makeSCIMRequest(http.MethodPatch, "/scim/v2/Groups/"+group.ID, body) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + ts.assertSCIMErrorWithType(w, http.StatusBadRequest, "invalidPath") +} + +func (ts *SCIMTestSuite) TestSCIMPatchGroupReplaceInvalidValueType() { + group := ts.createSCIMGroupWithExternalID("ReplInvalidVal", "repl-invalid-val-ext") + + body := map[string]interface{}{ + "schemas": []string{SCIMSchemaPatchOp}, + "Operations": []map[string]interface{}{ + {"op": "replace", "value": "not_an_object"}, + }, + } + + req := ts.makeSCIMRequest(http.MethodPatch, "/scim/v2/Groups/"+group.ID, body) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + ts.assertSCIMErrorWithType(w, http.StatusBadRequest, "invalidValue") +} + +func (ts *SCIMTestSuite) TestSCIMPatchGroupUnsupportedAddPath() { + group := ts.createSCIMGroupWithExternalID("UnsupAddPath", "unsup-add-path-ext") + + body := map[string]interface{}{ + "schemas": []string{SCIMSchemaPatchOp}, + "Operations": []map[string]interface{}{ + {"op": "add", "path": "schemas", "value": "Foo"}, + }, + } + + req := ts.makeSCIMRequest(http.MethodPatch, "/scim/v2/Groups/"+group.ID, body) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + ts.assertSCIMErrorWithType(w, http.StatusBadRequest, "invalidPath") +} + +func (ts *SCIMTestSuite) TestSCIMFilterGroupByDisplayNameExisting() { + created := ts.createSCIMGroupWithExternalID("YWWBHTHEMMLR", "94631638-0b6c-4b97-a369-aba35a454041") + + req := ts.makeSCIMRequest(http.MethodGet, "/scim/v2/Groups?filter=displayName+eq+%22YWWBHTHEMMLR%22", nil) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + var result SCIMListResponse + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&result)) + + require.Len(ts.T(), result.Schemas, 1) + require.Equal(ts.T(), SCIMSchemaListResponse, result.Schemas[0]) + require.Equal(ts.T(), 1, result.TotalResults) + require.Equal(ts.T(), 1, result.StartIndex) + require.Equal(ts.T(), 1, result.ItemsPerPage) + require.Len(ts.T(), result.Resources, 1) + + resource := result.Resources[0].(map[string]interface{}) + require.Equal(ts.T(), created.ID, resource["id"]) + require.Equal(ts.T(), "YWWBHTHEMMLR", resource["displayName"]) + require.Equal(ts.T(), "94631638-0b6c-4b97-a369-aba35a454041", resource["externalId"]) + + schemas := resource["schemas"].([]interface{}) + require.Len(ts.T(), schemas, 1) + require.Equal(ts.T(), SCIMSchemaGroup, schemas[0]) + + meta := resource["meta"].(map[string]interface{}) + require.Equal(ts.T(), "Group", meta["resourceType"]) + require.NotEmpty(ts.T(), meta["created"]) + require.NotEmpty(ts.T(), meta["lastModified"]) + require.Contains(ts.T(), meta["location"], "/scim/v2/Groups/"+created.ID) +} + +func (ts *SCIMTestSuite) TestSCIMFilterGroupByDisplayNameExcludingMembers() { + created := ts.createSCIMGroupWithExternalID("YWWBHTHEMMLR", "94631638-0b6c-4b97-a369-aba35a454041") + + req := ts.makeSCIMRequest(http.MethodGet, "/scim/v2/Groups?excludedAttributes=members&filter=displayName+eq+%22YWWBHTHEMMLR%22", nil) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + var result SCIMListResponse + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&result)) + + require.Len(ts.T(), result.Schemas, 1) + require.Equal(ts.T(), SCIMSchemaListResponse, result.Schemas[0]) + require.Equal(ts.T(), 1, result.TotalResults) + require.Equal(ts.T(), 1, result.StartIndex) + require.Equal(ts.T(), 1, result.ItemsPerPage) + require.Len(ts.T(), result.Resources, 1) + + resource := result.Resources[0].(map[string]interface{}) + require.Equal(ts.T(), created.ID, resource["id"]) + require.Equal(ts.T(), "YWWBHTHEMMLR", resource["displayName"]) + require.Equal(ts.T(), "94631638-0b6c-4b97-a369-aba35a454041", resource["externalId"]) + + _, hasMembers := resource["members"] + require.False(ts.T(), hasMembers, "Response should exclude members attribute") + + schemas := resource["schemas"].([]interface{}) + require.Len(ts.T(), schemas, 1) + require.Equal(ts.T(), SCIMSchemaGroup, schemas[0]) + + meta := resource["meta"].(map[string]interface{}) + require.Equal(ts.T(), "Group", meta["resourceType"]) + require.NotEmpty(ts.T(), meta["created"]) + require.NotEmpty(ts.T(), meta["lastModified"]) + require.Contains(ts.T(), meta["location"], "/scim/v2/Groups/"+created.ID) +} + +func (ts *SCIMTestSuite) TestSCIMFilterGroupByDisplayNameNonExistent() { + ts.createSCIMGroup("SomeExistingGroup") + + req := ts.makeSCIMRequest(http.MethodGet, "/scim/v2/Groups?filter=displayName+eq+%22nonexistente997dccbd8b7_EOKNVHIYLTCZ%22", nil) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + var result SCIMListResponse + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&result)) + + require.Len(ts.T(), result.Schemas, 1) + require.Equal(ts.T(), SCIMSchemaListResponse, result.Schemas[0]) + require.Equal(ts.T(), 0, result.TotalResults) + require.Equal(ts.T(), 1, result.StartIndex) + require.Equal(ts.T(), 0, result.ItemsPerPage) + require.Len(ts.T(), result.Resources, 0) +} + +func (ts *SCIMTestSuite) TestSCIMFilterGroupByDisplayNameCaseInsensitive() { + created := ts.createSCIMGroupWithExternalID("YWWBHTHEMMLR", "94631638-0b6c-4b97-a369-aba35a454041") + + req := ts.makeSCIMRequest(http.MethodGet, "/scim/v2/Groups?filter=displayName+eq+%22ywwbhthemmlr%22", nil) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + var result SCIMListResponse + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&result)) + + require.Len(ts.T(), result.Schemas, 1) + require.Equal(ts.T(), SCIMSchemaListResponse, result.Schemas[0]) + require.Equal(ts.T(), 1, result.TotalResults) + require.Equal(ts.T(), 1, result.StartIndex) + require.Equal(ts.T(), 1, result.ItemsPerPage) + require.Len(ts.T(), result.Resources, 1) + + resource := result.Resources[0].(map[string]interface{}) + require.Equal(ts.T(), created.ID, resource["id"]) + require.Equal(ts.T(), "YWWBHTHEMMLR", resource["displayName"]) + require.Equal(ts.T(), "94631638-0b6c-4b97-a369-aba35a454041", resource["externalId"]) + + schemas := resource["schemas"].([]interface{}) + require.Len(ts.T(), schemas, 1) + require.Equal(ts.T(), SCIMSchemaGroup, schemas[0]) + + meta := resource["meta"].(map[string]interface{}) + require.Equal(ts.T(), "Group", meta["resourceType"]) + require.NotEmpty(ts.T(), meta["created"]) + require.NotEmpty(ts.T(), meta["lastModified"]) + require.Contains(ts.T(), meta["location"], "/scim/v2/Groups/"+created.ID) +} + +func (ts *SCIMTestSuite) TestSCIMPatchGroupReplaceExternalID() { + group := ts.createSCIMGroupWithExternalID("SFSNYLFDSMIG", "643a3bd4-43e1-481a-9ea6-bd82d65bbd04") + + body := map[string]interface{}{ + "schemas": []string{SCIMSchemaPatchOp}, + "Operations": []map[string]interface{}{ + {"op": "replace", "value": map[string]interface{}{"externalId": "3d413e4f-7404-45e9-86b9-478c9b6a894a"}}, + }, + } + + req := ts.makeSCIMRequest(http.MethodPatch, "/scim/v2/Groups/"+group.ID, body) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + var result SCIMGroupResponse + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&result)) + + require.Len(ts.T(), result.Schemas, 1) + require.Equal(ts.T(), SCIMSchemaGroup, result.Schemas[0]) + require.Equal(ts.T(), group.ID, result.ID) + require.Equal(ts.T(), "3d413e4f-7404-45e9-86b9-478c9b6a894a", result.ExternalID) + require.Equal(ts.T(), "SFSNYLFDSMIG", result.DisplayName) + require.Equal(ts.T(), "Group", result.Meta.ResourceType) + require.NotNil(ts.T(), result.Meta.Created) + require.NotNil(ts.T(), result.Meta.LastModified) + require.Contains(ts.T(), result.Meta.Location, "/scim/v2/Groups/"+result.ID) + + req = ts.makeSCIMRequest(http.MethodGet, "/scim/v2/Groups/"+group.ID, nil) + w = httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + var getResult SCIMGroupResponse + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&getResult)) + require.Equal(ts.T(), "3d413e4f-7404-45e9-86b9-478c9b6a894a", getResult.ExternalID) +} + +func (ts *SCIMTestSuite) TestSCIMPatchGroupUpdateDisplayName() { + group := ts.createSCIMGroupWithExternalID("NUOSLUZYECIZ", "fa01b7f2-ab68-4f97-a211-11f5732d0e15") + + body := map[string]interface{}{ + "schemas": []string{SCIMSchemaPatchOp}, + "Operations": []map[string]interface{}{ + {"op": "replace", "value": map[string]interface{}{"displayName": "YJCESZMOUKCA"}}, + }, + } + + req := ts.makeSCIMRequest(http.MethodPatch, "/scim/v2/Groups/"+group.ID, body) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + var result SCIMGroupResponse + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&result)) + + require.Len(ts.T(), result.Schemas, 1) + require.Equal(ts.T(), SCIMSchemaGroup, result.Schemas[0]) + require.Equal(ts.T(), group.ID, result.ID) + require.Equal(ts.T(), "fa01b7f2-ab68-4f97-a211-11f5732d0e15", result.ExternalID) + require.Equal(ts.T(), "YJCESZMOUKCA", result.DisplayName) + require.Equal(ts.T(), "Group", result.Meta.ResourceType) + require.NotNil(ts.T(), result.Meta.Created) + require.NotNil(ts.T(), result.Meta.LastModified) + require.Contains(ts.T(), result.Meta.Location, "/scim/v2/Groups/"+result.ID) + + req = ts.makeSCIMRequest(http.MethodGet, "/scim/v2/Groups/"+group.ID, nil) + w = httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + var getResult SCIMGroupResponse + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&getResult)) + require.Equal(ts.T(), "YJCESZMOUKCA", getResult.DisplayName) +} + +func (ts *SCIMTestSuite) TestSCIMPatchGroupAddMember() { + groupName := "AddMemberGroup" + groupExtID := "grp-add-001" + memberEmail := "member1@acme.com" + group := ts.createSCIMGroupWithExternalID(groupName, groupExtID) + user := ts.createSCIMUserWithExternalID(memberEmail, memberEmail, "usr-member-001") + + body := map[string]interface{}{ + "schemas": []string{SCIMSchemaPatchOp}, + "Operations": []map[string]interface{}{ + {"op": "replace", "path": "members", "value": []map[string]interface{}{ + {"value": user.ID}, + }}, + }, + } + + req := ts.makeSCIMRequest(http.MethodPatch, "/scim/v2/Groups/"+group.ID, body) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + var result SCIMGroupResponse + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&result)) + + require.Len(ts.T(), result.Schemas, 1) + require.Equal(ts.T(), SCIMSchemaGroup, result.Schemas[0]) + require.Equal(ts.T(), group.ID, result.ID) + require.Equal(ts.T(), groupExtID, result.ExternalID) + require.Equal(ts.T(), groupName, result.DisplayName) + require.Len(ts.T(), result.Members, 1) + require.Equal(ts.T(), user.ID, result.Members[0].Value) + require.Contains(ts.T(), result.Members[0].Ref, "/scim/v2/Users/"+user.ID) + require.Equal(ts.T(), memberEmail, result.Members[0].Display) + require.Equal(ts.T(), "Group", result.Meta.ResourceType) + require.NotNil(ts.T(), result.Meta.Created) + require.NotNil(ts.T(), result.Meta.LastModified) + require.Contains(ts.T(), result.Meta.Location, "/scim/v2/Groups/"+result.ID) + + req = ts.makeSCIMRequest(http.MethodGet, "/scim/v2/Groups/"+group.ID, nil) + w = httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + var getResult SCIMGroupResponse + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&getResult)) + require.Len(ts.T(), getResult.Members, 1) + require.Equal(ts.T(), user.ID, getResult.Members[0].Value) +} + +func (ts *SCIMTestSuite) TestSCIMPatchGroupRemoveMember() { + groupName := "RemoveMemberGroup" + groupExtID := "grp-remove-001" + member1Email := "member2@acme.com" + member2Email := "member3@acme.com" + group := ts.createSCIMGroupWithExternalID(groupName, groupExtID) + user1 := ts.createSCIMUserWithExternalID(member1Email, member1Email, "usr-member-002") + user2 := ts.createSCIMUserWithExternalID(member2Email, member2Email, "usr-member-003") + + addMembersBody := map[string]interface{}{ + "schemas": []string{SCIMSchemaPatchOp}, + "Operations": []map[string]interface{}{ + {"op": "replace", "path": "members", "value": []map[string]interface{}{ + {"value": user1.ID}, + {"value": user2.ID}, + }}, + }, + } + + req := ts.makeSCIMRequest(http.MethodPatch, "/scim/v2/Groups/"+group.ID, addMembersBody) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + var addResult SCIMGroupResponse + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&addResult)) + require.Len(ts.T(), addResult.Members, 2) + + removeMemberBody := map[string]interface{}{ + "schemas": []string{SCIMSchemaPatchOp}, + "Operations": []map[string]interface{}{ + {"op": "remove", "path": fmt.Sprintf("members[value eq \"%s\"]", user1.ID)}, + }, + } + + req = ts.makeSCIMRequest(http.MethodPatch, "/scim/v2/Groups/"+group.ID, removeMemberBody) + w = httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + var result SCIMGroupResponse + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&result)) + + require.Len(ts.T(), result.Schemas, 1) + require.Equal(ts.T(), SCIMSchemaGroup, result.Schemas[0]) + require.Equal(ts.T(), group.ID, result.ID) + require.Equal(ts.T(), groupExtID, result.ExternalID) + require.Equal(ts.T(), groupName, result.DisplayName) + require.Len(ts.T(), result.Members, 1) + require.Equal(ts.T(), user2.ID, result.Members[0].Value) + require.Contains(ts.T(), result.Members[0].Ref, "/scim/v2/Users/"+user2.ID) + require.Equal(ts.T(), member2Email, result.Members[0].Display) + require.Equal(ts.T(), "Group", result.Meta.ResourceType) + require.NotNil(ts.T(), result.Meta.Created) + require.NotNil(ts.T(), result.Meta.LastModified) + require.Contains(ts.T(), result.Meta.Location, "/scim/v2/Groups/"+result.ID) + + req = ts.makeSCIMRequest(http.MethodGet, "/scim/v2/Groups/"+group.ID, nil) + w = httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + var getResult SCIMGroupResponse + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&getResult)) + require.Len(ts.T(), getResult.Members, 1) + require.Equal(ts.T(), user2.ID, getResult.Members[0].Value) +} + +func (ts *SCIMTestSuite) TestSCIMPatchGroupMultipleOperationsAddThenRemoveMember() { + groupName := "MultiOpGroup" + groupExtID := "grp-multiop-001" + memberEmail := "member4@acme.com" + group := ts.createSCIMGroupWithExternalID(groupName, groupExtID) + user := ts.createSCIMUserWithExternalID(memberEmail, memberEmail, "usr-member-004") + + body := map[string]interface{}{ + "schemas": []string{SCIMSchemaPatchOp}, + "Operations": []map[string]interface{}{ + {"op": "add", "path": "members", "value": []map[string]interface{}{ + {"value": user.ID}, + }}, + {"op": "remove", "path": fmt.Sprintf("members[value eq \"%s\"]", user.ID)}, + }, + } + + req := ts.makeSCIMRequest(http.MethodPatch, "/scim/v2/Groups/"+group.ID, body) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + var result SCIMGroupResponse + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&result)) + + require.Len(ts.T(), result.Schemas, 1) + require.Equal(ts.T(), SCIMSchemaGroup, result.Schemas[0]) + require.Equal(ts.T(), group.ID, result.ID) + require.Equal(ts.T(), groupExtID, result.ExternalID) + require.Equal(ts.T(), groupName, result.DisplayName) + require.Empty(ts.T(), result.Members) + require.Equal(ts.T(), "Group", result.Meta.ResourceType) + require.NotNil(ts.T(), result.Meta.Created) + require.NotNil(ts.T(), result.Meta.LastModified) + require.Contains(ts.T(), result.Meta.Location, "/scim/v2/Groups/"+result.ID) + + req = ts.makeSCIMRequest(http.MethodGet, "/scim/v2/Groups/"+group.ID, nil) + w = httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + var getResult SCIMGroupResponse + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&getResult)) + require.Empty(ts.T(), getResult.Members) +} + +func (ts *SCIMTestSuite) TestSCIMPatchGroupNotFound() { + body := map[string]interface{}{ + "schemas": []string{SCIMSchemaPatchOp}, + "Operations": []map[string]interface{}{ + {"op": "replace", "value": map[string]interface{}{"displayName": "NewName"}}, + }, + } + + req := ts.makeSCIMRequest(http.MethodPatch, "/scim/v2/Groups/00000000-0000-0000-0000-000000000000", body) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + ts.assertSCIMError(w, http.StatusNotFound) +} + +func (ts *SCIMTestSuite) TestSCIMPatchGroupUpdateDisplayNameWithPath() { + group := ts.createSCIMGroupWithExternalID("ORIGINALNAME", "path-test-ext-id") + + body := map[string]interface{}{ + "schemas": []string{SCIMSchemaPatchOp}, + "Operations": []map[string]interface{}{ + {"op": "replace", "path": "displayName", "value": "NEWDISPLAYNAME"}, + }, + } + + req := ts.makeSCIMRequest(http.MethodPatch, "/scim/v2/Groups/"+group.ID, body) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + var result SCIMGroupResponse + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&result)) + require.Equal(ts.T(), "NEWDISPLAYNAME", result.DisplayName) + + req = ts.makeSCIMRequest(http.MethodGet, "/scim/v2/Groups/"+group.ID, nil) + w = httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + var getResult SCIMGroupResponse + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&getResult)) + require.Equal(ts.T(), "NEWDISPLAYNAME", getResult.DisplayName) +} + +func (ts *SCIMTestSuite) TestSCIMPatchGroupAddMemberWithAddOp() { + group := ts.createSCIMGroup("AddOpTestGroup") + user := ts.createSCIMUser("addop_member@test.com", "addop_member@test.com") + + body := map[string]interface{}{ + "schemas": []string{SCIMSchemaPatchOp}, + "Operations": []map[string]interface{}{ + {"op": "add", "path": "members", "value": []map[string]interface{}{ + {"value": user.ID}, + }}, + }, + } + + req := ts.makeSCIMRequest(http.MethodPatch, "/scim/v2/Groups/"+group.ID, body) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + var result SCIMGroupResponse + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&result)) + require.Len(ts.T(), result.Members, 1) + require.Equal(ts.T(), user.ID, result.Members[0].Value) + + req = ts.makeSCIMRequest(http.MethodGet, "/scim/v2/Groups/"+group.ID, nil) + w = httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + var getResult SCIMGroupResponse + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&getResult)) + require.Len(ts.T(), getResult.Members, 1) + require.Equal(ts.T(), user.ID, getResult.Members[0].Value) +} + +func (ts *SCIMTestSuite) TestSCIMPatchGroupRemoveAllMembers() { + user1 := ts.createSCIMUser("remove_all_member1@test.com", "remove_all_member1@test.com") + user2 := ts.createSCIMUser("remove_all_member2@test.com", "remove_all_member2@test.com") + group := ts.createSCIMGroupWithMembers("RemoveAllMembersGroup", []string{user1.ID, user2.ID}) + + require.Len(ts.T(), group.Members, 2) + + body := map[string]interface{}{ + "schemas": []string{SCIMSchemaPatchOp}, + "Operations": []map[string]interface{}{ + {"op": "replace", "path": "members", "value": []map[string]interface{}{}}, + }, + } + + req := ts.makeSCIMRequest(http.MethodPatch, "/scim/v2/Groups/"+group.ID, body) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + var result SCIMGroupResponse + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&result)) + require.Empty(ts.T(), result.Members) + + req = ts.makeSCIMRequest(http.MethodGet, "/scim/v2/Groups/"+group.ID, nil) + w = httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + var getResult SCIMGroupResponse + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&getResult)) + require.Empty(ts.T(), getResult.Members) +} + +func (ts *SCIMTestSuite) TestSCIMPatchGroupDisplayNameConflict() { + _ = ts.createSCIMGroupWithExternalID("FirstGroup", "conflict-ext-1") + secondGroup := ts.createSCIMGroupWithExternalID("SecondGroup", "conflict-ext-2") + + body := map[string]interface{}{ + "schemas": []string{SCIMSchemaPatchOp}, + "Operations": []map[string]interface{}{ + {"op": "replace", "path": "displayName", "value": "FirstGroup"}, + }, + } + + req := ts.makeSCIMRequest(http.MethodPatch, "/scim/v2/Groups/"+secondGroup.ID, body) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + ts.assertSCIMErrorWithType(w, http.StatusConflict, "uniqueness") +} + +func (ts *SCIMTestSuite) TestSCIMPatchGroupDisplayNameConflictValueMap() { + _ = ts.createSCIMGroupWithExternalID("ValueMapFirst", "vm-ext-1") + secondGroup := ts.createSCIMGroupWithExternalID("ValueMapSecond", "vm-ext-2") + + body := map[string]interface{}{ + "schemas": []string{SCIMSchemaPatchOp}, + "Operations": []map[string]interface{}{ + {"op": "replace", "value": map[string]interface{}{"displayName": "ValueMapFirst"}}, + }, + } + + req := ts.makeSCIMRequest(http.MethodPatch, "/scim/v2/Groups/"+secondGroup.ID, body) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + ts.assertSCIMErrorWithType(w, http.StatusConflict, "uniqueness") +} + +func (ts *SCIMTestSuite) TestSCIMReplaceGroupDisplayNameConflict() { + _ = ts.createSCIMGroupWithExternalID("ReplaceFirst", "replace-ext-1") + secondGroup := ts.createSCIMGroupWithExternalID("ReplaceSecond", "replace-ext-2") + + body := map[string]interface{}{ + "schemas": []string{SCIMSchemaGroup}, + "displayName": "ReplaceFirst", + } + + req := ts.makeSCIMRequest(http.MethodPut, "/scim/v2/Groups/"+secondGroup.ID, body) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + ts.assertSCIMErrorWithType(w, http.StatusConflict, "uniqueness") +} + +func (ts *SCIMTestSuite) TestSCIMAuthMissingAuthorizationHeader() { + req := httptest.NewRequest(http.MethodGet, "http://localhost/scim/v2/Users", nil) + req.Header.Set("Content-Type", "application/scim+json") + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + ts.assertSCIMError(w, http.StatusUnauthorized) +} + +func (ts *SCIMTestSuite) TestSCIMAuthInvalidBearerToken() { + req := httptest.NewRequest(http.MethodGet, "http://localhost/scim/v2/Users", nil) + req.Header.Set("Authorization", "Bearer completely-invalid-token-xyz") + req.Header.Set("Content-Type", "application/scim+json") + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + ts.assertSCIMError(w, http.StatusUnauthorized) +} + +func (ts *SCIMTestSuite) TestSCIMAuthMalformedAuthorizationHeader() { + req := httptest.NewRequest(http.MethodGet, "http://localhost/scim/v2/Users", nil) + req.Header.Set("Authorization", "Basic dXNlcjpwYXNz") + req.Header.Set("Content-Type", "application/scim+json") + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + ts.assertSCIMError(w, http.StatusUnauthorized) +} + +func (ts *SCIMTestSuite) TestSCIMAuthEmptyBearerToken() { + req := httptest.NewRequest(http.MethodGet, "http://localhost/scim/v2/Users", nil) + req.Header.Set("Authorization", "Bearer ") + req.Header.Set("Content-Type", "application/scim+json") + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + ts.assertSCIMError(w, http.StatusUnauthorized) +} + +func (ts *SCIMTestSuite) TestSCIMErrorInvalidFilterSyntax() { + req := ts.makeSCIMRequest(http.MethodGet, "/scim/v2/Users?filter=invalid+++syntax", nil) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + + ts.assertSCIMErrorWithType(w, http.StatusBadRequest, "invalidFilter") +} + +func (ts *SCIMTestSuite) TestSCIMErrorInvalidFilterUnclosedQuote() { + req := ts.makeSCIMRequest(http.MethodGet, "/scim/v2/Users?filter=userName+eq+%22unclosed", nil) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + + ts.assertSCIMErrorWithType(w, http.StatusBadRequest, "invalidFilter") +} + +func (ts *SCIMTestSuite) TestSCIMErrorInvalidFilterUnsupportedAttribute() { + req := ts.makeSCIMRequest(http.MethodGet, "/scim/v2/Users?filter=unsupportedAttr+eq+%22value%22", nil) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + + ts.assertSCIMErrorWithType(w, http.StatusBadRequest, "invalidFilter") +} + +func (ts *SCIMTestSuite) TestSCIMErrorInvalidFilterGroupUnsupportedAttribute() { + req := ts.makeSCIMRequest(http.MethodGet, "/scim/v2/Groups?filter=invalidAttr+eq+%22value%22", nil) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + + ts.assertSCIMErrorWithType(w, http.StatusBadRequest, "invalidFilter") +} + +func (ts *SCIMTestSuite) TestSCIMErrorInvalidPatchOperationCopy() { + user := ts.createSCIMUser("patch_copy_op@test.com", "patch_copy_op@test.com") + + body := map[string]interface{}{ + "schemas": []string{SCIMSchemaPatchOp}, + "Operations": []map[string]interface{}{ + {"op": "copy", "from": "userName", "path": "externalId"}, + }, + } + + req := ts.makeSCIMRequest(http.MethodPatch, "/scim/v2/Users/"+user.ID, body) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + + ts.assertSCIMErrorWithType(w, http.StatusBadRequest, "invalidSyntax") +} + +func (ts *SCIMTestSuite) TestSCIMErrorInvalidPatchOperationMove() { + user := ts.createSCIMUser("patch_move_op@test.com", "patch_move_op@test.com") + + body := map[string]interface{}{ + "schemas": []string{SCIMSchemaPatchOp}, + "Operations": []map[string]interface{}{ + {"op": "move", "from": "userName", "path": "externalId"}, + }, + } + + req := ts.makeSCIMRequest(http.MethodPatch, "/scim/v2/Users/"+user.ID, body) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + + ts.assertSCIMErrorWithType(w, http.StatusBadRequest, "invalidSyntax") +} + +func (ts *SCIMTestSuite) TestSCIMErrorInvalidPatchMissingOperations() { + user := ts.createSCIMUser("patch_missing_ops@test.com", "patch_missing_ops@test.com") + + body := map[string]interface{}{ + "schemas": []string{SCIMSchemaPatchOp}, + } + + req := ts.makeSCIMRequest(http.MethodPatch, "/scim/v2/Users/"+user.ID, body) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + + ts.assertSCIMError(w, http.StatusBadRequest) +} + +func (ts *SCIMTestSuite) TestSCIMErrorInvalidJSON() { + req := httptest.NewRequest(http.MethodPost, "http://localhost/scim/v2/Users", bytes.NewBuffer([]byte("{invalid json"))) + req.Header.Set("Authorization", "Bearer "+ts.SCIMToken) + req.Header.Set("Content-Type", "application/scim+json") + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + + ts.assertSCIMErrorWithType(w, http.StatusBadRequest, "invalidSyntax") +} + +func (ts *SCIMTestSuite) TestSCIMErrorResponseFormatUsers() { + req := ts.makeSCIMRequest(http.MethodGet, "/scim/v2/Users/00000000-0000-0000-0000-000000000000", nil) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + + require.Equal(ts.T(), http.StatusNotFound, w.Code) + + var errorResp map[string]interface{} + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&errorResp)) + + schemas, ok := errorResp["schemas"].([]interface{}) + require.True(ts.T(), ok, "SCIM error must have schemas field per RFC 7644") + require.Len(ts.T(), schemas, 1) + require.Equal(ts.T(), "urn:ietf:params:scim:api:messages:2.0:Error", schemas[0]) + + detail, ok := errorResp["detail"].(string) + require.True(ts.T(), ok, "SCIM error must have detail field per RFC 7644") + require.NotEmpty(ts.T(), detail) + + status, ok := errorResp["status"].(string) + require.True(ts.T(), ok, "SCIM error status must be a string per RFC 7644") + require.Equal(ts.T(), "404", status) +} + +func (ts *SCIMTestSuite) TestSCIMErrorResponseFormatGroups() { + req := ts.makeSCIMRequest(http.MethodGet, "/scim/v2/Groups/00000000-0000-0000-0000-000000000000", nil) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + + require.Equal(ts.T(), http.StatusNotFound, w.Code) + + var errorResp map[string]interface{} + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&errorResp)) + + schemas, ok := errorResp["schemas"].([]interface{}) + require.True(ts.T(), ok, "SCIM error must have schemas field per RFC 7644") + require.Len(ts.T(), schemas, 1) + require.Equal(ts.T(), "urn:ietf:params:scim:api:messages:2.0:Error", schemas[0]) + + detail, ok := errorResp["detail"].(string) + require.True(ts.T(), ok, "SCIM error must have detail field per RFC 7644") + require.NotEmpty(ts.T(), detail) + + status, ok := errorResp["status"].(string) + require.True(ts.T(), ok, "SCIM error status must be a string per RFC 7644") + require.Equal(ts.T(), "404", status) +} + +func (ts *SCIMTestSuite) TestSCIMErrorSchemaValidationMissingRequiredField() { + body := map[string]interface{}{ + "schemas": []string{SCIMSchemaUser}, + "emails": []map[string]interface{}{ + {"value": "test@example.com", "primary": true}, + }, + } + + req := ts.makeSCIMRequest(http.MethodPost, "/scim/v2/Users", body) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + + require.Equal(ts.T(), http.StatusBadRequest, w.Code) + + var errorResp map[string]interface{} + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&errorResp)) + + schemas, ok := errorResp["schemas"].([]interface{}) + require.True(ts.T(), ok, "SCIM error should have schemas field") + require.Len(ts.T(), schemas, 1) + require.Equal(ts.T(), "urn:ietf:params:scim:api:messages:2.0:Error", schemas[0]) + + detail, ok := errorResp["detail"].(string) + require.True(ts.T(), ok, "SCIM error should have detail field") + require.Contains(ts.T(), detail, "userName") + + status, ok := errorResp["status"].(string) + require.True(ts.T(), ok, "SCIM error should have status field") + require.Equal(ts.T(), "400", status) + + scimType, ok := errorResp["scimType"].(string) + require.True(ts.T(), ok, "SCIM error should have scimType field") + require.Equal(ts.T(), "invalidSyntax", scimType) +} + +func (ts *SCIMTestSuite) TestSCIMErrorGroupSchemaValidationMissingDisplayName() { + body := map[string]interface{}{ + "schemas": []string{SCIMSchemaGroup}, + } + + req := ts.makeSCIMRequest(http.MethodPost, "/scim/v2/Groups", body) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + + require.Equal(ts.T(), http.StatusBadRequest, w.Code) + + var errorResp map[string]interface{} + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&errorResp)) + + schemas, ok := errorResp["schemas"].([]interface{}) + require.True(ts.T(), ok, "SCIM error should have schemas field") + require.Len(ts.T(), schemas, 1) + require.Equal(ts.T(), "urn:ietf:params:scim:api:messages:2.0:Error", schemas[0]) + + detail, ok := errorResp["detail"].(string) + require.True(ts.T(), ok, "SCIM error should have detail field") + require.Contains(ts.T(), detail, "displayName") + + status, ok := errorResp["status"].(string) + require.True(ts.T(), ok, "SCIM error should have status field") + require.Equal(ts.T(), "400", status) + + scimType, ok := errorResp["scimType"].(string) + require.True(ts.T(), ok, "SCIM error should have scimType field") + require.Equal(ts.T(), "invalidSyntax", scimType) +} + +func (ts *SCIMTestSuite) TestSCIMReplaceUser() { + user := ts.createSCIMUserWithName(testUser9.UserName, testUser9.Email, testUser9.GivenName, testUser9.FamilyName) + + body := map[string]interface{}{ + "schemas": []string{SCIMSchemaUser}, + "userName": "replaced@acme.com", + "name": map[string]interface{}{ + "givenName": "Replaced", + "familyName": "Name", + "formatted": "Replaced Name", + }, + "emails": []map[string]interface{}{ + {"value": "replaced@acme.com", "primary": true, "type": "work"}, + }, + } + + req := ts.makeSCIMRequest(http.MethodPut, "/scim/v2/Users/"+user.ID, body) + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code, w.Body.String()) + + var result SCIMUserResponse + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&result)) + require.Equal(ts.T(), "replaced@acme.com", result.UserName) + require.NotNil(ts.T(), result.Name) + require.Equal(ts.T(), "Replaced", result.Name.GivenName) + require.Equal(ts.T(), "Name", result.Name.FamilyName) +} + +func (ts *SCIMTestSuite) TestSCIMReplaceUserNotFound() { + req := ts.makeSCIMRequest(http.MethodPut, "/scim/v2/Users/00000000-0000-0000-0000-000000000000", map[string]interface{}{ + "schemas": []string{SCIMSchemaUser}, + "userName": "nobody@acme.com", + "emails": []map[string]interface{}{{"value": "nobody@acme.com", "primary": true}}, + }) + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + ts.assertSCIMError(w, http.StatusNotFound) +} + +func (ts *SCIMTestSuite) TestSCIMReplaceGroup() { + group := ts.createSCIMGroupWithExternalID(testGroup1.DisplayName, testGroup1.ExternalID) + + body := map[string]interface{}{ + "schemas": []string{SCIMSchemaGroup}, + "displayName": "Replaced Engineering", + "externalId": "replaced-ext-001", + } + + req := ts.makeSCIMRequest(http.MethodPut, "/scim/v2/Groups/"+group.ID, body) + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code, w.Body.String()) + + var result SCIMGroupResponse + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&result)) + require.Equal(ts.T(), "Replaced Engineering", result.DisplayName) + require.Equal(ts.T(), "replaced-ext-001", result.ExternalID) +} + +func (ts *SCIMTestSuite) TestSCIMReplaceGroupNotFound() { + req := ts.makeSCIMRequest(http.MethodPut, "/scim/v2/Groups/00000000-0000-0000-0000-000000000000", map[string]interface{}{ + "schemas": []string{SCIMSchemaGroup}, + "displayName": "Ghost", + }) + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + ts.assertSCIMError(w, http.StatusNotFound) +} + +func (ts *SCIMTestSuite) TestSCIMCrossProviderIsolationUsers() { + user := ts.createSCIMUser(testUser1.UserName, testUser1.Email) + + provider2 := &models.SSOProvider{} + require.NoError(ts.T(), ts.API.db.Create(provider2)) + token2 := "other-provider-token" + provider2.SetSCIMToken(token2) + require.NoError(ts.T(), ts.API.db.Update(provider2)) + + req := ts.makeSCIMRequest(http.MethodGet, "/scim/v2/Users/"+user.ID, nil) + req.Header.Set("Authorization", "Bearer "+token2) + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + ts.assertSCIMError(w, http.StatusNotFound) +} + +func (ts *SCIMTestSuite) TestSCIMCrossProviderIsolationGroups() { + group := ts.createSCIMGroup(testGroup1.DisplayName) + + provider2 := &models.SSOProvider{} + require.NoError(ts.T(), ts.API.db.Create(provider2)) + token2 := "other-provider-token" + provider2.SetSCIMToken(token2) + require.NoError(ts.T(), ts.API.db.Update(provider2)) + + req := ts.makeSCIMRequest(http.MethodGet, "/scim/v2/Groups/"+group.ID, nil) + req.Header.Set("Authorization", "Bearer "+token2) + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + ts.assertSCIMError(w, http.StatusNotFound) +} + +func (ts *SCIMTestSuite) TestSCIMPutEmailUniqueness() { + userA := ts.createSCIMUser("uniqueA@acme.com", "uniqueA@acme.com") + ts.createSCIMUser("uniqueB@acme.com", "uniqueB@acme.com") + + body := map[string]interface{}{ + "schemas": []string{SCIMSchemaUser}, + "userName": "uniqueB@acme.com", + "emails": []map[string]interface{}{{"value": "uniqueB@acme.com", "primary": true}}, + } + + req := ts.makeSCIMRequest(http.MethodPut, "/scim/v2/Users/"+userA.ID, body) + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + ts.assertSCIMErrorWithType(w, http.StatusConflict, "uniqueness") +} + +func (ts *SCIMTestSuite) TestSCIMPatchEmailUniqueness() { + userA := ts.createSCIMUser("patchA@acme.com", "patchA@acme.com") + ts.createSCIMUser("patchB@acme.com", "patchB@acme.com") + + body := map[string]interface{}{ + "schemas": []string{SCIMSchemaPatchOp}, + "Operations": []map[string]interface{}{ + {"op": "replace", "path": "emails[value eq \"patchA@acme.com\"].value", "value": "patchB@acme.com"}, + }, + } + + req := ts.makeSCIMRequest(http.MethodPatch, "/scim/v2/Users/"+userA.ID, body) + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + ts.assertSCIMErrorWithType(w, http.StatusConflict, "uniqueness") +} + +func (ts *SCIMTestSuite) TestSCIMErrorResponseContentType() { + req := ts.makeSCIMRequest(http.MethodGet, "/scim/v2/Users/not-a-uuid", nil) + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusNotFound, w.Code) + require.Equal(ts.T(), "application/scim+json", w.Header().Get("Content-Type")) +} + +func (ts *SCIMTestSuite) adminToken() string { + claims := &AccessTokenClaims{ + Role: "supabase_admin", + } + token, err := jwt.NewWithClaims(jwt.SigningMethodHS256, claims).SignedString([]byte(ts.Config.JWT.Secret)) + require.NoError(ts.T(), err) + return token +} + +func (ts *SCIMTestSuite) makeAdminRequest(method, path string, body interface{}) *http.Request { + var reqBody *bytes.Buffer + if body != nil { + jsonBody, err := json.Marshal(body) + require.NoError(ts.T(), err) + reqBody = bytes.NewBuffer(jsonBody) + } else { + reqBody = bytes.NewBuffer(nil) + } + req := httptest.NewRequest(method, "http://localhost"+path, reqBody) + req.Header.Set("Authorization", "Bearer "+ts.adminToken()) + req.Header.Set("Content-Type", "application/json") + return req +} + +func (ts *SCIMTestSuite) TestSCIMAdminGetConfig() { + req := ts.makeAdminRequest(http.MethodGet, "/admin/sso/providers/"+ts.SSOProvider.ID.String()+"/scim", nil) + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + var result map[string]interface{} + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&result)) + require.Equal(ts.T(), true, result["enabled"]) + require.Equal(ts.T(), true, result["token_set"]) + require.NotEmpty(ts.T(), result["base_url"]) +} + +func (ts *SCIMTestSuite) TestSCIMAdminEnableSCIM() { + provider := &models.SSOProvider{} + require.NoError(ts.T(), ts.API.db.Create(provider)) + + req := ts.makeAdminRequest(http.MethodPost, "/admin/sso/providers/"+provider.ID.String()+"/scim", nil) + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + var result map[string]interface{} + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&result)) + require.Equal(ts.T(), true, result["enabled"]) + require.NotEmpty(ts.T(), result["token"]) + require.NotEmpty(ts.T(), result["base_url"]) +} + +func (ts *SCIMTestSuite) TestSCIMAdminDisableSCIM() { + req := ts.makeAdminRequest(http.MethodDelete, "/admin/sso/providers/"+ts.SSOProvider.ID.String()+"/scim", nil) + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + var result map[string]interface{} + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&result)) + require.Equal(ts.T(), false, result["enabled"]) + + scimReq := ts.makeSCIMRequest(http.MethodGet, "/scim/v2/Users", nil) + scimW := httptest.NewRecorder() + ts.API.handler.ServeHTTP(scimW, scimReq) + require.Equal(ts.T(), http.StatusUnauthorized, scimW.Code) +} + +func (ts *SCIMTestSuite) TestSCIMAdminRotateToken() { + scimReq := ts.makeSCIMRequest(http.MethodGet, "/scim/v2/Users", nil) + scimW := httptest.NewRecorder() + ts.API.handler.ServeHTTP(scimW, scimReq) + require.Equal(ts.T(), http.StatusOK, scimW.Code) + + req := ts.makeAdminRequest(http.MethodPost, "/admin/sso/providers/"+ts.SSOProvider.ID.String()+"/scim/rotate", nil) + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + var result map[string]interface{} + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&result)) + require.Equal(ts.T(), true, result["enabled"]) + newToken, ok := result["token"].(string) + require.True(ts.T(), ok) + require.NotEmpty(ts.T(), newToken) + + scimReq = ts.makeSCIMRequest(http.MethodGet, "/scim/v2/Users", nil) + scimW = httptest.NewRecorder() + ts.API.handler.ServeHTTP(scimW, scimReq) + require.Equal(ts.T(), http.StatusUnauthorized, scimW.Code) + + scimReq2 := httptest.NewRequest(http.MethodGet, "http://localhost/scim/v2/Users", nil) + scimReq2.Header.Set("Authorization", "Bearer "+newToken) + scimReq2.Header.Set("Content-Type", "application/scim+json") + scimW2 := httptest.NewRecorder() + ts.API.handler.ServeHTTP(scimW2, scimReq2) + require.Equal(ts.T(), http.StatusOK, scimW2.Code) +} + +func (ts *SCIMTestSuite) TestSCIMAdminRotateTokenWhenDisabled() { + provider := &models.SSOProvider{} + require.NoError(ts.T(), ts.API.db.Create(provider)) + + req := ts.makeAdminRequest(http.MethodPost, "/admin/sso/providers/"+provider.ID.String()+"/scim/rotate", nil) + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusBadRequest, w.Code) + + var result map[string]interface{} + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&result)) + require.Equal(ts.T(), "scim_disabled", result["error_code"]) +} + +func (ts *SCIMTestSuite) TestSCIMDisabledSCIMProvider() { + provider := &models.SSOProvider{} + require.NoError(ts.T(), ts.API.db.Create(provider)) + token := "disabled-scim-provider-token" + provider.SetSCIMToken(token) + provider.ClearSCIMToken() + require.NoError(ts.T(), ts.API.db.Update(provider)) + + req := httptest.NewRequest(http.MethodGet, "http://localhost/scim/v2/Users", nil) + req.Header.Set("Authorization", "Bearer "+token) + req.Header.Set("Content-Type", "application/scim+json") + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusUnauthorized, w.Code) +} + +func (ts *SCIMTestSuite) TestSCIMDisabledSSOProvider() { + provider := &models.SSOProvider{} + require.NoError(ts.T(), ts.API.db.Create(provider)) + token := "disabled-sso-provider-token" + provider.SetSCIMToken(token) + disabled := true + provider.Disabled = &disabled + require.NoError(ts.T(), ts.API.db.Update(provider)) + + req := httptest.NewRequest(http.MethodGet, "http://localhost/scim/v2/Users", nil) + req.Header.Set("Authorization", "Bearer "+token) + req.Header.Set("Content-Type", "application/scim+json") + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusForbidden, w.Code) + + var errorResp map[string]interface{} + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&errorResp)) + detail, ok := errorResp["detail"].(string) + require.True(ts.T(), ok) + require.Contains(ts.T(), detail, "SSO provider is disabled") +} + +func (ts *SCIMTestSuite) createFilterTestUsers() { + ts.createSCIMUserWithExternalID("user1@acme.com", "user1@acme.com", "ext-f-001") + ts.createSCIMUserWithExternalID("user2@acme.com", "user2@acme.com", "ext-f-002") + ts.createSCIMUserWithExternalID("user3@other.com", "user3@other.com", "ext-f-003") + ts.createSCIMUser("user4@acme.com", "user4@acme.com") + ts.createSCIMUser("user5@other.com", "user5@other.com") +} + +func (ts *SCIMTestSuite) TestSCIMFilterNE() { + ts.createFilterTestUsers() + + req := ts.makeSCIMRequest(http.MethodGet, "/scim/v2/Users?filter=userName+ne+%22user1%40acme.com%22", nil) + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + var result SCIMListResponse + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&result)) + require.Equal(ts.T(), 4, result.TotalResults) + for _, r := range result.Resources { + resource := r.(map[string]interface{}) + require.NotEqual(ts.T(), "user1@acme.com", resource["userName"]) + } +} + +func (ts *SCIMTestSuite) TestSCIMFilterCO() { + ts.createFilterTestUsers() + + req := ts.makeSCIMRequest(http.MethodGet, "/scim/v2/Users?filter=userName+co+%22acme%22", nil) + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + var result SCIMListResponse + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&result)) + require.Equal(ts.T(), 3, result.TotalResults) + for _, r := range result.Resources { + resource := r.(map[string]interface{}) + require.Contains(ts.T(), resource["userName"], "acme") + } +} + +func (ts *SCIMTestSuite) TestSCIMFilterSW() { + ts.createFilterTestUsers() + + req := ts.makeSCIMRequest(http.MethodGet, "/scim/v2/Users?filter=userName+sw+%22user1%22", nil) + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + var result SCIMListResponse + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&result)) + require.Equal(ts.T(), 1, result.TotalResults) + resource := result.Resources[0].(map[string]interface{}) + require.Equal(ts.T(), "user1@acme.com", resource["userName"]) +} + +func (ts *SCIMTestSuite) TestSCIMFilterEW() { + ts.createFilterTestUsers() + + req := ts.makeSCIMRequest(http.MethodGet, "/scim/v2/Users?filter=userName+ew+%22acme.com%22", nil) + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + var result SCIMListResponse + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&result)) + require.Equal(ts.T(), 3, result.TotalResults) + for _, r := range result.Resources { + resource := r.(map[string]interface{}) + userName := resource["userName"].(string) + require.True(ts.T(), strings.HasSuffix(userName, "acme.com")) + } +} + +func (ts *SCIMTestSuite) TestSCIMFilterPR() { + ts.createFilterTestUsers() + + req := ts.makeSCIMRequest(http.MethodGet, "/scim/v2/Users?filter=externalId+pr", nil) + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + var result SCIMListResponse + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&result)) + require.Equal(ts.T(), 3, result.TotalResults) + for _, r := range result.Resources { + resource := r.(map[string]interface{}) + require.NotEmpty(ts.T(), resource["externalId"]) + } +} + +func (ts *SCIMTestSuite) TestSCIMFilterAnd() { + ts.createFilterTestUsers() + + req := ts.makeSCIMRequest(http.MethodGet, "/scim/v2/Users?filter=userName+sw+%22user%22+and+userName+ew+%22acme.com%22", nil) + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + var result SCIMListResponse + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&result)) + require.Equal(ts.T(), 3, result.TotalResults) +} + +func (ts *SCIMTestSuite) TestSCIMFilterOr() { + ts.createFilterTestUsers() + + req := ts.makeSCIMRequest(http.MethodGet, "/scim/v2/Users?filter=userName+eq+%22user1%40acme.com%22+or+userName+eq+%22user2%40acme.com%22", nil) + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + var result SCIMListResponse + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&result)) + require.Equal(ts.T(), 2, result.TotalResults) +} + +func (ts *SCIMTestSuite) TestSCIMFilterNot() { + ts.createFilterTestUsers() + + req := ts.makeSCIMRequest(http.MethodGet, "/scim/v2/Users?filter=not+userName+eq+%22user1%40acme.com%22", nil) + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + var result SCIMListResponse + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&result)) + require.Equal(ts.T(), 4, result.TotalResults) + for _, r := range result.Resources { + resource := r.(map[string]interface{}) + require.NotEqual(ts.T(), "user1@acme.com", resource["userName"]) + } +} + +func (ts *SCIMTestSuite) TestSCIMFilterEmailsValuePath() { + ts.createFilterTestUsers() + + req := ts.makeSCIMRequest(http.MethodGet, "/scim/v2/Users?filter=emails%5Bvalue+eq+%22user1%40acme.com%22%5D", nil) + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + var result SCIMListResponse + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&result)) + require.Equal(ts.T(), 1, result.TotalResults) +} + +func (ts *SCIMTestSuite) TestSCIMGroupFilterCO() { + ts.createSCIMGroupWithExternalID("Engineering Team", "grp-fc-001") + ts.createSCIMGroupWithExternalID("Sales Team", "grp-fc-002") + ts.createSCIMGroupWithExternalID("Eng Ops", "grp-fc-003") + + req := ts.makeSCIMRequest(http.MethodGet, "/scim/v2/Groups?filter=displayName+co+%22Eng%22", nil) + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + var result SCIMListResponse + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&result)) + require.Equal(ts.T(), 2, result.TotalResults) +} + +func (ts *SCIMTestSuite) TestSCIMBodyExceedsMaxSize() { + largeBody := strings.Repeat("x", SCIMMaxBodySize+1) + req := httptest.NewRequest(http.MethodPost, "http://localhost/scim/v2/Users", bytes.NewBufferString(largeBody)) + req.Header.Set("Authorization", "Bearer "+ts.SCIMToken) + req.Header.Set("Content-Type", "application/scim+json") + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.True(ts.T(), w.Code >= 400, "Expected error status for oversized body, got %d", w.Code) +} + +func (ts *SCIMTestSuite) TestSCIMFilterExceedsMaxLength() { + longFilter := "userName eq \"" + strings.Repeat("a", SCIMMaxFilterLength+1) + "\"" + req := ts.makeSCIMRequest(http.MethodGet, "/scim/v2/Users?filter="+longFilter, nil) + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + ts.assertSCIMErrorWithType(w, http.StatusBadRequest, "invalidFilter") +} + +func (ts *SCIMTestSuite) TestSCIMResourceTypeByID() { + req := ts.makeSCIMRequest(http.MethodGet, "/scim/v2/ResourceTypes/User", nil) + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + var result map[string]interface{} + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&result)) + require.Equal(ts.T(), "User", result["id"]) + require.Equal(ts.T(), "User", result["name"]) +} + +func (ts *SCIMTestSuite) TestSCIMResourceTypeByIDGroup() { + req := ts.makeSCIMRequest(http.MethodGet, "/scim/v2/ResourceTypes/Group", nil) + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + var result map[string]interface{} + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&result)) + require.Equal(ts.T(), "Group", result["id"]) + require.Equal(ts.T(), "Group", result["name"]) +} + +func (ts *SCIMTestSuite) TestSCIMResourceTypeByIDNotFound() { + req := ts.makeSCIMRequest(http.MethodGet, "/scim/v2/ResourceTypes/Invalid", nil) + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + ts.assertSCIMError(w, http.StatusNotFound) +} + +func (ts *SCIMTestSuite) TestSCIMSchemaByID() { + req := ts.makeSCIMRequest(http.MethodGet, "/scim/v2/Schemas/"+SCIMSchemaUser, nil) + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + var result map[string]interface{} + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&result)) + require.Equal(ts.T(), SCIMSchemaUser, result["id"]) + require.Equal(ts.T(), "User", result["name"]) +} + +func (ts *SCIMTestSuite) TestSCIMSchemaByIDNotFound() { + req := ts.makeSCIMRequest(http.MethodGet, "/scim/v2/Schemas/invalid", nil) + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + ts.assertSCIMError(w, http.StatusNotFound) +} + +func (ts *SCIMTestSuite) TestSCIMNotFoundRoute() { + req := ts.makeSCIMRequest(http.MethodGet, "/scim/v2/nonexistent", nil) + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusNotFound, w.Code) + + var errorResp map[string]interface{} + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&errorResp)) + schemas, ok := errorResp["schemas"].([]interface{}) + require.True(ts.T(), ok) + require.Len(ts.T(), schemas, 1) + require.Equal(ts.T(), "urn:ietf:params:scim:api:messages:2.0:Error", schemas[0]) +} + +func (ts *SCIMTestSuite) TestSCIMPaginationCountZero() { + for i := 0; i < 3; i++ { + ts.createSCIMUser(fmt.Sprintf("pagezero%d@acme.com", i), fmt.Sprintf("pagezero%d@acme.com", i)) + } + + req := ts.makeSCIMRequest(http.MethodGet, "/scim/v2/Users?count=0", nil) + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + var result SCIMListResponse + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&result)) + require.Equal(ts.T(), 3, result.TotalResults) + require.Empty(ts.T(), result.Resources) +} + +func (ts *SCIMTestSuite) TestSCIMPaginationStartIndexExceedsTotal() { + for i := 0; i < 5; i++ { + ts.createSCIMUser(fmt.Sprintf("pageexceed%d@acme.com", i), fmt.Sprintf("pageexceed%d@acme.com", i)) + } + + req := ts.makeSCIMRequest(http.MethodGet, "/scim/v2/Users?startIndex=999", nil) + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + var result SCIMListResponse + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&result)) + require.Equal(ts.T(), 5, result.TotalResults) + require.Empty(ts.T(), result.Resources) +} + +func (ts *SCIMTestSuite) TestSCIMGroupPagination() { + for i := 0; i < 5; i++ { + ts.createSCIMGroupWithExternalID(fmt.Sprintf("PagGroup%d", i), fmt.Sprintf("pag-grp-%d", i)) + } + + req := ts.makeSCIMRequest(http.MethodGet, "/scim/v2/Groups?startIndex=1&count=2", nil) + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + var result SCIMListResponse + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&result)) + require.Equal(ts.T(), 5, result.TotalResults) + require.Len(ts.T(), result.Resources, 2) +} + +func (ts *SCIMTestSuite) setupCrossProviderIsolation() (string, *SCIMUserResponse, *SCIMGroupResponse) { + user := ts.createSCIMUser("cross_iso@acme.com", "cross_iso@acme.com") + group := ts.createSCIMGroup("CrossIsoGroup") + + provider2 := &models.SSOProvider{} + require.NoError(ts.T(), ts.API.db.Create(provider2)) + token2 := "cross-provider-iso-token" + provider2.SetSCIMToken(token2) + require.NoError(ts.T(), ts.API.db.Update(provider2)) + + return token2, user, group +} + +func (ts *SCIMTestSuite) TestSCIMCrossProviderPatchUser() { + token2, user, _ := ts.setupCrossProviderIsolation() + + body := map[string]interface{}{ + "schemas": []string{SCIMSchemaPatchOp}, + "Operations": []map[string]interface{}{ + {"op": "replace", "value": map[string]interface{}{"userName": "hacked@evil.com"}}, + }, + } + + req := ts.makeSCIMRequest(http.MethodPatch, "/scim/v2/Users/"+user.ID, body) + req.Header.Set("Authorization", "Bearer "+token2) + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + ts.assertSCIMError(w, http.StatusNotFound) +} + +func (ts *SCIMTestSuite) TestSCIMCrossProviderPutUser() { + token2, user, _ := ts.setupCrossProviderIsolation() + + body := map[string]interface{}{ + "schemas": []string{SCIMSchemaUser}, + "userName": "hacked@evil.com", + "emails": []map[string]interface{}{{"value": "hacked@evil.com", "primary": true}}, + } + + req := ts.makeSCIMRequest(http.MethodPut, "/scim/v2/Users/"+user.ID, body) + req.Header.Set("Authorization", "Bearer "+token2) + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + ts.assertSCIMError(w, http.StatusNotFound) +} + +func (ts *SCIMTestSuite) TestSCIMCrossProviderDeleteUser() { + token2, user, _ := ts.setupCrossProviderIsolation() + + req := ts.makeSCIMRequest(http.MethodDelete, "/scim/v2/Users/"+user.ID, nil) + req.Header.Set("Authorization", "Bearer "+token2) + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + ts.assertSCIMError(w, http.StatusNotFound) +} + +func (ts *SCIMTestSuite) TestSCIMCrossProviderPatchGroup() { + token2, _, group := ts.setupCrossProviderIsolation() + + body := map[string]interface{}{ + "schemas": []string{SCIMSchemaPatchOp}, + "Operations": []map[string]interface{}{ + {"op": "replace", "value": map[string]interface{}{"displayName": "HackedGroup"}}, + }, + } + + req := ts.makeSCIMRequest(http.MethodPatch, "/scim/v2/Groups/"+group.ID, body) + req.Header.Set("Authorization", "Bearer "+token2) + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + ts.assertSCIMError(w, http.StatusNotFound) +} + +func (ts *SCIMTestSuite) TestSCIMCrossProviderDeleteGroup() { + token2, _, group := ts.setupCrossProviderIsolation() + + req := ts.makeSCIMRequest(http.MethodDelete, "/scim/v2/Groups/"+group.ID, nil) + req.Header.Set("Authorization", "Bearer "+token2) + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + ts.assertSCIMError(w, http.StatusNotFound) +} + +func (ts *SCIMTestSuite) TestSCIMPatchGroupReplaceExternalIDWithPath() { + group := ts.createSCIMGroupWithExternalID("ExtIDPathGroup", "orig-ext-id") + + body := map[string]interface{}{ + "schemas": []string{SCIMSchemaPatchOp}, + "Operations": []map[string]interface{}{ + {"op": "replace", "path": "externalId", "value": "new-ext-id"}, + }, + } + + req := ts.makeSCIMRequest(http.MethodPatch, "/scim/v2/Groups/"+group.ID, body) + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + var result SCIMGroupResponse + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&result)) + require.Equal(ts.T(), "new-ext-id", result.ExternalID) +} + +func (ts *SCIMTestSuite) TestSCIMPatchGroupAddMemberWrongProvider() { + group := ts.createSCIMGroup("WrongProviderGroup") + + provider2 := &models.SSOProvider{} + require.NoError(ts.T(), ts.API.db.Create(provider2)) + token2 := "wrong-provider-member-token" + provider2.SetSCIMToken(token2) + require.NoError(ts.T(), ts.API.db.Update(provider2)) + + userBody := map[string]interface{}{ + "schemas": []string{SCIMSchemaUser}, + "userName": "otherprovider@test.com", + "emails": []map[string]interface{}{{"value": "otherprovider@test.com", "primary": true}}, + } + userReq := ts.makeSCIMRequest(http.MethodPost, "/scim/v2/Users", userBody) + userReq.Header.Set("Authorization", "Bearer "+token2) + userW := httptest.NewRecorder() + ts.API.handler.ServeHTTP(userW, userReq) + require.Equal(ts.T(), http.StatusCreated, userW.Code) + + var otherUser SCIMUserResponse + require.NoError(ts.T(), json.NewDecoder(userW.Body).Decode(&otherUser)) + + body := map[string]interface{}{ + "schemas": []string{SCIMSchemaPatchOp}, + "Operations": []map[string]interface{}{ + {"op": "add", "path": "members", "value": []map[string]interface{}{ + {"value": otherUser.ID}, + }}, + }, + } + + req := ts.makeSCIMRequest(http.MethodPatch, "/scim/v2/Groups/"+group.ID, body) + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.True(ts.T(), w.Code >= 400, "Adding cross-provider member should fail, got %d: %s", w.Code, w.Body.String()) +} + +func (ts *SCIMTestSuite) TestSCIMPatchGroupAddNonExistentMember() { + group := ts.createSCIMGroup("NonExistentMemberGroup") + + body := map[string]interface{}{ + "schemas": []string{SCIMSchemaPatchOp}, + "Operations": []map[string]interface{}{ + {"op": "add", "path": "members", "value": []map[string]interface{}{ + {"value": "00000000-0000-0000-0000-000000000000"}, + }}, + }, + } + + req := ts.makeSCIMRequest(http.MethodPatch, "/scim/v2/Groups/"+group.ID, body) + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.True(ts.T(), w.Code >= 400, "Adding non-existent member should fail, got %d: %s", w.Code, w.Body.String()) +} + +func (ts *SCIMTestSuite) TestSCIMPatchGroupRemoveWithoutPath() { + group := ts.createSCIMGroup("RemoveNoPathGroup") + + body := map[string]interface{}{ + "schemas": []string{SCIMSchemaPatchOp}, + "Operations": []map[string]interface{}{ + {"op": "remove"}, + }, + } + + req := ts.makeSCIMRequest(http.MethodPatch, "/scim/v2/Groups/"+group.ID, body) + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + ts.assertSCIMErrorWithType(w, http.StatusBadRequest, "noTarget") +} diff --git a/internal/api/scim_types.go b/internal/api/scim_types.go new file mode 100644 index 000000000..dd4ba3d12 --- /dev/null +++ b/internal/api/scim_types.go @@ -0,0 +1,189 @@ +package api + +import ( + "encoding/json" + "fmt" + "strings" + "time" + + "github.com/supabase/auth/internal/api/apierrors" +) + +const ( + SCIMDefaultPageSize = 100 + SCIMMaxPageSize = 1000 + SCIMMaxBodySize = 1 << 20 // 1 MB + SCIMMaxMembers = 1000 + SCIMMaxPatchOperations = 100 + SCIMMaxStartIndex = 100000 + SCIMSchemaUser = "urn:ietf:params:scim:schemas:core:2.0:User" + SCIMSchemaGroup = "urn:ietf:params:scim:schemas:core:2.0:Group" + SCIMSchemaListResponse = "urn:ietf:params:scim:api:messages:2.0:ListResponse" + SCIMSchemaPatchOp = "urn:ietf:params:scim:api:messages:2.0:PatchOp" + + scimErrUserNotFound = "User not found" + scimErrGroupNotFound = "Group not found" + scimErrEmailConflict = "Email already in use by another user" + scimErrExternalIDConflict = "User with this externalId already exists" + scimErrUserNameConflict = "User with this userName already exists" + scimErrGroupExternalIDConflict = "Group with this externalId already exists" + scimErrGroupDisplayNameConflict = "Group with this displayName already exists" + scimErrMembersNotFound = "One or more members not found" + scimErrMembersWrongProvider = "One or more members do not belong to this SSO provider" + scimErrAmbiguousDeprovisioned = "Multiple deprovisioned users exist for this email" +) + +// Must be var (not const) because it's passed by pointer to user.Ban() +var scimDeprovisionedReason = "SCIM_DEPROVISIONED" + +// FlexBool handles both bool and string ("true"/"false") - Azure AD sends strings +type FlexBool bool + +func (fb *FlexBool) UnmarshalJSON(data []byte) error { + var b bool + if err := json.Unmarshal(data, &b); err == nil { + *fb = FlexBool(b) + return nil + } + var s string + if err := json.Unmarshal(data, &s); err == nil { + switch strings.ToLower(s) { + case "true": + *fb = FlexBool(true) + case "false": + *fb = FlexBool(false) + default: + return fmt.Errorf("cannot unmarshal %q into FlexBool: must be true or false", s) + } + return nil + } + return fmt.Errorf("cannot unmarshal %s into FlexBool", string(data)) +} + +type SCIMUserParams struct { + Schemas []string `json:"schemas"` + ExternalID string `json:"externalId"` + UserName string `json:"userName"` + Name *SCIMName `json:"name,omitempty"` + Emails []SCIMEmail `json:"emails,omitempty"` + Active *FlexBool `json:"active,omitempty"` +} + +func (p *SCIMUserParams) Validate() error { + if err := requireSCIMSchema(p.Schemas, SCIMSchemaUser); err != nil { + return err + } + if p.UserName == "" { + return apierrors.NewSCIMBadRequestError("userName is required", "invalidSyntax") + } + return nil +} + +type SCIMName struct { + Formatted string `json:"formatted,omitempty"` + FamilyName string `json:"familyName,omitempty"` + GivenName string `json:"givenName,omitempty"` +} + +type SCIMEmail struct { + Value string `json:"value"` + Type string `json:"type,omitempty"` + Primary FlexBool `json:"primary,omitempty"` +} + +type SCIMGroupParams struct { + Schemas []string `json:"schemas"` + ExternalID string `json:"externalId"` + DisplayName string `json:"displayName"` + Members []SCIMGroupMemberRef `json:"members,omitempty"` +} + +func (p *SCIMGroupParams) Validate() error { + if err := requireSCIMSchema(p.Schemas, SCIMSchemaGroup); err != nil { + return err + } + if p.DisplayName == "" { + return apierrors.NewSCIMBadRequestError("displayName is required", "invalidSyntax") + } + if len(p.Members) > SCIMMaxMembers { + return apierrors.NewSCIMRequestTooLargeError(fmt.Sprintf("Maximum %d members per request", SCIMMaxMembers)) + } + return nil +} + +type SCIMGroupMemberRef struct { + Value string `json:"value"` + Ref string `json:"$ref,omitempty"` + Display string `json:"display,omitempty"` +} + +type SCIMPatchRequest struct { + Schemas []string `json:"schemas"` + Operations []SCIMPatchOperation `json:"Operations"` +} + +func (p *SCIMPatchRequest) Validate() error { + if err := requireSCIMSchema(p.Schemas, SCIMSchemaPatchOp); err != nil { + return err + } + if len(p.Operations) == 0 { + return apierrors.NewSCIMBadRequestError("At least one operation is required", "invalidSyntax") + } + if len(p.Operations) > SCIMMaxPatchOperations { + return apierrors.NewSCIMRequestTooLargeError(fmt.Sprintf("Maximum %d operations per request", SCIMMaxPatchOperations)) + } + return nil +} + +type SCIMPatchOperation struct { + Op string `json:"op"` + Path string `json:"path,omitempty"` + Value interface{} `json:"value,omitempty"` +} + +type SCIMMeta struct { + ResourceType string `json:"resourceType"` + Created *time.Time `json:"created,omitempty"` + LastModified *time.Time `json:"lastModified,omitempty"` + Location string `json:"location,omitempty"` +} + +type SCIMUserResponse struct { + Schemas []string `json:"schemas"` + ID string `json:"id"` + ExternalID string `json:"externalId,omitempty"` + UserName string `json:"userName"` + Name *SCIMName `json:"name,omitempty"` + Emails []SCIMEmail `json:"emails,omitempty"` + Active bool `json:"active"` + Meta SCIMMeta `json:"meta"` +} + +type SCIMGroupResponse struct { + Schemas []string `json:"schemas"` + ID string `json:"id"` + ExternalID string `json:"externalId,omitempty"` + DisplayName string `json:"displayName"` + Members []SCIMGroupMemberRef `json:"members,omitempty"` + Meta SCIMMeta `json:"meta"` +} + +type SCIMListResponse struct { + Schemas []string `json:"schemas"` + TotalResults int `json:"totalResults"` + StartIndex int `json:"startIndex"` + ItemsPerPage int `json:"itemsPerPage"` + Resources []interface{} `json:"Resources"` +} + +func requireSCIMSchema(schemas []string, required string) error { + for _, s := range schemas { + if s == required { + return nil + } + } + return apierrors.NewSCIMBadRequestError( + fmt.Sprintf("schemas must include %s", required), + "invalidValue", + ) +} diff --git a/internal/api/ssoadmin.go b/internal/api/ssoadmin.go index 1b1d9519c..3af02790a 100644 --- a/internal/api/ssoadmin.go +++ b/internal/api/ssoadmin.go @@ -13,12 +13,19 @@ import ( "github.com/go-chi/chi/v5" "github.com/gofrs/uuid" "github.com/supabase/auth/internal/api/apierrors" + "github.com/supabase/auth/internal/crypto" "github.com/supabase/auth/internal/models" "github.com/supabase/auth/internal/observability" "github.com/supabase/auth/internal/storage" "github.com/supabase/auth/internal/utilities" ) +const SCIMTokenPrefix = "scim_" + +func generateSCIMToken() string { + return SCIMTokenPrefix + crypto.SecureAlphanumeric(32) +} + // loadSSOProvider looks for an idp_id and first checks it for a "resource_" // prefix, if present the provider is loaded by resource_id. Otherwise the // provider is loaded by id. @@ -460,3 +467,83 @@ func (a *API) adminSSOProvidersDelete(w http.ResponseWriter, r *http.Request) er return sendJSON(w, http.StatusOK, provider) } + +// adminSSOProviderGetSCIM returns the SCIM configuration for an SSO provider. +func (a *API) adminSSOProviderGetSCIM(w http.ResponseWriter, r *http.Request) error { + provider := getSSOProvider(r.Context()) + + return sendJSON(w, http.StatusOK, map[string]interface{}{ + "enabled": provider.IsSCIMEnabled(), + "token_set": provider.SCIMBearerTokenHash != nil, + "base_url": a.getSCIMBaseURL() + "/scim/v2", + }) +} + +// adminSSOProviderEnableSCIM enables SCIM for an SSO provider and generates a new token. +func (a *API) adminSSOProviderEnableSCIM(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + db := a.db.WithContext(ctx) + + provider := getSSOProvider(ctx) + + token := generateSCIMToken() + + if err := db.Transaction(func(tx *storage.Connection) error { + provider.SetSCIMToken(token) + return tx.UpdateOnly(provider, "scim_enabled", "scim_bearer_token_hash") + }); err != nil { + return err + } + + return sendJSON(w, http.StatusOK, map[string]interface{}{ + "enabled": true, + "token": token, + "base_url": a.getSCIMBaseURL() + "/scim/v2", + }) +} + +// adminSSOProviderDisableSCIM disables SCIM for an SSO provider. +func (a *API) adminSSOProviderDisableSCIM(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + db := a.db.WithContext(ctx) + + provider := getSSOProvider(ctx) + provider.ClearSCIMToken() + + if err := db.Transaction(func(tx *storage.Connection) error { + return tx.UpdateOnly(provider, "scim_enabled", "scim_bearer_token_hash") + }); err != nil { + return err + } + + return sendJSON(w, http.StatusOK, map[string]interface{}{ + "enabled": false, + }) +} + +// adminSSOProviderRotateSCIMToken rotates the SCIM token for an SSO provider. +func (a *API) adminSSOProviderRotateSCIMToken(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + db := a.db.WithContext(ctx) + + provider := getSSOProvider(ctx) + + if !provider.IsSCIMEnabled() { + return apierrors.NewBadRequestError(apierrors.ErrorCodeSCIMDisabled, "SCIM is not enabled for this provider") + } + + token := generateSCIMToken() + + if err := db.Transaction(func(tx *storage.Connection) error { + provider.SetSCIMToken(token) + return tx.UpdateOnly(provider, "scim_bearer_token_hash") + }); err != nil { + return err + } + + return sendJSON(w, http.StatusOK, map[string]interface{}{ + "enabled": true, + "token": token, + "base_url": a.getSCIMBaseURL() + "/scim/v2", + }) +} diff --git a/internal/conf/configuration.go b/internal/conf/configuration.go index d5df496ef..518a45301 100644 --- a/internal/conf/configuration.go +++ b/internal/conf/configuration.go @@ -351,6 +351,7 @@ type GlobalConfiguration struct { Sessions SessionsConfiguration `json:"sessions"` MFA MFAConfiguration `json:"MFA"` SAML SAMLConfiguration `json:"saml"` + SCIM SCIMConfiguration `json:"scim"` CORS CORSConfiguration `json:"cors"` IndexWorker IndexWorkerConfiguration `json:"index_worker" split_words:"true"` @@ -358,6 +359,10 @@ type GlobalConfiguration struct { Reloading ReloadingConfiguration `json:"reloading"` } +type SCIMConfiguration struct { + Enabled bool `json:"enabled" default:"false"` +} + type CORSConfiguration struct { AllowedHeaders []string `json:"allowed_headers" split_words:"true"` } diff --git a/internal/models/connection.go b/internal/models/connection.go index 82a5e8775..35fcccc35 100644 --- a/internal/models/connection.go +++ b/internal/models/connection.go @@ -50,6 +50,8 @@ func TruncateAll(conn *storage.Connection) error { (&pop.Model{Value: FlowState{}}).TableName(), (&pop.Model{Value: OneTimeToken{}}).TableName(), (&pop.Model{Value: OAuthServerClient{}}).TableName(), + (&pop.Model{Value: SCIMGroup{}}).TableName(), + (&pop.Model{Value: SCIMGroupMember{}}).TableName(), } for _, tableName := range tables { diff --git a/internal/models/errors.go b/internal/models/errors.go index 4f1c95e60..6e321f991 100644 --- a/internal/models/errors.go +++ b/internal/models/errors.go @@ -23,6 +23,8 @@ func IsNotFoundError(err error) bool { return true case SAMLRelayStateNotFoundError, *SAMLRelayStateNotFoundError: return true + case SCIMGroupNotFoundError, *SCIMGroupNotFoundError: + return true case FlowStateNotFoundError, *FlowStateNotFoundError: return true case OneTimeTokenNotFoundError, *OneTimeTokenNotFoundError: @@ -108,6 +110,20 @@ func (e SAMLRelayStateNotFoundError) Error() string { return "SAML RelayState not found" } +// SCIMGroupNotFoundError represents an error when a SCIM group can't be found. +type SCIMGroupNotFoundError struct{} + +func (e SCIMGroupNotFoundError) Error() string { + return "SCIM Group not found" +} + +// UserNotInSSOProviderError represents when a user does not belong to an SSO provider. +type UserNotInSSOProviderError struct{} + +func (e UserNotInSSOProviderError) Error() string { + return "User does not belong to this SSO provider" +} + // FlowStateNotFoundError represents an error when an FlowState can't be // found. type FlowStateNotFoundError struct{} diff --git a/internal/models/scim_group.go b/internal/models/scim_group.go new file mode 100644 index 000000000..926b24189 --- /dev/null +++ b/internal/models/scim_group.go @@ -0,0 +1,376 @@ +package models + +import ( + "database/sql" + "strings" + "time" + + "github.com/gobuffalo/pop/v6" + "github.com/gofrs/uuid" + "github.com/pkg/errors" + "github.com/supabase/auth/internal/storage" +) + +var ( + scimGroupTable = (&pop.Model{Value: SCIMGroup{}}).TableName() + scimGroupMemberTable = (&pop.Model{Value: SCIMGroupMember{}}).TableName() +) + +type SCIMGroup struct { + ID uuid.UUID `db:"id" json:"id"` + SSOProviderID uuid.UUID `db:"sso_provider_id" json:"-"` + ExternalID storage.NullString `db:"external_id" json:"external_id,omitempty"` + DisplayName string `db:"display_name" json:"display_name"` + CreatedAt time.Time `db:"created_at" json:"created_at"` + UpdatedAt time.Time `db:"updated_at" json:"updated_at"` + + SSOProvider *SSOProvider `belongs_to:"sso_providers" json:"-"` + Members []User `many_to_many:"scim_group_members" json:"members,omitempty"` +} + +func (SCIMGroup) TableName() string { + return "scim_groups" +} + +type SCIMGroupMember struct { + GroupID uuid.UUID `db:"group_id" json:"-"` + UserID uuid.UUID `db:"user_id" json:"-"` + CreatedAt time.Time `db:"created_at" json:"-"` +} + +func (SCIMGroupMember) TableName() string { + return "scim_group_members" +} + +func NewSCIMGroup(ssoProviderID uuid.UUID, externalID, displayName string) *SCIMGroup { + id := uuid.Must(uuid.NewV4()) + group := &SCIMGroup{ + ID: id, + SSOProviderID: ssoProviderID, + DisplayName: displayName, + } + // Only set ExternalID if non-empty (NULL in DB otherwise) + if externalID != "" { + group.ExternalID = storage.NullString(externalID) + } + return group +} + +func FindSCIMGroupByID(tx *storage.Connection, id uuid.UUID) (*SCIMGroup, error) { + var group SCIMGroup + if err := tx.Find(&group, id); err != nil { + if errors.Cause(err) == sql.ErrNoRows { + return nil, SCIMGroupNotFoundError{} + } + return nil, errors.Wrap(err, "error finding SCIM group by ID") + } + return &group, nil +} + +func FindSCIMGroupByExternalID(tx *storage.Connection, ssoProviderID uuid.UUID, externalID string) (*SCIMGroup, error) { + var group SCIMGroup + if err := tx.Q().Where("sso_provider_id = ? AND external_id = ?", ssoProviderID, externalID).First(&group); err != nil { + if errors.Cause(err) == sql.ErrNoRows { + return nil, SCIMGroupNotFoundError{} + } + return nil, errors.Wrap(err, "error finding SCIM group by external ID") + } + return &group, nil +} + +// SCIMFilterClause represents a parsed SCIM filter as SQL WHERE clause +type SCIMFilterClause struct { + Where string + Args []interface{} +} + +// FindSCIMGroupsBySSOProviderWithFilter finds groups with optional SCIM filter. +// The filterClause should be generated by ParseSCIMFilterToSQL. +func FindSCIMGroupsBySSOProviderWithFilter(tx *storage.Connection, ssoProviderID uuid.UUID, filterClause *SCIMFilterClause, startIndex, count int) ([]*SCIMGroup, int, error) { + groups := []*SCIMGroup{} + + offset := startIndex - 1 + if offset < 0 { + offset = 0 + } + + // Build query dynamically based on filter + whereClause := "sso_provider_id = ?" + args := []interface{}{ssoProviderID} + + if filterClause != nil && filterClause.Where != "" && filterClause.Where != "1=1" { + whereClause += " AND (" + filterClause.Where + ")" + args = append(args, filterClause.Args...) + } + + var totalResults int + countQuery := "SELECT COUNT(*) FROM " + scimGroupTable + " WHERE " + whereClause + if err := tx.RawQuery(countQuery, args...).First(&totalResults); err != nil { + return nil, 0, errors.Wrap(err, "error counting SCIM groups") + } + + query := "SELECT * FROM " + scimGroupTable + " WHERE " + whereClause + " ORDER BY created_at ASC LIMIT ? OFFSET ?" + args = append(args, count, offset) + if err := tx.RawQuery(query, args...).All(&groups); err != nil { + if errors.Cause(err) == sql.ErrNoRows { + return []*SCIMGroup{}, totalResults, nil + } + return nil, 0, errors.Wrap(err, "error finding SCIM groups") + } + return groups, totalResults, nil +} + +func (g *SCIMGroup) AddMember(tx *storage.Connection, userID uuid.UUID) error { + user, err := FindUserByID(tx, userID) + if err != nil { + return err + } + + if !UserBelongsToSSOProvider(user, g.SSOProviderID) { + return UserNotInSSOProviderError{} + } + + return tx.RawQuery( + "INSERT INTO "+scimGroupMemberTable+" (group_id, user_id, created_at) VALUES (?, ?, ?) ON CONFLICT DO NOTHING", + g.ID, userID, time.Now(), + ).Exec() +} + +// UserBelongsToSSOProvider checks if a user has an identity linked to the specified SSO provider. +func UserBelongsToSSOProvider(user *User, ssoProviderID uuid.UUID) bool { + providerType := "sso:" + ssoProviderID.String() + for _, identity := range user.Identities { + if identity.Provider == providerType { + return true + } + } + return false +} + +func (g *SCIMGroup) AddMembers(tx *storage.Connection, userIDs []uuid.UUID) error { + if len(userIDs) == 0 { + return nil + } + + userIDs = deduplicateUUIDs(userIDs) + + identityTable := (&pop.Model{Value: Identity{}}).TableName() + userTable := (&pop.Model{Value: User{}}).TableName() + providerType := "sso:" + g.SSOProviderID.String() + + placeholders := make([]string, len(userIDs)) + queryArgs := make([]interface{}, len(userIDs)) + for i, id := range userIDs { + placeholders[i] = "?" + queryArgs[i] = id + } + inClause := strings.Join(placeholders, ",") + + var rawValidIDs []uuid.UUID + validationArgs := make([]interface{}, 0, len(userIDs)+1) + validationArgs = append(validationArgs, queryArgs...) + validationArgs = append(validationArgs, providerType) + if err := tx.RawQuery( + "SELECT u.id FROM "+userTable+" u "+ + "INNER JOIN "+identityTable+" i ON i.user_id = u.id "+ + "WHERE u.id IN ("+inClause+") AND i.provider = ? "+ + "FOR SHARE OF u, i", + validationArgs..., + ).All(&rawValidIDs); err != nil { + return errors.Wrap(err, "error validating SCIM group member IDs") + } + + validSet := make(map[uuid.UUID]struct{}, len(rawValidIDs)) + for _, id := range rawValidIDs { + validSet[id] = struct{}{} + } + + if len(validSet) != len(userIDs) { + for _, id := range userIDs { + if _, ok := validSet[id]; !ok { + if _, err := FindUserByID(tx, id); err != nil { + if IsNotFoundError(err) { + return UserNotFoundError{} + } + return errors.Wrap(err, "error looking up user for SCIM group membership") + } + return UserNotInSSOProviderError{} + } + } + } + + now := time.Now() + insertArgs := make([]interface{}, 0, 2+len(userIDs)+1) + insertArgs = append(insertArgs, g.ID, now) + insertArgs = append(insertArgs, queryArgs...) + insertArgs = append(insertArgs, providerType) + + if err := tx.RawQuery( + "INSERT INTO "+scimGroupMemberTable+" (group_id, user_id, created_at) "+ + "SELECT ?, u.id, ? FROM "+userTable+" u "+ + "INNER JOIN "+identityTable+" i ON i.user_id = u.id "+ + "WHERE u.id IN ("+inClause+") AND i.provider = ? "+ + "ON CONFLICT DO NOTHING", + insertArgs..., + ).Exec(); err != nil { + return errors.Wrap(err, "error adding SCIM group members") + } + return nil +} + +func (g *SCIMGroup) RemoveMember(tx *storage.Connection, userID uuid.UUID) error { + return tx.RawQuery( + "DELETE FROM "+scimGroupMemberTable+" WHERE group_id = ? AND user_id = ?", + g.ID, userID, + ).Exec() +} + +func (g *SCIMGroup) GetMembers(tx *storage.Connection) ([]*User, error) { + users := []*User{} + userTable := (&pop.Model{Value: User{}}).TableName() + if err := tx.RawQuery( + "SELECT u.* FROM "+userTable+" u INNER JOIN "+scimGroupMemberTable+" m ON u.id = m.user_id WHERE m.group_id = ? ORDER BY u.email ASC LIMIT 10000", + g.ID, + ).All(&users); err != nil { + if errors.Cause(err) == sql.ErrNoRows { + return []*User{}, nil + } + return nil, errors.Wrap(err, "error getting SCIM group members") + } + return users, nil +} + +func (g *SCIMGroup) SetMembers(tx *storage.Connection, userIDs []uuid.UUID) error { + if len(userIDs) == 0 { + if err := tx.RawQuery("DELETE FROM "+scimGroupMemberTable+" WHERE group_id = ?", g.ID).Exec(); err != nil { + return errors.Wrap(err, "error clearing SCIM group members") + } + return nil + } + + userIDs = deduplicateUUIDs(userIDs) + + identityTable := (&pop.Model{Value: Identity{}}).TableName() + userTable := (&pop.Model{Value: User{}}).TableName() + providerType := "sso:" + g.SSOProviderID.String() + + placeholders := make([]string, len(userIDs)) + queryArgs := make([]interface{}, len(userIDs)) + for i, id := range userIDs { + placeholders[i] = "?" + queryArgs[i] = id + } + inClause := strings.Join(placeholders, ",") + + var rawValidIDs []uuid.UUID + validationArgs := make([]interface{}, 0, len(userIDs)+1) + validationArgs = append(validationArgs, queryArgs...) + validationArgs = append(validationArgs, providerType) + // Lock both user and identity rows during validation to prevent concurrent + // deletion or identity changes between validation and the membership write. + // DISTINCT is omitted because PostgreSQL disallows row-locking with DISTINCT; + // de-duplication is done in Go below. + if err := tx.RawQuery( + "SELECT u.id FROM "+userTable+" u "+ + "INNER JOIN "+identityTable+" i ON i.user_id = u.id "+ + "WHERE u.id IN ("+inClause+") AND i.provider = ? "+ + "FOR SHARE OF u, i", + validationArgs..., + ).All(&rawValidIDs); err != nil { + return errors.Wrap(err, "error validating SCIM group member IDs") + } + + // De-duplicate IDs in Go since we cannot use DISTINCT with FOR SHARE. + validSet := make(map[uuid.UUID]struct{}, len(rawValidIDs)) + for _, id := range rawValidIDs { + validSet[id] = struct{}{} + } + + if len(validSet) != len(userIDs) { + for _, id := range userIDs { + if _, ok := validSet[id]; !ok { + if _, err := FindUserByID(tx, id); err != nil { + if IsNotFoundError(err) { + return UserNotFoundError{} + } + return errors.Wrap(err, "error looking up user for SCIM group membership") + } + return UserNotInSSOProviderError{} + } + } + } + + if err := tx.RawQuery("DELETE FROM "+scimGroupMemberTable+" WHERE group_id = ?", g.ID).Exec(); err != nil { + return errors.Wrap(err, "error clearing SCIM group members") + } + + now := time.Now() + insertArgs := make([]interface{}, 0, 2+len(userIDs)+1) + insertArgs = append(insertArgs, g.ID, now) + insertArgs = append(insertArgs, queryArgs...) + insertArgs = append(insertArgs, providerType) + + if err := tx.RawQuery( + "INSERT INTO "+scimGroupMemberTable+" (group_id, user_id, created_at) "+ + "SELECT ?, u.id, ? FROM "+userTable+" u "+ + "INNER JOIN "+identityTable+" i ON i.user_id = u.id "+ + "WHERE u.id IN ("+inClause+") AND i.provider = ? "+ + "ON CONFLICT DO NOTHING", + insertArgs..., + ).Exec(); err != nil { + return errors.Wrap(err, "error setting SCIM group members") + } + return nil +} + +func GetMembersForGroups(tx *storage.Connection, groupIDs []uuid.UUID) (map[uuid.UUID][]*User, error) { + result := make(map[uuid.UUID][]*User) + if len(groupIDs) == 0 { + return result, nil + } + + userTable := (&pop.Model{Value: User{}}).TableName() + + type memberRow struct { + GroupID uuid.UUID `db:"group_id"` + User + } + + placeholders := make([]string, len(groupIDs)) + args := make([]interface{}, len(groupIDs)) + for i, id := range groupIDs { + placeholders[i] = "?" + args[i] = id + } + + rows := []memberRow{} + if err := tx.RawQuery( + "SELECT m.group_id, u.* FROM "+userTable+" u "+ + "INNER JOIN "+scimGroupMemberTable+" m ON u.id = m.user_id "+ + "WHERE m.group_id IN ("+strings.Join(placeholders, ",")+") "+ + "ORDER BY u.email ASC", + args..., + ).All(&rows); err != nil { + if errors.Cause(err) == sql.ErrNoRows { + return result, nil + } + return nil, errors.Wrap(err, "error batch loading SCIM group members") + } + + for i := range rows { + result[rows[i].GroupID] = append(result[rows[i].GroupID], &rows[i].User) + } + return result, nil +} + +func deduplicateUUIDs(ids []uuid.UUID) []uuid.UUID { + seen := make(map[uuid.UUID]struct{}, len(ids)) + out := make([]uuid.UUID, 0, len(ids)) + for _, id := range ids { + if _, ok := seen[id]; !ok { + seen[id] = struct{}{} + out = append(out, id) + } + } + return out +} diff --git a/internal/models/sso.go b/internal/models/sso.go index 3a5be7d97..746f85cb9 100644 --- a/internal/models/sso.go +++ b/internal/models/sso.go @@ -1,9 +1,11 @@ package models import ( + "crypto/sha256" "database/sql" "database/sql/driver" "encoding/json" + "fmt" "net/url" "reflect" "strings" @@ -23,6 +25,9 @@ type SSOProvider struct { SAMLProvider SAMLProvider `has_one:"saml_providers" fk_id:"sso_provider_id" json:"saml,omitempty"` SSODomains []SSODomain `has_many:"sso_domains" fk_id:"sso_provider_id" json:"domains"` + SCIMEnabled *bool `db:"scim_enabled" json:"scim_enabled,omitempty"` + SCIMBearerTokenHash *string `db:"scim_bearer_token_hash" json:"-"` + CreatedAt time.Time `db:"created_at" json:"created_at"` UpdatedAt time.Time `db:"updated_at" json:"updated_at"` } @@ -39,6 +44,27 @@ func (p SSOProvider) Type() string { return "saml" } +func (p SSOProvider) IsSCIMEnabled() bool { + return p.SCIMEnabled != nil && *p.SCIMEnabled +} + +func scimTokenHash(token string) string { + return fmt.Sprintf("%x", sha256.Sum256([]byte(token))) +} + +func (p *SSOProvider) SetSCIMToken(token string) { + hash := scimTokenHash(token) + p.SCIMBearerTokenHash = &hash + enabled := true + p.SCIMEnabled = &enabled +} + +func (p *SSOProvider) ClearSCIMToken() { + p.SCIMBearerTokenHash = nil + enabled := false + p.SCIMEnabled = &enabled +} + type SAMLAttribute struct { Name string `json:"name,omitempty"` Names []string `json:"names,omitempty"` @@ -266,12 +292,24 @@ func FindAllSSOProviders(tx *storage.Connection) ([]SSOProvider, error) { return providers, nil } +func FindSSOProviderBySCIMToken(tx *storage.Connection, token string) (*SSOProvider, error) { + hash := scimTokenHash(token) + + var provider SSOProvider + if err := tx.Q().Where("scim_enabled = ? AND scim_bearer_token_hash = ?", true, hash).First(&provider); err != nil { + if errors.Cause(err) == sql.ErrNoRows { + return nil, SSOProviderNotFoundError{} + } + return nil, errors.Wrap(err, "error finding SSO provider by SCIM token") + } + return &provider, nil +} + const ( resourceIDFilter = "resource_id" resourceIDPrefixFilter = "resource_id_prefix" ) -// FindAllSSOProvidersByFilter finds SSO Providers with the matching filter. func FindAllSSOProvidersByFilter( tx *storage.Connection, queryValues url.Values, diff --git a/internal/models/user.go b/internal/models/user.go index 3c706b80e..1b9d25bc4 100644 --- a/internal/models/user.go +++ b/internal/models/user.go @@ -65,11 +65,12 @@ type User struct { Factors []Factor `json:"factors,omitempty" has_many:"factors"` Identities []Identity `json:"identities" has_many:"identities"` - CreatedAt time.Time `json:"created_at" db:"created_at"` - UpdatedAt time.Time `json:"updated_at" db:"updated_at"` - BannedUntil *time.Time `json:"banned_until,omitempty" db:"banned_until"` - DeletedAt *time.Time `json:"deleted_at,omitempty" db:"deleted_at"` - IsAnonymous bool `json:"is_anonymous" db:"is_anonymous"` + CreatedAt time.Time `json:"created_at" db:"created_at"` + UpdatedAt time.Time `json:"updated_at" db:"updated_at"` + BannedUntil *time.Time `json:"banned_until,omitempty" db:"banned_until"` + BannedReason *string `json:"banned_reason,omitempty" db:"banned_reason"` + DeletedAt *time.Time `json:"deleted_at,omitempty" db:"deleted_at"` + IsAnonymous bool `json:"is_anonymous" db:"is_anonymous"` DONTUSEINSTANCEID uuid.UUID `json:"-" db:"instance_id"` } @@ -620,6 +621,28 @@ func FindUserByEmailAndAudience(tx *storage.Connection, email, aud string) (*Use return findUser(tx, "instance_id = ? and LOWER(email) = ? and aud = ? and is_sso_user = false", uuid.Nil, strings.ToLower(email), aud) } +// FindSSOUsersByEmailAndProvider finds all SSO users with the matching email, +// audience, and identity provider. This is used by SCIM provisioning to detect +// previously deprovisioned SSO users for reactivation without crossing provider +// boundaries. Results are ordered with active users first. +func FindSSOUsersByEmailAndProvider(tx *storage.Connection, email, aud, provider string) ([]*User, error) { + users := []*User{} + user := &User{} + query := ` + SELECT DISTINCT u.* FROM ` + user.TableName() + ` u + INNER JOIN identities i ON u.id = i.user_id + WHERE u.instance_id = ? AND LOWER(u.email) = ? AND u.aud = ? AND u.is_sso_user = true AND i.provider = ? + ORDER BY u.banned_until ASC NULLS FIRST, u.created_at ASC + ` + if err := tx.Eager().RawQuery(query, uuid.Nil, strings.ToLower(email), aud, provider).All(&users); err != nil { + if errors.Cause(err) == sql.ErrNoRows { + return nil, nil + } + return nil, errors.Wrap(err, "error finding SSO users by email and provider") + } + return users, nil +} + // FindUserByPhoneAndAudience finds a user with the matching email and audience. func FindUserByPhoneAndAudience(tx *storage.Connection, phone, aud string) (*User, error) { return findUser(tx, "instance_id = ? and phone = ? and aud = ? and is_sso_user = false", uuid.Nil, phone, aud) @@ -630,6 +653,93 @@ func FindUserByID(tx *storage.Connection, id uuid.UUID) (*User, error) { return findUser(tx, "instance_id = ? and id = ?", uuid.Nil, id) } +func CountUsersByProvider(tx *storage.Connection, provider string) (int, error) { + var count int + err := tx.RawQuery(` + SELECT COUNT(DISTINCT u.id) FROM `+(&User{}).TableName()+` u + INNER JOIN identities i ON u.id = i.user_id + WHERE i.provider = ? AND u.instance_id = ? + `, provider, uuid.Nil).First(&count) + if err != nil { + return 0, errors.Wrap(err, "error counting users by provider") + } + return count, nil +} + +// startIndex is 1-indexed per SCIM spec. count is the max number of results to return. +func FindUsersByProvider(tx *storage.Connection, provider string, startIndex, count int) ([]*User, error) { + users := []*User{} + + offset := startIndex - 1 + if offset < 0 { + offset = 0 + } + + query := ` + SELECT DISTINCT u.* FROM ` + (&User{}).TableName() + ` u + INNER JOIN identities i ON u.id = i.user_id + WHERE i.provider = ? AND u.instance_id = ? + ORDER BY u.created_at ASC + LIMIT ? OFFSET ? + ` + err := tx.Eager().RawQuery(query, provider, uuid.Nil, count, offset).All(&users) + if err != nil { + if errors.Cause(err) == sql.ErrNoRows { + return users, nil + } + return nil, errors.Wrap(err, "error finding users by provider") + } + return users, nil +} + +// FindUsersByProviderWithFilter finds users by provider with optional SCIM filter. +// The filterClause should be generated by ParseSCIMFilterToSQL with SCIMUserFilterAttrs. +func FindUsersByProviderWithFilter(tx *storage.Connection, provider string, filterClause *SCIMFilterClause, startIndex, count int) ([]*User, int, error) { + users := []*User{} + + offset := startIndex - 1 + if offset < 0 { + offset = 0 + } + + // Base WHERE clause for provider + baseWhere := "i.provider = ? AND u.instance_id = ?" + baseArgs := []interface{}{provider, uuid.Nil} + + // Add filter clause if present + whereClause := baseWhere + args := baseArgs + if filterClause != nil && filterClause.Where != "" && filterClause.Where != "1=1" { + whereClause += " AND (" + filterClause.Where + ")" + args = append(args, filterClause.Args...) + } + + var totalResults int + countQuery := ` + SELECT COUNT(DISTINCT u.id) FROM ` + (&User{}).TableName() + ` u + INNER JOIN identities i ON u.id = i.user_id + WHERE ` + whereClause + if err := tx.RawQuery(countQuery, args...).First(&totalResults); err != nil { + return nil, 0, errors.Wrap(err, "error counting users by provider with filter") + } + + query := ` + SELECT DISTINCT u.* FROM ` + (&User{}).TableName() + ` u + INNER JOIN identities i ON u.id = i.user_id + WHERE ` + whereClause + ` + ORDER BY u.created_at ASC + LIMIT ? OFFSET ? + ` + args = append(args, count, offset) + if err := tx.Eager("Identities").RawQuery(query, args...).All(&users); err != nil { + if errors.Cause(err) == sql.ErrNoRows { + return users, totalResults, nil + } + return nil, 0, errors.Wrap(err, "error finding users by provider with filter") + } + return users, totalResults, nil +} + // FindUserWithRefreshToken finds a user from the provided refresh token. If // forUpdate is set to true, then the SELECT statement used by the query has // the form SELECT ... FOR UPDATE SKIP LOCKED. This means that a FOR UPDATE @@ -838,15 +948,17 @@ func IsDuplicatedPhone(tx *storage.Connection, phone, aud string) (bool, error) return true, nil } -// Ban a user for a given duration. -func (u *User) Ban(tx *storage.Connection, duration time.Duration) error { +// Ban a user for a given duration with an optional reason. +func (u *User) Ban(tx *storage.Connection, duration time.Duration, reason *string) error { if duration == time.Duration(0) { u.BannedUntil = nil + u.BannedReason = nil } else { t := time.Now().Add(duration) u.BannedUntil = &t + u.BannedReason = reason } - return tx.UpdateOnly(u, "banned_until") + return tx.UpdateOnly(u, "banned_until", "banned_reason") } // IsBanned checks if a user is banned or not diff --git a/migrations/20251210100000_add_scim_to_sso_providers.up.sql b/migrations/20251210100000_add_scim_to_sso_providers.up.sql new file mode 100644 index 000000000..4fce688d7 --- /dev/null +++ b/migrations/20251210100000_add_scim_to_sso_providers.up.sql @@ -0,0 +1,72 @@ +-- Add SCIM provisioning support + +-- Add SCIM columns to SSO providers +do $$ begin + alter table only {{ index .Options "Namespace" }}.sso_providers + add column if not exists scim_enabled boolean null default false, + add column if not exists scim_bearer_token_hash text null; +end $$; + +comment on column {{ index .Options "Namespace" }}.sso_providers.scim_enabled is 'Auth: Whether SCIM provisioning is enabled for this SSO provider'; +comment on column {{ index .Options "Namespace" }}.sso_providers.scim_bearer_token_hash is 'Auth: SHA-256 hash of the SCIM bearer token used by the IdP'; + +-- Index for direct SCIM token hash lookup +create unique index if not exists sso_providers_scim_token_hash_idx + on {{ index .Options "Namespace" }}.sso_providers (scim_bearer_token_hash) + where scim_bearer_token_hash is not null; + +-- Add banned_reason to users for SCIM deprovisioning +do $$ begin + alter table only {{ index .Options "Namespace" }}.users + add column if not exists banned_reason text null; +end $$; + +comment on column {{ index .Options "Namespace" }}.users.banned_reason is 'Auth: Reason for user ban (e.g., SCIM_DEPROVISIONED)'; + +-- SCIM Groups +create table if not exists {{ index .Options "Namespace" }}.scim_groups ( + id uuid not null, + sso_provider_id uuid not null, + external_id text null, + display_name text not null, + created_at timestamptz not null default now(), + updated_at timestamptz not null default now(), + + constraint scim_groups_pkey primary key (id), + constraint scim_groups_sso_provider_fkey foreign key (sso_provider_id) + references {{ index .Options "Namespace" }}.sso_providers (id) on delete cascade, + constraint "external_id not empty if set" check (external_id is null or char_length(external_id) > 0), + constraint "display_name not empty" check (char_length(display_name) > 0) +); + +create unique index if not exists scim_groups_sso_provider_external_id_idx + on {{ index .Options "Namespace" }}.scim_groups (sso_provider_id, external_id) + where external_id is not null; + +create unique index if not exists scim_groups_sso_provider_display_name_idx + on {{ index .Options "Namespace" }}.scim_groups (sso_provider_id, lower(display_name)); + +create index if not exists scim_groups_sso_provider_id_idx + on {{ index .Options "Namespace" }}.scim_groups (sso_provider_id); + +comment on table {{ index .Options "Namespace" }}.scim_groups is 'Auth: Manages SCIM groups provisioned by SSO identity providers.'; +comment on column {{ index .Options "Namespace" }}.scim_groups.external_id is 'Auth: The group ID from the external identity provider.'; +comment on column {{ index .Options "Namespace" }}.scim_groups.display_name is 'Auth: Human-readable name of the group.'; + +-- SCIM Group Members +create table if not exists {{ index .Options "Namespace" }}.scim_group_members ( + group_id uuid not null, + user_id uuid not null, + created_at timestamptz not null default now(), + + constraint scim_group_members_pkey primary key (group_id, user_id), + constraint scim_group_members_group_fkey foreign key (group_id) + references {{ index .Options "Namespace" }}.scim_groups (id) on delete cascade, + constraint scim_group_members_user_fkey foreign key (user_id) + references {{ index .Options "Namespace" }}.users (id) on delete cascade +); + +create index if not exists scim_group_members_user_id_idx + on {{ index .Options "Namespace" }}.scim_group_members (user_id); + +comment on table {{ index .Options "Namespace" }}.scim_group_members is 'Auth: Junction table for SCIM group membership.';