diff --git a/README.md b/README.md index 0177d3f..d6161c0 100644 --- a/README.md +++ b/README.md @@ -137,6 +137,7 @@ Run arguments can be set in the file like: ``` args_dict = { + "provider": "anthropic", "model": "claude-3-5-sonnet-20241022", # "model": "gpt-4o-2024-08-06", # "filter_instance": "^(Prob070_ece241_2013_q2|Prob151_review2015_fsm)$", @@ -148,18 +149,23 @@ args_dict = { "n": 1, "temperature": 0.85, "top_p": 0.95, + "max_token": 8192, "use_golden_tb_in_mage": True, + "key_cfg_path": os.path.join(os.path.dirname(os.path.abspath(__file__)), "key.cfg"), } ``` Where each argument means: -1. model: The LLM model used. Support for gpt-4o and claude has been verified. -2. filter_instance: A RegEx style instance name filter. -3. type_benchmark: Support running verilog_eval_v1 or verilog_eval_v2 -4. path_benchmark: Where the benchmark repo is cloned -5. run_identifier: Unique name to disguish different runs -6. n: Number of repeated run to execute -7. temperature: Argument for LLM generation randomness. Usually between [0, 1] -8. top_p: Argument for LLM generation randomness. Usually between [0, 1] +1. provider: The api provider of the LLM model used. e.g. anthropic-->claude, openai-->gpt-4o +2. model: The LLM model used. Support for gpt-4o and claude has been verified. +3. filter_instance: A RegEx style instance name filter. +4. type_benchmark: Support running verilog_eval_v1 or verilog_eval_v2 +5. path_benchmark: Where the benchmark repo is cloned +6. run_identifier: Unique name to disguish different runs +7. n: Number of repeated run to execute +8. temperature: Argument for LLM generation randomness. Usually between [0, 1] +9. top_p: Argument for LLM generation randomness. Usually between [0, 1] +10. max_token: Maximum number of tokens the model is allowed to generate in its output. +11. key_cfg_path: Path to your key.cfg file. Defaulted to be under MAGE/tests ## Development Guide diff --git a/src/mage_rtl/gen_config.py b/src/mage_rtl/gen_config.py index 7fbf60f..85abdeb 100644 --- a/src/mage_rtl/gen_config.py +++ b/src/mage_rtl/gen_config.py @@ -34,19 +34,28 @@ def __getitem__(self, index): def get_llm(**kwargs) -> LLM: - err_msgs = [] - for LLM_func in [OpenAI, Anthropic]: - try: - llm: LLM = LLM_func(**kwargs) - _ = llm.complete("Say 'Hi'") - break - except Exception as e: - err_msgs.append(str(e)) - else: - raise Exception( - f"gen_config: Failed to get LLM. Error msgs include:\n" - + "\n".join(err_msgs) + LLM_func = Anthropic + cfg = Config(kwargs["cfg_path"]) + api_key_cfg = "" + if kwargs["provider"] == "anthropic": + LLM_func = Anthropic + api_key_cfg = cfg["ANTHROPIC_API_KEY"] + elif kwargs["provider"] == "openai": + LLM_func = OpenAI + api_key_cfg = cfg["OPENAI_API_KEY"] + # add more providers if needed + + try: + llm: LLM = LLM_func( + model=kwargs["model"], + api_key=api_key_cfg, + max_tokens=kwargs["max_token"], ) + _ = llm.complete("Say 'Hi'") + + except Exception as e: + raise Exception("gen_config: Failed to get LLM") from e + return llm diff --git a/tests/test_top_agent.py b/tests/test_top_agent.py index f017d17..0b91392 100644 --- a/tests/test_top_agent.py +++ b/tests/test_top_agent.py @@ -1,5 +1,6 @@ import argparse import json +import os import time from datetime import timedelta from typing import Any, Dict @@ -12,7 +13,7 @@ TypeBenchmarkFile, get_benchmark_contents, ) -from mage_rtl.gen_config import Config, get_llm, set_exp_setting +from mage_rtl.gen_config import get_llm, set_exp_setting from mage_rtl.log_utils import get_logger from mage_rtl.sim_reviewer import sim_review_golden_benchmark from mage_rtl.token_counter import TokenCount @@ -21,6 +22,7 @@ args_dict = { + "provider": "anthropic", "model": "claude-3-5-sonnet-20241022", # "model": "gpt-4o-2024-08-06", # "filter_instance": "^(Prob070_ece241_2013_q2|Prob151_review2015_fsm)$", @@ -32,7 +34,9 @@ "n": 1, "temperature": 0.85, "top_p": 0.95, + "max_token": 8192, "use_golden_tb_in_mage": True, + "key_cfg_path": os.path.join(os.path.dirname(os.path.abspath(__file__)), "key.cfg"), } @@ -148,8 +152,13 @@ def run_round(args: argparse.Namespace, llm: LLM): def main(): args = argparse.Namespace(**args_dict) - cfg = Config("./key.cfg") - llm = get_llm(model=args.model, api_key=cfg["ANTHROPIC_API_KEY"], max_tokens=8192) + + llm = get_llm( + model=args.model, + cfg_path=args.key_cfg_path, + max_token=args.max_token, + provider=args.provider, + ) identifier_head = args.run_identifier n = args.n set_exp_setting(temperature=args.temperature, top_p=args.top_p)