Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 40 additions & 44 deletions qubalab/objects/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,33 @@
class Classification(object):
"""
Simple class to store the names and color of a classification.

Each Classification with the same names is the same object, retrieved from a cache.
Therefore updating the color of a Classification will update all similarly classified objects.
"""

_cached_classifications = {}
_cached_classifications: dict[tuple[str], Classification] = {}

def __new__(
cls, names: Union[str, tuple[str]], color: Optional[tuple[int, int, int]] = None
):
if isinstance(names, str):
names = (names,)
elif isinstance(names, list):
names = tuple(names)
if names is None:
return None
if not isinstance(names, tuple):
raise TypeError("names should be str or tuple[str]")

classification = Classification._cached_classifications.get(names)
if classification is None:
classification = super().__new__(cls)
Classification._cached_classifications[names] = classification

if color is not None:
classification.color = color
return classification

def __init__(
self,
Expand All @@ -21,22 +45,19 @@ def __init__(
"""
if isinstance(names, str):
names = (names,)
elif isinstance(names, list):
names = tuple(names)
if not isinstance(names, tuple):
raise TypeError("names should be a tuple, list or string")
self._names = names
self._color = (
self._color: tuple = (
tuple(random.randint(0, 255) for _ in range(3)) if color is None else color
)

@property
def name(self) -> str:
"""
The name of the classification.
"""
return ": ".join(self._names)

@property
def names(self) -> tuple[str]:
"""
The name of the classification.
The names of the classification.
"""
return self._names

Expand All @@ -47,49 +68,24 @@ def color(self) -> tuple[int, int, int]:
"""
return self._color ## todo: pylance type hints problem

@staticmethod
def get_cached_classification(
name: Optional[Union[str, tuple[str]]],
color: Optional[tuple[int, int, int]] = None,
) -> Optional[Classification]:
@color.setter
def color(self, value: tuple[int, int, int]) -> None:
"""
Return a classification by looking at an internal cache.

If no classification with the provided name is present in the cache, a
new classification is created and the cache is updated.

This is useful if you want to avoid creating multiple classifications with the
same name and use only one instead.

:param name: the name of the classification (can be None)
:param color: the RGB color (each component between 0 and 255) of the classification.
Can be None to use a random color. This is only used if the cache doesn't
already contain a classification with the provided name
:return: a classification with the provided name, but not always with the provided color
if a classification with the same name already existed in the cache. If the provided
name is None, None is also returned
Change the color of the classification.
:param value: the new 8-bit RGB color
"""
if name is None:
return None
if isinstance(name, str):
name = (name,)
name = ": ".join(name)
classification = Classification._cached_classifications.get(name)
if classification is None:
classification = Classification(name, color)
Classification._cached_classifications[classification.name] = classification
return classification
self._color = value

def __str__(self):
return f"Classification {self.name} of color {self.color}"
return f"Classification {self.names} of color {self.color}"

def __repr__(self):
return f"Classification('{self.name}', {self.color})"
return f"Classification('{self.names}', {self.color})"

def __eq__(self, other):
if isinstance(other, Classification):
return self.name == other.name and self.color == other.color
return (self is other) or (self.names == other.names)
return False

def __hash__(self):
return hash(self.name)
return hash(self.names)
4 changes: 1 addition & 3 deletions qubalab/objects/image_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,9 +222,7 @@ def create_from_label_image(

feature = cls(
geometry=geometry,
classification=Classification.get_cached_classification(
classification_name
),
classification=Classification(classification_name),
measurements={"Label": float(label)} if include_labels else None,
object_type=object_type,
)
Expand Down
46 changes: 29 additions & 17 deletions tests/objects/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,57 +2,69 @@


def test_name():
expected_name = "name"
classification = Classification(expected_name)
expected_names = ("name",)
classification = Classification(expected_names)

name = classification.name
names = classification.names

assert expected_name == name
assert expected_names == names


def test_color():
expected_color = (2, 20, 56)
classification = Classification(None, expected_color)
classification = Classification("name", expected_color)

color = classification.color

assert expected_color == color


def test_cache_when_None_name_provided():
classification = Classification.get_cached_classification(None)

assert classification == None
def test_None_when_names_is_None():
classification = Classification(None)
assert classification is None


def test_cache_when_empty():
name = "name"
color = (2, 20, 56)

classification = Classification.get_cached_classification(name, color)
classification = Classification(name, color)

assert classification == Classification(name, color)
assert classification is Classification(name, color)


def test_cache_when_not_empty_and_same_name():
cached_name = "name"
cached_color = (2, 20, 56)
other_name = cached_name
other_color = (4, 65, 7)
cached_classification = Classification.get_cached_classification(cached_name, cached_color)
cached_classification = Classification(cached_name, cached_color)

classification = Classification.get_cached_classification(other_name, other_color)
classification = Classification(other_name, other_color)

assert classification != Classification(other_name, other_color) and classification == cached_classification
assert (
classification is Classification(other_name, other_color)
and classification is cached_classification
)


def test_cache_when_not_empty_and_different_name():
cached_name = "name"
cached_color = (2, 20, 56)
other_name = "other name"
other_color = (4, 65, 7)
cached_classification = Classification.get_cached_classification(cached_name, cached_color)
cached_classification = Classification(cached_name, cached_color)

classification = Classification(other_name, other_color)

assert (
classification == Classification(other_name, other_color)
and classification != cached_classification
)

classification = Classification.get_cached_classification(other_name, other_color)

assert classification == Classification(other_name, other_color) and classification != cached_classification
def test_names_input():
names = ("a", "b")
class1 = Classification(names)
class2 = Classification(list(names))
assert class1 is class2
22 changes: 9 additions & 13 deletions tests/objects/test_image_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,7 @@ def test_classification_when_created_from_label_image_and_classification_name_pr
)

assert all(
feature.classification.name == expected_classification_name
feature.classification.names == (expected_classification_name,)
for feature in features
)

Expand Down Expand Up @@ -457,7 +457,7 @@ def test_classification_when_created_from_label_image_and_classification_dict_pr

assert all(
feature.classification is None
or feature.classification.name in expected_classification_names
or feature.classification.names[0] in expected_classification_names
for feature in features
)

Expand Down Expand Up @@ -580,7 +580,7 @@ def test_classification_when_created_from_binary_image_and_classification_name_p
)

assert all(
feature.classification.name == expected_classification_name
feature.classification.names == (expected_classification_name,)
for feature in features
)

Expand Down Expand Up @@ -609,7 +609,7 @@ def test_classification_when_created_from_binary_image_and_classification_dict_p
)

assert all(
feature.classification.name in expected_classification_names
feature.classification.names[0] in expected_classification_names
for feature in features
)

Expand All @@ -618,7 +618,7 @@ def test_classification_when_set_after_creation():
expected_classification = Classification("name", (1, 1, 1))
image_feature = ImageFeature(None)
image_feature.classification = {
"name": expected_classification.name,
"names": expected_classification.names,
"color": expected_classification.color,
}

Expand All @@ -630,9 +630,9 @@ def test_classification_when_set_after_creation():
def test_name_when_set_after_creation():
expected_name = "name"
image_feature = ImageFeature(None)
image_feature.name = expected_name
image_feature.names = expected_name

name = image_feature.name
name = image_feature.names

assert name == expected_name

Expand Down Expand Up @@ -696,11 +696,7 @@ def test_imagefeature_handles_classification_names():
"""
feature = geojson_features_from_string(string)
ifeature = ImageFeature.create_from_feature(feature)
assert (
ifeature.classification.names
== feature["properties"]["classification"]["names"]
)
assert ifeature.classification.name == ": ".join(
assert ifeature.classification.names == tuple(
feature["properties"]["classification"]["names"]
)

Expand All @@ -711,6 +707,6 @@ def test_imagefeature_handles_classification_name():
"""
feature = geojson_features_from_string(string)
ifeature = ImageFeature.create_from_feature(feature)
assert ifeature.classification.name == ": ".join(
assert ifeature.classification.names == tuple(
feature["properties"]["classification"]["names"]
)