diff --git a/copick_torch/cli.py b/copick_torch/cli.py deleted file mode 100644 index b69d9ac..0000000 --- a/copick_torch/cli.py +++ /dev/null @@ -1,17 +0,0 @@ -import click - -import copick_torch -from copick_torch.entry_points.run_downsample import downsample -from copick_torch.entry_points.run_membrane_seg import membrain_seg - - -@click.group() -def routines(): - pass - - -routines.add_command(membrain_seg) -routines.add_command(downsample) - -if __name__ == "__main__": - routines() diff --git a/copick_torch/entry_points/run_downsample.py b/copick_torch/entry_points/run_downsample.py index 4f46318..31a05c4 100644 --- a/copick_torch/entry_points/run_downsample.py +++ b/copick_torch/entry_points/run_downsample.py @@ -38,6 +38,18 @@ def downsample( target_resolution: float, delete_source: bool, ): + """ + Downsample tomograms with Fourier Re-Scaling. + """ + + run(config, tomo_alg, voxel_size, target_resolution, delete_source) + + +def run(config, tomo_alg, voxel_size, target_resolution, delete_source): + """ + Runs the downsampling. + """ + import copick from copick_torch import parallelization @@ -57,7 +69,7 @@ def downsample( # Execute try: pool.execute( - run_downsampler, + downsample.run_downsampler, tasks, task_ids=run_ids, progress_desc="Downsampling Tomograms", @@ -69,37 +81,11 @@ def downsample( print("✅ Completed the Downsampling!") -def run_downsampler(run, tomo_alg, voxel_size, target_resolution, delete_source, gpu_id, models): - from copick_utils.io import readers, writers - - # Get the Downsampler - downsampler = models - - # Get the Tomogram - tomo = readers.tomogram(run, voxel_size, tomo_alg) - - # Check if Tomogram Exists - if tomo is None: - print(f"⚠️ Skipping Run {run.name}: No Tomogram found for Algorithm {tomo_alg} at Voxel Size {voxel_size}A") - return - - # Downsample the Tomogram - downsampled_tomo = downsampler.run(tomo) - - # Save the Downsampled Tomogram - writers.tomogram(run, downsampled_tomo, target_resolution, tomo_alg) - - # Delete the source tomograms if requested - if delete_source: - vs = run.get_voxel_spacing(voxel_size) - vs.delete_tomograms(tomo_alg) - - # If the Voxel Spacing is Empty, lets delete it as well - if vs.tomograms == []: - vs.delete() - - def save_parameters(config, tomo_alg, voxel_size, target_resolution): + """ + Save the parameters for the downsampling. + """ + import os import copick diff --git a/copick_torch/entry_points/run_filter3d.py b/copick_torch/entry_points/run_filter3d.py new file mode 100644 index 0000000..73c5c8c --- /dev/null +++ b/copick_torch/entry_points/run_filter3d.py @@ -0,0 +1,257 @@ +import click + + +def low_pass_commands(func): + """Decorator to add common options to a Click command.""" + options = [ + click.option( + "--lp-freq", + type=float, + required=False, + default=0, + help="Low-pass cutoff frequency (in Angstroms)", + ), + click.option("--lp-decay", type=float, required=False, default=50, help="Low-pass decay width (in pixels)"), + click.option( + "--hp-freq", + type=float, + required=False, + default=0, + help="High-pass cutoff frequency (in Angstroms)", + ), + click.option("--hp-decay", type=float, required=False, default=50, help="High-pass decay width (in pixels)"), + click.option( + "--show-filter", + type=bool, + required=False, + default=True, + help="Save the filter as a PNG (filter3d.png)", + ), + ] + for option in reversed(options): # Add options in reverse order to preserve correct order + func = option(func) + return func + + +def copick_commands(func): + """Decorator to add common options to a Click command.""" + options = [ + click.option("--config", type=str, required=True, help="Path to Copick Config for Processing Data"), + click.option( + "--run-ids", + type=str, + required=False, + default=None, + help="Run ID to process (No Input would process the entire dataset.)", + ), + click.option("--tomo-alg", type=str, required=True, help="Tomogram Algorithm to use"), + click.option("--voxel-size", type=float, required=False, default=10, help="Voxel Size to Query the Data"), + ] + for option in reversed(options): # Add options in reverse order to preserve correct order + func = option(func) + return func + + +def input_check(lp_res, hp_res, apix): + """ + lp_res, hp_res: resolutions in Å (0 means 'disabled') + apix: pixel size in Å/pixel + require_filter: if True, disallow the all-pass case (both 0) + """ + nyquist_res = 2.0 * apix # smallest physically valid resolution (Å) + + # All-pass allowed unless explicitly forbidden + if lp_res == 0 and hp_res == 0: + raise ValueError("Low-pass and high-pass cannot both be 0 (no filtering).") + + # Low-pass: if enabled, it cannot be finer (smaller Å) than Nyquist + if lp_res > 0 and lp_res < nyquist_res: + raise ValueError(f"Low-pass resolution {lp_res:.3g} Å is finer than Nyquist (2*apix = {nyquist_res:.3g} Å).") + + # High-pass: if enabled, it also cannot be finer than Nyquist (no frequencies exist beyond Nyquist) + if hp_res > 0 and hp_res < nyquist_res: + raise ValueError(f"High-pass resolution {hp_res:.3g} Å is finer than Nyquist (2*apix = {nyquist_res:.3g} Å).") + + # Band-pass consistency: in Å, larger number = lower spatial frequency + # Require hp > lp when both are enabled + if lp_res > 0 and hp_res > 0 and not (hp_res > lp_res): + raise ValueError("For band-pass, require hp (Å) > lp (Å).") + + +def print_header(lp_freq, lp_decay, hp_freq, hp_decay): + print("----------------------------------------") + print(f"Low-Pass Frequency: {lp_freq} Angstroms") + print(f"Low-Pass Decay: {lp_decay} Pixels") + print(f"High-Pass Frequency: {hp_freq} Angstroms") + print(f"High-Pass Decay: {hp_decay} Pixels") + print("----------------------------------------") + + +@click.command(context_settings={"show_default": True}) +@copick_commands +@low_pass_commands +def bandpass( + config: str, + run_ids: str, + lp_freq: float, + lp_decay: float, + hp_freq: float, + hp_decay: float, + tomo_alg: str, + voxel_size: float, + show_filter: bool, +): + """ + 3D bandpass filter tomograms. + """ + + run_filter3d(config, run_ids, lp_freq, lp_decay, hp_freq, hp_decay, tomo_alg, voxel_size, show_filter) + + +def run_filter3d( + config: str, + run_ids: str, + lp_freq: float, + lp_decay: float, + hp_freq: float, + hp_decay: float, + tomo_alg: str, + voxel_size: float, + show_filter: bool, +): + import os + + import copick + + from copick_torch import parallelization + from copick_torch.filters.bandpass import init_filter3d, run_filter3d + + # Input Check - Set Decay to 0 if Unused + input_check(lp_freq, hp_freq, voxel_size) + if lp_freq == 0: + lp_decay = 0 + if hp_freq == 0: + hp_decay = 0 + + # Load Copick Project + if os.path.exists(config): + root = copick.from_file(config) + else: + raise ValueError(f"Config file {config} does not exist.") + + print_header(lp_freq, lp_decay, hp_freq, hp_decay) + + # Get Run IDs + if run_ids is None: + run_ids = [run.name for run in root.runs] + else: + run_ids = run_ids.split(",") + + # Determine Write Algorithm + write_algorithm = tomo_alg + if lp_freq > 0: + write_algorithm = write_algorithm + f"-lp{lp_freq:0.0f}A" + if hp_freq > 0: + write_algorithm = write_algorithm + f"-hp{hp_freq:0.0f}A" + + # Get Volume Shape + vol_shape = get_tomo_shape(root, run_ids, tomo_alg, voxel_size) + + # Initialize Parallelization Pool + pool = parallelization.GPUPool( + init_fn=init_filter3d, + init_args=(voxel_size, vol_shape, lp_freq, lp_decay, hp_freq, hp_decay), + verbose=True, + ) + # Save Filter Image + if show_filter: + save_filter((voxel_size, vol_shape, lp_freq, lp_decay, hp_freq, hp_decay)) + + # Execute + tasks = [(root.get_run(run), tomo_alg, voxel_size, write_algorithm) for run in run_ids] + try: + pool.execute( + run_filter3d, + tasks, + task_ids=run_ids, + progress_desc="Filtering Tomograms", + ) + finally: + pool.shutdown() + + save_parameters(config, [tomo_alg, voxel_size], [lp_freq, lp_decay, hp_freq, hp_decay], write_algorithm) + print("✅ Completed the Filtering!") + + +def get_tomo_shape(root, run_ids, tomo_alg, voxel_size): + import numpy as np + import zarr + + for runID in run_ids: + # Get Volume Shape from First Run + run = root.get_run(runID) + + # Get Target Shape + vs = run.get_voxel_spacing(voxel_size) + if vs is None: + continue + tomo = vs.get_tomogram(tomo_alg) + if tomo is None: + continue + loc = tomo.zarr() + shape = zarr.open(loc)["0"].shape + return shape + + +def save_filter(params): + import torch + + from copick_torch.filters.bandpass import Filter3D + + filter = Filter3D( + params[0], + params[1], + params[2], + params[3], + params[4], + params[5], + device=torch.device("cpu"), + ) + + filter.show_filter() + + del filter + + +def save_parameters(config, tomo_info, parameters, write_algorithm): + import os + + import copick + + from copick_torch.entry_points.utils import save_parameters_yaml + + root = copick.from_file(config) + overlay_root = root.config.overlay_root + if overlay_root[:8] == "local://": + overlay_root = overlay_root[8:] + group = { + "input": { + "config": config, + "tomo_alg": tomo_info[0], + "voxel_size": tomo_info[1], + }, + "parameters": { + "Low-Pass Frequency (Angstroms)": parameters[0], + "Low-Pass Decay (Pixels)": parameters[1], + "High-Pass Frequency (Angstroms)": parameters[2], + "High-Pass Decay (Pixels)": parameters[3], + }, + "output": { + "tomo_alg": write_algorithm, + "voxel_size": tomo_info[1], + }, + } + os.makedirs(os.path.join(overlay_root, "logs"), exist_ok=True) + path = os.path.join(overlay_root, "logs", f"process-filter3d_{tomo_info[0]}_{tomo_info[1]}A.yaml") + save_parameters_yaml(group, path) + print(f"📝 Saved Parameters to {path}") diff --git a/copick_torch/entry_points/run_membrane_seg.py b/copick_torch/entry_points/run_membrane_seg.py index ee8d6f6..4d9c331 100644 --- a/copick_torch/entry_points/run_membrane_seg.py +++ b/copick_torch/entry_points/run_membrane_seg.py @@ -46,6 +46,16 @@ def membrain_seg( threshold: float, user_id: str, ): + """ + Runs the membrane segmentation command. + """ + run(config, tomo_alg, voxel_size, session_id, threshold, user_id) + + +def run(config, tomo_alg, voxel_size, session_id, threshold, user_id): + """ + Runs the membrane segmentation. + """ import copick from copick_torch import parallelization diff --git a/copick_torch/filters/bandpass.py b/copick_torch/filters/bandpass.py new file mode 100644 index 0000000..74abd2e --- /dev/null +++ b/copick_torch/filters/bandpass.py @@ -0,0 +1,252 @@ +""" +This module contains functions for creating cosine low-pass filter and applying it to tomograms. +This is a written translation of the MATLAB code cosine_filter.m from the artia-wrapper package +(https://github.com/uermel/artia-wrapper/tree/master) +""" + +import math + +import matplotlib.pyplot as plt +import numpy as np +import torch +from torch.fft import fftn, fftshift, ifftn, ifftshift + + +class Filter3D: + def __init__(self, apix, sz, lp=0, lpd=0, hp=0, hpd=0, device=None): + """ + Initialize the Filter3D class. + + Args: + apix (float): Pixel size in angstrom. + sz (tuple): Size of the tomogram (D, H, W). + lp (float): Low-pass cutoff resolution in angstroms. + lpd (float): Low-pass decay width in pixels. + hp (float): High-pass cutoff resolution in angstroms. + hpd (float): High-pass decay width in pixels. + device (torch.device, optional): Device for the filter tensor. + """ + # Set Parameters + self.apix = apix + self.sz = sz + self.lp = lp + self.lpd = lpd + self.hp = hp + self.hpd = hpd + self.dtype = torch.float32 + + # Set Device + if device is None: + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + else: + self.device = device + + # Allow LP-only, HP-only, or band-pass (hp > lp in Å) + if self.lp > 0 and self.hp > 0 and not (self.hp > self.lp): + raise ValueError("For band-pass, require hp (Å) > lp (Å).") + + # Convert cutoff values from angstroms to pixels + self.lp_pix = self.angst_to_pix(self.lp) if self.lp > 0 else 0 # Low-pass cutoff in pixels + self.hp_pix = self.angst_to_pix(self.hp) if self.hp > 0 else 0 # High-pass cutoff in pixels + # LPD and HPD are always in pixels; do not convert + self.lpd_pix = self.lpd # Decay width in pixels + self.hpd_pix = self.hpd # Decay width in pixels + + print("Constructing 3D Cosine Filter...") + self.cosine_filter() + + def angst_to_pix(self, ang): + """ + Convert angstroms to pixels based on the pixel size. + + Args: + ang (float): Measurement in angstroms. + + Returns: + float: Measurement in pixels. + """ + return max(self.sz) / (ang / self.apix) + + def cosine_filter(self): + """ + Creates a combined low-pass and high-pass cosine filter for 3D tomograms. + """ + D, H, W = self.sz + + # Create spatial frequency grids in pixel space (centered at zero) + zz, yy, xx = torch.meshgrid( + torch.arange(D, device=self.device, dtype=self.dtype) - D // 2, + torch.arange(H, device=self.device, dtype=self.dtype) - H // 2, + torch.arange(W, device=self.device, dtype=self.dtype) - W // 2, + indexing="ij", + ) + r = torch.sqrt(xx**2 + yy**2 + zz**2) # Radial distance in pixels + + # Low-pass filter + lpv = self.construct_filter(r, self.lp_pix, self.lpd_pix, mode="lp") + + # High-pass filter + hpv = self.construct_filter(r, self.hp_pix, self.hpd_pix, mode="hp") + + # Combined Filter + self.filter = lpv * hpv + + def construct_filter(self, r, freq, freqdecay, mode="lp"): + """ + Constructs a low-pass or high-pass filter based on the mode. + Handles pure LP or HP cases properly. + """ + if mode not in ["lp", "hp"]: + raise ValueError("Mode must be 'lp' for low-pass or 'hp' for high-pass.") + + # Start with neutral mask + filter_mask = torch.ones_like(r, dtype=self.dtype, device=self.device) + + # Skip filter: freq==0 disables that side + if freq == 0: + if mode == "lp": + # LP disabled → all-pass (1) + filter_mask[:] = 1.0 + elif mode == "hp": + # HP disabled → all-pass (1) + filter_mask[:] = 1.0 + return filter_mask + + # Box filter + if freqdecay == 0: + if mode == "lp": + filter_mask = (r <= freq).float() + else: # hp + filter_mask = (r >= freq).float() + return filter_mask + + # Cosine transition region + half_decay = freqdecay / 2.0 + if mode == "lp": + filter_mask = (r <= freq - half_decay).float() + sel = (r > (freq - half_decay)) & (r < (freq + half_decay)) + filter_mask[sel] = 0.5 + 0.5 * torch.cos(math.pi * (r[sel] - (freq - half_decay)) / freqdecay) + else: # high-pass + filter_mask = (r >= freq + half_decay).float() + sel = (r > (freq - half_decay)) & (r < (freq + half_decay)) + filter_mask[sel] = 0.5 - 0.5 * torch.cos(math.pi * (r[sel] - (freq - half_decay)) / freqdecay) + + return filter_mask + + def extract_1d_profile(self, axis="x"): + """ + Extracts a 1D profile from the 3D filter along the specified axis. + + Returns: + freqs (np.ndarray): Frequency values in cycles per angstrom (1/Å). + profile (np.ndarray): Filter magnitude values along the specified axis. + """ + + filter_tensor = self.filter.cpu().numpy() + D, H, W = filter_tensor.shape + + # Determine the axis + if axis == "x": + central_slice = filter_tensor[D // 2, H // 2, :] + freqs = np.fft.fftfreq(W, d=self.apix) # 1/Å + elif axis == "y": + central_slice = filter_tensor[D // 2, :, W // 2] + freqs = np.fft.fftfreq(H, d=self.apix) # 1/Å + elif axis == "z": + central_slice = filter_tensor[:, H // 2, W // 2] + freqs = np.fft.fftfreq(D, d=self.apix) # 1/Å + else: + raise ValueError("Axis must be one of 'x', 'y', or 'z'.") + + # Only keep positive frequencies + mask = freqs >= 0 + freqs_positive = freqs[mask] + profile_positive = central_slice[mask] + + return freqs_positive[::-1], profile_positive + + def apply(self, data): + """ + Applies the filter to a tomogram. + + Args: + data (torch.Tensor): Input data tensor of shape (D, H, W). + + Returns: + torch.Tensor: Filtered data tensor. + """ + + # Assuming 'vol' is your input data + if isinstance(data, np.ndarray): + # Convert NumPy array to PyTorch tensor + data = torch.from_numpy(data).float() # Ensure the dtype matches your filter + + # Ensure data is on the same device and dtype + data = data.to(self.device).type(self.dtype) + + # Compute the Fourier Transform of the data + filtered_data = ifftn(ifftshift(fftshift(fftn(data)) * self.filter)).real + + return filtered_data + + def show_filter(self): + """ + Displays the 3D filter as a 3D plot. + """ + + # Extract 1D profile + freqs, profile = self.extract_1d_profile(axis="x") + + # Create a 2x1 plot + fig, axs = plt.subplots(2, 1, figsize=(8, 12)) + + # Plot the 1D frequency profile + axs[0].plot(freqs, profile, label="Axis Profile") + axs[0].set_xlabel("Frequency (1/Å)") + axs[0].set_ylabel("Filter Magnitude") + axs[0].set_title("1D Frequency Profile Along Axis") + axs[0].legend() + axs[0].grid(True) + axs[0].set_xlim(0, max(freqs)) # Set x-axis range + + # Plot the 2D slice of the filter + (Nx, Ny, Nz) = self.filter.shape + axs[1].imshow(self.filter[int(Nx // 2), :, :], cmap="gray") + axs[1].axis("off") + + # Show the plot + # plt.tight_layout() + plt.savefig("filter.png") + + +def init_filter3d(gpu_id: int, apix: float, sz: tuple, lp: float, lpd: float, hp: float, hpd: float): + """ + Initializes the filter3d class. + """ + device = torch.device(f"cuda:{gpu_id}") + filter = Filter3D(apix, sz, lp, lpd, hp, hpd, device=device) + return filter + + +def run_filter3d(run, tomo_alg, voxel_size, write_algorithm, gpu_id, models): + """ + Runs the filter3d class. + """ + from copick_utils.io import readers, writers + + # Get the Filter + filter = models + + # Get the Tomogram + tomo = readers.tomogram(run, voxel_size, tomo_alg) + + # Check if Tomogram Exists + if tomo is None: + print(f"⚠️ Skipping Run {run.name}: No Tomogram found for Algorithm {tomo_alg} at Voxel Size {voxel_size}A") + return + + # Apply the Filter + filtered_tomo = filter.apply(tomo) + + # Save the Filtered Tomogram + writers.tomogram(run, filtered_tomo.cpu().numpy(), voxel_size, write_algorithm) diff --git a/copick_torch/filters/downsample.py b/copick_torch/filters/downsample.py index 4ed73f0..0deda85 100644 --- a/copick_torch/filters/downsample.py +++ b/copick_torch/filters/downsample.py @@ -49,11 +49,22 @@ def run(self, volume): if self.device.type == "cpu" and volume.dim() == 4: raise AssertionError("Batched volumes are not allowed on CPU. Please provide a single volume.") - if volume.dim() == 4: - output = self.batched_rescale(volume) - else: - output = self.single_rescale(volume) - + # Try to Run on the GPU, if there's memory issues, fall back to CPU + try: + output = self.submit(volume) + except (torch.cuda.OutOfMemoryError, RuntimeError) as e: + # Only handle GPU out-of-memory errors, not all exceptions + if isinstance(e, torch.cuda.OutOfMemoryError) or ( + isinstance(e, RuntimeError) and "out of memory" in str(e).lower() + ): + # Free GPU cache before retrying on CPU + torch.cuda.empty_cache() + + print("⚠️ GPU memory issue encountered, falling back to CPU for downsampling.") + self.device = torch.device("cpu") + output = self.submit(volume) + else: + raise # Return to CPU if Compute is on GPU if self.device != torch.device("cpu"): output = output.cpu() @@ -65,6 +76,16 @@ def run(self, volume): else: return output + def submit(self, volume: torch.Tensor) -> torch.Tensor: + """ + Submit the volume for rescaling based on its dimensionality. + """ + if volume.dim() == 4: + output = self.batched_rescale(volume) + else: + output = self.single_rescale(volume) + return output + def batched_rescale(self, volume: torch.Tensor): """ Process a (batched) volume: move to device, perform FFT, crop in Fourier space, @@ -141,3 +162,36 @@ def downsample_init(gpu_id: int, voxel_size: float, target_resolution: float): downsampler.device = torch.device(f"cuda:{gpu_id}") return downsampler + + +def run_downsampler(run, tomo_alg, voxel_size, target_resolution, delete_source, gpu_id, models): + """ + Runs the downsampler class. + """ + from copick_utils.io import readers, writers + + # Get the Downsampler + downsampler = models + + # Get the Tomogram + tomo = readers.tomogram(run, voxel_size, tomo_alg) + + # Check if Tomogram Exists + if tomo is None: + print(f"⚠️ Skipping Run {run.name}: No Tomogram found for Algorithm {tomo_alg} at Voxel Size {voxel_size}A") + return + + # Downsample the Tomogram + downsampled_tomo = downsampler.run(tomo) + + # Save the Downsampled Tomogram + writers.tomogram(run, downsampled_tomo, target_resolution, tomo_alg) + + # Delete the source tomograms if requested + if delete_source: + vs = run.get_voxel_spacing(voxel_size) + vs.delete_tomograms(tomo_alg) + + # If the Voxel Spacing is Empty, lets delete it as well + if vs.tomograms == []: + vs.delete() diff --git a/pyproject.toml b/pyproject.toml index 174e62d..875eba8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -85,6 +85,7 @@ coverage-report = "scripts.coverage_report:main" [project.entry-points."copick.process.commands"] downsample = "copick_torch.entry_points.run_downsample:downsample" +bandpass = "copick_torch.entry_points.run_filter3d:bandpass" [project.entry-points."copick.inference.commands"] membrain-seg = "copick_torch.entry_points.run_membrane_seg:membrain_seg"