Skip to content

Conversation

@samos123
Copy link
Contributor

  • Utilize a shared memory between the Jax client and pathways proxy for data heavy transfers e.g. device_puts.
  • Increase threads of ThreadPoolExecutor from 32 (python default) to 192.
  • Remove memory limit from pathways head main container.

Callers of deserialize should utilize a concurrent_restore_gb as large as possible until OOM. Otherwise GCS read and device_put won't happen in parallel. The default of 32GB is too low to achieve optimal performance with Pathways.

* Utilize a shared memory between the Jax client and pathways proxy for
  data heavy transfers e.g. device_puts.
* Increase threads of ThreadPoolExecutor from 32 (python default) to
  192.
* Remove memory limit from pathways head main container.

Callers should utilize a concurrent_restore_gb as large as possible
until OOM. Otherwise GCS read and device_put won't happen in parallel.
The default of 32GB is too low to achieve optimal performance with
Pathways.
# This image version extends GRPC timeout for long context models, based on jax-0.5.3-patch060625
# This image extends GRPC timeout for long context models.
_PATHWAYS_IMAGE_TAG = "disable_settings_20250701"
_PATHWAYS_IMAGE_TAG = "shm_proxy"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you double check with Shauray that this binary includes the path of extending GRPC timeout? Or it doesn't need it anymore?

# The flag below is needed for better H2D performance.
# Rule of thumb: 3x the shard size. So 128GB to be safe.
# Decrease if you start running out of host memory on TPU VMs.
"--tpu_premapped_buffer_size=137438953472",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use 1/4 of the machine type's host memory and round up to the oder of 2:

https://github.com/apple/axlearn/blob/main/axlearn/cloud/gcp/system_characteristics.py#L494-L499

self._loop_thread.start()
self._single_thread_pool = ThreadPoolExecutor(1)
self._single_thread_pool = ThreadPoolExecutor(max_workers=1)
self._multi_thread_pool = ThreadPoolExecutor(max_workers=192)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make this a config flag? It depends on how many cpu we allocate to the head pod: https://github.com/apple/axlearn/blob/main/axlearn/cloud/gcp/pathways_utils.py#L317

mem_req = f"{self.config.pathways_head_mem}Gi"
resources = {
"requests": {"cpu": cpu_req, "memory": mem_req},
"limits": {"cpu": cpu_req, "memory": mem_req},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For my education, what's the effect of having "request" and not "limit"?

@github-actions
Copy link

This pull request has been automatically marked as stale because it has been inactive for 60 days. It will be closed in 7 days if no further activity occurs. If you would like to continue working on this, please remove the stale label or leave a comment.

@github-actions github-actions bot added the stale label Dec 27, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants