Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions pytorch/nvvl/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,6 @@ class VideoDataset(torch.utils.data.Dataset):
def __init__(self, filenames, sequence_length, device_id=0,
get_label=None, processing=None, log_level="warn"):
self.ffi = lib._ffi
self.filenames = filenames
self.sequence_length = sequence_length
self.device_id = device_id
self.get_label = get_label if get_label is not None else lambda x,y,z: None
Expand All @@ -203,9 +202,6 @@ def __init__(self, filenames, sequence_length, device_id=0,
print("Invalid log level", log_level, "using warn.", file=sys.stderr)
log_level = lib.LogLevel_Warn

if not filenames:
raise ValueError("Empty filenames list given to VideoDataset")

if sequence_length < 1:
raise ValueError("Sequence length must be at least 1")

Expand All @@ -214,17 +210,22 @@ def __init__(self, filenames, sequence_length, device_id=0,
self.total_frames = 0
self.frame_counts = []
self.start_index = []
self.filenames = []
for f in filenames:
count = lib.nvvl_frame_count(self.loader, str.encode(f));
if count < self.sequence_length:
print("NVVL WARNING: Ignoring", f, "because it only has", count,
"frames and the sequence length is", self.sequence_length)
continue
self.filenames.append(f)
count = count - self.sequence_length + 1
self.frame_counts.append(count)
self.total_frames += count
self.start_index.append(self.total_frames) # purposefully off by one for bisect to work

if not filenames:
raise ValueError("Empty filenames list given to VideoDataset")

size = lib.nvvl_video_size(self.loader)
self.width = size.width
self.height = size.height
Expand Down