From 05266d6dbb0f5cc8b0b1a5ec294b7cf8dc7a1189 Mon Sep 17 00:00:00 2001 From: Georg Pfuetzenreuter Date: Sat, 6 Dec 2025 16:45:17 +0100 Subject: [PATCH 1/2] Add NetFromRange and NetFromInterval helpers When retrieving set elements it can be desired to format the result in the same way `nft` would, which is merging intervals to CIDR representations. To make this easier, introduce helper functions which allow for conversion of IP address ranges to CIDR networks. Signed-off-by: Georg Pfuetzenreuter --- util.go | 87 ++++++++++++++++++++ util_test.go | 218 +++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 305 insertions(+) diff --git a/util.go b/util.go index 7188bb0..eb9f3e2 100644 --- a/util.go +++ b/util.go @@ -16,12 +16,19 @@ package nftables import ( "encoding/binary" + "errors" "net" + "net/netip" "github.com/google/nftables/binaryutil" "golang.org/x/sys/unix" ) +var ( + MaxIPv4 = net.IP{255, 255, 255, 255} + MaxIPv6 = net.IP{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff} +) + func extraHeader(family uint8, resID uint16) []byte { return append([]byte{ family, @@ -126,3 +133,83 @@ func NetInterval(cidr string) (net.IP, net.IP, error) { return first, nextIP(last), nil } + +// endIp returns the last address in a given network. +func endIp(netIp net.IP, mask net.IPMask) net.IP { + ip := make(net.IP, len(netIp)) + copy(ip, netIp) + + for i := 0; i < len(mask); i++ { + ipIdx := len(ip) - i - 1 + ip[ipIdx] = netIp[ipIdx] | ^mask[len(mask)-i-1] + } + + return ip +} + +// NetFromRange returns a CIDR IP network given a start and end address. +// If the network is an exact match, ok will be true. +func NetFromRange(first net.IP, last net.IP) (*net.IPNet, bool, error) { + ip1 := net.IP(first) + ip2 := net.IP(last) + + maxLen := 32 + isIpv6 := ip1.To4() == nil + + if isIpv6 && ip2.To4() != nil || !isIpv6 && ip2.To4() == nil { + return nil, false, errors.New("Cannot mix IPv4 and IPv6 or process empty IP.") + } + + if isIpv6 { + maxLen = 128 + } + + var match *net.IPNet + for l := maxLen; l >= -1; l-- { + cidrmask := net.CIDRMask(l, maxLen) + ipmask := ip2.Mask(cidrmask) + ipnet := net.IPNet{ + IP: ipmask, + Mask: cidrmask, + } + + if ipnet.Contains(ip1) { + match = &ipnet + break + } + + } + + matchFirst := match.IP.Equal(ip1) + + // short-circuit if first address is not start of the network + if !matchFirst { + return match, matchFirst, nil + } + + return match, endIp(match.IP, match.Mask).Equal(ip2), nil +} + +// NetFromInterval returns a CIDR IP network given a start and end address as found in intervals. +// This is similar to NetFromRange, but subtracts one address from the end of the range. +// If the resulting network is an exact match, ok will be true. +func NetFromInterval(first net.IP, last net.IP) (out *net.IPNet, ok bool, err error) { + var previous net.IP + + if len(last) == 0 { + if first.To4() == nil { + previous = MaxIPv6 + } else { + previous = MaxIPv4 + } + } else { + ip2, ok := netip.AddrFromSlice(last) + if !ok { + return nil, false, errors.New("Failed to construct slice from network.") + } + + previous = ip2.Prev().AsSlice() + } + + return NetFromRange(first, previous) +} diff --git a/util_test.go b/util_test.go index d766e5a..ff9d09e 100644 --- a/util_test.go +++ b/util_test.go @@ -201,3 +201,221 @@ func TestNetInterval(t *testing.T) { }) } } + +func TestEndIp(t *testing.T) { + tests := []struct { + network string + wantEndIp string + }{ + { + network: "10.0.0.0/24", + wantEndIp: "10.0.0.255", + }, + { + network: "192.168.4.32/27", + wantEndIp: "192.168.4.63", + }, + { + network: "2001:db8:100::/64", + wantEndIp: "2001:db8:100:0:ffff:ffff:ffff:ffff", + }, + { + network: "2001:db8:100:a:b::50/64", + wantEndIp: "2001:db8:100:a:ffff:ffff:ffff:ffff", + }, + } + for _, tt := range tests { + taddr, tnet, err := net.ParseCIDR(tt.network) + if err != nil { + t.Fatalf("endIp() error parsing test CIDR = %v", err) + } + + t.Run(tnet.String(), func(t *testing.T) { + gotEndIp := endIp(taddr, tnet.Mask) + if !gotEndIp.Equal(net.ParseIP(tt.wantEndIp)) { + t.Errorf("endIp() gotEndIp = %s, wantEndIp = %s", gotEndIp, tt.wantEndIp) + } + }) + } +} + +func TestNetFromRange(t *testing.T) { + tests := []struct { + name string + first string + last string + wantNet string + wantOk bool + wantErr bool + }{ + { + first: "0.0.0.0", + last: "255.255.255.255", + wantNet: "0.0.0.0/0", + wantOk: true, + wantErr: false, + }, + { + first: "0.0.0.1", + last: "255.255.255.254", + wantNet: "0.0.0.0/0", + wantOk: false, + wantErr: false, + }, + { + first: "192.168.4.0", + last: "192.168.4.255", + wantNet: "192.168.4.0/24", + wantOk: true, + wantErr: false, + }, + { + first: "192.0.2.16", + last: "192.0.2.30", + wantNet: "192.0.2.16/28", + wantOk: false, + wantErr: false, + }, + { + first: "2001:db8:100::", + last: "2001:db8:100:ffff:ffff:ffff:ffff:ffff", + wantNet: "2001:db8:100::/48", + wantOk: true, + wantErr: false, + }, + { + first: "2001:db8:100::100", + last: "2001:db8:100:0:ffff:ffff:ffff:ffff", + wantNet: "2001:db8:100::/64", + wantOk: false, + wantErr: false, + }, + { + first: "2001:db8:100::", + last: "192.0.2.30", + wantNet: "", + wantOk: true, + wantErr: true, + }, + { + first: "192.0.2.30", + last: "2001:db8:100::", + wantNet: "", + wantOk: true, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.first+"-"+tt.last, func(t *testing.T) { + gotNet, gotOk, err := NetFromRange(net.ParseIP(tt.first), net.ParseIP(tt.last)) + if (err != nil) != tt.wantErr { + t.Errorf("NetFromRange() error = %v, wantErr = %v", err, tt.wantErr) + } + + if tt.wantNet == "" { + return + } + + _, wantNetParsed, err := net.ParseCIDR(tt.wantNet) + if err != nil { + t.Fatalf("NetFromRange() error parsing test network = %v", err) + } + + if tt.wantOk != gotOk { + t.Errorf("NetFromRange() gotOk = %t, wantOk = %t", gotOk, tt.wantOk) + } + + if !reflect.DeepEqual(gotNet, wantNetParsed) { + t.Errorf("NetFromRange() gotNet = %+v, wantNet = %+v", gotNet, wantNetParsed) + } + }) + } +} + +func TestNetFromInterval(t *testing.T) { + tests := []struct { + name string + first string + last string + wantNet string + wantOk bool + wantErr bool + }{ + { + first: "192.0.2.16", + last: "192.0.2.32", + wantNet: "192.0.2.16/28", + wantOk: true, + wantErr: false, + }, + { + first: "128.0.0.0", + last: "", + wantNet: "128.0.0.0/1", + wantOk: true, + wantErr: false, + }, + { + first: "2001:db8:100::", + last: "2001:db8:101::", + wantNet: "2001:db8:100::/48", + wantOk: true, + wantErr: false, + }, + { + first: "2001:db8:a1:11::", + last: "2001:db8:a1:12::", + wantNet: "2001:db8:a1:11::/64", + wantOk: true, + wantErr: false, + }, + { + first: "2001:db8:100::100", + last: "2001:db8:100:0:ffff:ffff:ffff:ffff", + wantNet: "2001:db8:100::/64", + wantOk: false, + wantErr: false, + }, + { + first: "2001:db8:100::", + last: "192.0.2.30", + wantNet: "", + wantOk: true, + wantErr: true, + }, + { + first: "192.0.2.30", + last: "2001:db8:100::", + wantNet: "", + wantOk: true, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.first+"-"+tt.last, func(t *testing.T) { + gotNet, gotOk, err := NetFromInterval(net.ParseIP(tt.first), net.ParseIP(tt.last)) + if (err != nil) != tt.wantErr { + t.Errorf("NetFromInterval() error = %v, wantErr = %v", err, tt.wantErr) + } + + if tt.wantNet == "" { + return + } + + _, wantNetParsed, err := net.ParseCIDR(tt.wantNet) + if err != nil { + t.Fatalf("NetFromInterval() error parsing test network = %v", err) + } + + if tt.wantOk != gotOk { + t.Errorf("NetFromInterval() gotOk = %t, wantOk = %t", gotOk, tt.wantOk) + } + + if !reflect.DeepEqual(gotNet, wantNetParsed) { + t.Errorf("NetFromInterval() gotNet = %+v, wantNet = %+v", gotNet, wantNetParsed) + } + }) + } +} From abe5282d456672b4d378a327e0e2d0d4fafe5eef Mon Sep 17 00:00:00 2001 From: Georg Pfuetzenreuter Date: Sun, 14 Dec 2025 14:53:08 +0100 Subject: [PATCH 2/2] Only return exact networks from NetFromRange If "ok" is not true, do not return the nearest possible CIDR, but instead none at all to avoid ambiguity. Signed-off-by: Georg Pfuetzenreuter --- util.go | 12 +++++++++--- util_test.go | 16 ++++++++++++---- 2 files changed, 21 insertions(+), 7 deletions(-) diff --git a/util.go b/util.go index eb9f3e2..6e363cf 100644 --- a/util.go +++ b/util.go @@ -148,7 +148,7 @@ func endIp(netIp net.IP, mask net.IPMask) net.IP { } // NetFromRange returns a CIDR IP network given a start and end address. -// If the network is an exact match, ok will be true. +// If an exact match is found, ok will be true. If not, no IPNet will be returned, and ok will be false. func NetFromRange(first net.IP, last net.IP) (*net.IPNet, bool, error) { ip1 := net.IP(first) ip2 := net.IP(last) @@ -184,10 +184,16 @@ func NetFromRange(first net.IP, last net.IP) (*net.IPNet, bool, error) { // short-circuit if first address is not start of the network if !matchFirst { - return match, matchFirst, nil + return nil, matchFirst, nil } - return match, endIp(match.IP, match.Mask).Equal(ip2), nil + matchSecond := endIp(match.IP, match.Mask).Equal(ip2) + + if !matchSecond { + return nil, matchSecond, nil + } + + return match, true, nil } // NetFromInterval returns a CIDR IP network given a start and end address as found in intervals. diff --git a/util_test.go b/util_test.go index ff9d09e..646fa7d 100644 --- a/util_test.go +++ b/util_test.go @@ -258,7 +258,7 @@ func TestNetFromRange(t *testing.T) { { first: "0.0.0.1", last: "255.255.255.254", - wantNet: "0.0.0.0/0", + wantNet: "", // not exactly 0.0.0.0/0 wantOk: false, wantErr: false, }, @@ -272,7 +272,7 @@ func TestNetFromRange(t *testing.T) { { first: "192.0.2.16", last: "192.0.2.30", - wantNet: "192.0.2.16/28", + wantNet: "", // not exactly 192.0.2.16/28 wantOk: false, wantErr: false, }, @@ -286,7 +286,7 @@ func TestNetFromRange(t *testing.T) { { first: "2001:db8:100::100", last: "2001:db8:100:0:ffff:ffff:ffff:ffff", - wantNet: "2001:db8:100::/64", + wantNet: "", // not exactly 2001:db8:100::/64 wantOk: false, wantErr: false, }, @@ -314,6 +314,10 @@ func TestNetFromRange(t *testing.T) { } if tt.wantNet == "" { + if gotNet != nil { + t.Errorf("NetFromInterval() gotNet = %v, wantNet = nil", gotNet) + } + return } @@ -373,7 +377,7 @@ func TestNetFromInterval(t *testing.T) { { first: "2001:db8:100::100", last: "2001:db8:100:0:ffff:ffff:ffff:ffff", - wantNet: "2001:db8:100::/64", + wantNet: "", // not exactly 2001:db8:100::/64 wantOk: false, wantErr: false, }, @@ -401,6 +405,10 @@ func TestNetFromInterval(t *testing.T) { } if tt.wantNet == "" { + if gotNet != nil { + t.Errorf("NetFromInterval() gotNet = %v, wantNet = nil", gotNet) + } + return }