From 99d2c0a1b279e9b3e90c4e9ded0dfbf5ae0f2135 Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Mon, 15 Dec 2025 09:06:03 +0000 Subject: [PATCH 1/3] feat: add minimal websocket example app - Add `github.com/gorilla/websocket` dependency. - Create `examples/websocket/main.go` to demonstrate functional WebSocket support. - Include a client test in the example to verify REST and WebSocket endpoints. - Prove that `IsWebSocket: true` bypasses global timeouts. --- examples/websocket/main.go | 146 +++++++++++++++++++++++++++++++++++++ go.mod | 1 + go.sum | 2 + 3 files changed, 149 insertions(+) create mode 100644 examples/websocket/main.go diff --git a/examples/websocket/main.go b/examples/websocket/main.go new file mode 100644 index 0000000..db2bca2 --- /dev/null +++ b/examples/websocket/main.go @@ -0,0 +1,146 @@ +package main + +import ( + "context" + "fmt" + "io" + "log" + "net/http" + "net/url" + "time" + + "github.com/Suhaibinator/SRouter/pkg/router" + "github.com/gorilla/websocket" + "go.uber.org/zap" +) + +var upgrader = websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + // Allow all origins for this example + CheckOrigin: func(r *http.Request) bool { return true }, +} + +func main() { + // 1. Setup Server + logger, _ := zap.NewProduction() + defer logger.Sync() + + routerConfig := router.RouterConfig{ + ServiceName: "websocket-example", + Logger: logger, + GlobalTimeout: 5 * time.Second, // Global timeout to test IsWebSocket bypass + } + + // Simple auth - accept everything + authFunc := func(ctx context.Context, token string) (*string, bool) { + user := "generic-user" + return &user, true + } + userIdFunc := func(user *string) string { return *user } + + r := router.NewRouter[string, string](routerConfig, authFunc, userIdFunc) + + // REST Endpoint + r.RegisterRoute(router.RouteConfigBase{ + Path: "/hello", + Methods: []router.HttpMethod{router.MethodGet}, + Handler: func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("Hello, World!")) + }, + }) + + // WebSocket Endpoint + r.RegisterRoute(router.RouteConfigBase{ + Path: "/ws", + Methods: []router.HttpMethod{router.MethodGet}, + IsWebSocket: true, // Crucial: disables global timeout + Handler: func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + logger.Error("upgrade failed", zap.Error(err)) + return + } + defer conn.Close() + + for { + messageType, p, err := conn.ReadMessage() + if err != nil { + return + } + // Echo message back + if err := conn.WriteMessage(messageType, p); err != nil { + return + } + } + }, + }) + + // Start server in goroutine + port := "8089" + server := &http.Server{Addr: ":" + port, Handler: r} + + go func() { + if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed { + log.Fatalf("ListenAndServe(): %v", err) + } + }() + fmt.Printf("Server started on port %s\n", port) + + // Give server a moment to start + time.Sleep(100 * time.Millisecond) + + // 2. Test Client Logic + testREST(port) + testWebSocket(port) + + // Shutdown + server.Shutdown(context.Background()) + fmt.Println("Done.") +} + +func testREST(port string) { + fmt.Println("--- Testing REST Endpoint ---") + resp, err := http.Get(fmt.Sprintf("http://localhost:%s/hello", port)) + if err != nil { + log.Fatalf("REST request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + log.Fatalf("REST expected status 200, got %d", resp.StatusCode) + } + + body, _ := io.ReadAll(resp.Body) + fmt.Printf("REST Response: %s\n", string(body)) + fmt.Println("REST Test Passed!") +} + +func testWebSocket(port string) { + fmt.Println("--- Testing WebSocket Endpoint ---") + u := url.URL{Scheme: "ws", Host: "localhost:" + port, Path: "/ws"} + + c, _, err := websocket.DefaultDialer.Dial(u.String(), nil) + if err != nil { + log.Fatalf("WebSocket dial failed: %v", err) + } + defer c.Close() + + msg := "hello websocket" + err = c.WriteMessage(websocket.TextMessage, []byte(msg)) + if err != nil { + log.Fatalf("WebSocket write failed: %v", err) + } + + _, message, err := c.ReadMessage() + if err != nil { + log.Fatalf("WebSocket read failed: %v", err) + } + + fmt.Printf("WebSocket Response: %s\n", string(message)) + if string(message) != msg { + log.Fatalf("WebSocket expected echo '%s', got '%s'", msg, string(message)) + } + fmt.Println("WebSocket Test Passed!") +} diff --git a/go.mod b/go.mod index ea335bc..203dd63 100644 --- a/go.mod +++ b/go.mod @@ -15,6 +15,7 @@ require ( require ( github.com/davecgh/go-spew v1.1.1 // indirect + github.com/gorilla/websocket v1.5.3 // indirect github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect diff --git a/go.sum b/go.sum index 9ab7e1b..0c9986f 100644 --- a/go.sum +++ b/go.sum @@ -10,6 +10,8 @@ github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= +github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= From af96b94838850bb54fbbbc02ef3041e1ecff5503 Mon Sep 17 00:00:00 2001 From: Suhaib Date: Mon, 15 Dec 2025 10:04:48 -0800 Subject: [PATCH 2/3] refactor: simplify router initialization and update dependencies --- examples/websocket/main.go | 2 +- go.mod | 21 ++++++------ go.sum | 56 ++++++++++++-------------------- pkg/router/handler_error_test.go | 6 ++-- pkg/router/websocket_test.go | 6 ++-- 5 files changed, 38 insertions(+), 53 deletions(-) diff --git a/examples/websocket/main.go b/examples/websocket/main.go index db2bca2..4ee623a 100644 --- a/examples/websocket/main.go +++ b/examples/websocket/main.go @@ -39,7 +39,7 @@ func main() { } userIdFunc := func(user *string) string { return *user } - r := router.NewRouter[string, string](routerConfig, authFunc, userIdFunc) + r := router.NewRouter(routerConfig, authFunc, userIdFunc) // REST Endpoint r.RegisterRoute(router.RouteConfigBase{ diff --git a/go.mod b/go.mod index 203dd63..224ee55 100644 --- a/go.mod +++ b/go.mod @@ -4,22 +4,23 @@ go 1.24.0 require ( github.com/julienschmidt/httprouter v1.3.0 - go.uber.org/zap v1.27.0 + go.uber.org/zap v1.27.1 ) require ( github.com/google/uuid v1.6.0 - github.com/stretchr/testify v1.10.0 - gorm.io/gorm v1.30.1 + github.com/gorilla/websocket v1.5.3 + github.com/stretchr/testify v1.11.1 + gorm.io/gorm v1.31.1 ) require ( github.com/davecgh/go-spew v1.1.1 // indirect - github.com/gorilla/websocket v1.5.3 // indirect github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - golang.org/x/text v0.28.0 // indirect + go.yaml.in/yaml/v2 v2.4.3 // indirect + golang.org/x/text v0.32.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) @@ -29,14 +30,14 @@ require ( github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/prometheus/client_model v0.6.2 - github.com/prometheus/common v0.65.0 // indirect - github.com/prometheus/procfs v0.17.0 // indirect + github.com/prometheus/common v0.67.4 // indirect + github.com/prometheus/procfs v0.19.2 // indirect go.uber.org/ratelimit v0.3.1 - golang.org/x/sys v0.35.0 // indirect - google.golang.org/protobuf v1.36.7 + golang.org/x/sys v0.39.0 // indirect + google.golang.org/protobuf v1.36.11 ) require ( - github.com/prometheus/client_golang v1.23.0 + github.com/prometheus/client_golang v1.23.2 go.uber.org/multierr v1.11.0 // indirect ) diff --git a/go.sum b/go.sum index 0c9986f..ca0d78e 100644 --- a/go.sum +++ b/go.sum @@ -30,24 +30,18 @@ github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/prometheus/client_golang v1.22.0 h1:rb93p9lokFEsctTys46VnV1kLCDpVZ0a/Y92Vm0Zc6Q= -github.com/prometheus/client_golang v1.22.0/go.mod h1:R7ljNsLXhuQXYZYtw6GAE9AZg8Y7vEW5scdCXrWRXC0= -github.com/prometheus/client_golang v1.23.0 h1:ust4zpdl9r4trLY/gSjlm07PuiBq2ynaXXlptpfy8Uc= -github.com/prometheus/client_golang v1.23.0/go.mod h1:i/o0R9ByOnHX0McrTMTyhYvKE4haaf2mW08I+jGAjEE= +github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h0RJWRi/o0o= +github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg= github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk= github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE= -github.com/prometheus/common v0.64.0 h1:pdZeA+g617P7oGv1CzdTzyeShxAGrTBsolKNOLQPGO4= -github.com/prometheus/common v0.64.0/go.mod h1:0gZns+BLRQ3V6NdaerOhMbwwRbNh9hkGINtQAsP5GS8= -github.com/prometheus/common v0.65.0 h1:QDwzd+G1twt//Kwj/Ww6E9FQq1iVMmODnILtW1t2VzE= -github.com/prometheus/common v0.65.0/go.mod h1:0gZns+BLRQ3V6NdaerOhMbwwRbNh9hkGINtQAsP5GS8= -github.com/prometheus/procfs v0.16.1 h1:hZ15bTNuirocR6u0JZ6BAHHmwS1p8B4P6MRqxtzMyRg= -github.com/prometheus/procfs v0.16.1/go.mod h1:teAbpZRB1iIAJYREa1LsoWUXykVXA1KlTmWl8x/U+Is= -github.com/prometheus/procfs v0.17.0 h1:FuLQ+05u4ZI+SS/w9+BWEM2TXiHKsUQ9TADiRH7DuK0= -github.com/prometheus/procfs v0.17.0/go.mod h1:oPQLaDAMRbA+u8H5Pbfq+dl3VDAvHxMUOVhe0wYB2zw= +github.com/prometheus/common v0.67.4 h1:yR3NqWO1/UyO1w2PhUvXlGQs/PtFmoveVO0KZ4+Lvsc= +github.com/prometheus/common v0.67.4/go.mod h1:gP0fq6YjjNCLssJCQp0yk4M8W6ikLURwkdd/YKtTbyI= +github.com/prometheus/procfs v0.19.2 h1:zUMhqEW66Ex7OXIiDkll3tl9a1ZdilUOd/F6ZXw4Vws= +github.com/prometheus/procfs v0.19.2/go.mod h1:M0aotyiemPhBCM0z5w87kL22CxfcH05ZpYlu+b4J7mw= github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog= -github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= -github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= go.uber.org/atomic v1.7.0 h1:ADUqmZGgLDDfbSL9ZmPxKTybcoEYHgpYfELNoN+7hsw= go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= @@ -56,30 +50,20 @@ go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= go.uber.org/ratelimit v0.3.1 h1:K4qVE+byfv/B3tC+4nYWP7v/6SimcO7HzHekoMNBma0= go.uber.org/ratelimit v0.3.1/go.mod h1:6euWsTB6U/Nb3X++xEUXA8ciPJvr19Q/0h1+oDcJhRk= -go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8= -go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= -golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw= -golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= -golang.org/x/sys v0.34.0 h1:H5Y5sJ2L2JRdyv7ROF1he/lPdvFsd0mJHFw2ThKHxLA= -golang.org/x/sys v0.34.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= -golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI= -golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= -golang.org/x/text v0.26.0 h1:P42AVeLghgTYr4+xUnTRKDMqpar+PtX7KWuNQL21L8M= -golang.org/x/text v0.26.0/go.mod h1:QK15LZJUUQVJxhz7wXgxSy/CJaTFjd0G+YLonydOVQA= -golang.org/x/text v0.27.0 h1:4fGWRpyh641NLlecmyl4LOe6yDdfaYNrGb2zdfo4JV4= -golang.org/x/text v0.27.0/go.mod h1:1D28KMCvyooCX9hBiosv5Tz/+YLxj0j7XhWjpSUF7CU= -golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng= -golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU= -google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY= -google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY= -google.golang.org/protobuf v1.36.7 h1:IgrO7UwFQGJdRNXH/sQux4R1Dj1WAKcLElzeeRaXV2A= -google.golang.org/protobuf v1.36.7/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY= +go.uber.org/zap v1.27.1 h1:08RqriUEv8+ArZRYSTXy1LeBScaMpVSTBhCeaZYfMYc= +go.uber.org/zap v1.27.1/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= +go.yaml.in/yaml/v2 v2.4.3 h1:6gvOSjQoTB3vt1l+CU+tSyi/HOjfOjRLJ4YwYZGwRO0= +go.yaml.in/yaml/v2 v2.4.3/go.mod h1:zSxWcmIDjOzPXpjlTTbAsKokqkDNAVtZO0WOMiT90s8= +golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk= +golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU= +golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY= +google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= +google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -gorm.io/gorm v1.30.0 h1:qbT5aPv1UH8gI99OsRlvDToLxW5zR7FzS9acZDOZcgs= -gorm.io/gorm v1.30.0/go.mod h1:8Z33v652h4//uMA76KjeDH8mJXPm1QNCYrMeatR0DOE= -gorm.io/gorm v1.30.1 h1:lSHg33jJTBxs2mgJRfRZeLDG+WZaHYCk3Wtfl6Ngzo4= -gorm.io/gorm v1.30.1/go.mod h1:8Z33v652h4//uMA76KjeDH8mJXPm1QNCYrMeatR0DOE= +gorm.io/gorm v1.31.1 h1:7CA8FTFz/gRfgqgpeKIBcervUn3xSyPUmr6B2WXJ7kg= +gorm.io/gorm v1.31.1/go.mod h1:XyQVbO2k6YkOis7C2437jSit3SsDK72s7n7rsSHd+Gs= diff --git a/pkg/router/handler_error_test.go b/pkg/router/handler_error_test.go index 25eccd8..afc5120 100644 --- a/pkg/router/handler_error_test.go +++ b/pkg/router/handler_error_test.go @@ -23,7 +23,7 @@ func TestGenericRouteHandlerError(t *testing.T) { return 0 } - router := NewRouter[int, interface{}](RouterConfig{ + router := NewRouter(RouterConfig{ Logger: zap.NewNop(), }, getUserByID, getUserID) @@ -177,11 +177,11 @@ func TestHandlerErrorWithMultipleMiddleware(t *testing.T) { getUserByID := func(ctx context.Context, userID string) (*interface{}, bool) { return nil, false } - getUserID := func(user *interface{}) int { + getUserID := func(user *any) int { return 0 } - router := NewRouter[int, interface{}](RouterConfig{ + router := NewRouter(RouterConfig{ Logger: zap.NewNop(), }, getUserByID, getUserID) diff --git a/pkg/router/websocket_test.go b/pkg/router/websocket_test.go index 243e422..cdcd5d4 100644 --- a/pkg/router/websocket_test.go +++ b/pkg/router/websocket_test.go @@ -21,9 +21,9 @@ type hijackableRecorder struct { serverConn net.Conn clientConn net.Conn - readDeadline time.Time - writeDeadline time.Time - fullDuplexEnabled bool + readDeadline time.Time + writeDeadline time.Time + fullDuplexEnabled bool } func newHijackableRecorder() *hijackableRecorder { From 66b9ccc0bcfdfea2bcb0b37fcca0154a82052158 Mon Sep 17 00:00:00 2001 From: Suhaib Date: Mon, 15 Dec 2025 13:04:28 -0800 Subject: [PATCH 3/3] fix: improve timeout handling in mutexResponseWriter to prevent race conditions --- pkg/router/router.go | 114 ++++++++++++++++++++++++++++++++++++++----- 1 file changed, 103 insertions(+), 11 deletions(-) diff --git a/pkg/router/router.go b/pkg/router/router.go index d0ae96b..ef8219d 100644 --- a/pkg/router/router.go +++ b/pkg/router/router.go @@ -423,19 +423,48 @@ func (r *Router[T, U]) timeoutMiddleware(timeout time.Duration) common.Middlewar fields = r.addTrace(fields, req) r.logger.Error("Request timed out", fields...) - // Acquire lock to safely check and potentially write timeout response. - wrappedW.mu.Lock() - // Check if handler already started writing. Use Swap for atomic check-and-set. - if !wrappedW.wroteHeader.Swap(true) { - // Handler hasn't written yet, we can write the timeout error. - // Hold the lock while writing headers and body for timeout. - // Use the new JSON error writer, passing the request - traceID := scontext.GetTraceIDFromRequest[T, U](req) - r.writeJSONError(wrappedW.ResponseWriter, req, http.StatusRequestTimeout, "Request Timeout", traceID) + // If the handler already started writing, don't attempt to take over the response. + // Wait for the handler to finish to avoid returning while another goroutine is writing. + if wrappedW.wroteHeader.Load() { + <-done + select { + case p := <-panicChan: + panic(p) + default: + } + return + } + + // Mark timed out so any in-flight handler writes fail fast and don't touch the underlying writer. + wrappedW.timedOut.Store(true) + + // Reserve the response so the handler can't race to write its own error response. + if !wrappedW.wroteHeader.CompareAndSwap(false, true) { + <-done + select { + case p := <-panicChan: + panic(p) + default: + } + return } - // If wroteHeader was already true, handler won the race, do nothing here. - // Unlock should happen regardless of whether we wrote the error or not. + + // Serialize the timeout response write with any handler goroutine currently inside rw methods. + wrappedW.mu.Lock() + traceID := scontext.GetTraceIDFromRequest[T, U](req) + r.writeJSONError(wrappedW.ResponseWriter, req, http.StatusRequestTimeout, "Request Timeout", traceID) wrappedW.mu.Unlock() + + // Give the handler a chance to observe cancellation and exit promptly. + select { + case <-done: + select { + case p := <-panicChan: + panic(p) + default: + } + case <-time.After(50 * time.Millisecond): + } return } }) @@ -1013,6 +1042,56 @@ func (r *Router[T, U]) handleError(w http.ResponseWriter, req *http.Request, err // It includes the trace ID in the JSON payload if available and enabled. // It also adds CORS headers based on information stored in the context by the CORS middleware. func (r *Router[T, U]) writeJSONError(w http.ResponseWriter, req *http.Request, statusCode int, message string, traceID string) { // Add req parameter + if mrw, ok := w.(*mutexResponseWriter); ok { + if mrw.timedOut.Load() { + return + } + if !mrw.wroteHeader.CompareAndSwap(false, true) { + return + } + + mrw.mu.Lock() + defer mrw.mu.Unlock() + + allowedOrigin, credentialsAllowed, corsOK := scontext.GetCORSInfoFromRequest[T, U](req) + header := mrw.ResponseWriter.Header() + + if corsOK { + if allowedOrigin != "" { + header.Set("Access-Control-Allow-Origin", allowedOrigin) + } + if credentialsAllowed { + header.Set("Access-Control-Allow-Credentials", "true") + } + if allowedOrigin != "" && allowedOrigin != "*" { + header.Add("Vary", "Origin") + } + } + + header.Set("Content-Type", "application/json; charset=utf-8") + mrw.ResponseWriter.WriteHeader(statusCode) + + errorPayload := map[string]any{ + "error": map[string]string{ + "message": message, + }, + } + if r.config.TraceIDBufferSize > 0 && traceID != "" { + errorMap := errorPayload["error"].(map[string]string) + errorMap["trace_id"] = traceID + } + + if err := json.NewEncoder(mrw.ResponseWriter).Encode(errorPayload); err != nil { + r.logger.Error("Failed to write JSON error response", + zap.Error(err), + zap.Int("original_status", statusCode), + zap.String("original_message", message), + zap.String("trace_id", traceID), + ) + } + return + } + // Retrieve CORS info from context using the passed-in request allowedOrigin, credentialsAllowed, corsOK := scontext.GetCORSInfoFromRequest[T, U](req) @@ -1220,10 +1299,14 @@ type mutexResponseWriter struct { http.ResponseWriter mu *sync.Mutex wroteHeader atomic.Bool // Tracks if WriteHeader or Write has been called + timedOut atomic.Bool // When true, reject all writes to the underlying writer } // Header acquires the mutex and returns the underlying Header map. func (rw *mutexResponseWriter) Header() http.Header { + if rw.timedOut.Load() { + return make(http.Header) + } rw.mu.Lock() defer rw.mu.Unlock() return rw.ResponseWriter.Header() @@ -1231,6 +1314,9 @@ func (rw *mutexResponseWriter) Header() http.Header { // WriteHeader acquires the mutex, marks headers as written, and calls the underlying ResponseWriter.WriteHeader. func (rw *mutexResponseWriter) WriteHeader(statusCode int) { + if rw.timedOut.Load() { + return + } rw.mu.Lock() defer rw.mu.Unlock() if !rw.wroteHeader.Swap(true) { // Atomically set flag and check previous value @@ -1241,6 +1327,9 @@ func (rw *mutexResponseWriter) WriteHeader(statusCode int) { // Write acquires the mutex, marks headers/body as written, and calls the underlying ResponseWriter.Write. func (rw *mutexResponseWriter) Write(b []byte) (int, error) { + if rw.timedOut.Load() { + return 0, http.ErrHandlerTimeout + } rw.mu.Lock() defer rw.mu.Unlock() rw.wroteHeader.Store(true) // Mark as written (headers might be implicitly written here) @@ -1249,6 +1338,9 @@ func (rw *mutexResponseWriter) Write(b []byte) (int, error) { // Flush acquires the mutex and calls the underlying ResponseWriter.Flush if it implements http.Flusher. func (rw *mutexResponseWriter) Flush() { + if rw.timedOut.Load() { + return + } rw.mu.Lock() defer rw.mu.Unlock() if f, ok := rw.ResponseWriter.(http.Flusher); ok {