diff --git a/buffer/buffer.go b/buffer/buffer.go index 2b1165c5..30ff1f03 100644 --- a/buffer/buffer.go +++ b/buffer/buffer.go @@ -60,9 +60,11 @@ var errHandler utils.ErrorHandler = &SizeErrHandler{} // Buffer is responsible for buffering requests and responses // It buffers large requests and responses to disk,. type Buffer struct { + disableRequest bool maxRequestBodyBytes int64 memRequestBodyBytes int64 + disableResponse bool maxResponseBodyBytes int64 memResponseBodyBytes int64 @@ -109,6 +111,12 @@ func (b *Buffer) Wrap(next http.Handler) error { } func (b *Buffer) ServeHTTP(w http.ResponseWriter, req *http.Request) { + if b.disableRequest && b.disableResponse { + b.next.ServeHTTP(w, req) + + return + } + if b.verbose { dump := utils.DumpHTTPRequest(req) @@ -116,60 +124,75 @@ func (b *Buffer) ServeHTTP(w http.ResponseWriter, req *http.Request) { defer b.log.Debug("vulcand/oxy/buffer: completed ServeHttp on request: %s", dump) } - if err := b.checkLimit(req); err != nil { - b.log.Error("vulcand/oxy/buffer: request body over limit, err: %v", err) - b.errHandler.ServeHTTP(w, req, err) + var body multibuf.MultiReader - return - } + var totalSize int64 - // Read the body while keeping limits in mind. This reader controls the maximum bytes - // to read into memory and disk. This reader returns an error if the total request size exceeds the - // predefined MaxSizeBytes. This can occur if we got chunked request, in this case ContentLength would be set to -1 - // and the reader would be unbounded bufio in the http.Server - body, err := multibuf.New(req.Body, multibuf.MaxBytes(b.maxRequestBodyBytes), multibuf.MemBytes(b.memRequestBodyBytes)) - if err != nil || body == nil { - if req.Context().Err() != nil { - b.log.Error("vulcand/oxy/buffer: error when reading request body, err: %v", req.Context().Err()) - b.errHandler.ServeHTTP(w, req, req.Context().Err()) + outReq := req + + if !b.disableRequest { + if err := b.checkLimit(req); err != nil { + b.log.Error("vulcand/oxy/buffer: request body over limit, err: %v", err) + b.errHandler.ServeHTTP(w, req, err) return } - b.log.Error("vulcand/oxy/buffer: error when reading request body, err: %v", err) - b.errHandler.ServeHTTP(w, req, err) + // Read the body while keeping limits in mind. This reader controls the maximum bytes + // to read into memory and disk. This reader returns an error if the total request size exceeds the + // predefined MaxSizeBytes. This can occur if we got chunked request, in this case ContentLength would be set to -1 + // and the reader would be unbounded bufio in the http.Server + var err error - return - } + body, err = multibuf.New(req.Body, multibuf.MaxBytes(b.maxRequestBodyBytes), multibuf.MemBytes(b.memRequestBodyBytes)) + if err != nil || body == nil { + if req.Context().Err() != nil { + b.log.Error("vulcand/oxy/buffer: error when reading request body, err: %v", req.Context().Err()) + b.errHandler.ServeHTTP(w, req, req.Context().Err()) - // Set request body to buffered reader that can replay the read and execute Seek - // Note that we don't change the original request body as it's handled by the http server - // and we don't want to mess with standard library - defer func() { - if body != nil { - errClose := body.Close() - if errClose != nil { - b.log.Error("vulcand/oxy/buffer: failed to close body, err: %v", errClose) + return + } + + b.log.Error("vulcand/oxy/buffer: error when reading request body, err: %v", err) + b.errHandler.ServeHTTP(w, req, err) + + return + } + + // Set request body to buffered reader that can replay the read and execute Seek + // Note that we don't change the original request body as it's handled by the http server + // and we don't want to mess with standard library + defer func() { + if body != nil { + errClose := body.Close() + if errClose != nil { + b.log.Error("vulcand/oxy/buffer: failed to close body, err: %v", errClose) + } } + }() + + // We need to set ContentLength based on known request size. The incoming request may have been + // set without content length or using chunked TransferEncoding + totalSize, err = body.Size() + if err != nil { + b.log.Error("vulcand/oxy/buffer: failed to get request size, err: %v", err) + b.errHandler.ServeHTTP(w, req, err) + + return } - }() - // We need to set ContentLength based on known request size. The incoming request may have been - // set without content length or using chunked TransferEncoding - totalSize, err := body.Size() - if err != nil { - b.log.Error("vulcand/oxy/buffer: failed to get request size, err: %v", err) - b.errHandler.ServeHTTP(w, req, err) + if totalSize == 0 { + body = nil + } - return + outReq = b.copyRequest(req, body, totalSize) } - if totalSize == 0 { - body = nil + if b.disableResponse { + b.next.ServeHTTP(w, outReq) + return } - outReq := b.copyRequest(req, body, totalSize) - attempt := 1 for { @@ -220,7 +243,7 @@ func (b *Buffer) ServeHTTP(w http.ResponseWriter, req *http.Request) { reader = rdr } - if (b.retryPredicate == nil || attempt > DefaultMaxRetryAttempts) || + if body == nil || (b.retryPredicate == nil || attempt > DefaultMaxRetryAttempts) || !b.retryPredicate(&context{r: req, attempt: attempt, responseCode: bw.code}) { utils.CopyHeaders(w.Header(), bw.Header()) w.WriteHeader(bw.code) @@ -236,7 +259,7 @@ func (b *Buffer) ServeHTTP(w http.ResponseWriter, req *http.Request) { if body != nil { if _, err := body.Seek(0, 0); err != nil { - b.log.Error("vulcand/oxy/buffer: failed to rewind response body, err: %v", err) + b.log.Error("vulcand/oxy/buffer: failed to rewind request body, err: %v", err) b.errHandler.ServeHTTP(w, req, err) return diff --git a/buffer/buffer_test.go b/buffer/buffer_test.go index e172e7e1..2e0d1569 100644 --- a/buffer/buffer_test.go +++ b/buffer/buffer_test.go @@ -496,3 +496,92 @@ func TestBuffer_GRPC_OKResponse(t *testing.T) { assert.Equal(t, http.StatusOK, re.StatusCode) assert.Equal(t, "grpc-body", string(body)) } + +func TestBuffer_disableRequestBuffer(t *testing.T) { + var ( + reqBody string + contentLength int64 + actuallyBuffered bool + ) + + srv := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { + body, err := io.ReadAll(req.Body) + require.NoError(t, err) + + reqBody = string(body) + contentLength = req.ContentLength + // When buffering is disabled, chunked requests should preserve their transfer encoding, and have no content-length. + actuallyBuffered = contentLength > 0 || len(req.TransferEncoding) == 0 + _, _ = w.Write([]byte("response")) + }) + t.Cleanup(srv.Close) + + fwd := forward.New(false) + rdr := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + req.URL = testutils.MustParseRequestURI(srv.URL) + fwd.ServeHTTP(w, req) + }) + + // buffer with disabled request buffering. + st, err := New(rdr, DisableRequestBuffer()) + require.NoError(t, err) + + proxy := httptest.NewServer(st) + t.Cleanup(proxy.Close) + + // Send a chunked request - when buffering is disabled, it should remain chunked. + conn, err := net.Dial("tcp", testutils.MustParseRequestURI(proxy.URL).Host) + require.NoError(t, err) + t.Cleanup(func() { + _ = conn.Close() + }) + + _, _ = fmt.Fprintf(conn, "POST / HTTP/1.1\r\nHost: %s\r\nTransfer-Encoding: chunked\r\n\r\n4\r\ntest\r\n0\r\n\r\n", testutils.MustParseRequestURI(proxy.URL).Host) + status, err := bufio.NewReader(conn).ReadString('\n') + require.NoError(t, err) + + assert.Equal(t, "HTTP/1.1 200 OK\r\n", status) + assert.Equal(t, "test", reqBody) + // When buffering is disabled, chunked encoding should be preserved (not converted to Content-Length). + assert.False(t, actuallyBuffered, "Request should not have been buffered") + assert.Equal(t, int64(-1), contentLength, "Content-Length should be -1 for chunked requests when not buffered") +} + +func TestBuffer_disableResponseBuffer(t *testing.T) { + largeResponseBody := strings.Repeat("A", 1000) + srv := testutils.NewHandler(func(w http.ResponseWriter, _ *http.Request) { + _, _ = w.Write([]byte(largeResponseBody)) + }) + t.Cleanup(srv.Close) + + fwd := forward.New(false) + rdr := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + req.URL = testutils.MustParseRequestURI(srv.URL) + fwd.ServeHTTP(w, req) + }) + + // buffer with a small max response size. + st, err := New(rdr, MaxResponseBodyBytes(4)) + require.NoError(t, err) + + proxy := httptest.NewServer(st) + t.Cleanup(proxy.Close) + + resp, _, err := testutils.Get(proxy.URL) + require.NoError(t, err) + // Response should not pass through as it exceeds the limit. + assert.Equal(t, http.StatusInternalServerError, resp.StatusCode) + + // buffer with disabled response buffering and a small max response size. + st, err = New(rdr, DisableResponseBuffer(), MaxResponseBodyBytes(4)) + require.NoError(t, err) + + proxy2 := httptest.NewServer(st) + t.Cleanup(proxy2.Close) + + resp2, body, err := testutils.Get(proxy2.URL) + require.NoError(t, err) + // Response should pass through even though it exceeds the limit, because buffering has been disabled. + assert.Equal(t, http.StatusOK, resp2.StatusCode) + assert.Equal(t, largeResponseBody, string(body)) +} diff --git a/buffer/options.go b/buffer/options.go index 0ba9d83a..1a16747f 100644 --- a/buffer/options.go +++ b/buffer/options.go @@ -69,6 +69,24 @@ func ErrorHandler(h utils.ErrorHandler) Option { } } +// DisableRequestBuffer disables request buffering. +func DisableRequestBuffer() Option { + return func(b *Buffer) error { + b.disableRequest = true + + return nil + } +} + +// DisableResponseBuffer disables response buffering. +func DisableResponseBuffer() Option { + return func(b *Buffer) error { + b.disableResponse = true + + return nil + } +} + // MaxRequestBodyBytes sets the maximum request body size in bytes. func MaxRequestBodyBytes(m int64) Option { return func(b *Buffer) error {