diff --git a/httpcache.go b/httpcache.go index b41a63d..957d0d3 100644 --- a/httpcache.go +++ b/httpcache.go @@ -9,7 +9,9 @@ package httpcache import ( "bufio" "bytes" + "context" "errors" + "fmt" "io" "io/ioutil" "net/http" @@ -38,6 +40,41 @@ type Cache interface { Delete(key string) } +// ContextCache the same as Cache except that its functions accept a +// context.Context argument and return an additional error object. +type ContextCache interface { + // Get returns the []byte representation of a cached response and a bool + // set to true if the value isn't empty + Get(ctx context.Context, key string) (responseBytes []byte, ok bool, err error) + // Set stores the []byte representation of a response against a key + Set(ctx context.Context, key string, responseBytes []byte) error + // Delete removes the value associated with the key + Delete(ctx context.Context, key string) error +} + +// cacheAsContextCache is an implementation of ContextCache that wraps a regular Cache. +type cacheAsContextCache struct{ cache Cache } + +var _ ContextCache = cacheAsContextCache{} + +// Delete implements ContextCache +func (c cacheAsContextCache) Delete(_ context.Context, key string) error { + c.cache.Delete(key) + return nil +} + +// Get implements ContextCache +func (c cacheAsContextCache) Get(_ context.Context, key string) (responseBytes []byte, ok bool, err error) { + got, ok := c.cache.Get(key) + return got, ok, nil +} + +// Set implements ContextCache +func (c cacheAsContextCache) Set(_ context.Context, key string, responseBytes []byte) error { + c.cache.Set(key, responseBytes) + return nil +} + // cacheKey returns the cache key for req. func cacheKey(req *http.Request) string { if req.Method == http.MethodGet { @@ -59,6 +96,21 @@ func CachedResponse(c Cache, req *http.Request) (resp *http.Response, err error) return http.ReadResponse(bufio.NewReader(b), req) } +// contextCachedResponse returns the cached http.Response for req if present, and nil +// otherwise. +func contextCachedResponse(c ContextCache, req *http.Request) (resp *http.Response, err error) { + cachedVal, ok, err := c.Get(req.Context(), cacheKey(req)) + if err != nil { + return nil, fmt.Errorf("httpcache Get error: %w", err) + } + if !ok { + return + } + + b := bytes.NewBuffer(cachedVal) + return http.ReadResponse(bufio.NewReader(b), req) +} + // MemoryCache is an implemtation of Cache that stores responses in an in-memory map. type MemoryCache struct { mu sync.RWMutex @@ -101,6 +153,14 @@ type Transport struct { // If nil, http.DefaultTransport is used Transport http.RoundTripper Cache Cache + // ContextCache, if set, will be used instead of Cache by the transport. + // + // The Context() method of http.Request is used to obtain the + // context.Context argument for the cache. Errors from the ContextCache + // cause the Transport's RoundTrip method to return errors. + // + // If ContextCache is non-nil, Cache may be nil. + ContextCache ContextCache // If true, responses returned from the cache will be given an extra header, X-From-Cache MarkCachedResponses bool } @@ -141,10 +201,20 @@ func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error cacheable := (req.Method == "GET" || req.Method == "HEAD") && req.Header.Get("range") == "" var cachedResp *http.Response if cacheable { - cachedResp, err = CachedResponse(t.Cache, req) + if t.ContextCache != nil { + cachedResp, err = contextCachedResponse(t.ContextCache, req) + } else { + cachedResp, err = CachedResponse(t.Cache, req) + } } else { // Need to invalidate an existing value - t.Cache.Delete(cacheKey) + if t.ContextCache != nil { + if err := t.ContextCache.Delete(req.Context(), cacheKey); err != nil { + return nil, fmt.Errorf("httpcache Delete error: %w", err) + } + } else { + t.Cache.Delete(cacheKey) + } } transport := t.Transport @@ -200,7 +270,14 @@ func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error return cachedResp, nil } else { if err != nil || resp.StatusCode != http.StatusOK { - t.Cache.Delete(cacheKey) + if t.ContextCache != nil { + // Don't overwrite non-nil err. + if cacheErr := t.ContextCache.Delete(req.Context(), cacheKey); err == nil { + err = cacheErr + } + } else { + t.Cache.Delete(cacheKey) + } } if err != nil { return nil, err @@ -232,23 +309,42 @@ func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error // Delay caching until EOF is reached. resp.Body = &cachingReadCloser{ R: resp.Body, - OnEOF: func(r io.Reader) { + OnEOF: func(r io.Reader) error { resp := *resp resp.Body = ioutil.NopCloser(r) respBytes, err := httputil.DumpResponse(&resp, true) if err == nil { - t.Cache.Set(cacheKey, respBytes) + if t.ContextCache != nil { + if err := t.ContextCache.Set(req.Context(), cacheKey, respBytes); err != nil { + return fmt.Errorf("httpcache Set error: %w", err) + } + } else { + t.Cache.Set(cacheKey, respBytes) + } } + return err }, } default: respBytes, err := httputil.DumpResponse(resp, true) if err == nil { - t.Cache.Set(cacheKey, respBytes) + if t.ContextCache != nil { + if err := t.ContextCache.Set(req.Context(), cacheKey, respBytes); err != nil { + return nil, err + } + } else { + t.Cache.Set(cacheKey, respBytes) + } } } } else { - t.Cache.Delete(cacheKey) + if t.ContextCache != nil { + if err := t.ContextCache.Delete(req.Context(), cacheKey); err != nil { + return nil, fmt.Errorf("httpcache Delete error: %w", err) + } + } else { + t.Cache.Delete(cacheKey) + } } return resp, nil } @@ -473,6 +569,10 @@ func cloneRequest(r *http.Request) *http.Request { for k, s := range r.Header { r2.Header[k] = s } + ctx := r.Context() + if ctx != nil { + r2 = r2.WithContext(ctx) + } return r2 } @@ -521,7 +621,7 @@ type cachingReadCloser struct { // Underlying ReadCloser. R io.ReadCloser // OnEOF is called with a copy of the content of R when EOF is reached. - OnEOF func(io.Reader) + OnEOF func(io.Reader) error buf bytes.Buffer // buf stores a copy of the content of R. } @@ -534,7 +634,9 @@ func (r *cachingReadCloser) Read(p []byte) (n int, err error) { n, err = r.R.Read(p) r.buf.Write(p[:n]) if err == io.EOF { - r.OnEOF(bytes.NewReader(r.buf.Bytes())) + if err := r.OnEOF(bytes.NewReader(r.buf.Bytes())); err != nil { + return n, err + } } return n, err } diff --git a/httpcache_test.go b/httpcache_test.go index a504641..9a51a11 100644 --- a/httpcache_test.go +++ b/httpcache_test.go @@ -14,6 +14,10 @@ import ( "time" ) +var ( + testContextCache = flag.Bool("test-context-cache", false, "if true, tests the functionality of the ContextCache property") +) + var s struct { server *httptest.Server client http.Client @@ -167,6 +171,11 @@ func teardown() { func resetTest() { s.transport.Cache = NewMemoryCache() + if *testContextCache { + s.transport.ContextCache = cacheAsContextCache{s.transport.Cache} + } else { + s.transport.ContextCache = nil + } clock = &realClock{} } @@ -223,8 +232,8 @@ func TestCacheableMethod(t *testing.T) { if resp.StatusCode != http.StatusOK { t.Errorf("response status code isn't 200 OK: %v", resp.StatusCode) } - if resp.Header.Get(XFromCache) != "" { - t.Errorf("XFromCache header isn't blank") + if got := resp.Header.Get(XFromCache); got != "" { + t.Errorf("XFromCache header isn't blank: %q", got) } } } @@ -305,8 +314,8 @@ func TestDontStorePartialRangeInCache(t *testing.T) { if resp.StatusCode != http.StatusOK { t.Errorf("response status code isn't 200 OK: %v", resp.StatusCode) } - if resp.Header.Get(XFromCache) != "" { - t.Error("XFromCache header isn't blank") + if got := resp.Header.Get(XFromCache); got != "" { + t.Errorf("XFromCache header isn't blank: %q", got) } } { @@ -469,8 +478,8 @@ func TestGetOnlyIfCachedMiss(t *testing.T) { t.Fatal(err) } defer resp.Body.Close() - if resp.Header.Get(XFromCache) != "" { - t.Fatal("XFromCache header isn't blank") + if got := resp.Header.Get(XFromCache); got != "" { + t.Errorf("XFromCache header isn't blank: %q", got) } if resp.StatusCode != http.StatusGatewayTimeout { t.Fatalf("response status code isn't 504 GatewayTimeout: %v", resp.StatusCode) @@ -490,8 +499,8 @@ func TestGetNoStoreRequest(t *testing.T) { t.Fatal(err) } defer resp.Body.Close() - if resp.Header.Get(XFromCache) != "" { - t.Fatal("XFromCache header isn't blank") + if got := resp.Header.Get(XFromCache); got != "" { + t.Errorf("XFromCache header isn't blank: %q", got) } } {