From 84e24b1292f1d56b116a753dad5831ee2f02e0e1 Mon Sep 17 00:00:00 2001 From: Scott Kidder Date: Mon, 4 Nov 2019 17:20:34 -0800 Subject: [PATCH] Allow ExternalId to be supplied to STS AssumeRole API request --- auth_assumerole.go | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/auth_assumerole.go b/auth_assumerole.go index e71773e..f225686 100644 --- a/auth_assumerole.go +++ b/auth_assumerole.go @@ -14,11 +14,19 @@ import ( // stsAuth object is used to authenticate to STS to fetch temporary credentials // for the desired role. func NewAuthWithAssumedRole(roleArn, sessionName, region string, stsAuth Auth) (Auth, error) { + return NewAuthWithAssumedRoleAndExternalID(roleArn, sessionName, region, "", stsAuth) +} + +// NewAuthWithAssumedRoleAndExternalID works just like NewAuthWithAssumedRole, except +// that it includes the supplied externalID parameter as the 'ExternalId' parameter +// in the authentication request to STS. +func NewAuthWithAssumedRoleAndExternalID(roleArn, sessionName, region, externalID string, stsAuth Auth) (Auth, error) { return newCachedMutexedWarmedUpAuth(&stsCreds{ RoleARN: roleArn, SessionName: sessionName, Region: region, STSAuth: stsAuth, + ExternalID: externalID, }) } @@ -27,15 +35,20 @@ type stsCreds struct { SessionName string Region string STSAuth Auth + ExternalID string } func (sts *stsCreds) ExpiringKeyForSigning(now time.Time) (*SigningKey, time.Time, error) { - r, err := http.NewRequest(http.MethodPost, fmt.Sprintf("https://sts.%s.amazonaws.com/?%s", sts.Region, (url.Values{ + params := url.Values{ "Version": []string{"2011-06-15"}, "Action": []string{"AssumeRole"}, "RoleSessionName": []string{sts.SessionName}, "RoleArn": []string{sts.RoleARN}, - }).Encode()), bytes.NewReader([]byte{})) + } + if sts.ExternalID != "" { + params["ExternalId"] = []string{sts.ExternalID} + } + r, err := http.NewRequest(http.MethodPost, fmt.Sprintf("https://sts.%s.amazonaws.com/?%s", sts.Region, params.Encode()), bytes.NewReader([]byte{})) if err != nil { return nil, time.Time{}, err }