Skip to content

Conversation

@Gattocrucco
Copy link
Collaborator

I did some testing on my laptop and sharding the chains makes the MCMC faster, it seems able to parallelize better. So I decided that the interface should by default shard along chains on cpu, as I want to provide defaults set for performance.

This PR changes the interface at the mcmcstep.init and Bart levels, using it internally in mc_gbart to set automatically the number of shard devices. I think the lower level interfaces should instead remain more explicit and require the user to set sharding deliberately instead of making so many automated choices.

The new interface in init solves the problem that the default device may not correspond to the device actually used in the MCMC. Now if inferring the device is not possible, the caller is forced to make it explicit.

The new interface in Bart will make it possible to use sharding without passing arguments directly to init.

The adaptations in mc_gbart will prompt the user to set the number of cpu devices.

JAX in general follows data placement to infer which device to use for
computations. init() broke this pattern because it inferred the device
from the default device in cases where it couldn't access the data
device due to tracing. Now in those cases it will instead raise an error
and ask the caller to specify the target device.
@Gattocrucco Gattocrucco marked this pull request as ready for review January 12, 2026 23:39
@Gattocrucco Gattocrucco requested a review from Copilot January 12, 2026 23:39
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR implements automatic chain sharding for improved CPU parallelization by default. The changes introduce a new device configuration system that automatically distributes MCMC chains across multiple CPU devices when beneficial, while maintaining backward compatibility with existing code.

Changes:

  • Added target_platform parameter to mcmcstep.init for explicit platform specification when device inference is not possible
  • Introduced device-related parameters (num_chains, num_chain_devices, num_data_devices, devices) to the Bart class interface
  • Modified mc_gbart to automatically configure chain sharding on CPU, using negative mc_cores values to override automatic platform detection
  • Renamed parameter init_kw to bart_kwargs in mc_gbart for clearer separation of concerns

Reviewed changes

Copilot reviewed 8 out of 8 changed files in this pull request and generated 2 comments.

Show a summary per file
File Description
src/bartz/mcmcstep/_state.py Added target_platform parameter and _parse_target_platform function to handle platform inference logic
src/bartz/_interface.py Added device configuration parameters to Bart, implemented process_device_settings function, changed ndpost to property
src/bartz/BART/_gbart.py Changed default maxdepth to 8, replaced init_kw with bart_kwargs, added process_mc_cores and get_platform for automatic chain sharding
src/bartz/BART/init.py New init file to make BART a proper package
src/bartz/debug.py Updated imports to reflect package restructuring
tests/test_BART.py Updated tests to use new bart_kwargs interface and handle new sharding defaults
tests/test_mcmcloop.py Added target_platform argument to test helper function
benchmarks/speed.py Added backward compatibility handling for target_platform parameter
Comments suppressed due to low confidence (2)

src/bartz/BART/_gbart.py:480

  • Spelling error: "that that" should be "than that".
    src/bartz/BART/_gbart.py:3
  • The copyright year is 2026, which is in the future. This should be 2025 or use a range like 2024-2025 consistent with other files in the project.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@Gattocrucco Gattocrucco merged commit a587b36 into main Jan 13, 2026
10 checks passed
@Gattocrucco Gattocrucco deleted the shard-chains-default branch January 13, 2026 00:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants