diff --git a/comet/models/predict_writer.py b/comet/models/predict_writer.py index 7d95989..6323445 100644 --- a/comet/models/predict_writer.py +++ b/comet/models/predict_writer.py @@ -97,14 +97,14 @@ def flatten_predictions(predictions): files = sorted(os.listdir(self.output_dir)) pred = flatten_predictions( [ - flatten_predictions(torch.load(os.path.join(self.output_dir, f))) + flatten_predictions(torch.load(os.path.join(self.output_dir, f), weights_only=False)) for f in files if "pred" in f ] ) indices = flatten( [ - flatten(torch.load(os.path.join(self.output_dir, f))[0]) + flatten(torch.load(os.path.join(self.output_dir, f), weights_only=False)[0]) for f in files if "batch_indices" in f ]