From 4644029c6c04908f509875d9943c20cabe3c5eb9 Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Fri, 5 Dec 2025 20:23:49 +0000 Subject: [PATCH 01/12] feat: Add WebSocket support by disabling timeouts for specific routes This commit introduces the `IsWebSocket` flag to `RouteConfigBase`. When set to true, this flag ensures that the router's global or sub-router timeout configurations are ignored for the specific route, effectively disabling the timeout middleware. This is necessary for long-lived WebSocket connections that would otherwise be terminated by the timeout. Changes: - Modified `RouteConfigBase` in `pkg/router/config.go` to include `IsWebSocket`. - Updated `registerSubRouter` in `pkg/router/router.go` to override timeout to 0 if `IsWebSocket` is true. - Updated `RegisterRoute` in `pkg/router/route.go` to override timeout to 0 if `IsWebSocket` is true. - Added `TestWebSocketRoute` in `pkg/router/websocket_test.go`. - Updated `README.md` with documentation and usage examples. --- README.md | 22 +++++++++++ pkg/router/config.go | 1 + pkg/router/route.go | 6 +++ pkg/router/router.go | 6 +++ pkg/router/websocket_test.go | 73 ++++++++++++++++++++++++++++++++++++ 5 files changed, 108 insertions(+) create mode 100644 pkg/router/websocket_test.go diff --git a/README.md b/README.md index c28ff71..e4d2a8f 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,7 @@ SRouter is a high-performance HTTP router for Go that wraps [julienschmidt/httpr - **Advanced Features** - [IP Configuration](./docs/ip-configuration.md) - [Rate Limiting](./docs/rate-limiting.md) + - [WebSocket Support](#websocket-support) - [Authentication](./docs/authentication.md) - [Context Management](./docs/context-management.md) - [Custom Error Handling](./docs/error-handling.md) @@ -321,6 +322,26 @@ func GetUserHandler(w http.ResponseWriter, r *http.Request) { } ``` +### WebSocket Support + +SRouter supports WebSocket connections by allowing you to disable the automatic request timeout for specific routes. This is crucial for long-lived connections. + +To enable WebSocket support for a route, set the `IsWebSocket` flag to `true` in your `RouteConfigBase`. This will prevent the global or sub-router timeout from terminating the connection. + +```go +// Register a WebSocket route +r.RegisterRoute(router.RouteConfigBase{ + Path: "/ws", + Methods: []router.HttpMethod{router.MethodGet}, + IsWebSocket: true, // Disables timeout for this route + Handler: func(w http.ResponseWriter, r *http.Request) { + // Upgrade the connection to a WebSocket + // conn, err := upgrader.Upgrade(w, r, nil) + // ... handle connection ... + }, +}) +``` + ### Trace ID Logging SRouter provides built-in support for trace ID logging, which allows you to correlate log entries across different parts of your application for a single request. Each request is assigned a unique trace ID (UUID) that is automatically included in all log entries when `EnableTraceLogging` is true. @@ -1224,6 +1245,7 @@ type RouteConfigBase struct { Overrides common.RouteOverrides // Optional per-route overrides Handler http.HandlerFunc // Standard HTTP handler function Middlewares []common.Middleware // Middlewares applied to this specific route + IsWebSocket bool // Indicates if this route is a WebSocket route. If true, timeout is disabled. } ``` diff --git a/pkg/router/config.go b/pkg/router/config.go index 4c8f701..ab4bc9a 100644 --- a/pkg/router/config.go +++ b/pkg/router/config.go @@ -195,6 +195,7 @@ type RouteConfigBase struct { Overrides common.RouteOverrides // Configuration overrides for this specific route Handler http.HandlerFunc // Standard HTTP handler function Middlewares []common.Middleware // Middlewares applied to this specific route (combined with sub-router and global middlewares) + IsWebSocket bool // Indicates if this route is a WebSocket route. If true, timeout is disabled. } // Implement RouteDefinition for RouteConfigBase diff --git a/pkg/router/route.go b/pkg/router/route.go index a82f7a2..e822bb1 100644 --- a/pkg/router/route.go +++ b/pkg/router/route.go @@ -25,6 +25,12 @@ import ( func (r *Router[T, U]) RegisterRoute(route RouteConfigBase) { // Get effective timeout, max body size, and rate limit for this route timeout := r.getEffectiveTimeout(route.Overrides.Timeout, 0) + + // If route is a WebSocket, disable timeout + if route.IsWebSocket { + timeout = 0 + } + maxBodySize := r.getEffectiveMaxBodySize(route.Overrides.MaxBodySize, 0) // Pass the specific route config (which is *common.RateLimitConfig[any, any]) // to getEffectiveRateLimit. The conversion happens inside getEffectiveRateLimit. diff --git a/pkg/router/router.go b/pkg/router/router.go index 38b71da..99fc792 100644 --- a/pkg/router/router.go +++ b/pkg/router/router.go @@ -210,6 +210,12 @@ func (r *Router[T, U]) registerSubRouter(sr SubRouterConfig) { // Get effective settings considering overrides timeout := r.getEffectiveTimeout(route.Overrides.Timeout, sr.Overrides.Timeout) + + // If route is a WebSocket, disable timeout + if route.IsWebSocket { + timeout = 0 + } + maxBodySize := r.getEffectiveMaxBodySize(route.Overrides.MaxBodySize, sr.Overrides.MaxBodySize) rateLimit := r.getEffectiveRateLimit(route.Overrides.RateLimit, sr.Overrides.RateLimit) authLevel := route.AuthLevel // Use route-specific first diff --git a/pkg/router/websocket_test.go b/pkg/router/websocket_test.go new file mode 100644 index 0000000..eda8e4a --- /dev/null +++ b/pkg/router/websocket_test.go @@ -0,0 +1,73 @@ +package router_test + +import ( + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/Suhaibinator/SRouter/pkg/router" + "go.uber.org/zap" +) + +func TestWebSocketRoute(t *testing.T) { + logger := zap.NewNop() + config := router.RouterConfig{ + Logger: logger, + GlobalTimeout: 100 * time.Millisecond, + } + + r := router.NewRouter[string, string](config, nil, nil) + + // Register a "WebSocket" route that sleeps longer than the global timeout + // Since IsWebSocket is true, it should NOT timeout. + r.RegisterRoute(router.RouteConfigBase{ + Path: "/ws", + Methods: []router.HttpMethod{router.MethodGet}, + IsWebSocket: true, + Handler: func(w http.ResponseWriter, r *http.Request) { + // Simulate long-lived connection + time.Sleep(200 * time.Millisecond) + w.WriteHeader(http.StatusOK) + }, + }) + + // Register a normal route that SHOULD timeout + r.RegisterRoute(router.RouteConfigBase{ + Path: "/normal", + Methods: []router.HttpMethod{router.MethodGet}, + Handler: func(w http.ResponseWriter, r *http.Request) { + time.Sleep(200 * time.Millisecond) + w.WriteHeader(http.StatusOK) + }, + }) + + server := httptest.NewServer(r) + defer server.Close() + + client := server.Client() + + // Test WebSocket Route + start := time.Now() + resp, err := client.Get(server.URL + "/ws") + duration := time.Since(start) + + if err != nil { + t.Fatalf("/ws request failed: %v", err) + } + if resp.StatusCode != http.StatusOK { + t.Errorf("/ws: expected 200 OK, got %d", resp.StatusCode) + } + if duration < 200*time.Millisecond { + t.Errorf("/ws: completed too fast (%v), sleep didn't happen?", duration) + } + + // Test Normal Route (Control Case) + resp, err = client.Get(server.URL + "/normal") + if err != nil { + t.Fatal(err) + } + if resp.StatusCode != http.StatusRequestTimeout { + t.Errorf("/normal: expected 408 Timeout, got %d", resp.StatusCode) + } +} From 341c7367cff09c9f77f03048922a7336260066dc Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Fri, 5 Dec 2025 20:31:03 +0000 Subject: [PATCH 02/12] test: add more comprehensive tests for WebSocket support Expanded the test suite to cover WebSocket support in both top-level routers and sub-routers. This ensures that the `IsWebSocket` flag correctly disables timeouts in all routing scenarios. --- pkg/router/websocket_test.go | 112 +++++++++++++++++++++++++++++------ 1 file changed, 94 insertions(+), 18 deletions(-) diff --git a/pkg/router/websocket_test.go b/pkg/router/websocket_test.go index eda8e4a..d614448 100644 --- a/pkg/router/websocket_test.go +++ b/pkg/router/websocket_test.go @@ -6,6 +6,7 @@ import ( "testing" "time" + "github.com/Suhaibinator/SRouter/pkg/common" "github.com/Suhaibinator/SRouter/pkg/router" "go.uber.org/zap" ) @@ -48,26 +49,101 @@ func TestWebSocketRoute(t *testing.T) { client := server.Client() // Test WebSocket Route - start := time.Now() - resp, err := client.Get(server.URL + "/ws") - duration := time.Since(start) + t.Run("WebSocket Route should not timeout", func(t *testing.T) { + start := time.Now() + resp, err := client.Get(server.URL + "/ws") + duration := time.Since(start) - if err != nil { - t.Fatalf("/ws request failed: %v", err) - } - if resp.StatusCode != http.StatusOK { - t.Errorf("/ws: expected 200 OK, got %d", resp.StatusCode) - } - if duration < 200*time.Millisecond { - t.Errorf("/ws: completed too fast (%v), sleep didn't happen?", duration) - } + if err != nil { + t.Fatalf("/ws request failed: %v", err) + } + if resp.StatusCode != http.StatusOK { + t.Errorf("/ws: expected 200 OK, got %d", resp.StatusCode) + } + if duration < 200*time.Millisecond { + t.Errorf("/ws: completed too fast (%v), sleep didn't happen?", duration) + } + }) // Test Normal Route (Control Case) - resp, err = client.Get(server.URL + "/normal") - if err != nil { - t.Fatal(err) - } - if resp.StatusCode != http.StatusRequestTimeout { - t.Errorf("/normal: expected 408 Timeout, got %d", resp.StatusCode) + t.Run("Normal Route should timeout", func(t *testing.T) { + resp, err := client.Get(server.URL + "/normal") + if err != nil { + t.Fatal(err) + } + if resp.StatusCode != http.StatusRequestTimeout { + t.Errorf("/normal: expected 408 Timeout, got %d", resp.StatusCode) + } + }) +} + +func TestSubRouterWebSocketRoute(t *testing.T) { + logger := zap.NewNop() + config := router.RouterConfig{ + Logger: logger, + GlobalTimeout: 100 * time.Millisecond, + SubRouters: []router.SubRouterConfig{ + { + PathPrefix: "/sub", + Overrides: common.RouteOverrides{ + Timeout: 50 * time.Millisecond, + }, + Routes: []router.RouteDefinition{ + router.RouteConfigBase{ + Path: "/ws", + Methods: []router.HttpMethod{router.MethodGet}, + IsWebSocket: true, + Handler: func(w http.ResponseWriter, r *http.Request) { + // Simulate long-lived connection + time.Sleep(200 * time.Millisecond) + w.WriteHeader(http.StatusOK) + }, + }, + router.RouteConfigBase{ + Path: "/normal", + Methods: []router.HttpMethod{router.MethodGet}, + Handler: func(w http.ResponseWriter, r *http.Request) { + time.Sleep(200 * time.Millisecond) + w.WriteHeader(http.StatusOK) + }, + }, + }, + }, + }, } + + r := router.NewRouter[string, string](config, nil, nil) + + server := httptest.NewServer(r) + defer server.Close() + + client := server.Client() + + // Test SubRouter WebSocket Route + t.Run("SubRouter WebSocket Route should not timeout", func(t *testing.T) { + start := time.Now() + resp, err := client.Get(server.URL + "/sub/ws") + duration := time.Since(start) + + if err != nil { + t.Fatalf("/sub/ws request failed: %v", err) + } + if resp.StatusCode != http.StatusOK { + t.Errorf("/sub/ws: expected 200 OK, got %d", resp.StatusCode) + } + if duration < 200*time.Millisecond { + t.Errorf("/sub/ws: completed too fast (%v), sleep didn't happen?", duration) + } + }) + + // Test SubRouter Normal Route (Control Case) + t.Run("SubRouter Normal Route should timeout", func(t *testing.T) { + resp, err := client.Get(server.URL + "/sub/normal") + if err != nil { + t.Fatal(err) + } + if resp.StatusCode != http.StatusRequestTimeout { + t.Errorf("/sub/normal: expected 408 Timeout, got %d", resp.StatusCode) + } + }) } From 3f99d825d5a11c0a136926bf143197b50b11fe48 Mon Sep 17 00:00:00 2001 From: suhaib Date: Sun, 14 Dec 2025 22:31:44 -0800 Subject: [PATCH 03/12] add support for hijacking connections --- pkg/router/router.go | 18 ++++++++ pkg/router/websocket_test.go | 84 ++++++++++++++++++++++++++++++++++++ 2 files changed, 102 insertions(+) diff --git a/pkg/router/router.go b/pkg/router/router.go index 99fc792..1eec830 100644 --- a/pkg/router/router.go +++ b/pkg/router/router.go @@ -3,10 +3,12 @@ package router import ( + "bufio" "context" "encoding/json" // Added for JSON marshalling "errors" "fmt" + "net" "net/http" "slices" // Added for CORS "strconv" // Added for CORS @@ -792,6 +794,12 @@ type baseResponseWriter struct { http.ResponseWriter } +// Unwrap returns the underlying ResponseWriter. +// This enables http.ResponseController to reach optional interfaces on the original writer. +func (bw *baseResponseWriter) Unwrap() http.ResponseWriter { + return bw.ResponseWriter +} + // WriteHeader calls the underlying ResponseWriter's WriteHeader. func (bw *baseResponseWriter) WriteHeader(statusCode int) { bw.ResponseWriter.WriteHeader(statusCode) @@ -809,6 +817,16 @@ func (bw *baseResponseWriter) Flush() { } } +// Hijack delegates to the underlying ResponseWriter when it supports http.Hijacker. +// This is required for WebSocket upgrades to work through ResponseWriter wrappers. +func (bw *baseResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + h, ok := bw.ResponseWriter.(http.Hijacker) + if !ok { + return nil, nil, http.ErrNotSupported + } + return h.Hijack() +} + // metricsResponseWriter is a wrapper around http.ResponseWriter that captures metrics. // It tracks the status code, bytes written, and timing information for each response. type metricsResponseWriter[T comparable, U any] struct { diff --git a/pkg/router/websocket_test.go b/pkg/router/websocket_test.go index d614448..7471df1 100644 --- a/pkg/router/websocket_test.go +++ b/pkg/router/websocket_test.go @@ -1,6 +1,9 @@ package router_test import ( + "bufio" + "errors" + "net" "net/http" "net/http/httptest" "testing" @@ -11,6 +14,36 @@ import ( "go.uber.org/zap" ) +type hijackableRecorder struct { + *httptest.ResponseRecorder + hijacked bool + serverConn net.Conn + clientConn net.Conn +} + +func newHijackableRecorder() *hijackableRecorder { + return &hijackableRecorder{ResponseRecorder: httptest.NewRecorder()} +} + +func (rw *hijackableRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) { + if rw.serverConn != nil || rw.clientConn != nil { + return nil, nil, errors.New("connection already hijacked") + } + + rw.hijacked = true + rw.clientConn, rw.serverConn = net.Pipe() + return rw.serverConn, bufio.NewReadWriter(bufio.NewReader(rw.serverConn), bufio.NewWriter(rw.serverConn)), nil +} + +func (rw *hijackableRecorder) Close() { + if rw.serverConn != nil { + _ = rw.serverConn.Close() + } + if rw.clientConn != nil { + _ = rw.clientConn.Close() + } +} + func TestWebSocketRoute(t *testing.T) { logger := zap.NewNop() config := router.RouterConfig{ @@ -77,6 +110,57 @@ func TestWebSocketRoute(t *testing.T) { }) } +func TestWebSocketRoutePreservesHijackerWithTracingEnabled(t *testing.T) { + logger := zap.NewNop() + config := router.RouterConfig{ + Logger: logger, + GlobalTimeout: 100 * time.Millisecond, + TraceIDBufferSize: 1, + } + + r := router.NewRouter[string, string](config, nil, nil) + + var sawHijacker bool + var hijackErr error + + r.RegisterRoute(router.RouteConfigBase{ + Path: "/ws", + Methods: []router.HttpMethod{router.MethodGet}, + IsWebSocket: true, + Handler: func(w http.ResponseWriter, _ *http.Request) { + h, ok := w.(http.Hijacker) + if !ok { + hijackErr = errors.New("response writer does not implement http.Hijacker") + return + } + sawHijacker = true + + conn, _, err := h.Hijack() + if err != nil { + hijackErr = err + return + } + _ = conn.Close() + }, + }) + + req := httptest.NewRequest(http.MethodGet, "/ws", nil) + rr := newHijackableRecorder() + defer rr.Close() + + r.ServeHTTP(rr, req) + + if !sawHijacker { + t.Fatalf("expected handler to receive an http.Hijacker when tracing is enabled") + } + if hijackErr != nil { + t.Fatalf("expected Hijack to succeed, got %v", hijackErr) + } + if !rr.hijacked { + t.Fatalf("expected Hijack to be delegated to the underlying response writer") + } +} + func TestSubRouterWebSocketRoute(t *testing.T) { logger := zap.NewNop() config := router.RouterConfig{ From 936d686fe9d88b78961d8a691609c107c4941b9f Mon Sep 17 00:00:00 2001 From: suhaib Date: Sun, 14 Dec 2025 23:08:53 -0800 Subject: [PATCH 04/12] test: enhance WebSocket route tests for hijacking behavior --- pkg/router/router.go | 5 ++-- pkg/router/websocket_test.go | 44 ++++++++++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+), 2 deletions(-) diff --git a/pkg/router/router.go b/pkg/router/router.go index 1eec830..d0ae96b 100644 --- a/pkg/router/router.go +++ b/pkg/router/router.go @@ -795,7 +795,8 @@ type baseResponseWriter struct { } // Unwrap returns the underlying ResponseWriter. -// This enables http.ResponseController to reach optional interfaces on the original writer. +// This enables Go 1.20+'s http.ResponseController to reach optional interfaces (e.g. Flusher, Hijacker) +// implemented by the original writer when this writer is wrapped. func (bw *baseResponseWriter) Unwrap() http.ResponseWriter { return bw.ResponseWriter } @@ -822,7 +823,7 @@ func (bw *baseResponseWriter) Flush() { func (bw *baseResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { h, ok := bw.ResponseWriter.(http.Hijacker) if !ok { - return nil, nil, http.ErrNotSupported + return nil, nil, fmt.Errorf("underlying ResponseWriter (%T) does not support hijacking: %w", bw.ResponseWriter, http.ErrNotSupported) } return h.Hijack() } diff --git a/pkg/router/websocket_test.go b/pkg/router/websocket_test.go index 7471df1..c76d6da 100644 --- a/pkg/router/websocket_test.go +++ b/pkg/router/websocket_test.go @@ -6,6 +6,7 @@ import ( "net" "net/http" "net/http/httptest" + "strings" "testing" "time" @@ -161,6 +162,49 @@ func TestWebSocketRoutePreservesHijackerWithTracingEnabled(t *testing.T) { } } +func TestWebSocketRouteHijackNotSupportedIsWrapped(t *testing.T) { + logger := zap.NewNop() + config := router.RouterConfig{ + Logger: logger, + GlobalTimeout: 100 * time.Millisecond, + TraceIDBufferSize: 1, + } + + r := router.NewRouter[string, string](config, nil, nil) + + var hijackErr error + + r.RegisterRoute(router.RouteConfigBase{ + Path: "/ws", + Methods: []router.HttpMethod{router.MethodGet}, + IsWebSocket: true, + Handler: func(w http.ResponseWriter, _ *http.Request) { + h, ok := w.(http.Hijacker) + if !ok { + hijackErr = errors.New("response writer does not implement http.Hijacker") + return + } + + _, _, hijackErr = h.Hijack() + }, + }) + + req := httptest.NewRequest(http.MethodGet, "/ws", nil) + rr := httptest.NewRecorder() + + r.ServeHTTP(rr, req) + + if hijackErr == nil { + t.Fatalf("expected Hijack to fail") + } + if !errors.Is(hijackErr, http.ErrNotSupported) { + t.Fatalf("expected errors.Is(hijackErr, http.ErrNotSupported) to be true, got %v", hijackErr) + } + if !strings.Contains(hijackErr.Error(), "does not support hijacking") { + t.Fatalf("expected Hijack error to include context, got %q", hijackErr.Error()) + } +} + func TestSubRouterWebSocketRoute(t *testing.T) { logger := zap.NewNop() config := router.RouterConfig{ From ef2abc725a3fcb82bfef3bed92ee333fb066e9b5 Mon Sep 17 00:00:00 2001 From: suhaib Date: Sun, 14 Dec 2025 23:32:45 -0800 Subject: [PATCH 05/12] test: add ResponseController tests for WebSocket route deadlines and full duplex --- pkg/router/websocket_test.go | 87 ++++++++++++++++++++++++++++++++++++ 1 file changed, 87 insertions(+) diff --git a/pkg/router/websocket_test.go b/pkg/router/websocket_test.go index c76d6da..243e422 100644 --- a/pkg/router/websocket_test.go +++ b/pkg/router/websocket_test.go @@ -20,6 +20,10 @@ type hijackableRecorder struct { hijacked bool serverConn net.Conn clientConn net.Conn + + readDeadline time.Time + writeDeadline time.Time + fullDuplexEnabled bool } func newHijackableRecorder() *hijackableRecorder { @@ -36,6 +40,21 @@ func (rw *hijackableRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) { return rw.serverConn, bufio.NewReadWriter(bufio.NewReader(rw.serverConn), bufio.NewWriter(rw.serverConn)), nil } +func (rw *hijackableRecorder) SetReadDeadline(deadline time.Time) error { + rw.readDeadline = deadline + return nil +} + +func (rw *hijackableRecorder) SetWriteDeadline(deadline time.Time) error { + rw.writeDeadline = deadline + return nil +} + +func (rw *hijackableRecorder) EnableFullDuplex() error { + rw.fullDuplexEnabled = true + return nil +} + func (rw *hijackableRecorder) Close() { if rw.serverConn != nil { _ = rw.serverConn.Close() @@ -205,6 +224,74 @@ func TestWebSocketRouteHijackNotSupportedIsWrapped(t *testing.T) { } } +func TestWebSocketRouteResponseControllerCanReachOptionalInterfaces(t *testing.T) { + logger := zap.NewNop() + config := router.RouterConfig{ + Logger: logger, + GlobalTimeout: 100 * time.Millisecond, + TraceIDBufferSize: 1, // ensures the router wraps the ResponseWriter + } + + r := router.NewRouter[string, string](config, nil, nil) + + var controllerErr error + var sawDeadlines bool + var sawFullDuplex bool + + r.RegisterRoute(router.RouteConfigBase{ + Path: "/ws", + Methods: []router.HttpMethod{router.MethodGet}, + IsWebSocket: true, + Handler: func(w http.ResponseWriter, _ *http.Request) { + rc := http.NewResponseController(w) + deadline := time.Now().Add(5 * time.Second) + + if err := rc.SetReadDeadline(deadline); err != nil { + controllerErr = err + return + } + if err := rc.SetWriteDeadline(deadline); err != nil { + controllerErr = err + return + } + if err := rc.EnableFullDuplex(); err != nil { + controllerErr = err + return + } + + // Also exercise Hijack through ResponseController, which is commonly used by WebSocket implementations. + conn, _, err := rc.Hijack() + if err != nil { + controllerErr = err + return + } + _ = conn.Close() + + sawDeadlines = true + sawFullDuplex = true + }, + }) + + req := httptest.NewRequest(http.MethodGet, "/ws", nil) + rr := newHijackableRecorder() + defer rr.Close() + + r.ServeHTTP(rr, req) + + if controllerErr != nil { + t.Fatalf("expected ResponseController methods to succeed, got %v", controllerErr) + } + if !sawDeadlines || rr.readDeadline.IsZero() || rr.writeDeadline.IsZero() { + t.Fatalf("expected ResponseController to reach SetReadDeadline/SetWriteDeadline on the underlying writer") + } + if !sawFullDuplex || !rr.fullDuplexEnabled { + t.Fatalf("expected ResponseController to reach EnableFullDuplex on the underlying writer") + } + if !rr.hijacked { + t.Fatalf("expected Hijack to be delegated to the underlying response writer") + } +} + func TestSubRouterWebSocketRoute(t *testing.T) { logger := zap.NewNop() config := router.RouterConfig{ From b851c8ab4676761caa35685fb050497e563ec3d1 Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Mon, 15 Dec 2025 10:05:45 -0800 Subject: [PATCH 06/12] Add minimal WebSocket example application (#102) * feat: add minimal websocket example app - Add `github.com/gorilla/websocket` dependency. - Create `examples/websocket/main.go` to demonstrate functional WebSocket support. - Include a client test in the example to verify REST and WebSocket endpoints. - Prove that `IsWebSocket: true` bypasses global timeouts. * refactor: simplify router initialization and update dependencies --------- Co-authored-by: google-labs-jules[bot] <161369871+google-labs-jules[bot]@users.noreply.github.com> Co-authored-by: Suhaib --- examples/websocket/main.go | 146 +++++++++++++++++++++++++++++++ go.mod | 20 +++-- go.sum | 58 +++++------- pkg/router/handler_error_test.go | 6 +- pkg/router/websocket_test.go | 6 +- 5 files changed, 185 insertions(+), 51 deletions(-) create mode 100644 examples/websocket/main.go diff --git a/examples/websocket/main.go b/examples/websocket/main.go new file mode 100644 index 0000000..4ee623a --- /dev/null +++ b/examples/websocket/main.go @@ -0,0 +1,146 @@ +package main + +import ( + "context" + "fmt" + "io" + "log" + "net/http" + "net/url" + "time" + + "github.com/Suhaibinator/SRouter/pkg/router" + "github.com/gorilla/websocket" + "go.uber.org/zap" +) + +var upgrader = websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + // Allow all origins for this example + CheckOrigin: func(r *http.Request) bool { return true }, +} + +func main() { + // 1. Setup Server + logger, _ := zap.NewProduction() + defer logger.Sync() + + routerConfig := router.RouterConfig{ + ServiceName: "websocket-example", + Logger: logger, + GlobalTimeout: 5 * time.Second, // Global timeout to test IsWebSocket bypass + } + + // Simple auth - accept everything + authFunc := func(ctx context.Context, token string) (*string, bool) { + user := "generic-user" + return &user, true + } + userIdFunc := func(user *string) string { return *user } + + r := router.NewRouter(routerConfig, authFunc, userIdFunc) + + // REST Endpoint + r.RegisterRoute(router.RouteConfigBase{ + Path: "/hello", + Methods: []router.HttpMethod{router.MethodGet}, + Handler: func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("Hello, World!")) + }, + }) + + // WebSocket Endpoint + r.RegisterRoute(router.RouteConfigBase{ + Path: "/ws", + Methods: []router.HttpMethod{router.MethodGet}, + IsWebSocket: true, // Crucial: disables global timeout + Handler: func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + logger.Error("upgrade failed", zap.Error(err)) + return + } + defer conn.Close() + + for { + messageType, p, err := conn.ReadMessage() + if err != nil { + return + } + // Echo message back + if err := conn.WriteMessage(messageType, p); err != nil { + return + } + } + }, + }) + + // Start server in goroutine + port := "8089" + server := &http.Server{Addr: ":" + port, Handler: r} + + go func() { + if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed { + log.Fatalf("ListenAndServe(): %v", err) + } + }() + fmt.Printf("Server started on port %s\n", port) + + // Give server a moment to start + time.Sleep(100 * time.Millisecond) + + // 2. Test Client Logic + testREST(port) + testWebSocket(port) + + // Shutdown + server.Shutdown(context.Background()) + fmt.Println("Done.") +} + +func testREST(port string) { + fmt.Println("--- Testing REST Endpoint ---") + resp, err := http.Get(fmt.Sprintf("http://localhost:%s/hello", port)) + if err != nil { + log.Fatalf("REST request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + log.Fatalf("REST expected status 200, got %d", resp.StatusCode) + } + + body, _ := io.ReadAll(resp.Body) + fmt.Printf("REST Response: %s\n", string(body)) + fmt.Println("REST Test Passed!") +} + +func testWebSocket(port string) { + fmt.Println("--- Testing WebSocket Endpoint ---") + u := url.URL{Scheme: "ws", Host: "localhost:" + port, Path: "/ws"} + + c, _, err := websocket.DefaultDialer.Dial(u.String(), nil) + if err != nil { + log.Fatalf("WebSocket dial failed: %v", err) + } + defer c.Close() + + msg := "hello websocket" + err = c.WriteMessage(websocket.TextMessage, []byte(msg)) + if err != nil { + log.Fatalf("WebSocket write failed: %v", err) + } + + _, message, err := c.ReadMessage() + if err != nil { + log.Fatalf("WebSocket read failed: %v", err) + } + + fmt.Printf("WebSocket Response: %s\n", string(message)) + if string(message) != msg { + log.Fatalf("WebSocket expected echo '%s', got '%s'", msg, string(message)) + } + fmt.Println("WebSocket Test Passed!") +} diff --git a/go.mod b/go.mod index ea335bc..224ee55 100644 --- a/go.mod +++ b/go.mod @@ -4,13 +4,14 @@ go 1.24.0 require ( github.com/julienschmidt/httprouter v1.3.0 - go.uber.org/zap v1.27.0 + go.uber.org/zap v1.27.1 ) require ( github.com/google/uuid v1.6.0 - github.com/stretchr/testify v1.10.0 - gorm.io/gorm v1.30.1 + github.com/gorilla/websocket v1.5.3 + github.com/stretchr/testify v1.11.1 + gorm.io/gorm v1.31.1 ) require ( @@ -18,7 +19,8 @@ require ( github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - golang.org/x/text v0.28.0 // indirect + go.yaml.in/yaml/v2 v2.4.3 // indirect + golang.org/x/text v0.32.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) @@ -28,14 +30,14 @@ require ( github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/prometheus/client_model v0.6.2 - github.com/prometheus/common v0.65.0 // indirect - github.com/prometheus/procfs v0.17.0 // indirect + github.com/prometheus/common v0.67.4 // indirect + github.com/prometheus/procfs v0.19.2 // indirect go.uber.org/ratelimit v0.3.1 - golang.org/x/sys v0.35.0 // indirect - google.golang.org/protobuf v1.36.7 + golang.org/x/sys v0.39.0 // indirect + google.golang.org/protobuf v1.36.11 ) require ( - github.com/prometheus/client_golang v1.23.0 + github.com/prometheus/client_golang v1.23.2 go.uber.org/multierr v1.11.0 // indirect ) diff --git a/go.sum b/go.sum index 9ab7e1b..ca0d78e 100644 --- a/go.sum +++ b/go.sum @@ -10,6 +10,8 @@ github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= +github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= @@ -28,24 +30,18 @@ github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/prometheus/client_golang v1.22.0 h1:rb93p9lokFEsctTys46VnV1kLCDpVZ0a/Y92Vm0Zc6Q= -github.com/prometheus/client_golang v1.22.0/go.mod h1:R7ljNsLXhuQXYZYtw6GAE9AZg8Y7vEW5scdCXrWRXC0= -github.com/prometheus/client_golang v1.23.0 h1:ust4zpdl9r4trLY/gSjlm07PuiBq2ynaXXlptpfy8Uc= -github.com/prometheus/client_golang v1.23.0/go.mod h1:i/o0R9ByOnHX0McrTMTyhYvKE4haaf2mW08I+jGAjEE= +github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h0RJWRi/o0o= +github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg= github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk= github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE= -github.com/prometheus/common v0.64.0 h1:pdZeA+g617P7oGv1CzdTzyeShxAGrTBsolKNOLQPGO4= -github.com/prometheus/common v0.64.0/go.mod h1:0gZns+BLRQ3V6NdaerOhMbwwRbNh9hkGINtQAsP5GS8= -github.com/prometheus/common v0.65.0 h1:QDwzd+G1twt//Kwj/Ww6E9FQq1iVMmODnILtW1t2VzE= -github.com/prometheus/common v0.65.0/go.mod h1:0gZns+BLRQ3V6NdaerOhMbwwRbNh9hkGINtQAsP5GS8= -github.com/prometheus/procfs v0.16.1 h1:hZ15bTNuirocR6u0JZ6BAHHmwS1p8B4P6MRqxtzMyRg= -github.com/prometheus/procfs v0.16.1/go.mod h1:teAbpZRB1iIAJYREa1LsoWUXykVXA1KlTmWl8x/U+Is= -github.com/prometheus/procfs v0.17.0 h1:FuLQ+05u4ZI+SS/w9+BWEM2TXiHKsUQ9TADiRH7DuK0= -github.com/prometheus/procfs v0.17.0/go.mod h1:oPQLaDAMRbA+u8H5Pbfq+dl3VDAvHxMUOVhe0wYB2zw= +github.com/prometheus/common v0.67.4 h1:yR3NqWO1/UyO1w2PhUvXlGQs/PtFmoveVO0KZ4+Lvsc= +github.com/prometheus/common v0.67.4/go.mod h1:gP0fq6YjjNCLssJCQp0yk4M8W6ikLURwkdd/YKtTbyI= +github.com/prometheus/procfs v0.19.2 h1:zUMhqEW66Ex7OXIiDkll3tl9a1ZdilUOd/F6ZXw4Vws= +github.com/prometheus/procfs v0.19.2/go.mod h1:M0aotyiemPhBCM0z5w87kL22CxfcH05ZpYlu+b4J7mw= github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog= -github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= -github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= go.uber.org/atomic v1.7.0 h1:ADUqmZGgLDDfbSL9ZmPxKTybcoEYHgpYfELNoN+7hsw= go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= @@ -54,30 +50,20 @@ go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= go.uber.org/ratelimit v0.3.1 h1:K4qVE+byfv/B3tC+4nYWP7v/6SimcO7HzHekoMNBma0= go.uber.org/ratelimit v0.3.1/go.mod h1:6euWsTB6U/Nb3X++xEUXA8ciPJvr19Q/0h1+oDcJhRk= -go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8= -go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= -golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw= -golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= -golang.org/x/sys v0.34.0 h1:H5Y5sJ2L2JRdyv7ROF1he/lPdvFsd0mJHFw2ThKHxLA= -golang.org/x/sys v0.34.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= -golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI= -golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= -golang.org/x/text v0.26.0 h1:P42AVeLghgTYr4+xUnTRKDMqpar+PtX7KWuNQL21L8M= -golang.org/x/text v0.26.0/go.mod h1:QK15LZJUUQVJxhz7wXgxSy/CJaTFjd0G+YLonydOVQA= -golang.org/x/text v0.27.0 h1:4fGWRpyh641NLlecmyl4LOe6yDdfaYNrGb2zdfo4JV4= -golang.org/x/text v0.27.0/go.mod h1:1D28KMCvyooCX9hBiosv5Tz/+YLxj0j7XhWjpSUF7CU= -golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng= -golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU= -google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY= -google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY= -google.golang.org/protobuf v1.36.7 h1:IgrO7UwFQGJdRNXH/sQux4R1Dj1WAKcLElzeeRaXV2A= -google.golang.org/protobuf v1.36.7/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY= +go.uber.org/zap v1.27.1 h1:08RqriUEv8+ArZRYSTXy1LeBScaMpVSTBhCeaZYfMYc= +go.uber.org/zap v1.27.1/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= +go.yaml.in/yaml/v2 v2.4.3 h1:6gvOSjQoTB3vt1l+CU+tSyi/HOjfOjRLJ4YwYZGwRO0= +go.yaml.in/yaml/v2 v2.4.3/go.mod h1:zSxWcmIDjOzPXpjlTTbAsKokqkDNAVtZO0WOMiT90s8= +golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk= +golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU= +golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY= +google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= +google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -gorm.io/gorm v1.30.0 h1:qbT5aPv1UH8gI99OsRlvDToLxW5zR7FzS9acZDOZcgs= -gorm.io/gorm v1.30.0/go.mod h1:8Z33v652h4//uMA76KjeDH8mJXPm1QNCYrMeatR0DOE= -gorm.io/gorm v1.30.1 h1:lSHg33jJTBxs2mgJRfRZeLDG+WZaHYCk3Wtfl6Ngzo4= -gorm.io/gorm v1.30.1/go.mod h1:8Z33v652h4//uMA76KjeDH8mJXPm1QNCYrMeatR0DOE= +gorm.io/gorm v1.31.1 h1:7CA8FTFz/gRfgqgpeKIBcervUn3xSyPUmr6B2WXJ7kg= +gorm.io/gorm v1.31.1/go.mod h1:XyQVbO2k6YkOis7C2437jSit3SsDK72s7n7rsSHd+Gs= diff --git a/pkg/router/handler_error_test.go b/pkg/router/handler_error_test.go index 25eccd8..afc5120 100644 --- a/pkg/router/handler_error_test.go +++ b/pkg/router/handler_error_test.go @@ -23,7 +23,7 @@ func TestGenericRouteHandlerError(t *testing.T) { return 0 } - router := NewRouter[int, interface{}](RouterConfig{ + router := NewRouter(RouterConfig{ Logger: zap.NewNop(), }, getUserByID, getUserID) @@ -177,11 +177,11 @@ func TestHandlerErrorWithMultipleMiddleware(t *testing.T) { getUserByID := func(ctx context.Context, userID string) (*interface{}, bool) { return nil, false } - getUserID := func(user *interface{}) int { + getUserID := func(user *any) int { return 0 } - router := NewRouter[int, interface{}](RouterConfig{ + router := NewRouter(RouterConfig{ Logger: zap.NewNop(), }, getUserByID, getUserID) diff --git a/pkg/router/websocket_test.go b/pkg/router/websocket_test.go index 243e422..cdcd5d4 100644 --- a/pkg/router/websocket_test.go +++ b/pkg/router/websocket_test.go @@ -21,9 +21,9 @@ type hijackableRecorder struct { serverConn net.Conn clientConn net.Conn - readDeadline time.Time - writeDeadline time.Time - fullDuplexEnabled bool + readDeadline time.Time + writeDeadline time.Time + fullDuplexEnabled bool } func newHijackableRecorder() *hijackableRecorder { From addda8648f91d903c559287a547f4b1a76bed84c Mon Sep 17 00:00:00 2001 From: Suhaibinator <42899065+Suhaibinator@users.noreply.github.com> Date: Mon, 15 Dec 2025 13:05:24 -0800 Subject: [PATCH 07/12] Example websocket proof 5925307507322852050 (#103) * feat: add minimal websocket example app - Add `github.com/gorilla/websocket` dependency. - Create `examples/websocket/main.go` to demonstrate functional WebSocket support. - Include a client test in the example to verify REST and WebSocket endpoints. - Prove that `IsWebSocket: true` bypasses global timeouts. * refactor: simplify router initialization and update dependencies * fix: improve timeout handling in mutexResponseWriter to prevent race conditions --------- Co-authored-by: google-labs-jules[bot] <161369871+google-labs-jules[bot]@users.noreply.github.com> --- pkg/router/router.go | 114 ++++++++++++++++++++++++++++++++++++++----- 1 file changed, 103 insertions(+), 11 deletions(-) diff --git a/pkg/router/router.go b/pkg/router/router.go index d0ae96b..ef8219d 100644 --- a/pkg/router/router.go +++ b/pkg/router/router.go @@ -423,19 +423,48 @@ func (r *Router[T, U]) timeoutMiddleware(timeout time.Duration) common.Middlewar fields = r.addTrace(fields, req) r.logger.Error("Request timed out", fields...) - // Acquire lock to safely check and potentially write timeout response. - wrappedW.mu.Lock() - // Check if handler already started writing. Use Swap for atomic check-and-set. - if !wrappedW.wroteHeader.Swap(true) { - // Handler hasn't written yet, we can write the timeout error. - // Hold the lock while writing headers and body for timeout. - // Use the new JSON error writer, passing the request - traceID := scontext.GetTraceIDFromRequest[T, U](req) - r.writeJSONError(wrappedW.ResponseWriter, req, http.StatusRequestTimeout, "Request Timeout", traceID) + // If the handler already started writing, don't attempt to take over the response. + // Wait for the handler to finish to avoid returning while another goroutine is writing. + if wrappedW.wroteHeader.Load() { + <-done + select { + case p := <-panicChan: + panic(p) + default: + } + return + } + + // Mark timed out so any in-flight handler writes fail fast and don't touch the underlying writer. + wrappedW.timedOut.Store(true) + + // Reserve the response so the handler can't race to write its own error response. + if !wrappedW.wroteHeader.CompareAndSwap(false, true) { + <-done + select { + case p := <-panicChan: + panic(p) + default: + } + return } - // If wroteHeader was already true, handler won the race, do nothing here. - // Unlock should happen regardless of whether we wrote the error or not. + + // Serialize the timeout response write with any handler goroutine currently inside rw methods. + wrappedW.mu.Lock() + traceID := scontext.GetTraceIDFromRequest[T, U](req) + r.writeJSONError(wrappedW.ResponseWriter, req, http.StatusRequestTimeout, "Request Timeout", traceID) wrappedW.mu.Unlock() + + // Give the handler a chance to observe cancellation and exit promptly. + select { + case <-done: + select { + case p := <-panicChan: + panic(p) + default: + } + case <-time.After(50 * time.Millisecond): + } return } }) @@ -1013,6 +1042,56 @@ func (r *Router[T, U]) handleError(w http.ResponseWriter, req *http.Request, err // It includes the trace ID in the JSON payload if available and enabled. // It also adds CORS headers based on information stored in the context by the CORS middleware. func (r *Router[T, U]) writeJSONError(w http.ResponseWriter, req *http.Request, statusCode int, message string, traceID string) { // Add req parameter + if mrw, ok := w.(*mutexResponseWriter); ok { + if mrw.timedOut.Load() { + return + } + if !mrw.wroteHeader.CompareAndSwap(false, true) { + return + } + + mrw.mu.Lock() + defer mrw.mu.Unlock() + + allowedOrigin, credentialsAllowed, corsOK := scontext.GetCORSInfoFromRequest[T, U](req) + header := mrw.ResponseWriter.Header() + + if corsOK { + if allowedOrigin != "" { + header.Set("Access-Control-Allow-Origin", allowedOrigin) + } + if credentialsAllowed { + header.Set("Access-Control-Allow-Credentials", "true") + } + if allowedOrigin != "" && allowedOrigin != "*" { + header.Add("Vary", "Origin") + } + } + + header.Set("Content-Type", "application/json; charset=utf-8") + mrw.ResponseWriter.WriteHeader(statusCode) + + errorPayload := map[string]any{ + "error": map[string]string{ + "message": message, + }, + } + if r.config.TraceIDBufferSize > 0 && traceID != "" { + errorMap := errorPayload["error"].(map[string]string) + errorMap["trace_id"] = traceID + } + + if err := json.NewEncoder(mrw.ResponseWriter).Encode(errorPayload); err != nil { + r.logger.Error("Failed to write JSON error response", + zap.Error(err), + zap.Int("original_status", statusCode), + zap.String("original_message", message), + zap.String("trace_id", traceID), + ) + } + return + } + // Retrieve CORS info from context using the passed-in request allowedOrigin, credentialsAllowed, corsOK := scontext.GetCORSInfoFromRequest[T, U](req) @@ -1220,10 +1299,14 @@ type mutexResponseWriter struct { http.ResponseWriter mu *sync.Mutex wroteHeader atomic.Bool // Tracks if WriteHeader or Write has been called + timedOut atomic.Bool // When true, reject all writes to the underlying writer } // Header acquires the mutex and returns the underlying Header map. func (rw *mutexResponseWriter) Header() http.Header { + if rw.timedOut.Load() { + return make(http.Header) + } rw.mu.Lock() defer rw.mu.Unlock() return rw.ResponseWriter.Header() @@ -1231,6 +1314,9 @@ func (rw *mutexResponseWriter) Header() http.Header { // WriteHeader acquires the mutex, marks headers as written, and calls the underlying ResponseWriter.WriteHeader. func (rw *mutexResponseWriter) WriteHeader(statusCode int) { + if rw.timedOut.Load() { + return + } rw.mu.Lock() defer rw.mu.Unlock() if !rw.wroteHeader.Swap(true) { // Atomically set flag and check previous value @@ -1241,6 +1327,9 @@ func (rw *mutexResponseWriter) WriteHeader(statusCode int) { // Write acquires the mutex, marks headers/body as written, and calls the underlying ResponseWriter.Write. func (rw *mutexResponseWriter) Write(b []byte) (int, error) { + if rw.timedOut.Load() { + return 0, http.ErrHandlerTimeout + } rw.mu.Lock() defer rw.mu.Unlock() rw.wroteHeader.Store(true) // Mark as written (headers might be implicitly written here) @@ -1249,6 +1338,9 @@ func (rw *mutexResponseWriter) Write(b []byte) (int, error) { // Flush acquires the mutex and calls the underlying ResponseWriter.Flush if it implements http.Flusher. func (rw *mutexResponseWriter) Flush() { + if rw.timedOut.Load() { + return + } rw.mu.Lock() defer rw.mu.Unlock() if f, ok := rw.ResponseWriter.(http.Flusher); ok { From d0db71373e72ecebeaeeef61e3dfba863587ed11 Mon Sep 17 00:00:00 2001 From: Suhaib Date: Mon, 15 Dec 2025 13:27:30 -0800 Subject: [PATCH 08/12] test: add race condition tests for timeout middleware --- .vscode/settings.json | 7 ++ pkg/router/timeout_middleware_race_test.go | 113 +++++++++++++++++++++ pkg/router/timeout_middleware_test.go | 101 ++++++++++++++++++ pkg/scontext/copy_test.go | 2 +- 4 files changed, 222 insertions(+), 1 deletion(-) create mode 100644 .vscode/settings.json create mode 100644 pkg/router/timeout_middleware_race_test.go create mode 100644 pkg/router/timeout_middleware_test.go diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..8e4aa4f --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,7 @@ +{ + "gopls": { + "buildFlags": [ + "-tags=race" + ] + } +} \ No newline at end of file diff --git a/pkg/router/timeout_middleware_race_test.go b/pkg/router/timeout_middleware_race_test.go new file mode 100644 index 0000000..b8ffa20 --- /dev/null +++ b/pkg/router/timeout_middleware_race_test.go @@ -0,0 +1,113 @@ +//go:build race + +package router + +import ( + "context" + "net/http" + "net/http/httptest" + "runtime" + "testing" + "time" + + "github.com/Suhaibinator/SRouter/pkg/router/internal/mocks" + "go.uber.org/zap" +) + +func TestTimeoutMiddleware_WhenHandlerWritesBetweenHeaderCheckAndTimeoutStore_TakeoverCASFails(t *testing.T) { + oldProcs := runtime.GOMAXPROCS(2) + defer runtime.GOMAXPROCS(oldProcs) + + r := NewRouter(RouterConfig{Logger: zap.NewNop()}, mocks.MockAuthFunction, mocks.MockUserIDFromUser) + timeout := 50 * time.Microsecond + + deadline := time.Now().Add(5 * time.Second) + attempts := 0 + + for time.Now().Before(deadline) { + attempts++ + + mrwCh := make(chan *mutexResponseWriter, 1) + ctxErrCh := make(chan error, 1) + + handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + if mrw, ok := w.(*mutexResponseWriter); ok { + select { + case mrwCh <- mrw: + default: + } + } + + <-req.Context().Done() + ctxErrCh <- req.Context().Err() + + w.WriteHeader(http.StatusAccepted) + }) + + h := r.timeoutMiddleware(timeout)(handler) + + req := httptest.NewRequest(http.MethodGet, "http://example.com/test", nil) + rr := httptest.NewRecorder() + h.ServeHTTP(rr, req) + + mrw := <-mrwCh + if rr.Code == http.StatusAccepted && mrw.timedOut.Load() { + select { + case err := <-ctxErrCh: + if err != context.DeadlineExceeded { + t.Fatalf("expected context deadline exceeded, got %v", err) + } + default: + t.Fatalf("expected handler to observe context cancellation") + } + return + } + } + + t.Fatalf("did not observe timeout takeover CAS failure within deadline (attempts=%d)", attempts) +} + +func TestTimeoutMiddleware_WhenHandlerPanicsInCASFailurePath_RethrowsToRecovery(t *testing.T) { + oldProcs := runtime.GOMAXPROCS(2) + defer runtime.GOMAXPROCS(oldProcs) + + r := NewRouter(RouterConfig{Logger: zap.NewNop()}, mocks.MockAuthFunction, mocks.MockUserIDFromUser) + timeout := 50 * time.Microsecond + + deadline := time.Now().Add(5 * time.Second) + attempts := 0 + + for time.Now().Before(deadline) { + attempts++ + + mrwCh := make(chan *mutexResponseWriter, 1) + + handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + if mrw, ok := w.(*mutexResponseWriter); ok { + select { + case mrwCh <- mrw: + default: + } + } + <-req.Context().Done() + w.WriteHeader(http.StatusAccepted) + panic("boom") + }) + + h := r.recoveryMiddleware(r.timeoutMiddleware(timeout)(handler)) + + req := httptest.NewRequest(http.MethodGet, "http://example.com/test", nil) + rr := httptest.NewRecorder() + h.ServeHTTP(rr, req) + + mrw := <-mrwCh + if rr.Code == http.StatusAccepted && mrw.timedOut.Load() { + if msg := parseJSONErrorMessage(t, rr.Body.Bytes()); msg != "Internal Server Error" { + t.Fatalf("expected internal server error payload, got %q", msg) + } + return + } + } + + t.Fatalf("did not observe panic rethrow in CAS-failure path within deadline (attempts=%d)", attempts) +} diff --git a/pkg/router/timeout_middleware_test.go b/pkg/router/timeout_middleware_test.go new file mode 100644 index 0000000..cfca017 --- /dev/null +++ b/pkg/router/timeout_middleware_test.go @@ -0,0 +1,101 @@ +package router + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/Suhaibinator/SRouter/pkg/router/internal/mocks" + "go.uber.org/zap" +) + +func parseJSONErrorMessage(t *testing.T, body []byte) string { + t.Helper() + + var payload struct { + Error struct { + Message string `json:"message"` + } `json:"error"` + } + if err := json.Unmarshal(body, &payload); err != nil { + t.Fatalf("expected JSON error payload, got %q: %v", string(body), err) + } + return payload.Error.Message +} + +func TestTimeoutMiddleware_WhenHandlerStartedWriting_DoesNotOverrideResponse(t *testing.T) { + r := NewRouter(RouterConfig{Logger: zap.NewNop()}, mocks.MockAuthFunction, mocks.MockUserIDFromUser) + + timeout := 25 * time.Millisecond + wroteHeader := make(chan struct{}) + ctxErrCh := make(chan error, 1) + + handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + w.WriteHeader(http.StatusCreated) + close(wroteHeader) + + <-req.Context().Done() + ctxErrCh <- req.Context().Err() + time.Sleep(10 * time.Millisecond) + + _, _ = w.Write([]byte("handler-finished")) + }) + + h := r.recoveryMiddleware(r.timeoutMiddleware(timeout)(handler)) + + req := httptest.NewRequest(http.MethodGet, "http://example.com/test", nil) + rr := httptest.NewRecorder() + + select { + case <-wroteHeader: + t.Fatalf("handler should not have executed before ServeHTTP") + default: + } + + h.ServeHTTP(rr, req) + + if rr.Code != http.StatusCreated { + t.Fatalf("expected status %d, got %d", http.StatusCreated, rr.Code) + } + if rr.Body.String() != "handler-finished" { + t.Fatalf("expected body %q, got %q", "handler-finished", rr.Body.String()) + } + + select { + case err := <-ctxErrCh: + if err != context.DeadlineExceeded { + t.Fatalf("expected context deadline exceeded, got %v", err) + } + default: + t.Fatalf("expected handler to observe context cancellation") + } +} + +func TestTimeoutMiddleware_WhenHandlerPanicsAfterTimeoutAndStartedWrite_RethrowsToRecovery(t *testing.T) { + r := NewRouter(RouterConfig{Logger: zap.NewNop()}, mocks.MockAuthFunction, mocks.MockUserIDFromUser) + + timeout := 15 * time.Millisecond + handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + w.WriteHeader(http.StatusTeapot) + <-req.Context().Done() + time.Sleep(10 * time.Millisecond) + panic("boom") + }) + + h := r.recoveryMiddleware(r.timeoutMiddleware(timeout)(handler)) + + req := httptest.NewRequest(http.MethodGet, "http://example.com/test", nil) + rr := httptest.NewRecorder() + + h.ServeHTTP(rr, req) + + if rr.Code != http.StatusTeapot { + t.Fatalf("expected status %d, got %d", http.StatusTeapot, rr.Code) + } + if msg := parseJSONErrorMessage(t, rr.Body.Bytes()); msg != "Internal Server Error" { + t.Fatalf("expected internal server error payload, got %q", msg) + } +} diff --git a/pkg/scontext/copy_test.go b/pkg/scontext/copy_test.go index 71646d1..fc8092b 100644 --- a/pkg/scontext/copy_test.go +++ b/pkg/scontext/copy_test.go @@ -44,7 +44,7 @@ func createFullSRouterContext() context.Context { // Set all values in context ctx = WithUserID[int, testUser](ctx, userID) - ctx = WithUser[int, testUser](ctx, user) + ctx = WithUser[int](ctx, user) ctx = WithTraceID[int, testUser](ctx, traceID) ctx = WithClientIP[int, testUser](ctx, clientIP) ctx = WithUserAgent[int, testUser](ctx, userAgent) From 7dc4ca0f01f523403ff567cef88fb10efc589d03 Mon Sep 17 00:00:00 2001 From: Suhaib Date: Mon, 15 Dec 2025 13:36:27 -0800 Subject: [PATCH 09/12] test: add tests for Base64 and Base62 query parameter decoding errors --- pkg/router/route_query_decode_error_test.go | 94 ++++++++++++++ pkg/router/write_json_error_test.go | 131 ++++++++++++++++++++ 2 files changed, 225 insertions(+) create mode 100644 pkg/router/route_query_decode_error_test.go create mode 100644 pkg/router/write_json_error_test.go diff --git a/pkg/router/route_query_decode_error_test.go b/pkg/router/route_query_decode_error_test.go new file mode 100644 index 0000000..74e95f7 --- /dev/null +++ b/pkg/router/route_query_decode_error_test.go @@ -0,0 +1,94 @@ +package router + +import ( + "encoding/base64" + "encoding/json" + "math/big" + "net/http" + "net/http/httptest" + "testing" + + "github.com/Suhaibinator/SRouter/pkg/codec" + "github.com/Suhaibinator/SRouter/pkg/router/internal/mocks" + "github.com/stretchr/testify/require" + "go.uber.org/zap" +) + +func encodeBase62(b []byte) string { + const alphabet = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" + + n := new(big.Int).SetBytes(b) + if n.Sign() == 0 { + return "0" + } + + base := big.NewInt(62) + mod := new(big.Int) + + var out []byte + for n.Sign() > 0 { + n.DivMod(n, base, mod) + out = append(out, alphabet[mod.Int64()]) + } + + for i, j := 0, len(out)-1; i < j; i, j = i+1, j-1 { + out[i], out[j] = out[j], out[i] + } + return string(out) +} + +func TestRegisterGenericRoute_Base64QueryParameter_DecodeBytesError(t *testing.T) { + logger := zap.NewNop() + r := NewRouter(RouterConfig{Logger: logger}, mocks.MockAuthFunction, mocks.MockUserIDFromUser) + + RegisterGenericRoute(r, RouteConfig[RequestType, ResponseType]{ + Path: "/test", + Methods: []HttpMethod{MethodGet}, + Codec: codec.NewJSONCodec[RequestType, ResponseType](), + SourceType: Base64QueryParameter, + SourceKey: "qdata", + Handler: func(r *http.Request, req RequestType) (ResponseType, error) { + t.Fatalf("handler should not be called on decode error") + return ResponseType{}, nil + }, + }, 0, 0, nil) + + invalidJSONBase64 := base64.StdEncoding.EncodeToString([]byte("{invalid json")) + req := httptest.NewRequest(http.MethodGet, "/test?qdata="+invalidJSONBase64, nil) + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + require.Equal(t, http.StatusBadRequest, rr.Code) + + var body map[string]map[string]string + require.NoError(t, json.Unmarshal(rr.Body.Bytes(), &body)) + require.Equal(t, "Failed to decode query parameter data", body["error"]["message"]) +} + +func TestRegisterGenericRoute_Base62QueryParameter_DecodeBytesError(t *testing.T) { + logger := zap.NewNop() + r := NewRouter(RouterConfig{Logger: logger}, mocks.MockAuthFunction, mocks.MockUserIDFromUser) + + RegisterGenericRoute(r, RouteConfig[RequestType, ResponseType]{ + Path: "/test", + Methods: []HttpMethod{MethodGet}, + Codec: codec.NewJSONCodec[RequestType, ResponseType](), + SourceType: Base62QueryParameter, + SourceKey: "qdata", + Handler: func(r *http.Request, req RequestType) (ResponseType, error) { + t.Fatalf("handler should not be called on decode error") + return ResponseType{}, nil + }, + }, 0, 0, nil) + + invalidJSONBase62 := encodeBase62([]byte("{invalid json")) + req := httptest.NewRequest(http.MethodGet, "/test?qdata="+invalidJSONBase62, nil) + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + require.Equal(t, http.StatusBadRequest, rr.Code) + + var body map[string]map[string]string + require.NoError(t, json.Unmarshal(rr.Body.Bytes(), &body)) + require.Equal(t, "Failed to decode query parameter data", body["error"]["message"]) +} diff --git a/pkg/router/write_json_error_test.go b/pkg/router/write_json_error_test.go new file mode 100644 index 0000000..3b6330b --- /dev/null +++ b/pkg/router/write_json_error_test.go @@ -0,0 +1,131 @@ +package router + +import ( + "errors" + "net/http" + "net/http/httptest" + "sync" + "testing" + + "github.com/Suhaibinator/SRouter/pkg/router/internal/mocks" + "github.com/Suhaibinator/SRouter/pkg/scontext" + "go.uber.org/zap" + "go.uber.org/zap/zaptest/observer" +) + +func TestWriteJSONError_MutexResponseWriter_SetsCORSHeaders(t *testing.T) { + tests := []struct { + name string + allowedOrigin string + credentialsAllowed bool + wantVaryOrigin bool + }{ + { + name: "specific_origin_with_credentials_sets_vary", + allowedOrigin: "https://example.com", + credentialsAllowed: true, + wantVaryOrigin: true, + }, + { + name: "wildcard_origin_no_vary", + allowedOrigin: "*", + credentialsAllowed: false, + wantVaryOrigin: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + r := NewRouter[string, string](RouterConfig{Logger: zap.NewNop()}, mocks.MockAuthFunction, mocks.MockUserIDFromUser) + + req := httptest.NewRequest(http.MethodGet, "http://example.com/test", nil) + req = req.WithContext(scontext.WithCORSInfo[string, string](req.Context(), tc.allowedOrigin, tc.credentialsAllowed)) + + rr := httptest.NewRecorder() + var mu sync.Mutex + mrw := &mutexResponseWriter{ResponseWriter: rr, mu: &mu} + + r.writeJSONError(mrw, req, http.StatusBadRequest, "Bad Request", "") + + if got := rr.Header().Get("Access-Control-Allow-Origin"); got != tc.allowedOrigin { + t.Fatalf("expected Access-Control-Allow-Origin %q, got %q", tc.allowedOrigin, got) + } + + if tc.credentialsAllowed { + if got := rr.Header().Get("Access-Control-Allow-Credentials"); got != "true" { + t.Fatalf("expected Access-Control-Allow-Credentials %q, got %q", "true", got) + } + } else if got := rr.Header().Get("Access-Control-Allow-Credentials"); got != "" { + t.Fatalf("expected no Access-Control-Allow-Credentials header, got %q", got) + } + + if tc.wantVaryOrigin { + if got := rr.Header().Get("Vary"); got != "Origin" { + t.Fatalf("expected Vary %q, got %q", "Origin", got) + } + } else if got := rr.Header().Get("Vary"); got != "" { + t.Fatalf("expected no Vary header, got %q", got) + } + }) + } +} + +type errResponseWriter struct { + header http.Header +} + +func (w *errResponseWriter) Header() http.Header { + if w.header == nil { + w.header = make(http.Header) + } + return w.header +} + +func (w *errResponseWriter) WriteHeader(statusCode int) {} + +func (w *errResponseWriter) Write([]byte) (int, error) { + return 0, errors.New("write failed") +} + +func TestWriteJSONError_MutexResponseWriter_LogsOnEncodeFailure(t *testing.T) { + core, logs := observer.New(zap.ErrorLevel) + logger := zap.New(core) + r := NewRouter[string, string](RouterConfig{Logger: logger, TraceIDBufferSize: 1}, mocks.MockAuthFunction, mocks.MockUserIDFromUser) + + req := httptest.NewRequest(http.MethodGet, "http://example.com/test", nil) + + var mu sync.Mutex + mrw := &mutexResponseWriter{ResponseWriter: &errResponseWriter{}, mu: &mu} + + r.writeJSONError(mrw, req, http.StatusInternalServerError, "Internal Server Error", "trace-123") + + entries := logs.All() + if len(entries) != 1 { + t.Fatalf("expected 1 log entry, got %d", len(entries)) + } + if entries[0].Message != "Failed to write JSON error response" { + t.Fatalf("expected log message %q, got %q", "Failed to write JSON error response", entries[0].Message) + } + + var foundStatus, foundMessage, foundTrace bool + for _, f := range entries[0].Context { + switch f.Key { + case "original_status": + foundStatus = f.Integer == int64(http.StatusInternalServerError) + case "original_message": + foundMessage = f.String == "Internal Server Error" + case "trace_id": + foundTrace = f.String == "trace-123" + } + } + + if !foundStatus { + t.Fatalf("expected original_status field to be present") + } + if !foundMessage { + t.Fatalf("expected original_message field to be present") + } + if !foundTrace { + t.Fatalf("expected trace_id field to be present") + } +} From 71fdcb04d899134d64c4d8fc892aac954bc80e6f Mon Sep 17 00:00:00 2001 From: Suhaib Date: Mon, 15 Dec 2025 13:52:06 -0800 Subject: [PATCH 10/12] test: add tests for CORS headers and logging in JSON error responses --- pkg/router/write_json_error_test.go | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/pkg/router/write_json_error_test.go b/pkg/router/write_json_error_test.go index 3b6330b..aa28ce1 100644 --- a/pkg/router/write_json_error_test.go +++ b/pkg/router/write_json_error_test.go @@ -9,6 +9,7 @@ import ( "github.com/Suhaibinator/SRouter/pkg/router/internal/mocks" "github.com/Suhaibinator/SRouter/pkg/scontext" + "github.com/stretchr/testify/require" "go.uber.org/zap" "go.uber.org/zap/zaptest/observer" ) @@ -129,3 +130,27 @@ func TestWriteJSONError_MutexResponseWriter_LogsOnEncodeFailure(t *testing.T) { t.Fatalf("expected trace_id field to be present") } } + +func TestWriteJSONError_MutexResponseWriter_NoOpWhenHeaderAlreadyWritten(t *testing.T) { + r := NewRouter[string, string](RouterConfig{Logger: zap.NewNop()}, mocks.MockAuthFunction, mocks.MockUserIDFromUser) + + req := httptest.NewRequest(http.MethodGet, "http://example.com/test", nil) + rr := httptest.NewRecorder() + rr.Header().Set("X-Existing", "1") + + var mu sync.Mutex + mrw := &mutexResponseWriter{ResponseWriter: rr, mu: &mu} + + // Simulate the handler having already started the response. + mrw.WriteHeader(http.StatusCreated) + + r.writeJSONError(mrw, req, http.StatusBadRequest, "Bad Request", "trace-ignored") + + require.Equal(t, http.StatusCreated, rr.Code) + require.Equal(t, "1", rr.Header().Get("X-Existing")) + require.Equal(t, "", rr.Header().Get("Content-Type")) + require.Equal(t, "", rr.Header().Get("Access-Control-Allow-Origin")) + require.Equal(t, "", rr.Header().Get("Access-Control-Allow-Credentials")) + require.Equal(t, "", rr.Header().Get("Vary")) + require.Equal(t, "", rr.Body.String()) +} From 689e73df28ada647622e8230ea303b83b3cac393 Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Mon, 15 Dec 2025 18:39:44 -0800 Subject: [PATCH 11/12] Rename IsWebSocket to DisableTimeout in RouteConfigBase (#104) This is a breaking change that renames the `IsWebSocket` field to `DisableTimeout` in `RouteConfigBase`. This new name better reflects the flag's purpose, which is to disable the global or sub-router timeout for any long-lived connection, such as WebSockets or Server-Sent Events (SSE). Updates: - `pkg/router/config.go`: Renamed field and updated comments. - `pkg/router/route.go` & `pkg/router/router.go`: Updated logic to use `DisableTimeout`. - `pkg/router/websocket_test.go`: Updated tests. - `examples/websocket/main.go`: Updated example. - `README.md`: Updated documentation. Co-authored-by: google-labs-jules[bot] <161369871+google-labs-jules[bot]@users.noreply.github.com> --- README.md | 6 +++--- examples/websocket/main.go | 4 ++-- pkg/router/config.go | 2 +- pkg/router/route.go | 4 ++-- pkg/router/router.go | 4 ++-- pkg/router/websocket_test.go | 12 ++++++------ 6 files changed, 16 insertions(+), 16 deletions(-) diff --git a/README.md b/README.md index e4d2a8f..40abf8c 100644 --- a/README.md +++ b/README.md @@ -326,14 +326,14 @@ func GetUserHandler(w http.ResponseWriter, r *http.Request) { SRouter supports WebSocket connections by allowing you to disable the automatic request timeout for specific routes. This is crucial for long-lived connections. -To enable WebSocket support for a route, set the `IsWebSocket` flag to `true` in your `RouteConfigBase`. This will prevent the global or sub-router timeout from terminating the connection. +To enable WebSocket support for a route, set the `DisableTimeout` flag to `true` in your `RouteConfigBase`. This will prevent the global or sub-router timeout from terminating the connection. This is also useful for other long-lived connections such as Server-Sent Events (SSE). ```go // Register a WebSocket route r.RegisterRoute(router.RouteConfigBase{ Path: "/ws", Methods: []router.HttpMethod{router.MethodGet}, - IsWebSocket: true, // Disables timeout for this route + DisableTimeout: true, // Disables timeout for this route Handler: func(w http.ResponseWriter, r *http.Request) { // Upgrade the connection to a WebSocket // conn, err := upgrader.Upgrade(w, r, nil) @@ -1245,7 +1245,7 @@ type RouteConfigBase struct { Overrides common.RouteOverrides // Optional per-route overrides Handler http.HandlerFunc // Standard HTTP handler function Middlewares []common.Middleware // Middlewares applied to this specific route - IsWebSocket bool // Indicates if this route is a WebSocket route. If true, timeout is disabled. + DisableTimeout bool // Indicates if the timeout should be disabled for this route (e.g., for WebSockets or long-lived connections). } ``` diff --git a/examples/websocket/main.go b/examples/websocket/main.go index 4ee623a..754d5bd 100644 --- a/examples/websocket/main.go +++ b/examples/websocket/main.go @@ -29,7 +29,7 @@ func main() { routerConfig := router.RouterConfig{ ServiceName: "websocket-example", Logger: logger, - GlobalTimeout: 5 * time.Second, // Global timeout to test IsWebSocket bypass + GlobalTimeout: 5 * time.Second, // Global timeout to test DisableTimeout bypass } // Simple auth - accept everything @@ -55,7 +55,7 @@ func main() { r.RegisterRoute(router.RouteConfigBase{ Path: "/ws", Methods: []router.HttpMethod{router.MethodGet}, - IsWebSocket: true, // Crucial: disables global timeout + DisableTimeout: true, // Crucial: disables global timeout Handler: func(w http.ResponseWriter, r *http.Request) { conn, err := upgrader.Upgrade(w, r, nil) if err != nil { diff --git a/pkg/router/config.go b/pkg/router/config.go index ab4bc9a..477f46c 100644 --- a/pkg/router/config.go +++ b/pkg/router/config.go @@ -195,7 +195,7 @@ type RouteConfigBase struct { Overrides common.RouteOverrides // Configuration overrides for this specific route Handler http.HandlerFunc // Standard HTTP handler function Middlewares []common.Middleware // Middlewares applied to this specific route (combined with sub-router and global middlewares) - IsWebSocket bool // Indicates if this route is a WebSocket route. If true, timeout is disabled. + DisableTimeout bool // Indicates if the timeout should be disabled for this route (e.g., for WebSockets or long-lived connections). } // Implement RouteDefinition for RouteConfigBase diff --git a/pkg/router/route.go b/pkg/router/route.go index e822bb1..cf826b1 100644 --- a/pkg/router/route.go +++ b/pkg/router/route.go @@ -26,8 +26,8 @@ func (r *Router[T, U]) RegisterRoute(route RouteConfigBase) { // Get effective timeout, max body size, and rate limit for this route timeout := r.getEffectiveTimeout(route.Overrides.Timeout, 0) - // If route is a WebSocket, disable timeout - if route.IsWebSocket { + // If route has timeout disabled, set timeout to 0 + if route.DisableTimeout { timeout = 0 } diff --git a/pkg/router/router.go b/pkg/router/router.go index ef8219d..3b605d8 100644 --- a/pkg/router/router.go +++ b/pkg/router/router.go @@ -213,8 +213,8 @@ func (r *Router[T, U]) registerSubRouter(sr SubRouterConfig) { // Get effective settings considering overrides timeout := r.getEffectiveTimeout(route.Overrides.Timeout, sr.Overrides.Timeout) - // If route is a WebSocket, disable timeout - if route.IsWebSocket { + // If route has timeout disabled, set timeout to 0 + if route.DisableTimeout { timeout = 0 } diff --git a/pkg/router/websocket_test.go b/pkg/router/websocket_test.go index cdcd5d4..6d1e361 100644 --- a/pkg/router/websocket_test.go +++ b/pkg/router/websocket_test.go @@ -74,11 +74,11 @@ func TestWebSocketRoute(t *testing.T) { r := router.NewRouter[string, string](config, nil, nil) // Register a "WebSocket" route that sleeps longer than the global timeout - // Since IsWebSocket is true, it should NOT timeout. + // Since DisableTimeout is true, it should NOT timeout. r.RegisterRoute(router.RouteConfigBase{ Path: "/ws", Methods: []router.HttpMethod{router.MethodGet}, - IsWebSocket: true, + DisableTimeout: true, Handler: func(w http.ResponseWriter, r *http.Request) { // Simulate long-lived connection time.Sleep(200 * time.Millisecond) @@ -146,7 +146,7 @@ func TestWebSocketRoutePreservesHijackerWithTracingEnabled(t *testing.T) { r.RegisterRoute(router.RouteConfigBase{ Path: "/ws", Methods: []router.HttpMethod{router.MethodGet}, - IsWebSocket: true, + DisableTimeout: true, Handler: func(w http.ResponseWriter, _ *http.Request) { h, ok := w.(http.Hijacker) if !ok { @@ -196,7 +196,7 @@ func TestWebSocketRouteHijackNotSupportedIsWrapped(t *testing.T) { r.RegisterRoute(router.RouteConfigBase{ Path: "/ws", Methods: []router.HttpMethod{router.MethodGet}, - IsWebSocket: true, + DisableTimeout: true, Handler: func(w http.ResponseWriter, _ *http.Request) { h, ok := w.(http.Hijacker) if !ok { @@ -241,7 +241,7 @@ func TestWebSocketRouteResponseControllerCanReachOptionalInterfaces(t *testing.T r.RegisterRoute(router.RouteConfigBase{ Path: "/ws", Methods: []router.HttpMethod{router.MethodGet}, - IsWebSocket: true, + DisableTimeout: true, Handler: func(w http.ResponseWriter, _ *http.Request) { rc := http.NewResponseController(w) deadline := time.Now().Add(5 * time.Second) @@ -307,7 +307,7 @@ func TestSubRouterWebSocketRoute(t *testing.T) { router.RouteConfigBase{ Path: "/ws", Methods: []router.HttpMethod{router.MethodGet}, - IsWebSocket: true, + DisableTimeout: true, Handler: func(w http.ResponseWriter, r *http.Request) { // Simulate long-lived connection time.Sleep(200 * time.Millisecond) From 0397b8fee87d16c522f38b0fe8b0c15cda5516cd Mon Sep 17 00:00:00 2001 From: suhaib Date: Mon, 15 Dec 2025 18:40:29 -0800 Subject: [PATCH 12/12] refactor: standardize formatting in WebSocket route configuration --- examples/websocket/main.go | 4 ++-- pkg/router/config.go | 12 ++++++------ pkg/router/websocket_test.go | 20 ++++++++++---------- 3 files changed, 18 insertions(+), 18 deletions(-) diff --git a/examples/websocket/main.go b/examples/websocket/main.go index 754d5bd..a1c3d06 100644 --- a/examples/websocket/main.go +++ b/examples/websocket/main.go @@ -53,8 +53,8 @@ func main() { // WebSocket Endpoint r.RegisterRoute(router.RouteConfigBase{ - Path: "/ws", - Methods: []router.HttpMethod{router.MethodGet}, + Path: "/ws", + Methods: []router.HttpMethod{router.MethodGet}, DisableTimeout: true, // Crucial: disables global timeout Handler: func(w http.ResponseWriter, r *http.Request) { conn, err := upgrader.Upgrade(w, r, nil) diff --git a/pkg/router/config.go b/pkg/router/config.go index 477f46c..5ae1f0d 100644 --- a/pkg/router/config.go +++ b/pkg/router/config.go @@ -189,12 +189,12 @@ type SubRouterConfig struct { // - Sub-router settings override global settings // - Middlewares are additive (not replaced) type RouteConfigBase struct { - Path string // Route path (will be prefixed with sub-router path prefix if applicable) - Methods []HttpMethod // HTTP methods this route handles (use constants like MethodGet) - AuthLevel *AuthLevel // Authentication level for this route. If nil, inherits from sub-router or defaults to NoAuth - Overrides common.RouteOverrides // Configuration overrides for this specific route - Handler http.HandlerFunc // Standard HTTP handler function - Middlewares []common.Middleware // Middlewares applied to this specific route (combined with sub-router and global middlewares) + Path string // Route path (will be prefixed with sub-router path prefix if applicable) + Methods []HttpMethod // HTTP methods this route handles (use constants like MethodGet) + AuthLevel *AuthLevel // Authentication level for this route. If nil, inherits from sub-router or defaults to NoAuth + Overrides common.RouteOverrides // Configuration overrides for this specific route + Handler http.HandlerFunc // Standard HTTP handler function + Middlewares []common.Middleware // Middlewares applied to this specific route (combined with sub-router and global middlewares) DisableTimeout bool // Indicates if the timeout should be disabled for this route (e.g., for WebSockets or long-lived connections). } diff --git a/pkg/router/websocket_test.go b/pkg/router/websocket_test.go index 6d1e361..1c5b707 100644 --- a/pkg/router/websocket_test.go +++ b/pkg/router/websocket_test.go @@ -76,8 +76,8 @@ func TestWebSocketRoute(t *testing.T) { // Register a "WebSocket" route that sleeps longer than the global timeout // Since DisableTimeout is true, it should NOT timeout. r.RegisterRoute(router.RouteConfigBase{ - Path: "/ws", - Methods: []router.HttpMethod{router.MethodGet}, + Path: "/ws", + Methods: []router.HttpMethod{router.MethodGet}, DisableTimeout: true, Handler: func(w http.ResponseWriter, r *http.Request) { // Simulate long-lived connection @@ -144,8 +144,8 @@ func TestWebSocketRoutePreservesHijackerWithTracingEnabled(t *testing.T) { var hijackErr error r.RegisterRoute(router.RouteConfigBase{ - Path: "/ws", - Methods: []router.HttpMethod{router.MethodGet}, + Path: "/ws", + Methods: []router.HttpMethod{router.MethodGet}, DisableTimeout: true, Handler: func(w http.ResponseWriter, _ *http.Request) { h, ok := w.(http.Hijacker) @@ -194,8 +194,8 @@ func TestWebSocketRouteHijackNotSupportedIsWrapped(t *testing.T) { var hijackErr error r.RegisterRoute(router.RouteConfigBase{ - Path: "/ws", - Methods: []router.HttpMethod{router.MethodGet}, + Path: "/ws", + Methods: []router.HttpMethod{router.MethodGet}, DisableTimeout: true, Handler: func(w http.ResponseWriter, _ *http.Request) { h, ok := w.(http.Hijacker) @@ -239,8 +239,8 @@ func TestWebSocketRouteResponseControllerCanReachOptionalInterfaces(t *testing.T var sawFullDuplex bool r.RegisterRoute(router.RouteConfigBase{ - Path: "/ws", - Methods: []router.HttpMethod{router.MethodGet}, + Path: "/ws", + Methods: []router.HttpMethod{router.MethodGet}, DisableTimeout: true, Handler: func(w http.ResponseWriter, _ *http.Request) { rc := http.NewResponseController(w) @@ -305,8 +305,8 @@ func TestSubRouterWebSocketRoute(t *testing.T) { }, Routes: []router.RouteDefinition{ router.RouteConfigBase{ - Path: "/ws", - Methods: []router.HttpMethod{router.MethodGet}, + Path: "/ws", + Methods: []router.HttpMethod{router.MethodGet}, DisableTimeout: true, Handler: func(w http.ResponseWriter, r *http.Request) { // Simulate long-lived connection