diff --git a/qubalab/objects/classification.py b/qubalab/objects/classification.py index e36eef3..dd59f3a 100644 --- a/qubalab/objects/classification.py +++ b/qubalab/objects/classification.py @@ -6,9 +6,33 @@ class Classification(object): """ Simple class to store the names and color of a classification. + + Each Classification with the same names is the same object, retrieved from a cache. + Therefore updating the color of a Classification will update all similarly classified objects. """ - _cached_classifications = {} + _cached_classifications: dict[tuple[str], Classification] = {} + + def __new__( + cls, names: Union[str, tuple[str]], color: Optional[tuple[int, int, int]] = None + ): + if isinstance(names, str): + names = (names,) + elif isinstance(names, list): + names = tuple(names) + if names is None: + return None + if not isinstance(names, tuple): + raise TypeError("names should be str or tuple[str]") + + classification = Classification._cached_classifications.get(names) + if classification is None: + classification = super().__new__(cls) + Classification._cached_classifications[names] = classification + + if color is not None: + classification.color = color + return classification def __init__( self, @@ -21,22 +45,19 @@ def __init__( """ if isinstance(names, str): names = (names,) + elif isinstance(names, list): + names = tuple(names) + if not isinstance(names, tuple): + raise TypeError("names should be a tuple, list or string") self._names = names - self._color = ( + self._color: tuple = ( tuple(random.randint(0, 255) for _ in range(3)) if color is None else color ) - @property - def name(self) -> str: - """ - The name of the classification. - """ - return ": ".join(self._names) - @property def names(self) -> tuple[str]: """ - The name of the classification. + The names of the classification. """ return self._names @@ -47,49 +68,24 @@ def color(self) -> tuple[int, int, int]: """ return self._color ## todo: pylance type hints problem - @staticmethod - def get_cached_classification( - name: Optional[Union[str, tuple[str]]], - color: Optional[tuple[int, int, int]] = None, - ) -> Optional[Classification]: + @color.setter + def color(self, value: tuple[int, int, int]) -> None: """ - Return a classification by looking at an internal cache. - - If no classification with the provided name is present in the cache, a - new classification is created and the cache is updated. - - This is useful if you want to avoid creating multiple classifications with the - same name and use only one instead. - - :param name: the name of the classification (can be None) - :param color: the RGB color (each component between 0 and 255) of the classification. - Can be None to use a random color. This is only used if the cache doesn't - already contain a classification with the provided name - :return: a classification with the provided name, but not always with the provided color - if a classification with the same name already existed in the cache. If the provided - name is None, None is also returned + Change the color of the classification. + :param value: the new 8-bit RGB color """ - if name is None: - return None - if isinstance(name, str): - name = (name,) - name = ": ".join(name) - classification = Classification._cached_classifications.get(name) - if classification is None: - classification = Classification(name, color) - Classification._cached_classifications[classification.name] = classification - return classification + self._color = value def __str__(self): - return f"Classification {self.name} of color {self.color}" + return f"Classification {self.names} of color {self.color}" def __repr__(self): - return f"Classification('{self.name}', {self.color})" + return f"Classification('{self.names}', {self.color})" def __eq__(self, other): if isinstance(other, Classification): - return self.name == other.name and self.color == other.color + return (self is other) or (self.names == other.names) return False def __hash__(self): - return hash(self.name) + return hash(self.names) diff --git a/qubalab/objects/image_feature.py b/qubalab/objects/image_feature.py index e78f706..a321d90 100644 --- a/qubalab/objects/image_feature.py +++ b/qubalab/objects/image_feature.py @@ -222,9 +222,7 @@ def create_from_label_image( feature = cls( geometry=geometry, - classification=Classification.get_cached_classification( - classification_name - ), + classification=Classification(classification_name), measurements={"Label": float(label)} if include_labels else None, object_type=object_type, ) diff --git a/tests/objects/test_classification.py b/tests/objects/test_classification.py index dacb356..1d22fe9 100644 --- a/tests/objects/test_classification.py +++ b/tests/objects/test_classification.py @@ -2,36 +2,35 @@ def test_name(): - expected_name = "name" - classification = Classification(expected_name) + expected_names = ("name",) + classification = Classification(expected_names) - name = classification.name + names = classification.names - assert expected_name == name + assert expected_names == names def test_color(): expected_color = (2, 20, 56) - classification = Classification(None, expected_color) + classification = Classification("name", expected_color) color = classification.color assert expected_color == color -def test_cache_when_None_name_provided(): - classification = Classification.get_cached_classification(None) - - assert classification == None +def test_None_when_names_is_None(): + classification = Classification(None) + assert classification is None def test_cache_when_empty(): name = "name" color = (2, 20, 56) - classification = Classification.get_cached_classification(name, color) + classification = Classification(name, color) - assert classification == Classification(name, color) + assert classification is Classification(name, color) def test_cache_when_not_empty_and_same_name(): @@ -39,11 +38,14 @@ def test_cache_when_not_empty_and_same_name(): cached_color = (2, 20, 56) other_name = cached_name other_color = (4, 65, 7) - cached_classification = Classification.get_cached_classification(cached_name, cached_color) + cached_classification = Classification(cached_name, cached_color) - classification = Classification.get_cached_classification(other_name, other_color) + classification = Classification(other_name, other_color) - assert classification != Classification(other_name, other_color) and classification == cached_classification + assert ( + classification is Classification(other_name, other_color) + and classification is cached_classification + ) def test_cache_when_not_empty_and_different_name(): @@ -51,8 +53,18 @@ def test_cache_when_not_empty_and_different_name(): cached_color = (2, 20, 56) other_name = "other name" other_color = (4, 65, 7) - cached_classification = Classification.get_cached_classification(cached_name, cached_color) + cached_classification = Classification(cached_name, cached_color) + + classification = Classification(other_name, other_color) + + assert ( + classification == Classification(other_name, other_color) + and classification != cached_classification + ) - classification = Classification.get_cached_classification(other_name, other_color) - assert classification == Classification(other_name, other_color) and classification != cached_classification +def test_names_input(): + names = ("a", "b") + class1 = Classification(names) + class2 = Classification(list(names)) + assert class1 is class2 diff --git a/tests/objects/test_image_feature.py b/tests/objects/test_image_feature.py index 5a32d50..963f8e2 100644 --- a/tests/objects/test_image_feature.py +++ b/tests/objects/test_image_feature.py @@ -423,7 +423,7 @@ def test_classification_when_created_from_label_image_and_classification_name_pr ) assert all( - feature.classification.name == expected_classification_name + feature.classification.names == (expected_classification_name,) for feature in features ) @@ -457,7 +457,7 @@ def test_classification_when_created_from_label_image_and_classification_dict_pr assert all( feature.classification is None - or feature.classification.name in expected_classification_names + or feature.classification.names[0] in expected_classification_names for feature in features ) @@ -580,7 +580,7 @@ def test_classification_when_created_from_binary_image_and_classification_name_p ) assert all( - feature.classification.name == expected_classification_name + feature.classification.names == (expected_classification_name,) for feature in features ) @@ -609,7 +609,7 @@ def test_classification_when_created_from_binary_image_and_classification_dict_p ) assert all( - feature.classification.name in expected_classification_names + feature.classification.names[0] in expected_classification_names for feature in features ) @@ -618,7 +618,7 @@ def test_classification_when_set_after_creation(): expected_classification = Classification("name", (1, 1, 1)) image_feature = ImageFeature(None) image_feature.classification = { - "name": expected_classification.name, + "names": expected_classification.names, "color": expected_classification.color, } @@ -630,9 +630,9 @@ def test_classification_when_set_after_creation(): def test_name_when_set_after_creation(): expected_name = "name" image_feature = ImageFeature(None) - image_feature.name = expected_name + image_feature.names = expected_name - name = image_feature.name + name = image_feature.names assert name == expected_name @@ -696,11 +696,7 @@ def test_imagefeature_handles_classification_names(): """ feature = geojson_features_from_string(string) ifeature = ImageFeature.create_from_feature(feature) - assert ( - ifeature.classification.names - == feature["properties"]["classification"]["names"] - ) - assert ifeature.classification.name == ": ".join( + assert ifeature.classification.names == tuple( feature["properties"]["classification"]["names"] ) @@ -711,6 +707,6 @@ def test_imagefeature_handles_classification_name(): """ feature = geojson_features_from_string(string) ifeature = ImageFeature.create_from_feature(feature) - assert ifeature.classification.name == ": ".join( + assert ifeature.classification.names == tuple( feature["properties"]["classification"]["names"] )