Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file modified .gitignore
Binary file not shown.
145 changes: 122 additions & 23 deletions scripts/eval_from_generations.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,13 @@
# Modal support
import modal

# ThunderKittens compilation utilities
from scripts.tk_compile import (
compile_thunderkittens_cuda,
compile_cuda_on_modal,
prepare_kernel_src_with_cuda
)

"""
Batch Evaluation from Existing Generations

Expand Down Expand Up @@ -60,20 +67,56 @@
operating_sys = "ubuntu22.04"
tag = f"{cuda_version}-{flavor}-{operating_sys}"

image = (
modal.Image.from_registry(f"nvidia/cuda:{tag}", add_python="3.10")
.apt_install("git",
"gcc-10",
"g++-10",
"clang"
)
.pip_install_from_requirements(os.path.join(REPO_TOP_DIR, "requirements.txt"))
.add_local_dir(
KERNEL_BENCH_PATH,
remote_path="/root/KernelBench"
# ThunderKittens support: Current method uses custom TK image if the TK directory exists locally
THUNDERKITTENS_LOCAL_PATH = os.path.join(REPO_TOP_DIR, "ThunderKittens")
SRC_PATH = os.path.join(REPO_TOP_DIR, "src")

if os.path.isdir(THUNDERKITTENS_LOCAL_PATH):
# ThunderKittens image with TK environment and mounting
image = (
modal.Image.from_registry(f"nvidia/cuda:{tag}", add_python="3.10")
.apt_install("git",
"gcc-10",
"g++-10",
"clang"
)
.pip_install_from_requirements(os.path.join(REPO_TOP_DIR, "requirements.txt"))
.pip_install("pybind11") # Ensure pybind11 is available for ThunderKittens compilation
.env({
"THUNDERKITTENS_ROOT": "/root/ThunderKittens",
"THUNDERKITTENS_PATH": "/root/ThunderKittens",
"TORCH_CUDA_ARCH_LIST": "9.0",
"CXX": "g++-10",
"CC": "gcc-10",
})
.add_local_dir(
THUNDERKITTENS_LOCAL_PATH,
remote_path="/root/ThunderKittens",
copy=True
)
.add_local_dir(
KERNEL_BENCH_PATH,
remote_path="/root/KernelBench"
)
.add_local_dir(SRC_PATH, remote_path="/root/src")
.add_local_python_source("scripts")
)
else:
# Standard image
image = (
modal.Image.from_registry(f"nvidia/cuda:{tag}", add_python="3.10")
.apt_install("git",
"gcc-10",
"g++-10",
"clang"
)
.pip_install_from_requirements(os.path.join(REPO_TOP_DIR, "requirements.txt"))
.add_local_dir(
KERNEL_BENCH_PATH,
remote_path="/root/KernelBench"
)
.add_local_python_source("src")
)
.add_local_python_source("src")
)


class EvalConfig(Config):
Expand Down Expand Up @@ -176,13 +219,19 @@ def evaluate_single_sample_modal(
verbose: bool = False,
backend: str = "cuda",
precision: str = "fp32",
cuda_src: str = None,
cuda_module_name: str = "tk_kernels",
):
"""
Evaluate a single sample on Modal GPU with automatic retries for GPU attachment failures
and proper GPU corruption handling via stop_fetching_inputs()

If cuda_src is provided, it will be compiled first and the kernel_src will be modified
to import the compiled module.
"""
from src.eval import eval_kernel_against_ref, get_torch_dtype_from_string
from src.utils import set_gpu_arch
from scripts.tk_compile import compile_cuda_on_modal, prepare_kernel_src_with_cuda
import torch
import time
import modal.experimental
Expand All @@ -206,6 +255,14 @@ def evaluate_single_sample_modal(

set_gpu_arch(gpu_arch)

# If CUDA source provided, compile it first
if cuda_src:
cuda_module_path = compile_cuda_on_modal(cuda_src, cuda_module_name, gpu_arch)

# Modify kernel_src to import the compiled module
kernel_src = prepare_kernel_src_with_cuda(kernel_src, cuda_module_path, cuda_module_name)
print(f"[Modal] Modified kernel source to use compiled module at {cuda_module_path}")

gpu_corrupted = False
try:
result = eval_kernel_against_ref(
Expand Down Expand Up @@ -275,20 +332,33 @@ def fetch_ref_arch_from_problem_id(
return ref_arch_src




def fetch_kernel_from_disk(
run_dir: str, level: int, problem_id: int, sample_id: int
) -> str | None:
) -> tuple[str | None, str | None]:
"""
Fetch kernel file from disk (stored in runs/{run_name})
Fetch kernel files from disk (stored in runs/{run_name})
Returns: (kernel_py_src, cuda_src_path) tuple
- kernel_py_src: Python kernel source code (or None if not found)
- cuda_src_path: Path to .cu file if it exists (or None)
"""
kernel_path = os.path.join(
run_dir, f"level_{level}_problem_{problem_id}_sample_{sample_id}_kernel.py"
)
cuda_path = os.path.join(
run_dir, f"level_{level}_problem_{problem_id}_sample_{sample_id}_kernel.cu"
)

kernel_py_src = None
if os.path.exists(kernel_path):
return read_file(kernel_path)
else:
return None
kernel_py_src = read_file(kernel_path)

cuda_src_path = None
if os.path.exists(cuda_path):
cuda_src_path = cuda_path

return (kernel_py_src, cuda_src_path)


def evaluate_single_sample(
Expand All @@ -309,11 +379,31 @@ def evaluate_single_sample(

# fetch kernel from disk
# Add database support in the future
kernel_src = fetch_kernel_from_disk(run_dir, configs.level, problem_id, sample_id)
kernel_py_src, cuda_src_path = fetch_kernel_from_disk(run_dir, configs.level, problem_id, sample_id)

assert (
kernel_src is not None
kernel_py_src is not None
), f"Kernel not found for problem {problem_id} sample {sample_id}"

# For local evaluation, if CUDA source exists, compile it first
kernel_src = kernel_py_src
if cuda_src_path:
# Create build directory
cuda_build_dir = os.path.join(
configs.kernel_eval_build_dir, configs.run_name, f"{problem_id}", f"{sample_id}", "cuda_build"
)

# Compile CUDA module
cuda_module_path = compile_thunderkittens_cuda(
cuda_src_path=cuda_src_path,
module_name="tk_kernels",
build_dir=cuda_build_dir,
verbose=configs.verbose,
repo_top_path=REPO_TOP_DIR
)

# Modify kernel_src to import the compiled module
kernel_src = prepare_kernel_src_with_cuda(kernel_src, cuda_module_path, "tk_kernels")

build_dir = os.path.join(
configs.kernel_eval_build_dir, configs.run_name, f"{problem_id}", f"{sample_id}"
Expand Down Expand Up @@ -466,17 +556,24 @@ def batch_eval_modal(
ref_arch_src = fetch_ref_arch_from_problem_id(
curr_level_dataset, problem_id, config.dataset_src
)
kernel_src = fetch_kernel_from_disk(run_dir, config.level, problem_id, sample_id)
kernel_py_src, cuda_src_path = fetch_kernel_from_disk(run_dir, config.level, problem_id, sample_id)

if kernel_src is None:
if kernel_py_src is None:
print(f"[WARNING] Kernel not found for problem {problem_id} sample {sample_id}")
work_items.append(None)
else:
# Read CUDA source if it exists
cuda_src = None
if cuda_src_path:
cuda_src = read_file(cuda_src_path)
print(f"[INFO] Found CUDA source for problem {problem_id} sample {sample_id}: {cuda_src_path}")

work_items.append({
'problem_id': problem_id,
'sample_id': sample_id,
'ref_arch_src': ref_arch_src,
'kernel_src': kernel_src,
'kernel_src': kernel_py_src,
'cuda_src': cuda_src,
})

# Submit all evaluations in parallel using Modal
Expand Down Expand Up @@ -505,6 +602,8 @@ def batch_eval_modal(
verbose=config.verbose,
backend=config.backend,
precision=config.precision,
cuda_src=item.get('cuda_src'),
cuda_module_name="tk_kernels",
)
futures.append(future)

Expand Down
66 changes: 56 additions & 10 deletions scripts/generate_and_eval_single_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from src.utils import (
create_inference_server_from_presets,
extract_first_code,
extract_cuda_and_python_code,
query_server,
read_file,
set_gpu_arch,
Expand Down Expand Up @@ -197,12 +198,16 @@ def main(config: EvalConfig):
include_hardware = include_hardware.lower() in ["true", "1", "yes"]
config.include_hardware_info = include_hardware

supported_backends = {"cuda", "triton", "tilelang", "cute"}
supported_backends = {"cuda", "triton", "tilelang", "cute", "thunderkittens"}
backend = config.backend.lower()
if backend not in supported_backends:
raise ValueError(
f"Unsupported backend: {config.backend}. Must be one of {sorted(supported_backends)}."
)

# ThunderKittens uses fp32 by default
if backend == "thunderkittens":
config.precision = "fp32"

if backend == "tilelang":
config.precision = "fp16" # tilelang only operates with fp16
Expand Down Expand Up @@ -246,22 +251,63 @@ def main(config: EvalConfig):
f.write(custom_prompt)

# Query server with constructed prompt
custom_kernel = inference_server(custom_prompt)
custom_kernel = extract_first_code(custom_kernel, ["python", "cpp"])

# check LLM is able to generate custom kernel code
assert (
custom_kernel is not None
), f"Custom {config.backend} kernel code generation failed"

# this should be optional
custom_kernel_response = inference_server(custom_prompt)

# For ThunderKittens, extract both CUDA and Python code
cuda_src = None
if backend == "thunderkittens":
cuda_code, python_code = extract_cuda_and_python_code(custom_kernel_response)
if cuda_code is None or python_code is None:
# Fallback to single code extraction
print("[WARNING] Could not extract separate CUDA and Python code blocks, falling back to single extraction")
custom_kernel = extract_first_code(custom_kernel_response, ["python", "cpp"])
assert custom_kernel is not None, f"Custom {config.backend} kernel code generation failed"
else:
custom_kernel = python_code
cuda_src = cuda_code
print(f"[INFO] Extracted CUDA code ({len(cuda_src)} chars) and Python code ({len(custom_kernel)} chars)")
else:
custom_kernel = extract_first_code(custom_kernel_response, ["python", "cpp"])
# check LLM is able to generate custom kernel code
assert (
custom_kernel is not None
), f"Custom {config.backend} kernel code generation failed"

# Log generated files
if config.log:
if cuda_src:
with open(os.path.join(config.logdir, f"generated_kernel_level_{config.level}_problem_{config.problem_id}.cu"), "w") as f:
f.write(cuda_src)
with open(os.path.join(config.logdir, f"generated_kernel_level_{config.level}_problem_{config.problem_id}.py"), "w") as f:
f.write(custom_kernel)

# 3. Evaluate Kernel
# NOTE: no need to wrap around process here as only a single sample
# see batch eval for examples of process isolation

# For ThunderKittens with separate CUDA file, compile it first
if backend == "thunderkittens" and cuda_src:
from scripts.eval_from_generations import compile_thunderkittens_cuda, prepare_kernel_src_with_cuda
import tempfile

# Create temporary CUDA file
temp_cuda_file = os.path.join(config.logdir, f"temp_kernel_level_{config.level}_problem_{config.problem_id}.cu")
os.makedirs(os.path.dirname(temp_cuda_file), exist_ok=True)
with open(temp_cuda_file, 'w') as f:
f.write(cuda_src)

# Compile CUDA module
cuda_build_dir = os.path.join(config.logdir, f"cuda_build_level_{config.level}_problem_{config.problem_id}")
cuda_module_path = compile_thunderkittens_cuda(
cuda_src_path=temp_cuda_file,
module_name="tk_kernels",
build_dir=cuda_build_dir,
verbose=config.verbose
)

# Modify kernel_src to import the compiled module
custom_kernel = prepare_kernel_src_with_cuda(custom_kernel, cuda_module_path, "tk_kernels")

kernel_exec_result = eval_kernel_against_ref(
ref_arch_src,
custom_kernel,
Expand Down
Loading