diff --git a/README.md b/README.md index c28ff71..dcbca6d 100644 --- a/README.md +++ b/README.md @@ -24,6 +24,7 @@ SRouter is a high-performance HTTP router for Go that wraps [julienschmidt/httpr - [Context Management](./docs/context-management.md) - [Custom Error Handling](./docs/error-handling.md) - [Custom Middleware](./docs/middleware.md) + - [WebSockets](./docs/websockets.md) - [Source Types](./docs/generic-routes.md#source-types) - [Custom Codecs](./docs/codecs.md) - [Metrics](./docs/metrics.md) @@ -52,6 +53,7 @@ SRouter is a high-performance HTTP router for Go that wraps [julienschmidt/httpr - **Path Parameters**: Easy access to path parameters via request context - **Graceful Shutdown**: Properly handle in-flight requests during shutdown - **Flexible Metrics System**: Support for multiple metric formats, custom collectors, and dependency injection +- **WebSocket Endpoints**: Register WebSocket handlers while reusing the same middleware, authentication, and logging pipeline - **Intelligent Logging**: Structured logging using `zap` with appropriate log levels for different types of events. Requires a logger instance in config. - **Trace ID Logging**: Automatically generate and include a unique trace ID for each request in context and log entries. - **Flexible Request Data Sources**: Support for retrieving request data from various sources including request body, query parameters, and path parameters with automatic decoding diff --git a/docs/websockets.md b/docs/websockets.md new file mode 100644 index 0000000..e63b3c7 --- /dev/null +++ b/docs/websockets.md @@ -0,0 +1,49 @@ +# WebSocket Support + +SRouter can now register WebSocket endpoints alongside standard HTTP routes. WebSocket routes reuse the same middleware and authentication pipeline used by REST handlers, so cross-cutting behavior such as tracing, logging, rate limiting, and custom middleware still apply to the initial upgrade request. + +## Declaring a WebSocket route + +Use `WebSocketRouteConfig` inside a `SubRouterConfig` just like any other `RouteDefinition`: + +```go +wsRoute := router.WebSocketRouteConfig{ + Path: "/echo", + Middlewares: []common.Middleware{loggingMiddleware}, + Handler: func(ctx context.Context, conn *websocket.Conn) { + for { + msgType, payload, err := conn.ReadMessage() + if err != nil { + return + } + _ = conn.WriteMessage(msgType, payload) + } + }, +} + +r := router.NewRouter(router.RouterConfig{ + Logger: logger, + SubRouters: []router.SubRouterConfig{{ + PathPrefix: "/ws", + Routes: []router.RouteDefinition{wsRoute}, + }}, +}, authFn, userFromUserFn) +``` + +The route is registered under the sub-router path prefix (e.g., `/ws/echo`). Only `GET` is used for WebSocket registration because the handshake is defined on `GET`. + +## Upgrader configuration + +`WebSocketRouteConfig` accepts an optional `Upgrader`. When omitted, a permissive upgrader is used (it allows all origins). Provide a custom `websocket.Upgrader` when you need stricter origin checks or other advanced settings. + +## Middleware, authentication, and limits + +The router wraps WebSocket routes with the same middleware chain as HTTP routes: + +- Global, sub-router, and route-level middleware are executed before the upgrade occurs. +- Authentication levels (`AuthRequired`, `AuthOptional`, `NoAuth`) are honored for the handshake request. +- Timeout, rate limit, and max body size overrides are applied to the handshake phase. For long-lived WebSocket sessions, consider leaving timeouts unset or explicitly set them to `0` for the route. + +## Shutdown behavior + +During graceful shutdown, SRouter waits for active WebSocket handlers to return before completing shutdown, just like regular HTTP handlers. Your handler should monitor `ctx.Done()` and exit when requested to allow a timely shutdown. diff --git a/go.mod b/go.mod index ea335bc..203dd63 100644 --- a/go.mod +++ b/go.mod @@ -15,6 +15,7 @@ require ( require ( github.com/davecgh/go-spew v1.1.1 // indirect + github.com/gorilla/websocket v1.5.3 // indirect github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect diff --git a/go.sum b/go.sum index 9ab7e1b..0c9986f 100644 --- a/go.sum +++ b/go.sum @@ -10,6 +10,8 @@ github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= +github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= diff --git a/pkg/router/router.go b/pkg/router/router.go index 38b71da..ff6fef3 100644 --- a/pkg/router/router.go +++ b/pkg/router/router.go @@ -235,6 +235,24 @@ func (r *Router[T, U]) registerSubRouter(sr SubRouterConfig) { // The function itself will handle calculating effective settings and calling RegisterGenericRoute route(r, sr) // Call the registration function + case WebSocketRouteConfig: + fullPath := sr.PathPrefix + route.Path + + timeout := r.getEffectiveTimeout(route.Overrides.Timeout, sr.Overrides.Timeout) + maxBodySize := r.getEffectiveMaxBodySize(route.Overrides.MaxBodySize, sr.Overrides.MaxBodySize) + rateLimit := r.getEffectiveRateLimit(route.Overrides.RateLimit, sr.Overrides.RateLimit) + authLevel := route.AuthLevel + if authLevel == nil { + authLevel = sr.AuthLevel + } + + allMiddlewares := make([]common.Middleware, 0, len(sr.Middlewares)+len(route.Middlewares)) + allMiddlewares = append(allMiddlewares, sr.Middlewares...) + allMiddlewares = append(allMiddlewares, route.Middlewares...) + + wrapped := r.wrapHandler(r.wrapWebSocketHandler(route), authLevel, timeout, maxBodySize, rateLimit, allMiddlewares) + r.router.Handle(http.MethodGet, fullPath, r.convertToHTTPRouterHandle(wrapped, fullPath)) + default: // Log or handle unexpected type in Routes slice r.logger.Warn("Unsupported type found in SubRouterConfig.Routes", @@ -415,19 +433,19 @@ func (r *Router[T, U]) timeoutMiddleware(timeout time.Duration) common.Middlewar fields = r.addTrace(fields, req) r.logger.Error("Request timed out", fields...) - // Acquire lock to safely check and potentially write timeout response. - wrappedW.mu.Lock() - // Check if handler already started writing. Use Swap for atomic check-and-set. - if !wrappedW.wroteHeader.Swap(true) { - // Handler hasn't written yet, we can write the timeout error. - // Hold the lock while writing headers and body for timeout. - // Use the new JSON error writer, passing the request + // Attempt to write a timeout response only if the handler hasn't written yet. + if !wrappedW.wroteHeader.Load() { traceID := scontext.GetTraceIDFromRequest[T, U](req) - r.writeJSONError(wrappedW.ResponseWriter, req, http.StatusRequestTimeout, "Request Timeout", traceID) + r.writeJSONError(wrappedW, req, http.StatusRequestTimeout, "Request Timeout", traceID) + } + // Ensure the handler observes cancellation before returning to avoid races in tests + // and potential goroutine leaks. Cancel explicitly (in addition to deferred cancel) + // and yield once to see if the goroutine already exited, but don't block on it. + cancel() + select { + case <-done: + default: } - // If wroteHeader was already true, handler won the race, do nothing here. - // Unlock should happen regardless of whether we wrote the error or not. - wrappedW.mu.Unlock() return } }) @@ -988,6 +1006,27 @@ func (r *Router[T, U]) handleError(w http.ResponseWriter, req *http.Request, err // It includes the trace ID in the JSON payload if available and enabled. // It also adds CORS headers based on information stored in the context by the CORS middleware. func (r *Router[T, U]) writeJSONError(w http.ResponseWriter, req *http.Request, statusCode int, message string, traceID string) { // Add req parameter + // If the writer is mutexResponseWriter, lock it to prevent concurrent header writes + // and short-circuit if a response has already been sent. + if mw, ok := w.(*mutexResponseWriter); ok { + mw.mu.Lock() + defer mw.mu.Unlock() + if mw.wroteHeader.Load() { + return + } + // Write using the underlying ResponseWriter while holding the lock and + // mark the response as written so subsequent attempts are skipped. + mw.wroteHeader.Store(true) + r.writeJSONErrorLocked(mw.ResponseWriter, req, statusCode, message, traceID) + return + } + + r.writeJSONErrorLocked(w, req, statusCode, message, traceID) +} + +// writeJSONErrorLocked performs the actual write of the JSON error response. It assumes any +// necessary synchronization has already been handled by the caller when needed. +func (r *Router[T, U]) writeJSONErrorLocked(w http.ResponseWriter, req *http.Request, statusCode int, message string, traceID string) { // Retrieve CORS info from context using the passed-in request allowedOrigin, credentialsAllowed, corsOK := scontext.GetCORSInfoFromRequest[T, U](req) @@ -1008,15 +1047,6 @@ func (r *Router[T, U]) writeJSONError(w http.ResponseWriter, req *http.Request, } } - // Check if headers have already been written (best effort) - // This check might not be foolproof depending on the ResponseWriter implementation. - // http.Error handles this internally, but we need to be careful here. - // A common pattern is to use a custom ResponseWriter wrapper that tracks this state. - // Since we have mutexResponseWriter and metricsResponseWriter, they might offer ways, - // but for simplicity, we'll rely on the fact that these error handlers are often - // called before the main handler writes anything. If a panic/timeout happens *after* - // writing has started, writing the JSON error might fail or corrupt the response. - w.Header().Set("Content-Type", "application/json; charset=utf-8") // Ensure the status code is written *before* the body. // CORS headers are set above, before this. diff --git a/pkg/router/websocket.go b/pkg/router/websocket.go new file mode 100644 index 0000000..36c5051 --- /dev/null +++ b/pkg/router/websocket.go @@ -0,0 +1,76 @@ +package router + +import ( + "context" + "errors" + "net/http" + + "github.com/Suhaibinator/SRouter/pkg/common" + "github.com/gorilla/websocket" + "go.uber.org/zap" +) + +// WebSocketHandler defines the signature for WebSocket route handlers. +// The request context is passed so callers can respect cancellation or deadlines +// applied by middleware such as timeouts or shutdown handling. +type WebSocketHandler func(ctx context.Context, conn *websocket.Conn) + +// WebSocketRouteConfig defines a WebSocket endpoint that can be registered like +// any other RouteDefinition. +// +// Middlewares, authentication, and rate limiting are applied to the handshake +// request using the existing pipeline so behavior matches standard routes. +type WebSocketRouteConfig struct { + Path string + AuthLevel *AuthLevel + Overrides common.RouteOverrides + Middlewares []common.Middleware + Upgrader *websocket.Upgrader + Handler WebSocketHandler +} + +// isRouteDefinition implements the RouteDefinition interface. +func (WebSocketRouteConfig) isRouteDefinition() {} + +// defaultWebSocketUpgrader returns a lenient upgrader suitable for most tests +// and local development scenarios. Users can supply their own Upgrader when they +// need stricter origin checks or advanced configuration. +func defaultWebSocketUpgrader() *websocket.Upgrader { + return &websocket.Upgrader{ + CheckOrigin: func(*http.Request) bool { + return true + }, + } +} + +// wrapWebSocketHandler creates an http.Handler that performs the WebSocket upgrade +// before delegating to the provided WebSocketHandler. The handler returned here is +// still wrapped by the router's middleware chain via wrapHandler. +func (r *Router[T, U]) wrapWebSocketHandler(route WebSocketRouteConfig) http.HandlerFunc { + return func(w http.ResponseWriter, req *http.Request) { + upgrader := route.Upgrader + if upgrader == nil { + upgrader = defaultWebSocketUpgrader() + } + + conn, err := upgrader.Upgrade(w, req, nil) + if err != nil { + // If the error is caused by a failed upgrade the upgrader already + // wrote the appropriate response. Just log and return. + var closeError *websocket.CloseError + if errors.As(err, &closeError) { + r.logger.Debug("WebSocket upgrade closed", zap.Error(err)) + } else { + r.logger.Error("WebSocket upgrade failed", zap.Error(err)) + } + return + } + defer func() { + if closeErr := conn.Close(); closeErr != nil && !errors.Is(closeErr, websocket.ErrCloseSent) { + r.logger.Warn("WebSocket close failed", zap.Error(closeErr)) + } + }() + + route.Handler(req.Context(), conn) + } +} diff --git a/pkg/router/websocket_test.go b/pkg/router/websocket_test.go new file mode 100644 index 0000000..561411a --- /dev/null +++ b/pkg/router/websocket_test.go @@ -0,0 +1,62 @@ +package router + +import ( + "context" + "net/http" + "net/http/httptest" + "strings" + "sync/atomic" + "testing" + + "github.com/Suhaibinator/SRouter/pkg/common" + "github.com/Suhaibinator/SRouter/pkg/router/internal/mocks" + "github.com/gorilla/websocket" + "github.com/stretchr/testify/require" + "go.uber.org/zap" +) + +func TestWebSocketRoute(t *testing.T) { + logger := zap.NewNop() + + middlewareCalled := atomic.Bool{} + + wsRoute := WebSocketRouteConfig{ + Path: "/echo", + Middlewares: []common.Middleware{ + func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + middlewareCalled.Store(true) + next.ServeHTTP(w, r) + }) + }, + }, + Handler: func(ctx context.Context, conn *websocket.Conn) { + msgType, data, err := conn.ReadMessage() + require.NoError(t, err) + require.NoError(t, conn.WriteMessage(msgType, append([]byte("echo:"), data...))) + }, + } + + r := NewRouter(RouterConfig{Logger: logger, SubRouters: []SubRouterConfig{{ + PathPrefix: "/ws", + Routes: []RouteDefinition{wsRoute}, + }}}, mocks.MockAuthFunction, mocks.MockUserIDFromUser) + + server := httptest.NewServer(r) + defer server.Close() + + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/ws/echo" + + conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) + require.NoError(t, err) + t.Cleanup(func() { + _ = conn.Close() + }) + + require.NoError(t, conn.WriteMessage(websocket.TextMessage, []byte("hello"))) + _, resp, err := conn.ReadMessage() + require.NoError(t, err) + require.Equal(t, "echo:hello", string(resp)) + + require.True(t, middlewareCalled.Load(), "middleware should run before WebSocket handler") +}