-
Notifications
You must be signed in to change notification settings - Fork 417
Use number of atoms for batch size #1690
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
|
||
| @serve.batch | ||
| @serve.batch( | ||
| batch_size_fn=lambda batch: sum(sample.natoms.sum() for sample in batch).item() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
naive question - how is @serve.batch working here. how does batch_size_fn get incorporated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great question! Thats all implemented in ray. The TLDR is that a BatchQueue class that gets the batch_size_fn is instantiated in the serve.batch decorator.
If you're curious you can see the implementation of the decorator here: https://github.com/ray-project/ray/blob/d4817998ee8476c138e9106280ecefdf1e59ba6b/python/ray/serve/batching.py#L677
| self, | ||
| predict_unit: MLIPPredictUnit, | ||
| max_batch_size: int = 16, | ||
| max_batch_size: int = 512, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like this went from 16 -> 512 here and 32 -> 512 in _batch_serve.py.
Is this interpreted differently now? Before batch size meant the number of structures in the batch, but is it now compared directly to the output of batch_size_fn (which is the number of atoms across all structures)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, thats correct! So the default batch size is set to 512 total number of atoms across all structures in the batch. Though this is approximate, it will break a batch as soon as the total number of atoms is larger than 512.
The next
ray.serverelease will include the option to use a function to determine batch sizes based on input data. ray-project/ray#59059This PR edits our code to use total number of atoms to determine batch sizes in the
BatchPredictServer.