diff --git a/client.go b/client.go index 5fe1059..f55315e 100644 --- a/client.go +++ b/client.go @@ -138,7 +138,11 @@ func (c *Client) LoginUrlForRequest(r *http.Request) (string, error) { } q := u.Query() - q.Add("service", sanitisedURLString(service)) + serviceURLStr, err := sanitisedURLString(service) + if err != nil { + return "", err + } + q.Add("service", serviceURLStr) u.RawQuery = q.Encode() return u.String(), nil @@ -158,7 +162,11 @@ func (c *Client) LogoutUrlForRequest(r *http.Request) (string, error) { } q := u.Query() - q.Add("service", sanitisedURLString(service)) + serviceURLStr, err := sanitisedURLString(service) + if err != nil { + return "", err + } + q.Add("service", serviceURLStr) u.RawQuery = q.Encode() } @@ -187,7 +195,7 @@ func (c *Client) ValidateUrlForRequest(ticket string, r *http.Request) (string, func (c *Client) RedirectToLogout(w http.ResponseWriter, r *http.Request) { u, err := c.LogoutUrlForRequest(r) if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) + http.Error(w, err.Error(), http.StatusBadRequest) return } @@ -204,7 +212,7 @@ func (c *Client) RedirectToLogout(w http.ResponseWriter, r *http.Request) { func (c *Client) RedirectToLogin(w http.ResponseWriter, r *http.Request) { u, err := c.LoginUrlForRequest(r) if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) + http.Error(w, err.Error(), http.StatusBadRequest) return } diff --git a/go.mod b/go.mod index d0dfe8a..5fcc70d 100644 --- a/go.mod +++ b/go.mod @@ -1,4 +1,4 @@ -module gopkg.in/cas.v2 +module github.com/yy1987316/cas go 1.12 diff --git a/sanitise.go b/sanitise.go index a2cc17c..f2b1d76 100644 --- a/sanitise.go +++ b/sanitise.go @@ -9,9 +9,12 @@ var ( ) // sanitisedURL cleans a URL of CAS specific parameters -func sanitisedURL(unclean *url.URL) *url.URL { - // Shouldn't be any errors parsing an existing *url.URL - u, _ := url.Parse(unclean.String()) +func sanitisedURL(unclean *url.URL) (*url.URL, error) { + // Parse maybe occur errors, cause unclean is dealt with requestURL method + u, err := url.Parse(unclean.String()) + if err != nil { + return nil, err + } q := u.Query() for _, param := range urlCleanParameters { @@ -19,10 +22,14 @@ func sanitisedURL(unclean *url.URL) *url.URL { } u.RawQuery = q.Encode() - return u + return u, nil } // sanitisedURLString cleans a URL and returns its string value -func sanitisedURLString(unclean *url.URL) string { - return sanitisedURL(unclean).String() +func sanitisedURLString(unclean *url.URL) (string, error) { + u, err := sanitisedURL(unclean) + if err != nil { + return "", err + } + return u.String(), nil } diff --git a/sanitise_test.go b/sanitise_test.go new file mode 100644 index 0000000..d15983e --- /dev/null +++ b/sanitise_test.go @@ -0,0 +1,51 @@ +package cas + +import ( + "net/url" + "reflect" + "testing" +) + +func Test_sanitisedURL(t *testing.T) { + type args struct { + unclean *url.URL + } + tests := []struct { + name string + args args + want *url.URL + wantErr bool + }{ + { + name: "Test the URL Scheme chaos value, cause be dealt with requestURL method", + args: args{ + unclean: &url.URL{ + Scheme: "chaos_input_from_header_X-Forwarded-Proto", + Opaque: "", + User: &url.Userinfo{}, + Host: "a.b.c", + Path: "/", + RawPath: "/", + ForceQuery: false, + RawQuery: "", + Fragment: "", + RawFragment: "", + }, + }, + want: nil, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := sanitisedURL(tt.args.unclean) + if (err != nil) != tt.wantErr { + t.Errorf("sanitisedURL() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("sanitisedURL() got = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/service_validate.go b/service_validate.go index 5754bb2..a9b7f17 100644 --- a/service_validate.go +++ b/service_validate.go @@ -99,7 +99,11 @@ func (validator *ServiceTicketValidator) ServiceValidateUrl(serviceURL *url.URL, } q := u.Query() - q.Add("service", sanitisedURLString(serviceURL)) + serviceURLStr, err := sanitisedURLString(serviceURL) + if err != nil { + return "", err + } + q.Add("service", serviceURLStr) q.Add("ticket", ticket) u.RawQuery = q.Encode() @@ -175,7 +179,11 @@ func (validator *ServiceTicketValidator) ValidateUrl(serviceURL *url.URL, ticket } q := u.Query() - q.Add("service", sanitisedURLString(serviceURL)) + serviceURLStr, err := sanitisedURLString(serviceURL) + if err != nil { + return "", err + } + q.Add("service", serviceURLStr) q.Add("ticket", ticket) u.RawQuery = q.Encode()