Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ SRouter is a high-performance HTTP router for Go that wraps [julienschmidt/httpr
- [Context Management](./docs/context-management.md)
- [Custom Error Handling](./docs/error-handling.md)
- [Custom Middleware](./docs/middleware.md)
- [WebSockets](./docs/websockets.md)
- [Source Types](./docs/generic-routes.md#source-types)
- [Custom Codecs](./docs/codecs.md)
- [Metrics](./docs/metrics.md)
Expand Down Expand Up @@ -52,6 +53,7 @@ SRouter is a high-performance HTTP router for Go that wraps [julienschmidt/httpr
- **Path Parameters**: Easy access to path parameters via request context
- **Graceful Shutdown**: Properly handle in-flight requests during shutdown
- **Flexible Metrics System**: Support for multiple metric formats, custom collectors, and dependency injection
- **WebSocket Endpoints**: Register WebSocket handlers while reusing the same middleware, authentication, and logging pipeline
- **Intelligent Logging**: Structured logging using `zap` with appropriate log levels for different types of events. Requires a logger instance in config.
- **Trace ID Logging**: Automatically generate and include a unique trace ID for each request in context and log entries.
- **Flexible Request Data Sources**: Support for retrieving request data from various sources including request body, query parameters, and path parameters with automatic decoding
Expand Down
49 changes: 49 additions & 0 deletions docs/websockets.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# WebSocket Support

SRouter can now register WebSocket endpoints alongside standard HTTP routes. WebSocket routes reuse the same middleware and authentication pipeline used by REST handlers, so cross-cutting behavior such as tracing, logging, rate limiting, and custom middleware still apply to the initial upgrade request.

## Declaring a WebSocket route

Use `WebSocketRouteConfig` inside a `SubRouterConfig` just like any other `RouteDefinition`:

```go
wsRoute := router.WebSocketRouteConfig{
Path: "/echo",
Middlewares: []common.Middleware{loggingMiddleware},
Handler: func(ctx context.Context, conn *websocket.Conn) {
for {
msgType, payload, err := conn.ReadMessage()
if err != nil {
return
}
_ = conn.WriteMessage(msgType, payload)
}
},
}

r := router.NewRouter(router.RouterConfig{
Logger: logger,
SubRouters: []router.SubRouterConfig{{
PathPrefix: "/ws",
Routes: []router.RouteDefinition{wsRoute},
}},
}, authFn, userFromUserFn)
```

The route is registered under the sub-router path prefix (e.g., `/ws/echo`). Only `GET` is used for WebSocket registration because the handshake is defined on `GET`.

## Upgrader configuration

`WebSocketRouteConfig` accepts an optional `Upgrader`. When omitted, a permissive upgrader is used (it allows all origins). Provide a custom `websocket.Upgrader` when you need stricter origin checks or other advanced settings.

## Middleware, authentication, and limits

The router wraps WebSocket routes with the same middleware chain as HTTP routes:

- Global, sub-router, and route-level middleware are executed before the upgrade occurs.
- Authentication levels (`AuthRequired`, `AuthOptional`, `NoAuth`) are honored for the handshake request.
- Timeout, rate limit, and max body size overrides are applied to the handshake phase. For long-lived WebSocket sessions, consider leaving timeouts unset or explicitly set them to `0` for the route.

## Shutdown behavior

During graceful shutdown, SRouter waits for active WebSocket handlers to return before completing shutdown, just like regular HTTP handlers. Your handler should monitor `ctx.Done()` and exit when requested to allow a timely shutdown.
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ require (

require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/gorilla/websocket v1.5.3 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
Expand Down
70 changes: 50 additions & 20 deletions pkg/router/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,24 @@ func (r *Router[T, U]) registerSubRouter(sr SubRouterConfig) {
// The function itself will handle calculating effective settings and calling RegisterGenericRoute
route(r, sr) // Call the registration function

case WebSocketRouteConfig:
fullPath := sr.PathPrefix + route.Path

timeout := r.getEffectiveTimeout(route.Overrides.Timeout, sr.Overrides.Timeout)
maxBodySize := r.getEffectiveMaxBodySize(route.Overrides.MaxBodySize, sr.Overrides.MaxBodySize)
rateLimit := r.getEffectiveRateLimit(route.Overrides.RateLimit, sr.Overrides.RateLimit)
authLevel := route.AuthLevel
if authLevel == nil {
authLevel = sr.AuthLevel
}

allMiddlewares := make([]common.Middleware, 0, len(sr.Middlewares)+len(route.Middlewares))
allMiddlewares = append(allMiddlewares, sr.Middlewares...)
allMiddlewares = append(allMiddlewares, route.Middlewares...)

wrapped := r.wrapHandler(r.wrapWebSocketHandler(route), authLevel, timeout, maxBodySize, rateLimit, allMiddlewares)
r.router.Handle(http.MethodGet, fullPath, r.convertToHTTPRouterHandle(wrapped, fullPath))

default:
// Log or handle unexpected type in Routes slice
r.logger.Warn("Unsupported type found in SubRouterConfig.Routes",
Expand Down Expand Up @@ -415,19 +433,19 @@ func (r *Router[T, U]) timeoutMiddleware(timeout time.Duration) common.Middlewar
fields = r.addTrace(fields, req)
r.logger.Error("Request timed out", fields...)

// Acquire lock to safely check and potentially write timeout response.
wrappedW.mu.Lock()
// Check if handler already started writing. Use Swap for atomic check-and-set.
if !wrappedW.wroteHeader.Swap(true) {
// Handler hasn't written yet, we can write the timeout error.
// Hold the lock while writing headers and body for timeout.
// Use the new JSON error writer, passing the request
// Attempt to write a timeout response only if the handler hasn't written yet.
if !wrappedW.wroteHeader.Load() {
traceID := scontext.GetTraceIDFromRequest[T, U](req)
r.writeJSONError(wrappedW.ResponseWriter, req, http.StatusRequestTimeout, "Request Timeout", traceID)
r.writeJSONError(wrappedW, req, http.StatusRequestTimeout, "Request Timeout", traceID)
}
// Ensure the handler observes cancellation before returning to avoid races in tests
// and potential goroutine leaks. Cancel explicitly (in addition to deferred cancel)
// and yield once to see if the goroutine already exited, but don't block on it.
cancel()
select {
case <-done:
default:
}
// If wroteHeader was already true, handler won the race, do nothing here.
// Unlock should happen regardless of whether we wrote the error or not.
wrappedW.mu.Unlock()
return
}
})
Expand Down Expand Up @@ -988,6 +1006,27 @@ func (r *Router[T, U]) handleError(w http.ResponseWriter, req *http.Request, err
// It includes the trace ID in the JSON payload if available and enabled.
// It also adds CORS headers based on information stored in the context by the CORS middleware.
func (r *Router[T, U]) writeJSONError(w http.ResponseWriter, req *http.Request, statusCode int, message string, traceID string) { // Add req parameter
// If the writer is mutexResponseWriter, lock it to prevent concurrent header writes
// and short-circuit if a response has already been sent.
if mw, ok := w.(*mutexResponseWriter); ok {
mw.mu.Lock()
defer mw.mu.Unlock()
if mw.wroteHeader.Load() {
return
}
// Write using the underlying ResponseWriter while holding the lock and
// mark the response as written so subsequent attempts are skipped.
mw.wroteHeader.Store(true)
r.writeJSONErrorLocked(mw.ResponseWriter, req, statusCode, message, traceID)
return
}

r.writeJSONErrorLocked(w, req, statusCode, message, traceID)
}

// writeJSONErrorLocked performs the actual write of the JSON error response. It assumes any
// necessary synchronization has already been handled by the caller when needed.
func (r *Router[T, U]) writeJSONErrorLocked(w http.ResponseWriter, req *http.Request, statusCode int, message string, traceID string) {
// Retrieve CORS info from context using the passed-in request
allowedOrigin, credentialsAllowed, corsOK := scontext.GetCORSInfoFromRequest[T, U](req)

Expand All @@ -1008,15 +1047,6 @@ func (r *Router[T, U]) writeJSONError(w http.ResponseWriter, req *http.Request,
}
}

// Check if headers have already been written (best effort)
// This check might not be foolproof depending on the ResponseWriter implementation.
// http.Error handles this internally, but we need to be careful here.
// A common pattern is to use a custom ResponseWriter wrapper that tracks this state.
// Since we have mutexResponseWriter and metricsResponseWriter, they might offer ways,
// but for simplicity, we'll rely on the fact that these error handlers are often
// called before the main handler writes anything. If a panic/timeout happens *after*
// writing has started, writing the JSON error might fail or corrupt the response.

w.Header().Set("Content-Type", "application/json; charset=utf-8")
// Ensure the status code is written *before* the body.
// CORS headers are set above, before this.
Expand Down
76 changes: 76 additions & 0 deletions pkg/router/websocket.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package router

import (
"context"
"errors"
"net/http"

"github.com/Suhaibinator/SRouter/pkg/common"
"github.com/gorilla/websocket"
"go.uber.org/zap"
)

// WebSocketHandler defines the signature for WebSocket route handlers.
// The request context is passed so callers can respect cancellation or deadlines
// applied by middleware such as timeouts or shutdown handling.
type WebSocketHandler func(ctx context.Context, conn *websocket.Conn)

// WebSocketRouteConfig defines a WebSocket endpoint that can be registered like
// any other RouteDefinition.
//
// Middlewares, authentication, and rate limiting are applied to the handshake
// request using the existing pipeline so behavior matches standard routes.
type WebSocketRouteConfig struct {
Path string
AuthLevel *AuthLevel
Overrides common.RouteOverrides
Middlewares []common.Middleware
Upgrader *websocket.Upgrader
Handler WebSocketHandler
}

// isRouteDefinition implements the RouteDefinition interface.
func (WebSocketRouteConfig) isRouteDefinition() {}

// defaultWebSocketUpgrader returns a lenient upgrader suitable for most tests
// and local development scenarios. Users can supply their own Upgrader when they
// need stricter origin checks or advanced configuration.
func defaultWebSocketUpgrader() *websocket.Upgrader {
return &websocket.Upgrader{
CheckOrigin: func(*http.Request) bool {
return true
},
}
}

// wrapWebSocketHandler creates an http.Handler that performs the WebSocket upgrade
// before delegating to the provided WebSocketHandler. The handler returned here is
// still wrapped by the router's middleware chain via wrapHandler.
func (r *Router[T, U]) wrapWebSocketHandler(route WebSocketRouteConfig) http.HandlerFunc {
return func(w http.ResponseWriter, req *http.Request) {
upgrader := route.Upgrader
if upgrader == nil {
upgrader = defaultWebSocketUpgrader()
}

conn, err := upgrader.Upgrade(w, req, nil)
Comment on lines +52 to +56

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge WebSocket upgrade fails when timeout middleware is enabled

Because WebSocket routes are run through wrapHandler, any non-zero Global/SubRouter/route timeout adds the timeout middleware, which replaces the http.ResponseWriter with mutexResponseWriter that does not implement http.Hijacker. The gorilla/websocket.Upgrader.Upgrade call here requires a Hijacker and will return websocket: response does not implement http.Hijacker, so WebSocket endpoints cannot upgrade once timeouts (or any similar middleware that wraps the writer) are configured.

Useful? React with 👍 / 👎.

if err != nil {
// If the error is caused by a failed upgrade the upgrader already
// wrote the appropriate response. Just log and return.
var closeError *websocket.CloseError
if errors.As(err, &closeError) {
r.logger.Debug("WebSocket upgrade closed", zap.Error(err))
} else {
r.logger.Error("WebSocket upgrade failed", zap.Error(err))
}
return
}
defer func() {
if closeErr := conn.Close(); closeErr != nil && !errors.Is(closeErr, websocket.ErrCloseSent) {
r.logger.Warn("WebSocket close failed", zap.Error(closeErr))
}
}()

route.Handler(req.Context(), conn)
}
}
62 changes: 62 additions & 0 deletions pkg/router/websocket_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
package router

import (
"context"
"net/http"
"net/http/httptest"
"strings"
"sync/atomic"
"testing"

"github.com/Suhaibinator/SRouter/pkg/common"
"github.com/Suhaibinator/SRouter/pkg/router/internal/mocks"
"github.com/gorilla/websocket"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
)

func TestWebSocketRoute(t *testing.T) {
logger := zap.NewNop()

middlewareCalled := atomic.Bool{}

wsRoute := WebSocketRouteConfig{
Path: "/echo",
Middlewares: []common.Middleware{
func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
middlewareCalled.Store(true)
next.ServeHTTP(w, r)
})
},
},
Handler: func(ctx context.Context, conn *websocket.Conn) {
msgType, data, err := conn.ReadMessage()
require.NoError(t, err)
require.NoError(t, conn.WriteMessage(msgType, append([]byte("echo:"), data...)))
},
}

r := NewRouter(RouterConfig{Logger: logger, SubRouters: []SubRouterConfig{{
PathPrefix: "/ws",
Routes: []RouteDefinition{wsRoute},
}}}, mocks.MockAuthFunction, mocks.MockUserIDFromUser)

server := httptest.NewServer(r)
defer server.Close()

wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/ws/echo"

conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
require.NoError(t, err)
t.Cleanup(func() {
_ = conn.Close()
})

require.NoError(t, conn.WriteMessage(websocket.TextMessage, []byte("hello")))
_, resp, err := conn.ReadMessage()
require.NoError(t, err)
require.Equal(t, "echo:hello", string(resp))

require.True(t, middlewareCalled.Load(), "middleware should run before WebSocket handler")
}