Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ pydra_config
pytest
ninja
cupy-cuda12x
nsight-python

# Numerics
einops
Expand Down
50 changes: 0 additions & 50 deletions scripts/generate_baseline_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,53 +243,3 @@ def test_measure_particular_program(level_num: int, problem_id: int):
# get_time(2, 43, torch_compile=True)




################################################################################
# Deprecated
################################################################################


def get_time_old(level_num, problem_id, num_trials=100, torch_compile=False):
raise DeprecationWarning("Use New measure_program_time instead")
ref_arch_name, ref_arch_src = fetch_ref_arch_from_level_problem_id(
level_num, problem_id, with_name=True
)
ref_arch_name = ref_arch_name.split("/")[-1]
context = {}
Model, get_init_inputs, get_inputs = load_original_model_and_inputs(
ref_arch_src, context
)
try:
with torch.no_grad():
torch.cuda.synchronize(device=device)
set_seed(42)
inputs = get_inputs()
set_seed(42)
init_inputs = get_init_inputs()
inputs = [
x.cuda(device=device) if isinstance(x, torch.Tensor) else x
for x in inputs
]
init_inputs = [
x.cuda(device=device) if isinstance(x, torch.Tensor) else x
for x in init_inputs
]
model = Model(*init_inputs)

if torch_compile:
model = torch.compile(model)
print("Compiled model Done")
model = model.cuda(device=device)
torch.cuda.synchronize(device=device)
elapsed_times = time_execution_with_cuda_event(
model, *inputs, num_trials=num_trials, verbose=False, device=device
)
runtime_stats = get_timing_stats(elapsed_times, device=device)
# json_results[f"level{level_num}"][ref_arch_name] = runtime_stats
print(f"{ref_arch_name} {runtime_stats}")
return (ref_arch_name, runtime_stats)
except Exception as e:
print(f"[Eval] Error in Measuring Performance: {e}")


68 changes: 13 additions & 55 deletions scripts/generate_baseline_time_modal.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@
import numpy as np
from src.eval import (
load_original_model_and_inputs,
time_execution_with_cuda_event,
get_timing_stats,
set_seed,
fetch_ref_arch_from_problem_id,
)
from src.timing import get_timing_function, get_timing_stats
from src.dataset import construct_problem_dataset_from_problem_dir
from src.utils import read_file
import os
Expand Down Expand Up @@ -160,6 +159,7 @@ def measure_program_time(
ref_arch_name: str,
ref_arch_src: str,
num_trials: int = 100,
timing_method: str="cuda_event",
use_torch_compile: bool = False,
torch_compile_backend: str="inductor",
torch_compile_options: str="default",
Expand Down Expand Up @@ -199,9 +199,16 @@ def measure_program_time(
print(f"Using PyTorch Eager Execution on {ref_arch_name}")

model = model.cuda(device=device)
timing_func = get_timing_function(timing_method)
torch.cuda.synchronize(device=device)
elapsed_times = time_execution_with_cuda_event(
model, *inputs, num_trials=num_trials, verbose=verbose, device=device
elapsed_times = timing_func(
model,
inputs,
num_warmup=3, # or any default you prefer
num_trials=num_trials,
discard_first=1, # or 0 to include first trial
verbose=verbose,
device=device,
)
runtime_stats = get_timing_stats(elapsed_times, device=device)

Expand Down Expand Up @@ -247,7 +254,8 @@ def record_baseline_times(config: BaselineConfig,
ref_arch_name,
ref_arch_src,
config.num_trials,
use_torch_compile,
timing_method,
use_torch_compile,
torch_compile_backend,
torch_compile_options,
torch.device(f"cuda:0"),
Expand Down Expand Up @@ -341,53 +349,3 @@ def main(config: BaselineConfig):
# get_time(2, 43, torch_compile=True)




################################################################################
# Deprecated
################################################################################


def get_time_old(level_num, problem_id, num_trials=100, torch_compile=False):
raise DeprecationWarning("Use New measure_program_time instead")
ref_arch_name, ref_arch_src = fetch_ref_arch_from_level_problem_id(
level_num, problem_id, with_name=True
)
ref_arch_name = ref_arch_name.split("/")[-1]
context = {}
Model, get_init_inputs, get_inputs = load_original_model_and_inputs(
ref_arch_src, context
)
try:
with torch.no_grad():
torch.cuda.synchronize(device=device)
set_seed(42)
inputs = get_inputs()
set_seed(42)
init_inputs = get_init_inputs()
inputs = [
x.cuda(device=device) if isinstance(x, torch.Tensor) else x
for x in inputs
]
init_inputs = [
x.cuda(device=device) if isinstance(x, torch.Tensor) else x
for x in init_inputs
]
model = Model(*init_inputs)

if torch_compile:
model = torch.compile(model)
print("Compiled model Done")
model = model.cuda(device=device)
torch.cuda.synchronize(device=device)
elapsed_times = time_execution_with_cuda_event(
model, *inputs, num_trials=num_trials, verbose=False, device=device
)
runtime_stats = get_timing_stats(elapsed_times, device=device)
# json_results[f"level{level_num}"][ref_arch_name] = runtime_stats
print(f"{ref_arch_name} {runtime_stats}")
return (ref_arch_name, runtime_stats)
except Exception as e:
print(f"[Eval] Error in Measuring Performance: {e}")


13 changes: 7 additions & 6 deletions scripts/get_baseline_time_single_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,17 @@
import numpy as np
from src.eval import (
load_original_model_and_inputs,
time_execution_with_cuda_event,
get_timing_stats,
set_seed,
fetch_ref_arch_from_problem_id,
)

from src.timing import get_timing_function, get_timing_stats

def measure_program_time(
ref_arch_name: str,
ref_arch_src: str,
num_trials: int = 100,
timing_method: str="cuda_event",
use_torch_compile: bool = False,
torch_compile_backend: str="inductor",
torch_compile_options: str="default",
Expand Down Expand Up @@ -52,8 +53,9 @@ def measure_program_time(

model = model.cuda(device=device)
torch.cuda.synchronize(device=device)
elapsed_times = time_execution_with_cuda_event(
model, *inputs, num_trials=num_trials, verbose=verbose, device=device
timing_func = get_timing_function(timing_method )
elapsed_times = timing_func(
model, inputs, num_warmup=3, num_trials=num_trials, discard_first=1, verbose=verbose, device=device
)
runtime_stats = get_timing_stats(elapsed_times, device=device)

Expand Down Expand Up @@ -87,5 +89,4 @@ def get_inputs():
def get_init_inputs():
return [] # No special initialization inputs needed
"""
print(measure_program_time(ref_arch_name, ref_arch_src, use_torch_compile=False))
print(measure_program_time(ref_arch_name, ref_arch_src, use_torch_compile=True))
print(measure_program_time(ref_arch_name, ref_arch_src, use_torch_compile=False, timing_method="cuda_event"))
3 changes: 1 addition & 2 deletions scripts/inspect_baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@
import numpy as np
from src.eval import (
load_original_model_and_inputs,
time_execution_with_cuda_event,
get_timing_stats,
set_seed,
fetch_ref_arch_from_problem_id,
)
from src.timing import get_timing_function, get_timing_stats
from src.dataset import construct_problem_dataset_from_problem_dir
import os, sys
import logging
Expand Down
9 changes: 6 additions & 3 deletions scripts/inspect_triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,9 @@
from src.eval import (
load_custom_model,
load_original_model_and_inputs,
time_execution_with_cuda_event,
get_timing_stats,
set_seed,
)
from src.timing import get_timing_function, get_timing_stats

def fetch_ref_arch_from_dataset(dataset: list[str],
problem_id: int) -> tuple[str, str, str]:
Expand Down Expand Up @@ -123,6 +122,8 @@ def run_profile_and_save_trace(dataset: list[str], problem_id: int, num_trials=1
def get_torch_compile_triton(level_num, problem_id):
"""
Get the triton code generated by torch compile for a particular problem

Made this during ICML rebuttal in March 2025
"""
ref_arch_path, ref_arch_name, ref_arch_src = fetch_ref_arch_from_dataset(
dataset, problem_id, with_name=True
Expand Down Expand Up @@ -167,7 +168,9 @@ def get_torch_compile_triton(level_num, problem_id):


torch.cuda.synchronize(device=device)
elapsed_times = time_execution_with_cuda_event(
timing_method = "cuda_event" # use cuda event for timing here
time_func_cuda_event = get_timing_function(timing_method)
elapsed_times = time_func_cuda_event(
model, *inputs, num_trials=1, verbose=False, device=device
)
runtime_stats = get_timing_stats(elapsed_times, device=device)
Expand Down
138 changes: 138 additions & 0 deletions src/profile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
####
# Profiling Related Functions
# TODO: @kesavan @simon @arya
####

import torch
import pandas as pd

# wrapper with tool to measure hardware metric


# Check if nsight-python is available
# To patch nsight-python to support multiple metrics and fix the "cannot insert Annotation" bug
try:
import nsight
NSIGHT_AVAILABLE = True

# Patch 1: Multiple metrics support
_orig_ncu_init = nsight.collection.ncu.NCUCollector.__init__
nsight.collection.ncu.NCUCollector.__init__ = lambda self, metric="gpu__time_duration.sum", *a, **kw: \
_orig_ncu_init(self, ",".join(metric) if isinstance(metric, (list, tuple)) else metric, *a, **kw)

# Patch 2: Extract all metrics from comma-separated string
_orig_extract = nsight.extraction.extract_df_from_report
def _patched_extract(path, metric, *a, **kw):
if "," not in metric:
return _orig_extract(path, metric, *a, **kw)
rows = []
for m in metric.split(","):
df = _orig_extract(path, m.strip(), *a, **kw)
if df is not None and not df.empty:
df = df.copy()
df['Metric Name'] = m.strip()
rows.extend(df.to_dict('records'))
return pd.DataFrame(rows) if rows else None
nsight.extraction.extract_df_from_report = _patched_extract

# Patch 3: Fix "cannot insert Annotation" bug
_orig_agg = nsight.transformation.aggregate_data
def _patched_agg(df, func, norm, progress):
_orig_gb = pd.DataFrame.groupby
pd.DataFrame.groupby = lambda self, *a, **kw: _orig_gb(self, *a, **{**kw, 'as_index': False})
try:
return _orig_agg(df, func, norm, progress)
finally:
pd.DataFrame.groupby = _orig_gb
nsight.transformation.aggregate_data = _patched_agg

except ImportError:
NSIGHT_AVAILABLE = False


def profile_with_nsight(func, metrics=None, *args, **kwargs):
"""Profile a PyTorch function. Returns {metric_name: value}."""
if not NSIGHT_AVAILABLE:
raise RuntimeError("nsight-python not available")

metrics = [metrics] if isinstance(metrics, str) else (metrics or ['sm__cycles_active.avg'])

@nsight.analyze.kernel(metric=metrics, runs=1, configs=[(0,)],
combine_kernel_metrics=lambda a, b: (a or 0) + (b or 0))
def profiled(_):
with nsight.annotate("kernel"):
return func(*args, **kwargs)

try:
result = profiled()
if result is None:
return {m: None for m in metrics}

df = result.to_dataframe()
if df is None or df.empty:
return {m: None for m in metrics}

if 'Metric Name' in df.columns:
return {row['Metric Name']: float(row['AvgValue']) for _, row in df.iterrows()}
return {metrics[0]: float(df['AvgValue'].iloc[0])}
except Exception as e:
print(f"Error profiling: {e}")
return {m: None for m in metrics}


# example function to profile with nsight
def example_ncu_python_profile():
# Test with simple kernel
def test_kernel(x, y):
"""Simple matmul kernel."""
return x @ y

print("Creating test tensors...")
a = torch.randn(256, 256, device="cuda")
b = torch.randn(256, 256, device="cuda")

print("Running nsight profiling...")
metric_values = profile_with_nsight(
test_kernel,
['sm__cycles_active.avg', 'sm__cycles_elapsed.sum', "smsp__inst_executed_pipe_tensor_op_hmma.sum"],
a, b
)

print("\nProfiling results:")
for metric_name, value in metric_values.items():
print(f" {metric_name}: {value}")
return


def check_ncu_available() -> bool:
from shutil import which
return which('ncu') is not None


@nsight.analyze.kernel
def benchmark_matmul(n):
"""Standard benchmark following nsight-python docs."""
a = torch.randn(n, n, device="cuda")
b = torch.randn(n, n, device="cuda")
with nsight.annotate("matmul"):
c = a @ b
return c


if __name__ == "__main__":
if not check_ncu_available():
print("ncu not found in PATH. Install Nsight Compute.")
exit(1)

if not torch.cuda.is_available():
print("CUDA is not available.")
exit(1)


# test the example_ncu_python_profile
example_ncu_python_profile()



# pytorch profiler
# migrate from old repo during ICML / caesar repo
1 change: 0 additions & 1 deletion src/timing.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ def get_timing_function(
NOTE: we have a WIP blogpost on this topic covering the various timing approaches
"""


def time_execution_with_cuda_event(
kernel_fn: callable,
args: list[Any],
Expand Down