@@ -734,42 +734,63 @@ def Q_update(self, recurrent=True, monte_carlo=False, policy=True, verbose=False
734734 self .update_target_network (source = self .Q2_network , target = self .Q2_target , tau = self .polyak )
735735 self .update_target_network (source = self .policy_network , target = self .policy_target , tau = self .polyak )
736736
737- def save_network (self , save_path ):
738- '''
739- Saves networks to directory specified by save_path
740- :param save_path: directory to save networks to
741- '''
742-
743- torch .save (self .policy_network , os .path .join (save_path , "policy_network.pth" ))
744- torch .save (self .Q1_network , os .path .join (save_path , "Q1_network.pth" ))
745- torch .save (self .Q2_network , os .path .join (save_path , "Q2_network.pth" ))
746-
747- torch .save (self .policy_target , os .path .join (save_path , "policy_target.pth" ))
748- torch .save (self .Q1_target , os .path .join (save_path , "Q1_target.pth" ))
749- torch .save (self .Q2_target , os .path .join (save_path , "Q2_target.pth" ))
750-
751- def load_network (self , load_path , load_target_networks = False ):
752- '''
753- Loads netoworks from directory specified by load_path.
754- :param load_path: directory to load networks from
755- :param load_target_networks: whether to load target networks
756- '''
757-
758- self .policy_network = torch .load (os .path .join (load_path , "policy_network.pth" ))
759- self .policy_network_opt = Adam (self .policy_network .parameters (), lr = self .pol_learning_rate )
760-
761- self .Q1_network = torch .load (os .path .join (load_path , "Q1_network.pth" ))
762- self .Q1_network_opt = Adam (self .Q1_network .parameters (), lr = self .val_learning_rate )
763-
764- self .Q2_network = torch .load (os .path .join (load_path , "Q2_network.pth" ))
765- self .Q2_etwork_opt = Adam (self .Q2_network .parameters (), lr = self .val_learning_rate )
766-
737+ def save_ckpt (self , save_path , additional_info = None ):
738+ '''
739+ Creates a full checkpoint (networks, optimizers, memory buffers) and saves it to the specified path.
740+ :param save_path: path to save the checkpoint to
741+ :param additional_info: additional information to save (Python dictionary)
742+ '''
743+ ckpt = {
744+ "policy_network" : self .policy_network .state_dict (),
745+ "Q1_network" : self .Q1_network .state_dict (),
746+ "Q2_network" : self .Q2_network .state_dict (),
747+ "policy_target" : self .policy_target .state_dict (),
748+ "Q1_target" : self .Q1_target .state_dict (),
749+ "Q2_target" : self .Q2_target .state_dict (),
750+ "policy_network_opt" : self .policy_network_opt .state_dict (),
751+ "Q1_network_opt" : self .Q1_network_opt .state_dict (),
752+ "Q2_network_opt" : self .Q2_network_opt .state_dict (),
753+ "additional_info" : additional_info if additional_info is not None else {},
754+ }
755+
756+ ### save buffers
757+ for buffer in ("memory" , "values" , "states" , "next_states" , "actions" , "rewards" , "dones" ,
758+ "sequences" , "next_sequences" , "all_returns" ):
759+ ckpt [buffer ] = getattr (self , buffer )
760+
761+ ### save the checkpoint
762+ torch .save (ckpt , save_path )
763+
764+ def load_ckpt (self , load_path , load_target_networks = True ):
765+ '''
766+ Loads a full checkpoint (networks, optimizers, memory buffers) from the specified path.
767+ :param load_path: path to load the checkpoint from
768+ :param load_target_networks: whether to load the target networks as well
769+ '''
770+ ckpt = torch .load (load_path )
771+
772+ ### load networks
773+ self .policy_network .load_state_dict (ckpt ["policy_network" ])
774+ self .Q1_network .load_state_dict (ckpt ["Q1_network" ])
775+ self .Q2_network .load_state_dict (ckpt ["Q2_network" ])
776+
777+ ### load target networks
767778 if load_target_networks :
768- self .policy_target = torch .load (os .path .join (load_path , "policy_target.pth" ))
769- self .Q1_target = torch .load (os .path .join (load_path , "Q1_target.pth" ))
770- self .Q2_target = torch .load (os .path .join (load_path , "Q2_target.pth" ))
771- else :
772- print ("[WARNING] Not loading target networks" )
779+ self .policy_target .load_state_dict (ckpt ["policy_target" ])
780+ self .Q1_target .load_state_dict (ckpt ["Q1_target" ])
781+ self .Q2_target .load_state_dict (ckpt ["Q2_target" ])
782+
783+ ### load optimizers
784+ self .policy_network_opt .load_state_dict (ckpt ["policy_network_opt" ])
785+ self .Q1_network_opt .load_state_dict (ckpt ["Q1_network_opt" ])
786+ self .Q2_network_opt .load_state_dict (ckpt ["Q2_network_opt" ])
787+
788+ ### load buffers
789+ for buffer in ("memory" , "values" , "states" , "next_states" , "actions" , "rewards" , "dones" ,
790+ "sequences" , "next_sequences" , "all_returns" ):
791+ setattr (self , buffer , ckpt [buffer ])
792+
793+ return ckpt
773794
774795 def reset_weights (self , policy = True ):
775796 '''
0 commit comments