diff --git a/pkg/asset/cluster/tfvars/tfvars.go b/pkg/asset/cluster/tfvars/tfvars.go index 45b5437c96d..2b55f751e3d 100644 --- a/pkg/asset/cluster/tfvars/tfvars.go +++ b/pkg/asset/cluster/tfvars/tfvars.go @@ -248,13 +248,19 @@ func (t *TerraformVariables) Generate(ctx context.Context, parents asset.Parents } } - sess, err := installConfig.AWS.Session(ctx) - if err != nil { - return err - } object := "bootstrap.ign" bucket := fmt.Sprintf("%s-bootstrap", clusterID.InfraID) - url, err := awsconfig.PresignedS3URL(sess, installConfig.Config.Platform.AWS.Region, bucket, object) + + platformAWS := installConfig.Config.Platform.AWS + client, err := awsconfig.NewS3Client(ctx, awsconfig.EndpointOptions{ + Region: platformAWS.Region, + Endpoints: platformAWS.ServiceEndpoints, + }) + if err != nil { + return fmt.Errorf("failed to create s3 client: %w", err) + } + + url, err := awsconfig.PresignedS3URL(ctx, client, bucket, object) if err != nil { return err } diff --git a/pkg/asset/installconfig/aws/clients.go b/pkg/asset/installconfig/aws/clients.go index 1d2465b76a8..92a3aa0ec27 100644 --- a/pkg/asset/installconfig/aws/clients.go +++ b/pkg/asset/installconfig/aws/clients.go @@ -9,6 +9,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/ec2" "github.com/aws/aws-sdk-go-v2/service/iam" "github.com/aws/aws-sdk-go-v2/service/route53" + "github.com/aws/aws-sdk-go-v2/service/s3" "github.com/aws/aws-sdk-go-v2/service/servicequotas" "github.com/aws/aws-sdk-go-v2/service/sts" ) @@ -116,3 +117,22 @@ func NewServiceQuotasClient(ctx context.Context, endpointOpts EndpointOptions, o return servicequotas.NewFromConfig(cfg, sqOpts...), nil } + +// NewS3Client creates a new S3 API client. +func NewS3Client(ctx context.Context, endpointOpts EndpointOptions, optFns ...func(*s3.Options)) (*s3.Client, error) { + cfg, err := GetConfigWithOptions(ctx, config.WithRegion(endpointOpts.Region)) + if err != nil { + return nil, err + } + + s3Opts := []func(*s3.Options){ + func(o *s3.Options) { + o.EndpointResolverV2 = &S3EndpointResolver{ + ServiceEndpointResolver: NewServiceEndpointResolver(endpointOpts), + } + }, + } + s3Opts = append(s3Opts, optFns...) + + return s3.NewFromConfig(cfg, s3Opts...), nil +} diff --git a/pkg/asset/installconfig/aws/endpoints.go b/pkg/asset/installconfig/aws/endpoints.go index b2571ede286..5cadd280d2a 100644 --- a/pkg/asset/installconfig/aws/endpoints.go +++ b/pkg/asset/installconfig/aws/endpoints.go @@ -187,6 +187,27 @@ func (s *ServiceQuotasEndpointResolver) ResolveEndpoint(ctx context.Context, par return servicequotas.NewDefaultEndpointResolverV2().ResolveEndpoint(ctx, params) } +// S3EndpointResolver implements EndpointResolverV2 interface for S3. +type S3EndpointResolver struct { + *ServiceEndpointResolver +} + +// ResolveEndpoint for S3. +func (s *S3EndpointResolver) ResolveEndpoint(ctx context.Context, params s3.EndpointParameters) (smithyendpoints.Endpoint, error) { + params.UseDualStack = aws.Bool(s.endpointOptions.UseDualStack) + params.UseFIPS = aws.Bool(s.endpointOptions.UseFIPS) + + // If custom endpoint not found, return default endpoint for the service. + endpoint, ok := s.endpoints[s3.ServiceID] + if !ok { + return s3.NewDefaultEndpointResolverV2().ResolveEndpoint(ctx, params) + } + + params.Endpoint = aws.String(endpoint.URL) + params.Region = aws.String(s.endpointOptions.Region) + return s3.NewDefaultEndpointResolverV2().ResolveEndpoint(ctx, params) +} + // GetDefaultServiceEndpoint will get the default service endpoint for a service and region. // Note: This uses the v1 EndpointResolver, which exposes the partition ID. func GetDefaultServiceEndpoint(ctx context.Context, service string, opts EndpointOptions) (aws.Endpoint, error) { //nolint: staticcheck diff --git a/pkg/asset/installconfig/aws/presign.go b/pkg/asset/installconfig/aws/presign.go index 8611cb9f284..65ec27a6353 100644 --- a/pkg/asset/installconfig/aws/presign.go +++ b/pkg/asset/installconfig/aws/presign.go @@ -1,24 +1,33 @@ package aws import ( + "context" + "fmt" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/s3" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/s3" +) + +const ( + // PresignExpireDuration defines the expiration duration for the generated presign url. + // Currently, this is used for bootstrap ignition. + PresignExpireDuration = 60 * time.Minute ) // PresignedS3URL returns a presigned S3 URL for a bucket/object pair -func PresignedS3URL(session *session.Session, region string, bucket string, object string) (string, error) { - client := s3.New(session, aws.NewConfig().WithRegion(region)) - req, _ := client.GetObjectRequest(&s3.GetObjectInput{ +func PresignedS3URL(ctx context.Context, client *s3.Client, bucket string, object string) (string, error) { + presignClient := s3.NewPresignClient(client) + + req, err := presignClient.PresignGetObject(ctx, &s3.GetObjectInput{ Bucket: aws.String(bucket), Key: aws.String(object), + }, func(po *s3.PresignOptions) { + po.Expires = PresignExpireDuration }) - presignedURL, err := req.Presign(60 * time.Minute) if err != nil { - return "", err + return "", fmt.Errorf("failed to get presigned url for object %s in bucket %s: %w", object, bucket, err) } - return presignedURL, nil + return req.URL, nil } diff --git a/pkg/infrastructure/aws/clusterapi/aws.go b/pkg/infrastructure/aws/clusterapi/aws.go index 3c2185a5b74..2c5d02ac0d7 100644 --- a/pkg/infrastructure/aws/clusterapi/aws.go +++ b/pkg/infrastructure/aws/clusterapi/aws.go @@ -456,20 +456,14 @@ func (p *Provider) PostDestroy(ctx context.Context, in clusterapi.PostDestroyerI // removeS3Bucket deletes an s3 bucket given its name. func removeS3Bucket(ctx context.Context, region string, bucketName string, endpoints []awstypes.ServiceEndpoint) error { - cfg, err := configv2.LoadDefaultConfig(ctx, configv2.WithRegion(region)) + client, err := awsconfig.NewS3Client(ctx, awsconfig.EndpointOptions{ + Region: region, + Endpoints: endpoints, + }) if err != nil { - return fmt.Errorf("failed to load AWS config: %w", err) + return fmt.Errorf("failed to create s3 client: %w", err) } - client := s3.NewFromConfig(cfg, func(options *s3.Options) { - options.Region = region - for _, endpoint := range endpoints { - if strings.EqualFold(endpoint.Name, "s3") { - options.BaseEndpoint = aws.String(endpoint.URL) - } - } - }) - paginator := s3.NewListObjectsV2Paginator(client, &s3.ListObjectsV2Input{Bucket: aws.String(bucketName)}) for paginator.HasMorePages() { page, err := paginator.NextPage(ctx)