From 79fcd7ca1b3d19d63d76243d2f5ec2efeaea8f2e Mon Sep 17 00:00:00 2001 From: Andrew Stuart Date: Tue, 20 Nov 2018 18:40:00 -0700 Subject: [PATCH 1/4] Filter cookies arbitrarily --- jar.go | 47 ++++++++++++++++++++++++++++++++++++++++------- serialize.go | 3 ++- 2 files changed, 42 insertions(+), 8 deletions(-) diff --git a/jar.go b/jar.go index 5ff2733..188bd08 100644 --- a/jar.go +++ b/jar.go @@ -72,6 +72,15 @@ type Options struct { // (useful for tests). If this is true, the value of Filename will be // ignored. NoPersist bool + + // Filter specifies the filter to be used when deciding whether each cookie + // should be persisted to the filesystem. + Filter CookieFilter +} + +// CookieFilter is a type for deciding whether a Cookie should be persisted +type CookieFilter interface { + IsPersistent(*http.Cookie) bool } // Jar implements the http.CookieJar interface from the net/http package. @@ -87,6 +96,9 @@ type Jar struct { // entries is a set of entries, keyed by their eTLD+1 and subkeyed by // their name/domain/path. entries map[string]map[string]entry + + // Filter from Options + filter CookieFilter } var noOptions Options @@ -108,6 +120,11 @@ func newAtTime(o *Options, now time.Time) (*Jar, error) { if o == nil { o = &noOptions } + jar.filter = DefaultFilter + if o.Filter != nil { + jar.filter = o.Filter + } + if jar.psList = o.PublicSuffixList; jar.psList == nil { jar.psList = publicsuffix.List } @@ -144,7 +161,6 @@ type entry struct { Path string Secure bool HttpOnly bool - Persistent bool HostOnly bool Expires time.Time Creation time.Time @@ -255,11 +271,11 @@ func (j *Jar) Cookies(u *url.URL) (cookies []*http.Cookie) { // cookies is like Cookies but takes the current time as a parameter. func (j *Jar) cookies(u *url.URL, now time.Time) (cookies []*http.Cookie) { if u.Scheme != "http" && u.Scheme != "https" { - return cookies + return } host, err := canonicalHost(u.Host) if err != nil { - return cookies + return } key := jarKey(host, j.psList) @@ -268,7 +284,7 @@ func (j *Jar) cookies(u *url.URL, now time.Time) (cookies []*http.Cookie) { submap := j.entries[key] if submap == nil { - return cookies + return } https := u.Scheme == "https" @@ -310,7 +326,7 @@ func (j *Jar) cookies(u *url.URL, now time.Time) (cookies []*http.Cookie) { // have Domain, Expires, HttpOnly, Name, Secure, Path, and Value filled // out. Expired cookies will not be returned. This function does not // modify the cookie jar. -func (j *Jar) AllCookies() (cookies []*http.Cookie) { +func (j *Jar) AllCookies() []*http.Cookie { return j.allCookies(time.Now()) } @@ -575,6 +591,24 @@ func defaultPath(path string) string { return path[:i] // Path is either of form "/abc/xyz" or "/abc/xyz/". } +// CookieFilterFunc implements CookieFilter by calling the underlying func +type CookieFilterFunc func(*http.Cookie) bool + +// IsPersistent implements CookieFilter for arbitrary funcs +func (cff CookieFilterFunc) IsPersistent(c *http.Cookie) bool { + return cff(c) +} + +// Well-known FilterFuncs +var ( + DefaultFilter = CookieFilterFunc(func(c *http.Cookie) bool { + return c.MaxAge == 0 && !c.Expires.IsZero() + }) + AnyFilter = CookieFilterFunc(func(c *http.Cookie) bool { + return true + }) +) + // newEntry creates an entry from a http.Cookie c. now is the current // time and is compared to c.Expires to determine deletion of c. defPath // and host are the default-path and the canonical host name of the URL @@ -597,9 +631,9 @@ func (j *Jar) newEntry(c *http.Cookie, now time.Time, defPath, host string) (e e if err != nil { return e, err } + // MaxAge takes precedence over Expires. if c.MaxAge != 0 { - e.Persistent = true e.Expires = now.Add(time.Duration(c.MaxAge) * time.Second) if c.MaxAge < 0 { return e, nil @@ -607,7 +641,6 @@ func (j *Jar) newEntry(c *http.Cookie, now time.Time, defPath, host string) (e e } else if c.Expires.IsZero() { e.Expires = endOfTime } else { - e.Persistent = true e.Expires = c.Expires if !c.Expires.After(now) { return e, nil diff --git a/serialize.go b/serialize.go index 2792dfb..2e1f548 100644 --- a/serialize.go +++ b/serialize.go @@ -8,6 +8,7 @@ import ( "encoding/json" "io" "log" + "net/http" "os" "path/filepath" "sort" @@ -137,7 +138,7 @@ func (j *Jar) allPersistentEntries() []entry { var entries []entry for _, submap := range j.entries { for _, e := range submap { - if e.Persistent { + if j.filter.IsPersistent(&http.Cookie{Name: e.Name, Value: e.Value}) { entries = append(entries, e) } } From a690c21b604cdddd4e5e60a4d93e5238d154a36b Mon Sep 17 00:00:00 2001 From: Andrew Stuart Date: Tue, 20 Nov 2018 19:27:27 -0700 Subject: [PATCH 2/4] More work on filtering --- jar.go | 9 ++++++--- jar_test.go | 38 ++++++++++++++++++++++++++++++++++++++ serialize.go | 11 ++++++++++- 3 files changed, 54 insertions(+), 4 deletions(-) diff --git a/jar.go b/jar.go index 188bd08..c4622e6 100644 --- a/jar.go +++ b/jar.go @@ -164,6 +164,7 @@ type entry struct { HostOnly bool Expires time.Time Creation time.Time + MaxAge int LastAccess time.Time // Updated records when the cookie was updated. @@ -271,11 +272,11 @@ func (j *Jar) Cookies(u *url.URL) (cookies []*http.Cookie) { // cookies is like Cookies but takes the current time as a parameter. func (j *Jar) cookies(u *url.URL, now time.Time) (cookies []*http.Cookie) { if u.Scheme != "http" && u.Scheme != "https" { - return + return cookies } host, err := canonicalHost(u.Host) if err != nil { - return + return cookies } key := jarKey(host, j.psList) @@ -284,7 +285,7 @@ func (j *Jar) cookies(u *url.URL, now time.Time) (cookies []*http.Cookie) { submap := j.entries[key] if submap == nil { - return + return cookies } https := u.Scheme == "https" @@ -604,6 +605,7 @@ var ( DefaultFilter = CookieFilterFunc(func(c *http.Cookie) bool { return c.MaxAge == 0 && !c.Expires.IsZero() }) + AnyFilter = CookieFilterFunc(func(c *http.Cookie) bool { return true }) @@ -621,6 +623,7 @@ var ( // A malformed c.Domain will result in an error. func (j *Jar) newEntry(c *http.Cookie, now time.Time, defPath, host string) (e entry, err error) { e.Name = c.Name + e.MaxAge = c.MaxAge if c.Path == "" || c.Path[0] != '/' { e.Path = defPath } else { diff --git a/jar_test.go b/jar_test.go index 5ea3467..80cc3ef 100644 --- a/jar_test.go +++ b/jar_test.go @@ -5,6 +5,7 @@ package cookiejar import ( + "encoding/json" "fmt" "io/ioutil" "net/http" @@ -2067,6 +2068,43 @@ func TestRemoveAllHostIP(t *testing.T) { testRemoveAllHost(t, mustParseURL("https://10.1.1.1"), "10.1.1.1", true) } +func TestFilter(t *testing.T) { + j := newTestJar("") + + j.filter = AnyFilter + google := mustParseURL("https://www.google.com") + + j.SetCookies( + google, + []*http.Cookie{ + &http.Cookie{ + Name: "test-cookie", + Value: "test-value", + Expires: time.Now().Add(24 * time.Hour), + }, + &http.Cookie{ + Name: "test-cookie2", + Value: "test-value", + Expires: time.Now().Add(-24 * time.Hour), + }, + }, + ) + + bs, err := j.MarshalJSON() + if err != nil { + t.Errorf("error marshaling json") + } + + var es []entry + if json.Unmarshal(bs, &es) != nil { + t.Errorf("error remarshaling") + } + + if len(es) < 2 { + t.Errorf("fewer than two entries were marshaled") + } +} + func testRemoveAllHost(t *testing.T, setURL *url.URL, removeHost string, shouldRemove bool) { jar := newTestJar("") google := mustParseURL("https://www.google.com") diff --git a/serialize.go b/serialize.go index 2e1f548..11507e8 100644 --- a/serialize.go +++ b/serialize.go @@ -138,7 +138,16 @@ func (j *Jar) allPersistentEntries() []entry { var entries []entry for _, submap := range j.entries { for _, e := range submap { - if j.filter.IsPersistent(&http.Cookie{Name: e.Name, Value: e.Value}) { + if j.filter.IsPersistent(&http.Cookie{ + Domain: e.Domain, + Expires: e.Expires, + HttpOnly: e.HttpOnly, + MaxAge: e.MaxAge, + Name: e.Name, + Path: e.Path, + Secure: e.Secure, + Value: e.Value, + }) { entries = append(entries, e) } } From 4d3d3105ce6611c79785202ee3f4fe5cf2eb6fd9 Mon Sep 17 00:00:00 2001 From: Andrew Stuart Date: Tue, 20 Nov 2018 19:31:26 -0700 Subject: [PATCH 3/4] Fix logic --- jar.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jar.go b/jar.go index c4622e6..b9744c2 100644 --- a/jar.go +++ b/jar.go @@ -603,7 +603,7 @@ func (cff CookieFilterFunc) IsPersistent(c *http.Cookie) bool { // Well-known FilterFuncs var ( DefaultFilter = CookieFilterFunc(func(c *http.Cookie) bool { - return c.MaxAge == 0 && !c.Expires.IsZero() + return c.MaxAge != 0 || !c.Expires.IsZero() }) AnyFilter = CookieFilterFunc(func(c *http.Cookie) bool { From afb54bd74b6bd054fae734b05efd62e34d8685c1 Mon Sep 17 00:00:00 2001 From: Andrew Stuart Date: Tue, 20 Nov 2018 20:11:08 -0700 Subject: [PATCH 4/4] Add expired helper func --- jar.go | 20 ++++++++++++-------- jar_test.go | 51 +++++++++++++++++++++++++++++++++++++++++---------- 2 files changed, 53 insertions(+), 18 deletions(-) diff --git a/jar.go b/jar.go index b9744c2..4ad290f 100644 --- a/jar.go +++ b/jar.go @@ -220,6 +220,10 @@ func (e *entry) pathMatch(requestPath string) bool { return false } +func (e *entry) isExpiredAfter(t time.Time) bool { + return !e.Expires.IsZero() && t.After(e.Expires) +} + // hasDotSuffix reports whether s ends in "."+suffix. func hasDotSuffix(s, suffix string) bool { return len(s) > len(suffix) && s[len(s)-len(suffix)-1] == '.' && s[len(s)-len(suffix):] == suffix @@ -296,7 +300,7 @@ func (j *Jar) cookies(u *url.URL, now time.Time) (cookies []*http.Cookie) { var selected []entry for id, e := range submap { - if !e.Expires.After(now) { + if e.isExpiredAfter(now) { // Save some space by deleting the value when the cookie // expires. We can't delete the cookie itself because then // we wouldn't know that the cookie had expired when @@ -338,7 +342,7 @@ func (j *Jar) allCookies(now time.Time) []*http.Cookie { defer j.mu.Unlock() for _, submap := range j.entries { for _, e := range submap { - if !e.Expires.After(now) { + if e.isExpiredAfter(now) { // Do not return expired cookies. continue } @@ -410,7 +414,7 @@ var expiryRemovalDuration = 24 * time.Hour func (j *Jar) deleteExpired(now time.Time) { for tld, submap := range j.entries { for id, e := range submap { - if !e.Expires.After(now) && !e.Updated.Add(expiryRemovalDuration).After(now) { + if e.isExpiredAfter(now) && !e.Updated.Add(expiryRemovalDuration).After(now) { delete(submap, id) } } @@ -600,13 +604,15 @@ func (cff CookieFilterFunc) IsPersistent(c *http.Cookie) bool { return cff(c) } -// Well-known FilterFuncs var ( + // DefaultFilter is the previous behavior which does not persist session + // cookies. DefaultFilter = CookieFilterFunc(func(c *http.Cookie) bool { return c.MaxAge != 0 || !c.Expires.IsZero() }) - AnyFilter = CookieFilterFunc(func(c *http.Cookie) bool { + // AllowAllFilter does not check any cookie properties before persisting. + AllowAllFilter = CookieFilterFunc(func(_ *http.Cookie) bool { return true }) ) @@ -641,11 +647,9 @@ func (j *Jar) newEntry(c *http.Cookie, now time.Time, defPath, host string) (e e if c.MaxAge < 0 { return e, nil } - } else if c.Expires.IsZero() { - e.Expires = endOfTime } else { e.Expires = c.Expires - if !c.Expires.After(now) { + if e.isExpiredAfter(now) { return e, nil } } diff --git a/jar_test.go b/jar_test.go index 80cc3ef..49a066e 100644 --- a/jar_test.go +++ b/jar_test.go @@ -1812,7 +1812,7 @@ func allCookies(jar *Jar, now time.Time) string { var cs []string for _, submap := range jar.entries { for _, cookie := range submap { - if !cookie.Expires.After(now) { + if !cookie.Expires.IsZero() && now.After(cookie.Expires) { continue } cs = append(cs, cookie.Name+"="+cookie.Value) @@ -2071,7 +2071,6 @@ func TestRemoveAllHostIP(t *testing.T) { func TestFilter(t *testing.T) { j := newTestJar("") - j.filter = AnyFilter google := mustParseURL("https://www.google.com") j.SetCookies( @@ -2083,26 +2082,58 @@ func TestFilter(t *testing.T) { Expires: time.Now().Add(24 * time.Hour), }, &http.Cookie{ - Name: "test-cookie2", - Value: "test-value", - Expires: time.Now().Add(-24 * time.Hour), + Name: "test-cookie2", + Value: "test-value", }, }, ) + es, err := jsonRoundTrip(j) + if err != nil { + t.Fatalf("json failed: %v", err) + } + + if len(es) != 1 { + t.Errorf("expected only one entry, got %d", len(es)) + } + + j.filter = CookieFilterFunc(func(_ *http.Cookie) bool { + return false + }) + + es, err = jsonRoundTrip(j) + if err != nil { + t.Fatalf("json failed: %v", err) + } + + if len(es) > 0 { + t.Errorf("expected zero entries") + } + + j.filter = AllowAllFilter + + es, err = jsonRoundTrip(j) + if err != nil { + t.Fatalf("json failed: %v", err) + } + + if len(es) < 2 { + t.Errorf("got fewer than two entries with AllowAllFilter") + } +} + +func jsonRoundTrip(j *Jar) ([]entry, error) { bs, err := j.MarshalJSON() if err != nil { - t.Errorf("error marshaling json") + return nil, err } var es []entry if json.Unmarshal(bs, &es) != nil { - t.Errorf("error remarshaling") + return nil, err } - if len(es) < 2 { - t.Errorf("fewer than two entries were marshaled") - } + return es, nil } func testRemoveAllHost(t *testing.T, setURL *url.URL, removeHost string, shouldRemove bool) {