diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index 75050c733..a41ba9fbc 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -387,9 +387,12 @@ async def query_endpoint_handler_base( # pylint: disable=R0914 consume_tokens( configuration.quota_limiters, + configuration.token_usage_history, user_id, input_tokens=token_usage.input_tokens, output_tokens=token_usage.output_tokens, + model_id=model_id, + provider_id=provider_id, ) store_conversation_into_cache( diff --git a/src/app/endpoints/streaming_query_v2.py b/src/app/endpoints/streaming_query_v2.py index aac676f23..266d3fbf7 100644 --- a/src/app/endpoints/streaming_query_v2.py +++ b/src/app/endpoints/streaming_query_v2.py @@ -275,9 +275,12 @@ async def response_generator( # pylint: disable=too-many-branches,too-many-stat ) consume_tokens( configuration.quota_limiters, + configuration.token_usage_history, context.user_id, input_tokens=token_usage.input_tokens, output_tokens=token_usage.output_tokens, + model_id=context.model_id, + provider_id=context.provider_id, ) referenced_documents = parse_referenced_documents_from_responses_api( cast(OpenAIResponseObject, latest_response_object) diff --git a/src/utils/quota.py b/src/utils/quota.py index a08b37741..524f4d33b 100644 --- a/src/utils/quota.py +++ b/src/utils/quota.py @@ -1,5 +1,7 @@ """Quota handling helper functions.""" +from typing import Optional + import psycopg2 from fastapi import HTTPException @@ -7,27 +9,44 @@ from models.responses import InternalServerErrorResponse, QuotaExceededResponse from quota.quota_exceed_error import QuotaExceedError from quota.quota_limiter import QuotaLimiter +from quota.token_usage_history import TokenUsageHistory logger = get_logger(__name__) +# pylint: disable=R0913,R0917 def consume_tokens( quota_limiters: list[QuotaLimiter], + token_usage_history: Optional[TokenUsageHistory], user_id: str, input_tokens: int, output_tokens: int, + model_id: str, + provider_id: str, ) -> None: """Consume tokens from cluster and/or user quotas. Parameters: quota_limiters: List of quota limiter instances to consume tokens from. + token_usage_history: Optional instance of TokenUsageHistory class that records used tokens user_id: Identifier of the user consuming tokens. input_tokens: Number of input tokens to consume. output_tokens: Number of output tokens to consume. + model_id: Model identification + provider_id: Provider identification Returns: None """ + # record token usage history + if token_usage_history is not None: + token_usage_history.consume_tokens( + user_id=user_id, + provider=provider_id, + model=model_id, + input_tokens=input_tokens, + output_tokens=output_tokens, + ) # consume tokens all configured quota limiters for quota_limiter in quota_limiters: quota_limiter.consume_tokens( diff --git a/tests/integration/endpoints/test_query_v2_integration.py b/tests/integration/endpoints/test_query_v2_integration.py index 9d4108fce..c20d1a7ac 100644 --- a/tests/integration/endpoints/test_query_v2_integration.py +++ b/tests/integration/endpoints/test_query_v2_integration.py @@ -1150,7 +1150,9 @@ async def test_query_v2_endpoint_quota_integration( mock_consume.assert_called_once() consume_args = mock_consume.call_args user_id, _, _, _ = test_auth - assert consume_args.args[1] == user_id + assert consume_args.args[2] == user_id + assert consume_args.kwargs["model_id"] == "test-model" + assert consume_args.kwargs["provider_id"] == "test-provider" assert consume_args.kwargs["input_tokens"] == 100 assert consume_args.kwargs["output_tokens"] == 50