From 64ff6709be2b917feebe6373458f3f65a067a0f6 Mon Sep 17 00:00:00 2001 From: Ruofan Kong Date: Thu, 12 Aug 2021 00:04:13 -0700 Subject: [PATCH] No Case: Fixed RLLib Sample collector API types. --- rllib/evaluation/collectors/sample_collector.py | 3 ++- rllib/evaluation/collectors/simple_list_collector.py | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/rllib/evaluation/collectors/sample_collector.py b/rllib/evaluation/collectors/sample_collector.py index 415eac9592a7..6024aafc1180 100644 --- a/rllib/evaluation/collectors/sample_collector.py +++ b/rllib/evaluation/collectors/sample_collector.py @@ -59,6 +59,7 @@ def __init__(self, @abstractmethod def add_init_obs(self, episode: MultiAgentEpisode, agent_id: AgentID, + env_id: EnvID, policy_id: PolicyID, t: int, init_obs: TensorType) -> None: """Adds an initial obs (after reset) to this collector. @@ -204,7 +205,7 @@ def postprocess_episode(self, episode: MultiAgentEpisode, is_done: bool = False, check_dones: bool = False, - build: bool = False) -> Optional[MultiAgentBatch]: + build: bool = False) -> Optional[Union[SampleBatch, MultiAgentBatch]]: """Postprocesses all agents' trajectories in a given episode. Generates (single-trajectory) SampleBatches for all Policies/Agents and diff --git a/rllib/evaluation/collectors/simple_list_collector.py b/rllib/evaluation/collectors/simple_list_collector.py index 3c4838627380..0876b322cdc9 100644 --- a/rllib/evaluation/collectors/simple_list_collector.py +++ b/rllib/evaluation/collectors/simple_list_collector.py @@ -3,7 +3,7 @@ import logging import math import numpy as np -from typing import Any, Dict, List, Tuple, TYPE_CHECKING, Union +from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union from ray.rllib.env.base_env import _DUMMY_AGENT_ID from ray.rllib.evaluation.collectors.sample_collector import SampleCollector @@ -612,7 +612,7 @@ def postprocess_episode( episode: MultiAgentEpisode, is_done: bool = False, check_dones: bool = False, - build: bool = False) -> Union[None, SampleBatch, MultiAgentBatch]: + build: bool = False) -> Optional[Union[SampleBatch, MultiAgentBatch]]: episode_id = episode.episode_id policy_collector_group = episode.batch_builder