Skip to content

Commit bdd9db0

Browse files
committed
Add NetFromRange and NetFromIntervalRange helpers
When retrieving set elements it can be desired to format the result in the same way `nft` would, which is merging intervals to CIDR representations. To make this easier, introduce helper functions which allow for conversion of IP address ranges to CIDR networks. Signed-off-by: Georg Pfuetzenreuter <mail@georg-pfuetzenreuter.net>
1 parent 1db35da commit bdd9db0

File tree

2 files changed

+196
-0
lines changed

2 files changed

+196
-0
lines changed

util.go

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,10 @@ package nftables
1616

1717
import (
1818
"encoding/binary"
19+
"errors"
20+
"fmt"
1921
"net"
22+
"net/netip"
2023

2124
"github.com/google/nftables/binaryutil"
2225
"golang.org/x/sys/unix"
@@ -126,3 +129,64 @@ func NetInterval(cidr string) (net.IP, net.IP, error) {
126129

127130
return first, nextIP(last), nil
128131
}
132+
133+
// cidrString returns an IPnet given a network address and CIDR mask
134+
func cidrString(address net.IP, cidr int) (*net.IPNet, error) {
135+
136+
// TODO: why is CIDR for IPv6 off by 1 in some cases?
137+
if address.To4() == nil {
138+
cidr = cidr + 1
139+
}
140+
141+
_, out, err := net.ParseCIDR(fmt.Sprintf("%s/%d", address, cidr))
142+
if err != nil {
143+
return nil, err
144+
}
145+
146+
return out, err
147+
}
148+
149+
// NetFromRange returns a CIDR IP network given a start and end address
150+
func NetFromRange(first []byte, last []byte) (*net.IPNet, error) {
151+
ip1 := net.IP(first)
152+
ip2 := net.IP(last)
153+
154+
maxLen := 32
155+
ip1parts := ip1.To4()
156+
157+
if ip1parts == nil && ip2.To4() != nil {
158+
return nil, errors.New("Cannot mix IPv4 and IPv6.")
159+
}
160+
161+
if ip1parts == nil {
162+
maxLen = 128
163+
}
164+
165+
for l := maxLen; l >= 0; l-- {
166+
cidrmask := net.CIDRMask(l, maxLen)
167+
ipmask := ip2.Mask(cidrmask)
168+
ipnet := net.IPNet{
169+
IP: ipmask,
170+
Mask: cidrmask,
171+
}
172+
173+
if ipnet.Contains(ip1) {
174+
return cidrString(ipmask, l)
175+
}
176+
}
177+
178+
return nil, errors.New("Failed to construct network from range.")
179+
}
180+
181+
// NetFromNetInterval returns a CIDR IP network given a start and end address as found in intervals.
182+
// This is similar to NetFromRange, but subtracts one address from the end of the range.
183+
func NetFromIntervalRange(first []byte, last []byte) (out *net.IPNet, err error) {
184+
ip2, ok := netip.AddrFromSlice(last)
185+
if !ok {
186+
return nil, errors.New("Failed to construct slice from network.")
187+
}
188+
189+
previous := ip2.Prev()
190+
191+
return NetFromRange(first, previous.AsSlice())
192+
}

util_test.go

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,3 +201,135 @@ func TestNetInterval(t *testing.T) {
201201
})
202202
}
203203
}
204+
205+
func TestNetFromRange(t *testing.T) {
206+
tests := []struct {
207+
name string
208+
first string
209+
last string
210+
wantNet string
211+
wantErr bool
212+
}{
213+
{
214+
first: "0.0.0.1",
215+
last: "255.255.255.254",
216+
wantNet: "0.0.0.0/0",
217+
wantErr: false,
218+
},
219+
{
220+
first: "192.168.4.0",
221+
last: "192.168.4.255",
222+
wantNet: "192.168.4.0/24",
223+
wantErr: false,
224+
},
225+
{
226+
first: "192.0.2.17",
227+
last: "192.0.2.30",
228+
wantNet: "192.0.2.16/28",
229+
wantErr: false,
230+
},
231+
{
232+
first: "2001:db8:100::",
233+
last: "2001:db8:100:ffff:ffff:ffff:ffff:ffff",
234+
wantNet: "2001:db8:100::/48",
235+
wantErr: false,
236+
},
237+
{
238+
first: "2001:db8:100::",
239+
last: "192.0.2.30",
240+
wantNet: "",
241+
wantErr: true,
242+
},
243+
{
244+
first: "192.0.2.30",
245+
last: "2001:db8:100::",
246+
wantNet: "",
247+
wantErr: true,
248+
},
249+
}
250+
251+
for _, tt := range tests {
252+
t.Run(tt.first+"-"+tt.last, func(t *testing.T) {
253+
gotNet, err := NetFromRange(net.ParseIP(tt.first), net.ParseIP(tt.last))
254+
if (err != nil) != tt.wantErr {
255+
t.Errorf("NetFromRange() error = %v, wantErr = %v", err, tt.wantErr)
256+
}
257+
258+
if tt.wantNet == "" {
259+
return
260+
}
261+
262+
_, wantNetParsed, err := net.ParseCIDR(tt.wantNet)
263+
if err != nil {
264+
t.Fatalf("NetFromRange() error parsing test network = %v", err)
265+
}
266+
267+
if !reflect.DeepEqual(gotNet, wantNetParsed) {
268+
t.Errorf("NetFromRange() gotNet = %+v, wantNet = %+v", gotNet, wantNetParsed)
269+
}
270+
})
271+
}
272+
}
273+
274+
func TestNetFromIntervalRange(t *testing.T) {
275+
tests := []struct {
276+
name string
277+
first string
278+
last string
279+
wantNet string
280+
wantErr bool
281+
}{
282+
{
283+
first: "192.0.2.16",
284+
last: "192.0.2.32",
285+
wantNet: "192.0.2.16/28",
286+
wantErr: false,
287+
},
288+
{
289+
first: "2001:db8:101::",
290+
last: "2001:db8:100::1",
291+
wantNet: "2001:db8:100::/48",
292+
wantErr: false,
293+
},
294+
{
295+
first: "2a02:1748:f7df:9c81::",
296+
last: "2a02:1748:f7df:9c80::1",
297+
wantNet: "2a02:1748:f7df:9c80::/64",
298+
wantErr: false,
299+
},
300+
{
301+
first: "2001:db8:100::",
302+
last: "192.0.2.30",
303+
wantNet: "",
304+
wantErr: true,
305+
},
306+
{
307+
first: "192.0.2.30",
308+
last: "2001:db8:100::",
309+
wantNet: "",
310+
wantErr: true,
311+
},
312+
}
313+
314+
for _, tt := range tests {
315+
t.Run(tt.first+"-"+tt.last, func(t *testing.T) {
316+
gotNet, err := NetFromIntervalRange(net.ParseIP(tt.first), net.ParseIP(tt.last))
317+
if (err != nil) != tt.wantErr {
318+
t.Errorf("NetFromIntervalRange() error = %v, wantErr = %v", err, tt.wantErr)
319+
}
320+
321+
if tt.wantNet == "" {
322+
return
323+
}
324+
325+
_, wantNetParsed, err := net.ParseCIDR(tt.wantNet)
326+
if err != nil {
327+
t.Fatalf("NetFromIntervalRange() error parsing test network = %v", err)
328+
}
329+
330+
if !reflect.DeepEqual(gotNet, wantNetParsed) {
331+
t.Errorf("NetFromIntervalRange() gotNet = %+v, wantNet = %+v", gotNet, wantNetParsed)
332+
}
333+
})
334+
}
335+
}

0 commit comments

Comments
 (0)