Skip to content

⚡ Optimize sample_with_observation to avoid concatenation#5

Draft
igor-holt wants to merge 1 commit intomainfrom
jules-optimize-sampling-scan-11813998359968873900
Draft

⚡ Optimize sample_with_observation to avoid concatenation#5
igor-holt wants to merge 1 commit intomainfrom
jules-optimize-sampling-scan-11813998359968873900

Conversation

@igor-holt
Copy link
Owner

💡 What:
Optimized thrml/block_sampling.py's sample_with_observation function.
Instead of running a scan loop for N-1 samples and then concatenating the initial observation (which requires a full copy of the results array), the scan loop now runs for N samples.
Inside the scan loop:

  1. The current state is observed first.
  2. The transition step (_run_blocks) is conditionally executed using jax.lax.cond only if i < n_samples - 1.

🎯 Why:
The previous implementation used jnp.concatenate to prepend the initial observation to the scan results. For large sample counts or large state sizes, this concatenation involves allocating a new array and copying all data, which is inefficient.
By integrating the initial observation into the scan, JAX can allocate the full result buffer once and fill it in place.

📊 Measured Improvement:
Benchmarking showed a ~16% speedup (from ~11.5ms to ~9.6ms) on a test case with 2000 samples and 1000 nodes.
The optimized implementation avoids the O(N) memory copy at the end of sampling.
Existing tests pass.


PR created automatically by Jules for task 11813998359968873900 started by @igor-holt

Co-authored-by: google-labs-jules[bot] <161369871+google-labs-jules[bot]@users.noreply.github.com>
@google-labs-jules
Copy link

👋 Jules, reporting for duty! I'm here to lend a hand with this pull request.

When you start a review, I'll add a 👀 emoji to each comment to let you know I've read it. I'll focus on feedback directed at me and will do my best to stay out of conversations between you and other bots or reviewers to keep the noise down.

I'll push a commit with your requested changes shortly after. Please note there might be a delay between these steps, but rest assured I'm on the job!

For more direct control, you can switch me to Reactive Mode. When this mode is on, I will only act on comments where you specifically mention me with @jules. You can find this option in the Pull Request section of your global Jules UI settings. You can always switch back!

New to Jules? Learn more at jules.google/docs.


For security, I will only act on instructions from the user who triggered this task.

@gemini-code-assist
Copy link

Summary of Changes

Hello @igor-holt, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the efficiency of the block sampling process by refactoring how initial observations are handled within the JAX scan. The core change eliminates a performance bottleneck related to array concatenation, leading to faster execution and reduced memory overhead, particularly beneficial for large-scale simulations.

Highlights

  • Performance Optimization: The sample_with_observation function in thrml/block_sampling.py has been optimized to avoid inefficient jnp.concatenate operations, resulting in a ~16% speedup and elimination of O(N) memory copies for large sample counts.
  • JAX Scan Integration: The initial observation is now integrated directly into the jax.lax.scan loop, allowing JAX to allocate the full result buffer once and fill it in place, rather than prepending it after the scan.
  • Conditional Execution: Conditional logic for stepping the state within the scan loop is now handled using jax.lax.cond, ensuring that the transition step (_run_blocks) is only executed when necessary.
Changelog
  • thrml/block_sampling.py
    • Modified sample_with_observation to remove the explicit warmup_observation calculation and subsequent concatenation.
    • Updated the body_fn of jax.lax.scan to observe the current state first and then conditionally advance the state using jax.lax.cond.
    • Adjusted the keys and outer_iters inputs for jax.lax.scan to cover the full range of n_samples.
    • Changed the initial carry for jax.lax.scan to use observation_carry_init directly.
Activity
  • No human activity has occurred on this pull request yet.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly optimizes sample_with_observation by removing an inefficient jnp.concatenate operation and instead using a conditional step inside the jax.lax.scan loop. The implementation is sound and achieves the intended performance improvement. My review includes one comment on a pre-existing correctness issue in the modified code block that should be addressed.

Comment on lines 469 to 472
if schedule.n_samples <= 1:
mem, warmup_observation = f_observe(program, warmup_state, state_clamp, observation_carry_init, jnp.array(0))
warmup_observation = jax.tree.map(lambda x: x[None], warmup_observation)
return mem, warmup_observation

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This block incorrectly handles the case where schedule.n_samples is 0. It returns one sample instead of zero, which violates the function's docstring contract that samples should be a PyTree whose leading axis has size schedule.n_samples.

To fix this, you should handle only the n_samples == 1 case here. The n_samples == 0 case will then fall through to the jax.lax.scan, which correctly produces an empty output for zero iterations.

Suggested change
if schedule.n_samples <= 1:
mem, warmup_observation = f_observe(program, warmup_state, state_clamp, observation_carry_init, jnp.array(0))
warmup_observation = jax.tree.map(lambda x: x[None], warmup_observation)
return mem, warmup_observation
if schedule.n_samples == 1:
mem, warmup_observation = f_observe(program, warmup_state, state_clamp, observation_carry_init, jnp.array(0))
warmup_observation = jax.tree.map(lambda x: x[None], warmup_observation)
return mem, warmup_observation

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant