Skip to content

Conversation

@TedThemistokleous
Copy link
Collaborator

@TedThemistokleous TedThemistokleous commented Jan 20, 2026

Description

  • Adds dynamic batch padding using max_dynamic_batch flag in EP API
  • Better handling of dynamic symbolic shape inputs/detection
  • Padded batch and slicing for models inbetween powers of two
  • Adds caching of models for static input shapes
  • Adds ultra fast, fast and standard path
  • Performs threaded load or serialized precompile with input models.
  • Reduces overhead of run with ultra/fast paths to make run close to MIGraphx driver

Using AI description for this to better summarize

Summary of Changes

  1. Dynamic Batch Support with Power-of-Two Padding
    The major feature addition is dynamic batch support where input batch sizes are padded to the nearest power-of-two (1, 2, 4, 8, 16, etc.) enabling model reuse across different batch sizes without recompilation.
  2. Three-Tier Execution Path Architecture
    The compute function was refactored into three distinct paths:
    Ultra-fast path (execute_ultra_fast_path): When shapes are identical to the last run - just rebinds memory pointers and executes directly
    Fast path (execute_fast_path): When a cached program exists for the shape hash - retrieves from cache and runs
    Standard path (execute_standard_path): Full shape checking with potential recompilation
  3. Precompilation at Compile Time
    Moved model compilation from runtime to the Compile() phase:
    precompile_all_dynamic_batch_models() - Precompiles all power-of-two batch sizes during initialization
    precompile_static_model() - Precompiles static models during initialization
    handle_precompilation_decision() - Logic to determine if precompilation is possible
  4. Extensive Caching Infrastructure
    Program cache: cached_programs map storing compiled programs by shape hash
    Shape caching: cached_mgx_param_shapes and cached_mgx_output_shapes to avoid repeated MIGraphX API calls
    Input/output caching: CachedInputParam and CachedOutputParam structures with pre-computed indices and shapes
    Buffer reuse: Padded input buffers and temporary output buffers are kept for reuse
  5. Input/Output Handling Refactoring
    handle_program_input_outputs() - Centralized function to bind inputs and allocate outputs
    populate_ultra_fast_caches() - Prepares optimized caches for ultra-fast path
    build_input_shapes_in_cached_order() - Ensures consistent shape ordering for cache lookups
    allocate_and_pad_inputs() - Handles GPU memory allocation and padding for dynamic batching
  6. Output Slicing for Padded Batches
    When running with a padded batch size larger than the original, outputs are stored in temporary buffers and sliced back to the original batch size.
  7. Parallel Loading / Sequential Compilation
    To improve load time while respecting MIGraphX thread safety:
    Parallel loading from disk cache
    Serialized compilation for cache misses (due to thread safety during MIGraphX compile)
  8. New Helper Functions
    has_only_dynamic_batch_dimension() - Checks if only the batch dimension is dynamic
    extract_base_shapes_from_graph() - Extracts non-batch dimensions from graph definition
    find_nearest_power_of_two_batch() - Finds the appropriate padded batch size
    generate_power_of_two_batch_sizes() - Generates the list of batch sizes to precompile
    get_input_name_map() / get_program_parameter_options() - Input processing utilities
    Key Benefits
    Reduced runtime compilation - Models are compiled at initialization instead of first inference
    Better batch size handling - Arbitrary batch sizes reuse pre-compiled power-of-two models
    Faster repeated inferences - Ultra-fast path skips shape checking when inputs are unchanged
    Improved memory management - Buffer reuse reduces allocation overhead

Motivation and Context

Customer ask since their models require us to add a dynamic batch mode
Offloads compile to leverage our MXR caching
has additional fixes and speedups as part of this body of work

Will cherry-pick to ROCm 7.2 once this is in

This commit addresses the issue where input_shapes could be empty when first
loading a model, which would cause model caching to be skipped incorrectly.

Key changes:
- Fixed incorrect iteration using session_input_names.size() as index into
  input_tensor vector. Now directly iterates over input_tensor.
- Changed loop to include all dimensions (starting from j=0 instead of j=1),
  as skipping dimension 0 could result in empty shapes for low-rank tensors.
- Added validation to check for null shapes and zero dimension sizes before
  processing.
- Added check for has_dim_value() to handle symbolic/dynamic dimensions
  properly, only including dimensions with concrete positive values.

This ensures model caching works correctly when valid input shapes are available.
…ted shapes. let MIGraphX treat this as a static shape
Updated 32 log statements in the compute function from LOGS_DEFAULT(VERBOSE)
to LOGS_DEFAULT(INFO) for better visibility of compute-related operations
during inference.
Use this call from previous genreated bits to encapsulate model compilation and parameters needed to ensure an MIGraphX program is properly  compiled and setup accordingly
Cleans up the compute thread and keeps all inputs and pieces clear to whats going to be run to migraphx through the api.
Remoev this from the compute so we can encapsulate things in a reasonable way
make this a seperate call that takes in input context, program and  paramters hape and name information so we can populate the items needed based of the MIGraphX program to perform a run_async later.

Doing this as part of cleaning up the compute function to further optimize later
capture this in a sepeate call so we get an idea of how input shapes are handled during compute and the modes
Reuse this and remove a bunch of redundant repeated code
Store a compiled or preloaded from disk MIGraphX program into a map index by batch size. Use this as the program in the compute method if an incomming batch size matches that of what we wanted to run.

If this fails, fallback to the preload from disk, and if that fails compile the model in the compute thread
- TIghten lock around run async'
- Remove O(n) lookup with find and use unordered_set instead
- Use optional to help tighten up lock
Should improve runtime from O(N*2) to O(N) for running through outputs and checking
Do this so that we can mange updates between inferences for things like dynamic batch or sequence length more effetively. Right now we were corsely recompiling based on any mismatch and always checking input shapes. In these cases now instaed of cheaping all N inputs we should check the symbolic dimensions  for updates
Move this out into a seperate call to ensure we're tracking whether we get a dynamic batch size as well as other symbolic dimensions in the model we detect on compile
- Implement dynamic batching using power-of-two batch sizes (1, 2, 4, 8, 16, 32, 64, etc.)
- Pre-compile models for all power-of-two batch sizes up to max_dynamic_batch
- Pad input tensors to nearest power-of-two when needed
- Slice output tensors back to original batch size
- Fix shape ordering consistency between compilation and runtime (alphabetical for hash, MIGraphX order for caching)
- Fix ultra-fast path to apply dynamic batch logic to all inputs, not just first
- Implement GPU buffer allocation and padding for inputs
- Fix output slicing to handle ALL outputs including pre-allocated ones

Key components:
- allocate_and_pad_inputs(): Allocates GPU buffers and pads input data
- free_padded_inputs(): Cleans up padded buffers after execution
- Modified execute_ultra_fast_path(), execute_fast_path(), execute_standard_path() to use padded inputs
- Fixed run_migraphx_program() to slice all 13 outputs correctly

This enables efficient dynamic batching without recompilation by reusing cached power-of-two batch models.

Adding other fixes to handle cases where batch matches power of two as well as ultra-fast path comparison for the cached models so that input shape orders aren't violated
Do this so that we save overhead on some of the buffers we've already alocated thus reusing them when the same padded batch input comes in
- Cache sliced and remove constant aclls to migraphx api calls that add overhead to each run.
Items depricated through optimizations and caching
Remove this from compile so its clear we can maintain this as a seperate piece/pass before inference is run
Ensure we just load mxr and not rely on the auto detection
Add parallel loading of mxr files to cache on startup
…mpile. This is due to thread safety during MIGraphX compile
@TedThemistokleous TedThemistokleous self-assigned this Jan 20, 2026
@TedThemistokleous TedThemistokleous changed the title padded batch, cache inprovements and parallel model precompiled load Improvements to MIGraphX EP Padded/sliced dynamic batch, cache inprovements and parallel model precompiled load Improvements to MIGraphX EP Jan 20, 2026
Copy link
Collaborator

@apwojcik apwojcik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work!

// Allocate and pad each input
mgx_state->padded_input_buffers.reserve(mgx_state->cached_inputs.size());

for (const auto& cached_inp : mgx_state->cached_inputs) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just be careful here where all input shapes may not actually have a batch dimension. There's not a clean way of checking this unless you do some sort of symbolic tracing. You could have a heuristic where you look at the first dimension of each input, then count number of occurrences of the value at the first dimension. The ones which don't match the most frequent occurrence probably don't use the batch dimension.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be handled. I only allow for dynamic batch by using the has_only_dynamic_batch_dimension(), used on line 3837. We only then extract shapes before we handle everything correctly to ensure we're not getting into weight scenarios for now.

This assumes other inputs are static right now. If we run into a model that needs dynamic batch + has something like sequence length that changes between runs then we'll need to cover that differently. Otherwise the model will appear to have different shapes each run and fail hash/lookups to use the same inputs and trigger a recompile if we fail lookup from cache, then disk.

Comment on lines +1252 to +1274
static void pad_input_tensor(const void* src_data, void* dst_data,
std::size_t original_batch, std::size_t padded_batch,
std::size_t element_size_bytes, std::size_t elements_per_batch,
hipStream_t stream) {
std::size_t bytes_per_batch = element_size_bytes * elements_per_batch;

// Copy original data
HIP_CALL_THROW(hipMemcpyAsync(dst_data, src_data,
original_batch * bytes_per_batch,
hipMemcpyDeviceToDevice, stream));

// Pad with last batch element replicated
if (original_batch > 0 && padded_batch > original_batch) {
const char* last_batch = static_cast<const char*>(src_data) + (original_batch - 1) * bytes_per_batch;
char* pad_start = static_cast<char*>(dst_data) + original_batch * bytes_per_batch;

for (std::size_t i = original_batch; i < padded_batch; ++i) {
HIP_CALL_THROW(hipMemcpyAsync(pad_start, last_batch, bytes_per_batch,
hipMemcpyDeviceToDevice, stream));
pad_start += bytes_per_batch;
}
}
}

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How much effort would it be to write and use a GPU kernel to do this? It would be more efficient.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could but that would mean we need to add another kernel to onnxruntime or find a way to add this in MIGraphX before this gets handled. Right now we treat an entire compiled graph as an MIGraphX Kernel and allow us to handle all the pieces internal to MIGraphX

ArenaExtendStrategy arena_extend_strategy{ArenaExtendStrategy::kNextPowerOfTwo};

OrtArenaCfg* default_memory_arena_cfg{nullptr};
size_t max_dynamic_batch{static_cast<size_t>(0)};

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the static cast looks redundant here

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think its due to 0 defaulting to int instead of size_t which is unsigned

@TedThemistokleous
Copy link
Collaborator Author

Merging this in then Cherry-picking to 7.2 so we have coverage for now for 7.2.1 and this changeset

@TedThemistokleous TedThemistokleous merged commit 5ff8193 into rocm7.1_internal_testing Jan 24, 2026
5 of 7 checks passed
@TedThemistokleous
Copy link
Collaborator Author

DO NOT DELETE THIS BRANCH as this branch is still in use

TedThemistokleous added a commit that referenced this pull request Jan 24, 2026
…ecompiled load Improvements to MIGraphX EP (#207)

Set of changes to enable dynamic batch functionality in MIGraphX EP using a padded/slice approach with added caching for models and faster paths for static cases

Summary of Changes

    Dynamic Batch Support with Power-of-Two Padding
    The major feature addition is dynamic batch support where input batch sizes are padded to the nearest power-of-two (1, 2, 4, 8, 16, etc.) enabling model reuse across different batch sizes without recompilation.
    Three-Tier Execution Path Architecture
    The compute function was refactored into three distinct paths:
    Ultra-fast path (execute_ultra_fast_path): When shapes are identical to the last run - just rebinds memory pointers and executes directly
    Fast path (execute_fast_path): When a cached program exists for the shape hash - retrieves from cache and runs
    Standard path (execute_standard_path): Full shape checking with potential recompilation
    Precompilation at Compile Time
    Moved model compilation from runtime to the Compile() phase:
    precompile_all_dynamic_batch_models() - Precompiles all power-of-two batch sizes during initialization
    precompile_static_model() - Precompiles static models during initialization
    handle_precompilation_decision() - Logic to determine if precompilation is possible
    Extensive Caching Infrastructure
    Program cache: cached_programs map storing compiled programs by shape hash
    Shape caching: cached_mgx_param_shapes and cached_mgx_output_shapes to avoid repeated MIGraphX API calls
    Input/output caching: CachedInputParam and CachedOutputParam structures with pre-computed indices and shapes
    Buffer reuse: Padded input buffers and temporary output buffers are kept for reuse
    Input/Output Handling Refactoring
    handle_program_input_outputs() - Centralized function to bind inputs and allocate outputs
    populate_ultra_fast_caches() - Prepares optimized caches for ultra-fast path
    build_input_shapes_in_cached_order() - Ensures consistent shape ordering for cache lookups
    allocate_and_pad_inputs() - Handles GPU memory allocation and padding for dynamic batching
    Output Slicing for Padded Batches
    When running with a padded batch size larger than the original, outputs are stored in temporary buffers and sliced back to the original batch size.
    Parallel Loading / Sequential Compilation
    To improve load time while respecting MIGraphX thread safety:
    Parallel loading from disk cache
    Serialized compilation for cache misses (due to thread safety during MIGraphX compile)
    New Helper Functions
    has_only_dynamic_batch_dimension() - Checks if only the batch dimension is dynamic
    extract_base_shapes_from_graph() - Extracts non-batch dimensions from graph definition
    find_nearest_power_of_two_batch() - Finds the appropriate padded batch size
    generate_power_of_two_batch_sizes() - Generates the list of batch sizes to precompile
    get_input_name_map() / get_program_parameter_options() - Input processing utilities
    Key Benefits
    Reduced runtime compilation - Models are compiled at initialization instead of first inference
    Better batch size handling - Arbitrary batch sizes reuse pre-compiled power-of-two models
    Faster repeated inferences - Ultra-fast path skips shape checking when inputs are unchanged
    Improved memory management - Buffer reuse reduces allocation overhead
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants