diff --git a/README.md b/README.md index d6161c0..009431e 100644 --- a/README.md +++ b/README.md @@ -28,6 +28,8 @@ You can either: ``` OPENAI_API_KEY: 'xxxxxxx' ANTHROPIC_API_KEY: 'xxxxxxx' +VERTEX_SERVICE_ACCOUNT_PATH: 'xxxxxxx' +VERTEX_REGION: 'xxxxxxx' ``` ### To install iverilog {.tabset} diff --git a/src/mage_rtl/gen_config.py b/src/mage_rtl/gen_config.py index 85abdeb..0f69e5c 100644 --- a/src/mage_rtl/gen_config.py +++ b/src/mage_rtl/gen_config.py @@ -1,12 +1,15 @@ import os import config +from google.oauth2 import service_account from llama_index.core.llms.llm import LLM from llama_index.llms.anthropic import Anthropic from llama_index.llms.openai import OpenAI +from llama_index.llms.vertex import Vertex from pydantic import BaseModel from .log_utils import get_logger +from .utils import VertexAnthropicWithCredentials logger = get_logger(__name__) @@ -34,27 +37,81 @@ def __getitem__(self, index): def get_llm(**kwargs) -> LLM: - LLM_func = Anthropic cfg = Config(kwargs["cfg_path"]) - api_key_cfg = "" - if kwargs["provider"] == "anthropic": - LLM_func = Anthropic - api_key_cfg = cfg["ANTHROPIC_API_KEY"] + provider: str = kwargs["provider"] + provider = provider.lower() + if provider == "anthropic": + try: + llm: LLM = Anthropic( + model=kwargs["model"], + api_key=cfg["ANTHROPIC_API_KEY"], + max_tokens=kwargs["max_token"], + ) + + except Exception as e: + raise Exception(f"gen_config: Failed to get {provider} LLM") from e elif kwargs["provider"] == "openai": - LLM_func = OpenAI - api_key_cfg = cfg["OPENAI_API_KEY"] - # add more providers if needed + try: + llm: LLM = OpenAI( + model=kwargs["model"], + api_key=cfg["OPENAI_API_KEY"], + max_tokens=kwargs["max_token"], + ) + + except Exception as e: + raise Exception(f"gen_config: Failed to get {provider} LLM") from e + elif kwargs["provider"] == "vertex": + logger.warning( + "Support of Vertex Gemini LLMs is still in experimental stage, use with caution" + ) + service_account_path = os.path.expanduser(cfg["VERTEX_SERVICE_ACCOUNT_PATH"]) + if not os.path.exists(service_account_path): + raise FileNotFoundError( + f"Google Cloud Service Account file not found: {service_account_path}" + ) + try: + credentials = service_account.Credentials.from_service_account_file( + service_account_path + ) + llm: LLM = Vertex( + model=kwargs["model"], + project=credentials.project_id, + credentials=credentials, + max_tokens=kwargs["max_token"], + ) + + except Exception as e: + raise Exception(f"gen_config: Failed to get {provider} LLM") from e + elif kwargs["provider"] == "vertexanthropic": + service_account_path = os.path.expanduser(cfg["VERTEX_SERVICE_ACCOUNT_PATH"]) + if not os.path.exists(service_account_path): + raise FileNotFoundError( + f"Google Cloud Service Account file not found: {service_account_path}" + ) + try: + credentials = service_account.Credentials.from_service_account_file( + service_account_path, + scopes=["https://www.googleapis.com/auth/cloud-platform"], + ) + llm: LLM = VertexAnthropicWithCredentials( + model=kwargs["model"], + project_id=credentials.project_id, + credentials=credentials, + region=cfg["VERTEX_REGION"], + max_tokens=kwargs["max_token"], + ) + + except Exception as e: + raise Exception(f"gen_config: Failed to get {provider} LLM") from e + else: + raise ValueError(f"gen_config: Invalid provider: {provider}") try: - llm: LLM = LLM_func( - model=kwargs["model"], - api_key=api_key_cfg, - max_tokens=kwargs["max_token"], - ) _ = llm.complete("Say 'Hi'") - except Exception as e: - raise Exception("gen_config: Failed to get LLM") from e + raise Exception( + f"gen_config: Failed to complete LLM chat for {provider}" + ) from e return llm diff --git a/src/mage_rtl/token_counter.py b/src/mage_rtl/token_counter.py index da5bd5f..20f186a 100644 --- a/src/mage_rtl/token_counter.py +++ b/src/mage_rtl/token_counter.py @@ -8,10 +8,13 @@ from llama_index.core.llms.llm import LLM from llama_index.llms.anthropic import Anthropic from llama_index.llms.openai import OpenAI +from llama_index.llms.vertex import Vertex from pydantic import BaseModel +from vertexai.preview.generative_models import GenerativeModel from .gen_config import get_exp_setting from .log_utils import get_logger +from .utils import reformat_json_string logger = get_logger(__name__) @@ -70,14 +73,23 @@ def __str__(self) -> str: class TokenCost(BaseModel): """Token cost of an LLM call""" - in_token_cost_per_token: float - out_token_cost_per_token: float + in_token_cost_per_token: float = 0.0 + out_token_cost_per_token: float = 0.0 token_costs = { "claude-3-5-sonnet-20241022": TokenCost( in_token_cost_per_token=3.0 / 1000000, out_token_cost_per_token=15.0 / 1000000 ), + "claude-3-5-sonnet@20241022": TokenCost( + in_token_cost_per_token=3.0 / 1000000, out_token_cost_per_token=15.0 / 1000000 + ), + "claude-3-7-sonnet-20250219": TokenCost( + in_token_cost_per_token=3.0 / 1000000, out_token_cost_per_token=15.0 / 1000000 + ), + "claude-3-7-sonnet@20250219": TokenCost( + in_token_cost_per_token=3.0 / 1000000, out_token_cost_per_token=15.0 / 1000000 + ), "gpt-4o-2024-08-06": TokenCost( in_token_cost_per_token=2.5 / 1000000, out_token_cost_per_token=10.0 / 1000000 ), @@ -93,6 +105,9 @@ class TokenCost(BaseModel): "gemini-1.5-pro-002": TokenCost( in_token_cost_per_token=1.25 / 1000000, out_token_cost_per_token=5.0 / 1000000 ), + "gemini-2.0-flash-001": TokenCost( + in_token_cost_per_token=0.1 / 1000000, out_token_cost_per_token=0.4 / 1000000 + ), } @@ -105,16 +120,33 @@ def __init__(self, llm: LLM) -> None: self.token_cnts_lock = asyncio.Lock() self.cur_tag = "" self.max_parallel_requests: int = 10 + self.enable_reformat_json = isinstance(llm, Vertex) model = llm.metadata.model_name if isinstance(llm, OpenAI): self.encoding = tiktoken.encoding_for_model(model) elif isinstance(llm, Anthropic): self.encoding = llm.tokenizer + elif isinstance(llm, Vertex): + assert llm.model.startswith( + "gemini" + ), f"Non-gemini Vertex model is not supported: {llm.model}" + assert isinstance(llm._client, GenerativeModel) + + class VertexEncoding: + def __init__(self, client: GenerativeModel): + self.client = client + + def encode(self, text: str) -> List[str]: + token_len = self.client.count_tokens(text).total_tokens + return ["placeholder" for _ in range(token_len)] + + self.encoding = VertexEncoding(llm._client) + self.activate_structure_output = True else: raise Exception(f"gen_config: No tokenizer for model {model}") logger.info(f"Found tokenizer for model '{model}'") - self.token_cost = token_costs[model] if model in token_costs else None - if self.token_cost is None: + self.token_cost = token_costs[model] if model in token_costs else TokenCost() + if self.token_cost == TokenCost(): logger.warning( f"Cannot find token cost for model '{model}' in record. Won't display cost in USD" ) @@ -147,6 +179,8 @@ def count_chat( out_token_cnt = self.count(response.message.content) token_cnt = TokenCount(in_token_cnt=in_token_cnt, out_token_cnt=out_token_cnt) self.token_cnts[self.cur_tag].append(token_cnt) + if self.enable_reformat_json: + response.message.content = reformat_json_string(response.message.content) return (response, token_cnt) async def count_achat( @@ -165,6 +199,8 @@ async def count_achat( token_cnt = TokenCount(in_token_cnt=in_token_cnt, out_token_cnt=out_token_cnt) async with self.token_cnts_lock: self.token_cnts[self.cur_tag].append(token_cnt) + if self.enable_reformat_json: + response.message.content = reformat_json_string(response.message.content) return (response, token_cnt) async def count_achat_batch( @@ -284,11 +320,6 @@ def count_chat( ) response = llm.chat( messages, - extra_headers=( - {"anthropic-beta": "prompt-caching-2024-07-31"} - if self.enable_cache - else {} - ), top_p=settings.top_p, temperature=settings.temperature, ) @@ -309,6 +340,8 @@ def count_chat( ), ) self.token_cnts[self.cur_tag].append(token_cnt) + if self.enable_reformat_json: + response.message.content = reformat_json_string(response.message.content) return (response, token_cnt) async def count_achat( @@ -321,11 +354,6 @@ async def count_achat( ) response = await llm.achat( messages, - extra_headers=( - {"anthropic-beta": "prompt-caching-2024-07-31"} - if self.enable_cache - else {} - ), top_p=settings.top_p, temperature=settings.temperature, ) @@ -347,6 +375,8 @@ async def count_achat( ) async with self.token_cnts_lock: self.token_cnts[self.cur_tag].append(token_cnt) + if self.enable_reformat_json: + response.message.content = reformat_json_string(response.message.content) return (response, token_cnt) def log_token_stats(self) -> None: diff --git a/src/mage_rtl/utils.py b/src/mage_rtl/utils.py index 25c35b3..42024e7 100644 --- a/src/mage_rtl/utils.py +++ b/src/mage_rtl/utils.py @@ -1,6 +1,66 @@ +import re + +import anthropic +from llama_index.llms.anthropic import Anthropic + + def add_lineno(file_content: str) -> str: lines = file_content.split("\n") ret = "" for i, line in enumerate(lines): ret += f"{i+1}: {line}\n" return ret + + +def reformat_json_string(output: str) -> str: + # in gemini, the output has markdown surrounding the json string + # like ```json ... ``` + # we need to remove the markdown + # remove by using regex between ```json and ``` + pattern = r"```json(.*?)```" + match = re.search(pattern, output, re.DOTALL) + if match: + return match.group(1).strip() + + pattern = r"```xml(.*?)```" + match = re.search(pattern, output, re.DOTALL) + if match: + return match.group(1).strip() + + return output.strip() + + +class VertexAnthropicWithCredentials(Anthropic): + def __init__(self, credentials, **kwargs): + """ + In addition to all parameters accepted by Anthropic, this class accepts a + new parameter `credentials` that will be passed to the underlying clients. + """ + # Pop parameters that determine client type so we can reuse them in our branch. + region = kwargs.get("region") + project_id = kwargs.get("project_id") + aws_region = kwargs.get("aws_region") + + # Call the parent initializer; this sets up a default _client and _aclient. + super().__init__(**kwargs) + + # If using AnthropicVertex (i.e., region and project_id are provided and aws_region is None), + # override the _client and _aclient with the additional credentials parameter. + if region and project_id and not aws_region: + self._client = anthropic.AnthropicVertex( + region=region, + project_id=project_id, + credentials=credentials, # extra argument + timeout=self.timeout, + max_retries=self.max_retries, + default_headers=kwargs.get("default_headers"), + ) + self._aclient = anthropic.AsyncAnthropicVertex( + region=region, + project_id=project_id, + credentials=credentials, # extra argument + timeout=self.timeout, + max_retries=self.max_retries, + default_headers=kwargs.get("default_headers"), + ) + # Optionally, you could add similar overrides for the aws_region branch if needed. diff --git a/tests/test_top_agent.py b/tests/test_top_agent.py index 0b91392..4c4e04c 100644 --- a/tests/test_top_agent.py +++ b/tests/test_top_agent.py @@ -1,6 +1,5 @@ import argparse import json -import os import time from datetime import timedelta from typing import Any, Dict @@ -22,8 +21,10 @@ args_dict = { - "provider": "anthropic", - "model": "claude-3-5-sonnet-20241022", + "provider": "vertexanthropic", + "model": "claude-3-7-sonnet@20250219", + # "model": "gemini-2.0-flash-001", + # "model": "claude-3-7-sonnet-20250219", # "model": "gpt-4o-2024-08-06", # "filter_instance": "^(Prob070_ece241_2013_q2|Prob151_review2015_fsm)$", "filter_instance": "^(Prob011_norgate)$", @@ -36,7 +37,7 @@ "top_p": 0.95, "max_token": 8192, "use_golden_tb_in_mage": True, - "key_cfg_path": os.path.join(os.path.dirname(os.path.abspath(__file__)), "key.cfg"), + "key_cfg_path": "./key.cfg", }