diff --git a/README.md b/README.md index 76496b7..cd69107 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# Inverting Proxy and Agent +## Inverting Proxy and Agent This repository defines a reverse proxy that inverts the direction of traffic between the proxy and the backend servers. diff --git a/agent/websockets/connection.go b/agent/websockets/connection.go index c14d445..2d281a4 100644 --- a/agent/websockets/connection.go +++ b/agent/websockets/connection.go @@ -17,17 +17,17 @@ limitations under the License. package websockets import ( + "context" "encoding/base64" "encoding/json" "errors" "fmt" "log" "net/http" + "sync" "time" - "context" - - "github.com/gorilla/websocket" + "google3/third_party/golang/gorilla/websocket/websocket" ) var websocketShimInjectedHeadersPath = []string{"resource", "headers"} @@ -57,12 +57,14 @@ func (m *message) Serialize(version int) interface{} { // and encapsulates it in an API that is a little more amenable to how the server side // of our websocket shim is implemented. type Connection struct { - done func() <-chan struct{} - cancel context.CancelFunc - clientMessages chan *message - serverMessages chan *message - protocolVersion int - subprotocol string + done func() <-chan struct{} + cancel context.CancelFunc + clientMessages chan *message + serverMessages chan *message + protocolVersion int + subprotocol string + mu sync.Mutex + lastActivityTime time.Time } // This map defines the set of headers that should be stripped from the WS request, as they @@ -87,6 +89,20 @@ func stripWSHeader(header http.Header) http.Header { return result } +// updateActivity updates the last activity timestamp. +func (conn *Connection) updateActivity() { + conn.mu.Lock() + defer conn.mu.Unlock() + conn.lastActivityTime = time.Now() +} + +// lastActivity returns the last activity timestamp. +func (conn *Connection) lastActivity() time.Time { + conn.mu.Lock() + defer conn.mu.Unlock() + return conn.lastActivityTime +} + // NewConnection creates and returns a new Connection. func NewConnection(ctx context.Context, targetURL string, header http.Header, errCallback func(err error)) (*Connection, error) { ctx, cancel := context.WithCancel(ctx) @@ -137,9 +153,6 @@ func NewConnection(ctx context.Context, targetURL string, header http.Header, er return case clientMsg, ok := <-clientMessages: if !ok { - // The Connection object was explicitly closed. We record that by passing `nil` to - // the error callback. - errCallback(nil) return } if clientMsg == nil { @@ -162,11 +175,12 @@ func NewConnection(ctx context.Context, targetURL string, header http.Header, er } }() return &Connection{ - done: ctx.Done, - cancel: cancel, - clientMessages: clientMessages, - serverMessages: serverMessages, - subprotocol: serverConn.Subprotocol(), + done: ctx.Done, + cancel: cancel, + clientMessages: clientMessages, + serverMessages: serverMessages, + subprotocol: serverConn.Subprotocol(), + lastActivityTime: time.Now(), }, nil } @@ -184,6 +198,7 @@ func (conn *Connection) Close() { // // The returned error value is non-nill if the connection has been closed. func (conn *Connection) SendClientMessage(msg interface{}, injectionEnabled bool, injectedHeaders map[string]string) error { + conn.updateActivity() var clientMessage *message if textMsg, ok := msg.(string); ok { clientMessage = &message{ @@ -244,6 +259,7 @@ func (conn *Connection) ReadServerMessages() ([]interface{}, error) { // The server messages channel has been closed. return nil, fmt.Errorf("attempt to read a server message from a closed websocket connection") } + conn.updateActivity() msgs = append(msgs, serverMsg.Serialize(conn.protocolVersion)) for { select { @@ -307,3 +323,4 @@ func injectWebsocketMessage(msg *message, injectionPath []string, injectionValue return &message{Type: msg.Type, Data: newMsgBytes}, nil } + diff --git a/agent/websockets/shim.go b/agent/websockets/shim.go index 135a7a2..3ed739d 100644 --- a/agent/websockets/shim.go +++ b/agent/websockets/shim.go @@ -18,6 +18,7 @@ package websockets import ( "bytes" + "context" "encoding/json" "fmt" "io" @@ -31,9 +32,9 @@ import ( "sync" "sync/atomic" "text/template" + "time" - "context" - "github.com/google/inverting-proxy/agent/metrics" + "google3/third_party/golang/invertingproxy/agent/metrics/metrics" ) const ( @@ -295,36 +296,33 @@ type sessionMessage struct { Subprotocol string `json:"s,omitempty"` } -type connectionErrorHandler struct { - mu sync.Mutex - firstError error - reportedError bool -} - -func (c *connectionErrorHandler) Error() error { - c.mu.Lock() - defer c.mu.Unlock() - return c.firstError -} - -func (c *connectionErrorHandler) ReportError(err error) { - c.mu.Lock() - defer c.mu.Unlock() - if c.reportedError { - return - } - c.reportedError = true - c.firstError = err - if err != nil { - log.Printf("Websocket failure: %v", err) - } -} - -func createShimChannel(ctx context.Context, host, shimPath string, rewriteHost bool, openWebsocketWrapper func(http.Handler, *metrics.MetricHandler) http.Handler, enableWebsocketInjection bool, metricHandler *metrics.MetricHandler) http.Handler { +func createShimChannel(ctx context.Context, host, shimPath string, rewriteHost bool, openWebsocketWrapper func(http.Handler, *metrics.MetricHandler) http.Handler, enableWebsocketInjection bool, metricHandler *metrics.MetricHandler, timeout time.Duration) http.Handler { var connections sync.Map var sessionCount uint64 + // Background goroutine to clean up inactive websocket shim connections. + go func() { + ticker := time.NewTicker(min(timeout, 30*time.Second)) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + connections.Range(func(key, value any) bool { + sessionID := key.(string) + conn := value.(*Connection) + if time.Since(conn.lastActivity()) > timeout { + log.Printf("Closing inactive websocket shim session %q after timeout", sessionID) + conn.Close() + connections.Delete(sessionID) + } + return true // Continue iteration + }) + } + } + }() + mux := http.NewServeMux() - errorHandler := &connectionErrorHandler{} openWebsocketHandler := openWebsocketWrapper(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { sessionID := fmt.Sprintf("%d", atomic.AddUint64(&sessionCount, 1)) targetURL := *(r.URL) @@ -333,7 +331,10 @@ func createShimChannel(ctx context.Context, host, shimPath string, rewriteHost b if originalHost := r.Host; rewriteHost && originalHost != "" { r.Header.Set("Host", originalHost) } - conn, err := NewConnection(ctx, targetURL.String(), r.Header, errorHandler.ReportError) + conn, err := NewConnection(ctx, targetURL.String(), r.Header, + func(err error) { + log.Printf("Websocket failure: %v", err) + }) if err != nil { log.Printf("Failed to dial the websocket server %q: %v\n", targetURL.String(), err) statusCode := http.StatusInternalServerError @@ -351,9 +352,9 @@ func createShimChannel(ctx context.Context, host, shimPath string, rewriteHost b } } resp := &sessionMessage{ - ID: sessionID, - Message: targetURL.String(), - Version: conn.protocolVersion, + ID: sessionID, + Message: targetURL.String(), + Version: conn.protocolVersion, Subprotocol: conn.Subprotocol(), } respBytes, err := json.Marshal(resp) @@ -469,11 +470,7 @@ func createShimChannel(ctx context.Context, host, shimPath string, rewriteHost b } if err := conn.SendClientMessage(msg.Message, enableWebsocketInjection, injectedHeaders); err != nil { statusCode := http.StatusBadRequest - errorMessage := fmt.Sprintf("attempt to send data on a closed session: %q", msg.ID) - if closureReason := errorHandler.Error(); closureReason != nil { - errorMessage = fmt.Sprintf("attempt to send data on a closed session: %q, closure reason: %q", msg.ID, closureReason) - } - http.Error(w, errorMessage, statusCode) + http.Error(w, fmt.Sprintf("attempt to send data on a closed session: %q", msg.ID), statusCode) metricHandler.WriteResponseCodeMetric(statusCode) return } @@ -515,11 +512,7 @@ func createShimChannel(ctx context.Context, host, shimPath string, rewriteHost b serverMsgs, err := conn.ReadServerMessages() if err != nil { statusCode := http.StatusBadRequest - errorMessage := fmt.Sprintf("attempt to read data from a closed session: %q", msg.ID) - if closureReason := errorHandler.Error(); closureReason != nil { - errorMessage = fmt.Sprintf("attempt to read data from a closed session: %q, closure reason: %q", msg.ID, closureReason) - } - http.Error(w, errorMessage, statusCode) + http.Error(w, fmt.Sprintf("attempt to read data from a closed session: %q", msg.ID), statusCode) metricHandler.WriteResponseCodeMetric(statusCode) connections.Delete(msg.ID) return @@ -548,13 +541,14 @@ func createShimChannel(ctx context.Context, host, shimPath string, rewriteHost b // openWebsocketWrapper is a http.Handler wrapper function that is invoked on websocket open requests after the original // targetURL of the request is restored. It must call the wrapped http.Handler with which it is created after it // is finished processing the request. -func Proxy(ctx context.Context, wrapped http.Handler, host, shimPath string, rewriteHost, enableWebsocketInjection bool, openWebsocketWrapper func(wrapped http.Handler, metricHandler *metrics.MetricHandler) http.Handler, metricHandler *metrics.MetricHandler) (http.Handler, error) { +func Proxy(ctx context.Context, wrapped http.Handler, host, shimPath string, rewriteHost, enableWebsocketInjection bool, openWebsocketWrapper func(wrapped http.Handler, metricHandler *metrics.MetricHandler) http.Handler, metricHandler *metrics.MetricHandler, timeout time.Duration) (http.Handler, error) { mux := http.NewServeMux() if shimPath != "" { shimPath = path.Clean("/"+shimPath) + "/" - shimServer := createShimChannel(ctx, host, shimPath, rewriteHost, openWebsocketWrapper, enableWebsocketInjection, metricHandler) + shimServer := createShimChannel(ctx, host, shimPath, rewriteHost, openWebsocketWrapper, enableWebsocketInjection, metricHandler, timeout) mux.Handle(shimPath, shimServer) } mux.Handle("/", wrapped) return mux, nil } +