-
Notifications
You must be signed in to change notification settings - Fork 9
Padded/sliced dynamic batch, cache inprovements and parallel model precompiled load Improvements to MIGraphX EP #207
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Padded/sliced dynamic batch, cache inprovements and parallel model precompiled load Improvements to MIGraphX EP #207
Conversation
This commit addresses the issue where input_shapes could be empty when first loading a model, which would cause model caching to be skipped incorrectly. Key changes: - Fixed incorrect iteration using session_input_names.size() as index into input_tensor vector. Now directly iterates over input_tensor. - Changed loop to include all dimensions (starting from j=0 instead of j=1), as skipping dimension 0 could result in empty shapes for low-rank tensors. - Added validation to check for null shapes and zero dimension sizes before processing. - Added check for has_dim_value() to handle symbolic/dynamic dimensions properly, only including dimensions with concrete positive values. This ensures model caching works correctly when valid input shapes are available.
This reverts commit 60c8226.
…ted shapes. let MIGraphX treat this as a static shape
…itializers with regards to batch size
Updated 32 log statements in the compute function from LOGS_DEFAULT(VERBOSE) to LOGS_DEFAULT(INFO) for better visibility of compute-related operations during inference.
Use this call from previous genreated bits to encapsulate model compilation and parameters needed to ensure an MIGraphX program is properly compiled and setup accordingly
Cleans up the compute thread and keeps all inputs and pieces clear to whats going to be run to migraphx through the api.
Remoev this from the compute so we can encapsulate things in a reasonable way
make this a seperate call that takes in input context, program and paramters hape and name information so we can populate the items needed based of the MIGraphX program to perform a run_async later. Doing this as part of cleaning up the compute function to further optimize later
capture this in a sepeate call so we get an idea of how input shapes are handled during compute and the modes
Reuse this and remove a bunch of redundant repeated code
Store a compiled or preloaded from disk MIGraphX program into a map index by batch size. Use this as the program in the compute method if an incomming batch size matches that of what we wanted to run. If this fails, fallback to the preload from disk, and if that fails compile the model in the compute thread
- TIghten lock around run async' - Remove O(n) lookup with find and use unordered_set instead - Use optional to help tighten up lock
Should improve runtime from O(N*2) to O(N) for running through outputs and checking
Do this so that we can mange updates between inferences for things like dynamic batch or sequence length more effetively. Right now we were corsely recompiling based on any mismatch and always checking input shapes. In these cases now instaed of cheaping all N inputs we should check the symbolic dimensions for updates
Move this out into a seperate call to ensure we're tracking whether we get a dynamic batch size as well as other symbolic dimensions in the model we detect on compile
- Implement dynamic batching using power-of-two batch sizes (1, 2, 4, 8, 16, 32, 64, etc.) - Pre-compile models for all power-of-two batch sizes up to max_dynamic_batch - Pad input tensors to nearest power-of-two when needed - Slice output tensors back to original batch size - Fix shape ordering consistency between compilation and runtime (alphabetical for hash, MIGraphX order for caching) - Fix ultra-fast path to apply dynamic batch logic to all inputs, not just first - Implement GPU buffer allocation and padding for inputs - Fix output slicing to handle ALL outputs including pre-allocated ones Key components: - allocate_and_pad_inputs(): Allocates GPU buffers and pads input data - free_padded_inputs(): Cleans up padded buffers after execution - Modified execute_ultra_fast_path(), execute_fast_path(), execute_standard_path() to use padded inputs - Fixed run_migraphx_program() to slice all 13 outputs correctly This enables efficient dynamic batching without recompilation by reusing cached power-of-two batch models. Adding other fixes to handle cases where batch matches power of two as well as ultra-fast path comparison for the cached models so that input shape orders aren't violated
Do this so that we save overhead on some of the buffers we've already alocated thus reusing them when the same padded batch input comes in
- Cache sliced and remove constant aclls to migraphx api calls that add overhead to each run.
Items depricated through optimizations and caching
Remove this from compile so its clear we can maintain this as a seperate piece/pass before inference is run
Ensure we just load mxr and not rely on the auto detection
Add parallel loading of mxr files to cache on startup
…mpile. This is due to thread safety during MIGraphX compile
apwojcik
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work!
| // Allocate and pad each input | ||
| mgx_state->padded_input_buffers.reserve(mgx_state->cached_inputs.size()); | ||
|
|
||
| for (const auto& cached_inp : mgx_state->cached_inputs) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just be careful here where all input shapes may not actually have a batch dimension. There's not a clean way of checking this unless you do some sort of symbolic tracing. You could have a heuristic where you look at the first dimension of each input, then count number of occurrences of the value at the first dimension. The ones which don't match the most frequent occurrence probably don't use the batch dimension.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be handled. I only allow for dynamic batch by using the has_only_dynamic_batch_dimension(), used on line 3837. We only then extract shapes before we handle everything correctly to ensure we're not getting into weight scenarios for now.
This assumes other inputs are static right now. If we run into a model that needs dynamic batch + has something like sequence length that changes between runs then we'll need to cover that differently. Otherwise the model will appear to have different shapes each run and fail hash/lookups to use the same inputs and trigger a recompile if we fail lookup from cache, then disk.
| static void pad_input_tensor(const void* src_data, void* dst_data, | ||
| std::size_t original_batch, std::size_t padded_batch, | ||
| std::size_t element_size_bytes, std::size_t elements_per_batch, | ||
| hipStream_t stream) { | ||
| std::size_t bytes_per_batch = element_size_bytes * elements_per_batch; | ||
|
|
||
| // Copy original data | ||
| HIP_CALL_THROW(hipMemcpyAsync(dst_data, src_data, | ||
| original_batch * bytes_per_batch, | ||
| hipMemcpyDeviceToDevice, stream)); | ||
|
|
||
| // Pad with last batch element replicated | ||
| if (original_batch > 0 && padded_batch > original_batch) { | ||
| const char* last_batch = static_cast<const char*>(src_data) + (original_batch - 1) * bytes_per_batch; | ||
| char* pad_start = static_cast<char*>(dst_data) + original_batch * bytes_per_batch; | ||
|
|
||
| for (std::size_t i = original_batch; i < padded_batch; ++i) { | ||
| HIP_CALL_THROW(hipMemcpyAsync(pad_start, last_batch, bytes_per_batch, | ||
| hipMemcpyDeviceToDevice, stream)); | ||
| pad_start += bytes_per_batch; | ||
| } | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How much effort would it be to write and use a GPU kernel to do this? It would be more efficient.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could but that would mean we need to add another kernel to onnxruntime or find a way to add this in MIGraphX before this gets handled. Right now we treat an entire compiled graph as an MIGraphX Kernel and allow us to handle all the pieces internal to MIGraphX
| ArenaExtendStrategy arena_extend_strategy{ArenaExtendStrategy::kNextPowerOfTwo}; | ||
|
|
||
| OrtArenaCfg* default_memory_arena_cfg{nullptr}; | ||
| size_t max_dynamic_batch{static_cast<size_t>(0)}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the static cast looks redundant here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think its due to 0 defaulting to int instead of size_t which is unsigned
|
Merging this in then Cherry-picking to 7.2 so we have coverage for now for 7.2.1 and this changeset |
5ff8193
into
rocm7.1_internal_testing
|
DO NOT DELETE THIS BRANCH as this branch is still in use |
…ecompiled load Improvements to MIGraphX EP (#207) Set of changes to enable dynamic batch functionality in MIGraphX EP using a padded/slice approach with added caching for models and faster paths for static cases Summary of Changes Dynamic Batch Support with Power-of-Two Padding The major feature addition is dynamic batch support where input batch sizes are padded to the nearest power-of-two (1, 2, 4, 8, 16, etc.) enabling model reuse across different batch sizes without recompilation. Three-Tier Execution Path Architecture The compute function was refactored into three distinct paths: Ultra-fast path (execute_ultra_fast_path): When shapes are identical to the last run - just rebinds memory pointers and executes directly Fast path (execute_fast_path): When a cached program exists for the shape hash - retrieves from cache and runs Standard path (execute_standard_path): Full shape checking with potential recompilation Precompilation at Compile Time Moved model compilation from runtime to the Compile() phase: precompile_all_dynamic_batch_models() - Precompiles all power-of-two batch sizes during initialization precompile_static_model() - Precompiles static models during initialization handle_precompilation_decision() - Logic to determine if precompilation is possible Extensive Caching Infrastructure Program cache: cached_programs map storing compiled programs by shape hash Shape caching: cached_mgx_param_shapes and cached_mgx_output_shapes to avoid repeated MIGraphX API calls Input/output caching: CachedInputParam and CachedOutputParam structures with pre-computed indices and shapes Buffer reuse: Padded input buffers and temporary output buffers are kept for reuse Input/Output Handling Refactoring handle_program_input_outputs() - Centralized function to bind inputs and allocate outputs populate_ultra_fast_caches() - Prepares optimized caches for ultra-fast path build_input_shapes_in_cached_order() - Ensures consistent shape ordering for cache lookups allocate_and_pad_inputs() - Handles GPU memory allocation and padding for dynamic batching Output Slicing for Padded Batches When running with a padded batch size larger than the original, outputs are stored in temporary buffers and sliced back to the original batch size. Parallel Loading / Sequential Compilation To improve load time while respecting MIGraphX thread safety: Parallel loading from disk cache Serialized compilation for cache misses (due to thread safety during MIGraphX compile) New Helper Functions has_only_dynamic_batch_dimension() - Checks if only the batch dimension is dynamic extract_base_shapes_from_graph() - Extracts non-batch dimensions from graph definition find_nearest_power_of_two_batch() - Finds the appropriate padded batch size generate_power_of_two_batch_sizes() - Generates the list of batch sizes to precompile get_input_name_map() / get_program_parameter_options() - Input processing utilities Key Benefits Reduced runtime compilation - Models are compiled at initialization instead of first inference Better batch size handling - Arbitrary batch sizes reuse pre-compiled power-of-two models Faster repeated inferences - Ultra-fast path skips shape checking when inputs are unchanged Improved memory management - Buffer reuse reduces allocation overhead
Description
Using AI description for this to better summarize
Summary of Changes
The major feature addition is dynamic batch support where input batch sizes are padded to the nearest power-of-two (1, 2, 4, 8, 16, etc.) enabling model reuse across different batch sizes without recompilation.
The compute function was refactored into three distinct paths:
Ultra-fast path (execute_ultra_fast_path): When shapes are identical to the last run - just rebinds memory pointers and executes directly
Fast path (execute_fast_path): When a cached program exists for the shape hash - retrieves from cache and runs
Standard path (execute_standard_path): Full shape checking with potential recompilation
Moved model compilation from runtime to the Compile() phase:
precompile_all_dynamic_batch_models() - Precompiles all power-of-two batch sizes during initialization
precompile_static_model() - Precompiles static models during initialization
handle_precompilation_decision() - Logic to determine if precompilation is possible
Program cache: cached_programs map storing compiled programs by shape hash
Shape caching: cached_mgx_param_shapes and cached_mgx_output_shapes to avoid repeated MIGraphX API calls
Input/output caching: CachedInputParam and CachedOutputParam structures with pre-computed indices and shapes
Buffer reuse: Padded input buffers and temporary output buffers are kept for reuse
handle_program_input_outputs() - Centralized function to bind inputs and allocate outputs
populate_ultra_fast_caches() - Prepares optimized caches for ultra-fast path
build_input_shapes_in_cached_order() - Ensures consistent shape ordering for cache lookups
allocate_and_pad_inputs() - Handles GPU memory allocation and padding for dynamic batching
When running with a padded batch size larger than the original, outputs are stored in temporary buffers and sliced back to the original batch size.
To improve load time while respecting MIGraphX thread safety:
Parallel loading from disk cache
Serialized compilation for cache misses (due to thread safety during MIGraphX compile)
has_only_dynamic_batch_dimension() - Checks if only the batch dimension is dynamic
extract_base_shapes_from_graph() - Extracts non-batch dimensions from graph definition
find_nearest_power_of_two_batch() - Finds the appropriate padded batch size
generate_power_of_two_batch_sizes() - Generates the list of batch sizes to precompile
get_input_name_map() / get_program_parameter_options() - Input processing utilities
Key Benefits
Reduced runtime compilation - Models are compiled at initialization instead of first inference
Better batch size handling - Arbitrary batch sizes reuse pre-compiled power-of-two models
Faster repeated inferences - Ultra-fast path skips shape checking when inputs are unchanged
Improved memory management - Buffer reuse reduces allocation overhead
Motivation and Context
Customer ask since their models require us to add a dynamic batch mode
Offloads compile to leverage our MXR caching
has additional fixes and speedups as part of this body of work
Will cherry-pick to ROCm 7.2 once this is in