diff --git a/README.md b/README.md index ba897602..008e51ed 100644 --- a/README.md +++ b/README.md @@ -104,7 +104,7 @@ python3 scripts/generate_and_eval_single_sample.py dataset_src="huggingface" lev **What you might need to modify** * **`gpu_arch`** - Depend on your GPU, you might need to adjust the `gpu_arch` argument to reflect your hardware. * **`precision`** - You can specify the precision of tensor by `precision=fp32`. Currently all of our reported results are `fp32` but we added support for `fp16` & `bf16`. -* **`backend`** - We are also supporting other GPU programming languages beyond `cuda`. Simply specify `backend=triton`. For now we support DSLs: `cuda`, `triton`, `cute`, `tilelang`. +* **`backend`** - We are also supporting GPU programming languages beyond `cuda`, e.g. simply specify `backend=triton` or `backend=hip`. For now we support: `cuda`, `hip`, `triton`, `cute`, `tilelang`. Check the config fields for comprehensive set of options. diff --git a/scripts/generate_and_eval_single_sample.py b/scripts/generate_and_eval_single_sample.py index 18fb3c55..0cba5ed4 100644 --- a/scripts/generate_and_eval_single_sample.py +++ b/scripts/generate_and_eval_single_sample.py @@ -172,11 +172,11 @@ def main(config: EvalConfig): # Use appropriate prompt constructor based on backend if config.backend == "cuda": custom_prompt = prompt_generate_custom_cuda_from_prompt_template(ref_arch_src) - elif config.backend in ["triton", "tilelang", "cute"]: + elif config.backend in ["triton", "tilelang", "cute", "hip"]: custom_prompt = get_prompt_for_backend(ref_arch_src, config.backend) else: raise ValueError( - f"Unsupported backend: {config.backend}. Must be 'cuda', 'triton', 'tilelang', or 'cute'." + f"Unsupported backend: {config.backend}. Must be 'cuda', 'triton', 'tilelang', 'cute', or 'hip'." ) if config.log_prompt: diff --git a/scripts/generate_and_eval_single_sample_modal.py b/scripts/generate_and_eval_single_sample_modal.py index 6962f515..ff25187b 100644 --- a/scripts/generate_and_eval_single_sample_modal.py +++ b/scripts/generate_and_eval_single_sample_modal.py @@ -198,10 +198,10 @@ def main(config: EvalConfig): # Use appropriate prompt constructor based on backend if config.backend == "cuda": custom_prompt = prompt_generate_custom_cuda_from_prompt_template(ref_arch_src) - elif config.backend in ["triton", "tilelang", "cute"]: + elif config.backend in ["triton", "tilelang", "cute", "hip"]: custom_prompt = get_prompt_for_backend(ref_arch_src, config.backend) else: - raise ValueError(f"Unsupported backend: {config.backend}. Must be 'cuda', 'triton', 'tilelang', or 'cute'.") + raise ValueError(f"Unsupported backend: {config.backend}. Must be 'cuda', 'hip', 'triton', 'tilelang', or 'cute'.") if config.log_prompt: with open(os.path.join(config.logdir, f"prompt_level_{config.level}_problem_{config.problem_id}.txt"), "w") as f: diff --git a/scripts/generate_samples.py b/scripts/generate_samples.py index 5b476445..99f162b3 100644 --- a/scripts/generate_samples.py +++ b/scripts/generate_samples.py @@ -131,11 +131,11 @@ def generate_sample_single( custom_cuda_prompt = prompt_generate_custom_cuda_from_prompt_template( ref_arch_src ) - elif config.backend in ["triton", "cute", "tilelang"]: + elif config.backend in ["triton", "hip", "cute", "tilelang"]: custom_cuda_prompt = get_prompt_for_backend(ref_arch_src, config.backend) else: raise ValueError( - f"Unsupported backend: {config.backend}. Must be 'cuda', 'triton', 'cute', or 'tilelang'." + f"Unsupported backend: {config.backend}. Must be 'cuda', `hip`, 'triton', 'cute', or 'tilelang'." ) if config.log_prompt: prompt_path = os.path.join( diff --git a/src/eval.py b/src/eval.py index 4a072c89..dd6196fe 100644 --- a/src/eval.py +++ b/src/eval.py @@ -370,7 +370,7 @@ def _process_input_tensor(input, device, backend="cuda", precision=torch.float32 Args: input: Input tensor or non-tensor value device: Target CUDA device - backend: Backend type (e.g., 'cuda', 'triton', 'cute') + backend: Backend type (e.g., 'cuda', `hip`, 'triton', 'cute') precision: torch.dtype Returns: Processed tensor on correct device with correct dtype, or original value if not a tensor @@ -399,7 +399,7 @@ def eval_kernel_against_ref( device: Union[torch.device, int] = ( torch.cuda.current_device() if torch.cuda.is_available() else None ), # have to run on GPU - backend: str = "cuda", # can be 'cuda', 'triton', 'tilelang', or 'cute' + backend: str = "cuda", # can be 'cuda', 'hip', 'triton', 'tilelang', or 'cute' precision: torch.dtype = torch.float32, ) -> KernelExecResult: """ @@ -408,7 +408,7 @@ def eval_kernel_against_ref( num_correct_trials: number of trials to initialize different random inputs; correctness pass only if all trials pass num_perf_trials: run the evalutation many times to take the average device: GPU (cuda) device to run the evalutation on - backend: str, one of 'cuda', 'triton', 'tilelang', or 'cute' + backend: str, one of 'cuda', 'hip', 'triton', 'tilelang', or 'cute' precision: torch.dtype for computation (note: tilelang only supports fp16) """ # TODO: check device is busy @@ -488,7 +488,7 @@ def eval_kernel_against_ref( custom_model_src, entry_point="ModelNew" ) else: - # Default CUDA backend + # Default CUDA/HIP backend ModelNew = load_custom_model(custom_model_src, context, build_dir) torch.cuda.synchronize(device=device) # not sure if this is too much except Exception as e: diff --git a/src/prompt_constructor.py b/src/prompt_constructor.py index 36cde19f..35a69940 100644 --- a/src/prompt_constructor.py +++ b/src/prompt_constructor.py @@ -456,12 +456,6 @@ def prompt_generate_prompt_with_hardware_info(ref_arch_src: str, return prompt - return Nonoe - - - - - def prompt_fix_compile(ref_arch_src, custom_cuda, metadata): prompt = PROBLEM_STATEMENT prompt += f""" diff --git a/src/prompt_constructor_multilang.py b/src/prompt_constructor_multilang.py index 8a520d10..557cf82b 100644 --- a/src/prompt_constructor_multilang.py +++ b/src/prompt_constructor_multilang.py @@ -492,6 +492,227 @@ def prompt_fix_correctness_cute(ref_arch_src, custom_kernel, metadata): return prompt +################################################################################ +# HIP Backend +################################################################################ + +HIP_PROBLEM_STATEMENT = """You write custom HIP kernels to replace the pytorch operators in the given architecture to get speedups. \n + You have complete freedom to choose the set of operators you want to replace. You may make the decision to replace some operators with custom HIP kernels and leave others unchanged. You may replace multiple operators with custom implementations, consider operator fusion opportunities (combining multiple operators into a single kernel, for example, combining matmul+relu), or algorithmic changes (such as online softmax). You are only limited by your imagination.\n +""" + +HIP_PROBLEM_INSTRUCTION = """ +Optimize the architecture named Model with custom HIP operators! Name your optimized output architecture ModelNew. Output the new code in codeblocks. Please generate real code, NOT pseudocode, make sure the code compiles and is fully functional. Just output the new model code, no other text, and NO testing code! \n +""" + +HIP_PROBLEM_STATEMENT_CLEANED = """You write custom HIP kernels to replace the pytorch operators in the given architecture to get speedups.\n\nYou have complete freedom to choose the set of operators you want to replace. You may make the decision to replace some operators with custom HIP kernels and leave others unchanged. You may replace multiple operators with custom implementations, consider operator fusion opportunities (combining multiple operators into a single kernel, for example, combining matmul+relu), or algorithmic changes (such as online softmax). You are only limited by your imagination.\n +""" + +HIP_PROBLEM_INSTRUCTION_CLEANED = """ +Optimize the architecture named Model with custom HIP operators! Name your optimized output architecture ModelNew. Output the new code in codeblocks. Please generate real code, NOT pseudocode, make sure the code compiles and is fully functional. Just output the new model code, no other text, and NO testing code! \n +""" + +def prompt_generate_custom_hip( + arc_src: str, example_arch_src: str, example_new_arch_src: str +) -> str: + prompt = HIP_PROBLEM_STATEMENT + + if example_arch_src != "" and example_new_arch_src != "": + prompt += f""" + Here's an example to show you the syntax of inline embedding custom HIP operators in torch: The example given architecture is: \n + ``` \n + {example_arch_src} + ``` \n + The example new arch with custom HIP kernels looks like this: + ``` + {example_new_arch_src} + ``` \n + """ + + prompt += f""" + You are given the following architecture: \n + ``` + {arc_src} + ``` + """ + prompt += HIP_PROBLEM_INSTRUCTION + return prompt + + +def prompt_generate_custom_hip_from_prompt_template(ref_arch_src: str) -> str: + """ + Using prompt example (an element-wise addition) for prompt templates + The most basic form of example just to show LLM the task and the expected output format + """ + arch = ref_arch_src + # These are strictly defined for now + + # path to prompt template, show an example of Model (torch specifications) and ModelNew (torch + custom HIP kernels) + example_arch_path = os.path.join( + REPO_TOP_PATH, f"src/prompts/model_ex_add.py" + ) + example_new_arch_path = os.path.join( + REPO_TOP_PATH, f"src/prompts/model_new_ex_add_hip.py" + ) + + if not os.path.exists(example_arch_path): + raise FileNotFoundError( + f"Example architecture file not found: {example_arch_path}" + ) + if not os.path.exists(example_new_arch_path): + raise FileNotFoundError( + f"Example new architecture file not found: {example_new_arch_path}" + ) + + example_arch = read_file(example_arch_path) + example_new_arch = read_file(example_new_arch_path) + + return prompt_generate_custom_hip(arch, example_arch, example_new_arch) + + +def prompt_generate_prompt_with_hardware_info_from_template_hip(ref_arch_src: str, gpu_name: str) -> str: + """ + Similar to prompt_generate_custom_hip_from_prompt_template, + but with hardware information for the given GPU + """ + + arch = ref_arch_src + # These are strictly defined for now + + # path to prompt template, show an example of Model (torch specifications) and ModelNew (torch + custom CUDA kernels) + example_arch_path = os.path.join( + REPO_TOP_PATH, f"src/prompts/model_ex_add.py" + ) + example_new_arch_path = os.path.join( + REPO_TOP_PATH, f"src/prompts/model_new_ex_add.py" + ) + + gpu_spec_file_path = os.path.join(REPO_TOP_PATH, f"src/prompts/hardware/gpu_specs.py") + + example_arch = read_file(example_arch_path) + example_new_arch = read_file(example_new_arch_path) + gpu_spec_info = read_file(gpu_spec_file_path) + + return prompt_generate_prompt_with_hardware_info_hip( + ref_arch_src=arch, + gpu_name=gpu_name, + example_arch_src=example_arch, + example_new_arch_src=example_new_arch, + gpu_spec_info_src=gpu_spec_info + ) + + + +def prompt_generate_prompt_with_hardware_info_hip(ref_arch_src: str, + gpu_name: str, + example_arch_src: str, + example_new_arch_src: str, + gpu_spec_info_src: str) -> str: + """ + Generate a prompt with hardware information for the given GPU + gpu_spec_info_src: str of the gpu spec src file + """ + + local_dict = {} + exec(gpu_spec_info_src, {}, local_dict) + + GPU_SPEC_INFO = local_dict.get('GPU_SPEC_INFO') + GPU_DEFINITIONS = local_dict.get('GPU_DEFINITIONS') + GPU_BEST_PRACTICES = local_dict.get('GPU_BEST_PRACTICES') + + if not GPU_SPEC_INFO or not GPU_DEFINITIONS or not GPU_BEST_PRACTICES: + raise ValueError("GPU_SPEC_INFO or GPU_DEFINITIONS or GPU_BEST_PRACTICES not found in gpu_spec_info_src") + + assert gpu_name in GPU_SPEC_INFO, f"GPU name {gpu_name} not found in GPU_SPEC_INFO" + + prompt = HIP_PROBLEM_STATEMENT + + if example_arch_src != "" and example_new_arch_src != "": + prompt += f""" + Here's an example to show you the syntax of inline embedding custom CUDA operators in torch: The example given architecture is: \n + ``` \n + {example_arch_src} + ``` \n + The example new arch with custom CUDA kernels looks like this: + ``` + {example_new_arch_src} + ``` \n + """ + + curr_gpu_spec_info = GPU_SPEC_INFO[gpu_name] + + gpu_architecture = curr_gpu_spec_info.get("GPU Architecture") + prompt += f""" + Here is some information about the underlying hardware that you should keep in mind. \n\n +The GPU that will run the kernel is AMD {gpu_name}, {gpu_architecture} architecture.\n\n""" + + for key, value in curr_gpu_spec_info.items(): + if key == "GPU Architecture": + continue + prompt += f"""- We have {value} of {key}.\n""" + + prompt += f"""\n\n +Here are some concepts about the GPU architecture that could be helpful: \n\n""" + for key, value in GPU_DEFINITIONS.items(): + prompt += f"""- {key}: {value}\n""" + + prompt += f"""\n\n +Here are some best practices for writing HIP kernels on GPU: \n\n""" + for best_practice in GPU_BEST_PRACTICES: + prompt += f"""- {best_practice}\n""" + + + prompt += f""" + You are given the following architecture: \n + ``` + {ref_arch_src} + ``` + """ + + prompt += HIP_PROBLEM_INSTRUCTION + return prompt + + +def prompt_fix_compile_hip(ref_arch_src, custom_hip_kernel, metadata): + prompt = HIP_PROBLEM_STATEMENT + prompt += f""" + With the following architecture: + ``` + {ref_arch_src} + ``` + You generated the following solution and it failed to compile: + ``` + {custom_hip_kernel} + ``` + Here's the metadata of the compilation error: + ``` + {metadata} + ``` + + Please fix the compilation error in the new model code. Please output the corrected code in codeblocks. + """ + return prompt + + +def prompt_fix_correctness_hip(ref_arch_src, custom_hip_kernel, metadata): + prompt = HIP_PROBLEM_STATEMENT + prompt += f""" + With the following architecture: + ``` + {ref_arch_src} + ``` + You generated the following solution and it failed correctness: + ``` + {custom_hip_kernel} + ``` + Here's the metadata of the correctness error: + ``` + {metadata} + ``` + Please consider how your custom Triton kernels are implemented, how it is different from the reference implementation, and fix the correctness error in the new model code. Please output the corrected code in codeblocks. + """ + return prompt + + ################################################################################ # Unified API ################################################################################ @@ -502,7 +723,7 @@ def get_prompt_for_backend(ref_arch_src: str, backend: str = "triton") -> str: Args: ref_arch_src: Reference architecture source code - backend: One of 'triton', 'tilelang', 'cute' + backend: One of 'triton', 'tilelang', 'cute', 'hip' Returns: Prompt string for the specified backend @@ -515,6 +736,8 @@ def get_prompt_for_backend(ref_arch_src: str, backend: str = "triton") -> str: return prompt_generate_custom_tilelang_from_prompt_template(ref_arch_src) elif backend_lower == "cute": return prompt_generate_custom_cute_from_prompt_template(ref_arch_src) + elif backend_lower == "hip": + return prompt_generate_custom_hip_from_prompt_template(ref_arch_src) else: raise ValueError( f"Unsupported backend: {backend}. Must be one of: 'triton', 'tilelang', 'cute'" diff --git a/src/prompts/hardware/gpu_specs.py b/src/prompts/hardware/gpu_specs.py index dcf60c7f..357337cb 100644 --- a/src/prompts/hardware/gpu_specs.py +++ b/src/prompts/hardware/gpu_specs.py @@ -118,7 +118,91 @@ "Maximum number of thread blocks per SM": "32", "Shared memory capacity per SM": "164 KB", "Maximum shared memory per thread block": "163 KB", - } + }, + "MI300X": { + "GPU Architecture": "CDNA3", + "GPU Memory": "192GB", + "Memory Bandwidth": "5.3 TB/s", + "FP64 TFLOPS": "81.7", + "FP64 Matrix Core TFLOPS": "163.4", + "FP32 TFLOPS": "163.4", + "TF32 Matrix Core TFLOPS": "653.7 (1307.4 with sparsity)", + "BFLOAT16 Matrix Core TFLOPS": "1307.4 (2614.9 with sparsity)", + "FP16 Matrix Core TFLOPS": "1307.4 (2614.9 with sparsity)", + "FP8 Matrix Core TFLOPS": "2614.9 (5229.8 with sparsity)", + "INT8 Matrix Core TOPS": "2614.9 (5229.8 with sparsity)", + "Number of CU": "304", + "SIMDs per CU": "4", + "Wavefront Size": "64", + "Workgroup Max Size": "1024", + "Max Waves Per CU": "32", + "Max Threads per CU": "2048", + "Maximum number of registers per thread": "256", + "Shared memory capacity per CU": "64 KB", + }, + "MI325X": { + "GPU Architecture": "CDNA3", + "GPU Memory": "256GB", + "Memory Bandwidth": "6TB/s", + "FP64 TFLOPS": "81.7", + "FP64 Matrix Core TFLOPS": "163.4", + "FP32 TFLOPS": "163.4", + "TF32 Matrix Core TFLOPS": "653.7 (1307.4 with sparsity)", + "BFLOAT16 Matrix Core TFLOPS": "1307.4 (2614.9 with sparsity)", + "FP16 Matrix Core TFLOPS": "1307.4 (2614.9 with sparsity)", + "FP8 Matrix Core TFLOPS": "2614.9 (5229.8 with sparsity)", + "INT8 Matrix Core TOPS": "2614.9 (5229.8 with sparsity)", + "Number of CU": "304", + "SIMDs per CU": "4", + "Wavefront Size": "64", + "Workgroup Max Size": "1024", + "Max Waves Per CU": "32", + "Max Threads per CU": "2048", + "Maximum number of registers per thread": "256", + "Shared memory capacity per CU": "64 KB", + }, + "MI350X": { + "GPU Architecture": "CDNA4", + "GPU Memory": "288GB", + "Memory Bandwidth": "8TB/s", + "FP64 TFLOPS": "72.1", + "FP64 Matrix Core TFLOPS": "72.1", + "FP32 TFLOPS": "144.2", + "BFLOAT16 Matrix Core TFLOPS": "2300 (4600 with sparsity)", + "FP16 Matrix Core TFLOPS": "2300 (4600 with sparsity)", + "FP8 Matrix Core TFLOPS": "4600", + "MXFP6, MXFP4 Matrix Core TFLOPS": "9200", + "INT8 Matrix Core TOPS": "4600 (9200 with sparsity)", + "Number of CU": "256", + "SIMDs per CU": "4", + "Wavefront Size": "64", + "Workgroup Max Size": "1024", + "Max Waves Per CU": "32", + "Max Threads per CU": "2048", + "Maximum number of registers per thread": "256", + "Shared memory capacity per CU": "160 KB", + }, + "MI355X": { + "GPU Architecture": "CDNA4", + "GPU Memory": "288GB", + "Memory Bandwidth": "8TB/s", + "FP64 TFLOPS": "78.6", + "FP64 Matrix Core TFLOPS": "78.6", + "FP32 TFLOPS": "157.3", + "BFLOAT16 Matrix Core TFLOPS": "2500 (5000 with sparsity)", + "FP16 Matrix Core TFLOPS": "2500 (5000 with sparsity)", + "FP8 Matrix Core TFLOPS": "5000", + "MXFP6, MXFP4 Matrix Core TFLOPS": "10000", + "INT8 Matrix Core TOPS": "5000 (10000 with sparsity)", + "Number of CU": "256", + "SIMDs per CU": "4", + "Wavefront Size": "64", + "Workgroup Max Size": "1024", + "Max Waves Per CU": "32", + "Max Threads per CU": "2048", + "Maximum number of registers per thread": "256", + "Shared memory capacity per CU": "160 KB", + }, } # Basic GPU concept definitions diff --git a/src/prompts/model_new_ex_add_hip.py b/src/prompts/model_new_ex_add_hip.py new file mode 100644 index 00000000..fa66cf03 --- /dev/null +++ b/src/prompts/model_new_ex_add_hip.py @@ -0,0 +1,45 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.cpp_extension import load_inline + +import os +os.environ["CXX"] = "hipcc" + +elementwise_add_cpp_source = """ +#include + +__global__ void elementwise_add_kernel(const float* a, const float* b, float* out, int size) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < size) { + out[idx] = a[idx] + b[idx]; + } +} + +torch::Tensor elementwise_add_hip(torch::Tensor a, torch::Tensor b) { + auto size = a.numel(); + auto out = torch::zeros_like(a); + + const int block_size = 256; + const int num_blocks = (size + block_size - 1) / block_size; + + elementwise_add_kernel<<>>(a.data_ptr(), b.data_ptr(), out.data_ptr(), size); + + return out; +} +""" + +elementwise_add = load_inline( + name="elementwise_add", + cpp_sources=elementwise_add_cpp_source, + functions=["elementwise_add_hip"], + verbose=True, +) + +class ModelNew(nn.Module): + def __init__(self) -> None: + super().__init__() + self.elementwise_add = elementwise_add + + def forward(self, a, b): + return self.elementwise_add.elementwise_add_hip(a, b) \ No newline at end of file