Skip to content
12 changes: 12 additions & 0 deletions pkg/metrics/handler_method.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package metrics

import (
"bufio" // Added for Hijack
"errors" // Added for error handling
"net" // Added for Hijack
"net/http"
"time"

Expand Down Expand Up @@ -152,3 +155,12 @@ func (rw *responseWriter) WriteHeader(statusCode int) {
rw.statusCode = statusCode
rw.ResponseWriter.WriteHeader(statusCode)
}

// Hijack attempts to hijack the connection for the underlying response writer.
// This allows the metrics middleware to support WebSocket upgrades.
func (rw *responseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
if hijacker, ok := rw.ResponseWriter.(http.Hijacker); ok {
return hijacker.Hijack()
}
return nil, nil, errors.New("underlying ResponseWriter does not support Hijack")
}
46 changes: 27 additions & 19 deletions pkg/middleware/trace.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ package middleware
import (
"encoding/hex"
"net/http"
"runtime"
"sync"
"time"

Expand Down Expand Up @@ -50,7 +51,9 @@ func (g *IDGenerator) init() {
go func() {
// Pre-allocate a batch of UUIDs to insert quickly when needed
const batchSize = 1000
batchUUIDs := make([]string, 0, batchSize)
batchUUIDs := make([]string, batchSize)
batchIndex := 0 // Current position in batch (0 means batch is empty/consumed)
batchLen := 0 // Number of valid UUIDs in batch

// Used to determine if we need to batch-fill when channel is getting empty
lastChannelLen := g.size
Expand All @@ -70,42 +73,44 @@ func (g *IDGenerator) init() {
// batch-fill it immediately with multiple UUIDs
if currentLen < emptyThreshold && lastChannelLen > currentLen {
// Channel is being consumed quickly, pre-generate a batch
if len(batchUUIDs) == 0 {
// Refill our batch
batchUUIDs = batchUUIDs[:0] // Clear without deallocating
for range batchSize {
batchUUIDs = append(batchUUIDs, generateUUID())
if batchIndex >= batchLen {
// Refill our batch using index-based access to avoid slice header garbage
for i := range batchSize {
batchUUIDs[i] = generateUUID()
}
batchIndex = 0
batchLen = batchSize
}

// Add from our batch as many as we can without blocking
for len(batchUUIDs) > 0 {
// Add from our batch as many as we can
addingBatch:
for batchIndex < batchLen {
select {
case <-g.stop:
return
case g.idChan <- batchUUIDs[0]:
case g.idChan <- batchUUIDs[batchIndex]:
// Successfully added one from batch
batchUUIDs = batchUUIDs[1:]
batchIndex++
default:
// Channel is now full, stop adding
}
if len(g.idChan) == g.size {
break
// Channel is full, stop batch insertion without blocking
break addingBatch
}
}

// Very short sleep to prevent CPU thrashing but still be responsive
time.Sleep(100 * time.Microsecond) // 100μs instead of 10ms
// Yield to scheduler instead of fixed sleep for better efficiency
runtime.Gosched()
} else {
// Normal case: channel has plenty of capacity, add one at a time
select {
case <-g.stop:
return
case g.idChan <- generateUUID():
// Successfully added a new UUID
// Yield briefly to avoid monopolizing CPU during refill
runtime.Gosched()
default:
// Channel is full, sleep longer to save CPU
time.Sleep(1 * time.Millisecond) // 1ms instead of 10ms
// Channel is full, sleep to avoid tight spin loop
time.Sleep(1 * time.Millisecond)
}
}

Expand All @@ -131,7 +136,10 @@ func generateUUID() string {
if err != nil {
id = uuid.New()
}
return hex.EncodeToString(id[:])
// Use stack-allocated buffer to avoid heap allocation
var buf [32]byte
hex.Encode(buf[:], id[:])
return string(buf[:])
}

// GetID returns a precomputed UUID from the buffer channel.
Expand Down
115 changes: 112 additions & 3 deletions pkg/middleware/trace_stop_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package middleware

import (
"sync"
"testing"
"time"
)
Expand All @@ -20,10 +21,118 @@ func TestIDGeneratorStopDuringBatchFill(t *testing.T) {
}

// TestIDGeneratorStopDuringNormalFill ensures Stop works during normal generation.
// This covers the `case <-g.stop: return` in the normal fill select.
func TestIDGeneratorStopDuringNormalFill(t *testing.T) {
g := NewIDGenerator(1)
_ = g.GetIDNonBlocking() // drain so filler runs normally
time.Sleep(1 * time.Millisecond)
// Use a larger buffer so normal fill runs longer (batch mode won't trigger)
// With size 50, emptyThreshold = 5, so we need currentLen < 5 for batch mode
// If we keep channel above 5 items, we stay in normal fill mode
g := NewIDGenerator(50)

// Wait for initial fill
time.Sleep(50 * time.Millisecond)

// Start draining to keep the worker active in normal fill mode
done := make(chan struct{})
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
for {
select {
case <-done:
return
default:
_ = g.GetIDNonBlocking()
// Keep some items in buffer so we don't trigger batch mode
if len(g.idChan) < 40 {
time.Sleep(100 * time.Microsecond)
}
}
}
}()

// Let the drain run briefly
time.Sleep(5 * time.Millisecond)

// Stop while normal fill is running
g.Stop()

close(done)
wg.Wait()

// Second stop should be safe
g.Stop()
}

// TestIDGeneratorStopWhileBatchLoopActive tests stopping during active batch insertion.
// This specifically covers the `case <-g.stop: return` inside the batch fill inner loop.
func TestIDGeneratorStopWhileBatchLoopActive(t *testing.T) {
// Use a larger buffer with low threshold to keep batch loop running longer
bufferSize := 100
g := NewIDGenerator(bufferSize)

// Wait for initial fill
time.Sleep(50 * time.Millisecond)

// Use multiple goroutines to aggressively drain and keep batch fill active
done := make(chan struct{})
var wg sync.WaitGroup

// Multiple consumers to drain aggressively
for range 5 {
wg.Add(1)
go func() {
defer wg.Done()
for {
select {
case <-done:
return
default:
_ = g.GetIDNonBlocking()
}
}
}()
}

// Let consumers drain aggressively to trigger and maintain batch fill mode
time.Sleep(10 * time.Millisecond)

// Stop while batch loop is actively running
g.Stop()

// Signal drain goroutines to stop
close(done)
wg.Wait()
}

// TestIDGeneratorBatchFillChannelFull tests the default case in batch fill loop.
// This covers the `default` branch that exits batch insertion when the channel fills.
func TestIDGeneratorBatchFillChannelFull(t *testing.T) {
// Use a very small buffer - the batch fill precomputes 1000 UUIDs,
// so a small buffer will fill up quickly and trigger the default case
bufferSize := 10
g := NewIDGenerator(bufferSize)

// Wait for initial fill
time.Sleep(50 * time.Millisecond)

// Drain buffer rapidly to trigger batch fill mode
// (channel must be below 10% threshold AND decreasing)
for range bufferSize {
_ = g.GetIDNonBlocking()
}

// Give the batch filler time to:
// 1. Detect the depletion (currentLen < emptyThreshold && lastChannelLen > currentLen)
// 2. Start batch fill
// 3. Hit the default case when channel fills up (buffer is only 10, batch has 1000)
time.Sleep(10 * time.Millisecond)

// Verify buffer is refilled (proves batch fill ran and hit the full condition)
finalLen := len(g.idChan)
if finalLen != bufferSize {
t.Errorf("Expected buffer to be full (%d), got %d", bufferSize, finalLen)
}

g.Stop()
}
24 changes: 15 additions & 9 deletions pkg/router/register_generic_route_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,7 @@ func TestRegisterGenericRouteWithTimeout(t *testing.T) {
logger := zap.NewNop()
r := NewRouter(RouterConfig{Logger: logger}, mocks.MockAuthFunction, mocks.MockUserIDFromUser)

timeout := 1 * time.Millisecond
timeout := 10 * time.Millisecond
ctxErrCh := make(chan error, 1)

RegisterGenericRoute(r, RouteConfig[RequestType, ResponseType]{
Expand All @@ -625,30 +625,36 @@ func TestRegisterGenericRouteWithTimeout(t *testing.T) {
Handler: func(r *http.Request, req RequestType) (ResponseType, error) {
<-r.Context().Done()
ctxErrCh <- r.Context().Err()
return ResponseType{Message: "Should have timed out"}, r.Context().Err()
// Return nil error to avoid handleError race with timeout response.
// The timeout middleware writes 408 before this returns.
return ResponseType{}, nil
},
SourceType: Body,
Overrides: common.RouteOverrides{Timeout: timeout},
}, timeout, 0, nil)

// Use httptest.Server to properly handle concurrent access to response
ts := httptest.NewServer(r)
defer ts.Close()

reqBody := RequestType{ID: "123", Name: "John"}
reqBytes, err := json.Marshal(reqBody)
require.NoError(t, err)
req := httptest.NewRequest("POST", "/test", strings.NewReader(string(reqBytes)))
req.Header.Set("Content-Type", "application/json")
rr := httptest.NewRecorder()
r.ServeHTTP(rr, req)

if rr.Code != http.StatusRequestTimeout {
t.Errorf("Expected status code %d, got %d", http.StatusRequestTimeout, rr.Code)
resp, err := ts.Client().Post(ts.URL+"/test", "application/json", strings.NewReader(string(reqBytes)))
require.NoError(t, err)
defer resp.Body.Close()

if resp.StatusCode != http.StatusRequestTimeout {
t.Errorf("Expected status code %d, got %d", http.StatusRequestTimeout, resp.StatusCode)
}

select {
case err := <-ctxErrCh:
if !errors.Is(err, context.DeadlineExceeded) {
t.Errorf("Expected context deadline exceeded, got %v", err)
}
default:
case <-time.After(1 * time.Second):
t.Error("Handler did not receive context cancellation")
}
}
Expand Down
Loading