diff --git a/extract_thinker/llm.py b/extract_thinker/llm.py index 95c9971..d2bbe66 100644 --- a/extract_thinker/llm.py +++ b/extract_thinker/llm.py @@ -34,6 +34,11 @@ class LLM: MIN_THINKING_BUDGET = 1200 # Minimum thinking budget DEFAULT_OUTPUT_TOKENS = 32000 + # Model-specific token limits + MODEL_TOKEN_LIMITS = { + "gpt-4o": 12000, # Reasonable middle ground for GPT-4o + } + def __init__( self, model: str, @@ -41,7 +46,7 @@ def __init__( backend: LLMEngine = LLMEngine.DEFAULT ): """Initialize LLM with specified backend. - + Args: model: The model name (e.g. "gpt-4", "claude-3") token_limit: Optional maximum tokens @@ -69,7 +74,8 @@ def __init__( from pydantic_ai import Agent from pydantic_ai.models import KnownModelName from typing import cast - + import asyncio + self.client = None self.agent = Agent( cast(KnownModelName, self.model) @@ -100,6 +106,15 @@ def _get_pydantic_ai(): "Please install it with `pip install pydantic-ai`." ) + def _get_model_max_tokens(self) -> int: + """Get the maximum tokens allowed for the current model.""" + # Only apply special limit for GPT-4o + if self.model == "gpt-4o": + return self.MODEL_TOKEN_LIMITS["gpt-4o"] + + # Default to the general MAX_TOKEN_LIMIT for all other models + return self.MAX_TOKEN_LIMIT + def load_router(self, router: Router) -> None: """Load a LiteLLM router for model fallbacks.""" if self.backend != LLMEngine.DEFAULT: @@ -108,7 +123,7 @@ def load_router(self, router: Router) -> None: def set_temperature(self, temperature: float) -> None: """Set the temperature for LLM requests. - + Args: temperature (float): Temperature value between 0 and 1 """ @@ -116,7 +131,7 @@ def set_temperature(self, temperature: float) -> None: def set_thinking(self, is_thinking: bool) -> None: """Set whether the LLM should handle thinking. - + Args: is_thinking (bool): Whether to enable thinking """ @@ -125,10 +140,10 @@ def set_thinking(self, is_thinking: bool) -> None: def set_dynamic(self, is_dynamic: bool) -> None: """Set whether the LLM should handle dynamic content. - + When dynamic is True, the LLM will attempt to parse and validate JSON responses. This is useful for handling structured outputs like masking mappings. - + Args: is_dynamic (bool): Whether to enable dynamic content handling """ @@ -136,28 +151,28 @@ def set_dynamic(self, is_dynamic: bool) -> None: def set_page_count(self, page_count: int) -> None: """Set the page count to calculate token limits for thinking. - + Each page is assumed to have DEFAULT_PAGE_TOKENS tokens (text + image). Thinking budget is calculated as DEFAULT_THINKING_RATIO of the content tokens. - + Args: page_count (int): Number of pages in the document """ if page_count <= 0: raise ValueError("Page count must be a positive integer") - + self.page_count = page_count - + # Calculate content tokens content_tokens = min(page_count * self.DEFAULT_PAGE_TOKENS, self.MAX_TOKEN_LIMIT) - + # Calculate thinking budget (1/3 of content tokens) thinking_tokens = int(page_count * self.DEFAULT_PAGE_TOKENS * self.DEFAULT_THINKING_RATIO) - + # Apply min/max constraints thinking_tokens = max(thinking_tokens, self.MIN_THINKING_BUDGET) thinking_tokens = min(thinking_tokens, self.MAX_THINKING_BUDGET) - + # Update token limit and thinking budget self.thinking_token_limit = content_tokens self.thinking_budget = thinking_tokens @@ -172,9 +187,16 @@ def request( # Combine messages into a single prompt combined_prompt = " ".join([m["content"] for m in messages]) try: - result = asyncio.run( + # Create event loop if it doesn't exist + try: + loop = asyncio.get_event_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + result = loop.run_until_complete( self.agent.run( - combined_prompt, + combined_prompt, result_type=response_model if response_model else str ) ) @@ -182,11 +204,6 @@ def request( except Exception as e: raise ValueError(f"Failed to extract from source: {str(e)}") - # Uncomment the following lines if you need to calculate max_tokens - # contents = map(lambda message: message['content'], messages) - # all_contents = ' '.join(contents) - # max_tokens = num_tokens_from_string(all_contents) - # if is sync, response model is None if dynamic true and used for dynamic parsing after llm request request_model = None if self.is_dynamic else response_model @@ -214,7 +231,7 @@ def request( content = response.choices[0].message.content if self.is_dynamic: return extract_thinking_json(content, response_model) - + return content def _request_with_router(self, messages: List[Dict[str, str]], response_model: Optional[str]) -> Any: @@ -234,18 +251,43 @@ def _request_with_router(self, messages: List[Dict[str, str]], response_model: O "max_completion_tokens": max_tokens, } if self.is_thinking: - if litellm.supports_reasoning(self.model): - # Add thinking parameter for supported models - thinking_param = { - "type": "enabled", - "budget_tokens": self.thinking_budget - } - params["thinking"] = thinking_param - else: - print(f"Warning: Model {self.model} doesn't support thinking parameter, proceeding without it.") + # Add thinking parameter for supported models + thinking_param = { + "type": "enabled", + "budget_tokens": self.thinking_budget + } + try: + return self.router.completion( + model=self.model, + messages=messages, + response_model=response_model, + temperature=self.temperature, + timeout=self.TIMEOUT, + thinking=thinking_param, + ) + except Exception as e: + # If thinking parameter causes an error, try without it + if "property 'thinking' is unsupported" in str(e): + print(f"Warning: Model {self.model} doesn't support thinking parameter, proceeding without it.") + return self.router.completion( + model=self.model, + messages=messages, + response_model=response_model, + temperature=self.temperature, + timeout=self.TIMEOUT, + ) + else: + raise e + else: + # Normal request without thinking parameter + return self.router.completion( + model=self.model, + messages=messages, + response_model=response_model, + temperature=self.temperature, + timeout=self.TIMEOUT, + ) - return self.router.completion(**params) - def _request_direct(self, messages: List[Dict[str, str]], response_model: Optional[str]) -> Any: """Handle direct request with or without thinking parameter""" max_tokens = self.DEFAULT_OUTPUT_TOKENS @@ -260,10 +302,10 @@ def _request_direct(self, messages: List[Dict[str, str]], response_model: Option "temperature": self.temperature, "response_model": response_model, "max_retries": 1, - "max_completion_tokens": max_tokens, + "max_tokens": self._get_model_max_tokens(), # <- capped max tokens here "timeout": self.TIMEOUT, } - + if self.is_thinking: if litellm.supports_reasoning(self.model): # Try with thinking parameter @@ -279,50 +321,43 @@ def _request_direct(self, messages: List[Dict[str, str]], response_model: Option def raw_completion(self, messages: List[Dict[str, str]]) -> str: """Make raw completion request without response model.""" - if self.backend == LLMEngine.PYDANTIC_AI: - # Combine messages into a single prompt - combined_prompt = " ".join([m["content"] for m in messages]) - try: - result = asyncio.run( - self.agent.run( - combined_prompt, - result_type=str - ) - ) - return result.data - except Exception as e: - raise ValueError(f"Failed to extract from source: {str(e)}") + max_tokens = self._get_model_max_tokens() # <- capped max tokens here - max_tokens = self.DEFAULT_OUTPUT_TOKENS - if self.token_limit is not None: - max_tokens = self.token_limit - elif self.is_thinking: - max_tokens = self.thinking_token_limit - - params = { - "model": self.model, - "messages": messages, - "max_completion_tokens": max_tokens, - } - - if self.is_thinking: - if litellm.supports_reasoning(self.model): + if self.router: + raw_response = self.router.completion(**params) + else: + if self.is_thinking: # Add thinking parameter for supported models thinking_param = { "type": "enabled", "budget_tokens": self.thinking_budget } - params["thinking"] = thinking_param + try: + raw_response = litellm.completion( + model=self.model, + messages=messages, + max_tokens=max_tokens, + thinking=thinking_param, + ) + except Exception as e: + # If thinking parameter causes an error, try without it + if "property 'thinking' is unsupported" in str(e): + print(f"Warning: Model {self.model} doesn't support thinking parameter, proceeding without it.") + raw_response = litellm.completion( + model=self.model, + messages=messages, + max_tokens=max_tokens, + ) + else: + raise e else: - print(f"Warning: Model {self.model} doesn't support thinking parameter, proceeding without it.") - - if self.router: - raw_response = self.router.completion(**params) - else: - raw_response = litellm.completion(**params) - + raw_response = litellm.completion( + model=self.model, + messages=messages, + max_tokens=max_tokens, + ) return raw_response.choices[0].message.content def set_timeout(self, timeout_ms: int) -> None: """Set the timeout value for LLM requests in milliseconds.""" - self.TIMEOUT = timeout_ms \ No newline at end of file + self.TIMEOUT = timeout_ms