Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 54 additions & 8 deletions pkg/datapath/sockets/sockets.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,53 @@ var (
networkOrder = binary.BigEndian
)

func stateMask(ms ...int) uint32 {
var out uint32
for _, m := range ms {
out |= 1 << m
}
return out
}

// StateFilterTCP is a mask of all states we consider for socket termination.
// Instead of destroying all states, we make some notable omissions which are
// documented below:
//
// - TCP_CLOSE: Calls to close a socket in TCP_CLOSE state will result in
// ENOENT, this is also confusing as it is the same err code returned
// when a socket that doesn't exist is destroyed.
//
// - TCP_TIME_WAIT: Socket may enter this state post close/FIN-wait states
// to catch any leftover traffic that may not have arrived yet. This is
// for security reasons, as well as avoiding late traffic from entering
// a new socket bound to the same addr/port. Technically, for the socket
// LB its not necessary as we remove the key from the rev NAT map in the
// cil_sock_release() hook, so these sockets won't be found. On the other
// hand we also do not need to waste time to iterate them.
var StateFilterTCP = stateMask(
// The following states emit RST (net/ipv4/tcp.c#L3228-L3235)
netlink.TCP_ESTABLISHED,
netlink.TCP_CLOSE_WAIT,
netlink.TCP_FIN_WAIT1,
netlink.TCP_FIN_WAIT2,
netlink.TCP_SYN_RECV,
// Sockets in SYN-RECV state are simply removed from request queue
// and freed in memory (net/ipv4/tcp.c#L4878-L4885)
netlink.TCP_NEW_SYN_REC,
// Sockets in TCP_LISTEN are moved to closing state
// (net/ipv4/tcp.c#L4908)
netlink.TCP_CLOSE,
Copy link

Copilot AI Jan 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment at lines 46-48 states that TCP_CLOSE is excluded from the filter because destroying sockets in this state results in ENOENT. However, TCP_CLOSE is included in the StateFilterTCP mask at line 69. This contradicts the documentation. Either remove TCP_CLOSE from the mask or update the comment to accurately reflect which states are included.

Copilot uses AI. Check for mistakes.
// Following are handled without any special consideration/just closed
netlink.TCP_SYN_SENT,
netlink.TCP_CLOSING,
netlink.TCP_LAST_ACK,
netlink.TCP_LISTEN,
)

// StateFilterUDP is a mask of all states we consider for socket termination.
// There are no state omissions.
const StateFilterUDP = 0xffff

// Iterate iterates netlink sockets via a callback.
func Iterate(proto uint8, family uint8, stateFilter uint32, fn func(*netlink.Socket, error) error) error {
return iterate(proto, family, stateFilter, func(s *Socket, err error) error {
Expand Down Expand Up @@ -64,6 +111,7 @@ type SocketFilter struct {
DestPort uint16
Family uint8
Protocol uint8
States uint32
// Optional callback function to determine whether a filtered socket needs to be destroyed
DestroyCB DestroySocketCB
}
Expand All @@ -77,8 +125,6 @@ type DestroySocketCB func(id netlink.SocketID) bool
// Supported protocols in the filter: unix.IPPROTO_UDP, unix.IPPROTO_TCP
func Destroy(logger *slog.Logger, filter SocketFilter) error {
family := filter.Family
protocol := filter.Protocol

if family != syscall.AF_INET && family != syscall.AF_INET6 {
return fmt.Errorf("unsupported family for socket destroy: %d", family)
}
Expand All @@ -87,18 +133,18 @@ func Destroy(logger *slog.Logger, filter SocketFilter) error {

// Query sockets matching the passed filter, and then destroy the filtered
// sockets.
switch protocol {
switch filter.Protocol {
case unix.IPPROTO_UDP, unix.IPPROTO_TCP:
redo:
err := filterAndDestroySockets(family, protocol, func(sock netlink.SocketID, err error) {
err := filterAndDestroySockets(family, filter.Protocol, filter.States, func(sock netlink.SocketID, err error) {
if err != nil {
errs = errors.Join(errs, fmt.Errorf("socket with filter [%v]: %w", filter, err))
failed++
return
}
if filter.MatchSocket(sock) {
logger.Info("", logfields.Socket, sock)
if err := destroySocket(logger, sock, family, protocol, 0xffff, true); err != nil {
if err := destroySocket(logger, sock, family, filter.Protocol, filter.States, true); err != nil {
errs = errors.Join(errs, fmt.Errorf("destroying socket with filter [%v]: %w", filter, err))
failed++
return
Expand All @@ -125,7 +171,7 @@ func Destroy(logger *slog.Logger, filter SocketFilter) error {
goto redo
}
default:
return fmt.Errorf("unsupported protocol for socket destroy: %d", protocol)
return fmt.Errorf("unsupported protocol for socket destroy: %d", filter.Protocol)
}
if success > 0 || failed > 0 || errs != nil {
logger.Info(
Expand All @@ -150,8 +196,8 @@ func (f *SocketFilter) MatchSocket(socket netlink.SocketID) bool {
return false
}

func filterAndDestroySockets(family, protocol uint8, socketCB func(socket netlink.SocketID, err error)) error {
return iterateNetlinkSockets(protocol, family, 0xffff, func(sockInfo *Socket, err error) error {
func filterAndDestroySockets(family, protocol uint8, states uint32, socketCB func(socket netlink.SocketID, err error)) error {
return iterateNetlinkSockets(protocol, family, states, func(sockInfo *Socket, err error) error {
socketCB(sockInfo.ID, err)
return nil
})
Expand Down
1 change: 1 addition & 0 deletions pkg/datapath/sockets/sockets_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@ func TestDestroy(t *testing.T) {
DestPort: uint16(dport),
Family: unix.AF_INET,
Protocol: unix.IPPROTO_UDP,
States: StateFilterUDP,
DestroyCB: func(id netlink.SocketID) bool {
matches++
return true
Expand Down
4 changes: 4 additions & 0 deletions pkg/loadbalancer/reconciler/termination.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,15 +177,18 @@ func terminateConnectionsToBackend(p socketTerminationParams, l3n4Addr lb.L3n4Ad
var (
family uint8
protocol uint8
states uint32
)
ip := net.IP(l3n4Addr.AddrCluster.Addr().AsSlice())
l4Addr := l3n4Addr.L4Addr

switch l3n4Addr.Protocol {
case lb.UDP:
protocol = unix.IPPROTO_UDP
states = sockets.StateFilterUDP
case lb.TCP:
protocol = unix.IPPROTO_TCP
states = sockets.StateFilterTCP
default:
return
}
Expand Down Expand Up @@ -214,6 +217,7 @@ func terminateConnectionsToBackend(p socketTerminationParams, l3n4Addr lb.L3n4Ad
return p.SocketDestroyer.Destroy(sockets.SocketFilter{
Family: family,
Protocol: protocol,
States: states,
DestIp: ip,
DestPort: l4Addr.Port,
DestroyCB: checkSockInRevNat,
Expand Down