diff --git a/go.mod b/go.mod index 820c8cb8..28cac721 100644 --- a/go.mod +++ b/go.mod @@ -21,6 +21,7 @@ require ( github.com/google/cel-go v0.26.0 github.com/google/go-cmp v0.7.0 github.com/insomniacslk/dhcp v0.0.0-20250417080101-5f8cf70e8c5f + github.com/jaypipes/ghw v0.17.0 github.com/mdlayher/genetlink v1.3.2 github.com/mdlayher/netlink v1.7.2 github.com/prometheus/client_golang v1.23.0 @@ -47,6 +48,7 @@ require ( cel.dev/expr v0.24.0 // indirect cloud.google.com/go/auth v0.16.3 // indirect cloud.google.com/go/auth/oauth2adapt v0.2.8 // indirect + github.com/StackExchange/wmi v1.2.1 // indirect github.com/antlr4-go/antlr/v4 v4.13.1 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect @@ -58,6 +60,7 @@ require ( github.com/fxamacker/cbor/v2 v2.8.0 // indirect github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/stdr v1.2.2 // indirect + github.com/go-ole/go-ole v1.2.6 // indirect github.com/go-openapi/jsonpointer v0.21.1 // indirect github.com/go-openapi/jsonreference v0.21.0 // indirect github.com/go-openapi/swag v0.23.1 // indirect @@ -68,6 +71,7 @@ require ( github.com/googleapis/enterprise-certificate-proxy v0.3.6 // indirect github.com/googleapis/gax-go/v2 v2.15.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/jaypipes/pcidb v1.0.1 // indirect github.com/josharian/intern v1.0.0 // indirect github.com/josharian/native v1.1.0 // indirect github.com/json-iterator/go v1.1.12 // indirect @@ -75,6 +79,7 @@ require ( github.com/mailru/easyjson v0.9.0 // indirect github.com/mdlayher/packet v1.1.2 // indirect github.com/mdlayher/socket v0.5.1 // indirect + github.com/mitchellh/go-homedir v1.1.0 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.3-0.20250322232337-35a7c28c31ee // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect @@ -116,6 +121,7 @@ require ( gopkg.in/evanphx/json-patch.v4 v4.12.0 // indirect gopkg.in/inf.v0 v0.9.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect + howett.net/plist v1.0.0 // indirect k8s.io/cri-api v0.34.0-beta.0 // indirect k8s.io/kube-openapi v0.0.0-20250710124328-f3f2b991d03b // indirect k8s.io/kubelet v0.33.3 // indirect diff --git a/go.sum b/go.sum index cf84f635..0d61cc82 100644 --- a/go.sum +++ b/go.sum @@ -14,6 +14,8 @@ cloud.google.com/go/container v1.44.0 h1:JEHeW535svvNwJrjrlQ/cdjd15LCWrPKnHsulru cloud.google.com/go/container v1.44.0/go.mod h1:tVK2o4UZUTkg9WpBcgj4qRzwGA1dSFdWA3mil3YkLIQ= github.com/Mellanox/rdmamap v1.1.0 h1:A/W1wAXw+6vm58f3VklrIylgV+eDJlPVIMaIKuxgUT4= github.com/Mellanox/rdmamap v1.1.0/go.mod h1:fN+/V9lf10ABnDCwTaXRjeeWijLt2iVLETnK+sx/LY8= +github.com/StackExchange/wmi v1.2.1 h1:VIkavFPXSjcnS+O8yTq7NI32k0R5Aj+v39y29VYDOSA= +github.com/StackExchange/wmi v1.2.1/go.mod h1:rcmrprowKIVzvc+NUiLncP2uuArMWLCbu9SBzvHz7e8= github.com/antlr4-go/antlr/v4 v4.13.1 h1:SqQKkuVZ+zWkMMNkjy5FZe5mr5WURWnlpmOuzYWrPrQ= github.com/antlr4-go/antlr/v4 v4.13.1/go.mod h1:GKmUxMtwp6ZgGwZSva4eWPC5mS6vUAmOABFgjdkM7Nw= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= @@ -44,6 +46,9 @@ github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= +github.com/go-ole/go-ole v1.2.5/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0= +github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY= +github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0= github.com/go-openapi/jsonpointer v0.21.1 h1:whnzv/pNXtK2FbX/W9yJfRmE2gsmkfahjMKB0fZvcic= github.com/go-openapi/jsonpointer v0.21.1/go.mod h1:50I1STOfbY1ycR8jGz8DaMeLCdXiI6aDteEdRNNzpdk= github.com/go-openapi/jsonreference v0.21.0 h1:Rs+Y7hSXT83Jacb7kFyjn4ijOuVGSvOdF2+tg1TRrwQ= @@ -81,6 +86,11 @@ github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2 github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/insomniacslk/dhcp v0.0.0-20250417080101-5f8cf70e8c5f h1:dd33oobuIv9PcBVqvbEiCXEbNTomOHyj3WFuC5YiPRU= github.com/insomniacslk/dhcp v0.0.0-20250417080101-5f8cf70e8c5f/go.mod h1:zhFlBeJssZ1YBCMZ5Lzu1pX4vhftDvU10WUVb1uXKtM= +github.com/jaypipes/ghw v0.17.0 h1:EVLJeNcy5z6GK/Lqby0EhBpynZo+ayl8iJWY0kbEUJA= +github.com/jaypipes/ghw v0.17.0/go.mod h1:In8SsaDqlb1oTyrbmTC14uy+fbBMvp+xdqX51MidlD8= +github.com/jaypipes/pcidb v1.0.1 h1:WB2zh27T3nwg8AE8ei81sNRb9yWBii3JGNJtT7K9Oic= +github.com/jaypipes/pcidb v1.0.1/go.mod h1:6xYUz/yYEyOkIkUt2t2J2folIuZ4Yg6uByCGFXMCeE4= +github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI= github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/josharian/native v1.1.0 h1:uuaP0hAbW7Y4l0ZRQ6C9zfb7Mg1mbFKry/xzDAfmtLA= @@ -122,6 +132,8 @@ github.com/mdlayher/packet v1.1.2 h1:3Up1NG6LZrsgDVn6X4L9Ge/iyRyxFEFD9o6Pr3Q1nQY github.com/mdlayher/packet v1.1.2/go.mod h1:GEu1+n9sG5VtiRE4SydOmX5GTwyyYlteZiFU+x0kew4= github.com/mdlayher/socket v0.5.1 h1:VZaqt6RkGkt2OE9l3GcC6nZkqD3xKeQLyfleW/uBcos= github.com/mdlayher/socket v0.5.1/go.mod h1:TjPLHI1UgwEv5J1B5q0zTZq12A/6H7nKmtTanQE37IQ= +github.com/mitchellh/go-homedir v1.1.0 h1:lukF9ziXFxDFPkA1vsr5zpc1XuPDn/wFntq5mG+4E0Y= +github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= @@ -242,6 +254,7 @@ golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190606203320-7fc4e5ec1444/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -285,9 +298,12 @@ gopkg.in/evanphx/json-patch.v4 v4.12.0 h1:n6jtcsulIzXPJaxegRbvFNNrZDjbij7ny3gmSP gopkg.in/evanphx/json-patch.v4 v4.12.0/go.mod h1:p8EYWUEYMpynmqDbY58zCKCFZw8pRWMG4EsWvDvM72M= gopkg.in/inf.v0 v0.9.1 h1:73M5CoZyi3ZLMOyDlQh031Cx6N9NDJ2Vvfl76EDAgDc= gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw= +gopkg.in/yaml.v1 v1.0.0-20140924161607-9f9df34309c0/go.mod h1:WDnlLJ4WF5VGsH/HVa3CI79GS0ol3YnhVnKP89i0kNg= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +howett.net/plist v1.0.0 h1:7CrbWYbPPO/PyNy38b2EB/+gYbjCe2DXBxgtOOZbSQM= +howett.net/plist v1.0.0/go.mod h1:lqaXoTrLY4hg8tnEzNru53gicrbv7rrk+2xJA/7hw9g= k8s.io/cloud-provider-gcp v0.0.0-20250326051131-7056e3facd39 h1:2m5DoDX46TPMmpcLRzrOWdBqouWChgbp4F/qlf/lIGc= k8s.io/cloud-provider-gcp v0.0.0-20250326051131-7056e3facd39/go.mod h1:NZrMafedcWXEFDubORCpHuWMp8cUS1TItObinH7vpwg= k8s.io/component-helpers v0.34.0-beta.0 h1:VtrSd4nRzYbHSr6hryNvIojiDOjf5MbjAWjyjGBjFjY= diff --git a/install.yaml b/install.yaml index 95a149a4..b3cfd5a4 100644 --- a/install.yaml +++ b/install.yaml @@ -150,6 +150,10 @@ spec: - name: bpf-programs mountPath: /sys/fs/bpf mountPropagation: HostToContainer + - name: pci-ids-hwdata + mountPath: /usr/share/hwdata/ + - name: pci-ids-misc + mountPath: /usr/share/misc/ volumes: - name: device-plugin hostPath: @@ -172,4 +176,10 @@ spec: - name: bpf-programs hostPath: path: /sys/fs/bpf + - name: pci-ids-hwdata + hostPath: + path: /usr/share/hwdata/ + - name: pci-ids-misc + hostPath: + path: /usr/share/misc/ --- diff --git a/pkg/apis/attributes.go b/pkg/apis/attributes.go new file mode 100644 index 00000000..289e8336 --- /dev/null +++ b/pkg/apis/attributes.go @@ -0,0 +1,42 @@ +/* +Copyright 2025 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package apis + +const ( + AttrPrefix = "dra.net" + + AttrInterfaceName = AttrPrefix + "/" + "ifName" + AttrMac = AttrPrefix + "/" + "mac" + AttrPCIVendor = AttrPrefix + "/" + "pciVendor" + AttrPCIDevice = AttrPrefix + "/" + "pciDevice" + AttrPCISubsystem = AttrPrefix + "/" + "pciSubsystem" + AttrNUMANode = AttrPrefix + "/" + "numaNode" + AttrMTU = AttrPrefix + "/" + "mtu" + AttrEncapsulation = AttrPrefix + "/" + "encapsulation" + AttrAlias = AttrPrefix + "/" + "alias" + AttrState = AttrPrefix + "/" + "state" + AttrType = AttrPrefix + "/" + "type" + AttrIPv4 = AttrPrefix + "/" + "ipv4" + AttrIPv6 = AttrPrefix + "/" + "ipv6" + AttrTCFilterNames = AttrPrefix + "/" + "tcFilterNames" + AttrTCXProgramNames = AttrPrefix + "/" + "tcxProgramNames" + AttrEBPF = AttrPrefix + "/" + "ebpf" + AttrSRIOV = AttrPrefix + "/" + "sriov" + AttrSRIOVVfs = AttrPrefix + "/" + "sriovVfs" + AttrVirtual = AttrPrefix + "/" + "virtual" + AttrRDMA = AttrPrefix + "/" + "rdma" +) diff --git a/pkg/cloudprovider/gce/gce.go b/pkg/cloudprovider/gce/gce.go index 50f867c4..8c32210f 100644 --- a/pkg/cloudprovider/gce/gce.go +++ b/pkg/cloudprovider/gce/gce.go @@ -41,6 +41,16 @@ const ( GPUDirectRDMA GPUDirectSupport = "GPUDirect-RDMA" ) +const ( + GCEAttrPrefix = "gce.dra.net" + + AttrGCEBlock = GCEAttrPrefix + "/" + "block" + AttrGCESubblock = GCEAttrPrefix + "/" + "subblock" + AttrGCEHost = GCEAttrPrefix + "/" + "host" + AttrGCENetworkName = GCEAttrPrefix + "/" + "networkName" + AttrGCENetworkProjectNumber = GCEAttrPrefix + "/" + "networkProjectNumber" +) + var ( // https://cloud.google.com/compute/docs/accelerator-optimized-machines#network-protocol // machine types have a one to one mapping to a network protocol in google cloud diff --git a/pkg/driver/dra_hooks.go b/pkg/driver/dra_hooks.go index 018380b1..f737fb8a 100644 --- a/pkg/driver/dra_hooks.go +++ b/pkg/driver/dra_hooks.go @@ -25,7 +25,6 @@ import ( "github.com/google/dranet/pkg/apis" "github.com/google/dranet/pkg/filter" - "github.com/google/dranet/pkg/names" "github.com/Mellanox/rdmamap" "github.com/vishvananda/netlink" @@ -165,11 +164,14 @@ func (np *NetworkDriver) prepareResourceClaim(ctx context.Context, claim *resour }, Network: netconf, } - ifName := names.GetOriginalName(result.Device) + ifName, err := np.netdb.GetNetInterfaceName(result.Device) + if err != nil { + errorList = append(errorList, fmt.Errorf("failed to get network interface %s", ifName)) + } // Get Network configuration and merge it link, err := nlHandle.LinkByName(ifName) if err != nil { - errorList = append(errorList, fmt.Errorf("fail to get network interface %s", ifName)) + errorList = append(errorList, fmt.Errorf("failed to get netlink to interface %s", ifName)) continue } @@ -298,15 +300,10 @@ func (np *NetworkDriver) prepareResourceClaim(ctx context.Context, claim *resour } } - device := kubeletplugin.Device{ - Requests: []string{result.Request}, - PoolName: result.Pool, - DeviceName: result.Device, - } // TODO: support for multiple pods sharing the same device // we'll create the subinterface here for _, uid := range podUIDs { - np.podConfigStore.Set(uid, device.DeviceName, podCfg) + np.podConfigStore.Set(uid, result.Device, podCfg) } klog.V(4).Infof("Claim Resources for pods %v : %#v", podUIDs, podCfg) } diff --git a/pkg/driver/nri_hooks.go b/pkg/driver/nri_hooks.go index 0c1ea27b..02325535 100644 --- a/pkg/driver/nri_hooks.go +++ b/pkg/driver/nri_hooks.go @@ -21,8 +21,6 @@ import ( "fmt" "time" - "github.com/google/dranet/pkg/names" - "github.com/containerd/nri/pkg/api" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -52,7 +50,7 @@ func (np *NetworkDriver) Synchronize(_ context.Context, pods []*api.PodSandbox, // host network pods are skipped if ns != "" { // store the Pod metadata in the db - np.netdb.AddPodNetns(podKey(pod), ns) + np.netdb.AddPodNetNs(podKey(pod), ns) } } @@ -106,7 +104,7 @@ func (np *NetworkDriver) RunPodSandbox(ctx context.Context, pod *api.PodSandbox) return fmt.Errorf("RunPodSandbox pod %s/%s using host network can not claim host devices", pod.Namespace, pod.Name) } // store the Pod metadata in the db - np.netdb.AddPodNetns(podKey(pod), ns) + np.netdb.AddPodNetNs(podKey(pod), ns) // Track all the status updates needed for the resource claims of the pod. statusUpdates := map[types.NamespacedName]*resourceapply.ResourceClaimStatusApplyConfiguration{} @@ -126,7 +124,13 @@ func (np *NetworkDriver) RunPodSandbox(ctx context.Context, pod *api.PodSandbox) WithDriver(np.driverName). WithPool(np.nodeName) - ifName := names.GetOriginalName(deviceName) + ifName, err := np.netdb.GetNetInterfaceName(deviceName) + if err != nil { + // Tip: We can tweak the code here to simply log + // this without returning and continue exeuction for cases where the + // device has no associated network interface. + return fmt.Errorf("failed to get network interface %s", ifName) + } klog.V(2).Infof("RunPodSandbox processing Network device: %s", ifName) // TODO config options to rename the device and pass parameters @@ -229,10 +233,11 @@ func (np *NetworkDriver) RunPodSandbox(ctx context.Context, pod *api.PodSandbox) // to avoid disrupting the pod shutdown. The kernel will do the cleanup once the namespace // is deleted. func (np *NetworkDriver) StopPodSandbox(ctx context.Context, pod *api.PodSandbox) error { + defer np.netdb.Sync() // Sync is expected to return immediately. klog.V(2).Infof("StopPodSandbox Pod %s/%s UID %s", pod.Namespace, pod.Name, pod.Uid) start := time.Now() defer func() { - np.netdb.RemovePodNetns(podKey(pod)) + np.netdb.RemovePodNetNs(podKey(pod)) klog.V(2).Infof("StopPodSandbox Pod %s/%s UID %s took %v", pod.Namespace, pod.Name, pod.Uid, time.Since(start)) }() @@ -247,7 +252,7 @@ func (np *NetworkDriver) StopPodSandbox(ctx context.Context, pod *api.PodSandbox // some version of containerd does not send the network namespace information on this hook so // we workaround it using the local copy we have in the db to associate interfaces with Pods via // the network namespace id. - ns = np.netdb.GetPodNamespace(podKey(pod)) + ns = np.netdb.GetPodNetNs(podKey(pod)) if ns == "" { klog.Infof("StopPodSandbox pod %s/%s using host network ... skipping", pod.Namespace, pod.Name) return nil @@ -255,10 +260,13 @@ func (np *NetworkDriver) StopPodSandbox(ctx context.Context, pod *api.PodSandbox } for deviceName, config := range podConfig { - ifName := names.GetOriginalName(deviceName) - - if err := nsDetachNetdev(ns, config.Network.Interface.Name, ifName); err != nil { - klog.Infof("fail to return network device %s : %v", deviceName, err) + ifName, err := np.netdb.GetNetInterfaceName(deviceName) + if err == nil { + if err := nsDetachNetdev(ns, config.Network.Interface.Name, ifName); err != nil { + klog.Infof("fail to return network device %s : %v", deviceName, err) + } + } else { + klog.V(2).Infof("Failed to identify network interface for device %s: %v; expect kernel to do the cleanup", deviceName, err) } if !np.rdmaSharedMode && config.RDMADevice.LinkDev != "" { @@ -272,7 +280,7 @@ func (np *NetworkDriver) StopPodSandbox(ctx context.Context, pod *api.PodSandbox func (np *NetworkDriver) RemovePodSandbox(_ context.Context, pod *api.PodSandbox) error { klog.V(2).Infof("RemovePodSandbox Pod %s/%s UID %s", pod.Namespace, pod.Name, pod.Uid) - np.netdb.RemovePodNetns(podKey(pod)) + np.netdb.RemovePodNetNs(podKey(pod)) return nil } diff --git a/pkg/inventory/db.go b/pkg/inventory/db.go index b2937e3d..52f4a3f1 100644 --- a/pkg/inventory/db.go +++ b/pkg/inventory/db.go @@ -19,14 +19,16 @@ package inventory import ( "context" "fmt" - "net" + "maps" "strings" "sync" "time" - "github.com/Mellanox/rdmamap" + "github.com/google/dranet/pkg/apis" "github.com/google/dranet/pkg/cloudprovider" - "github.com/google/dranet/pkg/names" + + "github.com/Mellanox/rdmamap" + "github.com/jaypipes/ghw" "github.com/vishvananda/netlink" "github.com/vishvananda/netns" "golang.org/x/time/rate" @@ -43,20 +45,23 @@ const ( maxInterval = 1 * time.Minute ) -var ( - // ignoredInterfaceNames is a set of network interface names that are typically - // created by CNI plugins or are otherwise not relevant for DRA resource exposure. - ignoredInterfaceNames = sets.New("cilium_net", "cilium_host", "docker") -) - type DB struct { instance *cloudprovider.CloudInstance - mu sync.RWMutex - podStore map[int]string // key: netnsid path value: Pod namespace/name - podNsStore map[string]string // key pod value: netns path - - rateLimiter *rate.Limiter + mu sync.RWMutex + // netNsForPod gives the network namespace for a pod, indexed by the pods + // "namespaced/name". + netNsForPod map[string]string + // deviceStore is an in-memory cache of the available devices on the node. + // It is keyed by the normalized PCI address of the device. The value is a + // resourceapi.Device object that contains the device's attributes. + // The deviceStore is periodically updated by the Run method. + deviceStore map[string]*resourceapi.Device + + rateLimiter *rate.Limiter + // syncCh is a channel that signals a request for an immediate device + // discovery (rateLimiter is NOT ignored.) + syncCh chan any notifications chan []resourceapi.Device hasDevices bool } @@ -64,79 +69,17 @@ type DB struct { func New() *DB { return &DB{ rateLimiter: rate.NewLimiter(rate.Every(minInterval), 1), - podStore: map[int]string{}, - podNsStore: map[string]string{}, + syncCh: make(chan any, 1), + netNsForPod: map[string]string{}, notifications: make(chan []resourceapi.Device), } } -func (db *DB) AddPodNetns(pod string, netnsPath string) { - db.mu.Lock() - defer db.mu.Unlock() - ns, err := netns.GetFromPath(netnsPath) - if err != nil { - klog.Infof("fail to get pod %s network namespace %s handle: %v", pod, netnsPath, err) - return - } - defer ns.Close() - id, err := netlink.GetNetNsIdByFd(int(ns)) - if err != nil { - klog.Infof("fail to get pod %s network namespace %s netnsid: %v", pod, netnsPath, err) - return - } - db.podStore[id] = pod - db.podNsStore[pod] = netnsPath -} - -func (db *DB) RemovePodNetns(pod string) { - db.mu.Lock() - defer db.mu.Unlock() - delete(db.podNsStore, pod) - for k, v := range db.podStore { - if v == pod { - delete(db.podStore, k) - return - } - } -} - -// GetPodName allows to get the Pod name from the namespace Id -// that comes in the link id from the veth pair interface -func (db *DB) GetPodName(netnsid int) string { - db.mu.RLock() - defer db.mu.RUnlock() - return db.podStore[netnsid] -} - -// GetPodNamespace allows to get the Pod network namespace -func (db *DB) GetPodNamespace(pod string) string { - db.mu.RLock() - defer db.mu.RUnlock() - return db.podNsStore[pod] -} - func (db *DB) Run(ctx context.Context) error { defer close(db.notifications) - nlHandle, err := netlink.NewHandle() - if err != nil { - return fmt.Errorf("error creating netlink handle %v", err) - } - // Resources are published periodically or if there is a netlink notification - // indicating a new interfaces was added or changed - nlChannel := make(chan netlink.LinkUpdate) - doneCh := make(chan struct{}) - defer close(doneCh) - if err := netlink.LinkSubscribe(nlChannel, doneCh); err != nil { - klog.Error(err, "error subscribing to netlink interfaces, only syncing periodically", "interval", maxInterval.String()) - } - // Obtain data that will not change after the startup db.instance = getInstanceProperties(ctx) - // TODO: it is not common but may happen in edge cases that the default gateway changes - // revisit once we have more evidence this can be a potential problem or break some use - // cases. - gwInterfaces := getDefaultGwInterfaces() for { err := db.rateLimiter.Wait(ctx) @@ -144,82 +87,145 @@ func (db *DB) Run(ctx context.Context) error { klog.Error(err, "unexpected rate limited error trying to get system interfaces") } - devices := []resourceapi.Device{} - ifaces, err := nlHandle.LinkList() + pci, err := ghw.PCI() if err != nil { - klog.Error(err, "unexpected error trying to get system interfaces") + return fmt.Errorf("error getting PCI info: %v", err) } - for _, iface := range ifaces { - klog.V(7).InfoS("Checking network interface", "name", iface.Attrs().Name) - if gwInterfaces.Has(iface.Attrs().Name) { - klog.V(4).Infof("iface %s is an uplink interface", iface.Attrs().Name) - continue - } - if ignoredInterfaceNames.Has(iface.Attrs().Name) { - klog.V(4).Infof("iface %s is in the list of ignored interfaces", iface.Attrs().Name) - continue - } - - // skip loopback interfaces - if iface.Attrs().Flags&net.FlagLoopback != 0 { - continue - } - - // publish this network interface - device, err := db.netdevToDRAdev(iface) - if err != nil { - klog.V(2).Infof("could not obtain attributes for iface %s : %v", iface.Attrs().Name, err) + devices := []resourceapi.Device{} + for _, pciDev := range pci.Devices { + if !isNetworkDevice(pciDev) { continue } - - devices = append(devices, *device) - klog.V(4).Infof("Found following network interface %s", iface.Attrs().Name) + // TODO(gauravkghildiyal): Exclude device for default interface. + devices = append(devices, *db.pciToDRAdev(pciDev)) } + // Future improvement: We have identified the relevant physical network + // devices. If need be, we could now inspect the network namespaces + // (indicated by `netNsForPod`) to determine the current state of + // network interfaces. This can be achieved by finding interfaces within + // those namespaces that have a PCI address matching a device in our + // deviceStore. + klog.V(4).Infof("Found %d devices", len(devices)) if len(devices) > 0 || db.hasDevices { db.hasDevices = len(devices) > 0 + db.updateDeviceStore(devices) db.notifications <- devices } select { - // trigger a reconcile - case <-nlChannel: - // drain the channel so we only sync once - for len(nlChannel) > 0 { - <-nlChannel - } case <-time.After(maxInterval): + case <-db.syncCh: case <-ctx.Done(): return ctx.Err() } } } +func (db *DB) Sync() { + select { + case db.syncCh <- struct{}{}: + default: + } +} + func (db *DB) GetResources(ctx context.Context) <-chan []resourceapi.Device { return db.notifications } -func (db *DB) netdevToDRAdev(link netlink.Link) (*resourceapi.Device, error) { - ifName := link.Attrs().Name - device := resourceapi.Device{ +func (db *DB) pciToDRAdev(pciDev *ghw.PCIDevice) *resourceapi.Device { + device := &resourceapi.Device{ + Name: NormalizePCIAddress(pciDev.Address), Attributes: make(map[resourceapi.QualifiedName]resourceapi.DeviceAttribute), Capacity: make(map[resourceapi.QualifiedName]resourceapi.DeviceCapacity), } - // Set the device name. It will be normalized only if necessary. - device.Name = names.SetDeviceName(ifName) - // expose the real interface name as an attribute in case it is normalized. - device.Attributes["dra.net/ifName"] = resourceapi.DeviceAttribute{StringValue: &ifName} - - linkType := link.Type() - linkAttrs := link.Attrs() - - // identify the namespace holding the link as the other end of a veth pair - netnsid := link.Attrs().NetNsID - if podName := db.GetPodName(netnsid); podName != "" { - device.Attributes["dra.net/pod"] = resourceapi.DeviceAttribute{StringValue: &podName} + + db.addPCIAttributes(device, pciDev) + db.addNetDevAttributes(device, pciDev.Address) + db.addRDMADevAttributes(device, pciDev.Address) + db.addCloudProviderAttributes(device) + + return device +} + +func (db *DB) addPCIAttributes(device *resourceapi.Device, pciDev *ghw.PCIDevice) { + if pciDev.Vendor != nil { + device.Attributes[apis.AttrPCIVendor] = resourceapi.DeviceAttribute{StringValue: &pciDev.Vendor.Name} + } + if pciDev.Product != nil { + device.Attributes[apis.AttrPCIDevice] = resourceapi.DeviceAttribute{StringValue: &pciDev.Product.Name} + } + if pciDev.Subsystem != nil { + device.Attributes[apis.AttrPCISubsystem] = resourceapi.DeviceAttribute{StringValue: &pciDev.Subsystem.ID} } + if pciDev.Node != nil { + device.Attributes[apis.AttrNUMANode] = resourceapi.DeviceAttribute{IntValue: ptr.To(int64(pciDev.Node.ID))} + } + + pcieRootAttr, err := deviceattribute.GetPCIeRootAttributeByPCIBusID(pciDev.Address) + if err != nil { + klog.Infof("Could not get pci root attribute: %v", err) + } else { + device.Attributes[pcieRootAttr.Name] = pcieRootAttr.Value + } +} + +func (db *DB) addNetDevAttributes(device *resourceapi.Device, pciAddress string) { + ifName, err := GetNetworkInterface(pciAddress) + if err != nil { + klog.Infof("Could not get network interface for pci device %s: %v; Will re-use any existing device attributes.", pciAddress, err) + prevDevice, exists := db.GetDevice(NormalizePCIAddress(pciAddress)) + if exists { + // This merging is a best-effort attempt to preserve device attributes + // when a network interface is not in the root namespace (e.g., inside a + // pod). It relies entirely on the existing `deviceStore` cache. + // + // This approach has a key limitation: if the agent restarts, the cache + // is empty. If an interface is already in a pod's namespace at startup, + // its network-related attributes cannot be discovered and will be missing + // until the interface is returned to the host. + // + // This is an acceptable limitation because the missing attributes + // are mutable (e.g., IP address, Interface Name). Users should base + // resource selections on stable, immutable device properties (like + // PCI Root, RDMA type) rather than transient interface state. + mergeDeviceAttributes(device, prevDevice, + apis.AttrInterfaceName, + apis.AttrMac, + apis.AttrEncapsulation, + apis.AttrAlias, + apis.AttrState, + apis.AttrType, + apis.AttrIPv4, + apis.AttrIPv6, + apis.AttrTCFilterNames, + apis.AttrTCXProgramNames, + apis.AttrEBPF, + apis.AttrSRIOV, + apis.AttrSRIOVVfs, + apis.AttrVirtual, + ) + } + return + } + // Expose the interface name as an attribute. + device.Attributes[apis.AttrInterfaceName] = resourceapi.DeviceAttribute{StringValue: &ifName} + + link, err := netlink.LinkByName(ifName) + if err != nil { + klog.Infof("Could not get link for interface %s: %v", ifName, err) + return + } + + device.Attributes[apis.AttrMac] = resourceapi.DeviceAttribute{StringValue: ptr.To(link.Attrs().HardwareAddr.String())} + device.Attributes[apis.AttrMTU] = resourceapi.DeviceAttribute{IntValue: ptr.To(int64(link.Attrs().MTU))} + device.Attributes[apis.AttrEncapsulation] = resourceapi.DeviceAttribute{StringValue: ptr.To(link.Attrs().EncapType)} + device.Attributes[apis.AttrAlias] = resourceapi.DeviceAttribute{StringValue: ptr.To(link.Attrs().Alias)} + device.Attributes[apis.AttrState] = resourceapi.DeviceAttribute{StringValue: ptr.To(link.Attrs().OperState.String())} + device.Attributes[apis.AttrType] = resourceapi.DeviceAttribute{StringValue: ptr.To(link.Type())} + v4 := sets.Set[string]{} v6 := sets.Set[string]{} if ips, err := netlink.AddrList(link, netlink.FAMILY_ALL); err == nil && len(ips) > 0 { @@ -235,101 +241,127 @@ func (db *DB) netdevToDRAdev(link netlink.Link) (*resourceapi.Device, error) { } } if v4.Len() > 0 { - device.Attributes["dra.net/ipv4"] = resourceapi.DeviceAttribute{StringValue: ptr.To(strings.Join(v4.UnsortedList(), ","))} + device.Attributes[apis.AttrIPv4] = resourceapi.DeviceAttribute{StringValue: ptr.To(strings.Join(v4.UnsortedList(), ","))} } if v6.Len() > 0 { - device.Attributes["dra.net/ipv6"] = resourceapi.DeviceAttribute{StringValue: ptr.To(strings.Join(v6.UnsortedList(), ","))} + device.Attributes[apis.AttrIPv6] = resourceapi.DeviceAttribute{StringValue: ptr.To(strings.Join(v6.UnsortedList(), ","))} } - mac := link.Attrs().HardwareAddr.String() - device.Attributes["dra.net/mac"] = resourceapi.DeviceAttribute{StringValue: &mac} - mtu := int64(link.Attrs().MTU) - device.Attributes["dra.net/mtu"] = resourceapi.DeviceAttribute{IntValue: &mtu} } - device.Attributes["dra.net/encapsulation"] = resourceapi.DeviceAttribute{StringValue: &linkAttrs.EncapType} - operState := linkAttrs.OperState.String() - device.Attributes["dra.net/state"] = resourceapi.DeviceAttribute{StringValue: &operState} - device.Attributes["dra.net/alias"] = resourceapi.DeviceAttribute{StringValue: &linkAttrs.Alias} - device.Attributes["dra.net/type"] = resourceapi.DeviceAttribute{StringValue: &linkType} - // Get eBPF properties from the interface using the legacy tc hooks isEbpf := false filterNames, ok := getTcFilters(link) if ok { isEbpf = true - device.Attributes["dra.net/tcFilterNames"] = resourceapi.DeviceAttribute{StringValue: ptr.To(strings.Join(filterNames, ","))} + device.Attributes[apis.AttrTCFilterNames] = resourceapi.DeviceAttribute{StringValue: ptr.To(strings.Join(filterNames, ","))} } // Get eBPF properties from the interface using the tcx hooks programNames, ok := getTcxFilters(link) if ok { isEbpf = true - device.Attributes["dra.net/tcxProgramNames"] = resourceapi.DeviceAttribute{StringValue: ptr.To(strings.Join(programNames, ","))} + device.Attributes[apis.AttrTCXProgramNames] = resourceapi.DeviceAttribute{StringValue: ptr.To(strings.Join(programNames, ","))} } - device.Attributes["dra.net/ebpf"] = resourceapi.DeviceAttribute{BoolValue: &isEbpf} + device.Attributes[apis.AttrEBPF] = resourceapi.DeviceAttribute{BoolValue: &isEbpf} - isRDMA := rdmamap.IsRDmaDeviceForNetdevice(ifName) - device.Attributes["dra.net/rdma"] = resourceapi.DeviceAttribute{BoolValue: &isRDMA} // from https://github.com/k8snetworkplumbingwg/sriov-network-device-plugin/blob/ed1c14dd4c313c7dd9fe4730a60358fbeffbfdd4/pkg/netdevice/netDeviceProvider.go#L99 isSRIOV := sriovTotalVFs(ifName) > 0 - device.Attributes["dra.net/sriov"] = resourceapi.DeviceAttribute{BoolValue: &isSRIOV} + device.Attributes[apis.AttrSRIOV] = resourceapi.DeviceAttribute{BoolValue: &isSRIOV} if isSRIOV { vfs := int64(sriovNumVFs(ifName)) - device.Attributes["dra.net/sriovVfs"] = resourceapi.DeviceAttribute{IntValue: &vfs} + device.Attributes[apis.AttrSRIOVVfs] = resourceapi.DeviceAttribute{IntValue: &vfs} } if isVirtual(ifName, sysnetPath) { - device.Attributes["dra.net/virtual"] = resourceapi.DeviceAttribute{BoolValue: ptr.To(true)} + device.Attributes[apis.AttrVirtual] = resourceapi.DeviceAttribute{BoolValue: ptr.To(true)} } else { - addPCIAttributes(&device, ifName, sysnetPath) - } - - mac := link.Attrs().HardwareAddr.String() - for name, attribute := range getProviderAttributes(mac, db.instance) { - device.Attributes[name] = attribute + device.Attributes[apis.AttrVirtual] = resourceapi.DeviceAttribute{BoolValue: ptr.To(false)} } +} - return &device, nil +func (db *DB) addRDMADevAttributes(device *resourceapi.Device, pciAddress string) { + rdmaDevices := rdmamap.GetRdmaDevicesForPcidev(pciAddress) + isRDMA := len(rdmaDevices) != 0 + device.Attributes[apis.AttrRDMA] = resourceapi.DeviceAttribute{BoolValue: &isRDMA} } -func addPCIAttributes(device *resourceapi.Device, ifName string, path string) { - device.Attributes["dra.net/virtual"] = resourceapi.DeviceAttribute{BoolValue: ptr.To(false)} +func (db *DB) addCloudProviderAttributes(device *resourceapi.Device) { + mac := device.Attributes[apis.AttrMac] + if mac.StringValue == nil { + return + } + maps.Copy(device.Attributes, getProviderAttributes(*mac.StringValue, db.instance)) +} - address, err := bdfAddress(ifName, path) +func (db *DB) AddPodNetNs(pod string, netNsPath string) { + db.mu.Lock() + defer db.mu.Unlock() + ns, err := netns.GetFromPath(netNsPath) if err != nil { - klog.Infof("Could not get bdf address : %v", err) - } else { - if err := setPciRootAttr(device, address); err != nil { - klog.Infof("Could not get pci root attribute : %v", err) - } + klog.Infof("fail to get pod %s network namespace %s handle: %v", pod, netNsPath, err) + return } + defer ns.Close() + db.netNsForPod[pod] = netNsPath +} - entry, err := ids(ifName, path) - if err == nil { - if entry.Vendor != "" { - device.Attributes["dra.net/pciVendor"] = resourceapi.DeviceAttribute{StringValue: &entry.Vendor} - } - if entry.Device != "" { - device.Attributes["dra.net/pciDevice"] = resourceapi.DeviceAttribute{StringValue: &entry.Device} - } - if entry.Subsystem != "" { - device.Attributes["dra.net/pciSubsystem"] = resourceapi.DeviceAttribute{StringValue: &entry.Subsystem} - } - } else { - klog.Infof("could not get pci vendor information : %v", err) +func (db *DB) RemovePodNetNs(pod string) { + db.mu.Lock() + defer db.mu.Unlock() + delete(db.netNsForPod, pod) +} + +// GetPodNamespace allows to get the Pod network namespace +func (db *DB) GetPodNetNs(pod string) string { + db.mu.RLock() + defer db.mu.RUnlock() + return db.netNsForPod[pod] +} + +func (db *DB) updateDeviceStore(devices []resourceapi.Device) { + db.mu.Lock() + defer db.mu.Unlock() + db.deviceStore = map[string]*resourceapi.Device{} + for _, device := range devices { + db.deviceStore[device.Name] = &device } +} + +func (db *DB) GetDevice(deviceName string) (*resourceapi.Device, bool) { + db.mu.Lock() + defer db.mu.Unlock() + device, exists := db.deviceStore[deviceName] + return device, exists +} - numa, err := numaNode(ifName, path) - if err == nil { - device.Attributes["dra.net/numaNode"] = resourceapi.DeviceAttribute{IntValue: &numa} +func (db *DB) GetNetInterfaceName(deviceName string) (string, error) { + pciAddress := DeNormalizePCIAddress(deviceName) + device, exists := db.GetDevice(deviceName) + if !exists { + klog.Infof("device %s not found in store, using sysfs to get interface name", deviceName) + return GetNetworkInterface(pciAddress) + } + if device.Attributes[apis.AttrInterfaceName].StringValue == nil { + klog.Infof("device %s has no interface name in store, using sysfs to get interface name", deviceName) + return GetNetworkInterface(pciAddress) } + return *device.Attributes[apis.AttrInterfaceName].StringValue, nil } -func setPciRootAttr(device *resourceapi.Device, address *pciAddress) error { - pcieRootAttr, err := deviceattribute.GetPCIeRootAttributeByPCIBusID(address.bus) - if err != nil { - return err +// mergeDeviceAttributes copies a selective list of attributes from a source device +// to a destination device. This is useful in scenarios where a device's state +// cannot be fully determined (e.g., a network interface is down), allowing the +// driver to preserve and reuse previously known attributes from the device store. +func mergeDeviceAttributes(dest *resourceapi.Device, src *resourceapi.Device, attrsToCopy ...resourceapi.QualifiedName) { + if dest == nil || src == nil || src.Attributes == nil { + return + } + if dest.Attributes == nil { + dest.Attributes = make(map[resourceapi.QualifiedName]resourceapi.DeviceAttribute) + } + for _, attr := range attrsToCopy { + if val, ok := src.Attributes[attr]; ok { + dest.Attributes[attr] = val + } } - device.Attributes[pcieRootAttr.Name] = pcieRootAttr.Value - return nil } diff --git a/pkg/inventory/net.go b/pkg/inventory/net.go index f2f14d9a..d641a369 100644 --- a/pkg/inventory/net.go +++ b/pkg/inventory/net.go @@ -22,55 +22,8 @@ import ( "github.com/vishvananda/netlink" "k8s.io/apimachinery/pkg/util/sets" - "k8s.io/klog/v2" ) -func getDefaultGwInterfaces() sets.Set[string] { - interfaces := sets.Set[string]{} - filter := &netlink.Route{} - routes, err := netlink.RouteListFiltered(netlink.FAMILY_ALL, filter, netlink.RT_FILTER_TABLE) - if err != nil { - return interfaces - } - - for _, r := range routes { - if r.Family != netlink.FAMILY_V4 && r.Family != netlink.FAMILY_V6 { - continue - } - - if r.Dst != nil && !r.Dst.IP.IsUnspecified() { - continue - } - - // no multipath - if len(r.MultiPath) == 0 { - if r.Gw == nil { - continue - } - intfLink, err := netlink.LinkByIndex(r.LinkIndex) - if err != nil { - klog.Infof("Failed to get interface link for route %v : %v", r, err) - continue - } - interfaces.Insert(intfLink.Attrs().Name) - } - - for _, nh := range r.MultiPath { - if nh.Gw == nil { - continue - } - intfLink, err := netlink.LinkByIndex(r.LinkIndex) - if err != nil { - klog.Infof("Failed to get interface link for route %v : %v", r, err) - continue - } - interfaces.Insert(intfLink.Attrs().Name) - } - } - klog.V(4).Infof("Found following interfaces for the default gateway: %v", interfaces.UnsortedList()) - return interfaces -} - func getTcFilters(link netlink.Link) ([]string, bool) { isTcEBPF := false filterNames := sets.Set[string]{} diff --git a/pkg/inventory/pci.go b/pkg/inventory/pci.go new file mode 100644 index 00000000..e0bbdb22 --- /dev/null +++ b/pkg/inventory/pci.go @@ -0,0 +1,86 @@ +/* +Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package inventory + +import ( + "fmt" + "os" + "path/filepath" + "strings" + + "github.com/jaypipes/ghw" + "k8s.io/klog/v2" +) + +const ( + pciClassNetwork = "02" + // The digit 1 indicates the first versioned naming scheme, allowing + // different future naming schemes. + normalizedNamePrefix = "net1" +) + +func isNetworkDevice(dev *ghw.PCIDevice) bool { + return strings.HasPrefix(dev.Class.ID, pciClassNetwork) +} + +// GetNetworkInterface returns the network interface name for a given PCI address. +func GetNetworkInterface(pciAddress string) (string, error) { + pciPath := filepath.Join(sysBusPciDevicesPath, pciAddress, "net") + if _, err := os.Stat(pciPath); os.IsNotExist(err) { + return "", fmt.Errorf("no net directory for pci device %s", pciAddress) + } + files, err := os.ReadDir(pciPath) + if err != nil { + return "", err + } + if len(files) == 0 { + return "", fmt.Errorf("no interface found for pci device %s", pciAddress) + } + klog.V(3).Infof("found interface %s for pci device %s", files[0].Name(), pciAddress) + return files[0].Name(), nil +} + +// NormalizePCIAddress takes a PCI address and converts it into a DNS-1123 +// acceptable format. +func NormalizePCIAddress(pciAddress string) string { + if pciAddress == "" { + return "" + } + // Replace ":" and "." with "-" to make it DNS-1123 compliant. + // A PCI address like "0000:8a:00.0" becomes "0000-8a-00-0". + r := strings.NewReplacer(":", "-", ".", "-") + return normalizedNamePrefix + "-" + r.Replace(pciAddress) +} + +// DeNormalizePCIAddress takes a normalized PCI address and converts it back to +// a standard PCI address format. +func DeNormalizePCIAddress(normalizedAddress string) string { + if normalizedAddress == "" { + return "" + } + if !strings.HasPrefix(normalizedAddress, normalizedNamePrefix+"-") { + klog.Errorf("invalid normalized PCI address format: missing '%v-' prefix: %s", normalizedNamePrefix, normalizedAddress) + return "" + } + pciAddress := strings.TrimPrefix(normalizedAddress, normalizedNamePrefix+"-") + parts := strings.Split(pciAddress, "-") + if len(parts) != 4 { + klog.Errorf("invalid normalized PCI address format: expected 4 parts, got %d for %s", len(parts), normalizedAddress) + return "" + } + return fmt.Sprintf("%s:%s:%s.%s", parts[0], parts[1], parts[2], parts[3]) +} diff --git a/pkg/inventory/pci_test.go b/pkg/inventory/pci_test.go new file mode 100644 index 00000000..a0d299af --- /dev/null +++ b/pkg/inventory/pci_test.go @@ -0,0 +1,80 @@ +/* +Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package inventory + +import ( + "testing" +) + +func TestNormalizePCIAddress(t *testing.T) { + testCases := []struct { + name string + pciAddress string + want string + }{ + { + name: "Standard PCI Address", + pciAddress: "0000:8a:00.0", + want: "net1-0000-8a-00-0", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + if got := NormalizePCIAddress(tc.pciAddress); got != tc.want { + t.Errorf("NormalizePCIAddress(%v) = %v, want %v", tc.pciAddress, got, tc.want) + } + }) + } +} + +func TestDeNormalizePCIAddress(t *testing.T) { + testCases := []struct { + name string + normalizedAddress string + want string + }{ + { + name: "Standard Normalized Address", + normalizedAddress: "net1-0000-8a-00-0", + want: "0000:8a:00.0", + }, + { + name: "Empty Normalized Address", + normalizedAddress: "", + want: "", + }, + { + name: "Invalid Format - No Prefix", + normalizedAddress: "0000-8a-00-0", + want: "", + }, + { + name: "Invalid Format - Wrong Number of Parts", + normalizedAddress: "net1-0000-8a-00", + want: "", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + if got := DeNormalizePCIAddress(tc.normalizedAddress); got != tc.want { + t.Errorf("DeNormalizePCIAddress(%v) = %v, want %v", tc.normalizedAddress, got, tc.want) + } + }) + } +} diff --git a/pkg/inventory/sysfs.go b/pkg/inventory/sysfs.go index e7fddda6..adacf21f 100644 --- a/pkg/inventory/sysfs.go +++ b/pkg/inventory/sysfs.go @@ -18,13 +18,11 @@ package inventory import ( "bytes" - "fmt" "os" "path/filepath" "strconv" "strings" - "github.com/google/dranet/pkg/pcidb" "k8s.io/klog/v2" ) @@ -37,7 +35,8 @@ const ( // that is accessing the directory. Each of these symbolic // links refers to entries in the /sys/devices directory. // https://man7.org/linux/man-pages/man5/sysfs.5.html - sysdevPath = "/sys/devices" + sysdevPath = "/sys/devices" + sysBusPciDevicesPath = "/sys/bus/pci/devices" ) func realpath(ifName string, syspath string) string { @@ -95,111 +94,3 @@ func sriovNumVFs(name string) int { } return t } - -func numaNode(ifName string, syspath string) (int64, error) { - // /sys/class/net//device/numa_node - numeNode, err := os.ReadFile(filepath.Join(syspath, ifName, "device/numa_node")) - if err != nil { - return 0, err - } - numa, err := strconv.ParseInt(strings.TrimSpace(string(numeNode)), 10, 32) - if err != nil { - return 0, err - } - return numa, nil -} - -// pciAddress BDF Notation -// [domain:]bus:device.function -// https://wiki.xenproject.org/wiki/Bus:Device.Function_(BDF)_Notation -type pciAddress struct { - // There might be several independent sets of PCI devices - // (e.g. several host PCI controllers on a mainboard chipset) - domain string - bus string - device string - // One PCI device (e.g. pluggable card) may implement several functions - // (e.g. sound card and joystick controller used to be a common combo), - // so PCI provides for up to 8 separate functions on a single PCI device. - function string -} - -// The PCI root is the root PCI device, derived from the -// pciAddress of a device. Spec is defined from the DRA KEP. -// https://github.com/kubernetes/enhancements/pull/5316 -type pciRoot struct { - domain string - // The root may have a different host bus than the PCI device. - // e.g https://uefi.org/specs/UEFI/2.10/14_Protocols_PCI_Bus_Support.html#server-system-with-four-pci-root-bridges - bus string -} - -func bdfAddress(ifName string, path string) (*pciAddress, error) { - address := &pciAddress{} - // https://docs.kernel.org/PCI/sysfs-pci.html - // realpath /sys/class/net/ens4/device - // /sys/devices/pci0000:00/0000:00:04.0/virtio1 - // The topmost element describes the PCI domain and bus number. - // PCI domain: 0000 Bus: 00 Device: 04 Function: 0 - sysfsPath := realpath(ifName, path) - bfd := strings.Split(sysfsPath, "/") - if len(bfd) < 5 { - return nil, fmt.Errorf("could not find corresponding PCI address: %v", bfd) - } - - klog.V(4).Infof("pci address: %s", bfd[4]) - pci := strings.Split(bfd[4], ":") - // Simple BDF notation - switch len(pci) { - case 2: - address.bus = pci[0] - f := strings.Split(pci[1], ".") - if len(f) != 2 { - return nil, fmt.Errorf("could not find corresponding PCI device and function: %v", pci) - } - address.device = f[0] - address.function = f[1] - case 3: - address.domain = pci[0] - address.bus = pci[1] - f := strings.Split(pci[2], ".") - if len(f) != 2 { - return nil, fmt.Errorf("could not find corresponding PCI device and function: %v", pci) - } - address.device = f[0] - address.function = f[1] - default: - return nil, fmt.Errorf("could not find corresponding PCI address: %v", pci) - } - return address, nil -} - -func ids(ifName string, path string) (*pcidb.Entry, error) { - // PCI data - var device, subsystemVendor, subsystemDevice []byte - vendor, err := os.ReadFile(filepath.Join(path, ifName, "device/vendor")) - if err != nil { - return nil, err - } - // device, subsystemVendor and subsystemDevice are best effort - device, err = os.ReadFile(filepath.Join(sysdevPath, ifName, "device/device")) - if err == nil { - subsystemVendor, err = os.ReadFile(filepath.Join(sysdevPath, ifName, "device/subsystem_vendor")) - if err == nil { - subsystemDevice, _ = os.ReadFile(filepath.Join(sysdevPath, ifName, "device/subsystem_device")) - } - } - - // remove the 0x prefix - entry, err := pcidb.GetDevice( - strings.TrimPrefix(strings.TrimSpace(string(vendor)), "0x"), - strings.TrimPrefix(strings.TrimSpace(string(device)), "0x"), - strings.TrimPrefix(strings.TrimSpace(string(subsystemVendor)), "0x"), - strings.TrimPrefix(strings.TrimSpace(string(subsystemDevice)), "0x"), - ) - - if err != nil { - return nil, err - } - return entry, nil -} diff --git a/pkg/names/names.go b/pkg/names/names.go deleted file mode 100644 index 281b7566..00000000 --- a/pkg/names/names.go +++ /dev/null @@ -1,71 +0,0 @@ -/* -Copyright 2025 Google LLC - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - https://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package names - -import ( - "encoding/base32" - "strings" - - "k8s.io/apimachinery/pkg/util/validation" - "k8s.io/klog/v2" -) - -const ( - // NormalizedPrefix is added to device names that had to be encoded - // because their original interface name was not DNS-1123 compliant. - NormalizedPrefix = "normalized-" -) - -// SetDeviceName determines the appropriate name for a device in Kubernetes. -// If the original interface name (ifName) is already a valid DNS-1123 label, -// it's returned as is. Otherwise, it's encoded using Base32, prefixed with -// NormalizedPrefix, and returned. -// Linux interface names (often limited by IFNAMSIZ, typically 16) plus the -// base32 encoding and the normalized prefix (11) are within the DNS-1123 label, -// which has a maximum length of 63. -func SetDeviceName(ifName string) string { - if ifName == "" { - return "" - } - if len(validation.IsDNS1123Label(ifName)) == 0 { - return ifName - } - - klog.V(4).Infof("Interface name '%s' is not DNS-1123 compliant, normalizing.", ifName) - encodedPayload := base32.StdEncoding.WithPadding(base32.NoPadding).EncodeToString([]byte(ifName)) - normalizedName := NormalizedPrefix + strings.ToLower(encodedPayload) - - return normalizedName -} - -// GetOriginalName retrieves the original interface name from a deviceName. -// If deviceName was prefixed with NormalizedPrefix (indicating it was encoded), -// it decodes the name. Otherwise, it assumes deviceName is the original name. -func GetOriginalName(deviceName string) string { - if strings.HasPrefix(deviceName, NormalizedPrefix) { - encodedPart := strings.TrimPrefix(deviceName, NormalizedPrefix) - encodedPart = strings.ToUpper(encodedPart) // base32 uses uppercase only - decodedBytes, err := base32.StdEncoding.WithPadding(base32.NoPadding).DecodeString(encodedPart) - if err != nil { - klog.Warningf("Failed to decode Base32 device name payload '%s' from full name '%s': %v. Returning the full deviceName as fallback.", - encodedPart, deviceName, err) - return deviceName - } - return string(decodedBytes) - } - return deviceName -} diff --git a/pkg/names/names_test.go b/pkg/names/names_test.go deleted file mode 100644 index 2efbfe2a..00000000 --- a/pkg/names/names_test.go +++ /dev/null @@ -1,76 +0,0 @@ -/* -Copyright 2025 Google LLC - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - https://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package names - -import ( - "strings" - "testing" -) - -func TestSetDeviceName(t *testing.T) { - tests := []struct { - name string - ifName string - want string - }{ - {"already compliant", "eth0", "eth0"}, - {"already compliant with hyphen", "my-device-1", "my-device-1"}, - {"needs normalization colons", "eth:0", NormalizedPrefix + "mv2gqorq"}, - {"needs normalization uppercase", "ETH0", NormalizedPrefix + "ivkeqma"}, - {"needs normalization underscore", "eth_int", NormalizedPrefix + "mv2gqx3jnz2a"}, - {"empty string", "", ""}, - { - name: "long name needs normalization", - ifName: "very_long_interface_name_that_is_not_dns_compliant_at_all_and_exceeds_limits", - // base32 of the above is much longer, this is just to check prefixing - want: NormalizedPrefix + "ozsxe6k7nrxw4z27nfxhizlsmzqwgzk7nzqw2zk7orugc5c7nfzv63tporpwi3ttl5rw63lqnruwc3tul5qxix3bnrwf6ylomrpwk6ddmvswi427nruw22luom", - }, - {"already compliant max length", strings.Repeat("a", 63), strings.Repeat("a", 63)}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := SetDeviceName(tt.ifName); got != tt.want { - t.Errorf("SetDeviceName(%q) = %q, want %q", tt.ifName, got, tt.want) - } - }) - } -} - -func TestSetAndGetOriginalName(t *testing.T) { - testIfNames := []string{ - "eth0", - "my-nic", - "eth:0:1", - "veth_pair_A", - "UPPERCASE_NIC", - "a.b.c.d.e.f", - strings.Repeat("a_b", 30), // long non-compliant name - "", - } - - for _, ifName := range testIfNames { - t.Run(ifName, func(t *testing.T) { - deviceName := SetDeviceName(ifName) - originalName := GetOriginalName(deviceName) - if originalName != ifName { - t.Errorf("SetDeviceName -> GetOriginalName round trip failed for %q: SetDeviceName returned %q, GetOriginalName returned %q", - ifName, deviceName, originalName) - } - }) - } -}