diff --git a/rllib/agents/dqn/dqn.py b/rllib/agents/dqn/dqn.py index 696633d96459..cdf34d7e52e0 100644 --- a/rllib/agents/dqn/dqn.py +++ b/rllib/agents/dqn/dqn.py @@ -234,6 +234,8 @@ def execution_plan(workers: WorkerSet, # which means that we don't need the sampling pipeline setup for batch in input_reader.get_all(): local_replay_buffer.add_batch(batch) + config["bc_iters"] = input_reader.total_iterations_count + workers.local_worker().policy_map['default_policy'].update_config(config) else: parallel_rollouts_mode = config.get("parallel_rollouts_mode", "bulk_sync") num_async = config.get("parallel_rollouts_num_async") diff --git a/rllib/agents/sac/sac.py b/rllib/agents/sac/sac.py index 87f6c21254ae..b0b3389ce018 100644 --- a/rllib/agents/sac/sac.py +++ b/rllib/agents/sac/sac.py @@ -84,6 +84,8 @@ "normalize_actions": True, # Number of iterations to perform in the Behavior Cloning Pretraining "bc_iters": None, + # Number of epochs to perform in the Behavior Cloning Pretraining + "bc_epochs": 1, # === Learning === # Disable setting done=True at end of episode. This should be set to True diff --git a/rllib/agents/sac/sac_tf_policy.py b/rllib/agents/sac/sac_tf_policy.py index c7a14bd3dbdd..17a3a36e7433 100644 --- a/rllib/agents/sac/sac_tf_policy.py +++ b/rllib/agents/sac/sac_tf_policy.py @@ -228,8 +228,11 @@ def sac_actor_critic_loss( # Should be True only for debugging purposes (e.g. test cases)! deterministic = policy.config["_deterministic_loss"] bc_iters = policy.config["bc_iters"] - bc_iters_const = (tf.constant(bc_iters, dtype=policy.global_step.dtype) - if bc_iters else None) + bc_iters_const = tf1.placeholder_with_default( + tf.constant(bc_iters, dtype=policy.global_step.dtype), + shape=None, + name="bc_iters_const") + policy.bc_iters_const = bc_iters_const # Get the base model output from the train batch. model_out_t, _ = model({ diff --git a/rllib/policy/tf_policy.py b/rllib/policy/tf_policy.py index 1479499b9274..c9be9473ec23 100644 --- a/rllib/policy/tf_policy.py +++ b/rllib/policy/tf_policy.py @@ -211,6 +211,7 @@ def __init__(self, self.dist_class is not None: self._log_likelihood = self.dist_class( self._dist_inputs, self.model).logp(self._action_input) + self.bc_iters_const: Optional[tf.Tensor] = None def variables(self): """Return the list of all savable variables for this policy."""