From 4da9346a89a25381b249311379d768d538b8b2f0 Mon Sep 17 00:00:00 2001 From: "Hareesh.Veligeti" Date: Fri, 23 Jan 2026 12:30:00 +0530 Subject: [PATCH] handling agent memory leak --- agent/websockets/connection.go | 43 +++++++++++++++++++++++++--------- agent/websockets/shim.go | 36 ++++++++++++++++++++++++---- 2 files changed, 63 insertions(+), 16 deletions(-) diff --git a/agent/websockets/connection.go b/agent/websockets/connection.go index c14d445..09fc1da 100644 --- a/agent/websockets/connection.go +++ b/agent/websockets/connection.go @@ -17,15 +17,16 @@ limitations under the License. package websockets import ( + "context" "encoding/base64" "encoding/json" "errors" "fmt" "log" "net/http" + "sync" "time" - "context" "github.com/gorilla/websocket" ) @@ -59,10 +60,12 @@ func (m *message) Serialize(version int) interface{} { type Connection struct { done func() <-chan struct{} cancel context.CancelFunc - clientMessages chan *message - serverMessages chan *message - protocolVersion int - subprotocol string + clientMessages chan *message + serverMessages chan *message + protocolVersion int + subprotocol string + mu sync.Mutex + lastActivityTime time.Time } // This map defines the set of headers that should be stripped from the WS request, as they @@ -87,6 +90,21 @@ func stripWSHeader(header http.Header) http.Header { return result } +// updateActivity updates the last activity timestamp. +func (conn *Connection) updateActivity() { + conn.mu.Lock() + defer conn.mu.Unlock() + conn.lastActivityTime = time.Now() +} + +// lastActivity returns the last activity timestamp. +func (conn *Connection) lastActivity() time.Time { + conn.mu.Lock() + defer conn.mu.Unlock() + return conn.lastActivityTime +} + + // NewConnection creates and returns a new Connection. func NewConnection(ctx context.Context, targetURL string, header http.Header, errCallback func(err error)) (*Connection, error) { ctx, cancel := context.WithCancel(ctx) @@ -162,11 +180,12 @@ func NewConnection(ctx context.Context, targetURL string, header http.Header, er } }() return &Connection{ - done: ctx.Done, - cancel: cancel, - clientMessages: clientMessages, - serverMessages: serverMessages, - subprotocol: serverConn.Subprotocol(), + done: ctx.Done, + cancel: cancel, + clientMessages: clientMessages, + serverMessages: serverMessages, + subprotocol: serverConn.Subprotocol(), + lastActivityTime: time.Now(), }, nil } @@ -184,6 +203,7 @@ func (conn *Connection) Close() { // // The returned error value is non-nill if the connection has been closed. func (conn *Connection) SendClientMessage(msg interface{}, injectionEnabled bool, injectedHeaders map[string]string) error { + conn.updateActivity() var clientMessage *message if textMsg, ok := msg.(string); ok { clientMessage = &message{ @@ -244,6 +264,7 @@ func (conn *Connection) ReadServerMessages() ([]interface{}, error) { // The server messages channel has been closed. return nil, fmt.Errorf("attempt to read a server message from a closed websocket connection") } + conn.updateActivity() msgs = append(msgs, serverMsg.Serialize(conn.protocolVersion)) for { select { @@ -306,4 +327,4 @@ func injectWebsocketMessage(msg *message, injectionPath []string, injectionValue } return &message{Type: msg.Type, Data: newMsgBytes}, nil -} +} \ No newline at end of file diff --git a/agent/websockets/shim.go b/agent/websockets/shim.go index 135a7a2..b8e442c 100644 --- a/agent/websockets/shim.go +++ b/agent/websockets/shim.go @@ -18,6 +18,7 @@ package websockets import ( "bytes" + "context" "encoding/json" "fmt" "io" @@ -31,8 +32,8 @@ import ( "sync" "sync/atomic" "text/template" + "time" - "context" "github.com/google/inverting-proxy/agent/metrics" ) @@ -320,9 +321,34 @@ func (c *connectionErrorHandler) ReportError(err error) { } } -func createShimChannel(ctx context.Context, host, shimPath string, rewriteHost bool, openWebsocketWrapper func(http.Handler, *metrics.MetricHandler) http.Handler, enableWebsocketInjection bool, metricHandler *metrics.MetricHandler) http.Handler { +func createShimChannel(ctx context.Context, host, shimPath string, rewriteHost bool, openWebsocketWrapper func(http.Handler, *metrics.MetricHandler) http.Handler, enableWebsocketInjection bool, metricHandler *metrics.MetricHandler, timeout time.Duration) http.Handler { var connections sync.Map var sessionCount uint64 + + // Background goroutine to clean up inactive websocket shim connections. + go func() { + ticker := time.NewTicker(min(timeout, 30*time.Second)) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + connections.Range(func(key, value any) bool { + sessionID := key.(string) + conn := value.(*Connection) + if time.Since(conn.lastActivity()) > timeout { + log.Printf("Closing inactive websocket shim session %q after timeout", sessionID) + conn.Close() + connections.Delete(sessionID) + } + return true // Continue iteration + }) + } + } + }() + + mux := http.NewServeMux() errorHandler := &connectionErrorHandler{} openWebsocketHandler := openWebsocketWrapper(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -548,13 +574,13 @@ func createShimChannel(ctx context.Context, host, shimPath string, rewriteHost b // openWebsocketWrapper is a http.Handler wrapper function that is invoked on websocket open requests after the original // targetURL of the request is restored. It must call the wrapped http.Handler with which it is created after it // is finished processing the request. -func Proxy(ctx context.Context, wrapped http.Handler, host, shimPath string, rewriteHost, enableWebsocketInjection bool, openWebsocketWrapper func(wrapped http.Handler, metricHandler *metrics.MetricHandler) http.Handler, metricHandler *metrics.MetricHandler) (http.Handler, error) { +func Proxy(ctx context.Context, wrapped http.Handler, host, shimPath string, rewriteHost, enableWebsocketInjection bool, openWebsocketWrapper func(wrapped http.Handler, metricHandler *metrics.MetricHandler) http.Handler, metricHandler *metrics.MetricHandler, timeout time.Duration) (http.Handler, error) { mux := http.NewServeMux() if shimPath != "" { shimPath = path.Clean("/"+shimPath) + "/" - shimServer := createShimChannel(ctx, host, shimPath, rewriteHost, openWebsocketWrapper, enableWebsocketInjection, metricHandler) + shimServer := createShimChannel(ctx, host, shimPath, rewriteHost, openWebsocketWrapper, enableWebsocketInjection, metricHandler,timeout) mux.Handle(shimPath, shimServer) } mux.Handle("/", wrapped) return mux, nil -} +} \ No newline at end of file