From 434649a78a946a42f9f8107945da0e261506721a Mon Sep 17 00:00:00 2001 From: nickgarlis Date: Mon, 10 Nov 2025 13:24:21 +0100 Subject: [PATCH] Add NetInterval helper When creating set elements that represent a network, the interval range must be half-open [start, end) rather than inclusive [start, end]. For example, for 10.0.0.0/24, the expected range is 10.0.0.0 to 10.0.1.0 instead of 10.0.0.0 to 10.0.0.255. This change introduces a NetInterval helper that returns the correct range given a CIDR string. --- util.go | 39 ++++++++++++++++ util_test.go | 125 +++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 164 insertions(+) diff --git a/util.go b/util.go index b040ae4c..7188bb06 100644 --- a/util.go +++ b/util.go @@ -87,3 +87,42 @@ func NetFirstAndLastIP(networkCIDR string) (first, last net.IP, err error) { return first, last, nil } + +// nextIp returns the next IP address after the given one. +// If the next address overflows, the sentinel values 0.0.0.0 (IPv4) +// or :: (IPv6) are returned. +func nextIP(ip net.IP) net.IP { + if ip == nil { + return nil + } + + next := make(net.IP, len(ip)) + copy(next, ip) + + for i := len(next) - 1; i >= 0; i-- { + next[i]++ + if next[i] != 0 { + return next + } + } + + // All bytes overflowed to 0 + return next +} + +// NetInterval returns the half-open ([start, end)) interval of a CIDR string. +// This is the range that nftables uses for interval matching with set elements. +// Unlike NetFirstAndLastIP, the end value is one past the last IP in the +// network. If the last IP is overflowed, the end value will be a zero IP. +// +// For example, for the CIDR "10.0.0.0/24", NetInterval returns +// first=10.0.0.0 and last=10.0.1.0. Note that last is one more than the +// broadcast address of the CIDR. +func NetInterval(cidr string) (net.IP, net.IP, error) { + first, last, err := NetFirstAndLastIP(cidr) + if err != nil { + return first, last, err + } + + return first, nextIP(last), nil +} diff --git a/util_test.go b/util_test.go index 9deee506..d766e5ab 100644 --- a/util_test.go +++ b/util_test.go @@ -76,3 +76,128 @@ func TestNetFirstAndLastIP(t *testing.T) { }) } } + +func TestNetInterval(t *testing.T) { + tests := []struct { + name string + cidr string + wantFirstIP net.IP + wantLastIP net.IP + wantErr bool + }{ + { + name: "Test Invalid", + cidr: "invalid-cidr", + wantFirstIP: nil, + wantLastIP: nil, + wantErr: true, + }, + { + name: "Test IPV4 /0", + cidr: "0.0.0.0/0", + wantFirstIP: net.IP{0, 0, 0, 0}, + wantLastIP: net.IP{0, 0, 0, 0}, + wantErr: false, + }, + { + name: "Test IPV4 /8", + cidr: "10.0.0.0/8", + wantFirstIP: net.IP{10, 0, 0, 0}, + wantLastIP: net.IP{11, 0, 0, 0}, + wantErr: false, + }, + { + name: "Test IPV4 /16", + cidr: "10.0.0.0/16", + wantFirstIP: net.IP{10, 0, 0, 0}, + wantLastIP: net.IP{10, 1, 0, 0}, + wantErr: false, + }, + { + name: "Test IPV4 /24", + cidr: "10.0.0.0/24", + wantFirstIP: net.IP{10, 0, 0, 0}, + wantLastIP: net.IP{10, 0, 1, 0}, + wantErr: false, + }, + { + name: "Test IPV4 /31 near max", + cidr: "255.255.255.255/31", + wantFirstIP: net.IP{255, 255, 255, 254}, + wantLastIP: net.IP{0, 0, 0, 0}, + wantErr: false, + }, + { + name: "Test IPV4 /32", + cidr: "10.0.0.1/32", + wantFirstIP: net.IP{10, 0, 0, 1}, + wantLastIP: net.IP{10, 0, 0, 2}, + wantErr: false, + }, + { + name: "Test IPv4 /0 with max", + cidr: "255.255.255.255/0", + wantFirstIP: net.IP{0, 0, 0, 0}, + wantLastIP: net.IP{0, 0, 0, 0}, + wantErr: false, + }, + { + name: "Test IPv6 /0", + cidr: "::/0", + wantFirstIP: net.ParseIP("::"), + wantLastIP: net.ParseIP("::"), + wantErr: false, + }, + { + name: "Test IPv6 /48", + cidr: "2001:db8::/48", + wantFirstIP: net.ParseIP("2001:db8::"), + wantLastIP: net.ParseIP("2001:db8:1::"), + wantErr: false, + }, + { + name: "Test IPv6 /64", + cidr: "2001:db8::/64", + wantFirstIP: net.ParseIP("2001:db8::"), + wantLastIP: net.ParseIP("2001:db8::1:0:0:0:0"), + wantErr: false, + }, + { + name: "Test IPv6 /120 near max", + cidr: "ffff:ffff:ffff:ffff:ffff:ffff:ffff:ff00/120", + wantFirstIP: net.ParseIP("ffff:ffff:ffff:ffff:ffff:ffff:ffff:ff00"), + wantLastIP: net.ParseIP("::"), + wantErr: false, + }, + { + name: "Test IPv6 /128", + cidr: "2001:db8::1/128", + wantFirstIP: net.ParseIP("2001:db8::1"), + wantLastIP: net.ParseIP("2001:db8::2"), + wantErr: false, + }, + { + name: "Test IPv6 /0 with max", + cidr: "ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff/0", + wantFirstIP: net.ParseIP("::"), + wantLastIP: net.ParseIP("::"), + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotFirstIP, gotLastIP, err := NetInterval(tt.cidr) + if (err != nil) != tt.wantErr { + t.Errorf("NetInterval() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(gotFirstIP, tt.wantFirstIP) { + t.Errorf("NetInterval() gotFirstIP = %v, want %v", gotFirstIP, tt.wantFirstIP) + } + if !reflect.DeepEqual(gotLastIP, tt.wantLastIP) { + t.Errorf("NetInterval() gotLastIP = %v, want %v", gotLastIP, tt.wantLastIP) + } + }) + } +}