From 06d80f2a82dcdbb4be889c54c76465e4b2e43c8f Mon Sep 17 00:00:00 2001 From: Johnny Liu Date: Wed, 14 Jan 2026 22:30:45 +0800 Subject: [PATCH] validate azure user-assigned identity existence --- pkg/asset/installconfig/azure/client.go | 30 ++++ .../azure/mock/azureclient_generated.go | 14 ++ pkg/asset/installconfig/azure/validation.go | 59 +++++++ .../installconfig/azure/validation_test.go | 148 ++++++++++++++++++ 4 files changed, 251 insertions(+) diff --git a/pkg/asset/installconfig/azure/client.go b/pkg/asset/installconfig/azure/client.go index e1376195283..e559be76a0d 100644 --- a/pkg/asset/installconfig/azure/client.go +++ b/pkg/asset/installconfig/azure/client.go @@ -14,6 +14,7 @@ import ( azmarketplace "github.com/Azure/azure-sdk-for-go/profiles/latest/marketplaceordering/mgmt/marketplaceordering" "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/msi/armmsi" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v2" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources" azstorage "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/storage/armstorage" @@ -46,6 +47,7 @@ type API interface { CheckIfExistsStorageAccount(ctx context.Context, resourceGroup, storageAccountName, region string) error GetRegionAvailabilityZones(ctx context.Context, region string) ([]string, error) CheckSubnetNatgateway(ctx context.Context, resourceGroup, virtualNetwork, subnet string) (bool, error) + GetUserAssignedIdentity(ctx context.Context, subscriptionID, resourceGroup, name string) error } // APIVersion describes to the version to use for Azure API calls that support both azure and azurestack. @@ -566,3 +568,31 @@ func (c *Client) CheckSubnetNatgateway(ctx context.Context, resourceGroup, virtu } return false, fmt.Errorf("unable to get subnet nat gateway") } + +// GetUserAssignedIdentity checks if a user-assigned identity exists in the specified resource group. +func (c *Client) GetUserAssignedIdentity(ctx context.Context, subscriptionID, resourceGroup, name string) error { + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + + // Use the subscription ID from the function parameter if provided, otherwise use session default + subID := subscriptionID + if subID == "" { + subID = c.ssn.Credentials.SubscriptionID + } + + clientOptions := arm.ClientOptions{ + ClientOptions: policy.ClientOptions{ + // Don't override APIVersion for managed identities - let SDK use the default + // API version which supports user-assigned identities. The generic APIVersion + // constant (2019-11-01) doesn't support the managed identity API. + Cloud: c.ssn.CloudConfig, + }, + } + client, err := armmsi.NewUserAssignedIdentitiesClient(subID, c.ssn.TokenCreds, &clientOptions) + if err != nil { + return fmt.Errorf("failed to create user-assigned identities client: %w", err) + } + + _, err = client.Get(ctx, resourceGroup, name, nil) + return err +} diff --git a/pkg/asset/installconfig/azure/mock/azureclient_generated.go b/pkg/asset/installconfig/azure/mock/azureclient_generated.go index ee62a871bc2..18eb90b533c 100644 --- a/pkg/asset/installconfig/azure/mock/azureclient_generated.go +++ b/pkg/asset/installconfig/azure/mock/azureclient_generated.go @@ -268,6 +268,20 @@ func (mr *MockAPIMockRecorder) GetStorageEndpointSuffix(ctx any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStorageEndpointSuffix", reflect.TypeOf((*MockAPI)(nil).GetStorageEndpointSuffix), ctx) } +// GetUserAssignedIdentity mocks base method. +func (m *MockAPI) GetUserAssignedIdentity(ctx context.Context, subscriptionID, resourceGroup, name string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetUserAssignedIdentity", ctx, subscriptionID, resourceGroup, name) + ret0, _ := ret[0].(error) + return ret0 +} + +// GetUserAssignedIdentity indicates an expected call of GetUserAssignedIdentity. +func (mr *MockAPIMockRecorder) GetUserAssignedIdentity(ctx, subscriptionID, resourceGroup, name any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserAssignedIdentity", reflect.TypeOf((*MockAPI)(nil).GetUserAssignedIdentity), ctx, subscriptionID, resourceGroup, name) +} + // GetVMCapabilities mocks base method. func (m *MockAPI) GetVMCapabilities(ctx context.Context, instanceType, region string) (map[string]string, error) { m.ctrl.T.Helper() diff --git a/pkg/asset/installconfig/azure/validation.go b/pkg/asset/installconfig/azure/validation.go index 6b4f58a5fe9..2a6c33eaa7a 100644 --- a/pkg/asset/installconfig/azure/validation.go +++ b/pkg/asset/installconfig/azure/validation.go @@ -658,6 +658,62 @@ func ValidatePublicDNS(ic *types.InstallConfig, azureDNS *DNSConfig) error { return nil } +// ValidateUserAssignedIdentities ensures the user-assigned identities exist and are valid. +func ValidateUserAssignedIdentities(client API, ic *types.InstallConfig) field.ErrorList { + allErrs := field.ErrorList{} + + // Validate default machine platform identities + if ic.Platform.Azure.DefaultMachinePlatform != nil && + ic.Platform.Azure.DefaultMachinePlatform.Identity != nil && + ic.Platform.Azure.DefaultMachinePlatform.Identity.Type == capz.VMIdentityUserAssigned { + for idx, identity := range ic.Platform.Azure.DefaultMachinePlatform.Identity.UserAssignedIdentities { + fieldPath := field.NewPath("platform").Child("azure", "defaultMachinePlatform", "identity", "userAssignedIdentities").Index(idx) + if err := validateUserAssignedIdentity(client, &identity, fieldPath); err != nil { + allErrs = append(allErrs, err) + } + } + } + + // Validate control plane identities + if ic.ControlPlane != nil && + ic.ControlPlane.Platform.Azure != nil && + ic.ControlPlane.Platform.Azure.Identity != nil && + ic.ControlPlane.Platform.Azure.Identity.Type == capz.VMIdentityUserAssigned { + for idx, identity := range ic.ControlPlane.Platform.Azure.Identity.UserAssignedIdentities { + fieldPath := field.NewPath("controlPlane").Child("platform", "azure", "identity", "userAssignedIdentities").Index(idx) + if err := validateUserAssignedIdentity(client, &identity, fieldPath); err != nil { + allErrs = append(allErrs, err) + } + } + } + + // Validate compute pool identities + for compIdx, compute := range ic.Compute { + if compute.Platform.Azure != nil && + compute.Platform.Azure.Identity != nil && + compute.Platform.Azure.Identity.Type == capz.VMIdentityUserAssigned { + for idIdx, identity := range compute.Platform.Azure.Identity.UserAssignedIdentities { + fieldPath := field.NewPath("compute").Index(compIdx).Child("platform", "azure", "identity", "userAssignedIdentities").Index(idIdx) + if err := validateUserAssignedIdentity(client, &identity, fieldPath); err != nil { + allErrs = append(allErrs, err) + } + } + } + } + + return allErrs +} + +func validateUserAssignedIdentity(client API, identity *aztypes.UserAssignedIdentity, fieldPath *field.Path) *field.Error { + err := client.GetUserAssignedIdentity(context.TODO(), identity.Subscription, identity.ResourceGroup, identity.Name) + if err != nil { + errMsg := fmt.Sprintf("failed to validate user-assigned identity '%s' in resource group '%s': %v", + identity.Name, identity.ResourceGroup, err) + return field.Invalid(fieldPath, identity.Name, errMsg) + } + return nil +} + // ValidateForProvisioning validates if the install config is valid for provisioning the cluster. func ValidateForProvisioning(client API, ic *types.InstallConfig) error { allErrs := field.ErrorList{} @@ -667,6 +723,9 @@ func ValidateForProvisioning(client API, ic *types.InstallConfig) error { if ic.Azure.CloudName == aztypes.StackCloud { allErrs = append(allErrs, checkAzureStackClusterOSImageSet(ic.Azure.ClusterOSImage, field.NewPath("platform").Child("azure"))...) } + + allErrs = append(allErrs, ValidateUserAssignedIdentities(client, ic)...) + return allErrs.ToAggregate() } diff --git a/pkg/asset/installconfig/azure/validation_test.go b/pkg/asset/installconfig/azure/validation_test.go index 400eb363792..7f32a207235 100644 --- a/pkg/asset/installconfig/azure/validation_test.go +++ b/pkg/asset/installconfig/azure/validation_test.go @@ -367,6 +367,67 @@ var ( ic.Compute[0].Platform.Azure.OSDisk.SecurityProfile = &azure.VMDiskSecurityProfile{DiskEncryptionSet: validDiskEncryptionSetConfig()} } + validUserAssignedIdentityName = "valid-identity" + validUserAssignedIdentityResourceGroup = "valid-identity-rg" + validUserAssignedIdentitySubscription = "valid-sub-id" + invalidUserAssignedIdentityName = "invalid-identity" + + validUserAssignedIdentityConfig = func() *azure.VMIdentity { + return &azure.VMIdentity{ + Type: "UserAssigned", + UserAssignedIdentities: []azure.UserAssignedIdentity{ + { + Name: validUserAssignedIdentityName, + ResourceGroup: validUserAssignedIdentityResourceGroup, + Subscription: validUserAssignedIdentitySubscription, + }, + }, + } + } + + invalidUserAssignedIdentityConfig = func() *azure.VMIdentity { + return &azure.VMIdentity{ + Type: "UserAssigned", + UserAssignedIdentities: []azure.UserAssignedIdentity{ + { + Name: invalidUserAssignedIdentityName, + ResourceGroup: validUserAssignedIdentityResourceGroup, + Subscription: validUserAssignedIdentitySubscription, + }, + }, + } + } + + validUserAssignedIdentityDefaultMachinePlatform = func(ic *types.InstallConfig) { + ic.Azure.DefaultMachinePlatform.Identity = validUserAssignedIdentityConfig() + } + validUserAssignedIdentityControlPlane = func(ic *types.InstallConfig) { + ic.ControlPlane.Platform.Azure.Identity = validUserAssignedIdentityConfig() + } + validUserAssignedIdentityCompute = func(ic *types.InstallConfig) { + ic.Compute[0].Platform.Azure.Identity = validUserAssignedIdentityConfig() + } + + invalidUserAssignedIdentityDefaultMachinePlatform = func(ic *types.InstallConfig) { + ic.Azure.DefaultMachinePlatform.Identity = invalidUserAssignedIdentityConfig() + } + invalidUserAssignedIdentityControlPlane = func(ic *types.InstallConfig) { + ic.ControlPlane.Platform.Azure.Identity = invalidUserAssignedIdentityConfig() + } + invalidUserAssignedIdentityCompute = func(ic *types.InstallConfig) { + ic.Compute[0].Platform.Azure.Identity = invalidUserAssignedIdentityConfig() + } + + noUserAssignedIdentity = func(ic *types.InstallConfig) { + ic.Azure.DefaultMachinePlatform.Identity = &azure.VMIdentity{ + Type: "None", + } + ic.ControlPlane.Platform.Azure.Identity = nil + if len(ic.Compute) > 0 { + ic.Compute[0].Platform.Azure.Identity = nil + } + } + validOSImageCompute = func(ic *types.InstallConfig) { ic.Compute[0].Platform.Azure.OSImage = validOSImage } @@ -1636,3 +1697,90 @@ func TestAzureStackDiskType(t *testing.T) { }) } } + +func TestValidateUserAssignedIdentities(t *testing.T) { + cases := []struct { + name string + edits editFunctions + errorMsg string + }{ + { + name: "Valid user-assigned identity for default machine platform", + edits: editFunctions{validUserAssignedIdentityDefaultMachinePlatform}, + errorMsg: "", + }, + { + name: "Invalid user-assigned identity not found for default machine platform", + edits: editFunctions{invalidUserAssignedIdentityDefaultMachinePlatform}, + errorMsg: fmt.Sprintf(`platform.azure.defaultMachinePlatform.identity.userAssignedIdentities\[0\]: Invalid value: "%s": failed to validate user-assigned identity '%s' in resource group '%s'`, invalidUserAssignedIdentityName, invalidUserAssignedIdentityName, validUserAssignedIdentityResourceGroup), + }, + { + name: "Valid user-assigned identity for control plane", + edits: editFunctions{validUserAssignedIdentityControlPlane}, + errorMsg: "", + }, + { + name: "Invalid user-assigned identity not found for control plane", + edits: editFunctions{invalidUserAssignedIdentityControlPlane}, + errorMsg: fmt.Sprintf(`controlPlane.platform.azure.identity.userAssignedIdentities\[0\]: Invalid value: "%s": failed to validate user-assigned identity '%s' in resource group '%s'`, invalidUserAssignedIdentityName, invalidUserAssignedIdentityName, validUserAssignedIdentityResourceGroup), + }, + { + name: "Valid user-assigned identity for compute", + edits: editFunctions{validUserAssignedIdentityCompute}, + errorMsg: "", + }, + { + name: "Invalid user-assigned identity not found for compute", + edits: editFunctions{invalidUserAssignedIdentityCompute}, + errorMsg: fmt.Sprintf(`compute\[0\].platform.azure.identity.userAssignedIdentities\[0\]: Invalid value: "%s": failed to validate user-assigned identity '%s' in resource group '%s'`, invalidUserAssignedIdentityName, invalidUserAssignedIdentityName, validUserAssignedIdentityResourceGroup), + }, + { + name: "No user-assigned identities specified", + edits: editFunctions{noUserAssignedIdentity}, + errorMsg: "", + }, + { + name: "Multiple valid identities in different pools", + edits: editFunctions{ + validUserAssignedIdentityDefaultMachinePlatform, + validUserAssignedIdentityControlPlane, + validUserAssignedIdentityCompute, + }, + errorMsg: "", + }, + { + name: "Mix of valid and invalid identities", + edits: editFunctions{ + validUserAssignedIdentityDefaultMachinePlatform, + invalidUserAssignedIdentityControlPlane, + }, + errorMsg: fmt.Sprintf(`failed to validate user-assigned identity '%s' in resource group '%s'`, invalidUserAssignedIdentityName, validUserAssignedIdentityResourceGroup), + }, + } + + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + + azureClient := mock.NewMockAPI(mockCtrl) + + // Setup mock expectations for valid and invalid identities + azureClient.EXPECT().GetUserAssignedIdentity(gomock.Any(), validUserAssignedIdentitySubscription, validUserAssignedIdentityResourceGroup, validUserAssignedIdentityName).Return(nil).AnyTimes() + azureClient.EXPECT().GetUserAssignedIdentity(gomock.Any(), validUserAssignedIdentitySubscription, validUserAssignedIdentityResourceGroup, invalidUserAssignedIdentityName).Return(fmt.Errorf("resource not found")).AnyTimes() + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + editedInstallConfig := validInstallConfig() + for _, edit := range tc.edits { + edit(editedInstallConfig) + } + + errors := ValidateUserAssignedIdentities(azureClient, editedInstallConfig) + aggregatedErrors := errors.ToAggregate() + if tc.errorMsg != "" { + assert.Regexp(t, tc.errorMsg, aggregatedErrors) + } else { + assert.NoError(t, aggregatedErrors) + } + }) + } +}