⚡ Optimize sample_blocks with incremental global state updates#7
⚡ Optimize sample_blocks with incremental global state updates#7
Conversation
Avoid reconstructing the global state in every iteration of the sampling loop. Use incremental updates with `jax.lax.dynamic_update_slice` to reduce memory allocations and improve performance for large models. Added `get_block_location` helper function to `thrml/block_management.py`. Co-authored-by: google-labs-jules[bot] <161369871+google-labs-jules[bot]@users.noreply.github.com>
|
👋 Jules, reporting for duty! I'm here to lend a hand with this pull request. When you start a review, I'll add a 👀 emoji to each comment to let you know I've read it. I'll focus on feedback directed at me and will do my best to stay out of conversations between you and other bots or reviewers to keep the noise down. I'll push a commit with your requested changes shortly after. Please note there might be a delay between these steps, but rest assured I'm on the job! For more direct control, you can switch me to Reactive Mode. When this mode is on, I will only act on comments where you specifically mention me with New to Jules? Learn more at jules.google/docs. For security, I will only act on instructions from the user who triggered this task. |
Summary of ChangesHello @igor-holt, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a significant optimization to the block sampling process by changing how the global state is managed. Instead of repeatedly rebuilding the entire global state, the system now initializes it once and applies incremental updates. This approach drastically reduces memory allocation pressure and improves the scalability and performance of the sampling loop, especially for models with a large number of blocks or extensive state sizes. Highlights
Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces a nice optimization to sample_blocks by avoiding the repeated reconstruction of global_state. The change to perform an incremental update is well-implemented and the addition of the get_block_location helper function is logical.
My review includes a couple of suggestions to improve maintainability: one regarding test coverage for the new helper function, and another to improve the clarity of a variable name in the update logic. Overall, this is a solid improvement.
| def get_block_location(block: Block, spec: BlockSpec) -> tuple[int, int]: | ||
| """ | ||
| Get the global state location (sd_index, start_index) for a block. | ||
| Assumes block nodes are contiguous in global state. | ||
|
|
||
| **Arguments:** | ||
|
|
||
| - `block`: The [`thrml.Block`][] whose location is needed. | ||
| - `spec`: The [`thrml.BlockSpec`][] that defines the mapping. | ||
|
|
||
| **Returns:** | ||
|
|
||
| Tuple ``(sd_index, start_index)`` where | ||
|
|
||
| * *sd_index* is the position inside the global list returned by | ||
| [`thrml.block_state_to_global`][], and | ||
| * *start_index* is the starting index in the node dimension for the block. | ||
| """ | ||
| if not block.nodes: | ||
| raise ValueError("Empty block") | ||
|
|
||
| node = block.nodes[0] | ||
| return spec.node_global_location_map[node] |
There was a problem hiding this comment.
The new public function get_block_location is a great addition for this optimization. However, it currently lacks unit tests. Please consider adding tests for this function in tests/test_block_management.py to ensure its correctness and prevent future regressions. A new test method in the TestBlocks class could verify that it returns the correct sd_index and start_index for each block in the test configurations.
| def update_leaf(global_leaf, update_leaf): | ||
| start_indices = (start_ind,) + (0,) * (global_leaf.ndim - 1) | ||
| return jax.lax.dynamic_update_slice(global_leaf, update_leaf, start_indices) |
There was a problem hiding this comment.
In the nested function update_leaf, the second argument is also named update_leaf. This shadows the function name and can be confusing. To improve clarity, please consider renaming the argument to something more descriptive, like new_leaf_slice.
| def update_leaf(global_leaf, update_leaf): | |
| start_indices = (start_ind,) + (0,) * (global_leaf.ndim - 1) | |
| return jax.lax.dynamic_update_slice(global_leaf, update_leaf, start_indices) | |
| def update_leaf(global_leaf, new_leaf_slice): | |
| start_indices = (start_ind,) + (0,) * (global_leaf.ndim - 1) | |
| return jax.lax.dynamic_update_slice(global_leaf, new_leaf_slice, start_indices) |
💡 What: Optimized
sample_blocksinthrml/block_sampling.pyto avoid reconstructingglobal_statefrom scratch in every iteration of the sampling loop. Instead,global_stateis constructed once before the loop and updated incrementally usingjax.lax.dynamic_update_sliceas blocks are sampled.A helper function
get_block_locationwas added tothrml/block_management.pyto efficiently retrieve the start index of a block within the global state arrays.🎯 Why: Reconstructing the global state involves concatenating arrays for all blocks, which is an O(N) operation where N is the total number of blocks (or total state size). For models with a large number of blocks or large state size, this becomes a significant bottleneck and increases memory allocation pressure. Incremental updates reduce this overhead, making the sampling loop more scalable.
📊 Measured Improvement:
Benchmarking showed a consistent performance improvement for large models.
While the speedup is modest for this specific benchmark size, the primary benefit is the reduction in memory allocation frequency (from O(Groups) allocations per step to O(1) + incremental updates), which is crucial for scaling to very large models where memory bandwidth and allocation are the limiting factors. Verified correctness with existing test suite (
tests/test_block_sampling.py,tests/test_block_management.py).PR created automatically by Jules for task 12139521339407399545 started by @igor-holt