diff --git a/src/bedrock_agentcore/identity/auth.py b/src/bedrock_agentcore/identity/auth.py index 4bd99b3..9b4337c 100644 --- a/src/bedrock_agentcore/identity/auth.py +++ b/src/bedrock_agentcore/identity/auth.py @@ -31,6 +31,7 @@ def requires_access_token( token_poller: Optional[TokenPoller] = None, custom_state: Optional[str] = None, custom_parameters: Optional[Dict[str, str]] = None, + require_par: Optional[bool] = None, ) -> Callable: """Decorator that fetches an OAuth2 access token before calling the decorated function. @@ -46,6 +47,8 @@ def requires_access_token( custom_state: A state that allows applications to verify the validity of callbacks to callback_url custom_parameters: A map of custom parameters to include in authorization request to the credential provider Note: these parameters are in addition to standard OAuth 2.0 flow parameters + require_par: Whether to require Pushed Authorization Request (PAR). Set to False to disable PAR + requirement for identity servers that don't support PAR. Defaults to None (backend default). Returns: Decorator function @@ -67,6 +70,7 @@ async def _get_token() -> str: token_poller=token_poller, custom_state=custom_state, custom_parameters=custom_parameters, + require_par=require_par, ) @wraps(func) diff --git a/src/bedrock_agentcore/services/identity.py b/src/bedrock_agentcore/services/identity.py index 80d30ac..ecd09b8 100644 --- a/src/bedrock_agentcore/services/identity.py +++ b/src/bedrock_agentcore/services/identity.py @@ -169,6 +169,7 @@ async def get_token( token_poller: Optional[TokenPoller] = None, custom_state: Optional[str] = None, custom_parameters: Optional[Dict[str, str]] = None, + require_par: Optional[bool] = None, ) -> str: """Get an OAuth2 access token for the specified provider. @@ -184,6 +185,8 @@ async def get_token( custom_state: A state that allows applications to verify the validity of callbacks to callback_url custom_parameters: A map of custom parameters to include in authorization request to the credential provider Note: these parameters are in addition to standard OAuth 2.0 flow parameters + require_par: Whether to require Pushed Authorization Request (PAR). Set to False to disable PAR + requirement for identity servers that don't support PAR. Defaults to None (backend default). Returns: The access token string @@ -211,6 +214,8 @@ async def get_token( req["customState"] = custom_state if custom_parameters: req["customParameters"] = custom_parameters + if require_par is not None: + req["requirePar"] = require_par response = self.dp_client.get_resource_oauth2_token(**req) @@ -236,8 +241,10 @@ async def get_token( req["sessionUri"] = response["sessionUri"] # Poll for the token + # Create a copy of req to avoid modifying the original during polling + poll_req = req.copy() active_poller = token_poller or _DefaultApiTokenPoller( - auth_url, lambda: self.dp_client.get_resource_oauth2_token(**req).get("accessToken", None) + auth_url, lambda: self.dp_client.get_resource_oauth2_token(**poll_req).get("accessToken", None) ) return await active_poller.poll_for_token() diff --git a/tests/bedrock_agentcore/services/test_identity.py b/tests/bedrock_agentcore/services/test_identity.py index d3801c1..4cafb08 100644 --- a/tests/bedrock_agentcore/services/test_identity.py +++ b/tests/bedrock_agentcore/services/test_identity.py @@ -343,6 +343,76 @@ async def test_get_token_with_custom_parameters(self): customParameters=custom_parameters, ) + @pytest.mark.asyncio + async def test_get_token_with_require_par_disabled(self): + """Test get_token with require_par set to False to disable PAR.""" + region = "us-west-2" + + with patch("boto3.client") as mock_boto_client: + mock_client = Mock() + mock_boto_client.return_value = mock_client + + identity_client = IdentityClient(region) + + provider_name = "test-provider" + scopes = ["read", "write"] + agent_identity_token = "test-agent-token" + expected_token = "test-access-token" + + mock_client.get_resource_oauth2_token.return_value = {"accessToken": expected_token} + + result = await identity_client.get_token( + provider_name=provider_name, + scopes=scopes, + agent_identity_token=agent_identity_token, + auth_flow="USER_FEDERATION", + require_par=False, + ) + + assert result == expected_token + mock_client.get_resource_oauth2_token.assert_called_once_with( + resourceCredentialProviderName=provider_name, + scopes=scopes, + oauth2Flow="USER_FEDERATION", + workloadIdentityToken=agent_identity_token, + requirePar=False, + ) + + @pytest.mark.asyncio + async def test_get_token_with_require_par_enabled(self): + """Test get_token with require_par set to True to enable PAR.""" + region = "us-west-2" + + with patch("boto3.client") as mock_boto_client: + mock_client = Mock() + mock_boto_client.return_value = mock_client + + identity_client = IdentityClient(region) + + provider_name = "test-provider" + scopes = ["read", "write"] + agent_identity_token = "test-agent-token" + expected_token = "test-access-token" + + mock_client.get_resource_oauth2_token.return_value = {"accessToken": expected_token} + + result = await identity_client.get_token( + provider_name=provider_name, + scopes=scopes, + agent_identity_token=agent_identity_token, + auth_flow="USER_FEDERATION", + require_par=True, + ) + + assert result == expected_token + mock_client.get_resource_oauth2_token.assert_called_once_with( + resourceCredentialProviderName=provider_name, + scopes=scopes, + oauth2Flow="USER_FEDERATION", + workloadIdentityToken=agent_identity_token, + requirePar=True, + ) + @pytest.mark.asyncio async def test_get_api_key_success(self): """Test successful API key retrieval."""