diff --git a/images/virtualization-dra/cmd/usb-gateway/app/app.go b/images/virtualization-dra/cmd/usb-gateway/app/app.go new file mode 100644 index 0000000000..3f5db92bb1 --- /dev/null +++ b/images/virtualization-dra/cmd/usb-gateway/app/app.go @@ -0,0 +1,97 @@ +package app + +import ( + "fmt" + + "github.com/spf13/cobra" + "golang.org/x/sync/errgroup" + "k8s.io/client-go/kubernetes" + "k8s.io/client-go/tools/clientcmd" + "k8s.io/component-base/cli/flag" + + "github.com/deckhouse/virtualization-dra/internal/controller/resourceclaim" + "github.com/deckhouse/virtualization-dra/internal/informer" + "github.com/deckhouse/virtualization-dra/pkg/logger" +) + +func NewUSBGatewayCommand() *cobra.Command { + o := &usbOptions{} + + cmd := &cobra.Command{ + Use: "usb-gateway", + Short: "USB gateway", + Long: "USB gateway", + SilenceUsage: true, + SilenceErrors: true, + PreRunE: func(cmd *cobra.Command, args []string) error { + if err := o.Validate(); err != nil { + return err + } + log := o.Logging.Complete() + logger.SetDefaultLogger(log) + return nil + }, + RunE: o.Run, + } + + fs := cmd.Flags() + for _, f := range o.NamedFlags().FlagSets { + fs.AddFlagSet(f) + } + + return cmd +} + +type usbOptions struct { + Kubeconfig string + NodeName string + Logging *logger.Options +} + +func (o *usbOptions) NamedFlags() (fs flag.NamedFlagSets) { + mfs := fs.FlagSet("usb-gateway") + mfs.StringVar(&o.Kubeconfig, "kubeconfig", o.Kubeconfig, "Path to kubeconfig file") + mfs.StringVar(&o.NodeName, "node-name", o.NodeName, "Node name") + + o.Logging.AddFlags(fs.FlagSet("logging")) + + return fs +} + +func (o *usbOptions) Validate() error { + if o.NodeName == "" { + return fmt.Errorf("NodeName is required") + } + + return nil +} + +func (o *usbOptions) Run(cmd *cobra.Command, _ []string) error { + cfg, err := clientcmd.BuildConfigFromFlags("", o.Kubeconfig) + if err != nil { + return fmt.Errorf("failed to get rest config: %w", err) + } + + client, err := kubernetes.NewForConfig(cfg) + if err != nil { + return fmt.Errorf("failed to create kubernetes client: %w", err) + } + + f := informer.NewFactory(client, nil) + resourceClaimInformer := f.ResourceClaim() + + f.Start(cmd.Context().Done()) + f.WaitForCacheSync(cmd.Context().Done()) + + c, err := resourceclaim.NewController(resourceClaimInformer) + if err != nil { + return fmt.Errorf("failed to create resourceclaim controller: %w", err) + } + + group, ctx := errgroup.WithContext(cmd.Context()) + group.Go(func() error { + return c.Run(ctx, 1) + }) + + return group.Wait() +} diff --git a/images/virtualization-dra/cmd/usb-gateway/main.go b/images/virtualization-dra/cmd/usb-gateway/main.go new file mode 100644 index 0000000000..8e5b7a5d15 --- /dev/null +++ b/images/virtualization-dra/cmd/usb-gateway/main.go @@ -0,0 +1,20 @@ +package main + +import ( + "context" + "log/slog" + "os" + "os/signal" + "syscall" + + "github.com/deckhouse/virtualization-dra/cmd/usb-gateway/app" +) + +func main() { + ctx, _ := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) + + if err := app.NewUSBGatewayCommand().ExecuteContext(ctx); err != nil { + slog.Error("failed to execute command", slog.Any("err", err)) + os.Exit(1) + } +} diff --git a/images/virtualization-dra/internal/controller/resourceclaim/controller.go b/images/virtualization-dra/internal/controller/resourceclaim/controller.go new file mode 100644 index 0000000000..494dc38204 --- /dev/null +++ b/images/virtualization-dra/internal/controller/resourceclaim/controller.go @@ -0,0 +1,144 @@ +package resourceclaim + +import ( + "context" + "fmt" + "log/slog" + "time" + + resourcev1beta1 "k8s.io/api/resource/v1beta1" + utilruntime "k8s.io/apimachinery/pkg/util/runtime" + "k8s.io/apimachinery/pkg/util/wait" + "k8s.io/client-go/tools/cache" + "k8s.io/client-go/util/workqueue" +) + +const controllerName = "resourceclaim-controller" + +var ( + keyFunc = cache.DeletionHandlingMetaNamespaceKeyFunc +) + +type Controller struct { + resourceClaimIndexer cache.Indexer + queue workqueue.TypedRateLimitingInterface[string] + log *slog.Logger + hasSynced cache.InformerSynced +} + +func NewController(resourceClaimInformer cache.SharedIndexInformer) (*Controller, error) { + queue := workqueue.NewTypedRateLimitingQueueWithConfig( + workqueue.DefaultTypedControllerRateLimiter[string](), + workqueue.TypedRateLimitingQueueConfig[string]{Name: "resourceclaim-controller"}, + ) + log := slog.With(slog.String("controller", controllerName)) + + c := &Controller{ + resourceClaimIndexer: resourceClaimInformer.GetIndexer(), + queue: queue, + log: log, + } + + _, err := resourceClaimInformer.AddEventHandler(cache.ResourceEventHandlerFuncs{ + AddFunc: c.addResourceClaim, + UpdateFunc: c.updateResourceClaim, + DeleteFunc: c.deleteResourceClaim, + }) + if err != nil { + return nil, fmt.Errorf("unable to add event handler to resourceclaim informer: %w", err) + } + + c.hasSynced = resourceClaimInformer.HasSynced + + return c, nil +} + +func (c *Controller) addResourceClaim(obj interface{}) { + if rc, ok := obj.(*resourcev1beta1.ResourceClaim); ok { + c.enqueueResourceClaim(rc) + } +} + +func (c *Controller) deleteResourceClaim(obj interface{}) { + if rc, ok := obj.(*resourcev1beta1.ResourceClaim); ok { + c.enqueueResourceClaim(rc) + } +} + +func (c *Controller) updateResourceClaim(oldObj, newObj interface{}) { + oldRC, ok := oldObj.(*resourcev1beta1.ResourceClaim) + if !ok { + return + } + newRC, ok := newObj.(*resourcev1beta1.ResourceClaim) + if !ok { + return + } + + if oldRC.Status.Allocation == nil { + c.enqueueResourceClaim(newRC) + } +} + +func (c *Controller) enqueueResourceClaim(rc *resourcev1beta1.ResourceClaim) { + key, err := keyFunc(rc) + if err != nil { + utilruntime.HandleError(fmt.Errorf("couldn't get key for object %#v: %w", rc, err)) + return + } + c.queueAdd(key) +} + +func (c *Controller) queueAdd(key string) { + c.queue.Add(key) +} + +func (c *Controller) Run(ctx context.Context, workers int) error { + defer utilruntime.HandleCrash() + defer c.queue.ShutDown() + + c.log.Info("Starting controller") + defer c.log.Info("Shutting down controller") + + if !cache.WaitForCacheSync(ctx.Done(), c.hasSynced) { + return fmt.Errorf("failed to wait for caches to sync") + } + + c.log.Info("Starting workers controller") + for i := 0; i < workers; i++ { + go wait.UntilWithContext(ctx, c.worker, time.Second) + } + + <-ctx.Done() + return nil +} + +func (c *Controller) worker(ctx context.Context) { + workFunc := func(ctx context.Context) bool { + key, quit := c.queue.Get() + if quit { + return true + } + defer c.queue.Done(key) + + if err := c.sync(key); err != nil { + c.log.Error("re-enqueuing", slog.String("key", key), slog.Any("err", err)) + c.queue.AddRateLimited(key) + } else { + c.log.Info(fmt.Sprintf("processed ResourceClaim %v", key)) + c.queue.Forget(key) + } + return false + } + for { + quit := workFunc(ctx) + + if quit { + return + } + } +} + +func (c *Controller) sync(key string) error { + return nil +} diff --git a/images/virtualization-dra/internal/informer/informer.go b/images/virtualization-dra/internal/informer/informer.go new file mode 100644 index 0000000000..436a7ecf4f --- /dev/null +++ b/images/virtualization-dra/internal/informer/informer.go @@ -0,0 +1,95 @@ +package informer + +import ( + "log/slog" + "math/rand/v2" + "sync" + "time" + + corev1 "k8s.io/api/core/v1" + resourcev1beta1 "k8s.io/api/resource/v1beta1" + "k8s.io/apimachinery/pkg/fields" + "k8s.io/client-go/kubernetes" + "k8s.io/client-go/tools/cache" +) + +func NewFactory(clientSet *kubernetes.Clientset, resync *time.Duration) *Factory { + var defaultResync time.Duration + if resync != nil { + defaultResync = *resync + } else { + defaultResync = resyncPeriod(12 * time.Hour) + } + + return &Factory{ + clientSet: clientSet, + defaultResync: defaultResync, + informers: make(map[string]cache.SharedIndexInformer), + } +} + +type Factory struct { + clientSet *kubernetes.Clientset + defaultResync time.Duration + + informers map[string]cache.SharedIndexInformer + startedInformers map[string]struct{} + mu sync.Mutex +} + +func (f *Factory) Start(stopCh <-chan struct{}) { + f.mu.Lock() + defer f.mu.Unlock() + + for name, informer := range f.informers { + if _, found := f.startedInformers[name]; found { + // skip informers that have already started. + slog.Info("SKIPPING informer", slog.String("name", name)) + continue + } + slog.Info("STARTING informer", slog.String("name", name)) + go informer.Run(stopCh) + f.startedInformers[name] = struct{}{} + } +} + +func (f *Factory) WaitForCacheSync(stopCh <-chan struct{}) { + var syncs []cache.InformerSynced + + f.mu.Lock() + for name, informer := range f.informers { + slog.Info("Waiting for cache sync of informer", slog.String("name", name)) + syncs = append(syncs, informer.HasSynced) + } + f.mu.Unlock() + + cache.WaitForCacheSync(stopCh, syncs...) +} + +func (f *Factory) ResourceClaim() cache.SharedIndexInformer { + return f.getInformer("resourceClaimInformer", func() cache.SharedIndexInformer { + lw := cache.NewListWatchFromClient(f.clientSet.ResourceV1beta1().RESTClient(), "resourceclaims", corev1.NamespaceAll, fields.Everything()) + return cache.NewSharedIndexInformer(lw, &resourcev1beta1.ResourceClaim{}, f.defaultResync, cache.Indexers{cache.NamespaceIndex: cache.MetaNamespaceIndexFunc}) + }) +} + +func (f *Factory) getInformer(key string, newFunc func() cache.SharedIndexInformer) cache.SharedIndexInformer { + f.mu.Lock() + defer f.mu.Unlock() + + informer, ok := f.informers[key] + if ok { + return informer + } + + informer = newFunc() + f.informers[key] = informer + + return informer +} + +// resyncPeriod computes the time interval a shared informer waits before resyncing with the api server +func resyncPeriod(minResyncPeriod time.Duration) time.Duration { + factor := rand.Float64() + 1 + return time.Duration(float64(minResyncPeriod.Nanoseconds()) * factor) +} diff --git a/images/virtualization-dra/internal/usbip/attacher.go b/images/virtualization-dra/internal/usbip/attacher.go new file mode 100644 index 0000000000..19f79ebea9 --- /dev/null +++ b/images/virtualization-dra/internal/usbip/attacher.go @@ -0,0 +1,17 @@ +package usbip + +func NewUSBAttacher() USBAttacher { + return &usbAttacher{} +} + +type usbAttacher struct{} + +func (a usbAttacher) Attach(busID string) error { + //TODO implement me + panic("implement me") +} + +func (a usbAttacher) Detach(busID string) error { + //TODO implement me + panic("implement me") +} diff --git a/images/virtualization-dra/internal/usbip/binder.go b/images/virtualization-dra/internal/usbip/binder.go new file mode 100644 index 0000000000..c7ff28adf9 --- /dev/null +++ b/images/virtualization-dra/internal/usbip/binder.go @@ -0,0 +1,174 @@ +package usbip + +import ( + "fmt" + "os" + "path/filepath" + "strings" +) + +func NewUSBBinder() USBBinder { + return &usbBinder{} +} + +type usbBinder struct{} + +// Bind binds the USB device to the USBIP server. +// https://github.com/torvalds/linux/blob/40fbbd64bba6c6e7a72885d2f59b6a3be9991eeb/tools/usb/usbip/src/usbip_bind.c#L130 +func (b *usbBinder) Bind(busID string) error { + devInfo, err := b.getUSBDeviceInfo(busID) + if err != nil { + return fmt.Errorf("device with bus ID %s does not exist: %w", busID, err) + } + + if strings.Contains(devInfo.DevPath, "vhci_hcd") { + return fmt.Errorf("bind loop detected: device %s is already attached to vhci_hcd", busID) + } + + err = b.unbindOther(devInfo) + if err != nil { + return fmt.Errorf("failed to unbind other devices: %w", err) + } + + if err = b.modifyMatchBusID(busID, true); err != nil { + return err + } + + if err = b.bindUsbip(busID); err != nil { + return fmt.Errorf("failed to bind usb device: %w: %w", err, b.modifyMatchBusID(busID, false)) + } + + return nil +} + +// Unbind unbinds the USB device from the USBIP server. +// https://github.com/torvalds/linux/blob/40fbbd64bba6c6e7a72885d2f59b6a3be9991eeb/tools/usb/usbip/src/usbip_unbind.c#L30 +func (b *usbBinder) Unbind(busID string) error { + devInfo, err := b.getUSBDeviceInfo(busID) + if err != nil { + return fmt.Errorf("device with bus ID %s does not exist: %w", busID, err) + } + + if devInfo.Driver != usbipHostDriverName { + return fmt.Errorf("device %s is not bound to %s driver", devInfo.BusID, usbipHostDriverName) + } + + if err = b.unbindUsbip(busID); err != nil { + return fmt.Errorf("failed to unbind usb device %s: %w", busID, err) + } + + // notify driver of unbind + if err = b.modifyMatchBusID(busID, false); err != nil { + return fmt.Errorf("failed to modify match bus ID %s: %w", busID, err) + } + + // Trigger new probing + if err = b.rebindUsbip(busID); err != nil { + return fmt.Errorf("failed to rebind usb device %s: %w", busID, err) + } + + return nil +} + +type usbDeviceInfo struct { + BusID string + DevClass string + Driver string + DevPath string + IsHub bool +} + +func (b *usbBinder) getUSBDeviceInfo(busID string) (*usbDeviceInfo, error) { + path := getUSBDevicePath(busID) + + if _, err := os.Stat(path); err != nil { + return nil, err + } + + info := &usbDeviceInfo{ + BusID: busID, + } + + bDevClassPath := filepath.Join(path, "bDeviceClass") + if data, err := os.ReadFile(bDevClassPath); err == nil { + info.DevClass = strings.TrimSpace(string(data)) + info.IsHub = info.DevClass == "09" // 09 = USB Hub class + } + + driverLink := filepath.Join(path, "driver") + if link, err := os.Readlink(driverLink); err == nil { + info.Driver = filepath.Base(link) + } + + ueventPath := filepath.Join(path, "uevent") + if data, err := os.ReadFile(ueventPath); err == nil { + lines := strings.Split(string(data), "\n") + for _, line := range lines { + if strings.HasPrefix(line, "DEVPATH=") { + info.DevPath = strings.TrimPrefix(line, "DEVPATH=") + break + } + } + } + + return info, nil +} + +func (b *usbBinder) unbindOther(devInfo *usbDeviceInfo) error { + if devInfo.IsHub { + return fmt.Errorf("skip unbinding of hub %s", devInfo.BusID) + } + + if devInfo.Driver == "" { + // no driver bound to the device + return nil + } + + if devInfo.Driver == usbipHostDriverName { + return fmt.Errorf("device %s is already bound to %s", devInfo.BusID, usbipHostDriverName) + } + + unbindPath := unbindAttrPath(devInfo.Driver) + + if err := writeSysfsAttr(unbindPath, busIDAttr{busID: devInfo.BusID}); err != nil { + return fmt.Errorf("error unbinding device %s from driver %s: %w", devInfo.BusID, devInfo.Driver, err) + } + + return nil +} + +func (b *usbBinder) bindUsbip(busID string) error { + return writeSysfsAttr(bindAttrPath(usbipHostDriverName), busIDAttr{busID: busID}) +} + +func (b *usbBinder) unbindUsbip(busID string) error { + return writeSysfsAttr(unbindAttrPath(usbipHostDriverName), busIDAttr{busID: busID}) +} + +func (b *usbBinder) rebindUsbip(busID string) error { + return writeSysfsAttr(rebindAttrPath(usbipHostDriverName), busIDAttr{busID: busID}) +} + +func (b *usbBinder) modifyMatchBusID(busID string, add bool) error { + return writeSysfsAttr(matchBusIDAttrPath(usbipHostDriverName), modifyMatchBusIDAttr{busID: busID, add: add}) +} + +type modifyMatchBusIDAttr struct { + busID string + add bool +} + +func (a modifyMatchBusIDAttr) Complete() (string, error) { + if a.add { + return fmt.Sprintf("add %s", a.busID), nil + } + return fmt.Sprintf("del %s", a.busID), nil +} + +type busIDAttr struct { + busID string +} + +func (a busIDAttr) Complete() (string, error) { + return a.busID, nil +} diff --git a/images/virtualization-dra/internal/usbip/interfaces.go b/images/virtualization-dra/internal/usbip/interfaces.go new file mode 100644 index 0000000000..6ee39c6422 --- /dev/null +++ b/images/virtualization-dra/internal/usbip/interfaces.go @@ -0,0 +1,19 @@ +package usbip + +type ServerInterface interface { + USBBinder +} + +type ClientInterface interface { + USBAttacher +} + +type USBBinder interface { + Bind(busID string) error + Unbind(busID string) error +} + +type USBAttacher interface { + Attach(busID string) error + Detach(busID string) error +} diff --git a/images/virtualization-dra/internal/usbip/protocol/common.go b/images/virtualization-dra/internal/usbip/protocol/common.go new file mode 100644 index 0000000000..5aa1658450 --- /dev/null +++ b/images/virtualization-dra/internal/usbip/protocol/common.go @@ -0,0 +1,105 @@ +package protocol + +import ( + "encoding/binary" + "fmt" + "io" +) + +type USBVersion uint16 + +const ( + Version USBVersion = 0x0111 +) + +type Op uint16 + +// Common header for all the kinds of PDUs. +const ( + OpRequest Op = 0x80 << 8 + OpReply Op = 0x00 << 8 +) + +// Dummy Code +const ( + OpUnspec Op = 0x00 + OpReqUnspec Op = OpRequest | OpUnspec + OpRepUnspec Op = OpReply | OpUnspec +) + +// Retrieve USB device information. (still not used) +const ( + OpDevInfo Op = 0x02 + OpReqDevInfo Op = OpRequest | OpDevInfo + OpRepDevInfo Op = OpReply | OpDevInfo +) + +// Import a remote USB device. +const ( + OpImport Op = 0x03 + OpReqImport Op = OpRequest | OpImport + OpRepImport Op = OpReply | OpImport +) + +// Negotiate IPSec encryption key. (still not used) +const ( + OpCrypkey Op = 0x04 + OpReqCrypkey Op = OpRequest | OpCrypkey + OpRepCrypkey Op = OpReply | OpCrypkey +) + +// Retrieve the list of exported USB devices. +const ( + OpDevList Op = 0x05 + OpReqDevList Op = OpRequest | OpDevList + OpRepDevList Op = OpReply | OpDevList +) + +// Export a USB device to a remote host. +const ( + OpExport Op = 0x06 + OpReqExport Op = OpRequest | OpExport + OpRepExport Op = OpReply | OpExport +) + +// un-Export a USB device from a remote host. +const ( + OpUnexport Op = 0x07 + OpReqUnexport Op = OpRequest | OpUnexport + OpRepUnexport Op = OpReply | OpUnexport +) + +type OpStatus uint32 + +const ( + OpStatusOk OpStatus = 0x00000000 + OPStatusError OpStatus = 0x00000001 +) + +type OpCommon struct { + Version USBVersion + Code Op + Status OpStatus +} + +func (op *OpCommon) Decode(r io.Reader) error { + buf := make([]byte, 8) + _, err := io.ReadFull(r, buf) + if err != nil { + return fmt.Errorf("failed to read OpCommon: %w", err) + } + + op.Version = USBVersion(binary.BigEndian.Uint16(buf[0:2])) + op.Code = Op(binary.BigEndian.Uint16(buf[2:4])) + op.Status = OpStatus(binary.BigEndian.Uint32(buf[4:8])) + return nil +} + +func (op *OpCommon) Encode(w io.Writer) error { + buf := make([]byte, 8) + binary.BigEndian.PutUint16(buf[0:2], uint16(op.Version)) + binary.BigEndian.PutUint16(buf[2:4], uint16(op.Code)) + binary.BigEndian.PutUint32(buf[4:8], uint32(op.Status)) + _, err := w.Write(buf) + return err +} diff --git a/images/virtualization-dra/internal/usbip/protocol/device_list.go b/images/virtualization-dra/internal/usbip/protocol/device_list.go new file mode 100644 index 0000000000..f9c0ab8b8a --- /dev/null +++ b/images/virtualization-dra/internal/usbip/protocol/device_list.go @@ -0,0 +1,219 @@ +package protocol + +import ( + "encoding/binary" + "fmt" + "io" +) + +func NewDeviceList(status OpStatus, devices []USBDeviceInfo) *DeviceList { + return &DeviceList{ + OpCommon: OpCommon{ + Version: Version, + Code: OpReqDevList, + Status: status, + }, + Ndev: uint32(len(devices)), + Devices: devices, + } +} + +type DeviceList struct { + OpCommon + + Ndev uint32 + Devices []USBDeviceInfo +} + +func (d *DeviceList) Encode(w io.Writer) error { + if err := d.OpCommon.Encode(w); err != nil { + return fmt.Errorf("failed to encode OpCommon: %w", err) + } + + buf := make([]byte, 4) + binary.BigEndian.PutUint32(buf[0:4], d.Ndev) + + if _, err := w.Write(buf); err != nil { + return fmt.Errorf("failed to write Ndev to writer: %w", err) + } + + for _, dev := range d.Devices { + if err := dev.Encode(w); err != nil { + return fmt.Errorf("failed to encode USBDeviceInfo: %w", err) + } + } + + return nil +} + +const ( + sysfsPathMax = 256 + sysfsBusIdMax = 32 +) + +type USBDeviceInfo struct { + USBDevice + Interfaces []USBDeviceInterface +} + +func (d *USBDeviceInfo) Decode(r io.Reader) error { + if err := d.USBDevice.Decode(r); err != nil { + return fmt.Errorf("unable to decode USBDevice: %w", err) + } + + d.Interfaces = make([]USBDeviceInterface, d.BNumInterfaces) + for i := 0; i < int(d.BNumInterfaces); i++ { + if err := d.Interfaces[i].Decode(r); err != nil { + return fmt.Errorf("unable to decode USBDeviceInterface: %w", err) + } + } + + return nil +} + +func (d *USBDeviceInfo) Encode(w io.Writer) error { + if err := d.USBDevice.Encode(w); err != nil { + return fmt.Errorf("unable to encode USBDevice: %w", err) + } + + for _, iface := range d.Interfaces { + if err := iface.Encode(w); err != nil { + return fmt.Errorf("unable to encode USBDeviceInterface: %w", err) + } + } + + return nil +} + +type USBDevice struct { + Path [sysfsPathMax]byte + BusID [sysfsBusIdMax]byte + + Busnum uint32 + Devnum uint32 + Speed uint32 + + IDVendor uint16 + IDProduct uint16 + BcdDevice uint16 + + BDeviceClass uint8 + BDeviceSubClass uint8 + BDeviceProtocol uint8 + BConfigurationValue uint8 + BNumConfigurations uint8 + BNumInterfaces uint8 +} + +func (u *USBDevice) Decode(r io.Reader) error { + buf := make([]byte, sysfsPathMax+sysfsBusIdMax+12+8+6) + _, err := io.ReadFull(r, buf) + if err != nil { + return fmt.Errorf("failed to read USBDevice from reader: %w", err) + } + + copy(u.Path[:], buf[0:sysfsPathMax]) + copy(u.BusID[:], buf[sysfsPathMax:sysfsPathMax+sysfsBusIdMax]) + + pass := sysfsPathMax + sysfsBusIdMax + + u.Busnum = binary.BigEndian.Uint32(buf[pass : pass+4]) + pass += 4 + u.Devnum = binary.BigEndian.Uint32(buf[pass : pass+4]) + pass += 4 + u.Speed = binary.BigEndian.Uint32(buf[pass : pass+4]) + pass += 4 + + u.IDVendor = binary.BigEndian.Uint16(buf[pass : pass+2]) + pass += 2 + u.IDProduct = binary.BigEndian.Uint16(buf[pass : pass+2]) + pass += 2 + u.BcdDevice = binary.BigEndian.Uint16(buf[pass : pass+2]) + pass += 2 + + u.BDeviceClass = uint8(binary.BigEndian.Uint16(buf[pass : pass+1])) + pass += 1 + u.BDeviceSubClass = uint8(binary.BigEndian.Uint16(buf[pass : pass+1])) + pass += 1 + u.BDeviceProtocol = uint8(binary.BigEndian.Uint16(buf[pass : pass+1])) + pass += 1 + u.BConfigurationValue = uint8(binary.BigEndian.Uint16(buf[pass : pass+1])) + pass += 1 + u.BNumConfigurations = uint8(binary.BigEndian.Uint16(buf[pass : pass+1])) + pass += 1 + u.BNumInterfaces = uint8(binary.BigEndian.Uint16(buf[pass : pass+1])) + + return nil +} + +func (u *USBDevice) Encode(w io.Writer) error { + buf := make([]byte, sysfsPathMax+sysfsBusIdMax+12+8+6) + + copy(buf[0:sysfsPathMax], u.Path[:]) + copy(buf[sysfsPathMax:sysfsPathMax+sysfsBusIdMax], u.BusID[:]) + + pass := sysfsPathMax + sysfsBusIdMax + + binary.BigEndian.PutUint32(buf[pass:pass+4], u.Busnum) + pass += 4 + binary.BigEndian.PutUint32(buf[pass:pass+4], u.Devnum) + pass += 4 + binary.BigEndian.PutUint32(buf[pass:pass+4], u.Speed) + pass += 4 + + binary.BigEndian.PutUint16(buf[pass:pass+2], u.IDVendor) + pass += 2 + binary.BigEndian.PutUint16(buf[pass:pass+2], u.IDProduct) + pass += 2 + binary.BigEndian.PutUint16(buf[pass:pass+2], u.BcdDevice) + pass += 2 + + binary.BigEndian.PutUint16(buf[pass:pass+1], uint16(u.BDeviceClass)) + pass += 1 + binary.BigEndian.PutUint16(buf[pass:pass+1], uint16(u.BDeviceSubClass)) + pass += 1 + binary.BigEndian.PutUint16(buf[pass:pass+1], uint16(u.BDeviceProtocol)) + pass += 1 + binary.BigEndian.PutUint16(buf[pass:pass+1], uint16(u.BConfigurationValue)) + pass += 1 + binary.BigEndian.PutUint16(buf[pass:pass+1], uint16(u.BNumConfigurations)) + pass += 1 + binary.BigEndian.PutUint16(buf[pass:pass+1], uint16(u.BNumInterfaces)) + + _, err := w.Write(buf) + return fmt.Errorf("failed to write USBDevice to writer: %w", err) +} + +type USBDeviceInterface struct { + BInterfaceClass uint8 + BInterfaceSubClass uint8 + BInterfaceProtocol uint8 + padding uint8 +} + +func (u *USBDeviceInterface) Decode(r io.Reader) error { + buf := make([]byte, 4) + _, err := io.ReadFull(r, buf) + if err != nil { + return fmt.Errorf("failed to read USBDeviceInterface from reader: %w", err) + } + + u.BInterfaceClass = buf[0] + u.BInterfaceSubClass = buf[1] + u.BInterfaceProtocol = buf[2] + u.padding = buf[3] + + return nil +} + +func (u *USBDeviceInterface) Encode(w io.Writer) error { + buf := make([]byte, 4) + + buf[0] = u.BInterfaceClass + buf[1] = u.BInterfaceSubClass + buf[2] = u.BInterfaceProtocol + buf[3] = u.padding + + _, err := w.Write(buf) + return fmt.Errorf("failed to write USBDeviceInterface to writer: %w", err) +} diff --git a/images/virtualization-dra/internal/usbip/sysfs.go b/images/virtualization-dra/internal/usbip/sysfs.go new file mode 100644 index 0000000000..38677d32ec --- /dev/null +++ b/images/virtualization-dra/internal/usbip/sysfs.go @@ -0,0 +1,57 @@ +package usbip + +import ( + "fmt" + "os" +) + +type sysfsAttr interface { + Complete() (string, error) +} + +func writeSysfsAttr(attrPath string, newValue sysfsAttr) error { + value, err := newValue.Complete() + if err != nil { + return err + } + + f, err := os.OpenFile(attrPath, os.O_WRONLY, 0644) + if err != nil { + return err + } + defer f.Close() + + _, err = f.WriteString(value) + return err +} + +const ( + bindAttrPathTmpl = "/sys/bus/usb/drivers/%s/bind" + unbindAttrPathTmpl = "/sys/bus/usb/drivers/%s/unbind" + rebindAttrPathTmpl = "/sys/bus/usb/drivers/%s/rebind" + matchBusIDAttrPathTmpl = "/sys/bus/usb/drivers/%s/match_busid" + + usbDevicesTmpl = "/sys/bus/usb/devices/%s" + + usbipHostDriverName = "usbip-host" +) + +func getUSBDevicePath(busID string) string { + return fmt.Sprintf(usbDevicesTmpl, busID) +} + +func bindAttrPath(driver string) string { + return fmt.Sprintf(bindAttrPathTmpl, driver) +} + +func unbindAttrPath(driver string) string { + return fmt.Sprintf(unbindAttrPathTmpl, driver) +} + +func rebindAttrPath(driver string) string { + return fmt.Sprintf(rebindAttrPathTmpl, driver) +} + +func matchBusIDAttrPath(driver string) string { + return fmt.Sprintf(matchBusIDAttrPathTmpl, driver) +} diff --git a/images/virtualization-dra/internal/usbip/usbip.go b/images/virtualization-dra/internal/usbip/usbip.go new file mode 100644 index 0000000000..0f2a4d42ea --- /dev/null +++ b/images/virtualization-dra/internal/usbip/usbip.go @@ -0,0 +1 @@ +package usbip diff --git a/images/virtualization-dra/internal/usbip/usbipd.go b/images/virtualization-dra/internal/usbip/usbipd.go new file mode 100644 index 0000000000..0eba3043db --- /dev/null +++ b/images/virtualization-dra/internal/usbip/usbipd.go @@ -0,0 +1,191 @@ +package usbip + +import ( + "crypto/tls" + "fmt" + "log/slog" + "net" + "strconv" + "sync" + "sync/atomic" + "time" + + "github.com/deckhouse/virtualization-dra/internal/usbip/protocol" +) + +const ( + defaultMaxTCPConnection = 100 + defaultGracefulShutdownTimeout = 30 * time.Second +) + +func makeOptions(opts ...Option) *options { + options := &options{ + maxTCPConnection: defaultMaxTCPConnection, + gracefulShutdownTimeout: defaultGracefulShutdownTimeout, + } + + for _, opt := range opts { + opt(options) + } + + return options +} + +type options struct { + tlsConfig *tls.Config + gracefulShutdownTimeout time.Duration + maxTCPConnection int +} + +type Option func(*options) + +func WithTLSConfig(tlsConfig *tls.Config) Option { + return func(o *options) { + o.tlsConfig = tlsConfig + } +} +func WithGracefulShutdownTimeout(timeout time.Duration) Option { + return func(o *options) { + o.gracefulShutdownTimeout = timeout + } +} + +func WithMaxTCPConnection(maxTCPConnection int) Option { + return func(o *options) { + o.maxTCPConnection = maxTCPConnection + } +} + +func NewUSBIPD(port int, opts ...Option) *USBIPD { + options := makeOptions(opts...) + return &USBIPD{ + addr: ":" + strconv.Itoa(port), + tlsConfig: options.tlsConfig, + gracefulShutdownTimeout: options.gracefulShutdownTimeout, + logger: slog.Default().With(slog.String("component", "usbipd")), + maxTCPConnection: options.maxTCPConnection, + } + +} + +type USBIPD struct { + addr string + tlsConfig *tls.Config + gracefulShutdownTimeout time.Duration + logger *slog.Logger + maxTCPConnection int + + listener net.Listener + connWg sync.WaitGroup + connCount atomic.Int64 + quit chan struct{} +} + +func (u *USBIPD) Start() (err error) { + if u.tlsConfig != nil { + u.listener, err = tls.Listen("tcp", u.addr, u.tlsConfig) + if err != nil { + return err + } + } else { + u.listener, err = net.Listen("tcp", u.addr) + if err != nil { + return err + } + } + + u.connWg.Add(1) + go func() { + var connCount atomic.Int64 + defer u.connWg.Done() + for { + conn, err := u.listener.Accept() + // Error occurred when + // 1. Connection error + // 2. The listener is closed (quit channel is closed) + if err != nil { + select { + case <-u.quit: + return + default: + u.logger.Error("unsable to accept request", slog.String("address", u.addr), slog.Any("err", err)) + } + } else { + // Check if TCP connection reached the limit specified in given config + count := connCount.Load() + if count+1 > int64(u.maxTCPConnection) { + u.logger.Error("maximum TCP connection reached, drop the connection", slog.Int64("count", count)) + conn.Close() + continue + } + + // TCP connection handler + u.connWg.Add(1) + connCount.Add(1) + go func() { + defer connCount.Add(-1) + defer u.connWg.Done() + u.logger.Info("new connection established", slog.String("addr", conn.RemoteAddr().String())) + if err := u.handleConnection(conn); err != nil { + u.logger.Error("failed to handle connection", slog.Any("err", err), slog.String("addr", conn.RemoteAddr().String())) + } + u.logger.Info("connection closed", slog.String("addr", conn.RemoteAddr().String())) + }() + } + } + }() + + return nil +} + +// https://docs.kernel.org/usb/usbip_protocol.html +func (u *USBIPD) handleConnection(conn net.Conn) error { + opCommon := protocol.OpCommon{} + if err := opCommon.Decode(conn); err != nil { + return fmt.Errorf("failed to decode OpCommon: %w", err) + } + + if opCommon.Version != protocol.Version { + return fmt.Errorf("unsupported USBIP version: %d", opCommon.Version) + } + + if opCommon.Status != protocol.OpStatusOk { + return fmt.Errorf("request failed: %d", opCommon.Status) + } + + switch opCommon.Code { + case protocol.OpReqDevList: + if err := u.handleDeviceList(conn); err != nil { + return fmt.Errorf("failed to handle OpReqDevList: %w", err) + } + case protocol.OpReqImport: + if err := u.handleImportRequest(conn); err != nil { + return fmt.Errorf("failed to handle OpReqImport: %w", err) + } + case protocol.OpReqDevInfo, protocol.OpReqCrypkey: + // nothing to do + default: + return fmt.Errorf("unsupported OpCommon.Code: %d", opCommon.Code) + } + + return nil +} + +func (u *USBIPD) handleDeviceList(conn net.Conn) error { + info, err := u.getUSBDeviceInfo() + if err != nil { + devList := protocol.NewDeviceList(protocol.OPStatusError, nil) + return devList.Encode(conn) + } + + devList := protocol.NewDeviceList(protocol.OpStatusOk, info) + return devList.Encode(conn) +} + +func (u *USBIPD) handleImportRequest(conn net.Conn) error { + return nil +} + +func (u *USBIPD) getUSBDeviceInfo() ([]protocol.USBDeviceInfo, error) { + return nil, nil +} diff --git a/images/virtualization-dra/internal/usbip/usbipd_config.go b/images/virtualization-dra/internal/usbip/usbipd_config.go new file mode 100644 index 0000000000..bc0f982db0 --- /dev/null +++ b/images/virtualization-dra/internal/usbip/usbipd_config.go @@ -0,0 +1,140 @@ +package usbip + +import ( + "crypto/tls" + "crypto/x509" + "flag" + "fmt" + "os" + "time" +) + +type ClientAuthType tls.ClientAuthType + +func (c *ClientAuthType) String() string { + cc := tls.ClientAuthType(*c) + return cc.String() +} + +func (c *ClientAuthType) Set(s string) error { + switch s { + case "NoClientCert": + *c = ClientAuthType(tls.NoClientCert) + case "RequestClientCert": + *c = ClientAuthType(tls.RequestClientCert) + case "RequireAnyClientCert": + *c = ClientAuthType(tls.RequireAnyClientCert) + case "VerifyClientCertIfGiven": + *c = ClientAuthType(tls.VerifyClientCertIfGiven) + case "RequireAndVerifyClientCert": + *c = ClientAuthType(tls.RequireAndVerifyClientCert) + default: + return fmt.Errorf("invalid client auth type: %s", s) + } + return nil +} + +type USBIPDConfig struct { + ServerCertificateFile string + ServerKeyFile string + + RootCAFile string + + ClientCAFile string + ClientAuthType *ClientAuthType + clientAuthType int + InsecureSkipVerify bool + + Port int + GracefulShutdownTimeout time.Duration +} + +func (c *USBIPDConfig) AddFlags(fs *flag.FlagSet) { + fs.IntVar(&c.Port, "usbipd-port", 0, "USBIPD port") + fs.StringVar(&c.ServerCertificateFile, "usbipd-server-certificate-file", "", "USBIPD server certificate file") + fs.StringVar(&c.ServerKeyFile, "usbipd-server-key-file", "", "USBIPD server key file") + fs.StringVar(&c.RootCAFile, "usbipd-root-ca-file", "", "USBIPD root CA file") + fs.StringVar(&c.ClientCAFile, "usbipd-client-ca-file", "", "USBIPD client CA file") + fs.Var(c.ClientAuthType, "usbipd-client-auth-type", "USBIPD client auth type") + fs.BoolVar(&c.InsecureSkipVerify, "usbipd-insecure-skip-verify", false, "USBIPD insecure skip verify") + fs.DurationVar(&c.GracefulShutdownTimeout, "usbipd-graceful-shutdown-timeout", 0, "USBIPD graceful shutdown timeout") +} + +func (c *USBIPDConfig) Validate() error { + if c.Port == 0 { + return fmt.Errorf("port is required") + } + + if c.ServerCertificateFile != "" && c.ServerKeyFile == "" { + return fmt.Errorf("server key file is required if server certificate file is provided") + } + + if c.ServerCertificateFile == "" && c.ServerKeyFile != "" { + return fmt.Errorf("server certificate file is required if server key file is provided") + } + + return nil +} + +func (c *USBIPDConfig) Complete() (*USBIPD, error) { + var opts []Option + if c.GracefulShutdownTimeout != 0 { + opts = append(opts, WithGracefulShutdownTimeout(c.GracefulShutdownTimeout)) + } + + var serverCertificate *tls.Certificate + if c.ServerCertificateFile != "" && c.ServerKeyFile != "" { + certificate, err := tls.LoadX509KeyPair(c.ServerCertificateFile, c.ServerKeyFile) + if err != nil { + return nil, err + } + serverCertificate = &certificate + } + + rootCACertPool, err := loadCAPoolFromFile(c.RootCAFile) + if err != nil { + return nil, err + } + + clientCACertPool, err := loadCAPoolFromFile(c.ClientCAFile) + if err != nil { + return nil, err + } + + if serverCertificate != nil || rootCACertPool != nil || clientCACertPool != nil { + tlsConfig := &tls.Config{ + RootCAs: rootCACertPool, + ClientCAs: clientCACertPool, + InsecureSkipVerify: c.InsecureSkipVerify, + } + if serverCertificate != nil { + tlsConfig.Certificates = []tls.Certificate{*serverCertificate} + } + if c.ClientAuthType != nil { + tlsConfig.ClientAuth = tls.ClientAuthType(*c.ClientAuthType) + } + + opts = append(opts, WithTLSConfig(tlsConfig)) + } + + return NewUSBIPD(c.Port, opts...), nil + +} + +func loadCAPoolFromFile(file string) (*x509.CertPool, error) { + if file == "" { + return nil, nil + } + + caCertPool := x509.NewCertPool() + caCertPEMBlock, err := os.ReadFile(file) + if err != nil { + return nil, fmt.Errorf("failed to read CA certificate: %w", err) + } + + if !caCertPool.AppendCertsFromPEM(caCertPEMBlock) { + return nil, fmt.Errorf("failed to parse CA certificate") + } + + return caCertPool, nil +}