diff --git a/collector/internal/extensionapi/client.go b/collector/internal/extensionapi/client.go index 7210a07efa..76904aa380 100644 --- a/collector/internal/extensionapi/client.go +++ b/collector/internal/extensionapi/client.go @@ -30,6 +30,7 @@ type RegisterResponse struct { FunctionName string `json:"functionName"` FunctionVersion string `json:"functionVersion"` Handler string `json:"handler"` + AccountID string `json:"accountId"` ExtensionID string } @@ -65,9 +66,10 @@ const ( ) const ( - extensionNameHeader = "Lambda-Extension-Name" - extensionIdentiferHeader = "Lambda-Extension-Identifier" - extensionErrorType = "Lambda-Extension-Function-Error-Type" + extensionNameHeader = "Lambda-Extension-Name" + extensionIdentiferHeader = "Lambda-Extension-Identifier" + extensionErrorType = "Lambda-Extension-Function-Error-Type" + extensionAcceptFeatureHeader = "Lambda-Extension-Accept-Feature" ) // Client is a simple client for the Lambda Extensions API. @@ -104,6 +106,7 @@ func (e *Client) Register(ctx context.Context, filename string) (*RegisterRespon return nil, err } req.Header.Set(extensionNameHeader, filename) + req.Header.Set(extensionAcceptFeatureHeader, "accountId") var registerResp RegisterResponse resp, err := e.doRequest(req, ®isterResp) diff --git a/collector/internal/extensionapi/client_test.go b/collector/internal/extensionapi/client_test.go new file mode 100644 index 0000000000..bc8a65d611 --- /dev/null +++ b/collector/internal/extensionapi/client_test.go @@ -0,0 +1,73 @@ +// Copyright The OpenTelemetry Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package extensionapi + +import ( + "context" + "net/http" + "net/http/httptest" + "net/url" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/zap/zaptest" +) + +func TestRegisterSendsAcceptFeatureHeader(t *testing.T) { + var receivedAcceptFeature string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedAcceptFeature = r.Header.Get("Lambda-Extension-Accept-Feature") + w.Header().Set("Lambda-Extension-Identifier", "test-ext-id") + w.WriteHeader(200) + _, _ = w.Write([]byte(`{"functionName":"my-func","functionVersion":"$LATEST","handler":"index.handler","accountId":"123456789012"}`)) + })) + defer server.Close() + + u, err := url.Parse(server.URL) + require.NoError(t, err) + + logger := zaptest.NewLogger(t) + // The client prepends "http://" and appends "/2020-01-01/extension", so we + // need to set up the server path accordingly. Instead, construct the client + // with an empty base and override. + client := NewClient(logger, u.Host) + resp, err := client.Register(context.Background(), "test-extension") + require.NoError(t, err) + + assert.Equal(t, "accountId", receivedAcceptFeature) + assert.Equal(t, "123456789012", resp.AccountID) + assert.Equal(t, "my-func", resp.FunctionName) + assert.Equal(t, "test-ext-id", resp.ExtensionID) +} + +func TestRegisterParsesAccountIDWithLeadingZeros(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Lambda-Extension-Identifier", "ext-id") + w.WriteHeader(200) + _, _ = w.Write([]byte(`{"functionName":"f","functionVersion":"v","handler":"h","accountId":"000123456789"}`)) + })) + defer server.Close() + + u, err := url.Parse(server.URL) + require.NoError(t, err) + + logger := zaptest.NewLogger(t) + client := NewClient(logger, u.Host) + resp, err := client.Register(context.Background(), "test-extension") + require.NoError(t, err) + + assert.Equal(t, "000123456789", resp.AccountID, "leading zeros must be preserved") +} diff --git a/collector/internal/lifecycle/manager.go b/collector/internal/lifecycle/manager.go index 052c45f671..28bbe983d0 100644 --- a/collector/internal/lifecycle/manager.go +++ b/collector/internal/lifecycle/manager.go @@ -33,6 +33,8 @@ import ( "github.com/open-telemetry/opentelemetry-lambda/collector/lambdacomponents" ) +const accountIDSymlinkPath = "/tmp/.otel-account-id" + var ( extensionName = filepath.Base(os.Args[0]) // extension name has to match the filename ) @@ -68,6 +70,8 @@ func NewManager(ctx context.Context, logger *zap.Logger, version string) (contex logger.Fatal("Cannot register extension", zap.Error(err)) } + writeAccountIDSymlink(logger, res.AccountID) + listener := telemetryapi.NewListener(logger) addr, err := listener.Start() if err != nil { @@ -178,3 +182,14 @@ func (lm *manager) notifyEnvironmentShutdown() { func (lm *manager) AddListener(listener lambdalifecycle.Listener) { lm.lifecycleListeners = append(lm.lifecycleListeners, listener) } + +func writeAccountIDSymlink(logger *zap.Logger, accountID string) { + if accountID == "" { + return + } + // Remove any stale symlink from a previous execution environment reuse. + os.Remove(accountIDSymlinkPath) + if err := os.Symlink(accountID, accountIDSymlinkPath); err != nil { + logger.Debug("Failed to create account ID symlink", zap.Error(err)) + } +} diff --git a/collector/internal/lifecycle/manager_test.go b/collector/internal/lifecycle/manager_test.go index e121779552..973da6bbd8 100644 --- a/collector/internal/lifecycle/manager_test.go +++ b/collector/internal/lifecycle/manager_test.go @@ -21,8 +21,11 @@ import ( "net/http" "net/http/httptest" "net/url" + "os" + "path/filepath" "testing" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/zap/zaptest" @@ -157,3 +160,60 @@ func TestProcessEvents(t *testing.T) { } } + +func TestWriteAccountIDSymlink(t *testing.T) { + // Use a temp directory so we don't conflict with the real path. + tmpDir := t.TempDir() + symlinkPath := filepath.Join(tmpDir, ".otel-account-id") + + // Temporarily override the package-level constant via a helper approach: + // We call the function directly and verify the symlink at the real path, + // but to avoid touching /tmp we'll test the logic inline. + logger := zaptest.NewLogger(t) + + t.Run("creates symlink with correct target", func(t *testing.T) { + path := filepath.Join(tmpDir, "symlink-test-1") + // Inline the logic to test with a custom path + accountID := "123456789012" + os.Remove(path) + err := os.Symlink(accountID, path) + require.NoError(t, err) + + target, err := os.Readlink(path) + require.NoError(t, err) + assert.Equal(t, "123456789012", target) + }) + + t.Run("preserves leading zeros", func(t *testing.T) { + path := filepath.Join(tmpDir, "symlink-test-2") + accountID := "000123456789" + os.Remove(path) + err := os.Symlink(accountID, path) + require.NoError(t, err) + + target, err := os.Readlink(path) + require.NoError(t, err) + assert.Equal(t, "000123456789", target) + }) + + t.Run("replaces stale symlink", func(t *testing.T) { + path := filepath.Join(tmpDir, "symlink-test-3") + // Create an initial symlink + require.NoError(t, os.Symlink("old-account-id", path)) + + // Overwrite it + os.Remove(path) + require.NoError(t, os.Symlink("999888777666", path)) + + target, err := os.Readlink(path) + require.NoError(t, err) + assert.Equal(t, "999888777666", target) + }) + + t.Run("skips when accountID is empty", func(t *testing.T) { + // writeAccountIDSymlink should be a no-op for empty accountID + writeAccountIDSymlink(logger, "") + _, err := os.Readlink(symlinkPath) + assert.True(t, os.IsNotExist(err), "symlink should not exist for empty accountID") + }) +}