Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 14 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ Run arguments can be set in the file like:

```
args_dict = {
"provider": "anthropic",
"model": "claude-3-5-sonnet-20241022",
# "model": "gpt-4o-2024-08-06",
# "filter_instance": "^(Prob070_ece241_2013_q2|Prob151_review2015_fsm)$",
Expand All @@ -148,18 +149,23 @@ args_dict = {
"n": 1,
"temperature": 0.85,
"top_p": 0.95,
"max_token": 8192,
"use_golden_tb_in_mage": True,
"key_cfg_path": os.path.join(os.path.dirname(os.path.abspath(__file__)), "key.cfg"),
}
```
Where each argument means:
1. model: The LLM model used. Support for gpt-4o and claude has been verified.
2. filter_instance: A RegEx style instance name filter.
3. type_benchmark: Support running verilog_eval_v1 or verilog_eval_v2
4. path_benchmark: Where the benchmark repo is cloned
5. run_identifier: Unique name to disguish different runs
6. n: Number of repeated run to execute
7. temperature: Argument for LLM generation randomness. Usually between [0, 1]
8. top_p: Argument for LLM generation randomness. Usually between [0, 1]
1. provider: The api provider of the LLM model used. e.g. anthropic-->claude, openai-->gpt-4o
2. model: The LLM model used. Support for gpt-4o and claude has been verified.
3. filter_instance: A RegEx style instance name filter.
4. type_benchmark: Support running verilog_eval_v1 or verilog_eval_v2
5. path_benchmark: Where the benchmark repo is cloned
6. run_identifier: Unique name to disguish different runs
7. n: Number of repeated run to execute
8. temperature: Argument for LLM generation randomness. Usually between [0, 1]
9. top_p: Argument for LLM generation randomness. Usually between [0, 1]
10. max_token: Maximum number of tokens the model is allowed to generate in its output.
11. key_cfg_path: Path to your key.cfg file. Defaulted to be under MAGE/tests


## Development Guide
Expand Down
33 changes: 21 additions & 12 deletions src/mage_rtl/gen_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,19 +34,28 @@ def __getitem__(self, index):


def get_llm(**kwargs) -> LLM:
err_msgs = []
for LLM_func in [OpenAI, Anthropic]:
try:
llm: LLM = LLM_func(**kwargs)
_ = llm.complete("Say 'Hi'")
break
except Exception as e:
err_msgs.append(str(e))
else:
raise Exception(
f"gen_config: Failed to get LLM. Error msgs include:\n"
+ "\n".join(err_msgs)
LLM_func = Anthropic
cfg = Config(kwargs["cfg_path"])
api_key_cfg = ""
if kwargs["provider"] == "anthropic":
LLM_func = Anthropic
api_key_cfg = cfg["ANTHROPIC_API_KEY"]
elif kwargs["provider"] == "openai":
LLM_func = OpenAI
api_key_cfg = cfg["OPENAI_API_KEY"]
# add more providers if needed

try:
llm: LLM = LLM_func(
model=kwargs["model"],
api_key=api_key_cfg,
max_tokens=kwargs["max_token"],
)
_ = llm.complete("Say 'Hi'")

except Exception as e:
raise Exception("gen_config: Failed to get LLM") from e

return llm


Expand Down
15 changes: 12 additions & 3 deletions tests/test_top_agent.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import argparse
import json
import os
import time
from datetime import timedelta
from typing import Any, Dict
Expand All @@ -12,7 +13,7 @@
TypeBenchmarkFile,
get_benchmark_contents,
)
from mage_rtl.gen_config import Config, get_llm, set_exp_setting
from mage_rtl.gen_config import get_llm, set_exp_setting
from mage_rtl.log_utils import get_logger
from mage_rtl.sim_reviewer import sim_review_golden_benchmark
from mage_rtl.token_counter import TokenCount
Expand All @@ -21,6 +22,7 @@


args_dict = {
"provider": "anthropic",
"model": "claude-3-5-sonnet-20241022",
# "model": "gpt-4o-2024-08-06",
# "filter_instance": "^(Prob070_ece241_2013_q2|Prob151_review2015_fsm)$",
Expand All @@ -32,7 +34,9 @@
"n": 1,
"temperature": 0.85,
"top_p": 0.95,
"max_token": 8192,
"use_golden_tb_in_mage": True,
"key_cfg_path": os.path.join(os.path.dirname(os.path.abspath(__file__)), "key.cfg"),
}


Expand Down Expand Up @@ -148,8 +152,13 @@ def run_round(args: argparse.Namespace, llm: LLM):

def main():
args = argparse.Namespace(**args_dict)
cfg = Config("./key.cfg")
llm = get_llm(model=args.model, api_key=cfg["ANTHROPIC_API_KEY"], max_tokens=8192)

llm = get_llm(
model=args.model,
cfg_path=args.key_cfg_path,
max_token=args.max_token,
provider=args.provider,
)
identifier_head = args.run_identifier
n = args.n
set_exp_setting(temperature=args.temperature, top_p=args.top_p)
Expand Down