Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ export(cifar10_dataset)
export(clip_boxes_to_image)
export(coco_caption_dataset)
export(coco_detection_dataset)
export(coco_segmentation_dataset)
export(draw_bounding_boxes)
export(draw_keypoints)
export(draw_segmentation_masks)
Expand Down
9 changes: 9 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
# torchvision (development version)

## Breaking changes

* **COCO datasets refactored**: Split `coco_detection_dataset()` into two separate datasets to reduce memory footprint and improve clarity:
- `coco_detection_dataset()` - Now only for object detection (bounding boxes). No longer includes segmentation polygons in the output.
- `coco_segmentation_dataset()` - New dataset specifically for instance segmentation tasks with polygon annotations and masks.
- Downloaded files are now organized in a `coco` subdirectory within the torch cache for better identification.
- This change reduces memory usage by ~50% (from 500MB+ to ~250MB per dataset) as each dataset only loads annotations relevant to its task.
- **Migration**: If you were using `coco_detection_dataset()` for segmentation tasks, switch to `coco_segmentation_dataset()` with `target_transform = target_transform_coco_masks`.

## New features

* Added collection dataset catalog with `search_collection()`, `get_collection_catalog()`, and `list_collection_datasets()` functions for discovering and exploring collections (#271, @ANAMASGARD).
Expand Down
238 changes: 221 additions & 17 deletions R/dataset-coco.R
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#' COCO Detection Dataset
#'
#' Loads the MS COCO dataset for object detection and segmentation.
#' Loads the MS COCO dataset for object detection tasks only.
#'
#' @rdname coco_detection_dataset
#' @param root Root directory where the dataset is stored or will be downloaded to.
Expand All @@ -16,20 +16,22 @@
#' - `y$labels`: an integer `torch_tensor` with the class label for each object.
#' - `y$area`: a float `torch_tensor` indicating the area of each object.
#' - `y$iscrowd`: a boolean `torch_tensor`, where `TRUE` marks the object as part of a crowd.
#' - `y$segmentation`: a list of segmentation polygons for each object.
#' - `y$masks`: a `(N, H, W)` boolean `torch_tensor` containing binary segmentation masks.
#'
#' The returned object has S3 classes \code{"image_with_bounding_box"} and \code{"image_with_segmentation_mask"}
#' to enable automatic dispatch by visualization functions such as \code{draw_bounding_boxes()} and \code{draw_segmentation_masks()}.
#' The returned object has S3 class \code{"image_with_bounding_box"}
#' to enable automatic dispatch by visualization functions such as \code{draw_bounding_boxes()}.
#'
#' For instance segmentation tasks, use \code{\link{coco_segmentation_dataset}} instead.
#'
#' @details
#' The returned image `x` is in CHW format (channels, height, width), matching the torch convention.
#' The dataset `y` offers object detection annotations such as bounding boxes, labels,
#' areas, crowd indicators, and segmentation masks from the official COCO annotations.
#' areas, and crowd indicators from the official COCO annotations.
#'
#' Files are downloaded to a \code{coco} subdirectory in the torch cache directory for better organization.
#'
#' @examples
#' \dontrun{
#' # Load dataset
#' # Load dataset for object detection
#' ds <- coco_detection_dataset(
#' train = FALSE,
#' year = "2017",
Expand All @@ -41,27 +43,229 @@
#' # Visualize bounding boxes
#' boxed <- draw_bounding_boxes(item)
#' tensor_image_browse(boxed)
#' }
#' @family detection_dataset
#' @seealso \code{\link{coco_segmentation_dataset}} for instance segmentation tasks
#' @importFrom jsonlite fromJSON
#' @export
coco_detection_dataset <- torch::dataset(
name = "coco_detection_dataset",
resources = data.frame(
year = rep(c(2017, 2014), each = 4 ),
content = rep(c("image", "annotation"), time = 2, each = 2),
split = rep(c("train", "val"), time = 4),
url = c("http://images.cocodataset.org/zips/train2017.zip", "http://images.cocodataset.org/zips/val2017.zip",
rep("http://images.cocodataset.org/annotations/annotations_trainval2017.zip", time = 2),
"http://images.cocodataset.org/zips/train2014.zip", "http://images.cocodataset.org/zips/val2014.zip",
rep("http://images.cocodataset.org/annotations/annotations_trainval2014.zip", time = 2)),
size = c("800 MB", "800 MB", rep("770 MB", time = 2), "6.33 GB", "6.33 GB", rep("242 MB", time = 2)),
md5 = c(c("cced6f7f71b7629ddf16f17bbcfab6b2", "442b8da7639aecaf257c1dceb8ba8c80"),
rep("f4bbac642086de4f52a3fdda2de5fa2c", time = 2),
c("0da8cfa0e090c266b78f30e2d2874f1a", "a3d79f5ed8d289b7a7554ce06a5782b3"),
rep("0a379cfc70b0e71301e0f377548639bd", time = 2)),
stringsAsFactors = FALSE
),

initialize = function(
root = tempdir(),
train = TRUE,
year = c("2017", "2014"),
download = FALSE,
transform = NULL,
target_transform = NULL
) {

year <- match.arg(year)
split <- ifelse(train, "train", "val")

root <- fs::path_expand(root)
self$root <- root
self$year <- year
self$split <- split
self$transform <- transform
self$target_transform <- target_transform
self$archive_size <- self$resources[self$resources$year == year & self$resources$split == split & self$resources$content == "image", ]$size

self$data_dir <- fs::path(root, glue::glue("coco{year}"))

image_year <- ifelse(year == "2016", "2014", year)
self$image_dir <- fs::path(self$data_dir, glue::glue("{split}{image_year}"))
self$annotation_file <- fs::path(self$data_dir, "annotations",
glue::glue("instances_{split}{year}.json"))

if (download) {
cli_inform("Dataset {.cls {class(self)[[1]]}} (~{.emph {self$archive_size}}) will be downloaded and processed if not already available.")
self$download()
}

if (!self$check_exists()) {
runtime_error("Dataset not found. You can use `download = TRUE` to download it.")
}

self$load_annotations()

cli_inform("{.cls {class(self)[[1]]}} dataset loaded with {length(self$image_ids)} images.")
},

check_exists = function() {
fs::file_exists(self$annotation_file) && fs::dir_exists(self$image_dir)
},

.getitem = function(index) {
image_id <- self$image_ids[index]
image_info <- self$image_metadata[[as.character(image_id)]]

img_path <- fs::path(self$image_dir, image_info$file_name)

x <- base_loader(img_path)

height <- dim(x)[1]
width <- dim(x)[2]

anns <- self$annotations[self$annotations$image_id == image_id, ]

if (nrow(anns) > 0) {
boxes_wh <- torch::torch_tensor(do.call(rbind, anns$bbox), dtype = torch::torch_float())
boxes <- box_xywh_to_xyxy(boxes_wh)

label_ids <- anns$category_id
labels <- as.character(self$categories$name[match(label_ids, self$categories$id)])

area <- torch::torch_tensor(anns$area, dtype = torch::torch_float())
iscrowd <- torch::torch_tensor(as.logical(anns$iscrowd), dtype = torch::torch_bool())

} else {
# empty annotation
boxes <- torch::torch_zeros(c(0, 4), dtype = torch::torch_float())
labels <- character()
area <- torch::torch_empty(0, dtype = torch::torch_float())
iscrowd <- torch::torch_empty(0, dtype = torch::torch_bool())
anns$segmentation <- list()
}

y <- list(
boxes = boxes,
labels = labels,
area = area,
iscrowd = iscrowd
)

if (!is.null(self$transform)) {
x <- self$transform(x)
}

if (!is.null(self$target_transform)) {
y$image_height <- height
y$image_width <- width
y <- self$target_transform(y)
}

result <- list(x = x, y = y)
class(result) <- c("image_with_bounding_box", class(result))

result
},

.length = function() {
length(self$image_ids)
},

download = function() {
annotation_filter <- self$resources$year == self$year & self$resources$split == self$split & self$resources$content == "annotation"
image_filter <- self$resources$year == self$year & self$resources$split == self$split & self$resources$content == "image"

cli_inform("Downloading {.cls {class(self)[[1]]}}...")

ann_zip <- download_and_cache(self$resources[annotation_filter, ]$url, prefix = "coco")
archive <- download_and_cache(self$resources[image_filter, ]$url, prefix = "coco")

if (tools::md5sum(archive) != self$resources[image_filter, ]$md5) {
runtime_error("Corrupt file! Delete the file in {archive} and try again.")
}

utils::unzip(ann_zip, exdir = self$data_dir)
utils::unzip(archive, exdir = self$data_dir)

cli_inform("Dataset {.cls {class(self)[[1]]}} downloaded and extracted successfully.")
},

load_annotations = function() {
data <- jsonlite::fromJSON(self$annotation_file)

self$image_metadata <- setNames(
split(data$images, seq_len(nrow(data$images))),
as.character(data$images$id)
)

self$annotations <- data$annotations
self$categories <- data$categories
self$category_names <- setNames(self$categories$name, self$categories$id)

ids <- as.numeric(names(self$image_metadata))
image_paths <- fs::path(self$image_dir,
sapply(ids, function(id) self$image_metadata[[as.character(id)]]$file_name))
exist <- fs::file_exists(image_paths)
self$image_ids <- ids[exist]
}
)


#' COCO Segmentation Dataset
#'
#' Loads the MS COCO dataset for instance segmentation tasks.
#'
#' @rdname coco_segmentation_dataset
#' @param root Root directory where the dataset is stored or will be downloaded to.
#' @param train Logical. If TRUE, loads the training split; otherwise, loads the validation split.
#' @param year Character. Dataset version year. One of \code{"2014"} or \code{"2017"}.
#' @param download Logical. If TRUE, downloads the dataset if it's not already present in the \code{root} directory.
#' @param transform Optional transform function applied to the image.
#' @param target_transform Optional transform function applied to the target.
#' Use \code{target_transform_coco_masks} to convert polygon annotations to binary masks.
#'
#' @return An object of class `coco_segmentation_dataset`. Each item is a list:
#' - `x`: a `(C, H, W)` array representing the image.
#' - `y$boxes`: a `(N, 4)` `torch_tensor` of bounding boxes in the format \eqn{(x_{min}, y_{min}, x_{max}, y_{max})}.
#' - `y$labels`: an integer `torch_tensor` with the class label for each object.
#' - `y$area`: a float `torch_tensor` indicating the area of each object.
#' - `y$iscrowd`: a boolean `torch_tensor`, where `TRUE` marks the object as part of a crowd.
#' - `y$segmentation`: a list of segmentation polygons for each object.
#' - `y$masks`: a `(N, H, W)` boolean `torch_tensor` containing binary segmentation masks (when using target_transform_coco_masks).
#'
#' The returned object has S3 class \code{"image_with_segmentation_mask"}
#' to enable automatic dispatch by visualization functions such as \code{draw_segmentation_masks()}.
#'
#' # In order to visualize segmentation masks, we
#' # use the specific segmentation mask target transformation
#' ds_with_masks <- coco_detection_dataset(
#' For object detection tasks without segmentation, use \code{\link{coco_detection_dataset}} instead.
#'
#' @details
#' The returned image `x` is in CHW format (channels, height, width), matching the torch convention.
#' The dataset `y` offers instance segmentation annotations including bounding boxes, labels,
#' areas, crowd indicators, and segmentation masks from the official COCO annotations.
#'
#' Files are downloaded to a \code{coco} subdirectory in the torch cache directory for better organization.
#'
#' @examples
#' \dontrun{
#' # Load dataset for instance segmentation
#' ds <- coco_segmentation_dataset(
#' train = FALSE,
#' year = "2017",
#' download = TRUE,
#' target_transform = target_transform_coco_masks
#' )
#'
#' item_masked <- ds_with_masks[1]
#' item <- ds[1]
#'
#' # Visualize segmentation masks
#' masked <- draw_segmentation_masks(item_masked)
#' masked <- draw_segmentation_masks(item)
#' tensor_image_browse(masked)
#' }
#' @family detection_dataset
#' @family segmentation_dataset
#' @seealso \code{\link{coco_detection_dataset}} for object detection tasks
#' @importFrom jsonlite fromJSON
#' @export
coco_detection_dataset <- torch::dataset(
name = "coco_detection_dataset",
coco_segmentation_dataset <- torch::dataset(
name = "coco_segmentation_dataset",
resources = data.frame(
year = rep(c(2017, 2014), each = 4 ),
content = rep(c("image", "annotation"), time = 2, each = 2),
Expand Down Expand Up @@ -192,8 +396,8 @@ coco_detection_dataset <- torch::dataset(

cli_inform("Downloading {.cls {class(self)[[1]]}}...")

ann_zip <- download_and_cache(self$resources[annotation_filter, ]$url, prefix = "coco_dataset")
archive <- download_and_cache(self$resources[image_filter, ]$url, prefix = "coco_dataset")
ann_zip <- download_and_cache(self$resources[annotation_filter, ]$url, prefix = "coco")
archive <- download_and_cache(self$resources[image_filter, ]$url, prefix = "coco")

if (tools::md5sum(archive) != self$resources[image_filter, ]$md5) {
runtime_error("Corrupt file! Delete the file in {archive} and try again.")
Expand Down
36 changes: 19 additions & 17 deletions _pkgdown.yml
Original file line number Diff line number Diff line change
Expand Up @@ -71,29 +71,31 @@ reference:

- title: Datasets
desc: >
Datasets readily available. All have a `x` variable in each item
being the input image.
- subtitle: for Image Classification
descr: Dataset having items with "y" for target class identifier.
Datasets readily available for various computer vision tasks.
All datasets provide items with an `x` variable containing the input image.

- subtitle: Classification Datasets
desc: >
Datasets for image classification tasks. Each item has a `y` variable
containing the target class identifier.
contents:
- has_concept("classification_dataset")
- subtitle: for Object Detection
descr: >
Dataset having items with "y" as a named list of bounding-box and
labels for object detection.

- subtitle: Detection & Segmentation Datasets
desc: >
Datasets for object detection and instance segmentation tasks.
Detection datasets provide bounding boxes, while segmentation datasets
additionally provide pixel-level masks.
contents:
- has_concept("detection_dataset")
- subtitle: for Image captionning
descr: >
Dataset having items with "y" as one or multiple captions of the image
- has_concept("segmentation_dataset")

- subtitle: Caption Datasets
desc: >
Datasets for image captioning tasks. Each item has a `y` variable
containing one or multiple captions describing the image.
contents:
- has_concept("caption_dataset")
- subtitle: for Semantic segmentation
descr: >
Dataset having items with "y" as a named list containing a segmentation
mask and labels for image segmentation.
contents:
- has_concept("segmentation_dataset")

- title: Displaying
- subtitle: Images loading
Expand Down
Loading