-
Notifications
You must be signed in to change notification settings - Fork 4
Shard chains by default #72
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
JAX in general follows data placement to infer which device to use for computations. init() broke this pattern because it inferred the device from the default device in cases where it couldn't access the data device due to tracing. Now in those cases it will instead raise an error and ask the caller to specify the target device.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR implements automatic chain sharding for improved CPU parallelization by default. The changes introduce a new device configuration system that automatically distributes MCMC chains across multiple CPU devices when beneficial, while maintaining backward compatibility with existing code.
Changes:
- Added
target_platformparameter tomcmcstep.initfor explicit platform specification when device inference is not possible - Introduced device-related parameters (
num_chains,num_chain_devices,num_data_devices,devices) to theBartclass interface - Modified
mc_gbartto automatically configure chain sharding on CPU, using negativemc_coresvalues to override automatic platform detection - Renamed parameter
init_kwtobart_kwargsinmc_gbartfor clearer separation of concerns
Reviewed changes
Copilot reviewed 8 out of 8 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
| src/bartz/mcmcstep/_state.py | Added target_platform parameter and _parse_target_platform function to handle platform inference logic |
| src/bartz/_interface.py | Added device configuration parameters to Bart, implemented process_device_settings function, changed ndpost to property |
| src/bartz/BART/_gbart.py | Changed default maxdepth to 8, replaced init_kw with bart_kwargs, added process_mc_cores and get_platform for automatic chain sharding |
| src/bartz/BART/init.py | New init file to make BART a proper package |
| src/bartz/debug.py | Updated imports to reflect package restructuring |
| tests/test_BART.py | Updated tests to use new bart_kwargs interface and handle new sharding defaults |
| tests/test_mcmcloop.py | Added target_platform argument to test helper function |
| benchmarks/speed.py | Added backward compatibility handling for target_platform parameter |
Comments suppressed due to low confidence (2)
src/bartz/BART/_gbart.py:480
- Spelling error: "that that" should be "than that".
src/bartz/BART/_gbart.py:3 - The copyright year is 2026, which is in the future. This should be 2025 or use a range like 2024-2025 consistent with other files in the project.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
I did some testing on my laptop and sharding the chains makes the MCMC faster, it seems able to parallelize better. So I decided that the interface should by default shard along chains on cpu, as I want to provide defaults set for performance.
This PR changes the interface at the
mcmcstep.initandBartlevels, using it internally inmc_gbartto set automatically the number of shard devices. I think the lower level interfaces should instead remain more explicit and require the user to set sharding deliberately instead of making so many automated choices.The new interface in
initsolves the problem that the default device may not correspond to the device actually used in the MCMC. Now if inferring the device is not possible, the caller is forced to make it explicit.The new interface in
Bartwill make it possible to use sharding without passing arguments directly toinit.The adaptations in
mc_gbartwill prompt the user to set the number of cpu devices.