Skip to content

Conversation

@lbluque
Copy link
Contributor

@lbluque lbluque commented Dec 18, 2025

The next ray.serve release will include the option to use a function to determine batch sizes based on input data. ray-project/ray#59059

This PR edits our code to use total number of atoms to determine batch sizes in the BatchPredictServer.

@lbluque lbluque added enhancement New feature or request patch Patch version release labels Dec 18, 2025
@meta-cla meta-cla bot added the cla signed label Dec 18, 2025
@lbluque lbluque marked this pull request as draft December 18, 2025 23:43
@lbluque lbluque marked this pull request as ready for review January 6, 2026 20:59
@lbluque lbluque requested a review from rayg1234 January 6, 2026 21:00
@lbluque lbluque requested review from kjmichel and mshuaibii January 8, 2026 18:43

@serve.batch
@serve.batch(
batch_size_fn=lambda batch: sum(sample.natoms.sum() for sample in batch).item()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

naive question - how is @serve.batch working here. how does batch_size_fn get incorporated

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great question! Thats all implemented in ray. The TLDR is that a BatchQueue class that gets the batch_size_fn is instantiated in the serve.batch decorator.

If you're curious you can see the implementation of the decorator here: https://github.com/ray-project/ray/blob/d4817998ee8476c138e9106280ecefdf1e59ba6b/python/ray/serve/batching.py#L677

@lbluque lbluque requested a review from mshuaibii January 9, 2026 20:54
self,
predict_unit: MLIPPredictUnit,
max_batch_size: int = 16,
max_batch_size: int = 512,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like this went from 16 -> 512 here and 32 -> 512 in _batch_serve.py.

Is this interpreted differently now? Before batch size meant the number of structures in the batch, but is it now compared directly to the output of batch_size_fn (which is the number of atoms across all structures)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, thats correct! So the default batch size is set to 512 total number of atoms across all structures in the batch. Though this is approximate, it will break a batch as soon as the total number of atoms is larger than 512.

@lbluque lbluque requested a review from kjmichel January 10, 2026 00:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed enhancement New feature or request patch Patch version release

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants