@@ -529,17 +529,20 @@ def _collect_hyperparameters(
529529
530530 losses = getattr (trainer_cfg , "losses" , None )
531531 loss_configs = getattr (losses , "loss_configs" , {}) if losses else {}
532- ppo_cfg = loss_configs .get ("ppo" ) if isinstance (loss_configs , dict ) else None
533- if ppo_cfg is not None :
534- for attr in (
535- "clip_coef" ,
536- "vf_clip_coef" ,
537- "ent_coef" ,
538- "l2_reg_loss_coef" ,
539- "l2_init_loss_coef" ,
540- ):
541- value = getattr (ppo_cfg , attr , None )
542- if value is None :
543- continue
544- hyperparameters [f"ppo_{ attr } " ] = value
532+ if isinstance (loss_configs , dict ):
533+ ppo_actor_cfg = loss_configs .get ("ppo_actor" )
534+ if ppo_actor_cfg is not None :
535+ for attr in ("clip_coef" , "ent_coef" , "norm_adv" , "target_kl" ):
536+ value = getattr (ppo_actor_cfg , attr , None )
537+ if value is None :
538+ continue
539+ hyperparameters [f"ppo_actor_{ attr } " ] = value
540+
541+ ppo_critic_cfg = loss_configs .get ("ppo_critic" )
542+ if ppo_critic_cfg is not None :
543+ for attr in ("vf_coef" , "vf_clip_coef" , "clip_vloss" ):
544+ value = getattr (ppo_critic_cfg , attr , None )
545+ if value is None :
546+ continue
547+ hyperparameters [f"ppo_critic_{ attr } " ] = value
545548 return hyperparameters
0 commit comments