Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion ec2instanceconnectcli/EC2InstanceConnectCommand.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class EC2InstanceConnectCommand(object):
Generates commands relevant for the client.
"""

def __init__(self, program, instance_bundles, key_file, flags, program_command, logger):
def __init__(self, program, instance_bundles, key_file, ssh_config, flags, program_command, logger):
"""
Utility class to generate program specific command.

Expand All @@ -35,6 +35,7 @@ def __init__(self, program, instance_bundles, key_file, flags, program_command,
self.program = program
self.instance_bundles = instance_bundles
self.key_file = key_file
self.ssh_config = ssh_config
self.flags = flags
self.program_command = program_command

Expand All @@ -45,6 +46,9 @@ def get_command(self):
# Start with protocol & identity file
command = "{0} -i {1}".format(self.program, self.key_file)

# Add ssh_config if using ssm
if self.ssh_config is not None:
command = "{0} -F {1}".format(command, self.ssh_config)
# Next add command flags if present
if len(self.flags) > 0:
command = "{0} {1}".format(command, self.flags)
Expand Down Expand Up @@ -75,6 +79,8 @@ def _get_target(instance_bundle):
target = ''
if instance_bundle.get('host_info', None):
target = "{0}@{1}".format(instance_bundle['username'], instance_bundle['host_info'])
if instance_bundle.get('ssm', False):
target = "{0}@{1}".format(instance_bundle['username'], instance_bundle['instance_id'])
# file will exist only for SFTP and SCP operations.
if instance_bundle.get('file', None):
target = "{0}:{1}".format(target, instance_bundle['file']).lstrip(':')
Expand Down
94 changes: 94 additions & 0 deletions ec2instanceconnectcli/EC2InstanceConnectSSHConfig.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.

import os
import tempfile

class EC2InstanceConnectSSHConfig(object):
"""
Generates a temporary ssh config file
"""
def __init__(self, instance_bundles, logger):
"""
:param instance_bundles: list of dicts that provide dns name, zone, etc information about EC2 instances
:type instance_bundles: list
:param logger: EC2 Instance Connect CLI logger to use for log messages
:type logger: ec2instanceconnectcli.EC2InstanceConnectLogger.EC2InstanceConnectLogger
"""
self.logger = logger
self.instance_bundles = instance_bundles
self.region = self.get_region()
self.tempf = self.write_config()

def get_region(self):
region = None
for bundle in self.instance_bundles:
region = bundle['region']

return region

def add_instance_name(self):
"""
Add ec2 instance id to ssh config
"""
config = "host {0}\n".format(self.instance_bundles[0]['instance_id'])
return config

def add_ssm_proxy_command(self, config, region=None):
"""
Add ProxyCommand to ssh config

:param config: Initial config string
:type config: basestring
:param region: Region string
:type region: basestring
"""

region_cli = ''

if region is not None:
region_cli = "--region {0}".format(region)

proxy_command = "ProxyCommand sh -c \"aws {0} ssm start-session --target %h --document-name AWS-StartSSHSession --parameters 'portNumber=%p'\"".format(region_cli)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Command should include the --profile argument if it was used when called mssh

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also it should adapt when executed on windows environments

ProxyCommand C:\Windows\System32\WindowsPowerShell\v1.0\powershell.exe "aws ssm start-session --target %h --document-name AWS-StartSSHSession --parameters portNumber=%p"

return "{0} {1}".format(config, proxy_command)

def write_config(self):
"""
Write ssh config file to temporary directory
"""
config = self.add_instance_name()
if self.instance_bundles[0]['ssm']:
config = self.add_ssm_proxy_command(config, self.region)
tempf = tempfile.NamedTemporaryFile(delete=False)
with open(tempf.name, 'w') as f:
f.write(config)
os.chmod(tempf.name, 0o600)
tempf.file.close()
return tempf

def get_config_file(self):
"""
Returns temporary ssh config file.

:return: ssh config filepath
:rtype: basestring
"""
return self.tempf.name

def __del__(self):
"""
Remove the temp ssh config file.
"""
if self.tempf is not None:
self.logger.debug('Deleting the ssh_config file: {0}'.format(self.tempf.name))
os.remove(self.tempf.name)
4 changes: 1 addition & 3 deletions ec2instanceconnectcli/input_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ def parseargs(args, mode='ssh'):
:return: Args split into three pieces: EC2 instance information, command flags, and and the actual command to run
:rtype: tuple
"""

if len(args) < 2:
raise AssertionError('Missing target')
if len(args[1]) < 1:
Expand All @@ -60,6 +59,7 @@ def parseargs(args, mode='ssh'):
"""
instance_bundles = [
{
'ssm': args[0].ssm_connect,
'profile': args[0].profile,
'instance_id': args[0].instance_id,
'region': args[0].region,
Expand Down Expand Up @@ -124,7 +124,6 @@ def _parse_command_flags(raw_command, instance_bundles, is_ssh=False):

# This is either a flag or a flag value
flags = '{0} {1}'.format(flags, raw_command[command_index])

if raw_command[command_index][0] == '-':
# Flag
is_flagged = True
Expand All @@ -143,7 +142,6 @@ def _parse_command_flags(raw_command, instance_bundles, is_ssh=False):
command_index += 1

flags = flags.strip()

"""
Target host and command or file list
"""
Expand Down
8 changes: 6 additions & 2 deletions ec2instanceconnectcli/mops.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import argparse

from ec2instanceconnectcli.EC2InstanceConnectCLI import EC2InstanceConnectCLI
from ec2instanceconnectcli.EC2InstanceConnectSSHConfig import EC2InstanceConnectSSHConfig
from ec2instanceconnectcli.EC2InstanceConnectKey import EC2InstanceConnectKey
from ec2instanceconnectcli.EC2InstanceConnectCommand import EC2InstanceConnectCommand
from ec2instanceconnectcli.EC2InstanceConnectLogger import EC2InstanceConnectLogger
Expand All @@ -37,7 +38,7 @@ def main(program, mode):
usage = ""
if mode == "ssh":
usage="""
mssh [-t instance_id] [-u profile] [-z availability_zone] [-r region] [supported ssh flags] target [command]
mssh [-t instance_id] [-u profile] [-z availability_zone] [-r region] [-ssm] [supported ssh flags] target [command]

target => [user@]instance_id | [user@]hostname
[supported ssh flags] => [-l login_name] [-p port]
Expand All @@ -54,6 +55,7 @@ def main(program, mode):
parser.add_argument('-z', '--zone', action='store', help='Availability zone', type=str, metavar='')
parser.add_argument('-u', '--profile', action='store', help='AWS Config Profile', type=str, default=DEFAULT_PROFILE, metavar='')
parser.add_argument('-t', '--instance_id', action='store', help='EC2 Instance ID. Required if target is hostname', type=str, default=DEFAULT_INSTANCE, metavar='')
parser.add_argument('-ssm', '--ssm_connect', action='store_true', help='Connect to EC2 Instance ID through SSM')
parser.add_argument('-d', '--debug', action="store_true", help='Turn on debug logging')

args = parser.parse_known_args()
Expand All @@ -66,9 +68,11 @@ def main(program, mode):
parser.print_help()
sys.exit(1)

#Generate temporary ssh config
ssh_config = EC2InstanceConnectSSHConfig(instance_bundles, logger.get_logger())
#Generate temp key
cli_key = EC2InstanceConnectKey(logger.get_logger())
cli_command = EC2InstanceConnectCommand(program, instance_bundles, cli_key.get_priv_key_file(), flags, program_command, logger.get_logger())
cli_command = EC2InstanceConnectCommand(program, instance_bundles, cli_key.get_priv_key_file(), ssh_config.get_config_file(), flags, program_command, logger.get_logger())

try:
# TODO: Handling for if the '-i' flag is passed
Expand Down
58 changes: 48 additions & 10 deletions tests/test_EC2ConnectCLI.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def test_mssh_no_target(self,
mock_push_key,
mock_run):
mock_file = 'identity'
ssh_config = 'ssh_config'
flag = '-f flag'
command = 'command arg'
logger = EC2InstanceConnectLogger()
Expand All @@ -37,11 +38,11 @@ def test_mssh_no_target(self,
mock_instance_data.return_value = self.instance_info
mock_push_key.return_value = None

cli_command = EC2InstanceConnectCommand("ssh", instance_bundles, mock_file, flag, command, logger.get_logger())
cli_command = EC2InstanceConnectCommand("ssh", instance_bundles, mock_file, ssh_config, flag, command, logger.get_logger())
cli = EC2InstanceConnectCLI(instance_bundles, "", cli_command, logger.get_logger())
cli.invoke_command()

expected_command = "ssh -i {0} {1} {2}@{3} {4}".format(mock_file, flag, self.default_user,
expected_command = "ssh -i {0} -F {1} {2} {3}@{4} {5}".format(mock_file, ssh_config, flag, self.default_user,
self.public_ip, command)

# Check that we successfully get to the run
Expand All @@ -59,6 +60,7 @@ def test_mssh_no_target_no_public_ip(self,
mock_run):
mock_file = "identity"
flag = '-f flag'
ssh_config = 'ssh_config'
command = 'command arg'
logger = EC2InstanceConnectLogger()
instance_bundles = [{'username': self.default_user, 'instance_id': self.instance_id,
Expand All @@ -68,11 +70,11 @@ def test_mssh_no_target_no_public_ip(self,
mock_instance_data.return_value = self.private_instance_info
mock_push_key.return_value = None

cli_command = EC2InstanceConnectCommand("ssh", instance_bundles, mock_file, flag, command, logger.get_logger())
cli_command = EC2InstanceConnectCommand("ssh", instance_bundles, mock_file, ssh_config, flag, command, logger.get_logger())
cli = EC2InstanceConnectCLI(instance_bundles, "", cli_command, logger.get_logger())
cli.invoke_command()

expected_command = "ssh -i {0} {1} {2}@{3} {4}".format(mock_file, flag, self.default_user,
expected_command = "ssh -i {0} -F {1} {2} {3}@{4} {5}".format(mock_file, ssh_config, flag, self.default_user,
self.private_ip, command)

# Check that we successfully get to the run
Expand All @@ -88,6 +90,7 @@ def test_mssh_with_target(self,
mock_push_key,
mock_run):
mock_file = 'identity'
ssh_config = 'ssh_config'
flag = '-f flag'
command = 'command arg'
host = '0.0.0.0'
Expand All @@ -99,11 +102,43 @@ def test_mssh_with_target(self,
mock_instance_data.return_value = self.instance_info
mock_push_key.return_value = None

cli_command = EC2InstanceConnectCommand("ssh", instance_bundles, mock_file, flag, command, logger.get_logger())
cli_command = EC2InstanceConnectCommand("ssh", instance_bundles, mock_file, ssh_config, flag, command, logger.get_logger())
cli = EC2InstanceConnectCLI(instance_bundles, "", cli_command, logger.get_logger())
cli.invoke_command()

expected_command = "ssh -i {0} {1} {2}@{3} {4}".format(mock_file, flag, self.default_user,
expected_command = "ssh -i {0} -F {1} {2} {3}@{4} {5}".format(mock_file, ssh_config, flag, self.default_user,
host, command)
# Check that we successfully get to the run
# Since both target and availability_zone are provided, mock_instance_data should not be called
self.assertFalse(mock_instance_data.called)
self.assertTrue(mock_push_key.called)
mock_run.assert_called_with(expected_command)

@mock.patch('ec2instanceconnectcli.EC2InstanceConnectCLI.EC2InstanceConnectCLI.run_command')
@mock.patch('ec2instanceconnectcli.key_publisher.push_public_key')
@mock.patch('ec2instanceconnectcli.ec2_util.get_instance_data')
def test_mssh_with_ssm_target(self,
mock_instance_data,
mock_push_key,
mock_run):
mock_file = 'identity'
ssh_config = 'ssh_config'
flag = '-f flag'
command = 'command arg'
host = self.instance_id
logger = EC2InstanceConnectLogger()
instance_bundles = [{'username': self.default_user, 'instance_id': self.instance_id,
'target': host, 'zone': self.availability_zone, 'region': self.region, 'ssm': True,
'profile': self.profile}]

mock_instance_data.return_value = self.instance_info
mock_push_key.return_value = None

cli_command = EC2InstanceConnectCommand("ssh", instance_bundles, mock_file, ssh_config, flag, command, logger.get_logger())
cli = EC2InstanceConnectCLI(instance_bundles, "", cli_command, logger.get_logger())
cli.invoke_command()

expected_command = "ssh -i {0} -F {1} {2} {3}@{4} {5}".format(mock_file, ssh_config, flag, self.default_user,
host, command)
# Check that we successfully get to the run
# Since both target and availability_zone are provided, mock_instance_data should not be called
Expand All @@ -119,6 +154,7 @@ def test_msftp(self,
mock_push_key,
mock_run):
mock_file = 'identity'
ssh_config = 'ssh_config'
flag = '-f flag'
command = 'file2 file3'
logger = EC2InstanceConnectLogger()
Expand All @@ -129,10 +165,11 @@ def test_msftp(self,
mock_instance_data.return_value = self.instance_info
mock_push_key.return_value = None

expected_command = "sftp -i {0} {1} {2}@{3}:{4} {5}".format(mock_file, flag, self.default_user,
expected_command = "sftp -i {0} -F {1} {2} {3}@{4}:{5} {6}".format(mock_file, ssh_config, flag, self.default_user,
self.public_ip, 'file1', command)

cli_command = EC2InstanceConnectCommand("sftp", instance_bundles, mock_file, flag, command, logger.get_logger())
cli_command = EC2InstanceConnectCommand("sftp", instance_bundles, mock_file, ssh_config,
flag, command, logger.get_logger())
cli = EC2InstanceConnectCLI(instance_bundles, "", cli_command, logger.get_logger())
cli.invoke_command()

Expand All @@ -150,6 +187,7 @@ def test_mscp(self,
mock_run):
mock_file = 'identity'
flag = '-f flag'
ssh_config = 'ssh_config'
command = 'file2 file3'
logger = EC2InstanceConnectLogger()
instance_bundles = [{'username': self.default_user, 'instance_id': self.instance_id,
Expand All @@ -162,12 +200,12 @@ def test_mscp(self,
mock_instance_data.return_value = self.instance_info
mock_push_key.return_value = None

expected_command = "scp -i {0} {1} {2}@{3}:{4} {5} {6}@{7}:{8}".format(mock_file, flag, self.default_user,
expected_command = "scp -i {0} -F {1} {2} {3}@{4}:{5} {6} {7}@{8}:{9}".format(mock_file, ssh_config, flag, self.default_user,
self.public_ip, 'file1', command,
self.default_user,
self.public_ip, 'file4')

cli_command = EC2InstanceConnectCommand("scp", instance_bundles, mock_file, flag, command, logger.get_logger())
cli_command = EC2InstanceConnectCommand("scp", instance_bundles, mock_file, ssh_config, flag, command, logger.get_logger())
cli = EC2InstanceConnectCLI(instance_bundles, "", cli_command, logger.get_logger())
cli.invoke_command()

Expand Down
Loading