Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/bedrock_agentcore/identity/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def requires_access_token(
token_poller: Optional[TokenPoller] = None,
custom_state: Optional[str] = None,
custom_parameters: Optional[Dict[str, str]] = None,
require_par: Optional[bool] = None,
) -> Callable:
"""Decorator that fetches an OAuth2 access token before calling the decorated function.

Expand All @@ -46,6 +47,8 @@ def requires_access_token(
custom_state: A state that allows applications to verify the validity of callbacks to callback_url
custom_parameters: A map of custom parameters to include in authorization request to the credential provider
Note: these parameters are in addition to standard OAuth 2.0 flow parameters
require_par: Whether to require Pushed Authorization Request (PAR). Set to False to disable PAR
requirement for identity servers that don't support PAR. Defaults to None (backend default).

Returns:
Decorator function
Expand All @@ -67,6 +70,7 @@ async def _get_token() -> str:
token_poller=token_poller,
custom_state=custom_state,
custom_parameters=custom_parameters,
require_par=require_par,
)

@wraps(func)
Expand Down
9 changes: 8 additions & 1 deletion src/bedrock_agentcore/services/identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ async def get_token(
token_poller: Optional[TokenPoller] = None,
custom_state: Optional[str] = None,
custom_parameters: Optional[Dict[str, str]] = None,
require_par: Optional[bool] = None,
) -> str:
"""Get an OAuth2 access token for the specified provider.

Expand All @@ -184,6 +185,8 @@ async def get_token(
custom_state: A state that allows applications to verify the validity of callbacks to callback_url
custom_parameters: A map of custom parameters to include in authorization request to the credential provider
Note: these parameters are in addition to standard OAuth 2.0 flow parameters
require_par: Whether to require Pushed Authorization Request (PAR). Set to False to disable PAR
requirement for identity servers that don't support PAR. Defaults to None (backend default).

Returns:
The access token string
Expand Down Expand Up @@ -211,6 +214,8 @@ async def get_token(
req["customState"] = custom_state
if custom_parameters:
req["customParameters"] = custom_parameters
if require_par is not None:
req["requirePar"] = require_par

response = self.dp_client.get_resource_oauth2_token(**req)

Expand All @@ -236,8 +241,10 @@ async def get_token(
req["sessionUri"] = response["sessionUri"]

# Poll for the token
# Create a copy of req to avoid modifying the original during polling
poll_req = req.copy()
active_poller = token_poller or _DefaultApiTokenPoller(
auth_url, lambda: self.dp_client.get_resource_oauth2_token(**req).get("accessToken", None)
auth_url, lambda: self.dp_client.get_resource_oauth2_token(**poll_req).get("accessToken", None)
)
return await active_poller.poll_for_token()

Expand Down
70 changes: 70 additions & 0 deletions tests/bedrock_agentcore/services/test_identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,76 @@ async def test_get_token_with_custom_parameters(self):
customParameters=custom_parameters,
)

@pytest.mark.asyncio
async def test_get_token_with_require_par_disabled(self):
"""Test get_token with require_par set to False to disable PAR."""
region = "us-west-2"

with patch("boto3.client") as mock_boto_client:
mock_client = Mock()
mock_boto_client.return_value = mock_client

identity_client = IdentityClient(region)

provider_name = "test-provider"
scopes = ["read", "write"]
agent_identity_token = "test-agent-token"
expected_token = "test-access-token"

mock_client.get_resource_oauth2_token.return_value = {"accessToken": expected_token}

result = await identity_client.get_token(
provider_name=provider_name,
scopes=scopes,
agent_identity_token=agent_identity_token,
auth_flow="USER_FEDERATION",
require_par=False,
)

assert result == expected_token
mock_client.get_resource_oauth2_token.assert_called_once_with(
resourceCredentialProviderName=provider_name,
scopes=scopes,
oauth2Flow="USER_FEDERATION",
workloadIdentityToken=agent_identity_token,
requirePar=False,
)

@pytest.mark.asyncio
async def test_get_token_with_require_par_enabled(self):
"""Test get_token with require_par set to True to enable PAR."""
region = "us-west-2"

with patch("boto3.client") as mock_boto_client:
mock_client = Mock()
mock_boto_client.return_value = mock_client

identity_client = IdentityClient(region)

provider_name = "test-provider"
scopes = ["read", "write"]
agent_identity_token = "test-agent-token"
expected_token = "test-access-token"

mock_client.get_resource_oauth2_token.return_value = {"accessToken": expected_token}

result = await identity_client.get_token(
provider_name=provider_name,
scopes=scopes,
agent_identity_token=agent_identity_token,
auth_flow="USER_FEDERATION",
require_par=True,
)

assert result == expected_token
mock_client.get_resource_oauth2_token.assert_called_once_with(
resourceCredentialProviderName=provider_name,
scopes=scopes,
oauth2Flow="USER_FEDERATION",
workloadIdentityToken=agent_identity_token,
requirePar=True,
)

@pytest.mark.asyncio
async def test_get_api_key_success(self):
"""Test successful API key retrieval."""
Expand Down