-
Notifications
You must be signed in to change notification settings - Fork 443
Feat: Update RL on Multi-Host TPUs tutorial for clarity and structure #2890
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,9 +14,9 @@ | |
| limitations under the License. | ||
| --> | ||
|
|
||
| # Reinforcement Learning on multi-host TPUs | ||
| # Reinforcement Learning on Multi-Host TPUs | ||
|
|
||
| This tutorial demonstrates step-by-step instructions for setting up the environment and then training the Llama3.1 70B-IT model on the GSM8K math reasoning dataset using [Pathways for orchestration](https://cloud.google.com/ai-hypercomputer/docs/workloads/pathways-on-cloud/pathways-intro) on multi-host TPU-VMs such as `v5p-128`. | ||
| This tutorial provides step-by-step instructions for setting up the environment and training the Llama3.1 70B-IT model on the GSM8K math reasoning dataset using [Pathways for orchestration](https://cloud.google.com/ai-hypercomputer/docs/workloads/pathways-on-cloud/pathways-intro) on multi-host TPU-VMs, such as `v5p-128`. | ||
|
|
||
| We utilize two RL algorithms, implemented via the Tunix library, to enhance the model's reasoning capabilities: | ||
|
|
||
|
|
@@ -26,58 +26,73 @@ We utilize two RL algorithms, implemented via the Tunix library, to enhance the | |
|
|
||
| For efficient model inference and response generation during this process, we rely on the vLLM library. | ||
|
|
||
| Let's get started! | ||
| ## Table of Contents | ||
|
|
||
| ## Create virtual environment and Install MaxText dependencies | ||
| Follow instructions in [Install MaxText](../../install_maxtext.md), but | ||
| recommend creating the virtual environment outside the `maxtext` directory. | ||
| - [Prerequisites](#prerequisites) | ||
| - [Setup Environment Variables](#setup-environment-variables) | ||
| - [Get Your Model Checkpoint](#get-your-model-checkpoint) | ||
| - [Build and Upload MaxText Docker Image](#build-and-upload-maxtext-docker-image) | ||
| - [Submit your RL workload via Pathways](#submit-your-rl-workload-via-pathways) | ||
| - [Managing Workloads](#managing-workloads) | ||
|
|
||
| ## Prerequisites | ||
|
|
||
| ## Setup environment variables | ||
| Before starting, ensure you have: | ||
| - Access to a Google Cloud Project with TPU quotas. | ||
| - A Hugging Face account with an access token for downloading models. | ||
| - Permissions for Google Artifact Registry (Artifact Registry Writer role). | ||
| - XPK installed (follow [official documentation](https://github.com/AI-Hypercomputer/xpk/blob/main/docs/installation.md)). | ||
| - A Pathways-ready GKE cluster (see [create GKE cluster](https://docs.cloud.google.com/ai-hypercomputer/docs/workloads/pathways-on-cloud/create-gke-cluster)). | ||
|
|
||
| Setup following environment variables: | ||
| ## Setup Environment Variables | ||
|
|
||
| Set up the following environment variables. Replace placeholders with your actual values. | ||
|
|
||
| ```bash | ||
| # -- Model configuration -- | ||
| export HF_MODEL='llama3.1-70b-Instruct' | ||
| export MODEL='llama3.1-70b' | ||
| export HF_MODEL='llama3.1-70b-Instruct' # Hugging Face model name for checkpoint conversion | ||
| export MODEL='llama3.1-70b' # MaxText model name for training | ||
| export TOKENIZER='meta-llama/Llama-3.1-70B-Instruct' | ||
| export HF_TOKEN=<Hugging Face access token> | ||
|
|
||
| # -- MaxText configuration -- | ||
| export BASE_OUTPUT_DIRECTORY=<output directory to store run logs> # e.g., gs://my-bucket/my-output-directory | ||
| export RUN_NAME=llama-3-70b-grpo | ||
| export MAXTEXT_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/${RUN_NAME}/0/items | ||
| export WORKLOAD=llama-3-70b-grpo | ||
| export MAXTEXT_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/${WORKLOAD}/0/items | ||
|
|
||
| # -- Workload configuration -- | ||
| export WORKLOAD=${RUN_NAME} | ||
| export TPU_TYPE='v5p-128' | ||
| export TPU_TYPE='v5p-128' | ||
| export TPU_CLUSTER=<cluster name> | ||
| export PROJECT_ID=<GCP project ID> | ||
| export ZONE=<zone name> | ||
| export CLOUD_IMAGE_NAME=<your artifact registry image> | ||
| ``` | ||
|
|
||
| ## Get your model checkpoint | ||
| ## Get Your Model Checkpoint | ||
|
|
||
| You can convert a Hugging Face checkpoint to MaxText format using the `src/MaxText/utils/ckpt_conversion/to_maxtext.py` script. This is useful if you have a pre-trained model from Hugging Face that you want to use with MaxText. | ||
|
|
||
| First, ensure you have the necessary dependencies installed. Then, run the conversion script on a CPU machine. For large models, it is recommended to use the `--lazy_load_tensors` flag to reduce memory usage during conversion. \ | ||
| For example, converting a Llama3.1-70B model scanned checkpoint using `--lazy_load_tensors=true` will use around 200GB of RAM and completes in ~10 mins. This command will download the Hugging Face model and convert it to the MaxText format, saving it to the specified GCS bucket. | ||
| First, ensure you have the necessary dependencies installed (PyTorch for the conversion script). Then, run the conversion script on a CPU machine. For large models, use the `--lazy_load_tensors` flag to reduce memory usage during conversion. | ||
|
|
||
| For example, converting a Llama3.1-70B model with `--lazy_load_tensors=true` uses around 200GB of RAM and completes in ~10 mins. This command will download the Hugging Face model and convert it to the MaxText format, saving it to the specified GCS bucket. | ||
|
|
||
| ```bash | ||
| python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu | ||
|
|
||
| # using --lazy_load_tensors=true here will reduce the memory usage. eg, Llama3.1-70B conversion takes around 86GB of RAM | ||
| python3 -m MaxText.utils.ckpt_conversion.to_maxtext MaxText/configs/base.yml \ | ||
| model_name=${HF_MODEL} \ | ||
| hf_access_token=${HF_TOKEN} \ | ||
| base_output_directory=${BASE_OUTPUT_DIRECTORY}/${RUN_NAME} \ | ||
| base_output_directory=${BASE_OUTPUT_DIRECTORY}/${WORKLOAD} \ | ||
| scan_layers=true checkpoint_storage_use_ocdbt=false checkpoint_storage_use_zarr3=false \ | ||
| skip_jax_distributed_system=true --lazy_load_tensors=true | ||
| ``` | ||
|
|
||
| ## Build and Upload MaxText Docker Image with Tunix, vLLM, tpu-inference dependencies | ||
| Before building the Docker image, authenticate to [Google Artifact Registry](https://docs.cloud.google.com/artifact-registry/docs/docker/authentication#gcloud-helper) for permission to push your images and other access. | ||
| This command downloads the Hugging Face model and converts it to MaxText format, saving it to the specified GCS bucket. | ||
|
|
||
| ## Build and Upload MaxText Docker Image | ||
|
|
||
| Before building the Docker image, authenticate to [Google Artifact Registry](https://docs.cloud.google.com/artifact-registry/docs/docker/authentication#gcloud-helper) for permission to push your images. | ||
|
|
||
| ```bash | ||
| # Authenticate your user account for gcloud CLI access | ||
| gcloud auth login | ||
|
|
@@ -88,65 +103,116 @@ gcloud auth configure-docker | |
| docker run hello-world | ||
| ``` | ||
|
|
||
| You can install the required dependencies using either of the following two options: | ||
| ### Option 1: Install Stable Releases | ||
|
|
||
| ### Option 1: Installing stable releases of tunix and vllm-tpu | ||
| Run the following bash script to create a docker image with all the dependencies of MaxText, Tunix, vLLM and tpu-inference installed. | ||
| Run the following script to create a Docker image with stable releases of MaxText, Tunix, vLLM, and tpu-inference dependencies. This installs `vllm-tpu` which provides TPU inference for vLLM with unified JAX and PyTorch support. The build process takes approximately 10-15 minutes. | ||
|
|
||
| In addition to MaxText dependencies, primarily, it installs `vllm-tpu` which is [vllm](https://github.com/vllm-project/vllm) and [tpu-inference](https://github.com/vllm-project/tpu-inference) and thereby providing TPU inference for vLLM, with unified JAX and PyTorch support. This build process takes approximately 10 to 15 minutes. | ||
|
|
||
| ``` | ||
| ```bash | ||
| bash dependencies/scripts/docker_build_dependency_image.sh MODE=post-training | ||
| ``` | ||
|
|
||
| You can also use `bash dependencies/scripts/docker_build_dependency_image.sh MODE=post-training-experimental` to try out new features via experimental dependencies such as improved pathwaysutils resharding API. | ||
| For experimental features (such as improved pathwaysutils resharding API), use: | ||
|
|
||
| ```bash | ||
| bash dependencies/scripts/docker_build_dependency_image.sh MODE=post-training-experimental | ||
| ``` | ||
|
|
||
| ### Option 2: Install from locally git cloned repositories | ||
|
|
||
| You can also locally git clone [tunix](https://github.com/google/tunix), [tpu-inference](https://github.com/vllm-project/tpu-inference), [vllm](https://github.com/vllm-project/vllm.git) and then use the following command to build a docker image using them: | ||
| ``` | ||
| You can also locally clone the [tunix](https://github.com/google/tunix), [tpu-inference](https://github.com/vllm-project/tpu-inference), and [vllm](https://github.com/vllm-project/vllm.git) repositories and then build using MaxText's local sources. | ||
|
|
||
| **Note:** Clone these repositories as siblings of the `maxtext` directory (e.g., in the same parent directory). After cloning, run the build from inside the `maxtext` repository so it picks up the local sources: | ||
|
|
||
| ```bash | ||
| bash dependencies/scripts/docker_build_dependency_image.sh MODE=post-training POST_TRAINING_SOURCE=local | ||
| ``` | ||
|
|
||
| ### Upload the dependency docker image along with MaxText code | ||
| > **Note:** You will need the [**Artifact Registry Writer**](https://docs.cloud.google.com/artifact-registry/docs/access-control#permissions) role to push Docker images to your project's Artifact Registry and to allow the cluster to pull them during workload execution. If you don't have this permission, contact your project administrator to grant you this role through "Google Cloud Console -> IAM -> Grant access". | ||
| ``` | ||
| ### Upload the Docker Image | ||
|
|
||
| > **Note:** You will need the [**Artifact Registry Writer**](https://docs.cloud.google.com/artifact-registry/docs/access-control#permissions) role to push Docker images to your project's Artifact Registry. Contact your project administrator if you don't have this permission. | ||
|
|
||
| ```bash | ||
| bash dependencies/scripts/docker_upload_runner.sh CLOUD_IMAGE_NAME=${CLOUD_IMAGE_NAME} | ||
| ``` | ||
|
|
||
| ## Submit your RL workload via Pathways | ||
|
|
||
| Please create a pathways ready GKE cluster as described [here](https://docs.cloud.google.com/ai-hypercomputer/docs/workloads/pathways-on-cloud/create-gke-cluster), and you can submit the `train_rl.py` script via XPK. You can install XPK by following the instructions in the [official documentation](https://github.com/AI-Hypercomputer/xpk/blob/main/docs/installation.md). | ||
| See the **Troubleshooting** section for concise instructions on how to retry or resume a failed workload. | ||
|
|
||
| Ensure you have a Pathways-ready GKE cluster (as mentioned in Prerequisites) and submit the `train_rl.py` script via XPK. | ||
|
|
||
| ### Submit GRPO workload | ||
| ``` | ||
| xpk workload create-pathways --workload $WORKLOAD \ | ||
| --docker-image <path/to/gcr.io> --cluster $TPU_CLUSTER \ | ||
| --docker-image gcr.io/$PROJECT_ID/$CLOUD_IMAGE_NAME --cluster $TPU_CLUSTER \ | ||
| --tpu-type=$TPU_TYPE --num-slices=1 --zone=$ZONE \ | ||
| --project=$PROJECT_ID --priority=high \ | ||
| --command "TF_CPP_MIN_LOG_LEVEL=0 JAX_PLATFORMS=proxy JAX_BACKEND_TARGET=grpc://127.0.0.1:29000 ENABLE_PATHWAYS_PERSISTENCE='1' \ | ||
| --command "HF_TOKEN=${HF_TOKEN} TF_CPP_MIN_LOG_LEVEL=0 JAX_PLATFORMS=proxy JAX_BACKEND_TARGET=grpc://127.0.0.1:29000 ENABLE_PATHWAYS_PERSISTENCE='1' \ | ||
| python3 -m src.MaxText.rl.train_rl src/MaxText/configs/rl.yml \ | ||
| model_name=${MODEL} \ | ||
| tokenizer_path=${TOKENIZER} \ | ||
| load_parameters_path=${MAXTEXT_CKPT_PATH} \ | ||
| run_name=${RUN_NAME} \ | ||
| run_name=${WORKLOAD} \ | ||
| base_output_directory=${BASE_OUTPUT_DIRECTORY} \ | ||
| hf_access_token=$HF_TOKEN" | ||
| hf_access_token=${HF_TOKEN}" | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need to set
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Technically no, as the code block currently ignores that flag. However, I suggest keeping it since it's consistent with our docs and other examples. It causes no issues, and the implementation might be updated to use it later anyway. |
||
| ``` | ||
|
|
||
| ### Submit GSPO workload | ||
| ``` | ||
| xpk workload create-pathways --workload $WORKLOAD \ | ||
| --docker-image <path/to/gcr.io> --cluster $TPU_CLUSTER \ | ||
| --docker-image gcr.io/$PROJECT_ID/$CLOUD_IMAGE_NAME --cluster $TPU_CLUSTER \ | ||
| --tpu-type=$TPU_TYPE --num-slices=1 --zone=$ZONE \ | ||
| --project=$PROJECT_ID --priority=high \ | ||
| --command "TF_CPP_MIN_LOG_LEVEL=0 JAX_PLATFORMS=proxy JAX_BACKEND_TARGET=grpc://127.0.0.1:29000 ENABLE_PATHWAYS_PERSISTENCE='1' \ | ||
| --command "HF_TOKEN=${HF_TOKEN} TF_CPP_MIN_LOG_LEVEL=0 JAX_PLATFORMS=proxy JAX_BACKEND_TARGET=grpc://127.0.0.1:29000 ENABLE_PATHWAYS_PERSISTENCE='1' \ | ||
| python3 -m src.MaxText.rl.train_rl src/MaxText/configs/rl.yml \ | ||
| model_name=${MODEL} \ | ||
| tokenizer_path=${TOKENIZER} \ | ||
| load_parameters_path=${MAXTEXT_CKPT_PATH} \ | ||
| run_name=${RUN_NAME} \ | ||
| run_name=${WORKLOAD} \ | ||
| base_output_directory=${BASE_OUTPUT_DIRECTORY} \ | ||
| hf_access_token=$HF_TOKEN \ | ||
| hf_access_token=${HF_TOKEN} \ | ||
| loss_algo=gspo-token" | ||
| ``` | ||
|
|
||
| ## Managing Workloads | ||
|
|
||
| - **Monitor workload status**: Check Pathways job status: | ||
| ```bash | ||
| kubectl get pathwaysjob | ||
| ``` | ||
| Check pod status: | ||
| ```bash | ||
| kubectl get pods | ||
| ``` | ||
| - **Delete a workload**: To remove a failed or unwanted Pathways job, use XPK: | ||
| ```bash | ||
| xpk workload delete \ | ||
| --workload $WORKLOAD \ | ||
| --cluster $TPU_CLUSTER \ | ||
| --project $PROJECT_ID | ||
| ``` | ||
|
|
||
| ## Troubleshooting | ||
|
|
||
| - **Authentication Issues**: Ensure your `HF_TOKEN` environment variable is set correctly and has access to the required models. | ||
| - **Resource Quotas**: Verify you have sufficient TPU quotas in your GCP project. | ||
| - **Docker Build Failures**: Check that all dependencies are correctly installed and authentication is configured. | ||
| - **Workload Failures**: Review the logs for specific error messages and ensure all environment variables are properly set. | ||
| - **Workload retry / resume**: | ||
| - **Retry (fresh run)**: Use a unique workload name to avoid overwriting outputs: | ||
| ```bash | ||
| export WORKLOAD=${WORKLOAD}-retry1 | ||
| export MAXTEXT_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/${WORKLOAD}/0/items | ||
| ``` | ||
| Re-submit the XPK workload. If "workload already exists" error occurs, pick a new name or list jobs: | ||
| ```bash | ||
| kubectl get pathwaysjob | ||
| ``` | ||
| - **Resume from checkpoint**: Keep the same `WORKLOAD` and set the checkpoint path: | ||
| ```bash | ||
| export load_parameters_path=${MAXTEXT_CKPT_PATH}/checkpoint-0000 | ||
| ``` | ||
| Then re-submit the job. | ||
| - **Tip**: Verify checkpoint exists in GCS with read access before resuming. | ||
|
|
||
| For more detailed troubleshooting, refer to the [MaxText documentation](https://maxtext.readthedocs.io) and [XPK documentation](https://github.com/AI-Hypercomputer/xpk). | ||
Uh oh!
There was an error while loading. Please reload this page.