diff --git a/pkg/helm/actions/config.go b/pkg/helm/actions/config.go index 2f88846154c..be5d2c062ab 100644 --- a/pkg/helm/actions/config.go +++ b/pkg/helm/actions/config.go @@ -50,6 +50,9 @@ func GetActionConfigurations(host, ns, token string, transport *http.RoundTrippe } conf := new(action.Configuration) conf.Init(confFlags, ns, "secrets", klog.Infof) - + err = GetDefaultOCIRegistry(conf) + if err != nil { + klog.V(4).Infof("Failed to get default OCI registry: %v", err) + } return conf } diff --git a/pkg/helm/actions/get_registry.go b/pkg/helm/actions/get_registry.go new file mode 100644 index 00000000000..1f9643f58d1 --- /dev/null +++ b/pkg/helm/actions/get_registry.go @@ -0,0 +1,44 @@ +package actions + +import ( + "crypto/tls" + "fmt" + "net/http" + + "helm.sh/helm/v3/pkg/action" + "helm.sh/helm/v3/pkg/registry" +) + +// newRegistryClient is a package-level variable to allow mocking in tests +var newRegistryClient = registry.NewClient + +func GetDefaultOCIRegistry(conf *action.Configuration) error { + return GetOCIRegistry(conf, false, false) +} + +func GetOCIRegistry(conf *action.Configuration, skipTLSVerify bool, plainHTTP bool) error { + if conf == nil { + return fmt.Errorf("action configuration cannot be nil") + } + opts := []registry.ClientOption{ + registry.ClientOptDebug(false), + } + if plainHTTP { + opts = append(opts, registry.ClientOptPlainHTTP()) + } + if skipTLSVerify { + opts = append(opts, registry.ClientOptHTTPClient(&http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + }, + })) + } + registryClient, err := newRegistryClient(opts...) + if err != nil { + return fmt.Errorf("failed to create registry client: %w", err) + } + conf.RegistryClient = registryClient + return nil +} diff --git a/pkg/helm/actions/get_registry_test.go b/pkg/helm/actions/get_registry_test.go new file mode 100644 index 00000000000..c3f232d5b6a --- /dev/null +++ b/pkg/helm/actions/get_registry_test.go @@ -0,0 +1,132 @@ +package actions + +import ( + "errors" + "io" + "testing" + + "github.com/stretchr/testify/require" + "helm.sh/helm/v3/pkg/action" + "helm.sh/helm/v3/pkg/chartutil" + kubefake "helm.sh/helm/v3/pkg/kube/fake" + "helm.sh/helm/v3/pkg/registry" + "helm.sh/helm/v3/pkg/storage" + "helm.sh/helm/v3/pkg/storage/driver" +) + +func TestGetDefaultOCIRegistry_Success(t *testing.T) { + store := storage.Init(driver.NewMemory()) + conf := &action.Configuration{ + RESTClientGetter: FakeConfig{}, + Releases: store, + KubeClient: &kubefake.PrintingKubeClient{Out: io.Discard}, + Capabilities: chartutil.DefaultCapabilities, + } + require.Nil(t, conf.RegistryClient, "Registry Client should be nil") + + // Store original values + originalReleases := conf.Releases + originalKubeClient := conf.KubeClient + originalCapabilities := conf.Capabilities + + err := GetDefaultOCIRegistry(conf) + require.NoError(t, err) + require.NotNil(t, conf.RegistryClient, "Registry Client should not be nil") + + // Verify other configuration fields are not modified. + require.Equal(t, originalReleases, conf.Releases, "Releases should not be modified") + require.Equal(t, originalKubeClient, conf.KubeClient, "KubeClient should not be modified") + require.Equal(t, originalCapabilities, conf.Capabilities, "Capabilities should not be modified") + +} + +func TestGetOCIRegistry_NilConfig(t *testing.T) { + err := GetOCIRegistry(nil, false, false) + require.Error(t, err) + require.Contains(t, err.Error(), "action configuration cannot be nil") +} + +func TestGetOCIRegistry_Success(t *testing.T) { + tests := []struct { + name string + skipTLSVerify bool + plainHTTP bool + }{ + { + name: "default options", + skipTLSVerify: false, + plainHTTP: false, + }, + { + name: "with skipTLSVerify", + skipTLSVerify: true, + plainHTTP: false, + }, + { + name: "with plainHTTP", + skipTLSVerify: false, + plainHTTP: true, + }, + { + name: "with both skipTLSVerify and plainHTTP", + skipTLSVerify: true, + plainHTTP: true, + }, + } + + for _, tt := range tests { + originalNewRegistryClient := newRegistryClient + defer func() { + newRegistryClient = originalNewRegistryClient + }() + newRegistryClient = func(options ...registry.ClientOption) (*registry.Client, error) { + count := 0 + if tt.plainHTTP { + count += 1 + } + if tt.skipTLSVerify { + count += 1 + } + require.Equal(t, count, len(options)-1, "Expected %d options, got %d", count, len(options)) + return ®istry.Client{}, nil + } + t.Run(tt.name, func(t *testing.T) { + store := storage.Init(driver.NewMemory()) + conf := &action.Configuration{ + RESTClientGetter: FakeConfig{}, + Releases: store, + KubeClient: &kubefake.PrintingKubeClient{Out: io.Discard}, + Capabilities: chartutil.DefaultCapabilities, + } + require.Nil(t, conf.RegistryClient, "Registry Client should be nil initially") + + err := GetOCIRegistry(conf, tt.skipTLSVerify, tt.plainHTTP) + require.NoError(t, err) + require.NotNil(t, conf.RegistryClient, "Registry Client should not be nil after GetOCIRegistry") + }) + } +} + +func TestGetOCIRegistry_NewClientError(t *testing.T) { + // Save original function and restore after test + originalNewRegistryClient := newRegistryClient + defer func() { newRegistryClient = originalNewRegistryClient }() + + // Mock newRegistryClient to return an error + newRegistryClient = func(options ...registry.ClientOption) (*registry.Client, error) { + return nil, errors.New("mock registry client error") + } + + store := storage.Init(driver.NewMemory()) + conf := &action.Configuration{ + RESTClientGetter: FakeConfig{}, + Releases: store, + KubeClient: &kubefake.PrintingKubeClient{Out: io.Discard}, + Capabilities: chartutil.DefaultCapabilities, + } + + err := GetOCIRegistry(conf, false, false) + require.Error(t, err) + require.Contains(t, err.Error(), "failed to create registry client") + require.Contains(t, err.Error(), "mock registry client error") +} diff --git a/pkg/helm/handlers/handler_test.go b/pkg/helm/handlers/handler_test.go index e80bf4d7994..8fb3d993f4e 100644 --- a/pkg/helm/handlers/handler_test.go +++ b/pkg/helm/handlers/handler_test.go @@ -63,6 +63,7 @@ var fakeReleaseManifest = "manifest-data" func fakeHelmHandler() helmHandlers { return helmHandlers{ getActionConfigurations: getFakeActionConfigurations, + getDefaultOCIRegistry: fakeGetDefaultOCIRegistry, } } @@ -201,6 +202,10 @@ func getFakeActionConfigurations(string, string, string, *http.RoundTripper) *ac } } +func fakeGetDefaultOCIRegistry(conf *action.Configuration) error { + return nil +} + func TestHelmHandlers_HandleHelmList(t *testing.T) { tests := []struct { name string diff --git a/pkg/helm/handlers/handlers.go b/pkg/helm/handlers/handlers.go index 9261ba8a156..6866724796f 100644 --- a/pkg/helm/handlers/handlers.go +++ b/pkg/helm/handlers/handlers.go @@ -39,6 +39,7 @@ func New(apiUrl string, transport http.RoundTripper, kubeversionGetter version.K uninstallReleaseAsync: actions.UninstallReleaseAsync, rollbackRelease: actions.RollbackRelease, getReleaseHistory: actions.GetReleaseHistory, + getDefaultOCIRegistry: actions.GetDefaultOCIRegistry, } h.newProxy = func(bearerToken string) (getter chartproxy.Proxy, err error) { @@ -72,6 +73,7 @@ type helmHandlers struct { getChart func(chartUrl string, conf *action.Configuration, namespace string, client dynamic.Interface, coreClient corev1client.CoreV1Interface, filesCleanup bool, indexEntry string) (*chart.Chart, error) getReleaseHistory func(releaseName string, conf *action.Configuration) ([]*release.Release, error) newProxy func(bearerToken string) (chartproxy.Proxy, error) + getDefaultOCIRegistry func(*action.Configuration) error } func (h *helmHandlers) restConfig(bearerToken string) *rest.Config { @@ -117,6 +119,11 @@ func (h *helmHandlers) HandleHelmInstall(user *auth.User, w http.ResponseWriter, } conf := h.getActionConfigurations(h.ApiServerHost, req.Namespace, user.Token, &h.Transport) + err = h.getDefaultOCIRegistry(conf) + if err != nil { + serverutils.SendResponse(w, http.StatusBadGateway, serverutils.ApiError{Err: fmt.Sprintf("Failed to get default registry: %v", err)}) + return + } handlerClients, err := NewHandlerClients(conf) if err != nil { serverutils.SendResponse(w, http.StatusBadGateway, serverutils.ApiError{Err: err.Error()}) @@ -143,6 +150,11 @@ func (h *helmHandlers) HandleHelmInstallAsync(user *auth.User, w http.ResponseWr } conf := h.getActionConfigurations(h.ApiServerHost, req.Namespace, user.Token, &h.Transport) + err = h.getDefaultOCIRegistry(conf) + if err != nil { + serverutils.SendResponse(w, http.StatusBadGateway, serverutils.ApiError{Err: fmt.Sprintf("Failed to get default registry: %v", err)}) + return + } handlerClients, err := NewHandlerClients(conf) if err != nil { serverutils.SendResponse(w, http.StatusBadGateway, serverutils.ApiError{Err: err.Error()}) @@ -210,6 +222,11 @@ func (h *helmHandlers) HandleChartGet(user *auth.User, w http.ResponseWriter, r indexEntry := params.Get("indexEntry") // scope request to default namespace conf := h.getActionConfigurations(h.ApiServerHost, "default", user.Token, &h.Transport) + err := h.getDefaultOCIRegistry(conf) + if err != nil { + serverutils.SendResponse(w, http.StatusBadGateway, serverutils.ApiError{Err: fmt.Sprintf("Failed to get default registry: %v", err)}) + return + } handlerClients, err := NewHandlerClients(conf) if err != nil { serverutils.SendResponse(w, http.StatusBadGateway, serverutils.ApiError{Err: err.Error()}) @@ -237,6 +254,11 @@ func (h *helmHandlers) HandleUpgradeRelease(user *auth.User, w http.ResponseWrit } conf := h.getActionConfigurations(h.ApiServerHost, req.Namespace, user.Token, &h.Transport) + err = h.getDefaultOCIRegistry(conf) + if err != nil { + serverutils.SendResponse(w, http.StatusBadGateway, serverutils.ApiError{Err: fmt.Sprintf("Failed to get default registry: %v", err)}) + return + } handlerClients, err := NewHandlerClients(conf) if err != nil { serverutils.SendResponse(w, http.StatusBadGateway, serverutils.ApiError{Err: err.Error()}) @@ -267,6 +289,11 @@ func (h *helmHandlers) HandleUpgradeReleaseAsync(user *auth.User, w http.Respons } conf := h.getActionConfigurations(h.ApiServerHost, req.Namespace, user.Token, &h.Transport) + err = h.getDefaultOCIRegistry(conf) + if err != nil { + serverutils.SendResponse(w, http.StatusBadGateway, serverutils.ApiError{Err: fmt.Sprintf("Failed to get default registry: %v", err)}) + return + } handlerClients, err := NewHandlerClients(conf) if err != nil { serverutils.SendResponse(w, http.StatusBadGateway, serverutils.ApiError{Err: err.Error()}) @@ -314,6 +341,11 @@ func (h *helmHandlers) HandleRollbackRelease(user *auth.User, w http.ResponseWri } conf := h.getActionConfigurations(h.ApiServerHost, req.Namespace, user.Token, &h.Transport) + err = h.getDefaultOCIRegistry(conf) + if err != nil { + serverutils.SendResponse(w, http.StatusBadGateway, serverutils.ApiError{Err: fmt.Sprintf("Failed to get default registry: %v", err)}) + return + } rel, err := h.rollbackRelease(req.Name, req.Version, conf) if err != nil { if err.Error() == actions.ErrReleaseRevisionNotFound.Error() {