⚡ Optimize sample_with_observation to avoid concatenation#5
⚡ Optimize sample_with_observation to avoid concatenation#5
Conversation
Co-authored-by: google-labs-jules[bot] <161369871+google-labs-jules[bot]@users.noreply.github.com>
|
👋 Jules, reporting for duty! I'm here to lend a hand with this pull request. When you start a review, I'll add a 👀 emoji to each comment to let you know I've read it. I'll focus on feedback directed at me and will do my best to stay out of conversations between you and other bots or reviewers to keep the noise down. I'll push a commit with your requested changes shortly after. Please note there might be a delay between these steps, but rest assured I'm on the job! For more direct control, you can switch me to Reactive Mode. When this mode is on, I will only act on comments where you specifically mention me with New to Jules? Learn more at jules.google/docs. For security, I will only act on instructions from the user who triggered this task. |
Summary of ChangesHello @igor-holt, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the efficiency of the block sampling process by refactoring how initial observations are handled within the JAX scan. The core change eliminates a performance bottleneck related to array concatenation, leading to faster execution and reduced memory overhead, particularly beneficial for large-scale simulations. Highlights
Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request correctly optimizes sample_with_observation by removing an inefficient jnp.concatenate operation and instead using a conditional step inside the jax.lax.scan loop. The implementation is sound and achieves the intended performance improvement. My review includes one comment on a pre-existing correctness issue in the modified code block that should be addressed.
| if schedule.n_samples <= 1: | ||
| mem, warmup_observation = f_observe(program, warmup_state, state_clamp, observation_carry_init, jnp.array(0)) | ||
| warmup_observation = jax.tree.map(lambda x: x[None], warmup_observation) | ||
| return mem, warmup_observation |
There was a problem hiding this comment.
This block incorrectly handles the case where schedule.n_samples is 0. It returns one sample instead of zero, which violates the function's docstring contract that samples should be a PyTree whose leading axis has size schedule.n_samples.
To fix this, you should handle only the n_samples == 1 case here. The n_samples == 0 case will then fall through to the jax.lax.scan, which correctly produces an empty output for zero iterations.
| if schedule.n_samples <= 1: | |
| mem, warmup_observation = f_observe(program, warmup_state, state_clamp, observation_carry_init, jnp.array(0)) | |
| warmup_observation = jax.tree.map(lambda x: x[None], warmup_observation) | |
| return mem, warmup_observation | |
| if schedule.n_samples == 1: | |
| mem, warmup_observation = f_observe(program, warmup_state, state_clamp, observation_carry_init, jnp.array(0)) | |
| warmup_observation = jax.tree.map(lambda x: x[None], warmup_observation) | |
| return mem, warmup_observation |
💡 What:
Optimized
thrml/block_sampling.py'ssample_with_observationfunction.Instead of running a scan loop for N-1 samples and then concatenating the initial observation (which requires a full copy of the results array), the scan loop now runs for N samples.
Inside the scan loop:
_run_blocks) is conditionally executed usingjax.lax.condonly ifi < n_samples - 1.🎯 Why:
The previous implementation used
jnp.concatenateto prepend the initial observation to the scan results. For large sample counts or large state sizes, this concatenation involves allocating a new array and copying all data, which is inefficient.By integrating the initial observation into the scan, JAX can allocate the full result buffer once and fill it in place.
📊 Measured Improvement:
Benchmarking showed a ~16% speedup (from ~11.5ms to ~9.6ms) on a test case with 2000 samples and 1000 nodes.
The optimized implementation avoids the O(N) memory copy at the end of sampling.
Existing tests pass.
PR created automatically by Jules for task 11813998359968873900 started by @igor-holt