diff --git a/.github/workflows/test_conda.yml b/.github/workflows/test_conda.yml index 127ac9a..fae89da 100644 --- a/.github/workflows/test_conda.yml +++ b/.github/workflows/test_conda.yml @@ -27,11 +27,11 @@ jobs: steps: - uses: actions/checkout@v4 - #- name: Install Headless Plotting Libs - # if: runner.os == 'Linux' - # run: | - # sudo apt-get update - # sudo apt-get install -y libosmesa6 libgl1 + - name: Install Headless Plotting Lib (Linux) + if: runner.os == 'Linux' + run: | + sudo apt-get update + sudo apt-get install -y libosmesa6 libgl1 - name: Setup conda-forge uses: conda-incubator/setup-miniconda@v3 @@ -47,14 +47,18 @@ jobs: - name: Install Missing Tools (Linux/Win) if: runner.os != 'macOS' run: | - conda install snakemake mesalib + conda install snakemake mesalib pytest # 2. Install for macOS (Excludes mesalib) - name: Install Missing Tools (macOS) if: runner.os == 'macOS' run: | - conda install snakemake - + conda install snakemake pytest + + - name: Run unit tests + run: | + python -m pytest tests + - name: Run snakemake run: | snakemake --cores 2 -p --configfile ./config_files/test.yml diff --git a/.gitignore b/.gitignore index a4f6ce5..4d7e53a 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,8 @@ results results/ logs/ +*.egg-info + fTetWild zipped zipped_surf diff --git a/src/emimesh/process_image_data.py b/src/emimesh/process_image_data.py index e7d106a..cbda3f6 100644 --- a/src/emimesh/process_image_data.py +++ b/src/emimesh/process_image_data.py @@ -14,13 +14,15 @@ def mergecells(img, labels): def ncells(img, ncells, keep_cell_labels=None): cell_labels, cell_counts = fastremap.unique(img, return_counts=True) - cell_labels = cell_labels[np.argsort(cell_counts)] - if keep_cell_labels is None: cois =[] - cois = set(keep_cell_labels) + cell_labels = cell_labels[np.argsort(cell_counts)][::-1] + if keep_cell_labels is None: + cois = set() + else: + cois = set(keep_cell_labels) for cid in cell_labels: if len(cois) >= ncells: break cois.add(cid) - img = np.where(np.isin(img, cois), img, 0) + img = np.where(np.isin(img, list(cois)), img, 0) return img def dilate(img, radius, labels=None): diff --git a/src/emimesh/utils.py b/src/emimesh/utils.py index 0678b7b..e93f964 100644 --- a/src/emimesh/utils.py +++ b/src/emimesh/utils.py @@ -3,9 +3,14 @@ import fastremap -def np2pv(arr, resolution, roimask=None): +def np2pv(arr, resolution, roimask=None, as_point_data=False): + dimensions = arr.shape + if not as_point_data: dimensions += + np.array([1, 1, 1]) + grid = pv.ImageData( - dimensions=arr.shape + np.array((1, 1, 1)), spacing=resolution, origin=(0, 0, 0) + dimensions=dimensions, + spacing=resolution, + origin=(0, 0, 0) ) grid[f"data"] = arr.flatten(order="F") if roimask is not None: diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..819c646 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,53 @@ +"""Global pytest configuration and fixtures for EMIMesh testing.""" +import pytest +import numpy as np +import pyvista as pv +from pathlib import Path +import tempfile +import shutil + + +@pytest.fixture +def temp_dir(): + """Create a temporary directory for test files.""" + temp_path = Path(tempfile.mkdtemp()) + yield temp_path + shutil.rmtree(temp_path) + + +@pytest.fixture +def sample_image_data(): + """Create sample image data for testing.""" + # Create a 3D array with some labeled regions + data = np.zeros((50, 50, 50), dtype=np.uint32) + + # Add some labeled cells + data[10:20, 10:20, 10:20] = 1 + data[30:40, 30:40, 30:40] = 2 + data[15:25, 35:45, 15:25] = 3 + + return data + + +@pytest.fixture +def sample_resolution(): + """Sample resolution for testing.""" + return [18.0, 18.0, 18.0] + + +@pytest.fixture +def sample_pyvista_grid(sample_image_data, sample_resolution): + """Create a sample PyVista grid for testing.""" + grid = pv.ImageData( + dimensions=sample_image_data.shape + np.array([1, 1, 1]), + spacing=sample_resolution, + origin=(0, 0, 0) + ) + grid["data"] = sample_image_data.flatten(order="F") + return grid + + +@pytest.fixture +def test_data_dir(): + """Path to test data directory.""" + return Path(__file__).parent / "data" diff --git a/tests/test_extract_surfaces.py b/tests/test_extract_surfaces.py new file mode 100644 index 0000000..4ee9259 --- /dev/null +++ b/tests/test_extract_surfaces.py @@ -0,0 +1,80 @@ +"""Tests for emimesh.extract_surfaces module.""" +import numpy as np +import pyvista as pv + +from emimesh.extract_surfaces import ( + extract_surface, create_balanced_csg_tree, + clean_mesh_nan_points +) +from emimesh.utils import np2pv + + +class TestExtractSurface: + """Test surface extraction functionality.""" + + def test_extract_surface_basic(self): + """Test basic surface extraction.""" + # Create a simple test volume + mask = np.zeros((10, 10, 10), dtype=np.uint32) + mask[2:7, 2:8, 2:8] = 1 # Cube in the center + + grid = pv.ImageData(dimensions=mask.shape, spacing=(1,1,1), origin=(0, 0, 0)) + print(grid) + print(mask) + result = extract_surface(mask, grid, mesh_reduction_factor=2, taubin_smooth_iter=5) + + # Should return a valid mesh + assert isinstance(result, pv.PolyData) + assert result.is_manifold + assert not np.isnan(result.points).any() + + + def test_extract_surface_too_small(self): + """Test surface extraction with too small volume.""" + mask = np.zeros((10, 10, 10), dtype=np.uint32) + mask[5, 5, 5] = 1 # Single voxel + + grid = pv.ImageData(dimensions=(10, 10, 10), spacing=(1, 1, 1)) + + result = extract_surface(mask, grid, mesh_reduction_factor=10, taubin_smooth_iter=5) + + assert result is False + +class TestCSGTree: + """Test CSG tree creation.""" + + def test_create_balanced_csg_tree_single(self): + """Test CSG tree creation with single surface.""" + surface_files = ["surface1.ply"] + + result = create_balanced_csg_tree(surface_files) + + assert result == "surface1.ply" + + def test_create_balanced_csg_tree_two(self): + """Test CSG tree creation with two surfaces.""" + surface_files = ["surface1.ply", "surface2.ply"] + + result = create_balanced_csg_tree(surface_files) + + expected = { + "operation": "union", + "left": "surface1.ply", + "right": "surface2.ply" + } + assert result == expected + + def test_create_balanced_csg_tree_multiple(self): + """Test CSG tree creation with multiple surfaces.""" + surface_files = ["s1.ply", "s2.ply", "s3.ply", "s4.ply"] + + result = create_balanced_csg_tree(surface_files) + + # Should be a balanced tree structure + assert result["operation"] == "union" + assert "left" in result + assert "right" in result + + # Left and right should each contain 2 surfaces + assert isinstance(result["left"], dict) or isinstance(result["left"], str) + assert isinstance(result["right"], dict) or isinstance(result["right"], str) \ No newline at end of file diff --git a/tests/test_process_image_data.py b/tests/test_process_image_data.py new file mode 100644 index 0000000..88e6615 --- /dev/null +++ b/tests/test_process_image_data.py @@ -0,0 +1,162 @@ +"""Tests for emimesh.process_image_data module.""" +import numpy as np +from unittest.mock import patch +from emimesh.process_image_data import ( + mergecells, ncells, dilate, erode, smooth, removeislands, + opdict, parse_operations, _parse_to_dict +) + +class TestImageOperations: + """Test individual image processing operations.""" + + def test_mergecells_basic(self): + """Test basic cell merging.""" + img = np.array([[[1, 1, 2], [1, 3, 2], [4, 4, 4]]], dtype=np.uint32) + labels = [1, 2] + + result = mergecells(img, labels) + + # All 1s and 2s should become the first label (1) + non_zero_values = result[result > 0] + unique_values = np.unique(non_zero_values) + + # Should only have values 1, 3, 4 (1 and 2 merged to 1) + assert set(unique_values) == {1, 3, 4} + assert 3 in result # 3 should remain unchanged + assert 4 in result # 4 should remain unchanged + + def test_ncells_basic(self): + """Test keeping only N largest cells.""" + img = np.array([[[1, 1, 2], [1, 3, 2], [4, 4, 4]]], dtype=np.uint32) + + result = ncells(img, ncells=2) + + # Should keep only background (0) and the two largest cells (1 and 4) + assert np.allclose(np.unique(result), np.array([0, 1,4])) + + def test_ncells_with_keep_labels(self): + """Test keeping specific cells regardless of size.""" + img = np.array([[[1, 1, 2], [1, 3, 2], [4, 4, 4]]], dtype=np.uint32) + keep_labels = [2] + + result = ncells(img, ncells=1, keep_cell_labels=keep_labels) + + # Should only have background (0) and the kept label (2) + assert np.allclose(np.unique(result), np.array([0, 2])) + + def test_removeislands_basic(self): + """Test removing small islands.""" + # Create an image with small and large connected components + img = np.zeros((10, 10, 10), dtype=np.uint32) + img[2:4, 2:4, 2:4] = 1 # Small island (8 voxels) + img[6:9, 6:9, 6:9] = 2 # Large island (27 voxels) + + result = removeislands(img, minsize=10) + + # Small island should be removed, large one should remain + assert 1 not in np.unique(result) + assert 2 in np.unique(result) + + +class TestOperationDictionary: + """Test the operation dictionary.""" + + def test_opdict_contains_all_operations(self): + """Test that opdict contains all expected operations.""" + expected_ops = ["merge", "smooth", "dilate", "erode", "removeislands", "ncells"] + + for op in expected_ops: + assert op in opdict + assert callable(opdict[op]) + + +class TestParseOperations: + """Test operation parsing functionality.""" + + def test_parse_to_dict_basic(self): + """Test basic dictionary parsing.""" + values = ["key1='value1'", "key2=42", "key3=True"] + + result = _parse_to_dict(values) + + assert result["key1"] == "value1" + assert result["key2"] == 42 + assert result["key3"] is True + + def test_parse_to_dict_with_lists(self): + """Test parsing with list values.""" + values = ["labels='[1, 2, 3]'", "radius=5"] + + result = _parse_to_dict(values) + + assert result["labels"] == [1, 2, 3] + assert result["radius"] == 5 + + def test_parse_operations_basic(self): + """Test basic operation parsing.""" + ops = [["merge", "labels='[1, 2]'", "radius=5"]] + + result = parse_operations(ops) + + assert len(result) == 1 + assert result[0][0] == "merge" + assert result[0][1]["labels"] == [1, 2] + assert result[0][1]["radius"] == 5 + + def test_parse_operations_multiple(self): + """Test parsing multiple operations.""" + ops = [ + ["merge", "labels='[1, 2]'"], + ["removeislands", "minsize=100"], + ["dilate", "radius=3"] + ] + + result = parse_operations(ops) + + assert len(result) == 3 + assert result[0][0] == "merge" + assert result[1][0] == "removeislands" + assert result[2][0] == "dilate" + + +class TestImageProcessingIntegration: + """Integration tests for image processing operations.""" + + def test_dilate_operation(self): + """Test dilation operation.""" + img = np.zeros((10, 10, 10), dtype=np.uint32) + img[4:6, 4:6, 4:6] = 1 + + # Mock nbmorph.dilate_labels_spherical to avoid dependency + with patch('emimesh.process_image_data.nbmorph') as mock_nbmorph: + mock_nbmorph.dilate_labels_spherical.return_value = img # Return same for simplicity + + result = dilate(img, radius=2) + + mock_nbmorph.dilate_labels_spherical.assert_called_once_with(img, radius=2) + + def test_erode_operation(self): + """Test erosion operation.""" + img = np.ones((10, 10, 10), dtype=np.uint32) + + # Mock nbmorph.erode_labels_spherical to avoid dependency + with patch('emimesh.process_image_data.nbmorph') as mock_nbmorph: + mock_nbmorph.erode_labels_spherical.return_value = img # Return same for simplicity + + result = erode(img, radius=2) + + mock_nbmorph.erode_labels_spherical.assert_called_once_with(img, radius=2) + + def test_smooth_operation(self): + """Test smoothing operation.""" + img = np.ones((10, 10, 10), dtype=np.uint32) + + # Mock nbmorph.smooth_labels_spherical to avoid dependency + with patch('emimesh.process_image_data.nbmorph') as mock_nbmorph: + mock_nbmorph.smooth_labels_spherical.return_value = img # Return same for simplicity + + result = smooth(img, iterations=5, radius=3) + + mock_nbmorph.smooth_labels_spherical.assert_called_once_with( + img, radius=3, iterations=5, dilate_radius=3 + ) \ No newline at end of file diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..2a37a00 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,53 @@ +"""Tests for emimesh.utils module.""" +import numpy as np +import pyvista as pv +from emimesh.utils import np2pv, get_cell_frequencies + + +class TestNp2pv: + """Test the np2pv function.""" + + def test_np2pv_basic(self, sample_image_data, sample_resolution): + """Test basic np2pv functionality.""" + grid = np2pv(sample_image_data, sample_resolution) + + assert isinstance(grid, pv.ImageData) + assert np.array_equal(grid.dimensions, sample_image_data.shape + np.array([1, 1, 1])) + assert np.array_equal(grid.spacing, sample_resolution) + assert grid.origin == (0, 0, 0) + assert "data" in grid.array_names + assert grid["data"].shape == (sample_image_data.size,) + + def test_np2pv_with_roimask(self, sample_image_data, sample_resolution): + """Test np2pv with roimask.""" + roimask = np.ones_like(sample_image_data, dtype=bool) + roimask[10:20, 10:20, 10:20] = False + + grid = np2pv(sample_image_data, sample_resolution, roimask=roimask) + + assert "roimask" in grid.array_names + assert grid["roimask"].shape == (roimask.size,) + +class TestGetCellFrequencies: + """Test the get_cell_frequencies function.""" + + def test_get_cell_frequencies_basic(self, sample_image_data): + """Test basic cell frequency calculation.""" + frequencies = get_cell_frequencies(sample_image_data) + + assert frequencies.shape[0] == 2 # labels and counts + assert frequencies.shape[1] >= 3 # at least 3 unique values (0, 1, 2, 3) + + # Check that labels are sorted by frequency + counts = frequencies[1] + assert np.all(counts[:-1] <= counts[1:]) + + def test_get_cell_frequencies_single_cell(self): + """Test with single cell.""" + data = np.zeros((10, 10, 10), dtype=np.uint32) + data[2:8, 2:8, 2:8] = 5 + + frequencies = get_cell_frequencies(data) + + assert frequencies.shape[1] == 2 # background and cell 5 + assert 5 in frequencies[0] \ No newline at end of file