From 16e9d5d27ca8bc5b896c449f74cb23d19c341989 Mon Sep 17 00:00:00 2001 From: Daniel Borkmann Date: Mon, 7 Jul 2025 10:06:00 +0000 Subject: [PATCH] cilium, socklb: Add states filter for the termination iteration Add a state filter to the iterator and skip TCP sockets which are in closing or time wait state. There is no need to spend time to iterate these. Technically, there is no harm since when the client app closes the socket and it goes into time wait state, then upon close the socket LB removes the socket from the revnat map in cil_sock_release.. but then again, no need to iterate through these. Suggested-by: Yusuke Suzuki Signed-off-by: Daniel Borkmann --- pkg/datapath/sockets/sockets.go | 62 +++++++++++++++++++--- pkg/datapath/sockets/sockets_test.go | 1 + pkg/loadbalancer/reconciler/termination.go | 4 ++ 3 files changed, 59 insertions(+), 8 deletions(-) diff --git a/pkg/datapath/sockets/sockets.go b/pkg/datapath/sockets/sockets.go index 42b056230b516..dfaab9e09912b 100644 --- a/pkg/datapath/sockets/sockets.go +++ b/pkg/datapath/sockets/sockets.go @@ -31,6 +31,53 @@ var ( networkOrder = binary.BigEndian ) +func stateMask(ms ...int) uint32 { + var out uint32 + for _, m := range ms { + out |= 1 << m + } + return out +} + +// StateFilterTCP is a mask of all states we consider for socket termination. +// Instead of destroying all states, we make some notable omissions which are +// documented below: +// +// - TCP_CLOSE: Calls to close a socket in TCP_CLOSE state will result in +// ENOENT, this is also confusing as it is the same err code returned +// when a socket that doesn't exist is destroyed. +// +// - TCP_TIME_WAIT: Socket may enter this state post close/FIN-wait states +// to catch any leftover traffic that may not have arrived yet. This is +// for security reasons, as well as avoiding late traffic from entering +// a new socket bound to the same addr/port. Technically, for the socket +// LB its not necessary as we remove the key from the rev NAT map in the +// cil_sock_release() hook, so these sockets won't be found. On the other +// hand we also do not need to waste time to iterate them. +var StateFilterTCP = stateMask( + // The following states emit RST (net/ipv4/tcp.c#L3228-L3235) + netlink.TCP_ESTABLISHED, + netlink.TCP_CLOSE_WAIT, + netlink.TCP_FIN_WAIT1, + netlink.TCP_FIN_WAIT2, + netlink.TCP_SYN_RECV, + // Sockets in SYN-RECV state are simply removed from request queue + // and freed in memory (net/ipv4/tcp.c#L4878-L4885) + netlink.TCP_NEW_SYN_REC, + // Sockets in TCP_LISTEN are moved to closing state + // (net/ipv4/tcp.c#L4908) + netlink.TCP_CLOSE, + // Following are handled without any special consideration/just closed + netlink.TCP_SYN_SENT, + netlink.TCP_CLOSING, + netlink.TCP_LAST_ACK, + netlink.TCP_LISTEN, +) + +// StateFilterUDP is a mask of all states we consider for socket termination. +// There are no state omissions. +const StateFilterUDP = 0xffff + // Iterate iterates netlink sockets via a callback. func Iterate(proto uint8, family uint8, stateFilter uint32, fn func(*netlink.Socket, error) error) error { return iterate(proto, family, stateFilter, func(s *Socket, err error) error { @@ -64,6 +111,7 @@ type SocketFilter struct { DestPort uint16 Family uint8 Protocol uint8 + States uint32 // Optional callback function to determine whether a filtered socket needs to be destroyed DestroyCB DestroySocketCB } @@ -77,8 +125,6 @@ type DestroySocketCB func(id netlink.SocketID) bool // Supported protocols in the filter: unix.IPPROTO_UDP, unix.IPPROTO_TCP func Destroy(logger *slog.Logger, filter SocketFilter) error { family := filter.Family - protocol := filter.Protocol - if family != syscall.AF_INET && family != syscall.AF_INET6 { return fmt.Errorf("unsupported family for socket destroy: %d", family) } @@ -87,10 +133,10 @@ func Destroy(logger *slog.Logger, filter SocketFilter) error { // Query sockets matching the passed filter, and then destroy the filtered // sockets. - switch protocol { + switch filter.Protocol { case unix.IPPROTO_UDP, unix.IPPROTO_TCP: redo: - err := filterAndDestroySockets(family, protocol, func(sock netlink.SocketID, err error) { + err := filterAndDestroySockets(family, filter.Protocol, filter.States, func(sock netlink.SocketID, err error) { if err != nil { errs = errors.Join(errs, fmt.Errorf("socket with filter [%v]: %w", filter, err)) failed++ @@ -98,7 +144,7 @@ func Destroy(logger *slog.Logger, filter SocketFilter) error { } if filter.MatchSocket(sock) { logger.Info("", logfields.Socket, sock) - if err := destroySocket(logger, sock, family, protocol, 0xffff, true); err != nil { + if err := destroySocket(logger, sock, family, filter.Protocol, filter.States, true); err != nil { errs = errors.Join(errs, fmt.Errorf("destroying socket with filter [%v]: %w", filter, err)) failed++ return @@ -125,7 +171,7 @@ func Destroy(logger *slog.Logger, filter SocketFilter) error { goto redo } default: - return fmt.Errorf("unsupported protocol for socket destroy: %d", protocol) + return fmt.Errorf("unsupported protocol for socket destroy: %d", filter.Protocol) } if success > 0 || failed > 0 || errs != nil { logger.Info( @@ -150,8 +196,8 @@ func (f *SocketFilter) MatchSocket(socket netlink.SocketID) bool { return false } -func filterAndDestroySockets(family, protocol uint8, socketCB func(socket netlink.SocketID, err error)) error { - return iterateNetlinkSockets(protocol, family, 0xffff, func(sockInfo *Socket, err error) error { +func filterAndDestroySockets(family, protocol uint8, states uint32, socketCB func(socket netlink.SocketID, err error)) error { + return iterateNetlinkSockets(protocol, family, states, func(sockInfo *Socket, err error) error { socketCB(sockInfo.ID, err) return nil }) diff --git a/pkg/datapath/sockets/sockets_test.go b/pkg/datapath/sockets/sockets_test.go index 3d217e96e7fd8..e820807076e8c 100644 --- a/pkg/datapath/sockets/sockets_test.go +++ b/pkg/datapath/sockets/sockets_test.go @@ -257,6 +257,7 @@ func TestDestroy(t *testing.T) { DestPort: uint16(dport), Family: unix.AF_INET, Protocol: unix.IPPROTO_UDP, + States: StateFilterUDP, DestroyCB: func(id netlink.SocketID) bool { matches++ return true diff --git a/pkg/loadbalancer/reconciler/termination.go b/pkg/loadbalancer/reconciler/termination.go index 73d9dd454158a..68ba8a42c53cd 100644 --- a/pkg/loadbalancer/reconciler/termination.go +++ b/pkg/loadbalancer/reconciler/termination.go @@ -177,6 +177,7 @@ func terminateConnectionsToBackend(p socketTerminationParams, l3n4Addr lb.L3n4Ad var ( family uint8 protocol uint8 + states uint32 ) ip := net.IP(l3n4Addr.AddrCluster.Addr().AsSlice()) l4Addr := l3n4Addr.L4Addr @@ -184,8 +185,10 @@ func terminateConnectionsToBackend(p socketTerminationParams, l3n4Addr lb.L3n4Ad switch l3n4Addr.Protocol { case lb.UDP: protocol = unix.IPPROTO_UDP + states = sockets.StateFilterUDP case lb.TCP: protocol = unix.IPPROTO_TCP + states = sockets.StateFilterTCP default: return } @@ -214,6 +217,7 @@ func terminateConnectionsToBackend(p socketTerminationParams, l3n4Addr lb.L3n4Ad return p.SocketDestroyer.Destroy(sockets.SocketFilter{ Family: family, Protocol: protocol, + States: states, DestIp: ip, DestPort: l4Addr.Port, DestroyCB: checkSockInRevNat,