diff --git a/scripts/benchmark_eval_analysis.py b/scripts/benchmark_eval_analysis.py index e2bea005..428e240f 100644 --- a/scripts/benchmark_eval_analysis.py +++ b/scripts/benchmark_eval_analysis.py @@ -40,7 +40,7 @@ def patch(eval_results, dataset): """ Patch the eval results with the dataset """ - for pid in range(1, len(dataset) + 1): + for pid in dataset.get_problem_ids(): if str(pid) not in eval_results: eval_results[str(pid)] = { "sample_id": 0, @@ -136,19 +136,40 @@ def analyze_greedy_eval(run_name, hardware, baseline, level): ) # Extract the speedup values - is_correct = np.array([entry["correctness"] for entry in eval_results.values()]) - baseline_speed = np.array( - [entry["mean"] for entry in baseline_results[f"level{level}"].values()] - ) - actual_speed = np.array([entry["runtime"] for entry in eval_results.values()]) + is_correct_list = [] + baseline_speed_list = [] + actual_speed_list = [] + + # Sort problem IDs to ensure consistent order + sorted_pids = sorted(dataset.get_problem_ids()) + + for pid in sorted_pids: + # Get eval result + if str(pid) not in eval_results: + print(f"Warning: Problem {pid} not found in eval results") + continue + eval_entry = eval_results[str(pid)] + + # Get baseline result + problem_path = dataset.get_problem_by_id(pid) + problem_name = os.path.basename(problem_path) + + if problem_name not in baseline_results[f"level{level}"]: + print(f"Warning: Problem {problem_name} not found in baseline results") + continue + + baseline_entry = baseline_results[f"level{level}"][problem_name] + + is_correct_list.append(eval_entry["correctness"]) + actual_speed_list.append(eval_entry["runtime"]) + baseline_speed_list.append(baseline_entry["mean"]) + + is_correct = np.array(is_correct_list) + baseline_speed = np.array(baseline_speed_list) + actual_speed = np.array(actual_speed_list) n = len(is_correct) - assert ( - len(baseline_speed) == n - ), "Baseline speedup values do not match the number of eval results" - assert ( - len(actual_speed) == n - ), "Actual speedup values do not match the number of eval results" + print(f"Aligned {n} problems for analysis") # Calculate the metrics gmsr_correct = geometric_mean_speed_ratio_correct_only( diff --git a/scripts/eval_from_generations.py b/scripts/eval_from_generations.py index 2e39e3be..a973187f 100644 --- a/scripts/eval_from_generations.py +++ b/scripts/eval_from_generations.py @@ -257,10 +257,7 @@ def fetch_ref_arch_from_problem_id( problem_name = curr_problem_row["name"][0] elif dataset_src == "local": - problem_idx_in_dataset = ( - problem_id - 1 - ) # due to dataset list being 0-indexed locally - ref_arch_path = dataset[problem_idx_in_dataset] + ref_arch_path = dataset.get_problem_by_id(problem_id) problem_name = os.path.basename(ref_arch_path) ref_arch_src = read_file(ref_arch_path) @@ -764,17 +761,18 @@ def main(config: EvalConfig): curr_level_dataset = construct_kernelbench_dataset(config.level) num_problems_in_level = len(curr_level_dataset) + all_problem_ids = curr_level_dataset.get_problem_ids() if config.dataset_src == "local" else list(range(1, num_problems_in_level + 1)) if config.subset == (None, None): - problem_id_range = range(1, num_problems_in_level) + problem_ids_to_run = all_problem_ids else: - assert ( - config.subset[0] >= 1 and config.subset[1] <= num_problems_in_level - ), f"Subset range {config.subset} out of range for Level {config.level}" - problem_id_range = range(config.subset[0], config.subset[1]) + start, end = config.subset + problem_ids_to_run = [pid for pid in all_problem_ids if start <= pid <= end] + if not problem_ids_to_run: + print(f"Warning: No problems found in subset range {config.subset}") print( - f"Evaluating {config.num_samples_per_problem} sample(s) each for level {config.level} problems: {problem_id_range}" + f"Evaluating {config.num_samples_per_problem} sample(s) each for level {config.level} problems: {problem_ids_to_run}" ) run_dir = os.path.join(config.runs_dir, config.run_name) @@ -784,22 +782,19 @@ def main(config: EvalConfig): # single_eval_example(config, curr_level_dataset, run_dir, eval_file_path) total_work = [] - for problem_id in range( - problem_id_range.start, problem_id_range.stop + 1 - ): # end index is inclusive + for problem_id in problem_ids_to_run: for sample_id in range(config.num_samples_per_problem): if not check_if_eval_exists_local(problem_id, sample_id, eval_file_path): total_work.append((problem_id, sample_id)) print( f"Start evaluation on {len(total_work)} unevaluated samples" - f" in range: {problem_id_range}" + f" in range: {problem_ids_to_run}" ) # Build Cache on CPU as that is faster (only for local mode) if config.build_cache and config.eval_mode == "local": compile.batch_compile(total_work, config.to_dict()) - # Batch Eval on multiple GPUs in parallel batch_eval(total_work, config, curr_level_dataset, run_dir, eval_file_path) # Calculate pass@k metrics if multiple samples per problem were evaluated diff --git a/scripts/generate_and_eval_single_sample.py b/scripts/generate_and_eval_single_sample.py index 18fb3c55..0b86964a 100644 --- a/scripts/generate_and_eval_single_sample.py +++ b/scripts/generate_and_eval_single_sample.py @@ -139,10 +139,7 @@ def main(config: EvalConfig): problem_name = curr_problem_row["name"][0] elif config.dataset_src == "local": - problem_idx_in_dataset = ( - config.problem_id - 1 - ) # due to dataset list being 0-indexed locally - ref_arch_path = curr_level_dataset[problem_idx_in_dataset] + ref_arch_path = curr_level_dataset.get_problem_by_id(config.problem_id) problem_name = os.path.basename(ref_arch_path) ref_arch_src = read_file(ref_arch_path) diff --git a/scripts/generate_and_eval_single_sample_modal.py b/scripts/generate_and_eval_single_sample_modal.py index 6962f515..471cee9a 100644 --- a/scripts/generate_and_eval_single_sample_modal.py +++ b/scripts/generate_and_eval_single_sample_modal.py @@ -13,7 +13,7 @@ from datasets import load_dataset -#from src.dataset import construct_kernelbench_dataset +from src.dataset import construct_kernelbench_dataset from src.eval import eval_kernel_against_ref from src.prompt_constructor import prompt_generate_custom_cuda_from_prompt_template from src.prompt_constructor_multilang import get_prompt_for_backend @@ -148,6 +148,8 @@ def main(config: EvalConfig): if config.dataset_src == "huggingface": dataset = load_dataset(config.dataset_name) curr_level_dataset = dataset[f"level_{config.level}"] + elif config.dataset_src == "local": + curr_level_dataset = construct_kernelbench_dataset(config.level) if config.log: os.makedirs(config.logdir, exist_ok=True) @@ -168,8 +170,7 @@ def main(config: EvalConfig): problem_name = curr_problem_row["name"][0] elif config.dataset_src == "local": - problem_idx_in_dataset = config.problem_id - 1 # due to dataset list being 0-indexed locally - ref_arch_path = curr_level_dataset[problem_idx_in_dataset] + ref_arch_path = curr_level_dataset.get_problem_by_id(config.problem_id) problem_name = os.path.basename(ref_arch_path) ref_arch_src = read_file(ref_arch_path) diff --git a/scripts/generate_baseline_time.py b/scripts/generate_baseline_time.py index 5a68ea08..95fca7ad 100644 --- a/scripts/generate_baseline_time.py +++ b/scripts/generate_baseline_time.py @@ -7,7 +7,7 @@ set_seed, fetch_ref_arch_from_problem_id, ) -from src.dataset import construct_problem_dataset_from_problem_dir +from src.dataset import construct_kernelbench_dataset, KernelBenchDataset, fetch_ref_arch_from_dataset from src.utils import read_file import os import json @@ -46,32 +46,6 @@ TIMING_DIR = os.path.join(REPO_TOP_PATH, "results", "timing") -def fetch_ref_arch_from_dataset(dataset: list[str], - problem_id: int) -> tuple[str, str, str]: - """ - Fetch the reference architecture from the problem directory - problem_id should be logical index (1-indexed), matching the problem_id in the problem_name - - Returns: - ref_arch_path: str, the path to the reference architecture - ref_arch_name: str, the name of the reference architecture - ref_arch_src: str, the source code of the reference architecture - """ - ref_arch_path = None - - for file in dataset: - if file.split("/")[-1].split("_")[0] == str(problem_id): - ref_arch_path = file - break - if ref_arch_path is None: - raise ValueError(f"No reference architecture found for problem_id {problem_id}") - - ref_arch_src = read_file(ref_arch_path) - - ref_arch_name = ref_arch_path.split("/")[-1] - return (ref_arch_path, ref_arch_name, ref_arch_src) - - def measure_program_time( ref_arch_name: str, ref_arch_src: str, @@ -143,12 +117,11 @@ def record_baseline_times(use_torch_compile: bool = False, json_results = {} for level in [1, 2, 3]: - PROBLEM_DIR = os.path.join(KERNEL_BENCH_PATH, "level" + str(level)) - dataset = construct_problem_dataset_from_problem_dir(PROBLEM_DIR) + dataset = construct_kernelbench_dataset(level) json_results[f"level{level}"] = {} num_problems = len(dataset) - for problem_id in tqdm(range(1, num_problems + 1)): + for problem_id in tqdm(dataset.get_problem_ids()): ref_arch_path, ref_arch_name, ref_arch_src = fetch_ref_arch_from_dataset(dataset, problem_id) runtime_stats = measure_program_time( ref_arch_name=ref_arch_name, @@ -174,8 +147,7 @@ def test_measure_particular_program(level_num: int, problem_id: int): """ device = torch.device("cuda:0") - PROBLEM_DIR = os.path.join(KERNEL_BENCH_PATH, "level" + str(level_num)) - dataset = construct_problem_dataset_from_problem_dir(PROBLEM_DIR) + dataset = construct_kernelbench_dataset(level_num) ref_arch_path, ref_arch_name, ref_arch_src = fetch_ref_arch_from_dataset(dataset, problem_id) @@ -249,7 +221,7 @@ def get_time_old(level_num, problem_id, num_trials=100, torch_compile=False): ref_arch_name, ref_arch_src = fetch_ref_arch_from_level_problem_id( level_num, problem_id, with_name=True ) - ref_arch_name = ref_arch_name.split("/")[-1] + ref_arch_name = os.path.basename(ref_arch_name) context = {} Model, get_init_inputs, get_inputs = load_original_model_and_inputs( ref_arch_src, context diff --git a/scripts/generate_baseline_time_modal.py b/scripts/generate_baseline_time_modal.py index a0039193..f7a579fa 100644 --- a/scripts/generate_baseline_time_modal.py +++ b/scripts/generate_baseline_time_modal.py @@ -7,7 +7,7 @@ set_seed, fetch_ref_arch_from_problem_id, ) -from src.dataset import construct_problem_dataset_from_problem_dir +from src.dataset import construct_kernelbench_dataset, KernelBenchDataset, fetch_ref_arch_from_dataset from src.utils import read_file import os import json @@ -126,31 +126,6 @@ def write_batch_to_json(entries_to_write: list, f_path: str): print(f"[INFO] Wrote {len(entries_to_write)} entries to {f_path}") -def fetch_ref_arch_from_dataset(dataset: list[str], - problem_id: int) -> tuple[str, str, str]: - """ - Fetch the reference architecture from the problem directory - problem_id should be logical index (1-indexed), matching the problem_id in the problem_name - - Returns: - ref_arch_path: str, the path to the reference architecture - ref_arch_name: str, the name of the reference architecture - ref_arch_src: str, the source code of the reference architecture - """ - ref_arch_path = None - - for file in dataset: - if file.split("/")[-1].split("_")[0] == str(problem_id): - ref_arch_path = file - break - if ref_arch_path is None: - raise ValueError(f"No reference architecture found for problem_id {problem_id}") - - ref_arch_src = read_file(ref_arch_path) - - ref_arch_name = ref_arch_path.split("/")[-1] - return (ref_arch_path, ref_arch_name, ref_arch_src) - @app.cls(image=image, scaledown_window=5) class EvalFunc: @@ -229,10 +204,9 @@ def record_baseline_times(config: BaselineConfig, json_results = [] level = config.level - PROBLEM_DIR = os.path.join(KERNEL_BENCH_PATH, "level" + str(level)) - dataset = construct_problem_dataset_from_problem_dir(PROBLEM_DIR) + dataset = construct_kernelbench_dataset(level) num_problems = len(dataset) - total_work = [(i, *fetch_ref_arch_from_dataset(dataset, i)) for i in list(range(1, num_problems + 1))] + total_work = [(i, *fetch_ref_arch_from_dataset(dataset, i)) for i in dataset.get_problem_ids()] with tqdm(total=len(total_work), desc="Processing batches") as pbar: while len(total_work) > 0: @@ -353,7 +327,7 @@ def get_time_old(level_num, problem_id, num_trials=100, torch_compile=False): ref_arch_name, ref_arch_src = fetch_ref_arch_from_level_problem_id( level_num, problem_id, with_name=True ) - ref_arch_name = ref_arch_name.split("/")[-1] + ref_arch_name = os.path.basename(ref_arch_name) context = {} Model, get_init_inputs, get_inputs = load_original_model_and_inputs( ref_arch_src, context diff --git a/scripts/generate_samples.py b/scripts/generate_samples.py index 5b476445..630d6bf7 100644 --- a/scripts/generate_samples.py +++ b/scripts/generate_samples.py @@ -46,7 +46,7 @@ def __init__(self): self.subset = ( None, None, - ) # (problem_id, problem_name), these are the logical index + ) # (start_id, end_id), both inclusive - logical 1-indexed IDs self.run_name = REQUIRED # name of the run @@ -112,10 +112,7 @@ def generate_sample_single( problem_name = curr_problem_row["name"][0] elif config.dataset_src == "local": - problem_idx_in_dataset = ( - work.problem_id - 1 - ) # due to dataset list being 0-indexed locally - ref_arch_path = dataset[problem_idx_in_dataset] + ref_arch_path = dataset.get_problem_by_id(work.problem_id) problem_name = os.path.basename(ref_arch_path) ref_arch_src = read_file(ref_arch_path) @@ -224,17 +221,18 @@ def main(config: GenerationConfig): curr_level_dataset = construct_kernelbench_dataset(config.level) num_problems_in_level = len(curr_level_dataset) + all_problem_ids = curr_level_dataset.get_problem_ids() if config.dataset_src == "local" else list(range(1, num_problems_in_level + 1)) if config.subset == (None, None): - problem_id_range = range(1, num_problems_in_level) + problem_ids_to_run = all_problem_ids else: - assert ( - config.subset[0] >= 1 and config.subset[1] <= num_problems_in_level - ), f"Subset range {config.subset} out of range for Level {config.level}" - problem_id_range = range(config.subset[0], config.subset[1]) + start, end = config.subset + problem_ids_to_run = [pid for pid in all_problem_ids if start <= pid <= end] + if not problem_ids_to_run: + print(f"Warning: No problems found in subset range {config.subset}") print( - f"Generating {config.num_samples} sample(s) each for level {config.level} problems: {problem_id_range}" + f"Generating {config.num_samples} sample(s) each for level {config.level} problems: {problem_ids_to_run}" ) # set up run directory @@ -253,9 +251,7 @@ def main(config: GenerationConfig): problems_to_run = [] total_problems = 0 already_completed = 0 - for problem_id in range( - problem_id_range.start, problem_id_range.stop + 1 - ): # end index is inclusive + for problem_id in problem_ids_to_run: for sample_id in range(config.num_samples): total_problems += 1 if not check_kernel_exists(run_dir, config.level, problem_id, sample_id): diff --git a/scripts/inspect_baseline.py b/scripts/inspect_baseline.py index e7811f64..90bd7f2f 100644 --- a/scripts/inspect_baseline.py +++ b/scripts/inspect_baseline.py @@ -10,7 +10,7 @@ set_seed, fetch_ref_arch_from_problem_id, ) -from src.dataset import construct_problem_dataset_from_problem_dir +from src.dataset import construct_kernelbench_dataset import os, sys import logging import json @@ -93,15 +93,14 @@ def emit(self, record): separator("") def fetch_ref_arch_from_level_problem_id(level_num, problem_id, with_name=False): - PROBLEM_DIR = os.path.join(KERNEL_BENCH_PATH, "level" + str(level_num)) - dataset = construct_problem_dataset_from_problem_dir(PROBLEM_DIR) + dataset = construct_kernelbench_dataset(level_num) return fetch_ref_arch_from_problem_id(problem_id, dataset, with_name) def inspect_torch_compile_triton(level_num, problem_id): ref_arch_name, ref_arch_src = fetch_ref_arch_from_level_problem_id( level_num, problem_id, with_name=True ) - ref_arch_name = ref_arch_name.split("/")[-1] + ref_arch_name = os.path.basename(ref_arch_name) context = {} Model, get_init_inputs, get_inputs = load_original_model_and_inputs( ref_arch_src, context @@ -116,7 +115,7 @@ def inspect_baseline_torch_compile(level_num, problem_id): level_num, problem_id, with_name=True ) - ref_arch_name = ref_arch_name.split("/")[-1] + ref_arch_name = os.path.basename(ref_arch_name) context = {} Model, get_init_inputs, get_inputs = load_original_model_and_inputs( ref_arch_src, context diff --git a/scripts/inspect_triton.py b/scripts/inspect_triton.py index 4f13c8af..0170dada 100644 --- a/scripts/inspect_triton.py +++ b/scripts/inspect_triton.py @@ -26,36 +26,23 @@ set_seed, ) -def fetch_ref_arch_from_dataset(dataset: list[str], - problem_id: int) -> tuple[str, str, str]: - """ - Fetch the reference architecture from the problem directory - problem_id should be logical index (1-indexed), matching the problem_id in the problem_name +from src.dataset import construct_kernelbench_dataset, KernelBenchDataset, fetch_ref_arch_from_dataset - Returns: - ref_arch_path: str, the path to the reference architecture - ref_arch_name: str, the name of the reference architecture - ref_arch_src: str, the source code of the reference architecture - """ - ref_arch_path = None - - for file in dataset: - if file.split("/")[-1].split("_")[0] == str(problem_id): - ref_arch_path = file - break - if ref_arch_path is None: - raise ValueError(f"No reference architecture found for problem_id {problem_id}") - - ref_arch_src = read_file(ref_arch_path) - - ref_arch_name = ref_arch_path.split("/")[-1] - return (ref_arch_path, ref_arch_name, ref_arch_src) - - -def run_profile_and_save_trace(dataset: list[str], problem_id: int, num_trials=10): - """ - Helper function to get Torch Profile of a problem - # TODO: Fix up this function + +def run_profile_and_save_trace( + dataset: KernelBenchDataset, + problem_id: int, + num_trials: int = 10 +) -> None: + """Helper function to get Torch Profile of a problem. + + Args: + dataset: KernelBenchDataset object + problem_id: Problem ID to profile + num_trials: Number of profiling trials to run (default: 10) + + Note: + Saves trace files to 'trace_non_compiled.json' and 'trace_compiled.json' """ ref_arch_path, ref_arch_name, ref_arch_src = fetch_ref_arch_from_dataset( dataset, problem_id @@ -120,12 +107,19 @@ def run_profile_and_save_trace(dataset: list[str], problem_id: int, num_trials=1 # except Exception as e: # print(f"[Eval] Error in Measuring Performance: {e}") -def get_torch_compile_triton(level_num, problem_id): - """ - Get the triton code generated by torch compile for a particular problem +def get_torch_compile_triton(level_num: int, problem_id: int) -> str: + """Get the triton code generated by torch compile for a particular problem. + + Args: + level_num: KernelBench level (1, 2, or 3) + problem_id: Problem ID to inspect + + Returns: + str: Name of the reference architecture """ + dataset = construct_kernelbench_dataset(level_num) ref_arch_path, ref_arch_name, ref_arch_src = fetch_ref_arch_from_dataset( - dataset, problem_id, with_name=True + dataset, problem_id ) context = {} # import pdb; pdb.set_trace() diff --git a/scripts/run_and_check.py b/scripts/run_and_check.py index 316b96ee..43fb4e81 100644 --- a/scripts/run_and_check.py +++ b/scripts/run_and_check.py @@ -81,6 +81,7 @@ def __init__(self): # ref_origin is local, specify local file path self.ref_arch_src_path = "" # ref_origin is kernelbench, specify level and problem id + self.dataset_src = "huggingface" # either huggingface or local self.dataset_name = "ScalingIntelligence/KernelBench" self.level = "" self.problem_id = "" @@ -240,16 +241,25 @@ def main(config: ScriptConfig): assert config.level != "", "level is required" assert config.problem_id != "", "problem_id is required" - # for now use the HuggingFace dataset - dataset = load_dataset(config.dataset_name) - curr_level_dataset = dataset[f"level_{config.level}"] - - curr_problem_row = curr_level_dataset.filter(lambda x: x["problem_id"] == config.problem_id) - ref_arch_src = curr_problem_row["code"][0] - problem_name = curr_problem_row["name"][0] + if config.dataset_src == "huggingface": + # for now use the HuggingFace dataset + dataset = load_dataset(config.dataset_name) + curr_level_dataset = dataset[f"level_{config.level}"] + + curr_problem_row = curr_level_dataset.filter(lambda x: x["problem_id"] == config.problem_id) + ref_arch_src = curr_problem_row["code"][0] + problem_name = curr_problem_row["name"][0] + elif config.dataset_src == "local": + from src.dataset import construct_kernelbench_dataset + dataset = construct_kernelbench_dataset(config.level) + ref_arch_path = dataset.get_problem_by_id(int(config.problem_id)) + ref_arch_src = read_file(ref_arch_path) + problem_name = os.path.basename(ref_arch_path) + else: + raise ValueError(f"Invalid dataset_src: {config.dataset_src}") problem_number = int(problem_name.split("_")[0]) - assert problem_number == config.problem_id, f"Problem number in filename ({problem_number}) does not match config problem_id ({config.problem_id})" + assert problem_number == int(config.problem_id), f"Problem number in filename ({problem_number}) does not match config problem_id ({config.problem_id})" print(f"Fetched problem {config.problem_id} from KernelBench level {config.level}: {problem_name}") diff --git a/scripts/verify_bench.py b/scripts/verify_bench.py index 5fdc6862..2ad79395 100644 --- a/scripts/verify_bench.py +++ b/scripts/verify_bench.py @@ -71,37 +71,43 @@ def run(Model, NewModel, get_inputs, get_init_inputs, seed=1012): return check_correctness(Model, NewModel, get_inputs, get_init_inputs, seed) -def run_all(directory): - print(f"Running {directory}") +from src.dataset import construct_kernelbench_dataset + +def run_all(level): + print(f"Running Level {level}") + dataset = construct_kernelbench_dataset(level) total = 0 passed = 0 fail_tests = [] - abs_path = os.path.abspath(directory) - for filename in os.listdir(abs_path): - if filename.endswith(".py"): - total += 1 - module_name = filename[:-3] # Remove .py extension - try: - # Dynamically import the module - spec = importlib.util.spec_from_file_location( - module_name, os.path.join(abs_path, filename) - ) - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - # Get the required attributes from the module - Model = getattr(module, "Model") - get_inputs = getattr(module, "get_inputs") - get_init_inputs = getattr(module, "get_init_inputs") - assert run(Model, Model, get_inputs, get_init_inputs) - passed += 1 - except Exception as e: - fail_tests.append(module_name) - print(f"{directory}: {passed}/{total} passed") + + for problem_id in dataset.get_problem_ids(): + problem_path = dataset.get_problem_by_id(problem_id) + filename = os.path.basename(problem_path) + + total += 1 + module_name = filename[:-3] # Remove .py extension + try: + # Dynamically import the module + spec = importlib.util.spec_from_file_location( + module_name, problem_path + ) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + # Get the required attributes from the module + Model = getattr(module, "Model") + get_inputs = getattr(module, "get_inputs") + get_init_inputs = getattr(module, "get_init_inputs") + assert run(Model, Model, get_inputs, get_init_inputs) + passed += 1 + except Exception as e: + print(f"Failed {module_name}: {e}") + fail_tests.append(module_name) + print(f"Level {level}: {passed}/{total} passed") if len(fail_tests) > 0: print(f"Failed tests: {fail_tests}") if __name__ == "__main__": - run_all(KERNEL_BENCH_PATH + "/level1") - run_all(KERNEL_BENCH_PATH + "/level2") - run_all(KERNEL_BENCH_PATH + "/level3") + run_all(1) + run_all(2) + run_all(3) diff --git a/src/dataset.py b/src/dataset.py index cb429dc1..a674d8a9 100644 --- a/src/dataset.py +++ b/src/dataset.py @@ -38,6 +38,154 @@ def get_code_hash(problem_src: str) -> str: return hashlib.md5(cleaned_problem_src.encode()).hexdigest() + +def check_id_matches_name(problem_id: int, problem_name: str) -> bool: + """Check if the problem_id matches the ID in the problem_name. + + Args: + problem_id: The problem ID to check + problem_name: Path to the problem file + + Returns: + bool: True if the ID matches the filename prefix + + Raises: + ValueError: If filename doesn't follow the expected format + """ + basename = os.path.basename(problem_name) + parts = basename.split('_') + + if len(parts) < 2: + raise ValueError( + f"Problem filename '{basename}' doesn't follow expected format '_.py'" + ) + + try: + file_id = int(parts[0]) + except ValueError: + raise ValueError( + f"Problem filename '{basename}' doesn't start with a numeric ID" + ) + + return problem_id == file_id + + +class KernelBenchDataset(): + """Dataset object for easy access to problems by IDs and iteration over problems. + + Args: + dataset_name: Name of the dataset + level: KernelBench level (1, 2, or 3) + use_subset: Whether to use the subset_dataset instead of full dataset + dataset: List of problem file paths for the full dataset + subset_dataset: List of problem file paths for a subset + """ + + def __init__( + self, + dataset_name: str, + level: int, + use_subset: bool = False, + dataset: list[str] = None, + subset_dataset: list[str] = None + ): + if level not in [1, 2, 3]: + raise ValueError(f"level must be 1, 2, or 3, got {level}") + + self.dataset_name = dataset_name + self.level = level + self.use_subset = use_subset + + # Avoid mutable default arguments + if dataset is None: + dataset = [] + if subset_dataset is None: + subset_dataset = [] + + if use_subset: + self.problems = subset_dataset + else: + self.problems = dataset + + def get_problem_by_id(self, problem_id: int) -> str: + """Get problem path by its ID (1-indexed logical index). + + Args: + problem_id: The problem ID to search for + + Returns: + str: Path to the problem file + + Raises: + ValueError: If problem ID not found in dataset + """ + for problem in self.problems: + if check_id_matches_name(problem_id, problem): + return problem + raise ValueError(f"Problem ID {problem_id} not found in dataset") + + def get_problem_ids(self) -> list[int]: + """Get list of all problem IDs in the dataset. + + Returns: + list[int]: Sorted list of problem IDs extracted from filenames + """ + return sorted([int(os.path.basename(problem).split('_')[0]) for problem in self.problems]) + + def __len__(self) -> int: + """Return the number of problems in the dataset.""" + return len(self.problems) + + def __getitem__(self, index: int) -> str: + """Get problem by index (0-indexed, for backward compatibility). + + Args: + index: Zero-based index into the problems list + + Returns: + str: Path to the problem file + """ + return self.problems[index] + + def __iter__(self): + """Iterate over problem paths in the dataset.""" + return iter(self.problems) + + def __repr__(self) -> str: + """Return string representation of the dataset.""" + subset_str = " (subset)" if self.use_subset else "" + return ( + f"KernelBenchDataset(name='{self.dataset_name}', " + f"level={self.level}, problems={len(self.problems)}{subset_str})" + ) + + +def fetch_ref_arch_from_dataset( + dataset: "KernelBenchDataset", + problem_id: int +) -> tuple[str, str, str]: + """Fetch the reference architecture from the dataset. + + This is a shared utility function to avoid duplication across scripts. + + Args: + dataset: KernelBenchDataset object + problem_id: Logical index (1-indexed), matching the problem_id in the problem_name + + Returns: + tuple containing: + - ref_arch_path: Path to the reference architecture + - ref_arch_name: Name of the reference architecture file + - ref_arch_src: Source code of the reference architecture + """ + from .utils import read_file + + ref_arch_path = dataset.get_problem_by_id(problem_id) + ref_arch_src = read_file(ref_arch_path) + ref_arch_name = os.path.basename(ref_arch_path) + return (ref_arch_path, ref_arch_name, ref_arch_src) + + def construct_problem_dataset_from_problem_dir(problem_dir: str) -> list[str]: """ Construct a list of relative paths to all the python files in the problem directory @@ -57,10 +205,15 @@ def construct_problem_dataset_from_problem_dir(problem_dir: str) -> list[str]: return DATASET -def construct_kernelbench_dataset(level: int) -> list[str]: - return construct_problem_dataset_from_problem_dir( +def construct_kernelbench_dataset(level: int) -> KernelBenchDataset: + dataset_list = construct_problem_dataset_from_problem_dir( os.path.join(KERNEL_BENCH_PATH, f"level{level}") ) + return KernelBenchDataset( + dataset_name=f"KernelBench_Level_{level}", + level=level, + dataset=dataset_list + ) KERNELBENCH_LEVEL_1_DATASET = construct_kernelbench_dataset(level=1) diff --git a/src/eval.py b/src/eval.py index 4a072c89..7e73479e 100644 --- a/src/eval.py +++ b/src/eval.py @@ -21,7 +21,7 @@ import torch.nn as nn from pydantic import BaseModel -from . import utils +from . import utils, dataset REPO_TOP_PATH = os.path.abspath( os.path.join( @@ -46,7 +46,11 @@ def fetch_ref_arch_from_problem_id(problem_id, problems, with_name=False) -> str if isinstance(problem_id, str): problem_id = int(problem_id) - problem_path = problems[problem_id] + if hasattr(problems, "get_problem_by_id"): + problem_path = problems.get_problem_by_id(problem_id) + else: + # Fallback for old list-based API: problem_id is 1-indexed but lists are 0-indexed + problem_path = problems[problem_id - 1] # problem_path = os.path.join(REPO_ROOT_PATH, problem) if not os.path.exists(problem_path): @@ -60,9 +64,8 @@ def fetch_ref_arch_from_problem_id(problem_id, problems, with_name=False) -> str def fetch_ref_arch_from_level_problem_id(level, problem_id, with_name=False): - PROBLEM_DIR = os.path.join(KERNEL_BENCH_PATH, "level" + str(level)) - dataset = utils.construct_problem_dataset_from_problem_dir(PROBLEM_DIR) - return fetch_ref_arch_from_problem_id(problem_id, dataset, with_name) + kb_dataset = dataset.construct_kernelbench_dataset(level) + return fetch_ref_arch_from_problem_id(problem_id, kb_dataset, with_name) def set_seed(seed: int): @@ -884,7 +887,13 @@ def fetch_baseline_time( with open(baseline_time_filepath, "r") as f: baseline_json = json.load(f) - problem_name = dataset[problem_id].split("/")[-1] + if hasattr(dataset, "get_problem_by_id"): + problem_path = dataset.get_problem_by_id(problem_id) + else: + # Fallback for old list-based API: problem_id is 1-indexed but lists are 0-indexed + problem_path = dataset[problem_id - 1] + + problem_name = os.path.basename(problem_path) baseline_time = baseline_json[level_name].get(problem_name, None) return baseline_time