Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Inverting Proxy and Agent
## Inverting Proxy and Agent

This repository defines a reverse proxy that inverts the direction of traffic
between the proxy and the backend servers.
Expand Down
51 changes: 34 additions & 17 deletions agent/websockets/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,17 @@ limitations under the License.
package websockets

import (
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"log"
"net/http"
"sync"
"time"

"context"

"github.com/gorilla/websocket"
"google3/third_party/golang/gorilla/websocket/websocket"
)

var websocketShimInjectedHeadersPath = []string{"resource", "headers"}
Expand Down Expand Up @@ -57,12 +57,14 @@ func (m *message) Serialize(version int) interface{} {
// and encapsulates it in an API that is a little more amenable to how the server side
// of our websocket shim is implemented.
type Connection struct {
done func() <-chan struct{}
cancel context.CancelFunc
clientMessages chan *message
serverMessages chan *message
protocolVersion int
subprotocol string
done func() <-chan struct{}
cancel context.CancelFunc
clientMessages chan *message
serverMessages chan *message
protocolVersion int
subprotocol string
mu sync.Mutex
lastActivityTime time.Time
}

// This map defines the set of headers that should be stripped from the WS request, as they
Expand All @@ -87,6 +89,20 @@ func stripWSHeader(header http.Header) http.Header {
return result
}

// updateActivity updates the last activity timestamp.
func (conn *Connection) updateActivity() {
conn.mu.Lock()
defer conn.mu.Unlock()
conn.lastActivityTime = time.Now()
}

// lastActivity returns the last activity timestamp.
func (conn *Connection) lastActivity() time.Time {
conn.mu.Lock()
defer conn.mu.Unlock()
return conn.lastActivityTime
}

// NewConnection creates and returns a new Connection.
func NewConnection(ctx context.Context, targetURL string, header http.Header, errCallback func(err error)) (*Connection, error) {
ctx, cancel := context.WithCancel(ctx)
Expand Down Expand Up @@ -137,9 +153,6 @@ func NewConnection(ctx context.Context, targetURL string, header http.Header, er
return
case clientMsg, ok := <-clientMessages:
if !ok {
// The Connection object was explicitly closed. We record that by passing `nil` to
// the error callback.
errCallback(nil)
return
}
if clientMsg == nil {
Expand All @@ -162,11 +175,12 @@ func NewConnection(ctx context.Context, targetURL string, header http.Header, er
}
}()
return &Connection{
done: ctx.Done,
cancel: cancel,
clientMessages: clientMessages,
serverMessages: serverMessages,
subprotocol: serverConn.Subprotocol(),
done: ctx.Done,
cancel: cancel,
clientMessages: clientMessages,
serverMessages: serverMessages,
subprotocol: serverConn.Subprotocol(),
lastActivityTime: time.Now(),
}, nil
}

Expand All @@ -184,6 +198,7 @@ func (conn *Connection) Close() {
//
// The returned error value is non-nill if the connection has been closed.
func (conn *Connection) SendClientMessage(msg interface{}, injectionEnabled bool, injectedHeaders map[string]string) error {
conn.updateActivity()
var clientMessage *message
if textMsg, ok := msg.(string); ok {
clientMessage = &message{
Expand Down Expand Up @@ -244,6 +259,7 @@ func (conn *Connection) ReadServerMessages() ([]interface{}, error) {
// The server messages channel has been closed.
return nil, fmt.Errorf("attempt to read a server message from a closed websocket connection")
}
conn.updateActivity()
msgs = append(msgs, serverMsg.Serialize(conn.protocolVersion))
for {
select {
Expand Down Expand Up @@ -307,3 +323,4 @@ func injectWebsocketMessage(msg *message, injectionPath []string, injectionValue

return &message{Type: msg.Type, Data: newMsgBytes}, nil
}

84 changes: 39 additions & 45 deletions agent/websockets/shim.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package websockets

import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
Expand All @@ -31,9 +32,9 @@ import (
"sync"
"sync/atomic"
"text/template"
"time"

"context"
"github.com/google/inverting-proxy/agent/metrics"
"google3/third_party/golang/invertingproxy/agent/metrics/metrics"
)

const (
Expand Down Expand Up @@ -295,36 +296,33 @@ type sessionMessage struct {
Subprotocol string `json:"s,omitempty"`
}

type connectionErrorHandler struct {
mu sync.Mutex
firstError error
reportedError bool
}

func (c *connectionErrorHandler) Error() error {
c.mu.Lock()
defer c.mu.Unlock()
return c.firstError
}

func (c *connectionErrorHandler) ReportError(err error) {
c.mu.Lock()
defer c.mu.Unlock()
if c.reportedError {
return
}
c.reportedError = true
c.firstError = err
if err != nil {
log.Printf("Websocket failure: %v", err)
}
}

func createShimChannel(ctx context.Context, host, shimPath string, rewriteHost bool, openWebsocketWrapper func(http.Handler, *metrics.MetricHandler) http.Handler, enableWebsocketInjection bool, metricHandler *metrics.MetricHandler) http.Handler {
func createShimChannel(ctx context.Context, host, shimPath string, rewriteHost bool, openWebsocketWrapper func(http.Handler, *metrics.MetricHandler) http.Handler, enableWebsocketInjection bool, metricHandler *metrics.MetricHandler, timeout time.Duration) http.Handler {
var connections sync.Map
var sessionCount uint64
// Background goroutine to clean up inactive websocket shim connections.
go func() {
ticker := time.NewTicker(min(timeout, 30*time.Second))
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
connections.Range(func(key, value any) bool {
sessionID := key.(string)
conn := value.(*Connection)
if time.Since(conn.lastActivity()) > timeout {
log.Printf("Closing inactive websocket shim session %q after timeout", sessionID)
conn.Close()
connections.Delete(sessionID)
}
return true // Continue iteration
})
}
}
}()

mux := http.NewServeMux()
errorHandler := &connectionErrorHandler{}
openWebsocketHandler := openWebsocketWrapper(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
sessionID := fmt.Sprintf("%d", atomic.AddUint64(&sessionCount, 1))
targetURL := *(r.URL)
Expand All @@ -333,7 +331,10 @@ func createShimChannel(ctx context.Context, host, shimPath string, rewriteHost b
if originalHost := r.Host; rewriteHost && originalHost != "" {
r.Header.Set("Host", originalHost)
}
conn, err := NewConnection(ctx, targetURL.String(), r.Header, errorHandler.ReportError)
conn, err := NewConnection(ctx, targetURL.String(), r.Header,
func(err error) {
log.Printf("Websocket failure: %v", err)
})
if err != nil {
log.Printf("Failed to dial the websocket server %q: %v\n", targetURL.String(), err)
statusCode := http.StatusInternalServerError
Expand All @@ -351,9 +352,9 @@ func createShimChannel(ctx context.Context, host, shimPath string, rewriteHost b
}
}
resp := &sessionMessage{
ID: sessionID,
Message: targetURL.String(),
Version: conn.protocolVersion,
ID: sessionID,
Message: targetURL.String(),
Version: conn.protocolVersion,
Subprotocol: conn.Subprotocol(),
}
respBytes, err := json.Marshal(resp)
Expand Down Expand Up @@ -469,11 +470,7 @@ func createShimChannel(ctx context.Context, host, shimPath string, rewriteHost b
}
if err := conn.SendClientMessage(msg.Message, enableWebsocketInjection, injectedHeaders); err != nil {
statusCode := http.StatusBadRequest
errorMessage := fmt.Sprintf("attempt to send data on a closed session: %q", msg.ID)
if closureReason := errorHandler.Error(); closureReason != nil {
errorMessage = fmt.Sprintf("attempt to send data on a closed session: %q, closure reason: %q", msg.ID, closureReason)
}
http.Error(w, errorMessage, statusCode)
http.Error(w, fmt.Sprintf("attempt to send data on a closed session: %q", msg.ID), statusCode)
metricHandler.WriteResponseCodeMetric(statusCode)
return
}
Expand Down Expand Up @@ -515,11 +512,7 @@ func createShimChannel(ctx context.Context, host, shimPath string, rewriteHost b
serverMsgs, err := conn.ReadServerMessages()
if err != nil {
statusCode := http.StatusBadRequest
errorMessage := fmt.Sprintf("attempt to read data from a closed session: %q", msg.ID)
if closureReason := errorHandler.Error(); closureReason != nil {
errorMessage = fmt.Sprintf("attempt to read data from a closed session: %q, closure reason: %q", msg.ID, closureReason)
}
http.Error(w, errorMessage, statusCode)
http.Error(w, fmt.Sprintf("attempt to read data from a closed session: %q", msg.ID), statusCode)
metricHandler.WriteResponseCodeMetric(statusCode)
connections.Delete(msg.ID)
return
Expand Down Expand Up @@ -548,13 +541,14 @@ func createShimChannel(ctx context.Context, host, shimPath string, rewriteHost b
// openWebsocketWrapper is a http.Handler wrapper function that is invoked on websocket open requests after the original
// targetURL of the request is restored. It must call the wrapped http.Handler with which it is created after it
// is finished processing the request.
func Proxy(ctx context.Context, wrapped http.Handler, host, shimPath string, rewriteHost, enableWebsocketInjection bool, openWebsocketWrapper func(wrapped http.Handler, metricHandler *metrics.MetricHandler) http.Handler, metricHandler *metrics.MetricHandler) (http.Handler, error) {
func Proxy(ctx context.Context, wrapped http.Handler, host, shimPath string, rewriteHost, enableWebsocketInjection bool, openWebsocketWrapper func(wrapped http.Handler, metricHandler *metrics.MetricHandler) http.Handler, metricHandler *metrics.MetricHandler, timeout time.Duration) (http.Handler, error) {
mux := http.NewServeMux()
if shimPath != "" {
shimPath = path.Clean("/"+shimPath) + "/"
shimServer := createShimChannel(ctx, host, shimPath, rewriteHost, openWebsocketWrapper, enableWebsocketInjection, metricHandler)
shimServer := createShimChannel(ctx, host, shimPath, rewriteHost, openWebsocketWrapper, enableWebsocketInjection, metricHandler, timeout)
mux.Handle(shimPath, shimServer)
}
mux.Handle("/", wrapped)
return mux, nil
}