-
Notifications
You must be signed in to change notification settings - Fork 6
Open
Labels
enhancementNew feature or requestNew feature or request
Description
Make sure our data generators provide the best performance. In particular see if torchvision functions might lead to better performance. Here some sample code I developed before with some ideas that could be adapted (if they turn out to be more efficient). CroppedImageDataset and ImageDataset can be combined into one.
from torch.utils.data import Dataset
import torchvision.transforms.v2 as T_v2
import torchvision.transforms.v2.functional as F_v2
from torchvision.io import read_image
class CroppedImageDataset(Dataset):
def __init__(self, image_paths, bounding_boxes, width=None, height=None, pad=False, transform=None):
"""
Args:
image_paths (list of str): List of file paths to images.
bounding_boxes (list of tuples): List of bounding box coordinates in the format (x, y, w, h) scaled between 0 and 1.
width (int, optional): Desired width of the output cropped image.
height (int, optional): Desired height of the output cropped image.
pad (bool, optional): Use black padding ot maintain aspect ratio.
transform (callable, optional): Optional transform to be applied on a sample.
"""
self.image_paths = image_paths
self.bounding_boxes = bounding_boxes
self.width = width
self.height = height
self.pad = pad
self.transform = transform
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
img_path = self.image_paths[idx]
bbox = self.bounding_boxes[idx]
# Load image using torchvision.io.read_image (returns a tensor)
image = read_image(img_path).float() / 255.0 # Normalize to [0, 1]
img_height, img_width = image.shape[1], image.shape[2]
# Extract bounding box coordinates and convert from scaled to absolute values
#x, y, w, h = bbox
x_scaled, y_scaled, w_scaled, h_scaled = bbox
x = int(x_scaled * img_width)
y = int(y_scaled * img_height)
w = int(w_scaled * img_width)
h = int(h_scaled * img_height)
# Crop image using torchvision.transforms.v2.functional.crop
image = F_v2.crop(image, y, x, h, w)
# Resize if width and height are specified
if self.width and self.height:
#pad image if requested
if self.pad:
resize_transform = Resize_with_pad(self.height, self.width)
image = resize_transform(image)
else:
resize_transform = T_v2.Resize((self.height, self.width))
image = resize_transform(image)
# Apply additional transforms if any
if self.transform:
image = self.transform(image)
return image
class ImageDataset(Dataset):
def __init__(self, image_paths, width=None, height=None, pad=False, transform=None):
"""
Args:
image_paths (list of str): List of file paths to images.
width (int, optional): Desired width of the output cropped image.
height (int, optional): Desired height of the output cropped image.
pad (bool, optional): Use black padding ot maintain aspect ratio.
transform (callable, optional): Optional transform to be applied on a sample.
"""
self.image_paths = image_paths
self.width = width
self.height = height
self.pad = pad
self.transform = transform
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
img_path = self.image_paths[idx]
# Load image using torchvision.io.read_image (returns a tensor)
image = read_image(img_path).float() / 255.0 # Normalize to [0, 1]
# Resize if width and height are specified
if self.width or self.height:
if self.width and not self.height:
self.height = int(image.size()[1]/image.size()[2]*self.width)
elif not self.width and self.height:
self.width = int(image.size()[2]/image.size()[1]*self.height)
#pad image if requested
if self.pad:
resize_transform = Resize_with_pad(self.height, self.width)
image = resize_transform(image)
else:
resize_transform = T_v2.Resize((self.height, self.width))
image = resize_transform(image)
# Apply additional transforms if any
if self.transform:
image = self.transform(image)
return image
class Resize_with_pad:
def __init__(self,h=768, w=1024):
self.w = w
self.h = h
def __call__(self, image):
b, h_1, w_1 = image.size()
ratio_f = self.w / self.h
ratio_1 = w_1 / h_1
# check if the original and final aspect ratios are the same within a margin
if round(ratio_1, 2) != round(ratio_f, 2):
# padding to preserve aspect ratio
hp = int(w_1/ratio_f - h_1)
wp = int(ratio_f * h_1 - w_1)
if hp > 0 and wp < 0:
hp = hp // 2
rp=T_v2.Compose([T_v2.Pad((0, hp, 0, hp), 0, "constant"),
T_v2.Resize([self.h, self.w])])
return rp(image)
elif hp < 0 and wp > 0:
wp = wp // 2
rp=T_v2.Compose([T_v2.Pad((wp, 0, wp, 0), 0, "constant"),
T_v2.Resize([self.h, self.w])])
return rp(image)
else:
resize=T_v2.Resize([self.h, self.w])
return resize(image)
Metadata
Metadata
Assignees
Labels
enhancementNew feature or requestNew feature or request