From 2249bef6dcff279b459c4c858e67e540f2dfe6a1 Mon Sep 17 00:00:00 2001 From: Simon Guo Date: Tue, 16 Dec 2025 07:41:24 +0000 Subject: [PATCH 1/3] start metric profile branch --- requirements.txt | 1 + src/profile.py | 49 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+) create mode 100644 src/profile.py diff --git a/requirements.txt b/requirements.txt index 253e57da..5f2ae387 100644 --- a/requirements.txt +++ b/requirements.txt @@ -19,6 +19,7 @@ pydra_config pytest ninja cupy-cuda12x +nsight-python # Numerics einops diff --git a/src/profile.py b/src/profile.py new file mode 100644 index 00000000..3459e106 --- /dev/null +++ b/src/profile.py @@ -0,0 +1,49 @@ +#### +# Profiling Related Functions +# TODO: @kesavan @simon @arya +#### + +import torch +import nsight + +# wrapper with tool to measure hardware metric + + +# nsight-python +# https://docs.nvidia.com/nsight-python/overview + +def check_ncu_available() -> bool: + """Check if ncu is in PATH. Returns True if found, False otherwise.""" + from shutil import which + return which('ncu') is not None + + +# Note you need access to hardware counter +# you also need to have ncu installed and point to the right path +# sudo env "PATH=$PATH" $(which python) src/profile.py +@nsight.analyze.kernel +def benchmark_matmul(n): + """ + The simplest possible benchmark. + We create two matrices and multiply them. + follow example from nsight-python docs + """ + # Create two NxN matrices on GPU + a = torch.randn(n, n, device="cuda") + b = torch.randn(n, n, device="cuda") + + # Mark the kernel we want to profile + with nsight.annotate("matmul"): + c = a @ b + + return c + +if __name__ == "__main__": + if not check_ncu_available(): + print("ncu not found in PATH. Please install ncu and point to path.") + exit(1) + # Run the benchmark + result = benchmark_matmul(1024) + +# pytorch profiler +# migrate from old repo during ICML / caesar repo \ No newline at end of file From 845dcbb835c045bdf80b827d2d614600ac04f249 Mon Sep 17 00:00:00 2001 From: Simon Guo Date: Tue, 16 Dec 2025 17:56:13 +0000 Subject: [PATCH 2/3] clean up remaining places for cuda events --- scripts/generate_baseline_time.py | 50 --------------- scripts/generate_baseline_time_modal.py | 68 ++++----------------- scripts/get_baseline_time_single_problem.py | 13 ++-- scripts/inspect_baseline.py | 3 +- scripts/inspect_triton.py | 9 ++- src/timing.py | 1 - 6 files changed, 27 insertions(+), 117 deletions(-) diff --git a/scripts/generate_baseline_time.py b/scripts/generate_baseline_time.py index 0a1f608b..47975ebd 100644 --- a/scripts/generate_baseline_time.py +++ b/scripts/generate_baseline_time.py @@ -243,53 +243,3 @@ def test_measure_particular_program(level_num: int, problem_id: int): # get_time(2, 43, torch_compile=True) - - -################################################################################ -# Deprecated -################################################################################ - - -def get_time_old(level_num, problem_id, num_trials=100, torch_compile=False): - raise DeprecationWarning("Use New measure_program_time instead") - ref_arch_name, ref_arch_src = fetch_ref_arch_from_level_problem_id( - level_num, problem_id, with_name=True - ) - ref_arch_name = ref_arch_name.split("/")[-1] - context = {} - Model, get_init_inputs, get_inputs = load_original_model_and_inputs( - ref_arch_src, context - ) - try: - with torch.no_grad(): - torch.cuda.synchronize(device=device) - set_seed(42) - inputs = get_inputs() - set_seed(42) - init_inputs = get_init_inputs() - inputs = [ - x.cuda(device=device) if isinstance(x, torch.Tensor) else x - for x in inputs - ] - init_inputs = [ - x.cuda(device=device) if isinstance(x, torch.Tensor) else x - for x in init_inputs - ] - model = Model(*init_inputs) - - if torch_compile: - model = torch.compile(model) - print("Compiled model Done") - model = model.cuda(device=device) - torch.cuda.synchronize(device=device) - elapsed_times = time_execution_with_cuda_event( - model, *inputs, num_trials=num_trials, verbose=False, device=device - ) - runtime_stats = get_timing_stats(elapsed_times, device=device) - # json_results[f"level{level_num}"][ref_arch_name] = runtime_stats - print(f"{ref_arch_name} {runtime_stats}") - return (ref_arch_name, runtime_stats) - except Exception as e: - print(f"[Eval] Error in Measuring Performance: {e}") - - diff --git a/scripts/generate_baseline_time_modal.py b/scripts/generate_baseline_time_modal.py index a0039193..c85824d1 100644 --- a/scripts/generate_baseline_time_modal.py +++ b/scripts/generate_baseline_time_modal.py @@ -2,11 +2,10 @@ import numpy as np from src.eval import ( load_original_model_and_inputs, - time_execution_with_cuda_event, - get_timing_stats, set_seed, fetch_ref_arch_from_problem_id, ) +from src.timing import get_timing_function, get_timing_stats from src.dataset import construct_problem_dataset_from_problem_dir from src.utils import read_file import os @@ -160,6 +159,7 @@ def measure_program_time( ref_arch_name: str, ref_arch_src: str, num_trials: int = 100, + timing_method: str="cuda_event", use_torch_compile: bool = False, torch_compile_backend: str="inductor", torch_compile_options: str="default", @@ -199,9 +199,16 @@ def measure_program_time( print(f"Using PyTorch Eager Execution on {ref_arch_name}") model = model.cuda(device=device) + timing_func = get_timing_function(timing_method) torch.cuda.synchronize(device=device) - elapsed_times = time_execution_with_cuda_event( - model, *inputs, num_trials=num_trials, verbose=verbose, device=device + elapsed_times = timing_func( + model, + inputs, + num_warmup=3, # or any default you prefer + num_trials=num_trials, + discard_first=1, # or 0 to include first trial + verbose=verbose, + device=device, ) runtime_stats = get_timing_stats(elapsed_times, device=device) @@ -247,7 +254,8 @@ def record_baseline_times(config: BaselineConfig, ref_arch_name, ref_arch_src, config.num_trials, - use_torch_compile, + timing_method, + use_torch_compile, torch_compile_backend, torch_compile_options, torch.device(f"cuda:0"), @@ -341,53 +349,3 @@ def main(config: BaselineConfig): # get_time(2, 43, torch_compile=True) - - -################################################################################ -# Deprecated -################################################################################ - - -def get_time_old(level_num, problem_id, num_trials=100, torch_compile=False): - raise DeprecationWarning("Use New measure_program_time instead") - ref_arch_name, ref_arch_src = fetch_ref_arch_from_level_problem_id( - level_num, problem_id, with_name=True - ) - ref_arch_name = ref_arch_name.split("/")[-1] - context = {} - Model, get_init_inputs, get_inputs = load_original_model_and_inputs( - ref_arch_src, context - ) - try: - with torch.no_grad(): - torch.cuda.synchronize(device=device) - set_seed(42) - inputs = get_inputs() - set_seed(42) - init_inputs = get_init_inputs() - inputs = [ - x.cuda(device=device) if isinstance(x, torch.Tensor) else x - for x in inputs - ] - init_inputs = [ - x.cuda(device=device) if isinstance(x, torch.Tensor) else x - for x in init_inputs - ] - model = Model(*init_inputs) - - if torch_compile: - model = torch.compile(model) - print("Compiled model Done") - model = model.cuda(device=device) - torch.cuda.synchronize(device=device) - elapsed_times = time_execution_with_cuda_event( - model, *inputs, num_trials=num_trials, verbose=False, device=device - ) - runtime_stats = get_timing_stats(elapsed_times, device=device) - # json_results[f"level{level_num}"][ref_arch_name] = runtime_stats - print(f"{ref_arch_name} {runtime_stats}") - return (ref_arch_name, runtime_stats) - except Exception as e: - print(f"[Eval] Error in Measuring Performance: {e}") - - diff --git a/scripts/get_baseline_time_single_problem.py b/scripts/get_baseline_time_single_problem.py index 476fe2fd..e6f8155a 100644 --- a/scripts/get_baseline_time_single_problem.py +++ b/scripts/get_baseline_time_single_problem.py @@ -2,16 +2,17 @@ import numpy as np from src.eval import ( load_original_model_and_inputs, - time_execution_with_cuda_event, - get_timing_stats, set_seed, fetch_ref_arch_from_problem_id, ) +from src.timing import get_timing_function, get_timing_stats + def measure_program_time( ref_arch_name: str, ref_arch_src: str, num_trials: int = 100, + timing_method: str="cuda_event", use_torch_compile: bool = False, torch_compile_backend: str="inductor", torch_compile_options: str="default", @@ -52,8 +53,9 @@ def measure_program_time( model = model.cuda(device=device) torch.cuda.synchronize(device=device) - elapsed_times = time_execution_with_cuda_event( - model, *inputs, num_trials=num_trials, verbose=verbose, device=device + timing_func = get_timing_function(timing_method ) + elapsed_times = timing_func( + model, inputs, num_warmup=3, num_trials=num_trials, discard_first=1, verbose=verbose, device=device ) runtime_stats = get_timing_stats(elapsed_times, device=device) @@ -87,5 +89,4 @@ def get_inputs(): def get_init_inputs(): return [] # No special initialization inputs needed """ - print(measure_program_time(ref_arch_name, ref_arch_src, use_torch_compile=False)) - print(measure_program_time(ref_arch_name, ref_arch_src, use_torch_compile=True)) \ No newline at end of file + print(measure_program_time(ref_arch_name, ref_arch_src, use_torch_compile=False, timing_method="cuda_event")) \ No newline at end of file diff --git a/scripts/inspect_baseline.py b/scripts/inspect_baseline.py index e7811f64..a39716fd 100644 --- a/scripts/inspect_baseline.py +++ b/scripts/inspect_baseline.py @@ -5,11 +5,10 @@ import numpy as np from src.eval import ( load_original_model_and_inputs, - time_execution_with_cuda_event, - get_timing_stats, set_seed, fetch_ref_arch_from_problem_id, ) +from src.timing import get_timing_function, get_timing_stats from src.dataset import construct_problem_dataset_from_problem_dir import os, sys import logging diff --git a/scripts/inspect_triton.py b/scripts/inspect_triton.py index 4f13c8af..f664b6d6 100644 --- a/scripts/inspect_triton.py +++ b/scripts/inspect_triton.py @@ -21,10 +21,9 @@ from src.eval import ( load_custom_model, load_original_model_and_inputs, - time_execution_with_cuda_event, - get_timing_stats, set_seed, ) +from src.timing import get_timing_function, get_timing_stats def fetch_ref_arch_from_dataset(dataset: list[str], problem_id: int) -> tuple[str, str, str]: @@ -123,6 +122,8 @@ def run_profile_and_save_trace(dataset: list[str], problem_id: int, num_trials=1 def get_torch_compile_triton(level_num, problem_id): """ Get the triton code generated by torch compile for a particular problem + + Made this during ICML rebuttal in March 2025 """ ref_arch_path, ref_arch_name, ref_arch_src = fetch_ref_arch_from_dataset( dataset, problem_id, with_name=True @@ -167,7 +168,9 @@ def get_torch_compile_triton(level_num, problem_id): torch.cuda.synchronize(device=device) - elapsed_times = time_execution_with_cuda_event( + timing_method = "cuda_event" # use cuda event for timing here + time_func_cuda_event = get_timing_function(timing_method) + elapsed_times = time_func_cuda_event( model, *inputs, num_trials=1, verbose=False, device=device ) runtime_stats = get_timing_stats(elapsed_times, device=device) diff --git a/src/timing.py b/src/timing.py index 8a36522b..7f192625 100644 --- a/src/timing.py +++ b/src/timing.py @@ -90,7 +90,6 @@ def get_timing_function( NOTE: we have a WIP blogpost on this topic covering the various timing approaches """ - def time_execution_with_cuda_event( kernel_fn: callable, args: list[Any], From b3387a42a301cea6d1d31dd2dd3265e170b15b5f Mon Sep 17 00:00:00 2001 From: Kesavan Ramakrishnan Date: Wed, 17 Dec 2025 21:19:51 +0000 Subject: [PATCH 3/3] add in nsight-python metric evaluation script and example usage, have to patch some nsight python functions to get this to work --- src/profile.py | 127 +++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 108 insertions(+), 19 deletions(-) diff --git a/src/profile.py b/src/profile.py index 3459e106..8a99f802 100644 --- a/src/profile.py +++ b/src/profile.py @@ -4,46 +4,135 @@ #### import torch -import nsight +import pandas as pd # wrapper with tool to measure hardware metric -# nsight-python -# https://docs.nvidia.com/nsight-python/overview +# Check if nsight-python is available +# To patch nsight-python to support multiple metrics and fix the "cannot insert Annotation" bug +try: + import nsight + NSIGHT_AVAILABLE = True + + # Patch 1: Multiple metrics support + _orig_ncu_init = nsight.collection.ncu.NCUCollector.__init__ + nsight.collection.ncu.NCUCollector.__init__ = lambda self, metric="gpu__time_duration.sum", *a, **kw: \ + _orig_ncu_init(self, ",".join(metric) if isinstance(metric, (list, tuple)) else metric, *a, **kw) + + # Patch 2: Extract all metrics from comma-separated string + _orig_extract = nsight.extraction.extract_df_from_report + def _patched_extract(path, metric, *a, **kw): + if "," not in metric: + return _orig_extract(path, metric, *a, **kw) + rows = [] + for m in metric.split(","): + df = _orig_extract(path, m.strip(), *a, **kw) + if df is not None and not df.empty: + df = df.copy() + df['Metric Name'] = m.strip() + rows.extend(df.to_dict('records')) + return pd.DataFrame(rows) if rows else None + nsight.extraction.extract_df_from_report = _patched_extract + + # Patch 3: Fix "cannot insert Annotation" bug + _orig_agg = nsight.transformation.aggregate_data + def _patched_agg(df, func, norm, progress): + _orig_gb = pd.DataFrame.groupby + pd.DataFrame.groupby = lambda self, *a, **kw: _orig_gb(self, *a, **{**kw, 'as_index': False}) + try: + return _orig_agg(df, func, norm, progress) + finally: + pd.DataFrame.groupby = _orig_gb + nsight.transformation.aggregate_data = _patched_agg + +except ImportError: + NSIGHT_AVAILABLE = False + + +def profile_with_nsight(func, metrics=None, *args, **kwargs): + """Profile a PyTorch function. Returns {metric_name: value}.""" + if not NSIGHT_AVAILABLE: + raise RuntimeError("nsight-python not available") + + metrics = [metrics] if isinstance(metrics, str) else (metrics or ['sm__cycles_active.avg']) + + @nsight.analyze.kernel(metric=metrics, runs=1, configs=[(0,)], + combine_kernel_metrics=lambda a, b: (a or 0) + (b or 0)) + def profiled(_): + with nsight.annotate("kernel"): + return func(*args, **kwargs) + + try: + result = profiled() + if result is None: + return {m: None for m in metrics} + + df = result.to_dataframe() + if df is None or df.empty: + return {m: None for m in metrics} + + if 'Metric Name' in df.columns: + return {row['Metric Name']: float(row['AvgValue']) for _, row in df.iterrows()} + return {metrics[0]: float(df['AvgValue'].iloc[0])} + except Exception as e: + print(f"Error profiling: {e}") + return {m: None for m in metrics} + + +# example function to profile with nsight +def example_ncu_python_profile(): + # Test with simple kernel + def test_kernel(x, y): + """Simple matmul kernel.""" + return x @ y + + print("Creating test tensors...") + a = torch.randn(256, 256, device="cuda") + b = torch.randn(256, 256, device="cuda") + + print("Running nsight profiling...") + metric_values = profile_with_nsight( + test_kernel, + ['sm__cycles_active.avg', 'sm__cycles_elapsed.sum', "smsp__inst_executed_pipe_tensor_op_hmma.sum"], + a, b + ) + + print("\nProfiling results:") + for metric_name, value in metric_values.items(): + print(f" {metric_name}: {value}") + return + def check_ncu_available() -> bool: - """Check if ncu is in PATH. Returns True if found, False otherwise.""" from shutil import which return which('ncu') is not None -# Note you need access to hardware counter -# you also need to have ncu installed and point to the right path -# sudo env "PATH=$PATH" $(which python) src/profile.py @nsight.analyze.kernel def benchmark_matmul(n): - """ - The simplest possible benchmark. - We create two matrices and multiply them. - follow example from nsight-python docs - """ - # Create two NxN matrices on GPU + """Standard benchmark following nsight-python docs.""" a = torch.randn(n, n, device="cuda") b = torch.randn(n, n, device="cuda") - - # Mark the kernel we want to profile with nsight.annotate("matmul"): c = a @ b - return c + if __name__ == "__main__": if not check_ncu_available(): - print("ncu not found in PATH. Please install ncu and point to path.") + print("ncu not found in PATH. Install Nsight Compute.") + exit(1) + + if not torch.cuda.is_available(): + print("CUDA is not available.") exit(1) - # Run the benchmark - result = benchmark_matmul(1024) + + + # test the example_ncu_python_profile + example_ncu_python_profile() + + # pytorch profiler # migrate from old repo during ICML / caesar repo \ No newline at end of file