Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 32 additions & 11 deletions agent/websockets/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,16 @@ limitations under the License.
package websockets

import (
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"log"
"net/http"
"sync"
"time"

"context"

"github.com/gorilla/websocket"
)
Expand Down Expand Up @@ -59,10 +60,12 @@ func (m *message) Serialize(version int) interface{} {
type Connection struct {
done func() <-chan struct{}
cancel context.CancelFunc
clientMessages chan *message
serverMessages chan *message
protocolVersion int
subprotocol string
clientMessages chan *message
serverMessages chan *message
protocolVersion int
subprotocol string
mu sync.Mutex
lastActivityTime time.Time
}

// This map defines the set of headers that should be stripped from the WS request, as they
Expand All @@ -87,6 +90,21 @@ func stripWSHeader(header http.Header) http.Header {
return result
}

// updateActivity updates the last activity timestamp.
func (conn *Connection) updateActivity() {
conn.mu.Lock()
defer conn.mu.Unlock()
conn.lastActivityTime = time.Now()
}

// lastActivity returns the last activity timestamp.
func (conn *Connection) lastActivity() time.Time {
conn.mu.Lock()
defer conn.mu.Unlock()
return conn.lastActivityTime
}


// NewConnection creates and returns a new Connection.
func NewConnection(ctx context.Context, targetURL string, header http.Header, errCallback func(err error)) (*Connection, error) {
ctx, cancel := context.WithCancel(ctx)
Expand Down Expand Up @@ -162,11 +180,12 @@ func NewConnection(ctx context.Context, targetURL string, header http.Header, er
}
}()
return &Connection{
done: ctx.Done,
cancel: cancel,
clientMessages: clientMessages,
serverMessages: serverMessages,
subprotocol: serverConn.Subprotocol(),
done: ctx.Done,
cancel: cancel,
clientMessages: clientMessages,
serverMessages: serverMessages,
subprotocol: serverConn.Subprotocol(),
lastActivityTime: time.Now(),
}, nil
}

Expand All @@ -184,6 +203,7 @@ func (conn *Connection) Close() {
//
// The returned error value is non-nill if the connection has been closed.
func (conn *Connection) SendClientMessage(msg interface{}, injectionEnabled bool, injectedHeaders map[string]string) error {
conn.updateActivity()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The formatting for this line and the one below (on line 267) is wrong.

Please run gofmt over the changed files.

var clientMessage *message
if textMsg, ok := msg.(string); ok {
clientMessage = &message{
Expand Down Expand Up @@ -244,6 +264,7 @@ func (conn *Connection) ReadServerMessages() ([]interface{}, error) {
// The server messages channel has been closed.
return nil, fmt.Errorf("attempt to read a server message from a closed websocket connection")
}
conn.updateActivity()
msgs = append(msgs, serverMsg.Serialize(conn.protocolVersion))
for {
select {
Expand Down Expand Up @@ -306,4 +327,4 @@ func injectWebsocketMessage(msg *message, injectionPath []string, injectionValue
}

return &message{Type: msg.Type, Data: newMsgBytes}, nil
}
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't delete the trailing newline.

36 changes: 31 additions & 5 deletions agent/websockets/shim.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package websockets

import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
Expand All @@ -31,8 +32,8 @@ import (
"sync"
"sync/atomic"
"text/template"
"time"

"context"
"github.com/google/inverting-proxy/agent/metrics"
)

Expand Down Expand Up @@ -320,9 +321,34 @@ func (c *connectionErrorHandler) ReportError(err error) {
}
}

func createShimChannel(ctx context.Context, host, shimPath string, rewriteHost bool, openWebsocketWrapper func(http.Handler, *metrics.MetricHandler) http.Handler, enableWebsocketInjection bool, metricHandler *metrics.MetricHandler) http.Handler {
func createShimChannel(ctx context.Context, host, shimPath string, rewriteHost bool, openWebsocketWrapper func(http.Handler, *metrics.MetricHandler) http.Handler, enableWebsocketInjection bool, metricHandler *metrics.MetricHandler, timeout time.Duration) http.Handler {
var connections sync.Map
var sessionCount uint64

// Background goroutine to clean up inactive websocket shim connections.
go func() {
ticker := time.NewTicker(min(timeout, 30*time.Second))
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
connections.Range(func(key, value any) bool {
sessionID := key.(string)
conn := value.(*Connection)
if time.Since(conn.lastActivity()) > timeout {
log.Printf("Closing inactive websocket shim session %q after timeout", sessionID)
conn.Close()
connections.Delete(sessionID)
}
return true // Continue iteration
})
}
}
}()


mux := http.NewServeMux()
errorHandler := &connectionErrorHandler{}
openWebsocketHandler := openWebsocketWrapper(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expand Down Expand Up @@ -548,13 +574,13 @@ func createShimChannel(ctx context.Context, host, shimPath string, rewriteHost b
// openWebsocketWrapper is a http.Handler wrapper function that is invoked on websocket open requests after the original
// targetURL of the request is restored. It must call the wrapped http.Handler with which it is created after it
// is finished processing the request.
func Proxy(ctx context.Context, wrapped http.Handler, host, shimPath string, rewriteHost, enableWebsocketInjection bool, openWebsocketWrapper func(wrapped http.Handler, metricHandler *metrics.MetricHandler) http.Handler, metricHandler *metrics.MetricHandler) (http.Handler, error) {
func Proxy(ctx context.Context, wrapped http.Handler, host, shimPath string, rewriteHost, enableWebsocketInjection bool, openWebsocketWrapper func(wrapped http.Handler, metricHandler *metrics.MetricHandler) http.Handler, metricHandler *metrics.MetricHandler, timeout time.Duration) (http.Handler, error) {
mux := http.NewServeMux()
if shimPath != "" {
shimPath = path.Clean("/"+shimPath) + "/"
shimServer := createShimChannel(ctx, host, shimPath, rewriteHost, openWebsocketWrapper, enableWebsocketInjection, metricHandler)
shimServer := createShimChannel(ctx, host, shimPath, rewriteHost, openWebsocketWrapper, enableWebsocketInjection, metricHandler,timeout)
mux.Handle(shimPath, shimServer)
}
mux.Handle("/", wrapped)
return mux, nil
}
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, here, don't delete the trailing newline.

Loading