diff --git a/cerberus/aws_auth.py b/cerberus/aws_auth.py index 3d3d451..02540da 100644 --- a/cerberus/aws_auth.py +++ b/cerberus/aws_auth.py @@ -17,8 +17,11 @@ # For python 2.7 from __future__ import print_function from botocore import session, awsrequest, auth +from boto3.session import Session import logging import sys +import os +from datetime import datetime from . import CerberusClientException, CLIENT_VERSION from .network_util import throw_if_bad_response, post_with_retry @@ -28,6 +31,75 @@ logger = logging.getLogger(__name__) +def _get_aws_role_session_name(): + return os.environ.get( + "AWS_ROLE_SESSION_NAME", + "cerberus-python-client-{}".format( + datetime.now() + .replace(microsecond=0, tzinfo=None) + .isoformat() + .replace(":", "-") + .replace(".", "-") + ), + ) + + +def _get_aws_web_identity_token(): + web_identity_token = None + web_identity_token_file = os.environ.get( + "AWS_WEB_IDENTITY_TOKEN_FILE", "" + ) + if web_identity_token_file: + with open(web_identity_token_file, "r") as web_identity_token_file_io: + web_identity_token = web_identity_token_file_io.read().strip() + return web_identity_token + + +def _get_aws_credentials(aws_session=None): + """ + Retrieve AWS credentials from a boto3 session. + + Parameters: + + - session (boto3.session.Session|None) = None: A boto3 session from + which to infer credentials. If not provided, a session will be + created using [AWS environment variables](https://go.aws/42HrK0R). + """ + if not aws_session: + arn = None + profile_name = os.environ.get("AWS_PROFILE", None) + aws_session = Session( + profile_name=profile_name + ) + if not profile_name: + # We only infer an assumed AWS role if not using a profile. + # For profiles, the assumed AWS role should be in + # the profile parameters. + arn = os.environ.get("AWS_ROLE_ARN", "") + if arn: + web_identity_token = _get_aws_web_identity_token() + session_name = _get_aws_role_session_name() + if web_identity_token: + credentials = session.client( + "sts" + ).assume_role_with_web_identity( + RoleArn=arn, + RoleSessionName=session_name, + WebIdentityToken=web_identity_token, + )["Credentials"] + else: + credentials = session.client("sts").assume_role( + RoleArn=arn, + RoleSessionName=session_name, + )["Credentials"] + aws_session = Session( + aws_access_key_id=credentials["AccessKeyId"], + aws_secret_access_key=credentials["SecretAccessKey"], + aws_session_token=credentials["SessionToken"], + ) + return aws_session.get_credentials() + + class AWSAuth(object): """Class to authenticate with an IAM Role""" CN_REGIONS = {"cn-north-1", "cn-northwest-1"} @@ -42,11 +114,7 @@ def __init__(self, cerberus_url, region, aws_session=None, verbose=None): def _get_v4_signed_headers(self): """Returns V4 signed get-caller-identity request headers""" - if self.aws_session is None: - boto_session = session.Session() - creds = boto_session.get_credentials() - else: - creds = self.aws_session.get_credentials() + creds = _get_aws_credentials(self.aws_session) if creds is None: raise CerberusClientException("Unable to locate AWS credentials") readonly_credentials = creds.get_frozen_credentials()