From d88d2eb2e560665adb2405436e5283b3d3855b7a Mon Sep 17 00:00:00 2001 From: Pavel Ostyakov Date: Wed, 20 Mar 2019 12:33:01 +0300 Subject: [PATCH] Fix ignoring short videos in VideoDataset --- pytorch/nvvl/dataset.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/pytorch/nvvl/dataset.py b/pytorch/nvvl/dataset.py index 2f8e591..431944d 100644 --- a/pytorch/nvvl/dataset.py +++ b/pytorch/nvvl/dataset.py @@ -186,7 +186,6 @@ class VideoDataset(torch.utils.data.Dataset): def __init__(self, filenames, sequence_length, device_id=0, get_label=None, processing=None, log_level="warn"): self.ffi = lib._ffi - self.filenames = filenames self.sequence_length = sequence_length self.device_id = device_id self.get_label = get_label if get_label is not None else lambda x,y,z: None @@ -203,9 +202,6 @@ def __init__(self, filenames, sequence_length, device_id=0, print("Invalid log level", log_level, "using warn.", file=sys.stderr) log_level = lib.LogLevel_Warn - if not filenames: - raise ValueError("Empty filenames list given to VideoDataset") - if sequence_length < 1: raise ValueError("Sequence length must be at least 1") @@ -214,17 +210,22 @@ def __init__(self, filenames, sequence_length, device_id=0, self.total_frames = 0 self.frame_counts = [] self.start_index = [] + self.filenames = [] for f in filenames: count = lib.nvvl_frame_count(self.loader, str.encode(f)); if count < self.sequence_length: print("NVVL WARNING: Ignoring", f, "because it only has", count, "frames and the sequence length is", self.sequence_length) continue + self.filenames.append(f) count = count - self.sequence_length + 1 self.frame_counts.append(count) self.total_frames += count self.start_index.append(self.total_frames) # purposefully off by one for bisect to work + if not filenames: + raise ValueError("Empty filenames list given to VideoDataset") + size = lib.nvvl_video_size(self.loader) self.width = size.width self.height = size.height