Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 12 additions & 8 deletions .github/workflows/test_conda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@ jobs:
steps:
- uses: actions/checkout@v4

#- name: Install Headless Plotting Libs
# if: runner.os == 'Linux'
# run: |
# sudo apt-get update
# sudo apt-get install -y libosmesa6 libgl1
- name: Install Headless Plotting Lib (Linux)
if: runner.os == 'Linux'
run: |
sudo apt-get update
sudo apt-get install -y libosmesa6 libgl1

- name: Setup conda-forge
uses: conda-incubator/setup-miniconda@v3
Expand All @@ -47,14 +47,18 @@ jobs:
- name: Install Missing Tools (Linux/Win)
if: runner.os != 'macOS'
run: |
conda install snakemake mesalib
conda install snakemake mesalib pytest

# 2. Install for macOS (Excludes mesalib)
- name: Install Missing Tools (macOS)
if: runner.os == 'macOS'
run: |
conda install snakemake

conda install snakemake pytest

- name: Run unit tests
run: |
python -m pytest tests

- name: Run snakemake
run: |
snakemake --cores 2 -p --configfile ./config_files/test.yml
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ results
results/
logs/

*.egg-info

fTetWild
zipped
zipped_surf
Expand Down
10 changes: 6 additions & 4 deletions src/emimesh/process_image_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@ def mergecells(img, labels):

def ncells(img, ncells, keep_cell_labels=None):
cell_labels, cell_counts = fastremap.unique(img, return_counts=True)
cell_labels = cell_labels[np.argsort(cell_counts)]
if keep_cell_labels is None: cois =[]
cois = set(keep_cell_labels)
cell_labels = cell_labels[np.argsort(cell_counts)][::-1]
if keep_cell_labels is None:
cois = set()
else:
cois = set(keep_cell_labels)
for cid in cell_labels:
if len(cois) >= ncells: break
cois.add(cid)
img = np.where(np.isin(img, cois), img, 0)
img = np.where(np.isin(img, list(cois)), img, 0)
return img

def dilate(img, radius, labels=None):
Expand Down
9 changes: 7 additions & 2 deletions src/emimesh/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,14 @@
import fastremap


def np2pv(arr, resolution, roimask=None):
def np2pv(arr, resolution, roimask=None, as_point_data=False):
dimensions = arr.shape
if not as_point_data: dimensions += + np.array([1, 1, 1])

grid = pv.ImageData(
dimensions=arr.shape + np.array((1, 1, 1)), spacing=resolution, origin=(0, 0, 0)
dimensions=dimensions,
spacing=resolution,
origin=(0, 0, 0)
)
grid[f"data"] = arr.flatten(order="F")
if roimask is not None:
Expand Down
53 changes: 53 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
"""Global pytest configuration and fixtures for EMIMesh testing."""
import pytest
import numpy as np
import pyvista as pv
from pathlib import Path
import tempfile
import shutil


@pytest.fixture
def temp_dir():
"""Create a temporary directory for test files."""
temp_path = Path(tempfile.mkdtemp())
yield temp_path
shutil.rmtree(temp_path)


@pytest.fixture
def sample_image_data():
"""Create sample image data for testing."""
# Create a 3D array with some labeled regions
data = np.zeros((50, 50, 50), dtype=np.uint32)

# Add some labeled cells
data[10:20, 10:20, 10:20] = 1
data[30:40, 30:40, 30:40] = 2
data[15:25, 35:45, 15:25] = 3

return data


@pytest.fixture
def sample_resolution():
"""Sample resolution for testing."""
return [18.0, 18.0, 18.0]


@pytest.fixture
def sample_pyvista_grid(sample_image_data, sample_resolution):
"""Create a sample PyVista grid for testing."""
grid = pv.ImageData(
dimensions=sample_image_data.shape + np.array([1, 1, 1]),
spacing=sample_resolution,
origin=(0, 0, 0)
)
grid["data"] = sample_image_data.flatten(order="F")
return grid


@pytest.fixture
def test_data_dir():
"""Path to test data directory."""
return Path(__file__).parent / "data"
80 changes: 80 additions & 0 deletions tests/test_extract_surfaces.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
"""Tests for emimesh.extract_surfaces module."""
import numpy as np
import pyvista as pv

from emimesh.extract_surfaces import (
extract_surface, create_balanced_csg_tree,
clean_mesh_nan_points
)
from emimesh.utils import np2pv


class TestExtractSurface:
"""Test surface extraction functionality."""

def test_extract_surface_basic(self):
"""Test basic surface extraction."""
# Create a simple test volume
mask = np.zeros((10, 10, 10), dtype=np.uint32)
mask[2:7, 2:8, 2:8] = 1 # Cube in the center

grid = pv.ImageData(dimensions=mask.shape, spacing=(1,1,1), origin=(0, 0, 0))
print(grid)
print(mask)
result = extract_surface(mask, grid, mesh_reduction_factor=2, taubin_smooth_iter=5)

# Should return a valid mesh
assert isinstance(result, pv.PolyData)
assert result.is_manifold
assert not np.isnan(result.points).any()


def test_extract_surface_too_small(self):
"""Test surface extraction with too small volume."""
mask = np.zeros((10, 10, 10), dtype=np.uint32)
mask[5, 5, 5] = 1 # Single voxel

grid = pv.ImageData(dimensions=(10, 10, 10), spacing=(1, 1, 1))

result = extract_surface(mask, grid, mesh_reduction_factor=10, taubin_smooth_iter=5)

assert result is False

class TestCSGTree:
"""Test CSG tree creation."""

def test_create_balanced_csg_tree_single(self):
"""Test CSG tree creation with single surface."""
surface_files = ["surface1.ply"]

result = create_balanced_csg_tree(surface_files)

assert result == "surface1.ply"

def test_create_balanced_csg_tree_two(self):
"""Test CSG tree creation with two surfaces."""
surface_files = ["surface1.ply", "surface2.ply"]

result = create_balanced_csg_tree(surface_files)

expected = {
"operation": "union",
"left": "surface1.ply",
"right": "surface2.ply"
}
assert result == expected

def test_create_balanced_csg_tree_multiple(self):
"""Test CSG tree creation with multiple surfaces."""
surface_files = ["s1.ply", "s2.ply", "s3.ply", "s4.ply"]

result = create_balanced_csg_tree(surface_files)

# Should be a balanced tree structure
assert result["operation"] == "union"
assert "left" in result
assert "right" in result

# Left and right should each contain 2 surfaces
assert isinstance(result["left"], dict) or isinstance(result["left"], str)
assert isinstance(result["right"], dict) or isinstance(result["right"], str)
162 changes: 162 additions & 0 deletions tests/test_process_image_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
"""Tests for emimesh.process_image_data module."""
import numpy as np
from unittest.mock import patch
from emimesh.process_image_data import (
mergecells, ncells, dilate, erode, smooth, removeislands,
opdict, parse_operations, _parse_to_dict
)

class TestImageOperations:
"""Test individual image processing operations."""

def test_mergecells_basic(self):
"""Test basic cell merging."""
img = np.array([[[1, 1, 2], [1, 3, 2], [4, 4, 4]]], dtype=np.uint32)
labels = [1, 2]

result = mergecells(img, labels)

# All 1s and 2s should become the first label (1)
non_zero_values = result[result > 0]
unique_values = np.unique(non_zero_values)

# Should only have values 1, 3, 4 (1 and 2 merged to 1)
assert set(unique_values) == {1, 3, 4}
assert 3 in result # 3 should remain unchanged
assert 4 in result # 4 should remain unchanged

def test_ncells_basic(self):
"""Test keeping only N largest cells."""
img = np.array([[[1, 1, 2], [1, 3, 2], [4, 4, 4]]], dtype=np.uint32)

result = ncells(img, ncells=2)

# Should keep only background (0) and the two largest cells (1 and 4)
assert np.allclose(np.unique(result), np.array([0, 1,4]))

def test_ncells_with_keep_labels(self):
"""Test keeping specific cells regardless of size."""
img = np.array([[[1, 1, 2], [1, 3, 2], [4, 4, 4]]], dtype=np.uint32)
keep_labels = [2]

result = ncells(img, ncells=1, keep_cell_labels=keep_labels)

# Should only have background (0) and the kept label (2)
assert np.allclose(np.unique(result), np.array([0, 2]))

def test_removeislands_basic(self):
"""Test removing small islands."""
# Create an image with small and large connected components
img = np.zeros((10, 10, 10), dtype=np.uint32)
img[2:4, 2:4, 2:4] = 1 # Small island (8 voxels)
img[6:9, 6:9, 6:9] = 2 # Large island (27 voxels)

result = removeislands(img, minsize=10)

# Small island should be removed, large one should remain
assert 1 not in np.unique(result)
assert 2 in np.unique(result)


class TestOperationDictionary:
"""Test the operation dictionary."""

def test_opdict_contains_all_operations(self):
"""Test that opdict contains all expected operations."""
expected_ops = ["merge", "smooth", "dilate", "erode", "removeislands", "ncells"]

for op in expected_ops:
assert op in opdict
assert callable(opdict[op])


class TestParseOperations:
"""Test operation parsing functionality."""

def test_parse_to_dict_basic(self):
"""Test basic dictionary parsing."""
values = ["key1='value1'", "key2=42", "key3=True"]

result = _parse_to_dict(values)

assert result["key1"] == "value1"
assert result["key2"] == 42
assert result["key3"] is True

def test_parse_to_dict_with_lists(self):
"""Test parsing with list values."""
values = ["labels='[1, 2, 3]'", "radius=5"]

result = _parse_to_dict(values)

assert result["labels"] == [1, 2, 3]
assert result["radius"] == 5

def test_parse_operations_basic(self):
"""Test basic operation parsing."""
ops = [["merge", "labels='[1, 2]'", "radius=5"]]

result = parse_operations(ops)

assert len(result) == 1
assert result[0][0] == "merge"
assert result[0][1]["labels"] == [1, 2]
assert result[0][1]["radius"] == 5

def test_parse_operations_multiple(self):
"""Test parsing multiple operations."""
ops = [
["merge", "labels='[1, 2]'"],
["removeislands", "minsize=100"],
["dilate", "radius=3"]
]

result = parse_operations(ops)

assert len(result) == 3
assert result[0][0] == "merge"
assert result[1][0] == "removeislands"
assert result[2][0] == "dilate"


class TestImageProcessingIntegration:
"""Integration tests for image processing operations."""

def test_dilate_operation(self):
"""Test dilation operation."""
img = np.zeros((10, 10, 10), dtype=np.uint32)
img[4:6, 4:6, 4:6] = 1

# Mock nbmorph.dilate_labels_spherical to avoid dependency
with patch('emimesh.process_image_data.nbmorph') as mock_nbmorph:
mock_nbmorph.dilate_labels_spherical.return_value = img # Return same for simplicity

result = dilate(img, radius=2)

mock_nbmorph.dilate_labels_spherical.assert_called_once_with(img, radius=2)

def test_erode_operation(self):
"""Test erosion operation."""
img = np.ones((10, 10, 10), dtype=np.uint32)

# Mock nbmorph.erode_labels_spherical to avoid dependency
with patch('emimesh.process_image_data.nbmorph') as mock_nbmorph:
mock_nbmorph.erode_labels_spherical.return_value = img # Return same for simplicity

result = erode(img, radius=2)

mock_nbmorph.erode_labels_spherical.assert_called_once_with(img, radius=2)

def test_smooth_operation(self):
"""Test smoothing operation."""
img = np.ones((10, 10, 10), dtype=np.uint32)

# Mock nbmorph.smooth_labels_spherical to avoid dependency
with patch('emimesh.process_image_data.nbmorph') as mock_nbmorph:
mock_nbmorph.smooth_labels_spherical.return_value = img # Return same for simplicity

result = smooth(img, iterations=5, radius=3)

mock_nbmorph.smooth_labels_spherical.assert_called_once_with(
img, radius=3, iterations=5, dilate_radius=3
)
Loading
Loading