Skip to content

Check if data generators still use most efficient methods #115

@matobler

Description

@matobler

Make sure our data generators provide the best performance. In particular see if torchvision functions might lead to better performance. Here some sample code I developed before with some ideas that could be adapted (if they turn out to be more efficient). CroppedImageDataset and ImageDataset can be combined into one.

from torch.utils.data import Dataset
import torchvision.transforms.v2 as T_v2
import torchvision.transforms.v2.functional as F_v2
from torchvision.io import read_image


class CroppedImageDataset(Dataset):
    def __init__(self, image_paths, bounding_boxes, width=None, height=None, pad=False, transform=None):
        """
        Args:
            image_paths (list of str): List of file paths to images.
            bounding_boxes (list of tuples): List of bounding box coordinates in the format (x, y, w, h) scaled between 0 and 1.
            width (int, optional): Desired width of the output cropped image.
            height (int, optional): Desired height of the output cropped image.
            pad (bool, optional): Use black padding ot maintain aspect ratio.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.image_paths = image_paths
        self.bounding_boxes = bounding_boxes
        self.width = width
        self.height = height
        self.pad = pad
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        bbox = self.bounding_boxes[idx]
        
        # Load image using torchvision.io.read_image (returns a tensor)
        image = read_image(img_path).float() / 255.0  # Normalize to [0, 1]
        img_height, img_width = image.shape[1], image.shape[2]
        
        # Extract bounding box coordinates and convert from scaled to absolute values
        #x, y, w, h = bbox
        x_scaled, y_scaled, w_scaled, h_scaled = bbox
        x = int(x_scaled * img_width)
        y = int(y_scaled * img_height)
        w = int(w_scaled * img_width)
        h = int(h_scaled * img_height)
        
        # Crop image using torchvision.transforms.v2.functional.crop
        image = F_v2.crop(image, y, x, h, w)

        
        # Resize if width and height are specified
        if self.width and self.height:
            #pad image if requested
            if self.pad:
              resize_transform = Resize_with_pad(self.height, self.width)
              image = resize_transform(image)  
            else:
              resize_transform = T_v2.Resize((self.height, self.width))
              image = resize_transform(image)
        
        # Apply additional transforms if any
        if self.transform:
            image = self.transform(image)
        
        return image


class ImageDataset(Dataset):
    def __init__(self, image_paths, width=None, height=None, pad=False, transform=None):
        """
        Args:
            image_paths (list of str): List of file paths to images.
            width (int, optional): Desired width of the output cropped image.
            height (int, optional): Desired height of the output cropped image.
            pad (bool, optional): Use black padding ot maintain aspect ratio.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.image_paths = image_paths
        self.width = width
        self.height = height
        self.pad = pad
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        
        # Load image using torchvision.io.read_image (returns a tensor)
        image = read_image(img_path).float() / 255.0  # Normalize to [0, 1]
        
        # Resize if width and height are specified
        if self.width or self.height:
            if self.width and not self.height:
              self.height = int(image.size()[1]/image.size()[2]*self.width)
            elif not self.width and  self.height:
              self.width = int(image.size()[2]/image.size()[1]*self.height)
            #pad image if requested
            if self.pad:
              resize_transform = Resize_with_pad(self.height, self.width)
              image = resize_transform(image)  
            else:
              resize_transform = T_v2.Resize((self.height, self.width))
              image = resize_transform(image)
        
        # Apply additional transforms if any
        if self.transform:
            image = self.transform(image)
        
        return image
      


class Resize_with_pad:
    def __init__(self,h=768, w=1024):
        self.w = w
        self.h = h

    def __call__(self, image):

        b, h_1, w_1 = image.size()
        ratio_f = self.w / self.h
        ratio_1 = w_1 / h_1


        # check if the original and final aspect ratios are the same within a margin
        if round(ratio_1, 2) != round(ratio_f, 2):

            # padding to preserve aspect ratio
            hp = int(w_1/ratio_f - h_1)
            wp = int(ratio_f * h_1 - w_1)
            if hp > 0 and wp < 0:
                hp = hp // 2
                rp=T_v2.Compose([T_v2.Pad((0, hp, 0, hp), 0, "constant"),
                  T_v2.Resize([self.h, self.w])])
                return rp(image)

            elif hp < 0 and wp > 0:
                wp = wp // 2
                rp=T_v2.Compose([T_v2.Pad((wp, 0, wp, 0), 0, "constant"),
                  T_v2.Resize([self.h, self.w])])
                return rp(image)

        else:
            resize=T_v2.Resize([self.h, self.w])
            return resize(image)

Metadata

Metadata

Assignees

Labels

enhancementNew feature or request

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions