Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 33 additions & 12 deletions scripts/benchmark_eval_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def patch(eval_results, dataset):
"""
Patch the eval results with the dataset
"""
for pid in range(1, len(dataset) + 1):
for pid in dataset.get_problem_ids():
if str(pid) not in eval_results:
eval_results[str(pid)] = {
"sample_id": 0,
Expand Down Expand Up @@ -136,19 +136,40 @@ def analyze_greedy_eval(run_name, hardware, baseline, level):
)

# Extract the speedup values
is_correct = np.array([entry["correctness"] for entry in eval_results.values()])
baseline_speed = np.array(
[entry["mean"] for entry in baseline_results[f"level{level}"].values()]
)
actual_speed = np.array([entry["runtime"] for entry in eval_results.values()])
is_correct_list = []
baseline_speed_list = []
actual_speed_list = []

# Sort problem IDs to ensure consistent order
sorted_pids = sorted(dataset.get_problem_ids())

for pid in sorted_pids:
# Get eval result
if str(pid) not in eval_results:
print(f"Warning: Problem {pid} not found in eval results")
continue
eval_entry = eval_results[str(pid)]

# Get baseline result
problem_path = dataset.get_problem_by_id(pid)
problem_name = os.path.basename(problem_path)

if problem_name not in baseline_results[f"level{level}"]:
print(f"Warning: Problem {problem_name} not found in baseline results")
continue

baseline_entry = baseline_results[f"level{level}"][problem_name]

is_correct_list.append(eval_entry["correctness"])
actual_speed_list.append(eval_entry["runtime"])
baseline_speed_list.append(baseline_entry["mean"])

is_correct = np.array(is_correct_list)
baseline_speed = np.array(baseline_speed_list)
actual_speed = np.array(actual_speed_list)
n = len(is_correct)

assert (
len(baseline_speed) == n
), "Baseline speedup values do not match the number of eval results"
assert (
len(actual_speed) == n
), "Actual speedup values do not match the number of eval results"
print(f"Aligned {n} problems for analysis")

# Calculate the metrics
gmsr_correct = geometric_mean_speed_ratio_correct_only(
Expand Down
25 changes: 10 additions & 15 deletions scripts/eval_from_generations.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,10 +257,7 @@ def fetch_ref_arch_from_problem_id(
problem_name = curr_problem_row["name"][0]

elif dataset_src == "local":
problem_idx_in_dataset = (
problem_id - 1
) # due to dataset list being 0-indexed locally
ref_arch_path = dataset[problem_idx_in_dataset]
ref_arch_path = dataset.get_problem_by_id(problem_id)

problem_name = os.path.basename(ref_arch_path)
ref_arch_src = read_file(ref_arch_path)
Expand Down Expand Up @@ -764,17 +761,18 @@ def main(config: EvalConfig):
curr_level_dataset = construct_kernelbench_dataset(config.level)

num_problems_in_level = len(curr_level_dataset)
all_problem_ids = curr_level_dataset.get_problem_ids() if config.dataset_src == "local" else list(range(1, num_problems_in_level + 1))

if config.subset == (None, None):
problem_id_range = range(1, num_problems_in_level)
problem_ids_to_run = all_problem_ids
else:
assert (
config.subset[0] >= 1 and config.subset[1] <= num_problems_in_level
), f"Subset range {config.subset} out of range for Level {config.level}"
problem_id_range = range(config.subset[0], config.subset[1])
start, end = config.subset
problem_ids_to_run = [pid for pid in all_problem_ids if start <= pid <= end]
if not problem_ids_to_run:
print(f"Warning: No problems found in subset range {config.subset}")

print(
f"Evaluating {config.num_samples_per_problem} sample(s) each for level {config.level} problems: {problem_id_range}"
f"Evaluating {config.num_samples_per_problem} sample(s) each for level {config.level} problems: {problem_ids_to_run}"
)

run_dir = os.path.join(config.runs_dir, config.run_name)
Expand All @@ -784,22 +782,19 @@ def main(config: EvalConfig):
# single_eval_example(config, curr_level_dataset, run_dir, eval_file_path)

total_work = []
for problem_id in range(
problem_id_range.start, problem_id_range.stop + 1
): # end index is inclusive
for problem_id in problem_ids_to_run:
for sample_id in range(config.num_samples_per_problem):
if not check_if_eval_exists_local(problem_id, sample_id, eval_file_path):
total_work.append((problem_id, sample_id))

print(
f"Start evaluation on {len(total_work)} unevaluated samples"
f" in range: {problem_id_range}"
f" in range: {problem_ids_to_run}"
)
# Build Cache on CPU as that is faster (only for local mode)
if config.build_cache and config.eval_mode == "local":
compile.batch_compile(total_work, config.to_dict())

# Batch Eval on multiple GPUs in parallel
batch_eval(total_work, config, curr_level_dataset, run_dir, eval_file_path)

# Calculate pass@k metrics if multiple samples per problem were evaluated
Expand Down
5 changes: 1 addition & 4 deletions scripts/generate_and_eval_single_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,10 +139,7 @@ def main(config: EvalConfig):
problem_name = curr_problem_row["name"][0]

elif config.dataset_src == "local":
problem_idx_in_dataset = (
config.problem_id - 1
) # due to dataset list being 0-indexed locally
ref_arch_path = curr_level_dataset[problem_idx_in_dataset]
ref_arch_path = curr_level_dataset.get_problem_by_id(config.problem_id)

problem_name = os.path.basename(ref_arch_path)
ref_arch_src = read_file(ref_arch_path)
Expand Down
7 changes: 4 additions & 3 deletions scripts/generate_and_eval_single_sample_modal.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from datasets import load_dataset

#from src.dataset import construct_kernelbench_dataset
from src.dataset import construct_kernelbench_dataset
from src.eval import eval_kernel_against_ref
from src.prompt_constructor import prompt_generate_custom_cuda_from_prompt_template
from src.prompt_constructor_multilang import get_prompt_for_backend
Expand Down Expand Up @@ -148,6 +148,8 @@ def main(config: EvalConfig):
if config.dataset_src == "huggingface":
dataset = load_dataset(config.dataset_name)
curr_level_dataset = dataset[f"level_{config.level}"]
elif config.dataset_src == "local":
curr_level_dataset = construct_kernelbench_dataset(config.level)

if config.log:
os.makedirs(config.logdir, exist_ok=True)
Expand All @@ -168,8 +170,7 @@ def main(config: EvalConfig):
problem_name = curr_problem_row["name"][0]

elif config.dataset_src == "local":
problem_idx_in_dataset = config.problem_id - 1 # due to dataset list being 0-indexed locally
ref_arch_path = curr_level_dataset[problem_idx_in_dataset]
ref_arch_path = curr_level_dataset.get_problem_by_id(config.problem_id)

problem_name = os.path.basename(ref_arch_path)
ref_arch_src = read_file(ref_arch_path)
Expand Down
38 changes: 5 additions & 33 deletions scripts/generate_baseline_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
set_seed,
fetch_ref_arch_from_problem_id,
)
from src.dataset import construct_problem_dataset_from_problem_dir
from src.dataset import construct_kernelbench_dataset, KernelBenchDataset, fetch_ref_arch_from_dataset
from src.utils import read_file
import os
import json
Expand Down Expand Up @@ -46,32 +46,6 @@
TIMING_DIR = os.path.join(REPO_TOP_PATH, "results", "timing")


def fetch_ref_arch_from_dataset(dataset: list[str],
problem_id: int) -> tuple[str, str, str]:
"""
Fetch the reference architecture from the problem directory
problem_id should be logical index (1-indexed), matching the problem_id in the problem_name

Returns:
ref_arch_path: str, the path to the reference architecture
ref_arch_name: str, the name of the reference architecture
ref_arch_src: str, the source code of the reference architecture
"""
ref_arch_path = None

for file in dataset:
if file.split("/")[-1].split("_")[0] == str(problem_id):
ref_arch_path = file
break
if ref_arch_path is None:
raise ValueError(f"No reference architecture found for problem_id {problem_id}")

ref_arch_src = read_file(ref_arch_path)

ref_arch_name = ref_arch_path.split("/")[-1]
return (ref_arch_path, ref_arch_name, ref_arch_src)


def measure_program_time(
ref_arch_name: str,
ref_arch_src: str,
Expand Down Expand Up @@ -143,12 +117,11 @@ def record_baseline_times(use_torch_compile: bool = False,
json_results = {}

for level in [1, 2, 3]:
PROBLEM_DIR = os.path.join(KERNEL_BENCH_PATH, "level" + str(level))
dataset = construct_problem_dataset_from_problem_dir(PROBLEM_DIR)
dataset = construct_kernelbench_dataset(level)
json_results[f"level{level}"] = {}

num_problems = len(dataset)
for problem_id in tqdm(range(1, num_problems + 1)):
for problem_id in tqdm(dataset.get_problem_ids()):
ref_arch_path, ref_arch_name, ref_arch_src = fetch_ref_arch_from_dataset(dataset, problem_id)
runtime_stats = measure_program_time(
ref_arch_name=ref_arch_name,
Expand All @@ -174,8 +147,7 @@ def test_measure_particular_program(level_num: int, problem_id: int):
"""
device = torch.device("cuda:0")

PROBLEM_DIR = os.path.join(KERNEL_BENCH_PATH, "level" + str(level_num))
dataset = construct_problem_dataset_from_problem_dir(PROBLEM_DIR)
dataset = construct_kernelbench_dataset(level_num)

ref_arch_path, ref_arch_name, ref_arch_src = fetch_ref_arch_from_dataset(dataset, problem_id)

Expand Down Expand Up @@ -249,7 +221,7 @@ def get_time_old(level_num, problem_id, num_trials=100, torch_compile=False):
ref_arch_name, ref_arch_src = fetch_ref_arch_from_level_problem_id(
level_num, problem_id, with_name=True
)
ref_arch_name = ref_arch_name.split("/")[-1]
ref_arch_name = os.path.basename(ref_arch_name)
context = {}
Model, get_init_inputs, get_inputs = load_original_model_and_inputs(
ref_arch_src, context
Expand Down
34 changes: 4 additions & 30 deletions scripts/generate_baseline_time_modal.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
set_seed,
fetch_ref_arch_from_problem_id,
)
from src.dataset import construct_problem_dataset_from_problem_dir
from src.dataset import construct_kernelbench_dataset, KernelBenchDataset, fetch_ref_arch_from_dataset
from src.utils import read_file
import os
import json
Expand Down Expand Up @@ -126,31 +126,6 @@ def write_batch_to_json(entries_to_write: list, f_path: str):

print(f"[INFO] Wrote {len(entries_to_write)} entries to {f_path}")

def fetch_ref_arch_from_dataset(dataset: list[str],
problem_id: int) -> tuple[str, str, str]:
"""
Fetch the reference architecture from the problem directory
problem_id should be logical index (1-indexed), matching the problem_id in the problem_name

Returns:
ref_arch_path: str, the path to the reference architecture
ref_arch_name: str, the name of the reference architecture
ref_arch_src: str, the source code of the reference architecture
"""
ref_arch_path = None

for file in dataset:
if file.split("/")[-1].split("_")[0] == str(problem_id):
ref_arch_path = file
break
if ref_arch_path is None:
raise ValueError(f"No reference architecture found for problem_id {problem_id}")

ref_arch_src = read_file(ref_arch_path)

ref_arch_name = ref_arch_path.split("/")[-1]
return (ref_arch_path, ref_arch_name, ref_arch_src)

@app.cls(image=image, scaledown_window=5)
class EvalFunc:

Expand Down Expand Up @@ -229,10 +204,9 @@ def record_baseline_times(config: BaselineConfig,
json_results = []

level = config.level
PROBLEM_DIR = os.path.join(KERNEL_BENCH_PATH, "level" + str(level))
dataset = construct_problem_dataset_from_problem_dir(PROBLEM_DIR)
dataset = construct_kernelbench_dataset(level)
num_problems = len(dataset)
total_work = [(i, *fetch_ref_arch_from_dataset(dataset, i)) for i in list(range(1, num_problems + 1))]
total_work = [(i, *fetch_ref_arch_from_dataset(dataset, i)) for i in dataset.get_problem_ids()]

with tqdm(total=len(total_work), desc="Processing batches") as pbar:
while len(total_work) > 0:
Expand Down Expand Up @@ -353,7 +327,7 @@ def get_time_old(level_num, problem_id, num_trials=100, torch_compile=False):
ref_arch_name, ref_arch_src = fetch_ref_arch_from_level_problem_id(
level_num, problem_id, with_name=True
)
ref_arch_name = ref_arch_name.split("/")[-1]
ref_arch_name = os.path.basename(ref_arch_name)
context = {}
Model, get_init_inputs, get_inputs = load_original_model_and_inputs(
ref_arch_src, context
Expand Down
24 changes: 10 additions & 14 deletions scripts/generate_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __init__(self):
self.subset = (
None,
None,
) # (problem_id, problem_name), these are the logical index
) # (start_id, end_id), both inclusive - logical 1-indexed IDs

self.run_name = REQUIRED # name of the run

Expand Down Expand Up @@ -112,10 +112,7 @@ def generate_sample_single(
problem_name = curr_problem_row["name"][0]

elif config.dataset_src == "local":
problem_idx_in_dataset = (
work.problem_id - 1
) # due to dataset list being 0-indexed locally
ref_arch_path = dataset[problem_idx_in_dataset]
ref_arch_path = dataset.get_problem_by_id(work.problem_id)

problem_name = os.path.basename(ref_arch_path)
ref_arch_src = read_file(ref_arch_path)
Expand Down Expand Up @@ -224,17 +221,18 @@ def main(config: GenerationConfig):
curr_level_dataset = construct_kernelbench_dataset(config.level)

num_problems_in_level = len(curr_level_dataset)
all_problem_ids = curr_level_dataset.get_problem_ids() if config.dataset_src == "local" else list(range(1, num_problems_in_level + 1))

if config.subset == (None, None):
problem_id_range = range(1, num_problems_in_level)
problem_ids_to_run = all_problem_ids
else:
assert (
config.subset[0] >= 1 and config.subset[1] <= num_problems_in_level
), f"Subset range {config.subset} out of range for Level {config.level}"
problem_id_range = range(config.subset[0], config.subset[1])
start, end = config.subset
problem_ids_to_run = [pid for pid in all_problem_ids if start <= pid <= end]
if not problem_ids_to_run:
print(f"Warning: No problems found in subset range {config.subset}")

print(
f"Generating {config.num_samples} sample(s) each for level {config.level} problems: {problem_id_range}"
f"Generating {config.num_samples} sample(s) each for level {config.level} problems: {problem_ids_to_run}"
)

# set up run directory
Expand All @@ -253,9 +251,7 @@ def main(config: GenerationConfig):
problems_to_run = []
total_problems = 0
already_completed = 0
for problem_id in range(
problem_id_range.start, problem_id_range.stop + 1
): # end index is inclusive
for problem_id in problem_ids_to_run:
for sample_id in range(config.num_samples):
total_problems += 1
if not check_kernel_exists(run_dir, config.level, problem_id, sample_id):
Expand Down
9 changes: 4 additions & 5 deletions scripts/inspect_baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
set_seed,
fetch_ref_arch_from_problem_id,
)
from src.dataset import construct_problem_dataset_from_problem_dir
from src.dataset import construct_kernelbench_dataset
import os, sys
import logging
import json
Expand Down Expand Up @@ -93,15 +93,14 @@ def emit(self, record):
separator("")

def fetch_ref_arch_from_level_problem_id(level_num, problem_id, with_name=False):
PROBLEM_DIR = os.path.join(KERNEL_BENCH_PATH, "level" + str(level_num))
dataset = construct_problem_dataset_from_problem_dir(PROBLEM_DIR)
dataset = construct_kernelbench_dataset(level_num)
return fetch_ref_arch_from_problem_id(problem_id, dataset, with_name)

def inspect_torch_compile_triton(level_num, problem_id):
ref_arch_name, ref_arch_src = fetch_ref_arch_from_level_problem_id(
level_num, problem_id, with_name=True
)
ref_arch_name = ref_arch_name.split("/")[-1]
ref_arch_name = os.path.basename(ref_arch_name)
context = {}
Model, get_init_inputs, get_inputs = load_original_model_and_inputs(
ref_arch_src, context
Expand All @@ -116,7 +115,7 @@ def inspect_baseline_torch_compile(level_num, problem_id):
level_num, problem_id, with_name=True
)

ref_arch_name = ref_arch_name.split("/")[-1]
ref_arch_name = os.path.basename(ref_arch_name)
context = {}
Model, get_init_inputs, get_inputs = load_original_model_and_inputs(
ref_arch_src, context
Expand Down
Loading