Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions collector/internal/extensionapi/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ type RegisterResponse struct {
FunctionName string `json:"functionName"`
FunctionVersion string `json:"functionVersion"`
Handler string `json:"handler"`
AccountID string `json:"accountId"`
ExtensionID string
}

Expand Down Expand Up @@ -65,9 +66,10 @@ const (
)

const (
extensionNameHeader = "Lambda-Extension-Name"
extensionIdentiferHeader = "Lambda-Extension-Identifier"
extensionErrorType = "Lambda-Extension-Function-Error-Type"
extensionNameHeader = "Lambda-Extension-Name"
extensionIdentiferHeader = "Lambda-Extension-Identifier"
extensionErrorType = "Lambda-Extension-Function-Error-Type"
extensionAcceptFeatureHeader = "Lambda-Extension-Accept-Feature"
)

// Client is a simple client for the Lambda Extensions API.
Expand Down Expand Up @@ -104,6 +106,7 @@ func (e *Client) Register(ctx context.Context, filename string) (*RegisterRespon
return nil, err
}
req.Header.Set(extensionNameHeader, filename)
req.Header.Set(extensionAcceptFeatureHeader, "accountId")

var registerResp RegisterResponse
resp, err := e.doRequest(req, &registerResp)
Expand Down
73 changes: 73 additions & 0 deletions collector/internal/extensionapi/client_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
// Copyright The OpenTelemetry Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package extensionapi

import (
"context"
"net/http"
"net/http/httptest"
"net/url"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/zap/zaptest"
)

func TestRegisterSendsAcceptFeatureHeader(t *testing.T) {
var receivedAcceptFeature string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
receivedAcceptFeature = r.Header.Get("Lambda-Extension-Accept-Feature")
w.Header().Set("Lambda-Extension-Identifier", "test-ext-id")
w.WriteHeader(200)
_, _ = w.Write([]byte(`{"functionName":"my-func","functionVersion":"$LATEST","handler":"index.handler","accountId":"123456789012"}`))
}))
defer server.Close()

u, err := url.Parse(server.URL)
require.NoError(t, err)

logger := zaptest.NewLogger(t)
// The client prepends "http://" and appends "/2020-01-01/extension", so we
// need to set up the server path accordingly. Instead, construct the client
// with an empty base and override.
client := NewClient(logger, u.Host)
resp, err := client.Register(context.Background(), "test-extension")
require.NoError(t, err)

assert.Equal(t, "accountId", receivedAcceptFeature)
assert.Equal(t, "123456789012", resp.AccountID)
assert.Equal(t, "my-func", resp.FunctionName)
assert.Equal(t, "test-ext-id", resp.ExtensionID)
}

func TestRegisterParsesAccountIDWithLeadingZeros(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Lambda-Extension-Identifier", "ext-id")
w.WriteHeader(200)
_, _ = w.Write([]byte(`{"functionName":"f","functionVersion":"v","handler":"h","accountId":"000123456789"}`))
}))
defer server.Close()

u, err := url.Parse(server.URL)
require.NoError(t, err)

logger := zaptest.NewLogger(t)
client := NewClient(logger, u.Host)
resp, err := client.Register(context.Background(), "test-extension")
require.NoError(t, err)

assert.Equal(t, "000123456789", resp.AccountID, "leading zeros must be preserved")
}
15 changes: 15 additions & 0 deletions collector/internal/lifecycle/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ import (
"github.com/open-telemetry/opentelemetry-lambda/collector/lambdacomponents"
)

const accountIDSymlinkPath = "/tmp/.otel-account-id"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small nit:

Suggested change
const accountIDSymlinkPath = "/tmp/.otel-account-id"
const accountIDSymlinkPath = "/tmp/.otel-aws-account-id"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like you've already made the PRs for all the contrib repos, so to save the effort of changing the name everywhere, .otel-account-id is probably fine.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree since all the other PRs have already been created, we can leave it as is.


var (
extensionName = filepath.Base(os.Args[0]) // extension name has to match the filename
)
Expand Down Expand Up @@ -68,6 +70,8 @@ func NewManager(ctx context.Context, logger *zap.Logger, version string) (contex
logger.Fatal("Cannot register extension", zap.Error(err))
}

writeAccountIDSymlink(logger, res.AccountID)

listener := telemetryapi.NewListener(logger)
addr, err := listener.Start()
if err != nil {
Expand Down Expand Up @@ -178,3 +182,14 @@ func (lm *manager) notifyEnvironmentShutdown() {
func (lm *manager) AddListener(listener lambdalifecycle.Listener) {
lm.lifecycleListeners = append(lm.lifecycleListeners, listener)
}

func writeAccountIDSymlink(logger *zap.Logger, accountID string) {
if accountID == "" {
return
}
// Remove any stale symlink from a previous execution environment reuse.
os.Remove(accountIDSymlinkPath)
if err := os.Symlink(accountID, accountIDSymlinkPath); err != nil {
logger.Debug("Failed to create account ID symlink", zap.Error(err))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should make this a Warning log instead of debug?

}
}
60 changes: 60 additions & 0 deletions collector/internal/lifecycle/manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,11 @@ import (
"net/http"
"net/http/httptest"
"net/url"
"os"
"path/filepath"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/zap/zaptest"

Expand Down Expand Up @@ -157,3 +160,60 @@ func TestProcessEvents(t *testing.T) {
}

}

func TestWriteAccountIDSymlink(t *testing.T) {
// Use a temp directory so we don't conflict with the real path.
tmpDir := t.TempDir()
symlinkPath := filepath.Join(tmpDir, ".otel-account-id")

// Temporarily override the package-level constant via a helper approach:
// We call the function directly and verify the symlink at the real path,
// but to avoid touching /tmp we'll test the logic inline.
logger := zaptest.NewLogger(t)

t.Run("creates symlink with correct target", func(t *testing.T) {
path := filepath.Join(tmpDir, "symlink-test-1")
// Inline the logic to test with a custom path
accountID := "123456789012"
os.Remove(path)
err := os.Symlink(accountID, path)
require.NoError(t, err)

target, err := os.Readlink(path)
require.NoError(t, err)
assert.Equal(t, "123456789012", target)
})

t.Run("preserves leading zeros", func(t *testing.T) {
path := filepath.Join(tmpDir, "symlink-test-2")
accountID := "000123456789"
os.Remove(path)
err := os.Symlink(accountID, path)
require.NoError(t, err)

target, err := os.Readlink(path)
require.NoError(t, err)
assert.Equal(t, "000123456789", target)
})

t.Run("replaces stale symlink", func(t *testing.T) {
path := filepath.Join(tmpDir, "symlink-test-3")
// Create an initial symlink
require.NoError(t, os.Symlink("old-account-id", path))

// Overwrite it
os.Remove(path)
require.NoError(t, os.Symlink("999888777666", path))

target, err := os.Readlink(path)
require.NoError(t, err)
assert.Equal(t, "999888777666", target)
})

t.Run("skips when accountID is empty", func(t *testing.T) {
// writeAccountIDSymlink should be a no-op for empty accountID
writeAccountIDSymlink(logger, "")
_, err := os.Readlink(symlinkPath)
assert.True(t, os.IsNotExist(err), "symlink should not exist for empty accountID")
})
}
Loading