diff --git a/returnn/torch/engine.py b/returnn/torch/engine.py index 4e1b18c608..73180a0b4c 100644 --- a/returnn/torch/engine.py +++ b/returnn/torch/engine.py @@ -377,8 +377,8 @@ def _create_data_loader(self, dataset: Dataset) -> DataLoader: dataset=batches_dataset, batch_size=None, num_workers=num_workers, - multiprocessing_context="spawn", - persistent_workers=True, + multiprocessing_context="spawn" if num_workers > 0 else None, + persistent_workers=num_workers > 0, ) def run_train_step(