From 155a316d29a07ecda46f25fb9b728355d94ed067 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Cabrero-Holgueras?= Date: Mon, 19 Jan 2026 17:12:25 +0100 Subject: [PATCH 1/6] feat: added pricing management --- .env.ci | 5 +- nilai-api/src/nilai_api/app.py | 9 +- nilai-api/src/nilai_api/config/__init__.py | 9 +- nilai-api/src/nilai_api/config/auth.py | 3 + .../src/nilai_api/config/config-a779.yaml | 31 --- .../src/nilai_api/config/config-e176.yaml | 35 --- .../src/nilai_api/config/config-f910.yaml | 35 --- nilai-api/src/nilai_api/config/config.yaml | 41 ++++ nilai-api/src/nilai_api/config/pricing.py | 21 ++ nilai-api/src/nilai_api/credit.py | 31 ++- nilai-api/src/nilai_api/pricing_service.py | 169 +++++++++++++ nilai-api/src/nilai_api/routers/pricing.py | 140 +++++++++++ tests/e2e/test_pricing.py | 222 +++++++++++++++++ tests/unit/nilai_api/test_pricing_service.py | 230 ++++++++++++++++++ 14 files changed, 866 insertions(+), 115 deletions(-) delete mode 100644 nilai-api/src/nilai_api/config/config-a779.yaml delete mode 100644 nilai-api/src/nilai_api/config/config-e176.yaml delete mode 100644 nilai-api/src/nilai_api/config/config-f910.yaml create mode 100644 nilai-api/src/nilai_api/config/pricing.py create mode 100644 nilai-api/src/nilai_api/pricing_service.py create mode 100644 nilai-api/src/nilai_api/routers/pricing.py create mode 100644 tests/e2e/test_pricing.py create mode 100644 tests/unit/nilai_api/test_pricing_service.py diff --git a/.env.ci b/.env.ci index f25246e7..0ecaea1d 100644 --- a/.env.ci +++ b/.env.ci @@ -23,9 +23,12 @@ ATTESTATION_HOST = "attestation" ATTESTATION_PORT = 8080 # nilAuth Trusted URLs -NILAUTH_TRUSTED_ROOT_ISSUERS = "http://nilauth-credit-server:3000" # "http://nilauth:30921" +NILAUTH_TRUSTED_ROOT_ISSUERS = "http://nilauth-credit-server:3000" CREDIT_API_TOKEN = "n i l l i o n" +# Admin token for pricing management API +ADMIN_TOKEN = "SecretAdminToken" + # Postgres Docker Compose Config POSTGRES_HOST = "postgres" POSTGRES_USER = "user" diff --git a/nilai-api/src/nilai_api/app.py b/nilai-api/src/nilai_api/app.py index 8a4e7ac4..d9e7e205 100644 --- a/nilai-api/src/nilai_api/app.py +++ b/nilai-api/src/nilai_api/app.py @@ -5,7 +5,8 @@ from fastapi import Depends, FastAPI from nilai_api.auth import get_auth_info from nilai_api.rate_limiting import setup_redis_conn -from nilai_api.routers import private, public +from nilai_api.routers import private, public, pricing +from nilai_api.pricing_service import PricingService, set_pricing_service from nilai_api import config from contextlib import asynccontextmanager from fastapi.middleware.cors import CORSMiddleware @@ -16,6 +17,11 @@ async def lifespan(app: FastAPI): client, rate_limit_command = await setup_redis_conn(config.CONFIG.redis.url) + # Initialize pricing service + pricing_service = PricingService(client) + await pricing_service.initialize_from_config() + set_pricing_service(pricing_service) + yield {"redis": client, "redis_rate_limit_command": rate_limit_command} @@ -88,6 +94,7 @@ async def lifespan(app: FastAPI): app.include_router(public.router) app.include_router(private.router, dependencies=[Depends(get_auth_info)]) +app.include_router(pricing.router, dependencies=[Depends(get_auth_info)]) app.add_middleware( CORSMiddleware, diff --git a/nilai-api/src/nilai_api/config/__init__.py b/nilai-api/src/nilai_api/config/__init__.py index 3f19f85e..18b77ac9 100644 --- a/nilai-api/src/nilai_api/config/__init__.py +++ b/nilai-api/src/nilai_api/config/__init__.py @@ -8,6 +8,7 @@ from .nildb import NilDBConfig from .web_search import WebSearchSettings from .rate_limiting import RateLimitingConfig +from .pricing import LLMPricingConfig, LLMPriceConfig from .utils import create_config_model, CONFIG_DATA @@ -37,6 +38,9 @@ class NilAIConfig(BaseModel): nildb: NilDBConfig = create_config_model( NilDBConfig, "nildb", CONFIG_DATA, "NILDB_" ) + llm_pricing: LLMPricingConfig = create_config_model( + LLMPricingConfig, "llm_pricing", CONFIG_DATA + ) def prettify(self): """Print the config in a pretty format removing passwords and other sensitive information""" @@ -66,7 +70,10 @@ def prettify(self): CONFIG = NilAIConfig() __all__ = [ # Main config object - "CONFIG" + "CONFIG", + # Pricing config for external use + "LLMPriceConfig", + "LLMPricingConfig", ] logging.info(CONFIG.prettify()) diff --git a/nilai-api/src/nilai_api/config/auth.py b/nilai-api/src/nilai_api/config/auth.py index 77358792..4e64e812 100644 --- a/nilai-api/src/nilai_api/config/auth.py +++ b/nilai-api/src/nilai_api/config/auth.py @@ -13,6 +13,9 @@ class AuthConfig(BaseModel): auth_token: Optional[str] = Field( default=None, description="Auth token for e2e tests and development" ) + admin_token: Optional[str] = Field( + default=None, description="Admin token for pricing updates" + ) @property def credit_service_url(self) -> str: diff --git a/nilai-api/src/nilai_api/config/config-a779.yaml b/nilai-api/src/nilai_api/config/config-a779.yaml deleted file mode 100644 index 40db5a67..00000000 --- a/nilai-api/src/nilai_api/config/config-a779.yaml +++ /dev/null @@ -1,31 +0,0 @@ -# Web Search Configuration -web_search: - api_key: null - api_path: "https://api.search.brave.com/res/v1/web/search" - count: 3 - lang: "en" - country: "us" - timeout: 20.0 - max_concurrent_requests: 20 - rps: 20 - -# Rate Limiting Configuration -rate_limiting: - user_rate_limit: null # For-good rate limit - user_rate_limit_minute: 10 - user_rate_limit_hour: 100 - user_rate_limit_day: 500 - web_search_rate_limit_minute: null - web_search_rate_limit_hour: null - web_search_rate_limit_day: null - web_search_rate_limit: 100 # For-good rate limit - model_concurrent_rate_limit: - meta-llama/Llama-3.2-1B-Instruct: 45 - meta-llama/Llama-3.2-3B-Instruct: 50 - meta-llama/Llama-3.1-8B-Instruct: 30 - cognitivecomputations/Dolphin3.0-Llama3.1-8B: 30 - deepseek-ai/DeepSeek-R1-Distill-Qwen-14B: 5 - hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4: 5 - openai/gpt-oss-20b: 50 - google/gemma-3-27b-it: 50 - default: 50 diff --git a/nilai-api/src/nilai_api/config/config-e176.yaml b/nilai-api/src/nilai_api/config/config-e176.yaml deleted file mode 100644 index 7c83f1eb..00000000 --- a/nilai-api/src/nilai_api/config/config-e176.yaml +++ /dev/null @@ -1,35 +0,0 @@ -# In production, this file is automatically generated by the `ansible` playbook. -# Configuration with structured sections and default values - - -# Web Search Configuration -web_search: - api_key: null - api_path: "https://api.search.brave.com/res/v1/web/search" - count: 3 - lang: "en" - country: "us" - timeout: 20.0 - max_concurrent_requests: 20 - rps: 20 - -# Rate Limiting Configuration -rate_limiting: - user_rate_limit: null # For-good rate limit - user_rate_limit_minute: 100 - user_rate_limit_hour: 1000 - user_rate_limit_day: 10000 - web_search_rate_limit_minute: null - web_search_rate_limit_hour: null - web_search_rate_limit_day: 500 - web_search_rate_limit: null # For-good rate limit - model_concurrent_rate_limit: - meta-llama/Llama-3.2-1B-Instruct: 45 - meta-llama/Llama-3.2-3B-Instruct: 50 - meta-llama/Llama-3.1-8B-Instruct: 30 - cognitivecomputations/Dolphin3.0-Llama3.1-8B: 30 - deepseek-ai/DeepSeek-R1-Distill-Qwen-14B: 5 - hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4: 5 - openai/gpt-oss-20b: 50 - google/gemma-3-27b-it: 50 - default: 50 diff --git a/nilai-api/src/nilai_api/config/config-f910.yaml b/nilai-api/src/nilai_api/config/config-f910.yaml deleted file mode 100644 index 2fbab210..00000000 --- a/nilai-api/src/nilai_api/config/config-f910.yaml +++ /dev/null @@ -1,35 +0,0 @@ -# In production, this file is automatically generated by the `ansible` playbook. -# Configuration with structured sections and default values - - -# Web Search Configuration -web_search: - api_key: null - api_path: "https://api.search.brave.com/res/v1/web/search" - count: 3 - lang: "en" - country: "us" - timeout: 20.0 - max_concurrent_requests: 20 - rps: 20 - -# Rate Limiting Configuration -rate_limiting: - user_rate_limit: null # For-good rate limit - user_rate_limit_minute: 100 - user_rate_limit_hour: 1000 - user_rate_limit_day: 10000 - web_search_rate_limit_minute: 1 - web_search_rate_limit_hour: 3 - web_search_rate_limit_day: 72 - web_search_rate_limit: null # For-good rate limit - model_concurrent_rate_limit: - meta-llama/Llama-3.2-1B-Instruct: 45 - meta-llama/Llama-3.2-3B-Instruct: 50 - meta-llama/Llama-3.1-8B-Instruct: 30 - cognitivecomputations/Dolphin3.0-Llama3.1-8B: 30 - deepseek-ai/DeepSeek-R1-Distill-Qwen-14B: 5 - hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4: 5 - openai/gpt-oss-20b: 50 - google/gemma-3-27b-it: 50 - default: 50 diff --git a/nilai-api/src/nilai_api/config/config.yaml b/nilai-api/src/nilai_api/config/config.yaml index 565113bf..e8271fc2 100644 --- a/nilai-api/src/nilai_api/config/config.yaml +++ b/nilai-api/src/nilai_api/config/config.yaml @@ -9,6 +9,7 @@ auth: strategy: "api_key" nilauth_trusted_root_issuers: - http://nilauth-credit-server:3000 + admin_token: null # Set via ADMIN_TOKEN env var for pricing management # Documentation Configuration docs: @@ -46,3 +47,43 @@ rate_limiting: openai/gpt-oss-20b: 50 google/gemma-3-27b-it: 50 default: 50 + +# LLM Pricing Configuration +llm_pricing: + default: + prompt_tokens_price: 2.0 + completion_tokens_price: 2.0 + web_search_cost: 0.05 + models: + meta-llama/Llama-3.2-1B-Instruct: + prompt_tokens_price: 3.0 + completion_tokens_price: 3.0 + web_search_cost: 0.05 + meta-llama/Llama-3.2-3B-Instruct: + prompt_tokens_price: 3.0 + completion_tokens_price: 3.0 + web_search_cost: 0.05 + meta-llama/Llama-3.1-8B-Instruct: + prompt_tokens_price: 3.0 + completion_tokens_price: 3.0 + web_search_cost: 0.05 + cognitivecomputations/Dolphin3.0-Llama3.1-8B: + prompt_tokens_price: 3.0 + completion_tokens_price: 3.0 + web_search_cost: 0.05 + deepseek-ai/DeepSeek-R1-Distill-Qwen-14B: + prompt_tokens_price: 5.0 + completion_tokens_price: 5.0 + web_search_cost: 0.05 + hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4: + prompt_tokens_price: 8.0 + completion_tokens_price: 8.0 + web_search_cost: 0.05 + openai/gpt-oss-20b: + prompt_tokens_price: 4.0 + completion_tokens_price: 4.0 + web_search_cost: 0.05 + google/gemma-3-27b-it: + prompt_tokens_price: 5.0 + completion_tokens_price: 5.0 + web_search_cost: 0.05 diff --git a/nilai-api/src/nilai_api/config/pricing.py b/nilai-api/src/nilai_api/config/pricing.py new file mode 100644 index 00000000..1fa041fb --- /dev/null +++ b/nilai-api/src/nilai_api/config/pricing.py @@ -0,0 +1,21 @@ +from typing import Dict +from pydantic import BaseModel, Field + + +class LLMPriceConfig(BaseModel): + """Pricing configuration for a single LLM model.""" + + prompt_tokens_price: float = Field( + default=2.0, description="Cost per 1M prompt tokens" + ) + completion_tokens_price: float = Field( + default=2.0, description="Cost per 1M completion tokens" + ) + web_search_cost: float = Field(default=0.05, description="Cost per web search") + + +class LLMPricingConfig(BaseModel): + """Container for all LLM pricing configurations.""" + + default: LLMPriceConfig = Field(default_factory=LLMPriceConfig) + models: Dict[str, LLMPriceConfig] = Field(default_factory=dict) diff --git a/nilai-api/src/nilai_api/credit.py b/nilai-api/src/nilai_api/credit.py index b9d7ea6f..e456f840 100644 --- a/nilai-api/src/nilai_api/credit.py +++ b/nilai-api/src/nilai_api/credit.py @@ -11,6 +11,7 @@ ) from nilai_api.config import CONFIG +from nilai_api.pricing_service import get_pricing_service from nuc.envelope import NucTokenEnvelope @@ -51,6 +52,22 @@ def default() -> "LLMCost": prompt_tokens_price=2.0, completion_tokens_price=2.0, web_search_cost=0.05 ) + @staticmethod + async def from_redis(model_name: str) -> "LLMCost": + """Fetch pricing from Redis for a specific model.""" + try: + pricing_service = get_pricing_service() + price_config = await pricing_service.get_price(model_name) + return LLMCost( + prompt_tokens_price=price_config.prompt_tokens_price, + completion_tokens_price=price_config.completion_tokens_price, + web_search_cost=price_config.web_search_cost, + ) + except RuntimeError: + # Pricing service not initialized, use default + logger.warning("Pricing service not initialized, using default pricing") + return LLMCost.default() + def total_cost( self, prompt_tokens: int, completion_tokens: int, web_searches: int ) -> float: @@ -87,14 +104,6 @@ class LLMResponse(BaseModel): LLMCostDict: TypeAlias = dict[str, LLMCost] - -MyCostDictionary: LLMCostDict = { - "meta-llama/Llama-3.2-1B-Instruct": LLMCost( - prompt_tokens_price=3.0, completion_tokens_price=3.0, web_search_cost=0.05 - ), - "default": LLMCost.default(), -} - # Configure the singleton credit client CreditClientSingleton.configure( base_url=CONFIG.auth.credit_service_url, @@ -138,10 +147,10 @@ async def extractor(request: Request) -> str: return extractor -def llm_cost_calculator(llm_cost_dict: LLMCostDict): +def llm_cost_calculator(): async def calculator(request: Request, response_data: dict) -> float: model_name = getattr(request, "model", "default") - llm_cost = llm_cost_dict.get(model_name, LLMCost.default()) + llm_cost = await LLMCost.from_redis(model_name) total_cost = 0.0 usage: Optional[LLMUsage] = response_data.get("usage", None) if usage is None: @@ -158,7 +167,7 @@ async def calculator(request: Request, response_data: dict) -> float: _base_llm_meter = create_metering_dependency( credential_extractor=credential_extractor(), estimated_cost=2.0, - cost_calculator=llm_cost_calculator(MyCostDictionary), + cost_calculator=llm_cost_calculator(), public_identifiers=CONFIG.auth.auth_strategy == "nuc", ) diff --git a/nilai-api/src/nilai_api/pricing_service.py b/nilai-api/src/nilai_api/pricing_service.py new file mode 100644 index 00000000..1cd86d58 --- /dev/null +++ b/nilai-api/src/nilai_api/pricing_service.py @@ -0,0 +1,169 @@ +import logging +from typing import Dict, Optional + +from redis.asyncio import Redis + +from nilai_api.config import CONFIG +from nilai_api.config.pricing import LLMPriceConfig + +logger = logging.getLogger(__name__) + +# Redis key prefix for pricing data +REDIS_PRICING_PREFIX = "nilai:pricing:" +REDIS_PRICING_ALL_KEY = "nilai:pricing:_all" + + +class PricingService: + """Redis-backed pricing service for thread-safe operations.""" + + def __init__(self, redis_client: Redis): + self._redis = redis_client + self._default_config = CONFIG.llm_pricing.default + self._model_configs = CONFIG.llm_pricing.models + + async def initialize_from_config(self) -> None: + """Load defaults from YAML into Redis on startup (only if not already set).""" + try: + # Check if pricing data already exists in Redis + existing_keys = await self._redis.keys(f"{REDIS_PRICING_PREFIX}*") + if existing_keys: + logger.info( + "Pricing data already exists in Redis, skipping initialization" + ) + return + + # Initialize default pricing + await self._set_price_in_redis("default", self._default_config) + + # Initialize model-specific pricing + for model_name, price_config in self._model_configs.items(): + await self._set_price_in_redis(model_name, price_config) + + logger.info( + f"Initialized pricing from config: default + {len(self._model_configs)} models" + ) + except Exception as e: + logger.error(f"Failed to initialize pricing from config: {e}") + raise + + async def _set_price_in_redis( + self, model_name: str, config: LLMPriceConfig + ) -> None: + """Set price for a model in Redis using pipeline for atomicity.""" + key = f"{REDIS_PRICING_PREFIX}{model_name}" + config_json = config.model_dump_json() + + async with self._redis.pipeline(transaction=True) as pipe: + # Set individual model key + pipe.set(key, config_json) + # Update hash for bulk retrieval + pipe.hset(REDIS_PRICING_ALL_KEY, model_name, config_json) + await pipe.execute() + + async def get_price(self, model_name: str) -> LLMPriceConfig: + """ + Get price from Redis for a specific model. + + Falls back to default pricing if model not found or Redis unavailable. + """ + try: + key = f"{REDIS_PRICING_PREFIX}{model_name}" + config_json = await self._redis.get(key) + + if config_json: + return LLMPriceConfig.model_validate_json(config_json) + + # Fallback to default pricing from Redis + default_key = f"{REDIS_PRICING_PREFIX}default" + default_json = await self._redis.get(default_key) + + if default_json: + return LLMPriceConfig.model_validate_json(default_json) + + # Last resort: return config default + logger.warning( + f"No pricing found in Redis for model '{model_name}', using config default" + ) + return self._default_config + + except Exception as e: + logger.error(f"Failed to get price from Redis: {e}, using config default") + return self._default_config + + async def get_all_prices(self) -> Dict[str, LLMPriceConfig]: + """Get all prices from Redis hash.""" + try: + all_prices: Dict[str, str] = await self._redis.hgetall( + REDIS_PRICING_ALL_KEY + ) # type: ignore[assignment] + + result = {} + for model_name, config_json in all_prices.items(): + # Handle bytes if Redis returns bytes + if isinstance(model_name, bytes): + model_name = model_name.decode("utf-8") + if isinstance(config_json, bytes): + config_json = config_json.decode("utf-8") + + result[model_name] = LLMPriceConfig.model_validate_json(config_json) + + return result + + except Exception as e: + logger.error(f"Failed to get all prices from Redis: {e}") + # Fallback to config defaults + result = {"default": self._default_config} + result.update(self._model_configs) + return result + + async def set_price(self, model_name: str, config: LLMPriceConfig) -> None: + """Atomic update of price for a model using Redis pipeline.""" + await self._set_price_in_redis(model_name, config) + logger.info(f"Updated pricing for model '{model_name}'") + + async def delete_price(self, model_name: str) -> bool: + """ + Remove custom pricing for a model (will use default). + + Returns True if the key existed and was deleted, False otherwise. + """ + if model_name == "default": + raise ValueError("Cannot delete default pricing") + + key = f"{REDIS_PRICING_PREFIX}{model_name}" + + async with self._redis.pipeline(transaction=True) as pipe: + # Check if key exists + pipe.exists(key) + # Delete individual model key + pipe.delete(key) + # Remove from hash + pipe.hdel(REDIS_PRICING_ALL_KEY, model_name) + results = await pipe.execute() + + existed = results[0] > 0 + if existed: + logger.info(f"Deleted pricing for model '{model_name}'") + return existed + + async def price_exists(self, model_name: str) -> bool: + """Check if a custom price exists for a model.""" + key = f"{REDIS_PRICING_PREFIX}{model_name}" + return await self._redis.exists(key) > 0 + + +# Global pricing service instance +_pricing_service: Optional[PricingService] = None + + +def set_pricing_service(service: PricingService) -> None: + """Set the global pricing service instance.""" + global _pricing_service + _pricing_service = service + + +def get_pricing_service() -> PricingService: + """Get the global pricing service instance.""" + if _pricing_service is None: + raise RuntimeError("Pricing service not initialized") + return _pricing_service diff --git a/nilai-api/src/nilai_api/routers/pricing.py b/nilai-api/src/nilai_api/routers/pricing.py new file mode 100644 index 00000000..0b2d00ef --- /dev/null +++ b/nilai-api/src/nilai_api/routers/pricing.py @@ -0,0 +1,140 @@ +import logging +from typing import Dict + +from fastapi import APIRouter, Depends, HTTPException, Request, status + +from nilai_api.auth import get_auth_info, AuthenticationInfo +from nilai_api.config import CONFIG +from nilai_api.config.pricing import LLMPriceConfig +from nilai_api.pricing_service import get_pricing_service + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/v1/pricing", tags=["Pricing"]) + + +def verify_admin_token(request: Request) -> None: + """Verify that the request has a valid admin token.""" + admin_token = CONFIG.auth.admin_token + if not admin_token: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Admin operations are disabled (no admin token configured)", + ) + + auth_header = request.headers.get("Authorization", "") + if auth_header.startswith("Bearer "): + token = auth_header[7:] + else: + token = auth_header + + if token != admin_token: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Invalid admin token", + ) + + +@router.get("", response_model=Dict[str, LLMPriceConfig]) +async def get_all_prices( + auth_info: AuthenticationInfo = Depends(get_auth_info), +) -> Dict[str, LLMPriceConfig]: + """ + Get all model prices. + + Returns a dictionary mapping model names to their pricing configurations. + """ + pricing_service = get_pricing_service() + return await pricing_service.get_all_prices() + + +@router.get("/{model_name:path}", response_model=LLMPriceConfig) +async def get_model_price( + model_name: str, + auth_info: AuthenticationInfo = Depends(get_auth_info), +) -> LLMPriceConfig: + """ + Get price for a specific model. + + - **model_name**: The model name (e.g., `meta-llama/Llama-3.2-1B-Instruct`) + - **Returns**: Pricing configuration for the model + + If no specific pricing is set for the model, returns default pricing. + """ + pricing_service = get_pricing_service() + return await pricing_service.get_price(model_name) + + +@router.put("/{model_name:path}", response_model=LLMPriceConfig) +async def update_model_price( + model_name: str, + price_config: LLMPriceConfig, + request: Request, + auth_info: AuthenticationInfo = Depends(get_auth_info), +) -> LLMPriceConfig: + """ + Update price for a specific model (admin only). + + - **model_name**: The model name (e.g., `meta-llama/Llama-3.2-1B-Instruct`) + - **price_config**: New pricing configuration + + Requires admin token in Authorization header. + """ + verify_admin_token(request) + + # Validate price values + if price_config.prompt_tokens_price < 0: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="prompt_tokens_price must be non-negative", + ) + if price_config.completion_tokens_price < 0: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="completion_tokens_price must be non-negative", + ) + if price_config.web_search_cost < 0: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="web_search_cost must be non-negative", + ) + + pricing_service = get_pricing_service() + await pricing_service.set_price(model_name, price_config) + + logger.info(f"Admin updated pricing for model '{model_name}'") + return price_config + + +@router.delete("/{model_name:path}", status_code=status.HTTP_204_NO_CONTENT) +async def delete_model_price( + model_name: str, + request: Request, + auth_info: AuthenticationInfo = Depends(get_auth_info), +) -> None: + """ + Delete custom price for a model (admin only). + + - **model_name**: The model name to delete pricing for + + After deletion, the model will use default pricing. + Requires admin token in Authorization header. + """ + verify_admin_token(request) + + if model_name == "default": + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Cannot delete default pricing", + ) + + pricing_service = get_pricing_service() + existed = await pricing_service.delete_price(model_name) + + if not existed: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"No custom pricing found for model '{model_name}'", + ) + + logger.info(f"Admin deleted pricing for model '{model_name}'") diff --git a/tests/e2e/test_pricing.py b/tests/e2e/test_pricing.py new file mode 100644 index 00000000..1e98fe52 --- /dev/null +++ b/tests/e2e/test_pricing.py @@ -0,0 +1,222 @@ +"""E2E tests for pricing API endpoints.""" + +import pytest +from .config import BASE_URL, api_key_getter +from nilai_api.config import CONFIG +import httpx + + +@pytest.fixture +def http_client(): + """Create an HTTPX client with user authentication.""" + invocation_token: str = api_key_getter() + # Use base URL without /v1 since pricing endpoint is at /v1/pricing + base = BASE_URL.rsplit("/v1", 1)[0] + return httpx.Client( + base_url=base, + headers={ + "accept": "application/json", + "Content-Type": "application/json", + "Authorization": f"Bearer {invocation_token}", + }, + verify=False, + timeout=30.0, + ) + + +@pytest.fixture +def admin_http_client(): + """Create an HTTPX client with admin authentication.""" + admin_token = CONFIG.auth.admin_token + if not admin_token: + pytest.skip("Admin token not configured") + + base = BASE_URL.rsplit("/v1", 1)[0] + return httpx.Client( + base_url=base, + headers={ + "accept": "application/json", + "Content-Type": "application/json", + "Authorization": f"Bearer {admin_token}", + }, + verify=False, + timeout=30.0, + ) + + +class TestPricingGet: + """Tests for GET pricing endpoints.""" + + def test_get_all_prices(self, http_client): + """Test getting all model prices.""" + response = http_client.get("/v1/pricing") + assert response.status_code == 200 + + prices = response.json() + assert isinstance(prices, dict) + assert "default" in prices + + # Check default pricing structure + default = prices["default"] + assert "prompt_tokens_price" in default + assert "completion_tokens_price" in default + assert "web_search_cost" in default + + def test_get_specific_model_price(self, http_client): + """Test getting price for a specific model.""" + response = http_client.get("/v1/pricing/meta-llama/Llama-3.2-1B-Instruct") + assert response.status_code == 200 + + price = response.json() + assert "prompt_tokens_price" in price + assert "completion_tokens_price" in price + assert "web_search_cost" in price + + def test_get_unknown_model_returns_default(self, http_client): + """Test that unknown model returns default pricing.""" + response = http_client.get("/v1/pricing/unknown/model-that-does-not-exist") + assert response.status_code == 200 + + price = response.json() + # Should return default pricing + assert price["prompt_tokens_price"] == 2.0 + assert price["completion_tokens_price"] == 2.0 + assert price["web_search_cost"] == 0.05 + + def test_get_default_price(self, http_client): + """Test getting the default price directly.""" + response = http_client.get("/v1/pricing/default") + assert response.status_code == 200 + + price = response.json() + assert price["prompt_tokens_price"] == 2.0 + assert price["completion_tokens_price"] == 2.0 + + +class TestPricingUpdateDelete: + """Tests for PUT/DELETE pricing endpoints (admin only).""" + + def test_update_price_without_admin_token_fails(self, http_client): + """Test that updating price without admin token fails.""" + response = http_client.put( + "/v1/pricing/test-model", + json={ + "prompt_tokens_price": 10.0, + "completion_tokens_price": 10.0, + "web_search_cost": 0.1, + }, + ) + assert response.status_code == 403 + + def test_delete_price_without_admin_token_fails(self, http_client): + """Test that deleting price without admin token fails.""" + response = http_client.delete("/v1/pricing/test-model") + assert response.status_code == 403 + + @pytest.mark.skipif( + not CONFIG.auth.admin_token, reason="Admin token not configured" + ) + def test_update_price_with_admin_token(self, admin_http_client, http_client): + """Test updating a model price with admin token.""" + model_name = "e2e-test-model" + new_price = { + "prompt_tokens_price": 25.0, + "completion_tokens_price": 30.0, + "web_search_cost": 0.5, + } + + # Update the price + response = admin_http_client.put(f"/v1/pricing/{model_name}", json=new_price) + assert response.status_code == 200 + + result = response.json() + assert result["prompt_tokens_price"] == 25.0 + assert result["completion_tokens_price"] == 30.0 + assert result["web_search_cost"] == 0.5 + + # Verify with a regular GET request + get_response = http_client.get(f"/v1/pricing/{model_name}") + assert get_response.status_code == 200 + + fetched = get_response.json() + assert fetched["prompt_tokens_price"] == 25.0 + assert fetched["completion_tokens_price"] == 30.0 + + # Clean up + admin_http_client.delete(f"/v1/pricing/{model_name}") + + @pytest.mark.skipif( + not CONFIG.auth.admin_token, reason="Admin token not configured" + ) + def test_delete_price_with_admin_token(self, admin_http_client, http_client): + """Test deleting a custom price with admin token.""" + model_name = "e2e-delete-test-model" + + # First create a custom price + new_price = { + "prompt_tokens_price": 50.0, + "completion_tokens_price": 50.0, + "web_search_cost": 1.0, + } + admin_http_client.put(f"/v1/pricing/{model_name}", json=new_price) + + # Delete it + response = admin_http_client.delete(f"/v1/pricing/{model_name}") + assert response.status_code == 204 + + # Verify it now returns default pricing + get_response = http_client.get(f"/v1/pricing/{model_name}") + fetched = get_response.json() + assert fetched["prompt_tokens_price"] == 2.0 # Default + + @pytest.mark.skipif( + not CONFIG.auth.admin_token, reason="Admin token not configured" + ) + def test_delete_nonexistent_price_returns_404(self, admin_http_client): + """Test that deleting a non-existent price returns 404.""" + response = admin_http_client.delete("/v1/pricing/nonexistent-model-xyz-12345") + assert response.status_code == 404 + + @pytest.mark.skipif( + not CONFIG.auth.admin_token, reason="Admin token not configured" + ) + def test_delete_default_fails(self, admin_http_client): + """Test that deleting default pricing fails.""" + response = admin_http_client.delete("/v1/pricing/default") + assert response.status_code == 400 + + @pytest.mark.skipif( + not CONFIG.auth.admin_token, reason="Admin token not configured" + ) + def test_update_price_with_invalid_values(self, admin_http_client): + """Test that updating price with negative values fails.""" + response = admin_http_client.put( + "/v1/pricing/test-model", + json={ + "prompt_tokens_price": -1.0, + "completion_tokens_price": 10.0, + "web_search_cost": 0.1, + }, + ) + assert response.status_code == 400 + + +class TestPricingAuth: + """Tests for pricing authentication requirements.""" + + def test_get_prices_requires_auth(self): + """Test that getting prices requires authentication.""" + base = BASE_URL.rsplit("/v1", 1)[0] + client = httpx.Client(base_url=base, verify=False, timeout=30.0) + + response = client.get("/v1/pricing") + # Should return 401 or 403 without auth + assert response.status_code in [401, 403, 422] + + def test_get_model_price_requires_auth(self): + """Test that getting a model price requires authentication.""" + base = BASE_URL.rsplit("/v1", 1)[0] + client = httpx.Client(base_url=base, verify=False, timeout=30.0) + + response = client.get("/v1/pricing/default") + assert response.status_code in [401, 403, 422] diff --git a/tests/unit/nilai_api/test_pricing_service.py b/tests/unit/nilai_api/test_pricing_service.py new file mode 100644 index 00000000..31f122a9 --- /dev/null +++ b/tests/unit/nilai_api/test_pricing_service.py @@ -0,0 +1,230 @@ +import pytest +import pytest_asyncio + +from nilai_api.config.pricing import LLMPriceConfig +from nilai_api.pricing_service import ( + PricingService, + get_pricing_service, + set_pricing_service, + REDIS_PRICING_PREFIX, + REDIS_PRICING_ALL_KEY, +) +from nilai_api.rate_limiting import setup_redis_conn + + +@pytest_asyncio.fixture +async def redis_client(redis_server): + """Create a Redis client connected to the test container.""" + host_ip = redis_server.get_container_host_ip() + host_port = redis_server.get_exposed_port(6379) + client, _ = await setup_redis_conn(f"redis://{host_ip}:{host_port}") + yield client + await client.aclose() + + +@pytest_asyncio.fixture +async def pricing_service(redis_client): + """Create a PricingService instance and clean up Redis keys before/after.""" + # Clean up any existing pricing keys + keys = await redis_client.keys(f"{REDIS_PRICING_PREFIX}*") + if keys: + await redis_client.delete(*keys) + + service = PricingService(redis_client) + set_pricing_service(service) + yield service + + # Clean up after test + keys = await redis_client.keys(f"{REDIS_PRICING_PREFIX}*") + if keys: + await redis_client.delete(*keys) + + +@pytest.mark.asyncio +async def test_initialize_from_config(pricing_service, redis_client): + """Test that pricing is initialized from config into Redis.""" + await pricing_service.initialize_from_config() + + # Check that default pricing was set + default_key = f"{REDIS_PRICING_PREFIX}default" + default_json = await redis_client.get(default_key) + assert default_json is not None + + default_config = LLMPriceConfig.model_validate_json(default_json) + assert default_config.prompt_tokens_price == 2.0 + assert default_config.completion_tokens_price == 2.0 + assert default_config.web_search_cost == 0.05 + + +@pytest.mark.asyncio +async def test_initialize_skips_if_data_exists(pricing_service, redis_client): + """Test that initialization is skipped if data already exists.""" + # First initialization + await pricing_service.initialize_from_config() + + # Modify a value directly in Redis + custom_config = LLMPriceConfig( + prompt_tokens_price=99.0, completion_tokens_price=99.0, web_search_cost=99.0 + ) + await redis_client.set( + f"{REDIS_PRICING_PREFIX}default", custom_config.model_dump_json() + ) + + # Second initialization should be skipped + await pricing_service.initialize_from_config() + + # Verify the custom value is still there + default_json = await redis_client.get(f"{REDIS_PRICING_PREFIX}default") + default_config = LLMPriceConfig.model_validate_json(default_json) + assert default_config.prompt_tokens_price == 99.0 + + +@pytest.mark.asyncio +async def test_get_price_returns_model_specific_price(pricing_service): + """Test getting price for a specific model.""" + await pricing_service.initialize_from_config() + + # Get price for a model that should have specific pricing + price = await pricing_service.get_price("meta-llama/Llama-3.2-1B-Instruct") + assert price.prompt_tokens_price == 3.0 + assert price.completion_tokens_price == 3.0 + + +@pytest.mark.asyncio +async def test_get_price_falls_back_to_default(pricing_service): + """Test that unknown models fall back to default pricing.""" + await pricing_service.initialize_from_config() + + # Get price for an unknown model + price = await pricing_service.get_price("unknown/model") + assert price.prompt_tokens_price == 2.0 + assert price.completion_tokens_price == 2.0 + assert price.web_search_cost == 0.05 + + +@pytest.mark.asyncio +async def test_set_price(pricing_service, redis_client): + """Test setting price for a model.""" + await pricing_service.initialize_from_config() + + # Set a new price + new_config = LLMPriceConfig( + prompt_tokens_price=10.0, completion_tokens_price=15.0, web_search_cost=0.1 + ) + await pricing_service.set_price("test-model", new_config) + + # Verify it was set correctly + price = await pricing_service.get_price("test-model") + assert price.prompt_tokens_price == 10.0 + assert price.completion_tokens_price == 15.0 + assert price.web_search_cost == 0.1 + + # Verify it's in the hash + hash_value = await redis_client.hget(REDIS_PRICING_ALL_KEY, "test-model") + assert hash_value is not None + + +@pytest.mark.asyncio +async def test_get_all_prices(pricing_service): + """Test getting all prices.""" + await pricing_service.initialize_from_config() + + all_prices = await pricing_service.get_all_prices() + + assert "default" in all_prices + assert all_prices["default"].prompt_tokens_price == 2.0 + + # Check that model-specific prices are included + assert "meta-llama/Llama-3.2-1B-Instruct" in all_prices + + +@pytest.mark.asyncio +async def test_delete_price(pricing_service): + """Test deleting a custom price.""" + await pricing_service.initialize_from_config() + + # Add a custom price + custom_config = LLMPriceConfig( + prompt_tokens_price=50.0, completion_tokens_price=50.0, web_search_cost=1.0 + ) + await pricing_service.set_price("custom-model", custom_config) + + # Verify it exists + price = await pricing_service.get_price("custom-model") + assert price.prompt_tokens_price == 50.0 + + # Delete it + existed = await pricing_service.delete_price("custom-model") + assert existed is True + + # Verify it falls back to default now + price = await pricing_service.get_price("custom-model") + assert price.prompt_tokens_price == 2.0 + + +@pytest.mark.asyncio +async def test_delete_price_returns_false_if_not_exists(pricing_service): + """Test deleting a non-existent price returns False.""" + await pricing_service.initialize_from_config() + + existed = await pricing_service.delete_price("nonexistent-model") + assert existed is False + + +@pytest.mark.asyncio +async def test_delete_default_raises_error(pricing_service): + """Test that deleting default pricing raises an error.""" + await pricing_service.initialize_from_config() + + with pytest.raises(ValueError, match="Cannot delete default pricing"): + await pricing_service.delete_price("default") + + +@pytest.mark.asyncio +async def test_price_exists(pricing_service): + """Test checking if a price exists.""" + await pricing_service.initialize_from_config() + + assert await pricing_service.price_exists("default") is True + assert ( + await pricing_service.price_exists("meta-llama/Llama-3.2-1B-Instruct") is True + ) + assert await pricing_service.price_exists("nonexistent-model") is False + + +@pytest.mark.asyncio +async def test_get_pricing_service_not_initialized(): + """Test that get_pricing_service raises error when not initialized.""" + # Reset the global service + from nilai_api import pricing_service as ps_module + + old_service = ps_module._pricing_service + ps_module._pricing_service = None + + try: + with pytest.raises(RuntimeError, match="Pricing service not initialized"): + get_pricing_service() + finally: + ps_module._pricing_service = old_service + + +@pytest.mark.asyncio +async def test_update_existing_price(pricing_service): + """Test updating an existing model's price.""" + await pricing_service.initialize_from_config() + + # Get original price + original = await pricing_service.get_price("meta-llama/Llama-3.2-1B-Instruct") + assert original.prompt_tokens_price == 3.0 + + # Update it + new_config = LLMPriceConfig( + prompt_tokens_price=100.0, completion_tokens_price=100.0, web_search_cost=5.0 + ) + await pricing_service.set_price("meta-llama/Llama-3.2-1B-Instruct", new_config) + + # Verify update + updated = await pricing_service.get_price("meta-llama/Llama-3.2-1B-Instruct") + assert updated.prompt_tokens_price == 100.0 + assert updated.completion_tokens_price == 100.0 + assert updated.web_search_cost == 5.0 From 00263048da91db2f58163c178f6e60aec6911ebf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Cabrero-Holgueras?= Date: Mon, 19 Jan 2026 18:05:57 +0100 Subject: [PATCH 2/6] fix: use admin-only auth for pricing PUT/DELETE endpoints --- nilai-api/src/nilai_api/routers/pricing.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/nilai-api/src/nilai_api/routers/pricing.py b/nilai-api/src/nilai_api/routers/pricing.py index 0b2d00ef..64dc2e80 100644 --- a/nilai-api/src/nilai_api/routers/pricing.py +++ b/nilai-api/src/nilai_api/routers/pricing.py @@ -13,8 +13,8 @@ router = APIRouter(prefix="/v1/pricing", tags=["Pricing"]) -def verify_admin_token(request: Request) -> None: - """Verify that the request has a valid admin token.""" +async def verify_admin_token(request: Request) -> None: + """Dependency to verify that the request has a valid admin token.""" admin_token = CONFIG.auth.admin_token if not admin_token: raise HTTPException( @@ -69,8 +69,7 @@ async def get_model_price( async def update_model_price( model_name: str, price_config: LLMPriceConfig, - request: Request, - auth_info: AuthenticationInfo = Depends(get_auth_info), + _: None = Depends(verify_admin_token), ) -> LLMPriceConfig: """ Update price for a specific model (admin only). @@ -80,7 +79,6 @@ async def update_model_price( Requires admin token in Authorization header. """ - verify_admin_token(request) # Validate price values if price_config.prompt_tokens_price < 0: @@ -109,8 +107,7 @@ async def update_model_price( @router.delete("/{model_name:path}", status_code=status.HTTP_204_NO_CONTENT) async def delete_model_price( model_name: str, - request: Request, - auth_info: AuthenticationInfo = Depends(get_auth_info), + _: None = Depends(verify_admin_token), ) -> None: """ Delete custom price for a model (admin only). @@ -120,7 +117,6 @@ async def delete_model_price( After deletion, the model will use default pricing. Requires admin token in Authorization header. """ - verify_admin_token(request) if model_name == "default": raise HTTPException( From 462f135cd6c2bbfa4f70a8eae9f8d88008ecd968 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Cabrero-Holgueras?= Date: Tue, 20 Jan 2026 10:37:43 +0100 Subject: [PATCH 3/6] fix: use admin-only auth for pricing PUT/DELETE endpoints --- nilai-api/src/nilai_api/app.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nilai-api/src/nilai_api/app.py b/nilai-api/src/nilai_api/app.py index d9e7e205..a01cc593 100644 --- a/nilai-api/src/nilai_api/app.py +++ b/nilai-api/src/nilai_api/app.py @@ -94,7 +94,7 @@ async def lifespan(app: FastAPI): app.include_router(public.router) app.include_router(private.router, dependencies=[Depends(get_auth_info)]) -app.include_router(pricing.router, dependencies=[Depends(get_auth_info)]) +app.include_router(pricing.router) app.add_middleware( CORSMiddleware, From e7acc29c5e0d7ee91bdce01417eb837df2dae80f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Cabrero-Holgueras?= Date: Tue, 20 Jan 2026 11:03:36 +0100 Subject: [PATCH 4/6] fix: use Bearer Token auth for Admin endpoints --- nilai-api/src/nilai_api/routers/pricing.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/nilai-api/src/nilai_api/routers/pricing.py b/nilai-api/src/nilai_api/routers/pricing.py index 64dc2e80..5dcae30a 100644 --- a/nilai-api/src/nilai_api/routers/pricing.py +++ b/nilai-api/src/nilai_api/routers/pricing.py @@ -1,7 +1,8 @@ import logging from typing import Dict -from fastapi import APIRouter, Depends, HTTPException, Request, status +from fastapi import APIRouter, Depends, HTTPException, Security, status +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from nilai_api.auth import get_auth_info, AuthenticationInfo from nilai_api.config import CONFIG @@ -11,9 +12,12 @@ logger = logging.getLogger(__name__) router = APIRouter(prefix="/v1/pricing", tags=["Pricing"]) +admin_bearer_scheme = HTTPBearer() -async def verify_admin_token(request: Request) -> None: +async def verify_admin_token( + credentials: HTTPAuthorizationCredentials = Security(admin_bearer_scheme), +) -> None: """Dependency to verify that the request has a valid admin token.""" admin_token = CONFIG.auth.admin_token if not admin_token: @@ -22,13 +26,7 @@ async def verify_admin_token(request: Request) -> None: detail="Admin operations are disabled (no admin token configured)", ) - auth_header = request.headers.get("Authorization", "") - if auth_header.startswith("Bearer "): - token = auth_header[7:] - else: - token = auth_header - - if token != admin_token: + if credentials.credentials != admin_token: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Invalid admin token", From f66acc3dbfd96a1f7bf5a4daf87400536a3a03e9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Cabrero-Holgueras?= Date: Tue, 20 Jan 2026 13:35:04 +0100 Subject: [PATCH 5/6] feat: add final pricing and estimated costs --- nilai-api/src/nilai_api/config/config.yaml | 42 +++++++------------- nilai-api/src/nilai_api/credit.py | 2 +- tests/unit/nilai_api/test_pricing_service.py | 35 ++++++++++------ 3 files changed, 38 insertions(+), 41 deletions(-) diff --git a/nilai-api/src/nilai_api/config/config.yaml b/nilai-api/src/nilai_api/config/config.yaml index e8271fc2..4723a41b 100644 --- a/nilai-api/src/nilai_api/config/config.yaml +++ b/nilai-api/src/nilai_api/config/config.yaml @@ -51,39 +51,27 @@ rate_limiting: # LLM Pricing Configuration llm_pricing: default: - prompt_tokens_price: 2.0 - completion_tokens_price: 2.0 - web_search_cost: 0.05 + prompt_tokens_price: 0.15 + completion_tokens_price: 0.45 + web_search_cost: 0.05 models: meta-llama/Llama-3.2-1B-Instruct: - prompt_tokens_price: 3.0 - completion_tokens_price: 3.0 - web_search_cost: 0.05 - meta-llama/Llama-3.2-3B-Instruct: - prompt_tokens_price: 3.0 - completion_tokens_price: 3.0 + prompt_tokens_price: 0.03 + completion_tokens_price: 0.09 web_search_cost: 0.05 meta-llama/Llama-3.1-8B-Instruct: - prompt_tokens_price: 3.0 - completion_tokens_price: 3.0 - web_search_cost: 0.05 - cognitivecomputations/Dolphin3.0-Llama3.1-8B: - prompt_tokens_price: 3.0 - completion_tokens_price: 3.0 - web_search_cost: 0.05 - deepseek-ai/DeepSeek-R1-Distill-Qwen-14B: - prompt_tokens_price: 5.0 - completion_tokens_price: 5.0 - web_search_cost: 0.05 - hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4: - prompt_tokens_price: 8.0 - completion_tokens_price: 8.0 + prompt_tokens_price: 0.03 + completion_tokens_price: 0.09 web_search_cost: 0.05 openai/gpt-oss-20b: - prompt_tokens_price: 4.0 - completion_tokens_price: 4.0 + prompt_tokens_price: 0.15 + completion_tokens_price: 0.45 web_search_cost: 0.05 google/gemma-3-27b-it: - prompt_tokens_price: 5.0 - completion_tokens_price: 5.0 + prompt_tokens_price: 0.15 + completion_tokens_price: 0.45 + web_search_cost: 0.05 + Qwen/Qwen3-Coder-30B-A3B-Instruct: + prompt_tokens_price: 0.15 + completion_tokens_price: 0.45 web_search_cost: 0.05 diff --git a/nilai-api/src/nilai_api/credit.py b/nilai-api/src/nilai_api/credit.py index e456f840..65c3daa8 100644 --- a/nilai-api/src/nilai_api/credit.py +++ b/nilai-api/src/nilai_api/credit.py @@ -166,7 +166,7 @@ async def calculator(request: Request, response_data: dict) -> float: _base_llm_meter = create_metering_dependency( credential_extractor=credential_extractor(), - estimated_cost=2.0, + estimated_cost=0.5, cost_calculator=llm_cost_calculator(), public_identifiers=CONFIG.auth.auth_strategy == "nuc", ) diff --git a/tests/unit/nilai_api/test_pricing_service.py b/tests/unit/nilai_api/test_pricing_service.py index 31f122a9..313f3437 100644 --- a/tests/unit/nilai_api/test_pricing_service.py +++ b/tests/unit/nilai_api/test_pricing_service.py @@ -51,8 +51,8 @@ async def test_initialize_from_config(pricing_service, redis_client): assert default_json is not None default_config = LLMPriceConfig.model_validate_json(default_json) - assert default_config.prompt_tokens_price == 2.0 - assert default_config.completion_tokens_price == 2.0 + assert default_config.prompt_tokens_price == 0.15 + assert default_config.completion_tokens_price == 0.45 assert default_config.web_search_cost == 0.05 @@ -86,8 +86,8 @@ async def test_get_price_returns_model_specific_price(pricing_service): # Get price for a model that should have specific pricing price = await pricing_service.get_price("meta-llama/Llama-3.2-1B-Instruct") - assert price.prompt_tokens_price == 3.0 - assert price.completion_tokens_price == 3.0 + assert price.prompt_tokens_price == 0.03 + assert price.completion_tokens_price == 0.09 @pytest.mark.asyncio @@ -97,8 +97,8 @@ async def test_get_price_falls_back_to_default(pricing_service): # Get price for an unknown model price = await pricing_service.get_price("unknown/model") - assert price.prompt_tokens_price == 2.0 - assert price.completion_tokens_price == 2.0 + assert price.prompt_tokens_price == 0.15 + assert price.completion_tokens_price == 0.45 assert price.web_search_cost == 0.05 @@ -132,10 +132,15 @@ async def test_get_all_prices(pricing_service): all_prices = await pricing_service.get_all_prices() assert "default" in all_prices - assert all_prices["default"].prompt_tokens_price == 2.0 - - # Check that model-specific prices are included + assert all_prices["default"].prompt_tokens_price == 0.15 + assert all_prices["default"].completion_tokens_price == 0.45 + assert all_prices["default"].web_search_cost == 0.05 assert "meta-llama/Llama-3.2-1B-Instruct" in all_prices + assert all_prices["meta-llama/Llama-3.2-1B-Instruct"].prompt_tokens_price == 0.03 + assert ( + all_prices["meta-llama/Llama-3.2-1B-Instruct"].completion_tokens_price == 0.09 + ) + assert all_prices["meta-llama/Llama-3.2-1B-Instruct"].web_search_cost == 0.05 @pytest.mark.asyncio @@ -145,13 +150,15 @@ async def test_delete_price(pricing_service): # Add a custom price custom_config = LLMPriceConfig( - prompt_tokens_price=50.0, completion_tokens_price=50.0, web_search_cost=1.0 + prompt_tokens_price=0.15, completion_tokens_price=0.45, web_search_cost=0.05 ) await pricing_service.set_price("custom-model", custom_config) # Verify it exists price = await pricing_service.get_price("custom-model") - assert price.prompt_tokens_price == 50.0 + assert price.prompt_tokens_price == 0.15 + assert price.completion_tokens_price == 0.45 + assert price.web_search_cost == 0.05 # Delete it existed = await pricing_service.delete_price("custom-model") @@ -159,7 +166,9 @@ async def test_delete_price(pricing_service): # Verify it falls back to default now price = await pricing_service.get_price("custom-model") - assert price.prompt_tokens_price == 2.0 + assert price.prompt_tokens_price == 0.15 + assert price.completion_tokens_price == 0.45 + assert price.web_search_cost == 0.05 @pytest.mark.asyncio @@ -215,7 +224,7 @@ async def test_update_existing_price(pricing_service): # Get original price original = await pricing_service.get_price("meta-llama/Llama-3.2-1B-Instruct") - assert original.prompt_tokens_price == 3.0 + assert original.prompt_tokens_price == 0.03 # Update it new_config = LLMPriceConfig( From af6ddb0f5c27784f84f8a7a0210ec96bad8c24f0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Cabrero-Holgueras?= Date: Tue, 20 Jan 2026 16:16:44 +0100 Subject: [PATCH 6/6] fix: tests --- tests/e2e/test_pricing.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/tests/e2e/test_pricing.py b/tests/e2e/test_pricing.py index 1e98fe52..96171f07 100644 --- a/tests/e2e/test_pricing.py +++ b/tests/e2e/test_pricing.py @@ -79,8 +79,8 @@ def test_get_unknown_model_returns_default(self, http_client): price = response.json() # Should return default pricing - assert price["prompt_tokens_price"] == 2.0 - assert price["completion_tokens_price"] == 2.0 + assert price["prompt_tokens_price"] == 0.15 + assert price["completion_tokens_price"] == 0.45 assert price["web_search_cost"] == 0.05 def test_get_default_price(self, http_client): @@ -89,8 +89,8 @@ def test_get_default_price(self, http_client): assert response.status_code == 200 price = response.json() - assert price["prompt_tokens_price"] == 2.0 - assert price["completion_tokens_price"] == 2.0 + assert price["prompt_tokens_price"] == 0.15 + assert price["completion_tokens_price"] == 0.45 class TestPricingUpdateDelete: @@ -167,7 +167,9 @@ def test_delete_price_with_admin_token(self, admin_http_client, http_client): # Verify it now returns default pricing get_response = http_client.get(f"/v1/pricing/{model_name}") fetched = get_response.json() - assert fetched["prompt_tokens_price"] == 2.0 # Default + assert fetched["prompt_tokens_price"] == 0.15 # Default + assert fetched["completion_tokens_price"] == 0.45 + assert fetched["web_search_cost"] == 0.05 @pytest.mark.skipif( not CONFIG.auth.admin_token, reason="Admin token not configured"