Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions pkg/asset/installconfig/azure/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
azmarketplace "github.com/Azure/azure-sdk-for-go/profiles/latest/marketplaceordering/mgmt/marketplaceordering"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/arm"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/msi/armmsi"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v2"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources"
azstorage "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/storage/armstorage"
Expand Down Expand Up @@ -46,6 +47,7 @@ type API interface {
CheckIfExistsStorageAccount(ctx context.Context, resourceGroup, storageAccountName, region string) error
GetRegionAvailabilityZones(ctx context.Context, region string) ([]string, error)
CheckSubnetNatgateway(ctx context.Context, resourceGroup, virtualNetwork, subnet string) (bool, error)
GetUserAssignedIdentity(ctx context.Context, subscriptionID, resourceGroup, name string) error
}

// APIVersion describes to the version to use for Azure API calls that support both azure and azurestack.
Expand Down Expand Up @@ -566,3 +568,31 @@ func (c *Client) CheckSubnetNatgateway(ctx context.Context, resourceGroup, virtu
}
return false, fmt.Errorf("unable to get subnet nat gateway")
}

// GetUserAssignedIdentity checks if a user-assigned identity exists in the specified resource group.
func (c *Client) GetUserAssignedIdentity(ctx context.Context, subscriptionID, resourceGroup, name string) error {
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()

// Use the subscription ID from the function parameter if provided, otherwise use session default
subID := subscriptionID
if subID == "" {
subID = c.ssn.Credentials.SubscriptionID
}

clientOptions := arm.ClientOptions{
ClientOptions: policy.ClientOptions{
// Don't override APIVersion for managed identities - let SDK use the default
// API version which supports user-assigned identities. The generic APIVersion
// constant (2019-11-01) doesn't support the managed identity API.
Cloud: c.ssn.CloudConfig,
},
}
client, err := armmsi.NewUserAssignedIdentitiesClient(subID, c.ssn.TokenCreds, &clientOptions)
if err != nil {
return fmt.Errorf("failed to create user-assigned identities client: %w", err)
}

_, err = client.Get(ctx, resourceGroup, name, nil)
return err
}
14 changes: 14 additions & 0 deletions pkg/asset/installconfig/azure/mock/azureclient_generated.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

59 changes: 59 additions & 0 deletions pkg/asset/installconfig/azure/validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -658,6 +658,62 @@ func ValidatePublicDNS(ic *types.InstallConfig, azureDNS *DNSConfig) error {
return nil
}

// ValidateUserAssignedIdentities ensures the user-assigned identities exist and are valid.
func ValidateUserAssignedIdentities(client API, ic *types.InstallConfig) field.ErrorList {
allErrs := field.ErrorList{}

// Validate default machine platform identities
if ic.Platform.Azure.DefaultMachinePlatform != nil &&
ic.Platform.Azure.DefaultMachinePlatform.Identity != nil &&
ic.Platform.Azure.DefaultMachinePlatform.Identity.Type == capz.VMIdentityUserAssigned {
for idx, identity := range ic.Platform.Azure.DefaultMachinePlatform.Identity.UserAssignedIdentities {
fieldPath := field.NewPath("platform").Child("azure", "defaultMachinePlatform", "identity", "userAssignedIdentities").Index(idx)
if err := validateUserAssignedIdentity(client, &identity, fieldPath); err != nil {
allErrs = append(allErrs, err)
}
}
}

// Validate control plane identities
if ic.ControlPlane != nil &&
ic.ControlPlane.Platform.Azure != nil &&
ic.ControlPlane.Platform.Azure.Identity != nil &&
ic.ControlPlane.Platform.Azure.Identity.Type == capz.VMIdentityUserAssigned {
for idx, identity := range ic.ControlPlane.Platform.Azure.Identity.UserAssignedIdentities {
fieldPath := field.NewPath("controlPlane").Child("platform", "azure", "identity", "userAssignedIdentities").Index(idx)
if err := validateUserAssignedIdentity(client, &identity, fieldPath); err != nil {
allErrs = append(allErrs, err)
}
}
}

// Validate compute pool identities
for compIdx, compute := range ic.Compute {
if compute.Platform.Azure != nil &&
compute.Platform.Azure.Identity != nil &&
compute.Platform.Azure.Identity.Type == capz.VMIdentityUserAssigned {
for idIdx, identity := range compute.Platform.Azure.Identity.UserAssignedIdentities {
fieldPath := field.NewPath("compute").Index(compIdx).Child("platform", "azure", "identity", "userAssignedIdentities").Index(idIdx)
if err := validateUserAssignedIdentity(client, &identity, fieldPath); err != nil {
allErrs = append(allErrs, err)
}
}
}
}

return allErrs
}

func validateUserAssignedIdentity(client API, identity *aztypes.UserAssignedIdentity, fieldPath *field.Path) *field.Error {
err := client.GetUserAssignedIdentity(context.TODO(), identity.Subscription, identity.ResourceGroup, identity.Name)
if err != nil {
errMsg := fmt.Sprintf("failed to validate user-assigned identity '%s' in resource group '%s': %v",
identity.Name, identity.ResourceGroup, err)
return field.Invalid(fieldPath, identity.Name, errMsg)
}
return nil
}

// ValidateForProvisioning validates if the install config is valid for provisioning the cluster.
func ValidateForProvisioning(client API, ic *types.InstallConfig) error {
allErrs := field.ErrorList{}
Expand All @@ -667,6 +723,9 @@ func ValidateForProvisioning(client API, ic *types.InstallConfig) error {
if ic.Azure.CloudName == aztypes.StackCloud {
allErrs = append(allErrs, checkAzureStackClusterOSImageSet(ic.Azure.ClusterOSImage, field.NewPath("platform").Child("azure"))...)
}

allErrs = append(allErrs, ValidateUserAssignedIdentities(client, ic)...)

return allErrs.ToAggregate()
}

Expand Down
148 changes: 148 additions & 0 deletions pkg/asset/installconfig/azure/validation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,67 @@ var (
ic.Compute[0].Platform.Azure.OSDisk.SecurityProfile = &azure.VMDiskSecurityProfile{DiskEncryptionSet: validDiskEncryptionSetConfig()}
}

validUserAssignedIdentityName = "valid-identity"
validUserAssignedIdentityResourceGroup = "valid-identity-rg"
validUserAssignedIdentitySubscription = "valid-sub-id"
invalidUserAssignedIdentityName = "invalid-identity"

validUserAssignedIdentityConfig = func() *azure.VMIdentity {
return &azure.VMIdentity{
Type: "UserAssigned",
UserAssignedIdentities: []azure.UserAssignedIdentity{
{
Name: validUserAssignedIdentityName,
ResourceGroup: validUserAssignedIdentityResourceGroup,
Subscription: validUserAssignedIdentitySubscription,
},
},
}
}

invalidUserAssignedIdentityConfig = func() *azure.VMIdentity {
return &azure.VMIdentity{
Type: "UserAssigned",
UserAssignedIdentities: []azure.UserAssignedIdentity{
{
Name: invalidUserAssignedIdentityName,
ResourceGroup: validUserAssignedIdentityResourceGroup,
Subscription: validUserAssignedIdentitySubscription,
},
},
}
}

validUserAssignedIdentityDefaultMachinePlatform = func(ic *types.InstallConfig) {
ic.Azure.DefaultMachinePlatform.Identity = validUserAssignedIdentityConfig()
}
validUserAssignedIdentityControlPlane = func(ic *types.InstallConfig) {
ic.ControlPlane.Platform.Azure.Identity = validUserAssignedIdentityConfig()
}
validUserAssignedIdentityCompute = func(ic *types.InstallConfig) {
ic.Compute[0].Platform.Azure.Identity = validUserAssignedIdentityConfig()
}

invalidUserAssignedIdentityDefaultMachinePlatform = func(ic *types.InstallConfig) {
ic.Azure.DefaultMachinePlatform.Identity = invalidUserAssignedIdentityConfig()
}
invalidUserAssignedIdentityControlPlane = func(ic *types.InstallConfig) {
ic.ControlPlane.Platform.Azure.Identity = invalidUserAssignedIdentityConfig()
}
invalidUserAssignedIdentityCompute = func(ic *types.InstallConfig) {
ic.Compute[0].Platform.Azure.Identity = invalidUserAssignedIdentityConfig()
}

noUserAssignedIdentity = func(ic *types.InstallConfig) {
ic.Azure.DefaultMachinePlatform.Identity = &azure.VMIdentity{
Type: "None",
}
ic.ControlPlane.Platform.Azure.Identity = nil
if len(ic.Compute) > 0 {
ic.Compute[0].Platform.Azure.Identity = nil
}
}

validOSImageCompute = func(ic *types.InstallConfig) {
ic.Compute[0].Platform.Azure.OSImage = validOSImage
}
Expand Down Expand Up @@ -1636,3 +1697,90 @@ func TestAzureStackDiskType(t *testing.T) {
})
}
}

func TestValidateUserAssignedIdentities(t *testing.T) {
cases := []struct {
name string
edits editFunctions
errorMsg string
}{
{
name: "Valid user-assigned identity for default machine platform",
edits: editFunctions{validUserAssignedIdentityDefaultMachinePlatform},
errorMsg: "",
},
{
name: "Invalid user-assigned identity not found for default machine platform",
edits: editFunctions{invalidUserAssignedIdentityDefaultMachinePlatform},
errorMsg: fmt.Sprintf(`platform.azure.defaultMachinePlatform.identity.userAssignedIdentities\[0\]: Invalid value: "%s": failed to validate user-assigned identity '%s' in resource group '%s'`, invalidUserAssignedIdentityName, invalidUserAssignedIdentityName, validUserAssignedIdentityResourceGroup),
},
{
name: "Valid user-assigned identity for control plane",
edits: editFunctions{validUserAssignedIdentityControlPlane},
errorMsg: "",
},
{
name: "Invalid user-assigned identity not found for control plane",
edits: editFunctions{invalidUserAssignedIdentityControlPlane},
errorMsg: fmt.Sprintf(`controlPlane.platform.azure.identity.userAssignedIdentities\[0\]: Invalid value: "%s": failed to validate user-assigned identity '%s' in resource group '%s'`, invalidUserAssignedIdentityName, invalidUserAssignedIdentityName, validUserAssignedIdentityResourceGroup),
},
{
name: "Valid user-assigned identity for compute",
edits: editFunctions{validUserAssignedIdentityCompute},
errorMsg: "",
},
{
name: "Invalid user-assigned identity not found for compute",
edits: editFunctions{invalidUserAssignedIdentityCompute},
errorMsg: fmt.Sprintf(`compute\[0\].platform.azure.identity.userAssignedIdentities\[0\]: Invalid value: "%s": failed to validate user-assigned identity '%s' in resource group '%s'`, invalidUserAssignedIdentityName, invalidUserAssignedIdentityName, validUserAssignedIdentityResourceGroup),
},
{
name: "No user-assigned identities specified",
edits: editFunctions{noUserAssignedIdentity},
errorMsg: "",
},
{
name: "Multiple valid identities in different pools",
edits: editFunctions{
validUserAssignedIdentityDefaultMachinePlatform,
validUserAssignedIdentityControlPlane,
validUserAssignedIdentityCompute,
},
errorMsg: "",
},
{
name: "Mix of valid and invalid identities",
edits: editFunctions{
validUserAssignedIdentityDefaultMachinePlatform,
invalidUserAssignedIdentityControlPlane,
},
errorMsg: fmt.Sprintf(`failed to validate user-assigned identity '%s' in resource group '%s'`, invalidUserAssignedIdentityName, validUserAssignedIdentityResourceGroup),
},
}

mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()

azureClient := mock.NewMockAPI(mockCtrl)

// Setup mock expectations for valid and invalid identities
azureClient.EXPECT().GetUserAssignedIdentity(gomock.Any(), validUserAssignedIdentitySubscription, validUserAssignedIdentityResourceGroup, validUserAssignedIdentityName).Return(nil).AnyTimes()
azureClient.EXPECT().GetUserAssignedIdentity(gomock.Any(), validUserAssignedIdentitySubscription, validUserAssignedIdentityResourceGroup, invalidUserAssignedIdentityName).Return(fmt.Errorf("resource not found")).AnyTimes()

for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
editedInstallConfig := validInstallConfig()
for _, edit := range tc.edits {
edit(editedInstallConfig)
}

errors := ValidateUserAssignedIdentities(azureClient, editedInstallConfig)
aggregatedErrors := errors.ToAggregate()
if tc.errorMsg != "" {
assert.Regexp(t, tc.errorMsg, aggregatedErrors)
} else {
assert.NoError(t, aggregatedErrors)
}
})
}
}