Guozheng Ma* · Lu Li* · Zilin Wang
Li Shen · Pierre-Luc Bacon · DaCheng Tao
This repository contains the source code required to reproduce the DeepMind Control experiments presented in our paper.
conda env create -f deps/environment.yaml
pip install -U "jax[cuda12]==0.4.30" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
# If you want to execute multiple runs with a single GPU, we recommend to set this variable.
export XLA_PYTHON_CLIENT_PREALLOCATE=false
Below is an example of how to train a SAC agent using the SimBa network with a sparsity level of 0.8 on the humanoid-run environment.
python run.py \
--config_name base_sac \
--overrides seed=0 \
--overrides updates_per_interaction_step=2 \
--overrides actor_sparsity=0.8 \
--overrides actor_num_blocks=1 \
--overrides actor_hidden_dim=128 \
--overrides critic_sparsity=0.8 \
--overrides critic_num_blocks=2 \
--overrides critic_hidden_dim=512 \
--overrides env_name=humanoid-run
We would like to thank the SimBa codebase and JaxPruner. Our implementation builds on top of their repository.
If you have any questions, please raise an issue or send an email to Lu (lu.li@mila.quebec) and Guozheng (guozheng001@e.ntu.edu.sg).