diff --git a/baselines/run_baseline_parallel.py b/baselines/run_baseline_parallel.py index 364c5afe5..3960687e6 100644 --- a/baselines/run_baseline_parallel.py +++ b/baselines/run_baseline_parallel.py @@ -8,6 +8,7 @@ from stable_baselines3.common.utils import set_random_seed from stable_baselines3.common.callbacks import CheckpointCallback from argparse_pokemon import * +import os def make_env(rank, env_conf, seed=0): """ @@ -39,8 +40,10 @@ def _init(): } env_config = change_env(env_config, args) - - num_cpu = 44 #64 #46 # Also sets the number of episodes per training iteration + + # Set the number of cpus dynamically + num_cpu = os.cpu_count() + env = SubprocVecEnv([make_env(i, env_config) for i in range(num_cpu)]) checkpoint_callback = CheckpointCallback(save_freq=ep_length, save_path=sess_path, @@ -62,4 +65,4 @@ def _init(): model = PPO('CnnPolicy', env, verbose=1, n_steps=ep_length, batch_size=512, n_epochs=1, gamma=0.999) for i in range(learn_steps): - model.learn(total_timesteps=(ep_length)*num_cpu*1000, callback=checkpoint_callback) \ No newline at end of file + model.learn(total_timesteps=(ep_length)*num_cpu*1000, callback=checkpoint_callback)