From 1b852fc0a3ae44171fddb041db655d0ba5ab9923 Mon Sep 17 00:00:00 2001 From: albert bou Date: Mon, 20 Feb 2023 16:48:47 +0100 Subject: [PATCH 1/3] distributed data collection --- .../distributed_collector.py | 180 ++++++++++++++++++ distributed_collector/fake_collector.py | 68 +++++++ distributed_collector/main.py | 57 ++++++ 3 files changed, 305 insertions(+) create mode 100644 distributed_collector/distributed_collector.py create mode 100644 distributed_collector/fake_collector.py create mode 100644 distributed_collector/main.py diff --git a/distributed_collector/distributed_collector.py b/distributed_collector/distributed_collector.py new file mode 100644 index 0000000..80a6618 --- /dev/null +++ b/distributed_collector/distributed_collector.py @@ -0,0 +1,180 @@ +import ray +import torch +from abc import ABC +from typing import Iterator +from torch.utils.data import IterableDataset +from tensordict.tensordict import TensorDict, TensorDictBase +from fake_collector import FakeCollector + +default_remote_config = { + "num_cpus": 1, + "num_gpus": 0.2, + "memory": 5 * 1024 ** 3, + "object_store_memory": 2 * 1024 ** 3 +} + + +class DistributedCollector(IterableDataset, ABC): + """ + Class to better handle the operations of ensembles of Collectors. + + Contains common functionality across all collectors sets. + + Parameters + ---------- + collector_class : class + A collector class. + collector_params : dict + Collector class kwargs. + remote_config : dict + Ray resource specs for the remote collectors. + num_collectors: int + Total number of collectors in the set (including local collector) + """ + + def __init__(self, + collector_class, + collector_params, + remote_config=default_remote_config, + num_collectors=1, + total_frames=1000, + communication="sync", # "sync" or "async" + ): + + if communication not in ("sync", "async"): + raise ValueError(f"Communication parameter in CollectorSet has to be sync or async.") + + self.collected_frames = 0 + self.total_frames = total_frames + self.collector_class = collector_class + self.collector_params = collector_params + self.num_collectors = num_collectors + self.remote_config = remote_config + self.communication = communication + + # Create a local instance of the collector class + # TODO: actually not used for now, but sometimes can be interesting to have a local + # copy of the collector + self._local_collector = self._make_collector( + self.collector_class, collector_params) + + # Create remote instances of the collector class + self._remote_collectors = [] + if self.num_collectors > 1: + self.add_collectors(self.num_collectors - 1, collector_params) + + @staticmethod + def _make_collector(cls, collector_params): + """Create a single collector instance.""" + w = cls(**collector_params) + return w + + def add_collectors(self, num_collectors, collector_params): + """Create and add a number of remote collectors to the set.""" + cls = self.collector_class.as_remote(**self.remote_config).remote + self._remote_collectors.extend( + [self._make_collector(cls, collector_params) for _ in range(num_collectors)]) + + def local_collector(self): + """Return local collector""" + return self._local_collector + + def remote_collectors(self): + """Returns list of remote collectors""" + return self._remote_collectors + + def stop(self): + """Stop all remote collectors""" + for w in self.remote_collectors(): + w.__ray_terminate__.remote() + + def __iter__(self) -> Iterator[TensorDictBase]: + if self.communication == "sync": + return self.sync_iterator() + else: + return self.async_iterator() + + def sync_iterator(self) -> Iterator[TensorDictBase]: + + while self.collected_frames < self.total_frames: + + # Broadcast weights + policy_weights = {} # TODO. get latest weights + latest_weights = ray.put(policy_weights) + for e in self.remote_collectors(): + e.set_weights.remote(latest_weights) + + # Ask for batches to all remote workers. + pending_samples = [e.rollout.remote() for e in self.remote_collectors()] + + # Wait for all rollouts + samples_ready = [] + while len(samples_ready) < self.num_collectors - 1: + samples_ready, samples_not_ready = ray.wait(pending_samples, num_returns=len(pending_samples), timeout=0.001) + + # Retrieve and concatenate Tensordicts + out_td = [] + for r in pending_samples: + rollouts = ray.get(r) + ray.internal.free(r) + out_td.append(rollouts) + out_td = torch.cat(out_td) + + self.collected_frames += out_td.numel() + + yield out_td + + def async_iterator(self) -> Iterator[TensorDictBase]: + + pending_tasks = {} + for w in self.remote_collectors(): + future = w.rollout.remote() + pending_tasks[future] = w + + while self.collected_frames < self.total_frames: + + if not len(list(pending_tasks.keys())) == len(self.remote_collectors()): + raise RuntimeError("Missing pending tasks, something went wrong") + + # Wait for first worker to finish + wait_results = ray.wait(list(pending_tasks.keys())) + future = wait_results[0][0] + w = pending_tasks.pop(future) + + # Retrieve single rollouts + out_td = ray.get(future) + ray.internal.free(future) + self.collected_frames += out_td.numel() + + # Update agent weights + policy_weights = {} # TODO. get latest weights + latest_weights = ray.put(policy_weights) + w.set_weights.remote(latest_weights) + + # Schedule a new collection task + future = w.rollout.remote() + pending_tasks[future] = w + + yield out_td + + +if __name__ == "__main__": + + ray.init() + distributed_collector = DistributedCollector( + collector_class=FakeCollector, + collector_params={ + "num_batches": 100, + "shape": (2, 10), + }, + remote_config=default_remote_config, + num_collectors=3, + total_frames=1000, + communication="async", + ) + + counter = 0 + for batch in distributed_collector: + counter += 1 + print(f"batch {counter}, shape {batch.shape}") + distributed_collector.stop() diff --git a/distributed_collector/fake_collector.py b/distributed_collector/fake_collector.py new file mode 100644 index 0000000..475b643 --- /dev/null +++ b/distributed_collector/fake_collector.py @@ -0,0 +1,68 @@ +from abc import ABC +import ray +import torch +from torch.utils.data import IterableDataset +from tensordict.tensordict import TensorDict, TensorDictBase + + +class FakeCollector(IterableDataset, ABC): + + def __init__(self, num_batches=100, shape=(2, 10)): + super(FakeCollector).__init__() + self.shape = shape + self.num_batches = num_batches + + @classmethod + def as_remote(cls, + num_cpus=None, + num_gpus=None, + memory=None, + object_store_memory=None, + resources=None): + """ + Creates an instance of a remote ray FakeCollector. + + Parameters + ---------- + num_cpus : int + The quantity of CPU cores to reserve for this class. + num_gpus : float + The quantity of GPUs to reserve for this class. + memory : int + The heap memory quota for this class (in bytes). + object_store_memory : int + The object store memory quota for this class (in bytes). + resources: Dict[str, float] + The default resources required by the class creation task. + + Returns + ------- + w : FakeCollector + A ray remote FakeCollector class. + """ + w = ray.remote( + num_cpus=num_cpus, + num_gpus=num_gpus, + memory=memory, + object_store_memory=object_store_memory, + resources=resources)(cls) + w.is_remote = True + return w + + @torch.no_grad() + def rollout(self) -> TensorDictBase: + return TensorDict({"observation": torch.randn(self.shape)}, batch_size=self.shape) + + def set_weights(self, policy_weights={}): + """Update the worker actor version with provided weights.""" + pass + + +if __name__ == "__main__": + + collector = FakeCollector() + counter = 0 + for _ in range(10): + counter += 1 + batch = collector.rollout() + print(f"batch {counter}, shape {batch.shape}") diff --git a/distributed_collector/main.py b/distributed_collector/main.py new file mode 100644 index 0000000..4439c1c --- /dev/null +++ b/distributed_collector/main.py @@ -0,0 +1,57 @@ +import ray +from fake_collector import FakeCollector +from distributed_collector import DistributedCollector + + +if __name__ == "__main__": + + # Init locally for now, but in a cluster is essentially the same with more params. + ray.init() + + # Define resources of each remote collector + default_remote_config = { + "num_cpus": 1, + "num_gpus": 0.2, + "memory": 5 * 1024 ** 3, + "object_store_memory": 2 * 1024 ** 3 + } + + print("Test 1: Collect data test in synchronous mode.") + + distributed_collector = DistributedCollector( + collector_class=FakeCollector, + collector_params={ + "num_batches": 100, + "shape": (2, 10), + }, + remote_config=default_remote_config, + num_collectors=3, + total_frames=1000, + communication="sync", + ) + + counter = 0 + for batch in distributed_collector: + counter += 1 + print(f"batch {counter}, shape {batch.shape}") + distributed_collector.stop() + + print("Test 2: Collect data test in asynchronous mode.") + + distributed_collector = DistributedCollector( + collector_class=FakeCollector, + collector_params={ + "num_batches": 100, + "shape": (2, 10), + }, + remote_config=default_remote_config, + num_collectors=3, + total_frames=1000, + communication="async", + ) + + counter = 0 + for batch in distributed_collector: + counter += 1 + print(f"batch {counter}, shape {batch.shape}") + distributed_collector.stop() From 59783edae0743201c42096908a3b34008217e1be Mon Sep 17 00:00:00 2001 From: albert bou Date: Mon, 20 Feb 2023 16:51:43 +0100 Subject: [PATCH 2/3] distributed data collection --- distributed_collector/main.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/distributed_collector/main.py b/distributed_collector/main.py index 4439c1c..73928c5 100644 --- a/distributed_collector/main.py +++ b/distributed_collector/main.py @@ -16,7 +16,7 @@ "object_store_memory": 2 * 1024 ** 3 } - print("Test 1: Collect data test in synchronous mode.") + print("\nTest 1: Collect data test in synchronous mode.\n") distributed_collector = DistributedCollector( collector_class=FakeCollector, @@ -36,7 +36,7 @@ print(f"batch {counter}, shape {batch.shape}") distributed_collector.stop() - print("Test 2: Collect data test in asynchronous mode.") + print("\nTest 2: Collect data test in asynchronous mode.\n") distributed_collector = DistributedCollector( collector_class=FakeCollector, @@ -55,3 +55,5 @@ counter += 1 print(f"batch {counter}, shape {batch.shape}") distributed_collector.stop() + + print("\nSuccess!") From 866e2bb1533029b68f25572b0998b813d9879c00 Mon Sep 17 00:00:00 2001 From: albert bou Date: Mon, 20 Feb 2023 16:56:36 +0100 Subject: [PATCH 3/3] minor change --- distributed_collector/fake_collector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed_collector/fake_collector.py b/distributed_collector/fake_collector.py index 475b643..cb5e404 100644 --- a/distributed_collector/fake_collector.py +++ b/distributed_collector/fake_collector.py @@ -54,7 +54,7 @@ def rollout(self) -> TensorDictBase: return TensorDict({"observation": torch.randn(self.shape)}, batch_size=self.shape) def set_weights(self, policy_weights={}): - """Update the worker actor version with provided weights.""" + """Update the policy version with provided weights.""" pass