Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 91 additions & 28 deletions src/azure-cli/azure/cli/command_modules/vm/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -1419,12 +1419,12 @@ def get_vm_to_update_by_aaz(cmd, resource_group_name, vm_name):
from .operations.vm import VMShow

vm = VMShow(cli_ctx=cmd.cli_ctx)(command_args={
'resource_group': resource_group_name,
"resource_group": resource_group_name,
"vm_name": vm_name
})

# To avoid unnecessary permission check of image
storage_profile = vm.get('storageProfile', {})
storage_profile = vm.get("storageProfile", {})
storage_profile["imageReference"] = None

return vm
Expand Down Expand Up @@ -1739,6 +1739,43 @@ def set_vm(cmd, instance, lro_operation=None, no_wait=False):
return LongRunningOperation(cmd.cli_ctx)(poller)


# Notes: vm format is in snake_case
def set_vm_by_aaz(cmd, vm, no_wait=False):
from .aaz.latest.vm import Create as _VMCreate

parsed_id = _parse_rg_name(vm["id"])
vm["resource_group"] = parsed_id[0]
vm["vm_name"] = parsed_id[1]

class SetVM(_VMCreate):
def pre_operations(self):
args = self.ctx.args
args.no_wait = no_wait

def _output(self, *args, **kwargs):
from azure.cli.core.aaz import AAZUndefined, has_value

# Resolve flatten conflict
# When the type field conflicts, the type in inner layer is ignored and the outer layer is applied
if has_value(self.ctx.vars.instance.resources):
for resource in self.ctx.vars.instance.resources:
if has_value(resource.type):
resource.type = AAZUndefined

result = self.deserialize_output(self.ctx.vars.instance, client_flatten=True)
if result.get('osProfile', {}).get('secrets', []):
for secret in result['osProfile']['secrets']:
for cert in secret.get('vaultCertificates', []):
if not cert.get('certificateStore'):
cert['certificateStore'] = None
return result

vm = LongRunningOperation(cmd.cli_ctx)(
SetVM(cli_ctx=cmd.cli_ctx)(command_args=vm))

return vm


def patch_vm(cmd, resource_group_name, vm_name, vm):
client = _compute_client_factory(cmd.cli_ctx)
poller = client.virtual_machines.begin_update(resource_group_name, vm_name, vm)
Expand Down Expand Up @@ -3288,51 +3325,75 @@ def get_vm_format_secret(cmd, secrets, certificate_store=None, keyvault=None, re
def add_vm_secret(cmd, resource_group_name, vm_name, keyvault, certificate, certificate_store=None):
from azure.mgmt.core.tools import parse_resource_id
from ._vm_utils import create_data_plane_keyvault_certificate_client, get_key_vault_base_url
VaultSecretGroup, SubResource, VaultCertificate = cmd.get_models(
'VaultSecretGroup', 'SubResource', 'VaultCertificate')
vm = get_vm_to_update(cmd, resource_group_name, vm_name)
from .operations.vm import convert_show_result_to_snake_case
vm = get_vm_to_update_by_aaz(cmd, resource_group_name, vm_name)
vm = convert_show_result_to_snake_case(vm)

if '://' not in certificate: # has a cert name rather a full url?
keyvault_client = create_data_plane_keyvault_certificate_client(
cmd.cli_ctx, get_key_vault_base_url(cmd.cli_ctx, parse_resource_id(keyvault)['name']))
cert_info = keyvault_client.get_certificate(certificate)
certificate = cert_info.secret_id

if not _is_linux_os(vm):
if not _is_linux_os_by_aaz(vm):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the _is_linux_os_by_aaz function is based on the snake case, so we might need to convert the vm to snake case first

certificate_store = certificate_store or 'My'
elif certificate_store:
raise CLIError('Usage error: --certificate-store is only applicable on Windows VM')
vault_cert = VaultCertificate(certificate_url=certificate, certificate_store=certificate_store)
vault_secret_group = next((x for x in vm.os_profile.secrets
if x.source_vault and x.source_vault.id.lower() == keyvault.lower()), None)
vault_cert = {
'certificate_store': certificate_store,
'certificate_url': certificate
}
vault_secret_group = next((x for x in vm.get('os_profile', {}).get('secrets', [])
if x.get('source_vault', {}).get('id', '').lower() == keyvault.lower()), None)
if vault_secret_group:
vault_secret_group.vault_certificates.append(vault_cert)
certs = vault_secret_group.get('vault_certificates', [])
certs.append(vault_cert)
vault_secret_group['vault_certificates'] = certs
else:
vault_secret_group = VaultSecretGroup(source_vault=SubResource(id=keyvault), vault_certificates=[vault_cert])
vm.os_profile.secrets.append(vault_secret_group)
vm = set_vm(cmd, vm)
return vm.os_profile.secrets
vault_secret_group = {
'source_vault': {
'id': keyvault
},
'vault_certificates': [vault_cert]
}

if not vm.get('os_profile'):
vm['os_profile'] = {'secret': []}

if not vm.get('os_profile').get('secrets'):
vm['os_profile']['secrets'] = []

vm['os_profile']['secrets'].append(vault_secret_group)

vm = set_vm_by_aaz(cmd, vm)
return vm.get('osProfile', {}).get('secrets', [])


def list_vm_secrets(cmd, resource_group_name, vm_name):
vm = get_vm(cmd, resource_group_name, vm_name)
if vm.os_profile:
return vm.os_profile.secrets
return []
vm = get_vm_by_aaz(cmd, resource_group_name, vm_name)

if vm.get('osProfile', {}).get('secrets', []):
for secret in vm['osProfile']['secrets']:
for cert in secret.get('vaultCertificates', []):
if not cert.get('certificateStore'):
cert['certificateStore'] = None

return vm.get('osProfile', {}).get('secrets', [])


def remove_vm_secret(cmd, resource_group_name, vm_name, keyvault, certificate=None):
vm = get_vm_to_update(cmd, resource_group_name, vm_name)
from .operations.vm import convert_show_result_to_snake_case
vm = get_vm_to_update_by_aaz(cmd, resource_group_name, vm_name)

# support 2 kinds of filter:
# a. if only keyvault is supplied, we delete its whole vault group.
# b. if both keyvault and certificate are supplied, we only delete the specific cert entry.

to_keep = vm.os_profile.secrets
to_keep = vm.get('osProfile', {}).get('secrets', [])
keyvault_matched = []
if keyvault:
keyvault = keyvault.lower()
keyvault_matched = [x for x in to_keep if x.source_vault and x.source_vault.id.lower() == keyvault]
keyvault_matched = [x for x in to_keep if x.get('sourceVault', {}).get('id', '').lower() == keyvault]

if keyvault and not certificate:
to_keep = [x for x in to_keep if x not in keyvault_matched]
Expand All @@ -3342,13 +3403,15 @@ def remove_vm_secret(cmd, resource_group_name, vm_name, keyvault, certificate=No
if '://' not in cert_url_pattern: # just a cert name?
cert_url_pattern = '/' + cert_url_pattern + '/'
for x in temp:
x.vault_certificates = ([v for v in x.vault_certificates
if not (v.certificate_url and cert_url_pattern in v.certificate_url.lower())])
to_keep = [x for x in to_keep if x.vault_certificates] # purge all groups w/o any cert entries

vm.os_profile.secrets = to_keep
vm = set_vm(cmd, vm)
return vm.os_profile.secrets
x['vaultCertificates'] = [v for v in x.get('vaultCertificates')
if not (v.get('certificateUrl') and
cert_url_pattern in v.get('certificateUrl', '').lower())]
to_keep = [x for x in to_keep if x.get('vaultCertificates')] # purge all groups w/o any cert entries

vm['osProfile']['secrets'] = to_keep
vm = convert_show_result_to_snake_case(vm)
vm = set_vm_by_aaz(cmd, vm)
return vm.get('osProfile', {}).get('secrets', [])
# endregion


Expand Down
2 changes: 2 additions & 0 deletions src/azure-cli/azure/cli/command_modules/vm/operations/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,8 @@ def __call__(self, *args, **kwargs):

def convert_show_result_to_snake_case(result):
new_result = {}
if "id" in result:
new_result["id"] = result["id"]
if "extendedLocation" in result:
new_result["extended_location"] = result["extendedLocation"]
if "identity" in result:
Expand Down
Loading