Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions cmd/nvidia-ctk-installer/container/runtime/nri/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,16 @@ type Plugin struct {
ctx context.Context
logger logger.Interface

stub stub.Stub
namespace string
stub stub.Stub
}

// NewPlugin creates a new NRI plugin for injecting CDI devices
func NewPlugin(ctx context.Context, logger logger.Interface) *Plugin {
func NewPlugin(ctx context.Context, logger logger.Interface, namespace string) *Plugin {
return &Plugin{
ctx: ctx,
logger: logger,
ctx: ctx,
logger: logger,
namespace: namespace,
}
}

Expand All @@ -69,25 +71,31 @@ func (p *Plugin) injectCDIDevices(pod *api.PodSandbox, ctr *api.Container, a *ap
ctx := p.ctx
pluginLogger := p.stub.Logger()

devices := parseCDIDevices(pod, nriCDIAnnotationDomain, ctr.Name)
devices := p.parseCDIDevices(pod, nriCDIAnnotationDomain, ctr.Name)
if len(devices) == 0 {
pluginLogger.Debugf(ctx, "%s: no CDI devices annotated...", containerName(pod, ctr))
return nil
}

pluginLogger.Infof(ctx, "%s: injecting CDI devices %v...", containerName(pod, ctr), devices)
for _, name := range devices {
a.AddCDIDevice(
&api.CDIDevice{
Name: name,
},
)
pluginLogger.Infof(ctx, "%s: injected CDI device %q...", containerName(pod, ctr), name)
}

return nil
}

func parseCDIDevices(pod *api.PodSandbox, key, container string) []string {
// parseCDIDevices processes the podSpec and determines which containers which need CDI devices injected to them
func (p *Plugin) parseCDIDevices(pod *api.PodSandbox, key, container string) []string {
if p.namespace != pod.Namespace {
p.logger.Debugf("pod %s/%s is not in the toolkit's namespace %s. Skipping CDI device injection...", pod.Namespace, pod.Name, p.namespace)
return nil
}

cdiDeviceNames, ok := plugin.GetEffectiveAnnotation(pod, key, container)
if !ok {
return nil
Expand Down
13 changes: 12 additions & 1 deletion cmd/nvidia-ctk-installer/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ type options struct {
enableNRIPlugin bool
nriPluginIndex uint
nriSocket string
nriNamespace string

toolkitOptions toolkit.Options

Expand Down Expand Up @@ -150,6 +151,12 @@ func (a app) build() *cli.Command {
Destination: &options.nriSocket,
Sources: cli.EnvVars("NRI_SOCKET"),
},
&cli.StringFlag{
Name: "nri-namespace",
Usage: "Specify the kubernetes namespace the toolkit's NRI plugin is running in.",
Destination: &options.nriNamespace,
Sources: cli.EnvVars("NRI_NAMESPACE"),
},
&cli.StringFlag{
Name: "runtime",
Aliases: []string{"r"},
Expand Down Expand Up @@ -226,6 +233,10 @@ func (a *app) validateFlags(c *cli.Command, o *options) error {
return fmt.Errorf("invalid toolkit.pid path %v", o.pidFile)
}

if o.enableNRIPlugin && len(o.nriNamespace) == 0 {
return fmt.Errorf("the NRI namespace must be specified when the NRI plugin is enabled")
}

if err := a.toolkit.ValidateOptions(&o.toolkitOptions); err != nil {
return err
}
Expand Down Expand Up @@ -353,7 +364,7 @@ func (a *app) startNRIPluginServer(ctx context.Context, opts *options) (*nri.Plu
retryBackoff = 2 * time.Second
)

plugin := nri.NewPlugin(ctx, a.logger)
plugin := nri.NewPlugin(ctx, a.logger, opts.nriNamespace)
retriable := func() error {
return plugin.Start(ctx, opts.nriSocket, fmt.Sprintf("%02d", opts.nriPluginIndex))
}
Expand Down