diff --git a/pkg/metrics/handler_method.go b/pkg/metrics/handler_method.go index 9f737e1..65291ea 100644 --- a/pkg/metrics/handler_method.go +++ b/pkg/metrics/handler_method.go @@ -1,6 +1,9 @@ package metrics import ( + "bufio" // Added for Hijack + "errors" // Added for error handling + "net" // Added for Hijack "net/http" "time" @@ -152,3 +155,12 @@ func (rw *responseWriter) WriteHeader(statusCode int) { rw.statusCode = statusCode rw.ResponseWriter.WriteHeader(statusCode) } + +// Hijack attempts to hijack the connection for the underlying response writer. +// This allows the metrics middleware to support WebSocket upgrades. +func (rw *responseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + if hijacker, ok := rw.ResponseWriter.(http.Hijacker); ok { + return hijacker.Hijack() + } + return nil, nil, errors.New("underlying ResponseWriter does not support Hijack") +} diff --git a/pkg/middleware/trace.go b/pkg/middleware/trace.go index d19c6c2..aabed43 100644 --- a/pkg/middleware/trace.go +++ b/pkg/middleware/trace.go @@ -4,6 +4,7 @@ package middleware import ( "encoding/hex" "net/http" + "runtime" "sync" "time" @@ -50,7 +51,9 @@ func (g *IDGenerator) init() { go func() { // Pre-allocate a batch of UUIDs to insert quickly when needed const batchSize = 1000 - batchUUIDs := make([]string, 0, batchSize) + batchUUIDs := make([]string, batchSize) + batchIndex := 0 // Current position in batch (0 means batch is empty/consumed) + batchLen := 0 // Number of valid UUIDs in batch // Used to determine if we need to batch-fill when channel is getting empty lastChannelLen := g.size @@ -70,32 +73,32 @@ func (g *IDGenerator) init() { // batch-fill it immediately with multiple UUIDs if currentLen < emptyThreshold && lastChannelLen > currentLen { // Channel is being consumed quickly, pre-generate a batch - if len(batchUUIDs) == 0 { - // Refill our batch - batchUUIDs = batchUUIDs[:0] // Clear without deallocating - for range batchSize { - batchUUIDs = append(batchUUIDs, generateUUID()) + if batchIndex >= batchLen { + // Refill our batch using index-based access to avoid slice header garbage + for i := range batchSize { + batchUUIDs[i] = generateUUID() } + batchIndex = 0 + batchLen = batchSize } - // Add from our batch as many as we can without blocking - for len(batchUUIDs) > 0 { + // Add from our batch as many as we can + addingBatch: + for batchIndex < batchLen { select { case <-g.stop: return - case g.idChan <- batchUUIDs[0]: + case g.idChan <- batchUUIDs[batchIndex]: // Successfully added one from batch - batchUUIDs = batchUUIDs[1:] + batchIndex++ default: - // Channel is now full, stop adding - } - if len(g.idChan) == g.size { - break + // Channel is full, stop batch insertion without blocking + break addingBatch } } - // Very short sleep to prevent CPU thrashing but still be responsive - time.Sleep(100 * time.Microsecond) // 100μs instead of 10ms + // Yield to scheduler instead of fixed sleep for better efficiency + runtime.Gosched() } else { // Normal case: channel has plenty of capacity, add one at a time select { @@ -103,9 +106,11 @@ func (g *IDGenerator) init() { return case g.idChan <- generateUUID(): // Successfully added a new UUID + // Yield briefly to avoid monopolizing CPU during refill + runtime.Gosched() default: - // Channel is full, sleep longer to save CPU - time.Sleep(1 * time.Millisecond) // 1ms instead of 10ms + // Channel is full, sleep to avoid tight spin loop + time.Sleep(1 * time.Millisecond) } } @@ -131,7 +136,10 @@ func generateUUID() string { if err != nil { id = uuid.New() } - return hex.EncodeToString(id[:]) + // Use stack-allocated buffer to avoid heap allocation + var buf [32]byte + hex.Encode(buf[:], id[:]) + return string(buf[:]) } // GetID returns a precomputed UUID from the buffer channel. diff --git a/pkg/middleware/trace_stop_test.go b/pkg/middleware/trace_stop_test.go index a0af310..76fc4f0 100644 --- a/pkg/middleware/trace_stop_test.go +++ b/pkg/middleware/trace_stop_test.go @@ -1,6 +1,7 @@ package middleware import ( + "sync" "testing" "time" ) @@ -20,10 +21,118 @@ func TestIDGeneratorStopDuringBatchFill(t *testing.T) { } // TestIDGeneratorStopDuringNormalFill ensures Stop works during normal generation. +// This covers the `case <-g.stop: return` in the normal fill select. func TestIDGeneratorStopDuringNormalFill(t *testing.T) { - g := NewIDGenerator(1) - _ = g.GetIDNonBlocking() // drain so filler runs normally - time.Sleep(1 * time.Millisecond) + // Use a larger buffer so normal fill runs longer (batch mode won't trigger) + // With size 50, emptyThreshold = 5, so we need currentLen < 5 for batch mode + // If we keep channel above 5 items, we stay in normal fill mode + g := NewIDGenerator(50) + + // Wait for initial fill + time.Sleep(50 * time.Millisecond) + + // Start draining to keep the worker active in normal fill mode + done := make(chan struct{}) + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + for { + select { + case <-done: + return + default: + _ = g.GetIDNonBlocking() + // Keep some items in buffer so we don't trigger batch mode + if len(g.idChan) < 40 { + time.Sleep(100 * time.Microsecond) + } + } + } + }() + + // Let the drain run briefly + time.Sleep(5 * time.Millisecond) + + // Stop while normal fill is running + g.Stop() + + close(done) + wg.Wait() + + // Second stop should be safe + g.Stop() +} + +// TestIDGeneratorStopWhileBatchLoopActive tests stopping during active batch insertion. +// This specifically covers the `case <-g.stop: return` inside the batch fill inner loop. +func TestIDGeneratorStopWhileBatchLoopActive(t *testing.T) { + // Use a larger buffer with low threshold to keep batch loop running longer + bufferSize := 100 + g := NewIDGenerator(bufferSize) + + // Wait for initial fill + time.Sleep(50 * time.Millisecond) + + // Use multiple goroutines to aggressively drain and keep batch fill active + done := make(chan struct{}) + var wg sync.WaitGroup + + // Multiple consumers to drain aggressively + for range 5 { + wg.Add(1) + go func() { + defer wg.Done() + for { + select { + case <-done: + return + default: + _ = g.GetIDNonBlocking() + } + } + }() + } + + // Let consumers drain aggressively to trigger and maintain batch fill mode + time.Sleep(10 * time.Millisecond) + + // Stop while batch loop is actively running g.Stop() + + // Signal drain goroutines to stop + close(done) + wg.Wait() +} + +// TestIDGeneratorBatchFillChannelFull tests the default case in batch fill loop. +// This covers the `default` branch that exits batch insertion when the channel fills. +func TestIDGeneratorBatchFillChannelFull(t *testing.T) { + // Use a very small buffer - the batch fill precomputes 1000 UUIDs, + // so a small buffer will fill up quickly and trigger the default case + bufferSize := 10 + g := NewIDGenerator(bufferSize) + + // Wait for initial fill + time.Sleep(50 * time.Millisecond) + + // Drain buffer rapidly to trigger batch fill mode + // (channel must be below 10% threshold AND decreasing) + for range bufferSize { + _ = g.GetIDNonBlocking() + } + + // Give the batch filler time to: + // 1. Detect the depletion (currentLen < emptyThreshold && lastChannelLen > currentLen) + // 2. Start batch fill + // 3. Hit the default case when channel fills up (buffer is only 10, batch has 1000) + time.Sleep(10 * time.Millisecond) + + // Verify buffer is refilled (proves batch fill ran and hit the full condition) + finalLen := len(g.idChan) + if finalLen != bufferSize { + t.Errorf("Expected buffer to be full (%d), got %d", bufferSize, finalLen) + } + g.Stop() } diff --git a/pkg/router/register_generic_route_test.go b/pkg/router/register_generic_route_test.go index e2b7169..569587d 100644 --- a/pkg/router/register_generic_route_test.go +++ b/pkg/router/register_generic_route_test.go @@ -615,7 +615,7 @@ func TestRegisterGenericRouteWithTimeout(t *testing.T) { logger := zap.NewNop() r := NewRouter(RouterConfig{Logger: logger}, mocks.MockAuthFunction, mocks.MockUserIDFromUser) - timeout := 1 * time.Millisecond + timeout := 10 * time.Millisecond ctxErrCh := make(chan error, 1) RegisterGenericRoute(r, RouteConfig[RequestType, ResponseType]{ @@ -625,22 +625,28 @@ func TestRegisterGenericRouteWithTimeout(t *testing.T) { Handler: func(r *http.Request, req RequestType) (ResponseType, error) { <-r.Context().Done() ctxErrCh <- r.Context().Err() - return ResponseType{Message: "Should have timed out"}, r.Context().Err() + // Return nil error to avoid handleError race with timeout response. + // The timeout middleware writes 408 before this returns. + return ResponseType{}, nil }, SourceType: Body, Overrides: common.RouteOverrides{Timeout: timeout}, }, timeout, 0, nil) + // Use httptest.Server to properly handle concurrent access to response + ts := httptest.NewServer(r) + defer ts.Close() + reqBody := RequestType{ID: "123", Name: "John"} reqBytes, err := json.Marshal(reqBody) require.NoError(t, err) - req := httptest.NewRequest("POST", "/test", strings.NewReader(string(reqBytes))) - req.Header.Set("Content-Type", "application/json") - rr := httptest.NewRecorder() - r.ServeHTTP(rr, req) - if rr.Code != http.StatusRequestTimeout { - t.Errorf("Expected status code %d, got %d", http.StatusRequestTimeout, rr.Code) + resp, err := ts.Client().Post(ts.URL+"/test", "application/json", strings.NewReader(string(reqBytes))) + require.NoError(t, err) + defer resp.Body.Close() + + if resp.StatusCode != http.StatusRequestTimeout { + t.Errorf("Expected status code %d, got %d", http.StatusRequestTimeout, resp.StatusCode) } select { @@ -648,7 +654,7 @@ func TestRegisterGenericRouteWithTimeout(t *testing.T) { if !errors.Is(err, context.DeadlineExceeded) { t.Errorf("Expected context deadline exceeded, got %v", err) } - default: + case <-time.After(1 * time.Second): t.Error("Handler did not receive context cancellation") } } diff --git a/pkg/router/router.go b/pkg/router/router.go index 38b71da..f36b774 100644 --- a/pkg/router/router.go +++ b/pkg/router/router.go @@ -3,10 +3,12 @@ package router import ( + "bufio" // Added for Hijack "context" "encoding/json" // Added for JSON marshalling "errors" "fmt" + "net" // Added for Hijack "net/http" "slices" // Added for CORS "strconv" // Added for CORS @@ -375,13 +377,7 @@ func (r *Router[T, U]) timeoutMiddleware(timeout time.Duration) common.Middlewar defer cancel() req = req.WithContext(ctx) - var wMutex sync.Mutex - wrappedW := &mutexResponseWriter{ - ResponseWriter: w, - mu: &wMutex, - // wroteHeader initialized to false - } - + tw := newTimeoutResponseWriter(w) done := make(chan struct{}) panicChan := make(chan any, 1) // Channel to capture panic @@ -389,25 +385,22 @@ func (r *Router[T, U]) timeoutMiddleware(timeout time.Duration) common.Middlewar defer func() { if p := recover(); p != nil { panicChan <- p // Send panic to the channel + return } - close(done) // Signal completion (normal or panic) + close(done) // Signal normal completion }() - next.ServeHTTP(wrappedW, req) + next.ServeHTTP(tw, req) }() select { + case p := <-panicChan: + // Re-panic so the recoveryMiddleware can handle it + panic(p) case <-done: - // Handler finished (normally or panicked). Check panicChan. - select { - case p := <-panicChan: - // Re-panic so the recoveryMiddleware can handle it - panic(p) - default: - // No panic, normal completion - } + // Handler finished normally. return case <-ctx.Done(): - // Timeout occurred. Log it. + // Timeout occurred. Log it and send JSON response from the main goroutine. fields := append(r.baseFields(req), zap.Duration("timeout", timeout), zap.String("client_ip", req.RemoteAddr), @@ -415,19 +408,9 @@ func (r *Router[T, U]) timeoutMiddleware(timeout time.Duration) common.Middlewar fields = r.addTrace(fields, req) r.logger.Error("Request timed out", fields...) - // Acquire lock to safely check and potentially write timeout response. - wrappedW.mu.Lock() - // Check if handler already started writing. Use Swap for atomic check-and-set. - if !wrappedW.wroteHeader.Swap(true) { - // Handler hasn't written yet, we can write the timeout error. - // Hold the lock while writing headers and body for timeout. - // Use the new JSON error writer, passing the request - traceID := scontext.GetTraceIDFromRequest[T, U](req) - r.writeJSONError(wrappedW.ResponseWriter, req, http.StatusRequestTimeout, "Request Timeout", traceID) - } - // If wroteHeader was already true, handler won the race, do nothing here. - // Unlock should happen regardless of whether we wrote the error or not. - wrappedW.mu.Unlock() + // Tell the response writer to stop accepting handler writes and best-effort emit the status. + tw.timeout(http.StatusRequestTimeout) + return } }) @@ -803,6 +786,14 @@ func (bw *baseResponseWriter) Flush() { } } +// Hijack checks if the underlying ResponseWriter supports Hijack and calls it. +func (bw *baseResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + if hijacker, ok := bw.ResponseWriter.(http.Hijacker); ok { + return hijacker.Hijack() + } + return nil, nil, errors.New("underlying ResponseWriter does not support Hijack") +} + // metricsResponseWriter is a wrapper around http.ResponseWriter that captures metrics. // It tracks the status code, bytes written, and timing information for each response. type metricsResponseWriter[T comparable, U any] struct { @@ -832,6 +823,11 @@ func (rw *metricsResponseWriter[T, U]) Flush() { rw.baseResponseWriter.Flush() } +// Hijack delegates to the underlying baseResponseWriter (which delegates to the original writer). +func (rw *metricsResponseWriter[T, U]) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return rw.baseResponseWriter.Hijack() +} + // Shutdown gracefully shuts down the router. // It stops accepting new requests and waits for existing requests to complete. func (r *Router[T, U]) Shutdown(ctx context.Context) error { @@ -988,6 +984,20 @@ func (r *Router[T, U]) handleError(w http.ResponseWriter, req *http.Request, err // It includes the trace ID in the JSON payload if available and enabled. // It also adds CORS headers based on information stored in the context by the CORS middleware. func (r *Router[T, U]) writeJSONError(w http.ResponseWriter, req *http.Request, statusCode int, message string, traceID string) { // Add req parameter + // If this is a mutexResponseWriter, atomically claim the write. + // This prevents races with the timeout handler which may also be trying to write. + if mrw, ok := w.(*mutexResponseWriter); ok { + // Try to claim the write atomically. If someone else already claimed it, abort. + if !mrw.wroteHeader.CompareAndSwap(false, true) { + return + } + // We claimed it. Hold the mutex while writing to prevent concurrent access. + mrw.mu.Lock() + defer mrw.mu.Unlock() + // Use the underlying writer directly since we hold the mutex + w = mrw.ResponseWriter + } + // Retrieve CORS info from context using the passed-in request allowedOrigin, credentialsAllowed, corsOK := scontext.GetCORSInfoFromRequest[T, U](req) @@ -1189,6 +1199,11 @@ func (rw *responseWriter) Flush() { rw.baseResponseWriter.Flush() } +// Hijack delegates to the underlying baseResponseWriter. +func (rw *responseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return rw.baseResponseWriter.Hijack() +} + // mutexResponseWriter is a wrapper around http.ResponseWriter that uses a mutex to protect access // and tracks if headers/body have been written. type mutexResponseWriter struct { @@ -1230,3 +1245,139 @@ func (rw *mutexResponseWriter) Flush() { f.Flush() } } + +// Hijack acquires the mutex to ensure no concurrent writes, then delegates to the underlying writer. +// Note: After a successful hijack, the standard library usually manages the connection directly, +// so subsequent writes via ResponseWriter might fail or behavior is undefined, but the mutex is released. +func (rw *mutexResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + rw.mu.Lock() + defer rw.mu.Unlock() + if hijacker, ok := rw.ResponseWriter.(http.Hijacker); ok { + return hijacker.Hijack() + } + return nil, nil, errors.New("underlying ResponseWriter does not support Hijack") +} + +// timeoutResponseWriter proxies writes to the underlying http.ResponseWriter while +// coordinating with the timeout middleware to ensure only one goroutine writes at a time. +// It keeps a private Header map so handler code can mutate headers without racing the +// timeout path, and it stops accepting writes once the request has timed out. +type timeoutResponseWriter struct { + w http.ResponseWriter + + mu sync.Mutex + header http.Header + wroteHeader bool + timedOut bool + hijacked bool // Track if the connection has been hijacked +} + +func newTimeoutResponseWriter(w http.ResponseWriter) *timeoutResponseWriter { + return &timeoutResponseWriter{ + w: w, + header: cloneHeader(w.Header()), + } +} + +func (tw *timeoutResponseWriter) Header() http.Header { + return tw.header +} + +func (tw *timeoutResponseWriter) WriteHeader(statusCode int) { + tw.mu.Lock() + defer tw.mu.Unlock() + if tw.timedOut { + return + } + tw.commitLocked(statusCode) +} + +func (tw *timeoutResponseWriter) Write(b []byte) (int, error) { + tw.mu.Lock() + defer tw.mu.Unlock() + if tw.timedOut { + return 0, http.ErrHandlerTimeout + } + if !tw.wroteHeader { + tw.commitLocked(http.StatusOK) + } + return tw.w.Write(b) +} + +func (tw *timeoutResponseWriter) Flush() { + tw.mu.Lock() + defer tw.mu.Unlock() + if tw.timedOut { + return + } + if !tw.wroteHeader { + tw.commitLocked(http.StatusOK) + } + if f, ok := tw.w.(http.Flusher); ok { + f.Flush() + } +} + +// Hijack attempts to hijack the connection. +// It marks the writer as hijacked so the timeout handler knows to back off. +func (tw *timeoutResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + tw.mu.Lock() + defer tw.mu.Unlock() + + if tw.timedOut { + return nil, nil, http.ErrHandlerTimeout + } + + if hijacker, ok := tw.w.(http.Hijacker); ok { + conn, rw, err := hijacker.Hijack() + if err == nil { + tw.hijacked = true + } + return conn, rw, err + } + return nil, nil, errors.New("underlying ResponseWriter does not support Hijack") +} + +// timeout is invoked by the timeout middleware to stop further handler writes and emit +// the timeout status code if the handler hasn't committed a response yet. +func (tw *timeoutResponseWriter) timeout(statusCode int) { + tw.mu.Lock() + defer tw.mu.Unlock() + if tw.timedOut { + return + } + // If hijacked, don't interfere + if tw.hijacked { + return + } + tw.timedOut = true + if tw.wroteHeader { + return + } + tw.wroteHeader = true + tw.w.WriteHeader(statusCode) +} + +func (tw *timeoutResponseWriter) commitLocked(statusCode int) { + if tw.wroteHeader { + return + } + tw.wroteHeader = true + dst := tw.w.Header() + copyHeader(dst, tw.header) + tw.w.WriteHeader(statusCode) +} + +func cloneHeader(src http.Header) http.Header { + dst := make(http.Header, len(src)) + copyHeader(dst, src) + return dst +} + +func copyHeader(dst, src http.Header) { + for k, vv := range src { + copied := make([]string, len(vv)) + copy(copied, vv) + dst[k] = copied + } +}