diff --git a/src/tracksdata/constants.py b/src/tracksdata/constants.py index f74672fc..923ae5a9 100644 --- a/src/tracksdata/constants.py +++ b/src/tracksdata/constants.py @@ -18,6 +18,10 @@ class DefaultAttrKeys: Default key for time information. MASK : str Default key for node masks. + BBOX : str + Default key for node bounding boxes. + For a 2D image, the bounding box is a tuple of (x_start, y_start, x_end, y_end). + For a 3D image, the bounding box is a tuple of (x_start, y_start, z_start, x_end, y_end, z_end). SOLUTION : str Default key for solution information. TRACK_ID : str @@ -50,6 +54,7 @@ class DefaultAttrKeys: NODE_ID = "node_id" T = "t" MASK = "mask" + BBOX = "bbox" SOLUTION = "solution" TRACK_ID = "track_id" diff --git a/src/tracksdata/nodes/_regionprops.py b/src/tracksdata/nodes/_regionprops.py index 42cac025..70899ac4 100644 --- a/src/tracksdata/nodes/_regionprops.py +++ b/src/tracksdata/nodes/_regionprops.py @@ -120,8 +120,9 @@ def _init_node_attrs(self, graph: BaseGraph, axis_names: list[str]) -> None: """ Initialize the node attributes for the graph. """ - if DEFAULT_ATTR_KEYS.MASK not in graph.node_attr_keys: - graph.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, None) + for attr_key in [DEFAULT_ATTR_KEYS.MASK, DEFAULT_ATTR_KEYS.BBOX]: + if attr_key not in graph.node_attr_keys: + graph.add_node_attr_key(attr_key, None) # initialize the attribute keys for attr_key in axis_names + self.attrs_keys(): @@ -288,6 +289,7 @@ def _nodes_per_time( attrs[prop] = getattr(obj, prop) attrs[DEFAULT_ATTR_KEYS.MASK] = Mask(obj.image, obj.bbox) + attrs[DEFAULT_ATTR_KEYS.BBOX] = np.asarray(obj.bbox, dtype=int) attrs[DEFAULT_ATTR_KEYS.T] = t nodes_data.append(attrs) diff --git a/src/tracksdata/nodes/_test/test_regionprops.py b/src/tracksdata/nodes/_test/test_regionprops.py index afcf7baf..7fe8466c 100644 --- a/src/tracksdata/nodes/_test/test_regionprops.py +++ b/src/tracksdata/nodes/_test/test_regionprops.py @@ -252,9 +252,11 @@ def test_regionprops_spacing() -> None: # Check that nodes were added (spacing affects internal calculations) nodes_df = graph.node_attrs() + assert len(nodes_df) == 1 assert "area" in nodes_df.columns assert DEFAULT_ATTR_KEYS.MASK in nodes_df.columns + assert nodes_df[DEFAULT_ATTR_KEYS.BBOX].to_numpy().ndim == 2 def test_regionprops_empty_labels() -> None: