Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
d748003
Add intensity masking script
constantinpape Sep 2, 2025
de13181
Add stardist postprocessing script
constantinpape Sep 3, 2025
2d929c4
Fixed output path
schilling40 Sep 3, 2025
cab37c0
Implement splitting of non-convex objects
constantinpape Sep 8, 2025
43bfda9
Merge branch 'intensity-masking' of https://github.com/computational-…
constantinpape Sep 8, 2025
663d82e
Improve train-val splits
constantinpape Sep 8, 2025
e698bb1
Implement SGN detection
constantinpape Sep 9, 2025
0534db2
Merge branch 'intensity-masking' of https://github.com/computational-…
constantinpape Sep 9, 2025
5d6955e
Implement la-vision WS prototype
constantinpape Sep 11, 2025
f797bc7
Merge branch 'intensity-masking' of https://github.com/computational-…
constantinpape Sep 11, 2025
66ed39b
Update to sgn detection training
constantinpape Sep 11, 2025
7507d91
Implement more debugging for detection model
constantinpape Sep 11, 2025
8c68d19
Add otof import scripts
constantinpape Sep 12, 2025
49fb49f
Update SGN Subtype analysis
constantinpape Sep 16, 2025
7d92250
Merge branch 'intensity-masking' of https://github.com/computational-…
constantinpape Sep 16, 2025
94774a0
Update subtype analysis
constantinpape Sep 17, 2025
c0d2820
Merge branch 'master' of https://github.com/computational-cell-analyt…
constantinpape Sep 18, 2025
8dc5a4c
Update SGN detection model training
constantinpape Sep 18, 2025
cf7d738
Update sgn detection training
constantinpape Sep 19, 2025
920230b
Merge branch 'intensity-masking' of https://github.com/computational-…
constantinpape Sep 20, 2025
0b7e4b2
Update sgn detection training
constantinpape Sep 20, 2025
0c8b97f
Update sgn training
constantinpape Sep 20, 2025
a6b2de8
Implement IHC grid search WIP
constantinpape Sep 24, 2025
4473d9d
Merge branch 'intensity-masking' of https://github.com/computational-…
constantinpape Sep 24, 2025
4566e99
Update SGN subtype scripts
constantinpape Oct 1, 2025
a9780d1
Minor updates
constantinpape Oct 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion flamingo_tools/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import zarr
from elf.io import open_file

from .s3_utils import get_s3_path

try:
from zarr.abc.store import Store
except ImportError:
Expand Down Expand Up @@ -67,7 +69,9 @@ def read_tif(file_path: str) -> Union[np.ndarray, np.memmap]:
return x


def read_image_data(input_path: Union[str, Store], input_key: Optional[str]) -> np.typing.ArrayLike:
def read_image_data(
input_path: Union[str, Store], input_key: Optional[str], from_s3: bool = False
) -> np.typing.ArrayLike:
"""Read flamingo image data, stored in various formats.
Args:
Expand All @@ -76,10 +80,16 @@ def read_image_data(input_path: Union[str, Store], input_key: Optional[str]) ->
Access via S3 is only supported for a zarr container.
input_key: The key (= internal path) for a zarr or n5 container.
Set it to None if the data is stored in a tif file.
from_s3: Whether to read the data from S3.
Returns:
The data, loaded either as a numpy mem-map, a numpy array, or a zarr / n5 array.
"""
if from_s3:
assert input_key is not None
s3_store, fs = get_s3_path(input_path)
return zarr.open(s3_store, mode="r")[input_key]

if input_key is None:
input_ = read_tif(input_path)
elif isinstance(input_path, str):
Expand Down
17 changes: 10 additions & 7 deletions flamingo_tools/measurements.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import warnings
from concurrent import futures
from functools import partial
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, Union

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -60,10 +60,13 @@ def _get_bounding_box_and_center(table, seg_id, resolution, shape, dilation):
for bmin, bmax, sh in zip(bb_min, bb_max, shape)
)

if isinstance(resolution, float):
resolution = (resolution,) * 3

center = (
int(row.anchor_z.item() / resolution),
int(row.anchor_y.item() / resolution),
int(row.anchor_x.item() / resolution),
int(row.anchor_z.item() / resolution[0]),
int(row.anchor_y.item() / resolution[1]),
int(row.anchor_x.item() / resolution[2]),
)

return bb, center
Expand Down Expand Up @@ -307,7 +310,7 @@ def compute_object_measures(
image_key: Optional[str] = None,
segmentation_key: Optional[str] = None,
n_threads: Optional[int] = None,
resolution: float = 0.38,
resolution: Union[float, Tuple[float, ...]] = 0.38,
force: bool = False,
feature_set: str = "default",
s3_flag: bool = False,
Expand Down Expand Up @@ -359,8 +362,8 @@ def compute_object_measures(
table = table[table["component_labels"].isin(component_list)]

# Then, open the volumes.
image = read_image_data(image_path, image_key)
segmentation = read_image_data(segmentation_path, segmentation_key)
image = read_image_data(image_path, image_key, from_s3=s3_flag)
segmentation = read_image_data(segmentation_path, segmentation_key, from_s3=s3_flag)

measures = compute_object_measures_impl(
image, segmentation, n_threads, resolution, table=table, feature_set=feature_set,
Expand Down
137 changes: 136 additions & 1 deletion flamingo_tools/segmentation/postprocessing.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import math
import multiprocessing as mp
import threading
from concurrent import futures
from typing import Callable, List, Optional, Tuple
from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union

import elf.parallel as parallel
import numpy as np
Expand All @@ -15,6 +16,9 @@
from scipy.spatial import distance
from scipy.spatial import cKDTree, ConvexHull
from skimage import measure
from skimage.filters import gaussian
from skimage.feature import peak_local_max
from skimage.segmentation import find_boundaries, watershed
from sklearn.neighbors import NearestNeighbors
from tqdm import tqdm

Expand Down Expand Up @@ -732,3 +736,134 @@ def filter_cochlea_volume(
combined_dilated[combined_dilated > 0] = 1

return combined_dilated


def split_nonconvex_objects(
segmentation: np.typing.ArrayLike,
output: np.typing.ArrayLike,
segmentation_table: pd.DataFrame,
min_size: int,
resolution: Union[float, Sequence[float]],
height_map: Optional[np.typing.ArrayLike] = None,
component_labels: Optional[List[int]] = None,
n_threads: Optional[int] = None,
) -> Dict[int, List[int]]:
"""Split noncovex objects into multiple parts inplace.

Args:
segmentation:
output:
segmentation_table:
min_size:
resolution:
height_map:
component_labels:
n_threads:
"""
if isinstance(resolution, float):
resolution = [resolution] * 3
assert len(resolution) == 3
resolution = np.array(resolution)

lock = threading.Lock()
offset = len(segmentation_table)

def split_object(object_id):
nonlocal offset

row = segmentation_table[segmentation_table.label_id == object_id]
if row.n_pixels.values[0] < min_size:
# print(object_id, ": min-size")
return [object_id]

bb_min = np.array([
row.bb_min_z.values[0], row.bb_min_y.values[0], row.bb_min_x.values[0],
]) / resolution
bb_max = np.array([
row.bb_max_z.values[0], row.bb_max_y.values[0], row.bb_max_x.values[0],
]) / resolution

bb_min = np.maximum(bb_min.astype(int) - 1, np.array([0, 0, 0]))
bb_max = np.minimum(bb_max.astype(int) + 1, np.array(list(segmentation.shape)))
bb = tuple(slice(mi, ma) for mi, ma in zip(bb_min, bb_max))

# This is due to segmentation artifacts.
bb_shape = bb_max - bb_min
if (bb_shape > 500).any():
print(object_id, "has a too large shape:", bb_shape)
return [object_id]

seg = segmentation[bb]
mask = ~find_boundaries(seg)
dist = distance_transform_edt(mask, sampling=resolution)

seg_mask = seg == object_id
dist[~seg_mask] = 0
dist = gaussian(dist, (0.6, 1.2, 1.2))
maxima = peak_local_max(dist, min_distance=3, exclude_border=True)

if len(maxima) == 1:
# print(object_id, ": max len")
return [object_id]

with lock:
old_offset = offset
offset += len(maxima)

seeds = np.zeros(seg.shape, dtype=int)
for i, pos in enumerate(maxima, 1):
seeds[tuple(pos)] = old_offset + i

if height_map is None:
hmap = dist.max() - dist
else:
hmap = height_map[bb]
new_seg = watershed(hmap, markers=seeds, mask=seg_mask)

seg_ids, sizes = np.unique(new_seg, return_counts=True)
seg_ids, sizes = seg_ids[1:], sizes[1:]

keep_ids = seg_ids[sizes > min_size]
if len(keep_ids) < 2:
# print(object_id, ": keep-id")
return [object_id]

elif len(keep_ids) != len(seg_ids):
new_seg[~np.isin(new_seg, keep_ids)] = 0
new_seg = watershed(hmap, markers=new_seg, mask=seg_mask)

with lock:
out = output[bb]
out[seg_mask] = new_seg[seg_mask]
output[bb] = out

# print(object_id, ":", len(keep_ids))
return keep_ids.tolist()

# import napari
# v = napari.Viewer()
# v.add_image(hmap)
# v.add_labels(seg)
# v.add_labels(new_seg)
# v.add_points(maxima)
# napari.run()

if component_labels is None:
object_ids = segmentation_table.label_id.values
else:
object_ids = segmentation_table[segmentation_table.component_labels.isin(component_labels)].label_id.values

if n_threads is None:
n_threads = mp.cpu_count()

# new_id_mapping = []
# for object_id in tqdm(object_ids, desc="Split non-convex objects"):
# new_id_mapping.append(split_object(object_id))

with futures.ThreadPoolExecutor(n_threads) as tp:
new_id_mapping = list(
tqdm(tp.map(split_object, object_ids), total=len(object_ids), desc="Split non-convex objects")
)

new_id_mapping = {object_id: mapped_ids for object_id, mapped_ids in zip(object_ids, new_id_mapping)}
return new_id_mapping
Loading
Loading