Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
27 changes: 27 additions & 0 deletions cmd/nvidia-ctk-installer/container/runtime/nri/logger.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package nri

import (
"context"

"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
)

type toNriLogger struct {
logger.Interface
}

func (l toNriLogger) Debugf(_ context.Context, fmt string, args ...interface{}) {
l.Interface.Debugf(fmt, args...)
}

func (l toNriLogger) Errorf(_ context.Context, fmt string, args ...interface{}) {
l.Interface.Errorf(fmt, args...)
}

func (l toNriLogger) Infof(_ context.Context, fmt string, args ...interface{}) {
l.Interface.Infof(fmt, args...)
}

func (l toNriLogger) Warnf(_ context.Context, fmt string, args ...interface{}) {
l.Warningf(fmt, args...)
}
142 changes: 142 additions & 0 deletions cmd/nvidia-ctk-installer/container/runtime/nri/plugin.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
/**
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/

package nri

import (
"context"
"fmt"
"os"
"strings"

"github.com/containerd/nri/pkg/api"
"github.com/containerd/nri/pkg/plugin"
"github.com/containerd/nri/pkg/stub"

"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
)

// Compile-time interface checks
var (
_ stub.Plugin = (*Plugin)(nil)
)

const (
// nriCDIAnnotationDomain is the domain name used for CDI device annotations
nriCDIAnnotationDomain = "nvidia.cdi.k8s.io"
)

type Plugin struct {
ctx context.Context
logger logger.Interface

stub stub.Stub
}

// NewPlugin creates a new NRI plugin for injecting CDI devices
func NewPlugin(ctx context.Context, logger logger.Interface) *Plugin {
return &Plugin{
ctx: ctx,
logger: logger,
}
}

// CreateContainer handles container creation requests.
func (p *Plugin) CreateContainer(_ context.Context, pod *api.PodSandbox, ctr *api.Container) (*api.ContainerAdjustment, []*api.ContainerUpdate, error) {
adjust := &api.ContainerAdjustment{}

if err := p.injectCDIDevices(pod, ctr, adjust); err != nil {
return nil, nil, err
}

return adjust, nil, nil
}

func (p *Plugin) injectCDIDevices(pod *api.PodSandbox, ctr *api.Container, a *api.ContainerAdjustment) error {
ctx := p.ctx
pluginLogger := p.stub.Logger()

devices := parseCDIDevices(pod, nriCDIAnnotationDomain, ctr.Name)
if len(devices) == 0 {
pluginLogger.Debugf(ctx, "%s: no CDI devices annotated...", containerName(pod, ctr))
return nil
}

for _, name := range devices {
a.AddCDIDevice(
&api.CDIDevice{
Name: name,
},
)
pluginLogger.Infof(ctx, "%s: injected CDI device %q...", containerName(pod, ctr), name)
}

return nil
}

func parseCDIDevices(pod *api.PodSandbox, key, container string) []string {
cdiDeviceNames, ok := plugin.GetEffectiveAnnotation(pod, key, container)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok. So we've switched to using the format as implemented in the NRI plugin directly? That's better. We may need to update our documentation / implementation to use nvidia.cdi.k8s.io/pod for the cases where all containers need access to the device.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct. Yes, I will make sure that this is reflected in our documentation

if !ok {
return nil
}

cdiDevices := strings.Split(cdiDeviceNames, ",")
return cdiDevices
}

// Construct a container name for log messages.
func containerName(pod *api.PodSandbox, container *api.Container) string {
if pod != nil {
return pod.Name + "/" + container.Name
}
return container.Name
}

// Start starts the NRI plugin
func (p *Plugin) Start(ctx context.Context, nriSocketPath, nriPluginIdx string) error {
pluginOpts := []stub.Option{
stub.WithPluginIdx(nriPluginIdx),
stub.WithLogger(toNriLogger{p.logger}),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not for this PR: Note that the stub has a WithOnClose() option that allows one to specify a callback when connection is lost. Does it make sense to at least pass a function that logs so that we can collect data on when / how often this happens?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I can look into that

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is done

stub.WithOnClose(func() {
p.logger.Infof("NRI ttrpc connection to %s is down. NRI plugin stopped.", nriSocketPath)
}),
}
if len(nriSocketPath) > 0 {
_, err := os.Stat(nriSocketPath)
if err != nil {
return fmt.Errorf("failed to find valid nri socket at %s: %w", nriSocketPath, err)
}
pluginOpts = append(pluginOpts, stub.WithSocketPath(nriSocketPath))
}

var err error
if p.stub, err = stub.New(p, pluginOpts...); err != nil {
return fmt.Errorf("failed to initialise plugin at %s: %w", nriSocketPath, err)
}
err = p.stub.Start(ctx)
if err != nil {
return fmt.Errorf("plugin exited with error: %w", err)
}
return nil
}

// Stop stops the NRI plugin
func (p *Plugin) Stop() {
if p == nil || p.stub == nil {
return
}
p.stub.Stop()
}
100 changes: 90 additions & 10 deletions cmd/nvidia-ctk-installer/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,14 @@ import (
"os/signal"
"path/filepath"
"syscall"
"time"

nriapi "github.com/containerd/nri/pkg/api"
"github.com/urfave/cli/v3"
"golang.org/x/sys/unix"

"github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-ctk-installer/container/runtime"
"github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-ctk-installer/container/runtime/nri"
"github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-ctk-installer/toolkit"
"github.com/NVIDIA/nvidia-container-toolkit/internal/info"
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
Expand All @@ -26,6 +29,8 @@ const (
toolkitSubDir = "toolkit"

defaultRuntime = "docker"

defaultNRIPluginIdx uint = 10
)

var availableRuntimes = map[string]struct{}{"docker": {}, "crio": {}, "containerd": {}}
Expand All @@ -44,6 +49,10 @@ type options struct {
sourceRoot string
packageType string

enableNRIPlugin bool
nriPluginIndex uint
nriSocket string

toolkitOptions toolkit.Options
runtimeOptions runtime.Options
}
Expand Down Expand Up @@ -73,7 +82,7 @@ type app struct {
toolkit *toolkit.Installer
}

// NewApp creates the CLI app fro the specified options.
// NewApp creates the CLI app from the specified options.
func NewApp(logger logger.Interface) *cli.Command {
a := app{
logger: logger,
Expand All @@ -93,8 +102,8 @@ func (a app) build() *cli.Command {
Before: func(ctx context.Context, cmd *cli.Command) (context.Context, error) {
return ctx, a.Before(cmd, &options)
},
Action: func(_ context.Context, cmd *cli.Command) error {
return a.Run(cmd, &options)
Action: func(ctx context.Context, cmd *cli.Command) error {
return a.Run(ctx, cmd, &options)
},
Flags: []cli.Flag{
&cli.BoolFlag{
Expand All @@ -104,6 +113,34 @@ func (a app) build() *cli.Command {
Destination: &options.noDaemon,
Sources: cli.EnvVars("NO_DAEMON"),
},
&cli.BoolFlag{
Name: "enable-nri-plugin",
Aliases: []string{"p"},
Usage: "if set to true, the toolkit will stand up an NRI Plugin server used to inject CDI devices " +
"to containers. Note that this option will be ignored if --no-daemon is set.",
Destination: &options.enableNRIPlugin,
Sources: cli.EnvVars("ENABLE_NRI_PLUGIN"),
},
&cli.UintFlag{
Name: "nri-plugin-index",
Usage: "Specify the plugin index to register to NRI",
Value: defaultNRIPluginIdx,
Destination: &options.nriPluginIndex,
Sources: cli.EnvVars("NRI_PLUGIN_INDEX"),
Action: func(ctx context.Context, c *cli.Command, u uint) error {
if u > 99 {
return fmt.Errorf("nri-plugin-index must be in the range [0,99]")
}
return nil
},
},
&cli.StringFlag{
Name: "nri-socket",
Usage: "Specify the path to the NRI socket file to register the NRI plugin server",
Value: nriapi.DefaultSocketPath,
Destination: &options.nriSocket,
Sources: cli.EnvVars("NRI_SOCKET"),
},
&cli.StringFlag{
Name: "runtime",
Aliases: []string{"r"},
Expand Down Expand Up @@ -194,7 +231,7 @@ func (a *app) validateFlags(c *cli.Command, o *options) error {
// Run installs the NVIDIA Container Toolkit and updates the requested runtime.
// If the application is run as a daemon, the application waits and unconfigures
// the runtime on termination.
func (a *app) Run(c *cli.Command, o *options) error {
func (a *app) Run(ctx context.Context, c *cli.Command, o *options) error {
err := a.initialize(o.pidFile)
if err != nil {
return fmt.Errorf("unable to initialize: %v", err)
Expand All @@ -216,20 +253,32 @@ func (a *app) Run(c *cli.Command, o *options) error {
return fmt.Errorf("unable to install toolkit: %v", err)
}

err = runtime.Setup(c, &o.runtimeOptions, o.runtime)
if err != nil {
return fmt.Errorf("unable to setup runtime: %v", err)
if !o.enableNRIPlugin {
err = runtime.Setup(c, &o.runtimeOptions, o.runtime)
if err != nil {
return fmt.Errorf("unable to setup runtime: %w", err)
}
}

if !o.noDaemon {
if o.enableNRIPlugin {
nriPlugin, err := a.startNRIPluginServer(ctx, o)
if err != nil {
a.logger.Errorf("unable to start NRI plugin server: %v", err)
}
defer nriPlugin.Stop()
}

err = a.waitForSignal()
if err != nil {
return fmt.Errorf("unable to wait for signal: %v", err)
}

err = runtime.Cleanup(c, &o.runtimeOptions, o.runtime)
if err != nil {
return fmt.Errorf("unable to cleanup runtime: %v", err)
if !o.enableNRIPlugin {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@elezar This should no longer be required once #1521 gets in. I can make the necessary changes in a follow-up if this PR gets merged before #1521

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure let's keep these separate for now. I can rebase #1521 once we merge this.

err = runtime.Cleanup(c, &o.runtimeOptions, o.runtime)
if err != nil {
return fmt.Errorf("unable to cleanup runtime: %v", err)
}
}
}

Expand Down Expand Up @@ -287,6 +336,37 @@ func (a *app) waitForSignal() error {
return nil
}

func (a *app) startNRIPluginServer(ctx context.Context, opts *options) (*nri.Plugin, error) {
a.logger.Infof("Starting the NRI Plugin server....")

const (
maxRetryAttempts = 5
retryBackoff = 2 * time.Second
)

plugin := nri.NewPlugin(ctx, a.logger)
retriable := func() error {
return plugin.Start(ctx, opts.nriSocket, fmt.Sprintf("%02d", opts.nriPluginIndex))
}
var err error
for i := 0; i < maxRetryAttempts; i++ {
err = retriable()
if err == nil {
break
}
a.logger.Warningf("Attempt %d - error starting the NRI plugin: %v", i+1, err)
if i == maxRetryAttempts-1 {
break
}
time.Sleep(retryBackoff)
}
if err != nil {
a.logger.Errorf("Max retries reached %d/%d, aborting", maxRetryAttempts, maxRetryAttempts)
return nil, err
}
return plugin, nil
}

func (a *app) shutdown(pidFile string) {
a.logger.Infof("Shutting Down")

Expand Down
1 change: 1 addition & 0 deletions cmd/nvidia-ctk-installer/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,7 @@ version = 2
"nvidia-ctk-installer",
"--toolkit-install-dir=" + toolkitRoot,
"--no-daemon",
"--enable-nri-plugin=false",
"--cdi-output-dir=" + cdiOutputDir,
"--config=" + runtimeConfigFile,
"--drop-in-config=" + runtimeDropInConfigFile,
Expand Down
12 changes: 10 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ go 1.25.0
require (
github.com/NVIDIA/go-nvlib v0.9.1-0.20251202135446-d0f42ba016dd
github.com/NVIDIA/go-nvml v0.13.0-1
github.com/containerd/nri v0.11.0
github.com/google/uuid v1.6.0
github.com/moby/sys/mountinfo v0.7.2
github.com/moby/sys/reexec v0.1.0
Expand All @@ -25,18 +26,25 @@ require (

require (
cyphar.com/go-pathrs v0.2.1 // indirect
github.com/containerd/log v0.1.0 // indirect
github.com/containerd/ttrpc v1.2.7 // indirect
github.com/cyphar/filepath-securejoin v0.6.1 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/fsnotify/fsnotify v1.7.0 // indirect
github.com/golang/protobuf v1.5.3 // indirect
github.com/hashicorp/errwrap v1.1.0 // indirect
github.com/kr/pretty v0.3.1 // indirect
github.com/knqyf263/go-plugin v0.9.0 // indirect
github.com/kr/text v0.2.0 // indirect
github.com/moby/sys/capability v0.4.0 // indirect
github.com/opencontainers/cgroups v0.0.6 // indirect
github.com/opencontainers/runtime-tools v0.9.1-0.20251114084447-edf4cb3d2116 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/rogpeppe/go-internal v1.11.0 // indirect
github.com/tetratelabs/wazero v1.10.1 // indirect
github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb // indirect
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20230731190214-cbb8c96f2d6d // indirect
google.golang.org/grpc v1.57.1 // indirect
google.golang.org/protobuf v1.36.8 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
sigs.k8s.io/yaml v1.4.0 // indirect
)
Loading
Loading