-
Notifications
You must be signed in to change notification settings - Fork 392
Improve pathways checkpoint load times #1345
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
* Utilize a shared memory between the Jax client and pathways proxy for data heavy transfers e.g. device_puts. * Increase threads of ThreadPoolExecutor from 32 (python default) to 192. * Remove memory limit from pathways head main container. Callers should utilize a concurrent_restore_gb as large as possible until OOM. Otherwise GCS read and device_put won't happen in parallel. The default of 32GB is too low to achieve optimal performance with Pathways.
| # This image version extends GRPC timeout for long context models, based on jax-0.5.3-patch060625 | ||
| # This image extends GRPC timeout for long context models. | ||
| _PATHWAYS_IMAGE_TAG = "disable_settings_20250701" | ||
| _PATHWAYS_IMAGE_TAG = "shm_proxy" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you double check with Shauray that this binary includes the path of extending GRPC timeout? Or it doesn't need it anymore?
| # The flag below is needed for better H2D performance. | ||
| # Rule of thumb: 3x the shard size. So 128GB to be safe. | ||
| # Decrease if you start running out of host memory on TPU VMs. | ||
| "--tpu_premapped_buffer_size=137438953472", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's use 1/4 of the machine type's host memory and round up to the oder of 2:
https://github.com/apple/axlearn/blob/main/axlearn/cloud/gcp/system_characteristics.py#L494-L499
| self._loop_thread.start() | ||
| self._single_thread_pool = ThreadPoolExecutor(1) | ||
| self._single_thread_pool = ThreadPoolExecutor(max_workers=1) | ||
| self._multi_thread_pool = ThreadPoolExecutor(max_workers=192) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we make this a config flag? It depends on how many cpu we allocate to the head pod: https://github.com/apple/axlearn/blob/main/axlearn/cloud/gcp/pathways_utils.py#L317
axlearn/cloud/gcp/pathways_utils.py
Outdated
| mem_req = f"{self.config.pathways_head_mem}Gi" | ||
| resources = { | ||
| "requests": {"cpu": cpu_req, "memory": mem_req}, | ||
| "limits": {"cpu": cpu_req, "memory": mem_req}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For my education, what's the effect of having "request" and not "limit"?
|
This pull request has been automatically marked as stale because it has been inactive for 60 days. It will be closed in 7 days if no further activity occurs. If you would like to continue working on this, please remove the |
Callers of deserialize should utilize a concurrent_restore_gb as large as possible until OOM. Otherwise GCS read and device_put won't happen in parallel. The default of 32GB is too low to achieve optimal performance with Pathways.