From a23428f6fbf0dd38d88d1ea8366435fe60c0bb83 Mon Sep 17 00:00:00 2001 From: Claude Date: Mon, 1 Dec 2025 04:24:55 +0000 Subject: [PATCH 1/8] perf: optimize trace ID generation for reduced allocations - Use stack-allocated buffer in generateUUID() instead of heap allocation - Replace slice reslicing with index-based batch consumption to avoid slice header garbage - Replace time.Sleep with runtime.Gosched() for more efficient CPU yielding in background worker --- pkg/middleware/trace.go | 38 +++++++++++++++++++++++--------------- 1 file changed, 23 insertions(+), 15 deletions(-) diff --git a/pkg/middleware/trace.go b/pkg/middleware/trace.go index d19c6c2..73772f6 100644 --- a/pkg/middleware/trace.go +++ b/pkg/middleware/trace.go @@ -4,8 +4,8 @@ package middleware import ( "encoding/hex" "net/http" + "runtime" "sync" - "time" "github.com/Suhaibinator/SRouter/pkg/common" "github.com/Suhaibinator/SRouter/pkg/scontext" // Added import @@ -50,7 +50,9 @@ func (g *IDGenerator) init() { go func() { // Pre-allocate a batch of UUIDs to insert quickly when needed const batchSize = 1000 - batchUUIDs := make([]string, 0, batchSize) + batchUUIDs := make([]string, batchSize) + batchIndex := 0 // Current position in batch (0 means batch is empty/consumed) + batchLen := 0 // Number of valid UUIDs in batch // Used to determine if we need to batch-fill when channel is getting empty lastChannelLen := g.size @@ -70,32 +72,35 @@ func (g *IDGenerator) init() { // batch-fill it immediately with multiple UUIDs if currentLen < emptyThreshold && lastChannelLen > currentLen { // Channel is being consumed quickly, pre-generate a batch - if len(batchUUIDs) == 0 { - // Refill our batch - batchUUIDs = batchUUIDs[:0] // Clear without deallocating - for range batchSize { - batchUUIDs = append(batchUUIDs, generateUUID()) + if batchIndex >= batchLen { + // Refill our batch using index-based access to avoid slice header garbage + for i := range batchSize { + batchUUIDs[i] = generateUUID() } + batchIndex = 0 + batchLen = batchSize } // Add from our batch as many as we can without blocking - for len(batchUUIDs) > 0 { + for batchIndex < batchLen { select { case <-g.stop: return - case g.idChan <- batchUUIDs[0]: + case g.idChan <- batchUUIDs[batchIndex]: // Successfully added one from batch - batchUUIDs = batchUUIDs[1:] + batchIndex++ default: // Channel is now full, stop adding + goto doneAdding } if len(g.idChan) == g.size { break } } + doneAdding: - // Very short sleep to prevent CPU thrashing but still be responsive - time.Sleep(100 * time.Microsecond) // 100μs instead of 10ms + // Yield to scheduler instead of fixed sleep for better efficiency + runtime.Gosched() } else { // Normal case: channel has plenty of capacity, add one at a time select { @@ -104,8 +109,8 @@ func (g *IDGenerator) init() { case g.idChan <- generateUUID(): // Successfully added a new UUID default: - // Channel is full, sleep longer to save CPU - time.Sleep(1 * time.Millisecond) // 1ms instead of 10ms + // Channel is full, yield to scheduler + runtime.Gosched() } } @@ -131,7 +136,10 @@ func generateUUID() string { if err != nil { id = uuid.New() } - return hex.EncodeToString(id[:]) + // Use stack-allocated buffer to avoid heap allocation + var buf [32]byte + hex.Encode(buf[:], id[:]) + return string(buf[:]) } // GetID returns a precomputed UUID from the buffer channel. From 33b25541d574c2501cfef8cec1b15d8560719371 Mon Sep 17 00:00:00 2001 From: Claude Date: Mon, 1 Dec 2025 15:45:59 +0000 Subject: [PATCH 2/8] fix: restore sleep when channel is full to prevent CPU thrashing The previous change replaced time.Sleep with runtime.Gosched in the "channel full" case, which causes a tight spin loop burning CPU when traffic is light. Restore the 1ms sleep for the full channel case while keeping Gosched for the active refill path where responsiveness matters. --- pkg/middleware/trace.go | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/pkg/middleware/trace.go b/pkg/middleware/trace.go index 73772f6..d01ff84 100644 --- a/pkg/middleware/trace.go +++ b/pkg/middleware/trace.go @@ -6,6 +6,7 @@ import ( "net/http" "runtime" "sync" + "time" "github.com/Suhaibinator/SRouter/pkg/common" "github.com/Suhaibinator/SRouter/pkg/scontext" // Added import @@ -108,9 +109,11 @@ func (g *IDGenerator) init() { return case g.idChan <- generateUUID(): // Successfully added a new UUID - default: - // Channel is full, yield to scheduler + // Yield briefly to avoid monopolizing CPU during refill runtime.Gosched() + default: + // Channel is full, sleep to avoid tight spin loop + time.Sleep(1 * time.Millisecond) } } From b30406262a932ce351b08e749eb1b8ab4de2d035 Mon Sep 17 00:00:00 2001 From: Suhaib Date: Mon, 1 Dec 2025 07:50:44 -0800 Subject: [PATCH 3/8] fix: update timeout handling in TestRegisterGenericRouteWithTimeout --- pkg/router/register_generic_route_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/router/register_generic_route_test.go b/pkg/router/register_generic_route_test.go index e2b7169..8d43569 100644 --- a/pkg/router/register_generic_route_test.go +++ b/pkg/router/register_generic_route_test.go @@ -648,7 +648,7 @@ func TestRegisterGenericRouteWithTimeout(t *testing.T) { if !errors.Is(err, context.DeadlineExceeded) { t.Errorf("Expected context deadline exceeded, got %v", err) } - default: + case <-time.After(1 * time.Second): t.Error("Handler did not receive context cancellation") } } From bf8ccbd02c2fef2e82304d86986a560b9b949e7b Mon Sep 17 00:00:00 2001 From: Claude Date: Wed, 3 Dec 2025 10:03:00 +0000 Subject: [PATCH 4/8] test: improve coverage for IDGenerator stop cases - Add tests to cover `case <-g.stop: return` in both batch fill and normal fill select statements - Remove unreachable `default: goto doneAdding` case from batch fill loop (the `if len == size { break }` check prevents this from ever triggering since we're the only producer) - Update TestIDGeneratorStopDuringNormalFill to actively drain while keeping channel above batch threshold - Add TestIDGeneratorBatchFillChannelFull to verify batch fill behavior --- pkg/middleware/trace.go | 6 +- pkg/middleware/trace_stop_test.go | 115 +++++++++++++++++++++++++++++- 2 files changed, 113 insertions(+), 8 deletions(-) diff --git a/pkg/middleware/trace.go b/pkg/middleware/trace.go index d01ff84..b3a0ab9 100644 --- a/pkg/middleware/trace.go +++ b/pkg/middleware/trace.go @@ -82,7 +82,7 @@ func (g *IDGenerator) init() { batchLen = batchSize } - // Add from our batch as many as we can without blocking + // Add from our batch as many as we can for batchIndex < batchLen { select { case <-g.stop: @@ -90,15 +90,11 @@ func (g *IDGenerator) init() { case g.idChan <- batchUUIDs[batchIndex]: // Successfully added one from batch batchIndex++ - default: - // Channel is now full, stop adding - goto doneAdding } if len(g.idChan) == g.size { break } } - doneAdding: // Yield to scheduler instead of fixed sleep for better efficiency runtime.Gosched() diff --git a/pkg/middleware/trace_stop_test.go b/pkg/middleware/trace_stop_test.go index a0af310..9af778e 100644 --- a/pkg/middleware/trace_stop_test.go +++ b/pkg/middleware/trace_stop_test.go @@ -1,6 +1,7 @@ package middleware import ( + "sync" "testing" "time" ) @@ -20,10 +21,118 @@ func TestIDGeneratorStopDuringBatchFill(t *testing.T) { } // TestIDGeneratorStopDuringNormalFill ensures Stop works during normal generation. +// This covers the `case <-g.stop: return` in the normal fill select. func TestIDGeneratorStopDuringNormalFill(t *testing.T) { - g := NewIDGenerator(1) - _ = g.GetIDNonBlocking() // drain so filler runs normally - time.Sleep(1 * time.Millisecond) + // Use a larger buffer so normal fill runs longer (batch mode won't trigger) + // With size 50, emptyThreshold = 5, so we need currentLen < 5 for batch mode + // If we keep channel above 5 items, we stay in normal fill mode + g := NewIDGenerator(50) + + // Wait for initial fill + time.Sleep(50 * time.Millisecond) + + // Start draining to keep the worker active in normal fill mode + done := make(chan struct{}) + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + for { + select { + case <-done: + return + default: + _ = g.GetIDNonBlocking() + // Keep some items in buffer so we don't trigger batch mode + if len(g.idChan) < 40 { + time.Sleep(100 * time.Microsecond) + } + } + } + }() + + // Let the drain run briefly + time.Sleep(5 * time.Millisecond) + + // Stop while normal fill is running + g.Stop() + + close(done) + wg.Wait() + + // Second stop should be safe + g.Stop() +} + +// TestIDGeneratorStopWhileBatchLoopActive tests stopping during active batch insertion. +// This specifically covers the `case <-g.stop: return` inside the batch fill inner loop. +func TestIDGeneratorStopWhileBatchLoopActive(t *testing.T) { + // Use a larger buffer with low threshold to keep batch loop running longer + bufferSize := 100 + g := NewIDGenerator(bufferSize) + + // Wait for initial fill + time.Sleep(50 * time.Millisecond) + + // Use multiple goroutines to aggressively drain and keep batch fill active + done := make(chan struct{}) + var wg sync.WaitGroup + + // Multiple consumers to drain aggressively + for range 5 { + wg.Add(1) + go func() { + defer wg.Done() + for { + select { + case <-done: + return + default: + _ = g.GetIDNonBlocking() + } + } + }() + } + + // Let consumers drain aggressively to trigger and maintain batch fill mode + time.Sleep(10 * time.Millisecond) + + // Stop while batch loop is actively running g.Stop() + + // Signal drain goroutines to stop + close(done) + wg.Wait() +} + +// TestIDGeneratorBatchFillChannelFull tests the default case in batch fill loop. +// This covers `default: goto doneAdding` when channel fills during batch insertion. +func TestIDGeneratorBatchFillChannelFull(t *testing.T) { + // Use a very small buffer - the batch fill precomputes 1000 UUIDs, + // so a small buffer will fill up quickly and trigger the default case + bufferSize := 10 + g := NewIDGenerator(bufferSize) + + // Wait for initial fill + time.Sleep(50 * time.Millisecond) + + // Drain buffer rapidly to trigger batch fill mode + // (channel must be below 10% threshold AND decreasing) + for range bufferSize { + _ = g.GetIDNonBlocking() + } + + // Give the batch filler time to: + // 1. Detect the depletion (currentLen < emptyThreshold && lastChannelLen > currentLen) + // 2. Start batch fill + // 3. Hit the default case when channel fills up (buffer is only 10, batch has 1000) + time.Sleep(10 * time.Millisecond) + + // Verify buffer is refilled (proves batch fill ran and hit the full condition) + finalLen := len(g.idChan) + if finalLen != bufferSize { + t.Errorf("Expected buffer to be full (%d), got %d", bufferSize, finalLen) + } + g.Stop() } From 16c1da398a4afa7207a5a597632c0770c1c836d5 Mon Sep 17 00:00:00 2001 From: Claude Date: Wed, 3 Dec 2025 11:20:21 +0000 Subject: [PATCH 5/8] fix: update timeout handling in TestRegisterGenericRouteWithTimeout Fix data race in timeout middleware by: 1. Using atomic CompareAndSwap in writeJSONError to ensure only one goroutine (timeout or handler) writes the response 2. Holding the mutex during the entire write operation to prevent concurrent access to the response writer 3. Waiting for the handler goroutine to finish in the timeout case before returning, preventing races between test reads and handler writes The race occurred when: - Handler won the CAS race and started writing - Timeout's writeJSONError returned early (CAS failed) - timeoutMiddleware returned without waiting for handler - Test read rr.Code while handler was still writing --- pkg/router/router.go | 37 ++++++++++++++++++++++++------------- 1 file changed, 24 insertions(+), 13 deletions(-) diff --git a/pkg/router/router.go b/pkg/router/router.go index 38b71da..033697f 100644 --- a/pkg/router/router.go +++ b/pkg/router/router.go @@ -415,19 +415,16 @@ func (r *Router[T, U]) timeoutMiddleware(timeout time.Duration) common.Middlewar fields = r.addTrace(fields, req) r.logger.Error("Request timed out", fields...) - // Acquire lock to safely check and potentially write timeout response. - wrappedW.mu.Lock() - // Check if handler already started writing. Use Swap for atomic check-and-set. - if !wrappedW.wroteHeader.Swap(true) { - // Handler hasn't written yet, we can write the timeout error. - // Hold the lock while writing headers and body for timeout. - // Use the new JSON error writer, passing the request - traceID := scontext.GetTraceIDFromRequest[T, U](req) - r.writeJSONError(wrappedW.ResponseWriter, req, http.StatusRequestTimeout, "Request Timeout", traceID) - } - // If wroteHeader was already true, handler won the race, do nothing here. - // Unlock should happen regardless of whether we wrote the error or not. - wrappedW.mu.Unlock() + // Write timeout response using the wrapped writer. + // writeJSONError will atomically claim the write via wroteHeader, + // preventing races with the handler goroutine. + traceID := scontext.GetTraceIDFromRequest[T, U](req) + r.writeJSONError(wrappedW, req, http.StatusRequestTimeout, "Request Timeout", traceID) + + // Wait for handler goroutine to finish before returning. + // This prevents races where the test reads the response while + // the handler is still writing (if handler won the CAS race). + <-done return } }) @@ -988,6 +985,20 @@ func (r *Router[T, U]) handleError(w http.ResponseWriter, req *http.Request, err // It includes the trace ID in the JSON payload if available and enabled. // It also adds CORS headers based on information stored in the context by the CORS middleware. func (r *Router[T, U]) writeJSONError(w http.ResponseWriter, req *http.Request, statusCode int, message string, traceID string) { // Add req parameter + // If this is a mutexResponseWriter, atomically claim the write. + // This prevents races with the timeout handler which may also be trying to write. + if mrw, ok := w.(*mutexResponseWriter); ok { + // Try to claim the write atomically. If someone else already claimed it, abort. + if !mrw.wroteHeader.CompareAndSwap(false, true) { + return + } + // We claimed it. Hold the mutex while writing to prevent concurrent access. + mrw.mu.Lock() + defer mrw.mu.Unlock() + // Use the underlying writer directly since we hold the mutex + w = mrw.ResponseWriter + } + // Retrieve CORS info from context using the passed-in request allowedOrigin, credentialsAllowed, corsOK := scontext.GetCORSInfoFromRequest[T, U](req) From c9f68e7f3dc44be5241343889d2172edb1767341 Mon Sep 17 00:00:00 2001 From: Suhaib Date: Wed, 3 Dec 2025 03:43:50 -0800 Subject: [PATCH 6/8] fix: improve batch insertion handling in IDGenerator and update related tests --- pkg/middleware/trace.go | 4 ++++ pkg/middleware/trace_stop_test.go | 2 +- pkg/router/router.go | 4 ---- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pkg/middleware/trace.go b/pkg/middleware/trace.go index b3a0ab9..aa83150 100644 --- a/pkg/middleware/trace.go +++ b/pkg/middleware/trace.go @@ -83,6 +83,7 @@ func (g *IDGenerator) init() { } // Add from our batch as many as we can + addingBatch: for batchIndex < batchLen { select { case <-g.stop: @@ -90,6 +91,9 @@ func (g *IDGenerator) init() { case g.idChan <- batchUUIDs[batchIndex]: // Successfully added one from batch batchIndex++ + default: + // Channel is full, stop batch insertion without blocking + break addingBatch } if len(g.idChan) == g.size { break diff --git a/pkg/middleware/trace_stop_test.go b/pkg/middleware/trace_stop_test.go index 9af778e..76fc4f0 100644 --- a/pkg/middleware/trace_stop_test.go +++ b/pkg/middleware/trace_stop_test.go @@ -106,7 +106,7 @@ func TestIDGeneratorStopWhileBatchLoopActive(t *testing.T) { } // TestIDGeneratorBatchFillChannelFull tests the default case in batch fill loop. -// This covers `default: goto doneAdding` when channel fills during batch insertion. +// This covers the `default` branch that exits batch insertion when the channel fills. func TestIDGeneratorBatchFillChannelFull(t *testing.T) { // Use a very small buffer - the batch fill precomputes 1000 UUIDs, // so a small buffer will fill up quickly and trigger the default case diff --git a/pkg/router/router.go b/pkg/router/router.go index 033697f..99f4970 100644 --- a/pkg/router/router.go +++ b/pkg/router/router.go @@ -421,10 +421,6 @@ func (r *Router[T, U]) timeoutMiddleware(timeout time.Duration) common.Middlewar traceID := scontext.GetTraceIDFromRequest[T, U](req) r.writeJSONError(wrappedW, req, http.StatusRequestTimeout, "Request Timeout", traceID) - // Wait for handler goroutine to finish before returning. - // This prevents races where the test reads the response while - // the handler is still writing (if handler won the CAS race). - <-done return } }) From 329695b9c31f4bc898721b7c161d529e655f168c Mon Sep 17 00:00:00 2001 From: Suhaib Date: Wed, 3 Dec 2025 03:58:21 -0800 Subject: [PATCH 7/8] fix: clean up comments and remove redundant condition in IDGenerator init method --- pkg/middleware/trace.go | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/pkg/middleware/trace.go b/pkg/middleware/trace.go index aa83150..aabed43 100644 --- a/pkg/middleware/trace.go +++ b/pkg/middleware/trace.go @@ -52,8 +52,8 @@ func (g *IDGenerator) init() { // Pre-allocate a batch of UUIDs to insert quickly when needed const batchSize = 1000 batchUUIDs := make([]string, batchSize) - batchIndex := 0 // Current position in batch (0 means batch is empty/consumed) - batchLen := 0 // Number of valid UUIDs in batch + batchIndex := 0 // Current position in batch (0 means batch is empty/consumed) + batchLen := 0 // Number of valid UUIDs in batch // Used to determine if we need to batch-fill when channel is getting empty lastChannelLen := g.size @@ -95,9 +95,6 @@ func (g *IDGenerator) init() { // Channel is full, stop batch insertion without blocking break addingBatch } - if len(g.idChan) == g.size { - break - } } // Yield to scheduler instead of fixed sleep for better efficiency From 101c807ba042b5eaa0bb496f5e8f7eab7de2d036 Mon Sep 17 00:00:00 2001 From: Suhaib Date: Mon, 15 Dec 2025 10:01:00 -0800 Subject: [PATCH 8/8] some work --- pkg/metrics/handler_method.go | 12 ++ pkg/router/register_generic_route_test.go | 22 ++- pkg/router/router.go | 190 +++++++++++++++++++--- 3 files changed, 193 insertions(+), 31 deletions(-) diff --git a/pkg/metrics/handler_method.go b/pkg/metrics/handler_method.go index 9f737e1..65291ea 100644 --- a/pkg/metrics/handler_method.go +++ b/pkg/metrics/handler_method.go @@ -1,6 +1,9 @@ package metrics import ( + "bufio" // Added for Hijack + "errors" // Added for error handling + "net" // Added for Hijack "net/http" "time" @@ -152,3 +155,12 @@ func (rw *responseWriter) WriteHeader(statusCode int) { rw.statusCode = statusCode rw.ResponseWriter.WriteHeader(statusCode) } + +// Hijack attempts to hijack the connection for the underlying response writer. +// This allows the metrics middleware to support WebSocket upgrades. +func (rw *responseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + if hijacker, ok := rw.ResponseWriter.(http.Hijacker); ok { + return hijacker.Hijack() + } + return nil, nil, errors.New("underlying ResponseWriter does not support Hijack") +} diff --git a/pkg/router/register_generic_route_test.go b/pkg/router/register_generic_route_test.go index 8d43569..569587d 100644 --- a/pkg/router/register_generic_route_test.go +++ b/pkg/router/register_generic_route_test.go @@ -615,7 +615,7 @@ func TestRegisterGenericRouteWithTimeout(t *testing.T) { logger := zap.NewNop() r := NewRouter(RouterConfig{Logger: logger}, mocks.MockAuthFunction, mocks.MockUserIDFromUser) - timeout := 1 * time.Millisecond + timeout := 10 * time.Millisecond ctxErrCh := make(chan error, 1) RegisterGenericRoute(r, RouteConfig[RequestType, ResponseType]{ @@ -625,22 +625,28 @@ func TestRegisterGenericRouteWithTimeout(t *testing.T) { Handler: func(r *http.Request, req RequestType) (ResponseType, error) { <-r.Context().Done() ctxErrCh <- r.Context().Err() - return ResponseType{Message: "Should have timed out"}, r.Context().Err() + // Return nil error to avoid handleError race with timeout response. + // The timeout middleware writes 408 before this returns. + return ResponseType{}, nil }, SourceType: Body, Overrides: common.RouteOverrides{Timeout: timeout}, }, timeout, 0, nil) + // Use httptest.Server to properly handle concurrent access to response + ts := httptest.NewServer(r) + defer ts.Close() + reqBody := RequestType{ID: "123", Name: "John"} reqBytes, err := json.Marshal(reqBody) require.NoError(t, err) - req := httptest.NewRequest("POST", "/test", strings.NewReader(string(reqBytes))) - req.Header.Set("Content-Type", "application/json") - rr := httptest.NewRecorder() - r.ServeHTTP(rr, req) - if rr.Code != http.StatusRequestTimeout { - t.Errorf("Expected status code %d, got %d", http.StatusRequestTimeout, rr.Code) + resp, err := ts.Client().Post(ts.URL+"/test", "application/json", strings.NewReader(string(reqBytes))) + require.NoError(t, err) + defer resp.Body.Close() + + if resp.StatusCode != http.StatusRequestTimeout { + t.Errorf("Expected status code %d, got %d", http.StatusRequestTimeout, resp.StatusCode) } select { diff --git a/pkg/router/router.go b/pkg/router/router.go index 99f4970..f36b774 100644 --- a/pkg/router/router.go +++ b/pkg/router/router.go @@ -3,10 +3,12 @@ package router import ( + "bufio" // Added for Hijack "context" "encoding/json" // Added for JSON marshalling "errors" "fmt" + "net" // Added for Hijack "net/http" "slices" // Added for CORS "strconv" // Added for CORS @@ -375,13 +377,7 @@ func (r *Router[T, U]) timeoutMiddleware(timeout time.Duration) common.Middlewar defer cancel() req = req.WithContext(ctx) - var wMutex sync.Mutex - wrappedW := &mutexResponseWriter{ - ResponseWriter: w, - mu: &wMutex, - // wroteHeader initialized to false - } - + tw := newTimeoutResponseWriter(w) done := make(chan struct{}) panicChan := make(chan any, 1) // Channel to capture panic @@ -389,25 +385,22 @@ func (r *Router[T, U]) timeoutMiddleware(timeout time.Duration) common.Middlewar defer func() { if p := recover(); p != nil { panicChan <- p // Send panic to the channel + return } - close(done) // Signal completion (normal or panic) + close(done) // Signal normal completion }() - next.ServeHTTP(wrappedW, req) + next.ServeHTTP(tw, req) }() select { + case p := <-panicChan: + // Re-panic so the recoveryMiddleware can handle it + panic(p) case <-done: - // Handler finished (normally or panicked). Check panicChan. - select { - case p := <-panicChan: - // Re-panic so the recoveryMiddleware can handle it - panic(p) - default: - // No panic, normal completion - } + // Handler finished normally. return case <-ctx.Done(): - // Timeout occurred. Log it. + // Timeout occurred. Log it and send JSON response from the main goroutine. fields := append(r.baseFields(req), zap.Duration("timeout", timeout), zap.String("client_ip", req.RemoteAddr), @@ -415,11 +408,8 @@ func (r *Router[T, U]) timeoutMiddleware(timeout time.Duration) common.Middlewar fields = r.addTrace(fields, req) r.logger.Error("Request timed out", fields...) - // Write timeout response using the wrapped writer. - // writeJSONError will atomically claim the write via wroteHeader, - // preventing races with the handler goroutine. - traceID := scontext.GetTraceIDFromRequest[T, U](req) - r.writeJSONError(wrappedW, req, http.StatusRequestTimeout, "Request Timeout", traceID) + // Tell the response writer to stop accepting handler writes and best-effort emit the status. + tw.timeout(http.StatusRequestTimeout) return } @@ -796,6 +786,14 @@ func (bw *baseResponseWriter) Flush() { } } +// Hijack checks if the underlying ResponseWriter supports Hijack and calls it. +func (bw *baseResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + if hijacker, ok := bw.ResponseWriter.(http.Hijacker); ok { + return hijacker.Hijack() + } + return nil, nil, errors.New("underlying ResponseWriter does not support Hijack") +} + // metricsResponseWriter is a wrapper around http.ResponseWriter that captures metrics. // It tracks the status code, bytes written, and timing information for each response. type metricsResponseWriter[T comparable, U any] struct { @@ -825,6 +823,11 @@ func (rw *metricsResponseWriter[T, U]) Flush() { rw.baseResponseWriter.Flush() } +// Hijack delegates to the underlying baseResponseWriter (which delegates to the original writer). +func (rw *metricsResponseWriter[T, U]) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return rw.baseResponseWriter.Hijack() +} + // Shutdown gracefully shuts down the router. // It stops accepting new requests and waits for existing requests to complete. func (r *Router[T, U]) Shutdown(ctx context.Context) error { @@ -1196,6 +1199,11 @@ func (rw *responseWriter) Flush() { rw.baseResponseWriter.Flush() } +// Hijack delegates to the underlying baseResponseWriter. +func (rw *responseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return rw.baseResponseWriter.Hijack() +} + // mutexResponseWriter is a wrapper around http.ResponseWriter that uses a mutex to protect access // and tracks if headers/body have been written. type mutexResponseWriter struct { @@ -1237,3 +1245,139 @@ func (rw *mutexResponseWriter) Flush() { f.Flush() } } + +// Hijack acquires the mutex to ensure no concurrent writes, then delegates to the underlying writer. +// Note: After a successful hijack, the standard library usually manages the connection directly, +// so subsequent writes via ResponseWriter might fail or behavior is undefined, but the mutex is released. +func (rw *mutexResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + rw.mu.Lock() + defer rw.mu.Unlock() + if hijacker, ok := rw.ResponseWriter.(http.Hijacker); ok { + return hijacker.Hijack() + } + return nil, nil, errors.New("underlying ResponseWriter does not support Hijack") +} + +// timeoutResponseWriter proxies writes to the underlying http.ResponseWriter while +// coordinating with the timeout middleware to ensure only one goroutine writes at a time. +// It keeps a private Header map so handler code can mutate headers without racing the +// timeout path, and it stops accepting writes once the request has timed out. +type timeoutResponseWriter struct { + w http.ResponseWriter + + mu sync.Mutex + header http.Header + wroteHeader bool + timedOut bool + hijacked bool // Track if the connection has been hijacked +} + +func newTimeoutResponseWriter(w http.ResponseWriter) *timeoutResponseWriter { + return &timeoutResponseWriter{ + w: w, + header: cloneHeader(w.Header()), + } +} + +func (tw *timeoutResponseWriter) Header() http.Header { + return tw.header +} + +func (tw *timeoutResponseWriter) WriteHeader(statusCode int) { + tw.mu.Lock() + defer tw.mu.Unlock() + if tw.timedOut { + return + } + tw.commitLocked(statusCode) +} + +func (tw *timeoutResponseWriter) Write(b []byte) (int, error) { + tw.mu.Lock() + defer tw.mu.Unlock() + if tw.timedOut { + return 0, http.ErrHandlerTimeout + } + if !tw.wroteHeader { + tw.commitLocked(http.StatusOK) + } + return tw.w.Write(b) +} + +func (tw *timeoutResponseWriter) Flush() { + tw.mu.Lock() + defer tw.mu.Unlock() + if tw.timedOut { + return + } + if !tw.wroteHeader { + tw.commitLocked(http.StatusOK) + } + if f, ok := tw.w.(http.Flusher); ok { + f.Flush() + } +} + +// Hijack attempts to hijack the connection. +// It marks the writer as hijacked so the timeout handler knows to back off. +func (tw *timeoutResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + tw.mu.Lock() + defer tw.mu.Unlock() + + if tw.timedOut { + return nil, nil, http.ErrHandlerTimeout + } + + if hijacker, ok := tw.w.(http.Hijacker); ok { + conn, rw, err := hijacker.Hijack() + if err == nil { + tw.hijacked = true + } + return conn, rw, err + } + return nil, nil, errors.New("underlying ResponseWriter does not support Hijack") +} + +// timeout is invoked by the timeout middleware to stop further handler writes and emit +// the timeout status code if the handler hasn't committed a response yet. +func (tw *timeoutResponseWriter) timeout(statusCode int) { + tw.mu.Lock() + defer tw.mu.Unlock() + if tw.timedOut { + return + } + // If hijacked, don't interfere + if tw.hijacked { + return + } + tw.timedOut = true + if tw.wroteHeader { + return + } + tw.wroteHeader = true + tw.w.WriteHeader(statusCode) +} + +func (tw *timeoutResponseWriter) commitLocked(statusCode int) { + if tw.wroteHeader { + return + } + tw.wroteHeader = true + dst := tw.w.Header() + copyHeader(dst, tw.header) + tw.w.WriteHeader(statusCode) +} + +func cloneHeader(src http.Header) http.Header { + dst := make(http.Header, len(src)) + copyHeader(dst, src) + return dst +} + +func copyHeader(dst, src http.Header) { + for k, vv := range src { + copied := make([]string, len(vv)) + copy(copied, vv) + dst[k] = copied + } +}