Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

### Added

- `tilebox-datasets`: Added `create_dataset` method to `Client` to create a new dataset.

## [0.45.0] - 2025-11-17

### Added
Expand Down
49 changes: 47 additions & 2 deletions tilebox-datasets/tests/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,32 @@
from dataclasses import replace
from functools import lru_cache

from google.protobuf.descriptor_pb2 import FileDescriptorProto, FileDescriptorSet
from google.protobuf.descriptor_pb2 import FieldDescriptorProto, FileDescriptorProto, FileDescriptorSet
from hypothesis.strategies import (
DrawFn,
booleans,
composite,
integers,
just,
lists,
none,
one_of,
sampled_from,
text,
uuids,
)

from tests.example_dataset.example_dataset_pb2 import DESCRIPTOR_PROTO
from tilebox.datasets.data.datasets import AnnotatedType, Dataset, DatasetGroup, FieldAnnotation, ListDatasetsResponse
from tilebox.datasets.data.datasets import (
AnnotatedType,
Dataset,
DatasetGroup,
DatasetKind,
DatasetType,
Field,
FieldAnnotation,
ListDatasetsResponse,
)
from tilebox.datasets.message_pool import register_once


Expand All @@ -28,6 +39,40 @@ def field_annotations(draw: DrawFn) -> FieldAnnotation:
return FieldAnnotation(description, example_value)


@composite
def fields(draw: DrawFn) -> Field:
"""A hypothesis strategy for generating random fields"""
name = draw(text(alphabet=string.ascii_lowercase + "_", min_size=3, max_size=25))
field_type = draw(
one_of(
just(FieldDescriptorProto.Type.TYPE_STRING),
just(FieldDescriptorProto.Type.TYPE_BYTES),
just(FieldDescriptorProto.Type.TYPE_BOOL),
just(FieldDescriptorProto.Type.TYPE_INT64),
just(FieldDescriptorProto.Type.TYPE_UINT64),
just(FieldDescriptorProto.Type.TYPE_DOUBLE),
just(FieldDescriptorProto.Type.TYPE_MESSAGE),
)
)
type_name = f".datasets.v1.{name}" if field_type == FieldDescriptorProto.Type.TYPE_MESSAGE else None
label = draw(
one_of(just(FieldDescriptorProto.Label.LABEL_OPTIONAL), just(FieldDescriptorProto.Label.LABEL_REPEATED))
)
descriptor = FieldDescriptorProto(name=name, type=field_type, type_name=type_name, label=label)

annotation = draw(field_annotations())
queryable = draw(booleans())
return Field(descriptor, annotation, queryable)


@composite
def dataset_types(draw: DrawFn) -> DatasetType:
"""A hypothesis strategy for generating random dataset types"""
kind = draw(sampled_from(DatasetKind) | none())
dataset_fields = draw(lists(fields(), min_size=1, max_size=5))
return DatasetType(kind, dataset_fields)


@lru_cache
def example_dataset_type() -> AnnotatedType:
descriptor = FileDescriptorProto.FromString(DESCRIPTOR_PROTO)
Expand Down
30 changes: 28 additions & 2 deletions tilebox-datasets/tests/data/test_datasets.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,40 @@
from hypothesis import given

from tests.data.datasets import annotated_types, dataset_groups, datasets, field_annotations, list_datasets_responses
from tilebox.datasets.data.datasets import AnnotatedType, Dataset, DatasetGroup, FieldAnnotation, ListDatasetsResponse
from tests.data.datasets import (
annotated_types,
dataset_groups,
dataset_types,
datasets,
field_annotations,
fields,
list_datasets_responses,
)
from tilebox.datasets.data.datasets import (
AnnotatedType,
Dataset,
DatasetGroup,
DatasetType,
Field,
FieldAnnotation,
ListDatasetsResponse,
)


@given(field_annotations())
def test_field_annotations_to_message_and_back(annotation: FieldAnnotation) -> None:
assert FieldAnnotation.from_message(annotation.to_message()) == annotation


@given(fields())
def test_fields_to_message_and_back(field: Field) -> None:
assert Field.from_message(field.to_message()) == field


@given(dataset_types())
def test_dataset_types_to_message_and_back(dataset_type: DatasetType) -> None:
assert DatasetType.from_message(dataset_type.to_message()) == dataset_type


@given(annotated_types())
def test_annotated_types_to_message_and_back(annotated_type: AnnotatedType) -> None:
assert AnnotatedType.from_message(annotated_type.to_message()) == annotated_type
Expand Down
38 changes: 38 additions & 0 deletions tilebox-datasets/tilebox/datasets/aio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from tilebox.datasets.aio.dataset import DatasetClient
from tilebox.datasets.client import Client as BaseClient
from tilebox.datasets.client import token_from_env
from tilebox.datasets.data.datasets import DatasetKind, FieldDict
from tilebox.datasets.datasets.v1.collections_pb2_grpc import CollectionServiceStub
from tilebox.datasets.datasets.v1.data_access_pb2_grpc import DataAccessServiceStub
from tilebox.datasets.datasets.v1.data_ingestion_pb2_grpc import DataIngestionServiceStub
Expand Down Expand Up @@ -32,10 +33,47 @@ def __init__(self, *, url: str = "https://api.tilebox.com", token: str | None =
)
self._client = BaseClient(service)

async def create_dataset(
self,
kind: DatasetKind,
code_name: str,
fields: list[FieldDict],
*,
name: str | None = None,
description: str | None = None,
) -> DatasetClient:
"""Create a new dataset.

Args:
kind: The kind of the dataset.
code_name: The code name of the dataset.
fields: The fields of the dataset.
name: The name of the dataset. Defaults to the code name.
description: A short description of the dataset. Optional.

Returns:
The created dataset.
"""
if name is None:
name = code_name
if description is None:
description = ""

return await self._client.create_dataset(kind, code_name, fields, name, description, DatasetClient)

async def datasets(self) -> Group:
"""Fetch all available datasets."""
return await self._client.datasets(DatasetClient)

async def dataset(self, slug: str) -> DatasetClient:
"""Get a dataset by its slug, e.g. `open_data.copernicus.sentinel1_sar`.

Args:
slug: The slug of the dataset.

Returns:
The dataset if it exists.
"""
return await self._client.dataset(slug, DatasetClient)

async def _dataset_by_id(self, dataset_id: str | UUID) -> DatasetClient:
Expand Down
22 changes: 10 additions & 12 deletions tilebox-datasets/tilebox/datasets/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from promise import Promise

from _tilebox.grpc.channel import parse_channel_info
from tilebox.datasets.data.datasets import Dataset, DatasetGroup, ListDatasetsResponse
from tilebox.datasets.data.datasets import Dataset, DatasetGroup, DatasetKind, FieldDict, ListDatasetsResponse
from tilebox.datasets.group import Group
from tilebox.datasets.message_pool import register_once
from tilebox.datasets.service import TileboxDatasetService
Expand All @@ -26,8 +26,16 @@ class Client:
def __init__(self, service: TileboxDatasetService) -> None:
self._service = service

def create_dataset( # noqa: PLR0913
self, kind: DatasetKind, code_name: str, fields: list[FieldDict], name: str, summary: str, dataset_type: type[T]
) -> Promise[T]:
return (
self._service.create_dataset(kind, code_name, fields, name, summary)
.then(_ensure_registered)
.then(lambda dataset: dataset_type(self._service, dataset))
)

def datasets(self, dataset_type: type[T]) -> Promise[Group]:
"""Fetch all available datasets."""
return (
self._service.list_datasets()
.then(_log_server_message)
Expand All @@ -40,16 +48,6 @@ def datasets(self, dataset_type: type[T]) -> Promise[Group]:
)

def dataset(self, slug: str, dataset_type: type[T]) -> Promise[T]:
"""
Get a dataset by its slug, e.g. `open_data.copernicus.sentinel1_sar`.

Args:
slug: The slug of the dataset

Returns:
The dataset if it exists.
"""

return (
self._service.get_dataset_by_slug(slug)
.then(_ensure_registered)
Expand Down
131 changes: 129 additions & 2 deletions tilebox-datasets/tilebox/datasets/data/datasets.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,29 @@
from dataclasses import dataclass
from datetime import datetime, timedelta
from enum import Enum
from typing import TypedDict, get_args, get_origin
from uuid import UUID

from google.protobuf.descriptor_pb2 import FileDescriptorSet
import numpy as np
from google.protobuf import duration_pb2, timestamp_pb2
from google.protobuf.descriptor_pb2 import FieldDescriptorProto, FileDescriptorSet
from shapely import Geometry
from typing_extensions import NotRequired, Required

from tilebox.datasets.datasets.v1 import core_pb2, dataset_type_pb2, datasets_pb2
from tilebox.datasets.datasets.v1 import core_pb2, dataset_type_pb2, datasets_pb2, well_known_types_pb2
from tilebox.datasets.uuid import uuid_message_to_optional_uuid, uuid_message_to_uuid, uuid_to_uuid_message


class DatasetKind(Enum):
TEMPORAL = dataset_type_pb2.DATASET_KIND_TEMPORAL
"""A dataset that contains a timestamp field."""
SPATIOTEMPORAL = dataset_type_pb2.DATASET_KIND_SPATIOTEMPORAL
"""A dataset that contains a timestamp field and a geometry field."""


_dataset_kind_int_to_enum = {kind.value: kind for kind in DatasetKind}


@dataclass(frozen=True)
class FieldAnnotation:
description: str
Expand All @@ -20,6 +37,116 @@ def to_message(self) -> dataset_type_pb2.FieldAnnotation:
return dataset_type_pb2.FieldAnnotation(description=self.description, example_value=self.example_value)


class FieldDict(TypedDict):
name: Required[str]
type: Required[
type[str]
| type[list[str]]
| type[bytes]
| type[list[bytes]]
| type[bool]
| type[list[bool]]
| type[int]
| type[list[int]]
| type[np.uint64]
| type[list[np.uint64]]
| type[float]
| type[list[float]]
| type[timedelta]
| type[list[timedelta]]
| type[datetime]
| type[list[datetime]]
| type[UUID]
| type[list[UUID]]
| type[Geometry]
| type[list[Geometry]]
]
description: NotRequired[str]
example_value: NotRequired[str]


_TYPE_INFO: dict[type, tuple[FieldDescriptorProto.Type.ValueType, str | None]] = {
str: (FieldDescriptorProto.TYPE_STRING, None),
bytes: (FieldDescriptorProto.TYPE_BYTES, None),
bool: (FieldDescriptorProto.TYPE_BOOL, None),
int: (FieldDescriptorProto.TYPE_INT64, None),
np.uint64: (FieldDescriptorProto.TYPE_UINT64, None),
float: (FieldDescriptorProto.TYPE_DOUBLE, None),
timedelta: (FieldDescriptorProto.TYPE_MESSAGE, f".{duration_pb2.Duration.DESCRIPTOR.full_name}"),
datetime: (FieldDescriptorProto.TYPE_MESSAGE, f".{timestamp_pb2.Timestamp.DESCRIPTOR.full_name}"),
UUID: (FieldDescriptorProto.TYPE_MESSAGE, f".{well_known_types_pb2.UUID.DESCRIPTOR.full_name}"),
Geometry: (FieldDescriptorProto.TYPE_MESSAGE, f".{well_known_types_pb2.Geometry.DESCRIPTOR.full_name}"),
}


@dataclass(frozen=True)
class Field:
descriptor: FieldDescriptorProto
annotation: FieldAnnotation
queryable: bool

@classmethod
def from_message(cls, field: dataset_type_pb2.Field) -> "Field":
return cls(
descriptor=field.descriptor,
annotation=FieldAnnotation.from_message(field.annotation),
queryable=field.queryable,
)

@classmethod
def from_dict(cls, field: FieldDict) -> "Field":
origin = get_origin(field["type"])
if origin is list:
label = FieldDescriptorProto.Label.LABEL_REPEATED
args = get_args(field["type"])
inner_type = args[0] if args else field["type"]
else:
label = FieldDescriptorProto.Label.LABEL_OPTIONAL
inner_type = field["type"]

(field_type, field_type_name) = _TYPE_INFO[inner_type]

return cls(
descriptor=FieldDescriptorProto(
name=field["name"],
type=field_type,
type_name=field_type_name,
label=label,
),
annotation=FieldAnnotation(
description=field.get("description", ""),
example_value=field.get("example_value", ""),
),
queryable=False,
)

def to_message(self) -> dataset_type_pb2.Field:
return dataset_type_pb2.Field(
descriptor=self.descriptor,
annotation=self.annotation.to_message(),
queryable=self.queryable,
)


@dataclass(frozen=True)
class DatasetType:
kind: DatasetKind | None
fields: list[Field]

@classmethod
def from_message(cls, dataset_type: dataset_type_pb2.DatasetType) -> "DatasetType":
return cls(
kind=_dataset_kind_int_to_enum.get(dataset_type.kind, None),
fields=[Field.from_message(f) for f in dataset_type.fields],
)

def to_message(self) -> dataset_type_pb2.DatasetType:
return dataset_type_pb2.DatasetType(
kind=self.kind.value if self.kind else dataset_type_pb2.DATASET_KIND_UNSPECIFIED,
fields=[f.to_message() for f in self.fields],
)


@dataclass(frozen=True)
class AnnotatedType:
descriptor_set: FileDescriptorSet
Expand Down
Loading
Loading