Skip to content

Commit cf6569c

Browse files
authored
Ensure socket buffer draining during Flush (#330)
This commit re-implements response handling for batch messages. The current implementation uses heuristics to drain the socket buffer by counting sent messages and making a blocking recvmsg call for each one. While this generally works, it has several problems: - No kernel guarantees: The kernel makes no guarantee about the number of ACKs returned. Error conditions like EPERM, ENOBUFS, and ENOMEM can cause unexpected behavior. - Vulnerable to kernel bugs: As described in #329 - Complex code: ACK counting complicates the response handling logic in Flush The new approach follows nft's pattern by leveraging netlink's synchronous request-response behavior. When a blocking sendmsg completes, the kernel has already processed the request and queued any response data in the socket's receive buffer. This means we can simply check if data is available rather than blocking indefinitely waiting for a specific number of responses. The implementation uses pselect6 (the same syscall nft uses) to poll the socket for readability. Fixes: #329
1 parent df852a3 commit cf6569c

File tree

4 files changed

+230
-45
lines changed

4 files changed

+230
-45
lines changed

conn.go

Lines changed: 73 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ import (
2020
"math"
2121
"os"
2222
"sync"
23-
"syscall"
2423

2524
"github.com/google/nftables/binaryutil"
2625
"github.com/google/nftables/expr"
@@ -262,71 +261,100 @@ func (cc *Conn) flush(genID uint32) error {
262261
return err
263262
}
264263

265-
messages, err := conn.SendMessages(batch)
264+
sentMsgs, err := conn.SendMessages(batch)
266265
if err != nil {
267266
return fmt.Errorf("SendMessages: %w", err)
268267
}
269268

270269
var errs error
271270

272-
// Fetch replies. Each message with the Echo flag triggers a reply of the same
273-
// type. Additionally, if the first message of the batch has the Echo flag, we
274-
// get a reply of type NFT_MSG_NEWGEN, which we ignore.
275-
replyIndex := 0
276-
for replyIndex < len(cc.messages) && cc.messages[replyIndex].Header.Flags&netlink.Echo == 0 {
277-
replyIndex++
278-
}
279-
replies, err := conn.Receive()
280-
for err == nil && len(replies) != 0 {
281-
reply := replies[0]
282-
if reply.Header.Type == netlink.Error && reply.Header.Sequence == messages[1].Header.Sequence {
283-
// The next message is the acknowledgement for the first message in the
284-
// batch; stop looking for replies.
271+
seqToMsgMap := cc.getSeqToMsgMap(sentMsgs)
272+
273+
for {
274+
ready, err := cc.isReadReady(conn)
275+
if err != nil {
276+
return err
277+
}
278+
279+
// Since SendMessages is blocking and netlink communication is synchronous,
280+
// the kernel has already processed the request and queued any responses by
281+
// the time SendMessages returns. Therefore, if isReadReady returns false on
282+
// the first call, it means there are no messages coming at all and we can
283+
// safely exit.
284+
if !ready {
285285
break
286-
} else if replyIndex < len(cc.messages) {
287-
msg := messages[replyIndex+1]
288-
if msg.Header.Sequence == reply.Header.Sequence && msg.Header.Type == reply.Header.Type {
289-
// The only messages which set the echo flag are rule create messages.
290-
err := cc.messages[replyIndex].rule.handleCreateReply(reply)
291-
if err != nil {
292-
errs = errors.Join(errs, err)
293-
}
294-
replyIndex++
295-
for replyIndex < len(cc.messages) && cc.messages[replyIndex].Header.Flags&netlink.Echo == 0 {
296-
replyIndex++
297-
}
298-
}
299286
}
300-
replies = replies[1:]
301-
if len(replies) == 0 {
302-
replies, err = conn.Receive()
287+
288+
replies, err := conn.Receive()
289+
if err != nil {
290+
errs = errors.Join(errs, fmt.Errorf("receive: %w", err))
303291
}
304-
}
305292

306-
// Fetch the requested acknowledgement for each message we sent.
307-
for i := range cc.messages {
308-
if i != 0 {
309-
_, err = conn.Receive()
293+
if len(replies) == 0 && cc.TestDial != nil {
294+
// When using a test dial function, we don't always get a reply for each
295+
// sent message. Additionally, there is no buffer to poll for more data,
296+
// so we stop here.
297+
break
310298
}
311-
if err != nil {
312-
if errors.Is(err, os.ErrPermission) || errors.Is(err, syscall.ENOBUFS) || errors.Is(err, syscall.ENOMEM) {
313-
// Kernel will only send one error to user space.
314-
return err
299+
300+
for _, reply := range replies {
301+
if err := cc.handleEchoReply(seqToMsgMap, reply); err != nil {
302+
errs = errors.Join(errs, err)
315303
}
316-
errs = errors.Join(errs, err)
317304
}
318305
}
319306

320307
if errs != nil {
321-
return fmt.Errorf("conn.Receive: %w", errs)
322-
}
323-
if replyIndex < len(cc.messages) {
324-
return fmt.Errorf("missing reply for message %d in batch", replyIndex)
308+
return errs
325309
}
326310

327311
return nil
328312
}
329313

314+
// getSeqToMsgMap returns a map of the cc.messages that were sent, indexed by
315+
// their sequence number as included in the sent netlink messages. The returned
316+
// map will not include the batch begin and end messages.
317+
func (cc *Conn) getSeqToMsgMap(sentMsgs []netlink.Message) map[uint32]netlinkMessage {
318+
seqToMsgMap := make(map[uint32]netlinkMessage)
319+
for i, msg := range sentMsgs {
320+
if i == 0 || i == len(sentMsgs)-1 {
321+
// Skip batch begin and end messages.
322+
continue
323+
}
324+
if i-1 >= len(cc.messages) {
325+
// Should not happen, but be defensive.
326+
break
327+
}
328+
// Update the header in the original message, as the sequence number
329+
// and possibly other fields have been updated by the the underlying
330+
// netlink library.
331+
cc.messages[i-1].Header = msg.Header
332+
seqToMsgMap[msg.Header.Sequence] = cc.messages[i-1]
333+
}
334+
335+
return seqToMsgMap
336+
}
337+
338+
func (cc *Conn) handleEchoReply(seqToMsgMap map[uint32]netlinkMessage, reply netlink.Message) error {
339+
sentMsg, ok := seqToMsgMap[reply.Header.Sequence]
340+
if !ok {
341+
// We don't have a record of sending this message, ignore.
342+
return nil
343+
}
344+
345+
if sentMsg.Header.Flags&netlink.Echo == 0 {
346+
return nil
347+
}
348+
349+
switch reply.Header.Type {
350+
case newRuleHeaderType:
351+
// The only messages which set the echo flag are rule create messages.
352+
return sentMsg.rule.handleCreateReply(reply)
353+
default:
354+
return nil
355+
}
356+
}
357+
330358
// FlushRuleset flushes the entire ruleset. See also
331359
// https://wiki.nftables.org/wiki-nftables/index.php/Operations_at_ruleset_level
332360
func (cc *Conn) FlushRuleset() {

internal/nftest/util.go

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
package nftest
2+
3+
import (
4+
"fmt"
5+
"math"
6+
"runtime"
7+
8+
"golang.org/x/sys/unix"
9+
)
10+
11+
// AsUnprivileged temporarily drops the effective UID to an unprivileged
12+
// value (65535) while executing the provided function. It requires the
13+
// process to be running as root to be able to regain privileges afterwards.
14+
func AsUnprivileged(fn func() error) error {
15+
runtime.LockOSThread()
16+
defer runtime.UnlockOSThread()
17+
18+
targetUID := math.MaxUint16
19+
_, euid, suid := unix.Getresuid()
20+
21+
if euid != 0 && suid != 0 {
22+
return fmt.Errorf("must be run as root to regain privileges (euid=%d suid=%d)", euid, suid)
23+
}
24+
25+
// Drop privileges by changing only the effective UID
26+
if err := unix.Setresuid(-1, targetUID, -1); err != nil {
27+
return fmt.Errorf("failed to drop effective UID to %d: %w", targetUID, err)
28+
}
29+
30+
// Restore when done
31+
defer func() {
32+
if err := unix.Setresuid(-1, euid, -1); err != nil {
33+
panic(fmt.Sprintf("failed to restore euid=%d: %v", euid, err))
34+
}
35+
}()
36+
37+
return fn()
38+
}

nftables_test.go

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ import (
3434
"github.com/google/nftables/internal/nftest"
3535
"github.com/google/nftables/xt"
3636
"github.com/mdlayher/netlink"
37+
"github.com/vishvananda/netns"
3738
"golang.org/x/sys/unix"
3839
)
3940

@@ -7947,3 +7948,82 @@ func TestGetPortID(t *testing.T) {
79477948
t.Fatalf("conn.GetPortID() returned invalid port ID: %d", pid)
79487949
}
79497950
}
7951+
7952+
func TestSocketDrainingOnErrors(t *testing.T) {
7953+
tests := []struct {
7954+
name string
7955+
setupError func(conn *nftables.Conn, ns netns.NsHandle) error
7956+
expectedError error
7957+
description string
7958+
}{
7959+
{
7960+
name: "short_circuited_error",
7961+
setupError: func(conn *nftables.Conn, ns netns.NsHandle) error {
7962+
// Add multiple tables but trigger EPERM before kernel processes all
7963+
conn.AddTable(&nftables.Table{Name: "table1", Family: nftables.TableFamilyIPv4})
7964+
conn.AddTable(&nftables.Table{Name: "table2", Family: nftables.TableFamilyIPv4})
7965+
conn.AddTable(&nftables.Table{Name: "table3", Family: nftables.TableFamilyIPv4})
7966+
7967+
// Drop privileges to trigger immediate EPERM (short-circuit)
7968+
return nftest.AsUnprivileged(conn.Flush)
7969+
},
7970+
expectedError: syscall.EPERM,
7971+
description: "kernel returns EPERM immediately without processing all messages",
7972+
},
7973+
{
7974+
name: "non_short_circuited_error",
7975+
setupError: func(conn *nftables.Conn, ns netns.NsHandle) error {
7976+
// Use new connection to create an owned table
7977+
newConn, err := nftables.New(nftables.WithNetNSFd(int(ns)), nftables.AsLasting())
7978+
if err != nil {
7979+
return err
7980+
}
7981+
defer newConn.CloseLasting()
7982+
7983+
ownedTable := &nftables.Table{
7984+
Name: "owned-table",
7985+
Family: nftables.TableFamilyIPv4,
7986+
Flags: nftables.TableFlagOwner | nftables.TableFlagPersist,
7987+
}
7988+
newConn.AddTable(ownedTable)
7989+
if err := newConn.Flush(); err != nil {
7990+
return err
7991+
}
7992+
7993+
// Use old connection to try deleting owned table (will fail)
7994+
conn.DelTable(ownedTable)
7995+
conn.AddTable(&nftables.Table{Name: "table2", Family: nftables.TableFamilyIPv4})
7996+
conn.AddTable(&nftables.Table{Name: "table3", Family: nftables.TableFamilyIPv4})
7997+
7998+
return conn.Flush()
7999+
},
8000+
expectedError: syscall.EPERM,
8001+
description: "kernel processes all messages but returns EPERM for owned table deletion",
8002+
},
8003+
}
8004+
8005+
for _, tt := range tests {
8006+
t.Run(tt.name, func(t *testing.T) {
8007+
_, newNS := nftest.OpenSystemConn(t, *enableSysTests)
8008+
conn, err := nftables.New(nftables.WithNetNSFd(int(newNS)), nftables.AsLasting())
8009+
if err != nil {
8010+
t.Fatalf("nftables.New() failed: %v", err)
8011+
}
8012+
defer nftest.CleanupSystemConn(t, newNS)
8013+
defer conn.FlushRuleset()
8014+
defer conn.CloseLasting()
8015+
8016+
err = tt.setupError(conn, newNS)
8017+
8018+
if err == nil || !errors.Is(err, tt.expectedError) {
8019+
t.Fatalf("expected %v error, got %v", tt.expectedError, err)
8020+
}
8021+
8022+
// Verify socket is properly drained. If not, this will fail as
8023+
// there will be leftover messages in the socket buffer.
8024+
if _, err := conn.ListTables(); err != nil {
8025+
t.Fatalf("ListTables failed after error - socket not properly drained: %v", err)
8026+
}
8027+
})
8028+
}
8029+
}

socket.go

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
package nftables
2+
3+
import (
4+
"fmt"
5+
6+
"github.com/mdlayher/netlink"
7+
"golang.org/x/sys/unix"
8+
)
9+
10+
// isReadReady reports whether the netlink connection is ready for reading.
11+
// It uses pselect6 with a zero timeout on the underlying raw connection.
12+
// This allows for an efficient check of socket readiness without blocking.
13+
// If the Conn was created with a TestDial function, it assumes readiness.
14+
func (cc *Conn) isReadReady(conn *netlink.Conn) (bool, error) {
15+
if cc.TestDial != nil {
16+
return true, nil
17+
}
18+
19+
rawConn, err := conn.SyscallConn()
20+
if err != nil {
21+
return false, fmt.Errorf("get raw conn: %w", err)
22+
}
23+
24+
var n int
25+
var opErr error
26+
err = rawConn.Control(func(fd uintptr) {
27+
var readfds unix.FdSet
28+
readfds.Zero()
29+
readfds.Set(int(fd))
30+
31+
ts := &unix.Timespec{} // zero timeout: immediate return
32+
n, opErr = unix.Pselect(int(fd)+1, &readfds, nil, nil, ts, nil)
33+
})
34+
if err != nil {
35+
return false, err
36+
}
37+
38+
return n > 0, opErr
39+
}

0 commit comments

Comments
 (0)