diff --git a/ec2instanceconnectcli/EC2InstanceConnectCommand.py b/ec2instanceconnectcli/EC2InstanceConnectCommand.py index 15adf19..e89749a 100644 --- a/ec2instanceconnectcli/EC2InstanceConnectCommand.py +++ b/ec2instanceconnectcli/EC2InstanceConnectCommand.py @@ -16,7 +16,7 @@ class EC2InstanceConnectCommand(object): Generates commands relevant for the client. """ - def __init__(self, program, instance_bundles, key_file, flags, program_command, logger): + def __init__(self, program, instance_bundles, key_file, ssh_config, flags, program_command, logger): """ Utility class to generate program specific command. @@ -35,6 +35,7 @@ def __init__(self, program, instance_bundles, key_file, flags, program_command, self.program = program self.instance_bundles = instance_bundles self.key_file = key_file + self.ssh_config = ssh_config self.flags = flags self.program_command = program_command @@ -45,6 +46,9 @@ def get_command(self): # Start with protocol & identity file command = "{0} -i {1}".format(self.program, self.key_file) + # Add ssh_config if using ssm + if self.ssh_config is not None: + command = "{0} -F {1}".format(command, self.ssh_config) # Next add command flags if present if len(self.flags) > 0: command = "{0} {1}".format(command, self.flags) @@ -75,6 +79,8 @@ def _get_target(instance_bundle): target = '' if instance_bundle.get('host_info', None): target = "{0}@{1}".format(instance_bundle['username'], instance_bundle['host_info']) + if instance_bundle.get('ssm', False): + target = "{0}@{1}".format(instance_bundle['username'], instance_bundle['instance_id']) # file will exist only for SFTP and SCP operations. if instance_bundle.get('file', None): target = "{0}:{1}".format(target, instance_bundle['file']).lstrip(':') diff --git a/ec2instanceconnectcli/EC2InstanceConnectSSHConfig.py b/ec2instanceconnectcli/EC2InstanceConnectSSHConfig.py new file mode 100644 index 0000000..61d8f4e --- /dev/null +++ b/ec2instanceconnectcli/EC2InstanceConnectSSHConfig.py @@ -0,0 +1,94 @@ +# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. + +import os +import tempfile + +class EC2InstanceConnectSSHConfig(object): + """ + Generates a temporary ssh config file + """ + def __init__(self, instance_bundles, logger): + """ + :param instance_bundles: list of dicts that provide dns name, zone, etc information about EC2 instances + :type instance_bundles: list + :param logger: EC2 Instance Connect CLI logger to use for log messages + :type logger: ec2instanceconnectcli.EC2InstanceConnectLogger.EC2InstanceConnectLogger + """ + self.logger = logger + self.instance_bundles = instance_bundles + self.region = self.get_region() + self.tempf = self.write_config() + + def get_region(self): + region = None + for bundle in self.instance_bundles: + region = bundle['region'] + + return region + + def add_instance_name(self): + """ + Add ec2 instance id to ssh config + """ + config = "host {0}\n".format(self.instance_bundles[0]['instance_id']) + return config + + def add_ssm_proxy_command(self, config, region=None): + """ + Add ProxyCommand to ssh config + + :param config: Initial config string + :type config: basestring + :param region: Region string + :type region: basestring + """ + + region_cli = '' + + if region is not None: + region_cli = "--region {0}".format(region) + + proxy_command = "ProxyCommand sh -c \"aws {0} ssm start-session --target %h --document-name AWS-StartSSHSession --parameters 'portNumber=%p'\"".format(region_cli) + return "{0} {1}".format(config, proxy_command) + + def write_config(self): + """ + Write ssh config file to temporary directory + """ + config = self.add_instance_name() + if self.instance_bundles[0]['ssm']: + config = self.add_ssm_proxy_command(config, self.region) + tempf = tempfile.NamedTemporaryFile(delete=False) + with open(tempf.name, 'w') as f: + f.write(config) + os.chmod(tempf.name, 0o600) + tempf.file.close() + return tempf + + def get_config_file(self): + """ + Returns temporary ssh config file. + + :return: ssh config filepath + :rtype: basestring + """ + return self.tempf.name + + def __del__(self): + """ + Remove the temp ssh config file. + """ + if self.tempf is not None: + self.logger.debug('Deleting the ssh_config file: {0}'.format(self.tempf.name)) + os.remove(self.tempf.name) diff --git a/ec2instanceconnectcli/input_parser.py b/ec2instanceconnectcli/input_parser.py index 9442b01..9c1172a 100644 --- a/ec2instanceconnectcli/input_parser.py +++ b/ec2instanceconnectcli/input_parser.py @@ -46,7 +46,6 @@ def parseargs(args, mode='ssh'): :return: Args split into three pieces: EC2 instance information, command flags, and and the actual command to run :rtype: tuple """ - if len(args) < 2: raise AssertionError('Missing target') if len(args[1]) < 1: @@ -60,6 +59,7 @@ def parseargs(args, mode='ssh'): """ instance_bundles = [ { + 'ssm': args[0].ssm_connect, 'profile': args[0].profile, 'instance_id': args[0].instance_id, 'region': args[0].region, @@ -124,7 +124,6 @@ def _parse_command_flags(raw_command, instance_bundles, is_ssh=False): # This is either a flag or a flag value flags = '{0} {1}'.format(flags, raw_command[command_index]) - if raw_command[command_index][0] == '-': # Flag is_flagged = True @@ -143,7 +142,6 @@ def _parse_command_flags(raw_command, instance_bundles, is_ssh=False): command_index += 1 flags = flags.strip() - """ Target host and command or file list """ diff --git a/ec2instanceconnectcli/mops.py b/ec2instanceconnectcli/mops.py index d3144c0..68199ab 100644 --- a/ec2instanceconnectcli/mops.py +++ b/ec2instanceconnectcli/mops.py @@ -15,6 +15,7 @@ import argparse from ec2instanceconnectcli.EC2InstanceConnectCLI import EC2InstanceConnectCLI +from ec2instanceconnectcli.EC2InstanceConnectSSHConfig import EC2InstanceConnectSSHConfig from ec2instanceconnectcli.EC2InstanceConnectKey import EC2InstanceConnectKey from ec2instanceconnectcli.EC2InstanceConnectCommand import EC2InstanceConnectCommand from ec2instanceconnectcli.EC2InstanceConnectLogger import EC2InstanceConnectLogger @@ -37,7 +38,7 @@ def main(program, mode): usage = "" if mode == "ssh": usage=""" - mssh [-t instance_id] [-u profile] [-z availability_zone] [-r region] [supported ssh flags] target [command] + mssh [-t instance_id] [-u profile] [-z availability_zone] [-r region] [-ssm] [supported ssh flags] target [command] target => [user@]instance_id | [user@]hostname [supported ssh flags] => [-l login_name] [-p port] @@ -54,6 +55,7 @@ def main(program, mode): parser.add_argument('-z', '--zone', action='store', help='Availability zone', type=str, metavar='') parser.add_argument('-u', '--profile', action='store', help='AWS Config Profile', type=str, default=DEFAULT_PROFILE, metavar='') parser.add_argument('-t', '--instance_id', action='store', help='EC2 Instance ID. Required if target is hostname', type=str, default=DEFAULT_INSTANCE, metavar='') + parser.add_argument('-ssm', '--ssm_connect', action='store_true', help='Connect to EC2 Instance ID through SSM') parser.add_argument('-d', '--debug', action="store_true", help='Turn on debug logging') args = parser.parse_known_args() @@ -66,9 +68,11 @@ def main(program, mode): parser.print_help() sys.exit(1) + #Generate temporary ssh config + ssh_config = EC2InstanceConnectSSHConfig(instance_bundles, logger.get_logger()) #Generate temp key cli_key = EC2InstanceConnectKey(logger.get_logger()) - cli_command = EC2InstanceConnectCommand(program, instance_bundles, cli_key.get_priv_key_file(), flags, program_command, logger.get_logger()) + cli_command = EC2InstanceConnectCommand(program, instance_bundles, cli_key.get_priv_key_file(), ssh_config.get_config_file(), flags, program_command, logger.get_logger()) try: # TODO: Handling for if the '-i' flag is passed diff --git a/tests/test_EC2ConnectCLI.py b/tests/test_EC2ConnectCLI.py index 910f567..904b5cf 100644 --- a/tests/test_EC2ConnectCLI.py +++ b/tests/test_EC2ConnectCLI.py @@ -27,6 +27,7 @@ def test_mssh_no_target(self, mock_push_key, mock_run): mock_file = 'identity' + ssh_config = 'ssh_config' flag = '-f flag' command = 'command arg' logger = EC2InstanceConnectLogger() @@ -37,11 +38,11 @@ def test_mssh_no_target(self, mock_instance_data.return_value = self.instance_info mock_push_key.return_value = None - cli_command = EC2InstanceConnectCommand("ssh", instance_bundles, mock_file, flag, command, logger.get_logger()) + cli_command = EC2InstanceConnectCommand("ssh", instance_bundles, mock_file, ssh_config, flag, command, logger.get_logger()) cli = EC2InstanceConnectCLI(instance_bundles, "", cli_command, logger.get_logger()) cli.invoke_command() - expected_command = "ssh -i {0} {1} {2}@{3} {4}".format(mock_file, flag, self.default_user, + expected_command = "ssh -i {0} -F {1} {2} {3}@{4} {5}".format(mock_file, ssh_config, flag, self.default_user, self.public_ip, command) # Check that we successfully get to the run @@ -59,6 +60,7 @@ def test_mssh_no_target_no_public_ip(self, mock_run): mock_file = "identity" flag = '-f flag' + ssh_config = 'ssh_config' command = 'command arg' logger = EC2InstanceConnectLogger() instance_bundles = [{'username': self.default_user, 'instance_id': self.instance_id, @@ -68,11 +70,11 @@ def test_mssh_no_target_no_public_ip(self, mock_instance_data.return_value = self.private_instance_info mock_push_key.return_value = None - cli_command = EC2InstanceConnectCommand("ssh", instance_bundles, mock_file, flag, command, logger.get_logger()) + cli_command = EC2InstanceConnectCommand("ssh", instance_bundles, mock_file, ssh_config, flag, command, logger.get_logger()) cli = EC2InstanceConnectCLI(instance_bundles, "", cli_command, logger.get_logger()) cli.invoke_command() - expected_command = "ssh -i {0} {1} {2}@{3} {4}".format(mock_file, flag, self.default_user, + expected_command = "ssh -i {0} -F {1} {2} {3}@{4} {5}".format(mock_file, ssh_config, flag, self.default_user, self.private_ip, command) # Check that we successfully get to the run @@ -88,6 +90,7 @@ def test_mssh_with_target(self, mock_push_key, mock_run): mock_file = 'identity' + ssh_config = 'ssh_config' flag = '-f flag' command = 'command arg' host = '0.0.0.0' @@ -99,11 +102,43 @@ def test_mssh_with_target(self, mock_instance_data.return_value = self.instance_info mock_push_key.return_value = None - cli_command = EC2InstanceConnectCommand("ssh", instance_bundles, mock_file, flag, command, logger.get_logger()) + cli_command = EC2InstanceConnectCommand("ssh", instance_bundles, mock_file, ssh_config, flag, command, logger.get_logger()) cli = EC2InstanceConnectCLI(instance_bundles, "", cli_command, logger.get_logger()) cli.invoke_command() - expected_command = "ssh -i {0} {1} {2}@{3} {4}".format(mock_file, flag, self.default_user, + expected_command = "ssh -i {0} -F {1} {2} {3}@{4} {5}".format(mock_file, ssh_config, flag, self.default_user, + host, command) + # Check that we successfully get to the run + # Since both target and availability_zone are provided, mock_instance_data should not be called + self.assertFalse(mock_instance_data.called) + self.assertTrue(mock_push_key.called) + mock_run.assert_called_with(expected_command) + + @mock.patch('ec2instanceconnectcli.EC2InstanceConnectCLI.EC2InstanceConnectCLI.run_command') + @mock.patch('ec2instanceconnectcli.key_publisher.push_public_key') + @mock.patch('ec2instanceconnectcli.ec2_util.get_instance_data') + def test_mssh_with_ssm_target(self, + mock_instance_data, + mock_push_key, + mock_run): + mock_file = 'identity' + ssh_config = 'ssh_config' + flag = '-f flag' + command = 'command arg' + host = self.instance_id + logger = EC2InstanceConnectLogger() + instance_bundles = [{'username': self.default_user, 'instance_id': self.instance_id, + 'target': host, 'zone': self.availability_zone, 'region': self.region, 'ssm': True, + 'profile': self.profile}] + + mock_instance_data.return_value = self.instance_info + mock_push_key.return_value = None + + cli_command = EC2InstanceConnectCommand("ssh", instance_bundles, mock_file, ssh_config, flag, command, logger.get_logger()) + cli = EC2InstanceConnectCLI(instance_bundles, "", cli_command, logger.get_logger()) + cli.invoke_command() + + expected_command = "ssh -i {0} -F {1} {2} {3}@{4} {5}".format(mock_file, ssh_config, flag, self.default_user, host, command) # Check that we successfully get to the run # Since both target and availability_zone are provided, mock_instance_data should not be called @@ -119,6 +154,7 @@ def test_msftp(self, mock_push_key, mock_run): mock_file = 'identity' + ssh_config = 'ssh_config' flag = '-f flag' command = 'file2 file3' logger = EC2InstanceConnectLogger() @@ -129,10 +165,11 @@ def test_msftp(self, mock_instance_data.return_value = self.instance_info mock_push_key.return_value = None - expected_command = "sftp -i {0} {1} {2}@{3}:{4} {5}".format(mock_file, flag, self.default_user, + expected_command = "sftp -i {0} -F {1} {2} {3}@{4}:{5} {6}".format(mock_file, ssh_config, flag, self.default_user, self.public_ip, 'file1', command) - cli_command = EC2InstanceConnectCommand("sftp", instance_bundles, mock_file, flag, command, logger.get_logger()) + cli_command = EC2InstanceConnectCommand("sftp", instance_bundles, mock_file, ssh_config, + flag, command, logger.get_logger()) cli = EC2InstanceConnectCLI(instance_bundles, "", cli_command, logger.get_logger()) cli.invoke_command() @@ -150,6 +187,7 @@ def test_mscp(self, mock_run): mock_file = 'identity' flag = '-f flag' + ssh_config = 'ssh_config' command = 'file2 file3' logger = EC2InstanceConnectLogger() instance_bundles = [{'username': self.default_user, 'instance_id': self.instance_id, @@ -162,12 +200,12 @@ def test_mscp(self, mock_instance_data.return_value = self.instance_info mock_push_key.return_value = None - expected_command = "scp -i {0} {1} {2}@{3}:{4} {5} {6}@{7}:{8}".format(mock_file, flag, self.default_user, + expected_command = "scp -i {0} -F {1} {2} {3}@{4}:{5} {6} {7}@{8}:{9}".format(mock_file, ssh_config, flag, self.default_user, self.public_ip, 'file1', command, self.default_user, self.public_ip, 'file4') - cli_command = EC2InstanceConnectCommand("scp", instance_bundles, mock_file, flag, command, logger.get_logger()) + cli_command = EC2InstanceConnectCommand("scp", instance_bundles, mock_file, ssh_config, flag, command, logger.get_logger()) cli = EC2InstanceConnectCLI(instance_bundles, "", cli_command, logger.get_logger()) cli.invoke_command() diff --git a/tests/test_input_parser.py b/tests/test_input_parser.py index 6afc2d4..b7eac7c 100644 --- a/tests/test_input_parser.py +++ b/tests/test_input_parser.py @@ -33,6 +33,7 @@ class TestInputParser(TestBase): parser.add_argument('-T', '--dest_instance_id', action='store', type=str, default='', help='EC2 Instance ID. Required if destination is a second instance and is given as a DNS name' 'or IP address') + parser.add_argument('-ssm', '--ssm_connect', action='store_true', help='Connect to instance with instance id through SSM') def test_basic_target(self): args = self.parser.parse_known_args(['-u', self.profile, self.instance_id]) @@ -40,7 +41,17 @@ def test_basic_target(self): bundles, flags, command = input_parser.parseargs(args) self.assertEqual(bundles, [{'username': self.default_user, 'instance_id': self.instance_id, - 'target': None, 'zone': None, 'region': None, 'profile': self.profile}]) + 'target': None, 'zone': None, 'region': None, 'ssm': False, 'profile': self.profile}]) + self.assertEqual(flags, '') + self.assertEqual(command, '') + + def test_ssm_target(self): + args = self.parser.parse_known_args(['-u', self.profile, '-ssm', self.instance_id]) + + bundles, flags, command = input_parser.parseargs(args) + + self.assertEqual(bundles, [{'username': self.default_user, 'instance_id': self.instance_id, + 'target': None, 'zone': None, 'region': None, 'ssm': True, 'profile': self.profile}]) self.assertEqual(flags, '') self.assertEqual(command, '') @@ -50,7 +61,7 @@ def test_username(self): bundles, flags, command = input_parser.parseargs(args) self.assertEqual(bundles, [{'username': 'myuser', 'instance_id': self.instance_id, - 'target': None, 'zone': None, 'region': None, 'profile': self.profile}]) + 'target': None, 'zone': None, 'region': None, 'ssm': False, 'profile': self.profile}]) self.assertEqual(flags, '') self.assertEqual(command, '') @@ -62,7 +73,7 @@ def test_dns_name(self): self.assertEqual(bundles, [{'username': self.default_user, 'instance_id': self.instance_id, 'target': self.dns_name, 'zone': self.availability_zone, - 'region': self.region, 'profile': self.profile}]) + 'region': self.region, 'ssm': False, 'profile': self.profile}]) self.assertEqual(flags, '') self.assertEqual(command, '') @@ -72,7 +83,7 @@ def test_flags(self): bundles, flags, command = input_parser.parseargs(args) self.assertEqual(bundles, [{'username': 'login', 'instance_id': self.instance_id, - 'target': None, 'zone': None, 'region': None, 'profile': self.profile}]) + 'target': None, 'zone': None, 'region': None, 'ssm': False, 'profile': self.profile}]) self.assertEqual(flags, '-1 -l login') self.assertEqual(command, '') @@ -82,7 +93,7 @@ def test_command(self): bundles, flags, command = input_parser.parseargs(args) self.assertEqual(bundles, [{'username': self.default_user, 'instance_id': self.instance_id, - 'target': None, 'zone': None, 'region': None, 'profile': self.profile}]) + 'target': None, 'zone': None, 'region': None, 'ssm': False, 'profile': self.profile}]) self.assertEqual(flags, '') self.assertEqual(command, 'uname -a') @@ -93,7 +104,7 @@ def test_sftp(self): bundles, flags, command = input_parser.parseargs(args, 'sftp') self.assertEqual(bundles, [{'username': self.default_user, 'instance_id': self.instance_id, - 'target': None, 'zone': None, 'region': None, 'profile': self.profile, + 'target': None, 'zone': None, 'region': None, 'ssm': False, 'profile': self.profile, 'file': 'first_file'}]) self.assertEqual(flags, '') self.assertEqual(command, 'second_file')