From 8445431794690ed09611b056023723d31477b289 Mon Sep 17 00:00:00 2001 From: jtschwar Date: Sat, 25 Oct 2025 00:56:55 -0700 Subject: [PATCH 01/12] feat: add bandpass filter as a process command --- copick_torch/cli.py | 5 +- copick_torch/entry_points/run_filter3d.py | 158 ++++++++++++++++ copick_torch/filters/bandpass.py | 214 ++++++++++++++++++++++ 3 files changed, 375 insertions(+), 2 deletions(-) create mode 100644 copick_torch/entry_points/run_filter3d.py create mode 100644 copick_torch/filters/bandpass.py diff --git a/copick_torch/cli.py b/copick_torch/cli.py index b69d9ac..8d8d13b 100644 --- a/copick_torch/cli.py +++ b/copick_torch/cli.py @@ -1,17 +1,18 @@ import click import copick_torch -from copick_torch.entry_points.run_downsample import downsample from copick_torch.entry_points.run_membrane_seg import membrain_seg +from copick_torch.entry_points.run_downsample import downsample +from copick_torch.entry_points.run_filter3d import bandpass @click.group() def routines(): pass - routines.add_command(membrain_seg) routines.add_command(downsample) +routines.add_command(bandpass) if __name__ == "__main__": routines() diff --git a/copick_torch/entry_points/run_filter3d.py b/copick_torch/entry_points/run_filter3d.py new file mode 100644 index 0000000..994f0f0 --- /dev/null +++ b/copick_torch/entry_points/run_filter3d.py @@ -0,0 +1,158 @@ +import click + +def low_pass_commands(func): + """Decorator to add common options to a Click command.""" + options = [ + click.option("--lp-freq", type=float, required=False, default=0, + help="Low-pass cutoff frequency (in Angstroms)"), + click.option("--lp-decay", type=float, required=False, default=0, + help="Low-pass decay width (in pixels)"), + click.option("--hp-freq", type=float, required=False, default=0, + help="High-pass cutoff frequency (in Angstroms)"), + click.option("--hp-decay", type=float, required=False, default=0, + help="High-pass decay width (in pixels)"), + click.option("--show-filter", type=bool, required=False, default=False, + help="Save the filter as a Png (filter3d.png)") + ] + for option in reversed(options): # Add options in reverse order to preserve correct order + func = option(func) + return func + +def copick_commands(func): + """Decorator to add common options to a Click command.""" + options = [ + click.option("--config", type=str, required=True, + help="Path to Copick Config for Processing Data"), + click.option("--run-ids", type=str, required=False, default=None, + help="Run ID to process (No Input would process the entire dataset.)"), + click.option("--tomo-alg", type=str, required=True, + help="Tomogram Algorithm to use"), + click.option("--voxel-size", type=float, required=False, default=10, + help="Voxel Size to Query the Data"), + click.option("--show-filter", type=bool, required=False, default=True, + help="Save the filter as a Png (filter3d.png)") + ] + for option in reversed(options): # Add options in reverse order to preserve correct order + func = option(func) + return func + +def input_check(lp_freq, hp_freq, voxel_size): + if lp_freq == 0 and hp_freq == 0: + raise ValueError("Low-pass and high-pass frequencies cannot both be 0.") + elif lp_freq < voxel_size * 2: + raise ValueError("Low-pass frequency cannot be less than twice the Nyquist resolution.") + elif hp_freq < voxel_size * 2: + raise ValueError("High-pass frequency cannot be less than twice the Nyquist resolution.") + elif lp_freq > hp_freq and lp_freq > 0 and hp_freq > 0: + raise ValueError("Low-pass cutoff resolution must be less than high-pass cutoff resolution.") + +def print_header(lp_freq, lp_decay, hp_freq, hp_decay): + print('----------------------------------------') + print(f'Low-Pass Frequency: {lp_freq} Angstroms') + print(f'Low-Pass Decay: {lp_decay} Pixels') + print(f'High-Pass Frequency: {hp_freq} Angstroms') + print(f'High-Pass Decay: {hp_decay} Pixels') + print('----------------------------------------') + +@click.command(context_settings={"show_default": True}) +@low_pass_commands +@copick_commands +def bandpass( + config: str, + run_ids: str, + lp_freq: float, + lp_decay: float, + hp_freq: float, + hp_decay: float, + tomo_alg: str, + voxel_size: float, + show_filter: bool + ): + + run_filter3d(config, run_ids, lp_freq, lp_decay, hp_freq, hp_decay, tomo_alg, voxel_size, show_filter) + +def run_filter3d( + config: str, + run_ids: str, + lp_freq: float, lp_decay: float, + hp_freq: float, hp_decay: float, + tomo_alg: str, voxel_size: float, + show_filter: bool + ): + from copick_torch.filters.bandpass import Filter3D + from tqdm import tqdm + import os, copick + + input_check(lp_freq, hp_freq, voxel_size) + + # Load Copick Project + if os.path.exists(config): root = copick.from_file(config) + else: raise ValueError(f"Config file {config} does not exist.") + + print_header(lp_freq, lp_decay, hp_freq, hp_decay) + + # Get Run IDs + if run_ids is None: run_ids = [run.name for run in root.runs] + else: run_ids = run_ids.split(",") + + # Determine Write Algorithm + write_algorithm = tomo_alg + if lp_freq > 0: write_algorithm = write_algorithm + f'-lp{lp_freq:0.0f}A' + if hp_freq > 0: write_algorithm = write_algorithm + f'-hp{hp_freq:0.0f}A' + + # Get Tomogram for Initializing 3D Filter + vol = readers.tomogram(root.get_run(run_ids[0]), voxel_size, tomo_alg) + + # Create 3D Filter + filter = Filter3D( + apix=voxel_size, sz=vol.shape, + lp= lp_freq, lpd = lp_decay, + hp=hp_freq, hpd=hp_decay) + + # Save Filter + if show_filter: + filter.show_filter() + + # Get Tomogram and Process + for run_id in tqdm(run_ids): + + run = root.get_run(run_id) + vol = readers.tomogram(run, voxel_size, tomo_alg) + + # Apply Low-pass Filter + vol = filter.apply(vol) + + # Write Tomogram + writers.tomogram(run, vol.cpu().numpy(), voxel_size, write_algorithm) + + print('Applying Filters to All Tomograms Complete...') + +def save_parameters(config, tomo_info, parameters): + import os + + import copick + + from copick_torch.entry_points.utils import save_parameters_yaml + + root = copick.from_file(config) + overlay_root = root.config.overlay_root + if overlay_root[:8] == "local://": + overlay_root = overlay_root[8:] + group = { + "input": { + "config": config, + "tomo_alg": tomo_info[0], + "voxel_size": tomo_info[1], + }, + "parameters": { + "Low-Pass Frequency (Angstroms)": parameters[0], + "Low-Pass Decay (Pixels)": parameters[1], + "High-Pass Frequency (Angstroms)": parameters[2], + "High-Pass Decay (Pixels)": parameters[3], + } + } + os.makedirs(os.path.join(overlay_root, "logs"), exist_ok=True) + path = os.path.join(overlay_root, "logs", f"process-filter3d_{tomo_info[0]}_{tomo_info[1]}A.yaml") + save_parameters_yaml(group, path) + print(f"📝 Saved Parameters to {path}") + diff --git a/copick_torch/filters/bandpass.py b/copick_torch/filters/bandpass.py new file mode 100644 index 0000000..37a0c14 --- /dev/null +++ b/copick_torch/filters/bandpass.py @@ -0,0 +1,214 @@ +from torch.fft import fftshift, ifftshift, fftn, ifftn +import matplotlib.pyplot as plt +import numpy as np +import torch, math + +""" +This module contains functions for creating cosine-low pass filter and applying it to tomograms. +This is a written translation of the MATLAB code cosine_filter.m from the artia-wrapper package +(https://github.com/uermel/artia-wrapper/tree/master) +""" + +class Filter3D: + def __init__(self, apix, sz, lp=0, lpd=0, hp=0, hpd=0, device=None): + """ + Initialize the Filter3D class. + + Args: + apix (float): Pixel size in angstrom. + sz (tuple): Size of the tomogram (D, H, W). + lp (float): Low-pass cutoff resolution in angstroms. + lpd (float): Low-pass decay width in pixels. + hp (float): High-pass cutoff resolution in angstroms. + hpd (float): High-pass decay width in pixels. + device (torch.device, optional): Device for the filter tensor. + """ + # Set Parameters + self.apix = apix + self.sz = sz + self.lp = lp + self.lpd = lpd + self.hp = hp + self.hpd = hpd + self.dtype = torch.float32 + + # Set Device + if device is None: + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + else: + self.device = device + + # Check if low-pass cutoff resolution is less than high-pass cutoff resolution + if self.lp > self.hp and self.lp > 0 and self.hp > 0: + raise ValueError("Low-pass cutoff resolution must be less than high-pass cutoff resolution.") + + # Convert cutoff values from angstroms to pixels + self.lp_pix = self.angst_to_pix(self.lp) if self.lp > 0 else 0 # Low-pass cutoff in pixels + self.hp_pix = self.angst_to_pix(self.hp) if self.hp > 0 else 0 # High-pass cutoff in pixels + # LPD and HPD are always in pixels; do not convert + self.lpd_pix = self.lpd # Decay width in pixels + self.hpd_pix = self.hpd # Decay width in pixels + + print('Constructing 3D Cosine Filter...') + self.cosine_filter() + + def angst_to_pix(self, ang): + """ + Convert angstroms to pixels based on the pixel size. + + Args: + ang (float): Measurement in angstroms. + + Returns: + float: Measurement in pixels. + """ + return max(self.sz) / (ang / self.apix) + + def cosine_filter(self): + """ + Creates a combined low-pass and high-pass cosine filter for 3D tomograms. + """ + D, H, W = self.sz + + # Create spatial frequency grids in pixel space (centered at zero) + zz, yy, xx = torch.meshgrid( + torch.arange(D, device=self.device, dtype=self.dtype) - D // 2, + torch.arange(H, device=self.device, dtype=self.dtype) - H // 2, + torch.arange(W, device=self.device, dtype=self.dtype) - W // 2, + indexing='ij' + ) + r = torch.sqrt(xx**2 + yy**2 + zz**2) # Radial distance in pixels + + # Low-pass filter + lpv = self.construct_filter(r, self.lp_pix, self.lpd_pix, mode='lp') + + # High-pass filter + hpv = self.construct_filter(r, self.hp_pix, self.hpd_pix, mode='hp') + + # Combined Filter + self.filter = lpv * hpv + + def construct_filter(self, r, freq, freqdecay, mode='lp'): + """ + Constructs a low-pass or high-pass filter based on the mode. + + Args: + r (torch.Tensor): Radial spatial frequency tensor in pixels. + freq (float): Cutoff frequency in pixels. + freqdecay (float): Decay width in pixels. + mode (str): 'lp' for low-pass, 'hp' for high-pass. + + Returns: + torch.Tensor: Filter mask. + """ + if mode not in ['lp', 'hp']: + raise ValueError("Mode must be 'lp' for low-pass or 'hp' for high-pass.") + + # Skip filter + if freq == 0 and freqdecay == 0: + filter_mask = torch.ones_like(r) + # Box Filter + elif freq > 0 and freqdecay == 0: + filter_mask = (r < freq).float() + if mode == 'hp': filter_mask = 1 - filter_mask + # # Cosine Filter starting with 1 (low-pass) or 0 (high-pass) + # elif freq == 0 and freqdecay > 0: + # filter_mask = (r < freq).float() + # sel = r <= freqdecay + # filter_mask[sel] = 0.5 + 0.5 * torch.cos(math.pi * r[sel] / freqdecay) + # if mode == 'lp': filter_mask = 1 - filter_mask + # Box Filter with cosine decay + else: + half_decay = freqdecay / 2.0 + filter_mask = (r < freq).float() + sel = (r > (freq - half_decay)) & (r < (freq + half_decay)) + filter_mask[sel] = 0.5 + 0.5 * torch.cos(math.pi * (r[sel] - (freq - half_decay)) / freqdecay) + if mode == 'hp': filter_mask = 1 - filter_mask + + return filter_mask + + def extract_1d_profile(self, axis='x'): + """ + Extracts a 1D profile from the 3D filter along the specified axis. + + Returns: + freqs (np.ndarray): Frequency values in cycles per angstrom (1/Å). + profile (np.ndarray): Filter magnitude values along the specified axis. + """ + + filter_tensor = self.filter.cpu().numpy() + D, H, W = filter_tensor.shape + + # Determine the axis + if axis == 'x': + central_slice = filter_tensor[D // 2, H // 2, :] + freqs = np.fft.fftfreq(W, d=self.apix) # 1/Å + elif axis == 'y': + central_slice = filter_tensor[D // 2, :, W // 2] + freqs = np.fft.fftfreq(H, d=self.apix) # 1/Å + elif axis == 'z': + central_slice = filter_tensor[:, H // 2, W // 2] + freqs = np.fft.fftfreq(D, d=self.apix) # 1/Å + else: + raise ValueError("Axis must be one of 'x', 'y', or 'z'.") + + # Only keep positive frequencies + mask = freqs >= 0 + freqs_positive = freqs[mask] + profile_positive = central_slice[mask] + + return freqs_positive[::-1], profile_positive + + def apply(self, data): + """ + Applies the filter to a tomogram. + + Args: + data (torch.Tensor): Input data tensor of shape (D, H, W). + + Returns: + torch.Tensor: Filtered data tensor. + """ + + # Assuming 'vol' is your input data + if isinstance(data, np.ndarray): + # Convert NumPy array to PyTorch tensor + data = torch.from_numpy(data).float() # Ensure the dtype matches your filter + + # Ensure data is on the same device and dtype + data = data.to(self.device).type(self.dtype) + + # Compute the Fourier Transform of the data + filtered_data = ifftn(ifftshift(fftshift(fftn(data)) * self.filter)).real + + return filtered_data + + def show_filter(self): + """ + Displays the 3D filter as a 3D plot. + """ + + # Extract 1D profile + freqs, profile = self.extract_1d_profile(axis='x') + + # Create a 2x1 plot + fig, axs = plt.subplots(2, 1, figsize=(8, 12)) + + # Plot the 1D frequency profile + axs[0].plot(freqs, profile, label='Axis Profile') + axs[0].set_xlabel('Frequency (1/Å)') + axs[0].set_ylabel('Filter Magnitude') + axs[0].set_title('1D Frequency Profile Along Axis') + axs[0].legend() + axs[0].grid(True) + axs[0].set_xlim(0, max(freqs)) # Set x-axis range + + # Plot the 2D slice of the filter + (Nx, Ny, Nz) = self.filter.shape + axs[1].imshow(self.filter[int(Nx//2), :, :], cmap='gray') + axs[1].axis('off') + + # Show the plot + # plt.tight_layout() + plt.savefig('filter.png') + From 284567c3144e0fcf25e224a916d84c2406f2a66e Mon Sep 17 00:00:00 2001 From: jtschwar Date: Thu, 30 Oct 2025 10:28:20 -0700 Subject: [PATCH 02/12] add 3D bandpass filtering and lazy loading --- copick_torch/cli.py | 18 ----- copick_torch/entry_points/run_downsample.py | 54 +++++---------- copick_torch/entry_points/run_filter3d.py | 66 +++++++++---------- copick_torch/entry_points/run_membrane_seg.py | 16 +++-- copick_torch/filters/bandpass.py | 30 +++++++++ copick_torch/filters/downsample.py | 33 ++++++++++ pyproject.toml | 1 + 7 files changed, 124 insertions(+), 94 deletions(-) delete mode 100644 copick_torch/cli.py diff --git a/copick_torch/cli.py b/copick_torch/cli.py deleted file mode 100644 index 8d8d13b..0000000 --- a/copick_torch/cli.py +++ /dev/null @@ -1,18 +0,0 @@ -import click - -import copick_torch -from copick_torch.entry_points.run_membrane_seg import membrain_seg -from copick_torch.entry_points.run_downsample import downsample -from copick_torch.entry_points.run_filter3d import bandpass - - -@click.group() -def routines(): - pass - -routines.add_command(membrain_seg) -routines.add_command(downsample) -routines.add_command(bandpass) - -if __name__ == "__main__": - routines() diff --git a/copick_torch/entry_points/run_downsample.py b/copick_torch/entry_points/run_downsample.py index 4f46318..205afeb 100644 --- a/copick_torch/entry_points/run_downsample.py +++ b/copick_torch/entry_points/run_downsample.py @@ -1,6 +1,5 @@ import click - def downsample_commands(func): """Decorator to add common options to a Click command.""" options = [ @@ -38,10 +37,20 @@ def downsample( target_resolution: float, delete_source: bool, ): - import copick + """ + Runs the downsampling command. + """ + + run(config, tomo_alg, voxel_size, target_resolution, delete_source) + +def run(config, tomo_alg, voxel_size, target_resolution, delete_source): + """ + Runs the downsampling. + """ - from copick_torch import parallelization from copick_torch.filters import downsample + from copick_torch import parallelization + import copick root = copick.from_file(config) run_ids = [run.name for run in root.runs] @@ -57,7 +66,7 @@ def downsample( # Execute try: pool.execute( - run_downsampler, + downsample.run_downsampler, tasks, task_ids=run_ids, progress_desc="Downsampling Tomograms", @@ -69,42 +78,13 @@ def downsample( print("✅ Completed the Downsampling!") -def run_downsampler(run, tomo_alg, voxel_size, target_resolution, delete_source, gpu_id, models): - from copick_utils.io import readers, writers - - # Get the Downsampler - downsampler = models - - # Get the Tomogram - tomo = readers.tomogram(run, voxel_size, tomo_alg) - - # Check if Tomogram Exists - if tomo is None: - print(f"⚠️ Skipping Run {run.name}: No Tomogram found for Algorithm {tomo_alg} at Voxel Size {voxel_size}A") - return - - # Downsample the Tomogram - downsampled_tomo = downsampler.run(tomo) - - # Save the Downsampled Tomogram - writers.tomogram(run, downsampled_tomo, target_resolution, tomo_alg) - - # Delete the source tomograms if requested - if delete_source: - vs = run.get_voxel_spacing(voxel_size) - vs.delete_tomograms(tomo_alg) - - # If the Voxel Spacing is Empty, lets delete it as well - if vs.tomograms == []: - vs.delete() - - def save_parameters(config, tomo_alg, voxel_size, target_resolution): - import os - - import copick + """ + Save the parameters for the downsampling. + """ from copick_torch.entry_points.utils import save_parameters_yaml + import copick, os root = copick.from_file(config) overlay_root = root.config.overlay_root diff --git a/copick_torch/entry_points/run_filter3d.py b/copick_torch/entry_points/run_filter3d.py index 994f0f0..16f8c74 100644 --- a/copick_torch/entry_points/run_filter3d.py +++ b/copick_torch/entry_points/run_filter3d.py @@ -28,7 +28,7 @@ def copick_commands(func): click.option("--tomo-alg", type=str, required=True, help="Tomogram Algorithm to use"), click.option("--voxel-size", type=float, required=False, default=10, - help="Voxel Size to Query the Data"), + help="Voxel Size to Query the Data"), click.option("--show-filter", type=bool, required=False, default=True, help="Save the filter as a Png (filter3d.png)") ] @@ -79,8 +79,8 @@ def run_filter3d( tomo_alg: str, voxel_size: float, show_filter: bool ): - from copick_torch.filters.bandpass import Filter3D - from tqdm import tqdm + from copick_torch.filters.bandpass import init_filter3d, run_filter3d + from copick_torch import parallelization import os, copick input_check(lp_freq, hp_freq, voxel_size) @@ -100,39 +100,31 @@ def run_filter3d( if lp_freq > 0: write_algorithm = write_algorithm + f'-lp{lp_freq:0.0f}A' if hp_freq > 0: write_algorithm = write_algorithm + f'-hp{hp_freq:0.0f}A' - # Get Tomogram for Initializing 3D Filter - vol = readers.tomogram(root.get_run(run_ids[0]), voxel_size, tomo_alg) - - # Create 3D Filter - filter = Filter3D( - apix=voxel_size, sz=vol.shape, - lp= lp_freq, lpd = lp_decay, - hp=hp_freq, hpd=hp_decay) - - # Save Filter - if show_filter: - filter.show_filter() - - # Get Tomogram and Process - for run_id in tqdm(run_ids): - - run = root.get_run(run_id) - vol = readers.tomogram(run, voxel_size, tomo_alg) - - # Apply Low-pass Filter - vol = filter.apply(vol) - - # Write Tomogram - writers.tomogram(run, vol.cpu().numpy(), voxel_size, write_algorithm) - - print('Applying Filters to All Tomograms Complete...') - -def save_parameters(config, tomo_info, parameters): - import os - - import copick - + # Initialize Parallelization Pool + pool = parallelization.GPUPool( + init_fn=init_filter3d, + init_args=(voxel_size, lp_freq, lp_decay, hp_freq, hp_decay), + verbose=True, + ) + + # Execute + tasks = [(run, tomo_alg, voxel_size, write_algorithm) for run in run_ids] + try: + pool.execute( + run_filter3d, + tasks, + task_ids=run_ids, + progress_desc="Filtering Tomograms", + ) + finally: + pool.shutdown() + + save_parameters(config, [tomo_alg, voxel_size], [lp_freq, lp_decay, hp_freq, hp_decay], write_algorithm) + print('✅ Completed the Filtering!') + +def save_parameters(config, tomo_info, parameters, write_algorithm): from copick_torch.entry_points.utils import save_parameters_yaml + import copick, os root = copick.from_file(config) overlay_root = root.config.overlay_root @@ -149,6 +141,10 @@ def save_parameters(config, tomo_info, parameters): "Low-Pass Decay (Pixels)": parameters[1], "High-Pass Frequency (Angstroms)": parameters[2], "High-Pass Decay (Pixels)": parameters[3], + }, + 'output': { + 'tomo_alg': write_algorithm, + 'voxel_size': tomo_info[1] } } os.makedirs(os.path.join(overlay_root, "logs"), exist_ok=True) diff --git a/copick_torch/entry_points/run_membrane_seg.py b/copick_torch/entry_points/run_membrane_seg.py index ee8d6f6..7622c77 100644 --- a/copick_torch/entry_points/run_membrane_seg.py +++ b/copick_torch/entry_points/run_membrane_seg.py @@ -1,6 +1,5 @@ import click - def segment_commands(func): """Decorator to add common options to a Click command.""" options = [ @@ -46,10 +45,19 @@ def membrain_seg( threshold: float, user_id: str, ): - import copick + """ + Runs the membrane segmentation command. + """ + run(config, tomo_alg, voxel_size, session_id, threshold, user_id) - from copick_torch import parallelization + +def run(config, tomo_alg, voxel_size, session_id, threshold, user_id): + """ + Runs the membrane segmentation. + """ from copick_torch.inference import membrain_seg + from copick_torch import parallelization + import copick print("Starting Membrane Segmentation...") print(f"Using Tomograms with Voxel Size: {voxel_size} and Algorithm: {tomo_alg}") @@ -87,9 +95,9 @@ def membrain_seg( def run_segmenter(run, tomo_alg, voxel_size, session_id, threshold, user_id, gpu_id, models): - from copick_utils.io import readers, writers from copick_torch.inference import membrain_seg + from copick_utils.io import readers, writers # Default Sliding Window Parameters sw_batch_size = 4 diff --git a/copick_torch/filters/bandpass.py b/copick_torch/filters/bandpass.py index 37a0c14..abc2afa 100644 --- a/copick_torch/filters/bandpass.py +++ b/copick_torch/filters/bandpass.py @@ -212,3 +212,33 @@ def show_filter(self): # plt.tight_layout() plt.savefig('filter.png') +def init_filter3d(gpu_id: int, apix: float, sz: tuple, lp: float, lpd: float, hp: float, hpd: float): + """ + Initializes the filter3d class. + """ + device = torch.device(f"cuda:{gpu_id}") + filter = Filter3D(apix, sz, lp, lpd, hp, hpd, device=device) + return filter + +def run_filter3d(run, tomo_alg, voxel_size, write_algorithm, gpu_id, models): + """ + Runs the filter3d class. + """ + from copick_utils.io import readers, writers + + # Get the Filter + filter = models + + # Get the Tomogram + tomo = readers.tomogram(run, voxel_size, tomo_alg) + + # Check if Tomogram Exists + if tomo is None: + print(f"⚠️ Skipping Run {run.name}: No Tomogram found for Algorithm {tomo_alg} at Voxel Size {voxel_size}A") + return + + # Apply the Filter + filtered_tomo = filter.apply(tomo) + + # Save the Filtered Tomogram + writers.tomogram(run, filtered_tomo.cpu().numpy(), voxel_size, write_algorithm) \ No newline at end of file diff --git a/copick_torch/filters/downsample.py b/copick_torch/filters/downsample.py index 4ed73f0..c8b9eeb 100644 --- a/copick_torch/filters/downsample.py +++ b/copick_torch/filters/downsample.py @@ -141,3 +141,36 @@ def downsample_init(gpu_id: int, voxel_size: float, target_resolution: float): downsampler.device = torch.device(f"cuda:{gpu_id}") return downsampler + +def run_downsampler(run, tomo_alg, voxel_size, target_resolution, delete_source, gpu_id, models): + """ + Runs the downsampler class. + """ + from copick_utils.io import readers, writers + + # Get the Downsampler + downsampler = models + + # Get the Tomogram + tomo = readers.tomogram(run, voxel_size, tomo_alg) + + # Check if Tomogram Exists + if tomo is None: + print(f"⚠️ Skipping Run {run.name}: No Tomogram found for Algorithm {tomo_alg} at Voxel Size {voxel_size}A") + return + + # Downsample the Tomogram + downsampled_tomo = downsampler.run(tomo) + + # Save the Downsampled Tomogram + writers.tomogram(run, downsampled_tomo, target_resolution, tomo_alg) + + # Delete the source tomograms if requested + if delete_source: + vs = run.get_voxel_spacing(voxel_size) + vs.delete_tomograms(tomo_alg) + + # If the Voxel Spacing is Empty, lets delete it as well + if vs.tomograms == []: + vs.delete() + diff --git a/pyproject.toml b/pyproject.toml index 174e62d..875eba8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -85,6 +85,7 @@ coverage-report = "scripts.coverage_report:main" [project.entry-points."copick.process.commands"] downsample = "copick_torch.entry_points.run_downsample:downsample" +bandpass = "copick_torch.entry_points.run_filter3d:bandpass" [project.entry-points."copick.inference.commands"] membrain-seg = "copick_torch.entry_points.run_membrane_seg:membrain_seg" From df2e6b23585d28227b4b63568d3ae489fe5b2f92 Mon Sep 17 00:00:00 2001 From: jtschwar Date: Thu, 30 Oct 2025 10:45:35 -0700 Subject: [PATCH 03/12] linting --- copick_torch/entry_points/run_downsample.py | 12 +- copick_torch/entry_points/run_filter3d.py | 141 +++++++++++------- copick_torch/entry_points/run_membrane_seg.py | 9 +- copick_torch/filters/bandpass.py | 61 ++++---- copick_torch/filters/downsample.py | 2 +- 5 files changed, 141 insertions(+), 84 deletions(-) diff --git a/copick_torch/entry_points/run_downsample.py b/copick_torch/entry_points/run_downsample.py index 205afeb..1cc473f 100644 --- a/copick_torch/entry_points/run_downsample.py +++ b/copick_torch/entry_points/run_downsample.py @@ -1,5 +1,6 @@ import click + def downsample_commands(func): """Decorator to add common options to a Click command.""" options = [ @@ -43,15 +44,17 @@ def downsample( run(config, tomo_alg, voxel_size, target_resolution, delete_source) + def run(config, tomo_alg, voxel_size, target_resolution, delete_source): """ Runs the downsampling. """ - from copick_torch.filters import downsample - from copick_torch import parallelization import copick + from copick_torch import parallelization + from copick_torch.filters import downsample + root = copick.from_file(config) run_ids = [run.name for run in root.runs] @@ -83,8 +86,11 @@ def save_parameters(config, tomo_alg, voxel_size, target_resolution): Save the parameters for the downsampling. """ + import os + + import copick + from copick_torch.entry_points.utils import save_parameters_yaml - import copick, os root = copick.from_file(config) overlay_root = root.config.overlay_root diff --git a/copick_torch/entry_points/run_filter3d.py b/copick_torch/entry_points/run_filter3d.py index 16f8c74..eaa7125 100644 --- a/copick_torch/entry_points/run_filter3d.py +++ b/copick_torch/entry_points/run_filter3d.py @@ -1,41 +1,64 @@ import click + def low_pass_commands(func): """Decorator to add common options to a Click command.""" options = [ - click.option("--lp-freq", type=float, required=False, default=0, - help="Low-pass cutoff frequency (in Angstroms)"), - click.option("--lp-decay", type=float, required=False, default=0, - help="Low-pass decay width (in pixels)"), - click.option("--hp-freq", type=float, required=False, default=0, - help="High-pass cutoff frequency (in Angstroms)"), - click.option("--hp-decay", type=float, required=False, default=0, - help="High-pass decay width (in pixels)"), - click.option("--show-filter", type=bool, required=False, default=False, - help="Save the filter as a Png (filter3d.png)") + click.option( + "--lp-freq", + type=float, + required=False, + default=0, + help="Low-pass cutoff frequency (in Angstroms)", + ), + click.option("--lp-decay", type=float, required=False, default=0, help="Low-pass decay width (in pixels)"), + click.option( + "--hp-freq", + type=float, + required=False, + default=0, + help="High-pass cutoff frequency (in Angstroms)", + ), + click.option("--hp-decay", type=float, required=False, default=0, help="High-pass decay width (in pixels)"), + click.option( + "--show-filter", + type=bool, + required=False, + default=False, + help="Save the filter as a Png (filter3d.png)", + ), ] for option in reversed(options): # Add options in reverse order to preserve correct order func = option(func) return func + def copick_commands(func): """Decorator to add common options to a Click command.""" options = [ - click.option("--config", type=str, required=True, - help="Path to Copick Config for Processing Data"), - click.option("--run-ids", type=str, required=False, default=None, - help="Run ID to process (No Input would process the entire dataset.)"), - click.option("--tomo-alg", type=str, required=True, - help="Tomogram Algorithm to use"), - click.option("--voxel-size", type=float, required=False, default=10, - help="Voxel Size to Query the Data"), - click.option("--show-filter", type=bool, required=False, default=True, - help="Save the filter as a Png (filter3d.png)") + click.option("--config", type=str, required=True, help="Path to Copick Config for Processing Data"), + click.option( + "--run-ids", + type=str, + required=False, + default=None, + help="Run ID to process (No Input would process the entire dataset.)", + ), + click.option("--tomo-alg", type=str, required=True, help="Tomogram Algorithm to use"), + click.option("--voxel-size", type=float, required=False, default=10, help="Voxel Size to Query the Data"), + click.option( + "--show-filter", + type=bool, + required=False, + default=True, + help="Save the filter as a Png (filter3d.png)", + ), ] for option in reversed(options): # Add options in reverse order to preserve correct order func = option(func) return func + def input_check(lp_freq, hp_freq, voxel_size): if lp_freq == 0 and hp_freq == 0: raise ValueError("Low-pass and high-pass frequencies cannot both be 0.") @@ -44,15 +67,17 @@ def input_check(lp_freq, hp_freq, voxel_size): elif hp_freq < voxel_size * 2: raise ValueError("High-pass frequency cannot be less than twice the Nyquist resolution.") elif lp_freq > hp_freq and lp_freq > 0 and hp_freq > 0: - raise ValueError("Low-pass cutoff resolution must be less than high-pass cutoff resolution.") + raise ValueError("Low-pass cutoff resolution must be less than high-pass cutoff resolution.") + def print_header(lp_freq, lp_decay, hp_freq, hp_decay): - print('----------------------------------------') - print(f'Low-Pass Frequency: {lp_freq} Angstroms') - print(f'Low-Pass Decay: {lp_decay} Pixels') - print(f'High-Pass Frequency: {hp_freq} Angstroms') - print(f'High-Pass Decay: {hp_decay} Pixels') - print('----------------------------------------') + print("----------------------------------------") + print(f"Low-Pass Frequency: {lp_freq} Angstroms") + print(f"Low-Pass Decay: {lp_decay} Pixels") + print(f"High-Pass Frequency: {hp_freq} Angstroms") + print(f"High-Pass Decay: {hp_decay} Pixels") + print("----------------------------------------") + @click.command(context_settings={"show_default": True}) @low_pass_commands @@ -66,39 +91,52 @@ def bandpass( hp_decay: float, tomo_alg: str, voxel_size: float, - show_filter: bool - ): + show_filter: bool, +): run_filter3d(config, run_ids, lp_freq, lp_decay, hp_freq, hp_decay, tomo_alg, voxel_size, show_filter) + def run_filter3d( config: str, run_ids: str, - lp_freq: float, lp_decay: float, - hp_freq: float, hp_decay: float, - tomo_alg: str, voxel_size: float, - show_filter: bool - ): - from copick_torch.filters.bandpass import init_filter3d, run_filter3d + lp_freq: float, + lp_decay: float, + hp_freq: float, + hp_decay: float, + tomo_alg: str, + voxel_size: float, + show_filter: bool, +): + import os + + import copick + from copick_torch import parallelization - import os, copick + from copick_torch.filters.bandpass import init_filter3d, run_filter3d input_check(lp_freq, hp_freq, voxel_size) # Load Copick Project - if os.path.exists(config): root = copick.from_file(config) - else: raise ValueError(f"Config file {config} does not exist.") + if os.path.exists(config): + root = copick.from_file(config) + else: + raise ValueError(f"Config file {config} does not exist.") print_header(lp_freq, lp_decay, hp_freq, hp_decay) - + # Get Run IDs - if run_ids is None: run_ids = [run.name for run in root.runs] - else: run_ids = run_ids.split(",") + if run_ids is None: + run_ids = [run.name for run in root.runs] + else: + run_ids = run_ids.split(",") # Determine Write Algorithm write_algorithm = tomo_alg - if lp_freq > 0: write_algorithm = write_algorithm + f'-lp{lp_freq:0.0f}A' - if hp_freq > 0: write_algorithm = write_algorithm + f'-hp{hp_freq:0.0f}A' + if lp_freq > 0: + write_algorithm = write_algorithm + f"-lp{lp_freq:0.0f}A" + if hp_freq > 0: + write_algorithm = write_algorithm + f"-hp{hp_freq:0.0f}A" # Initialize Parallelization Pool pool = parallelization.GPUPool( @@ -120,11 +158,15 @@ def run_filter3d( pool.shutdown() save_parameters(config, [tomo_alg, voxel_size], [lp_freq, lp_decay, hp_freq, hp_decay], write_algorithm) - print('✅ Completed the Filtering!') + print("✅ Completed the Filtering!") + def save_parameters(config, tomo_info, parameters, write_algorithm): + import os + + import copick + from copick_torch.entry_points.utils import save_parameters_yaml - import copick, os root = copick.from_file(config) overlay_root = root.config.overlay_root @@ -142,13 +184,12 @@ def save_parameters(config, tomo_info, parameters, write_algorithm): "High-Pass Frequency (Angstroms)": parameters[2], "High-Pass Decay (Pixels)": parameters[3], }, - 'output': { - 'tomo_alg': write_algorithm, - 'voxel_size': tomo_info[1] - } + "output": { + "tomo_alg": write_algorithm, + "voxel_size": tomo_info[1], + }, } os.makedirs(os.path.join(overlay_root, "logs"), exist_ok=True) path = os.path.join(overlay_root, "logs", f"process-filter3d_{tomo_info[0]}_{tomo_info[1]}A.yaml") save_parameters_yaml(group, path) print(f"📝 Saved Parameters to {path}") - diff --git a/copick_torch/entry_points/run_membrane_seg.py b/copick_torch/entry_points/run_membrane_seg.py index 7622c77..22104ae 100644 --- a/copick_torch/entry_points/run_membrane_seg.py +++ b/copick_torch/entry_points/run_membrane_seg.py @@ -1,5 +1,6 @@ import click + def segment_commands(func): """Decorator to add common options to a Click command.""" options = [ @@ -55,10 +56,11 @@ def run(config, tomo_alg, voxel_size, session_id, threshold, user_id): """ Runs the membrane segmentation. """ - from copick_torch.inference import membrain_seg - from copick_torch import parallelization import copick + from copick_torch import parallelization + from copick_torch.inference import membrain_seg + print("Starting Membrane Segmentation...") print(f"Using Tomograms with Voxel Size: {voxel_size} and Algorithm: {tomo_alg}") print(f"Saving Segmentations with {user_id}_{session_id}_membranes Query") @@ -96,9 +98,10 @@ def run(config, tomo_alg, voxel_size, session_id, threshold, user_id): def run_segmenter(run, tomo_alg, voxel_size, session_id, threshold, user_id, gpu_id, models): - from copick_torch.inference import membrain_seg from copick_utils.io import readers, writers + from copick_torch.inference import membrain_seg + # Default Sliding Window Parameters sw_batch_size = 4 sw_window_size = 160 diff --git a/copick_torch/filters/bandpass.py b/copick_torch/filters/bandpass.py index abc2afa..63b4c9f 100644 --- a/copick_torch/filters/bandpass.py +++ b/copick_torch/filters/bandpass.py @@ -1,7 +1,9 @@ -from torch.fft import fftshift, ifftshift, fftn, ifftn +import math + import matplotlib.pyplot as plt import numpy as np -import torch, math +import torch +from torch.fft import fftn, fftshift, ifftn, ifftshift """ This module contains functions for creating cosine-low pass filter and applying it to tomograms. @@ -9,6 +11,7 @@ (https://github.com/uermel/artia-wrapper/tree/master) """ + class Filter3D: def __init__(self, apix, sz, lp=0, lpd=0, hp=0, hpd=0, device=None): """ @@ -34,10 +37,10 @@ def __init__(self, apix, sz, lp=0, lpd=0, hp=0, hpd=0, device=None): # Set Device if device is None: - self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") else: self.device = device - + # Check if low-pass cutoff resolution is less than high-pass cutoff resolution if self.lp > self.hp and self.lp > 0 and self.hp > 0: raise ValueError("Low-pass cutoff resolution must be less than high-pass cutoff resolution.") @@ -49,7 +52,7 @@ def __init__(self, apix, sz, lp=0, lpd=0, hp=0, hpd=0, device=None): self.lpd_pix = self.lpd # Decay width in pixels self.hpd_pix = self.hpd # Decay width in pixels - print('Constructing 3D Cosine Filter...') + print("Constructing 3D Cosine Filter...") self.cosine_filter() def angst_to_pix(self, ang): @@ -75,20 +78,20 @@ def cosine_filter(self): torch.arange(D, device=self.device, dtype=self.dtype) - D // 2, torch.arange(H, device=self.device, dtype=self.dtype) - H // 2, torch.arange(W, device=self.device, dtype=self.dtype) - W // 2, - indexing='ij' + indexing="ij", ) r = torch.sqrt(xx**2 + yy**2 + zz**2) # Radial distance in pixels # Low-pass filter - lpv = self.construct_filter(r, self.lp_pix, self.lpd_pix, mode='lp') + lpv = self.construct_filter(r, self.lp_pix, self.lpd_pix, mode="lp") # High-pass filter - hpv = self.construct_filter(r, self.hp_pix, self.hpd_pix, mode='hp') + hpv = self.construct_filter(r, self.hp_pix, self.hpd_pix, mode="hp") # Combined Filter self.filter = lpv * hpv - def construct_filter(self, r, freq, freqdecay, mode='lp'): + def construct_filter(self, r, freq, freqdecay, mode="lp"): """ Constructs a low-pass or high-pass filter based on the mode. @@ -101,7 +104,7 @@ def construct_filter(self, r, freq, freqdecay, mode='lp'): Returns: torch.Tensor: Filter mask. """ - if mode not in ['lp', 'hp']: + if mode not in ["lp", "hp"]: raise ValueError("Mode must be 'lp' for low-pass or 'hp' for high-pass.") # Skip filter @@ -110,7 +113,8 @@ def construct_filter(self, r, freq, freqdecay, mode='lp'): # Box Filter elif freq > 0 and freqdecay == 0: filter_mask = (r < freq).float() - if mode == 'hp': filter_mask = 1 - filter_mask + if mode == "hp": + filter_mask = 1 - filter_mask # # Cosine Filter starting with 1 (low-pass) or 0 (high-pass) # elif freq == 0 and freqdecay > 0: # filter_mask = (r < freq).float() @@ -123,11 +127,12 @@ def construct_filter(self, r, freq, freqdecay, mode='lp'): filter_mask = (r < freq).float() sel = (r > (freq - half_decay)) & (r < (freq + half_decay)) filter_mask[sel] = 0.5 + 0.5 * torch.cos(math.pi * (r[sel] - (freq - half_decay)) / freqdecay) - if mode == 'hp': filter_mask = 1 - filter_mask + if mode == "hp": + filter_mask = 1 - filter_mask return filter_mask - def extract_1d_profile(self, axis='x'): + def extract_1d_profile(self, axis="x"): """ Extracts a 1D profile from the 3D filter along the specified axis. @@ -140,13 +145,13 @@ def extract_1d_profile(self, axis='x'): D, H, W = filter_tensor.shape # Determine the axis - if axis == 'x': + if axis == "x": central_slice = filter_tensor[D // 2, H // 2, :] freqs = np.fft.fftfreq(W, d=self.apix) # 1/Å - elif axis == 'y': + elif axis == "y": central_slice = filter_tensor[D // 2, :, W // 2] freqs = np.fft.fftfreq(H, d=self.apix) # 1/Å - elif axis == 'z': + elif axis == "z": central_slice = filter_tensor[:, H // 2, W // 2] freqs = np.fft.fftfreq(D, d=self.apix) # 1/Å else: @@ -189,43 +194,45 @@ def show_filter(self): """ # Extract 1D profile - freqs, profile = self.extract_1d_profile(axis='x') + freqs, profile = self.extract_1d_profile(axis="x") # Create a 2x1 plot fig, axs = plt.subplots(2, 1, figsize=(8, 12)) # Plot the 1D frequency profile - axs[0].plot(freqs, profile, label='Axis Profile') - axs[0].set_xlabel('Frequency (1/Å)') - axs[0].set_ylabel('Filter Magnitude') - axs[0].set_title('1D Frequency Profile Along Axis') + axs[0].plot(freqs, profile, label="Axis Profile") + axs[0].set_xlabel("Frequency (1/Å)") + axs[0].set_ylabel("Filter Magnitude") + axs[0].set_title("1D Frequency Profile Along Axis") axs[0].legend() axs[0].grid(True) axs[0].set_xlim(0, max(freqs)) # Set x-axis range # Plot the 2D slice of the filter (Nx, Ny, Nz) = self.filter.shape - axs[1].imshow(self.filter[int(Nx//2), :, :], cmap='gray') - axs[1].axis('off') + axs[1].imshow(self.filter[int(Nx // 2), :, :], cmap="gray") + axs[1].axis("off") # Show the plot # plt.tight_layout() - plt.savefig('filter.png') + plt.savefig("filter.png") + def init_filter3d(gpu_id: int, apix: float, sz: tuple, lp: float, lpd: float, hp: float, hpd: float): """ Initializes the filter3d class. """ - device = torch.device(f"cuda:{gpu_id}") + device = torch.device(f"cuda:{gpu_id}") filter = Filter3D(apix, sz, lp, lpd, hp, hpd, device=device) return filter + def run_filter3d(run, tomo_alg, voxel_size, write_algorithm, gpu_id, models): """ Runs the filter3d class. """ from copick_utils.io import readers, writers - + # Get the Filter filter = models @@ -241,4 +248,4 @@ def run_filter3d(run, tomo_alg, voxel_size, write_algorithm, gpu_id, models): filtered_tomo = filter.apply(tomo) # Save the Filtered Tomogram - writers.tomogram(run, filtered_tomo.cpu().numpy(), voxel_size, write_algorithm) \ No newline at end of file + writers.tomogram(run, filtered_tomo.cpu().numpy(), voxel_size, write_algorithm) diff --git a/copick_torch/filters/downsample.py b/copick_torch/filters/downsample.py index c8b9eeb..999a7bb 100644 --- a/copick_torch/filters/downsample.py +++ b/copick_torch/filters/downsample.py @@ -142,6 +142,7 @@ def downsample_init(gpu_id: int, voxel_size: float, target_resolution: float): return downsampler + def run_downsampler(run, tomo_alg, voxel_size, target_resolution, delete_source, gpu_id, models): """ Runs the downsampler class. @@ -173,4 +174,3 @@ def run_downsampler(run, tomo_alg, voxel_size, target_resolution, delete_source, # If the Voxel Spacing is Empty, lets delete it as well if vs.tomograms == []: vs.delete() - From 96206006e8af64a32dec5a9a5c11f84e68f34582 Mon Sep 17 00:00:00 2001 From: jtschwar Date: Thu, 30 Oct 2025 10:55:25 -0700 Subject: [PATCH 04/12] formatting for doc string --- copick_torch/filters/bandpass.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/copick_torch/filters/bandpass.py b/copick_torch/filters/bandpass.py index 63b4c9f..e835642 100644 --- a/copick_torch/filters/bandpass.py +++ b/copick_torch/filters/bandpass.py @@ -1,3 +1,9 @@ +""" +This module contains functions for creating cosine-low pass filter and applying it to tomograms. +This is a written translation of the MATLAB code cosine_filter.m from the artia-wrapper package +(https://github.com/uermel/artia-wrapper/tree/master) +""" + import math import matplotlib.pyplot as plt @@ -5,12 +11,6 @@ import torch from torch.fft import fftn, fftshift, ifftn, ifftshift -""" -This module contains functions for creating cosine-low pass filter and applying it to tomograms. -This is a written translation of the MATLAB code cosine_filter.m from the artia-wrapper package -(https://github.com/uermel/artia-wrapper/tree/master) -""" - class Filter3D: def __init__(self, apix, sz, lp=0, lpd=0, hp=0, hpd=0, device=None): From 2d7ed63c446af2d181917b9384d18b533f91b3a3 Mon Sep 17 00:00:00 2001 From: Jonathan Schwartz Date: Thu, 30 Oct 2025 12:44:32 -0700 Subject: [PATCH 05/12] add safe guards to cli commands --- copick_torch/entry_points/run_downsample.py | 2 +- copick_torch/entry_points/run_filter3d.py | 101 +++++++++++++++----- copick_torch/filters/bandpass.py | 65 ++++++------- copick_torch/filters/downsample.py | 25 ++++- 4 files changed, 134 insertions(+), 59 deletions(-) diff --git a/copick_torch/entry_points/run_downsample.py b/copick_torch/entry_points/run_downsample.py index 1cc473f..31a05c4 100644 --- a/copick_torch/entry_points/run_downsample.py +++ b/copick_torch/entry_points/run_downsample.py @@ -39,7 +39,7 @@ def downsample( delete_source: bool, ): """ - Runs the downsampling command. + Downsample tomograms with Fourier Re-Scaling. """ run(config, tomo_alg, voxel_size, target_resolution, delete_source) diff --git a/copick_torch/entry_points/run_filter3d.py b/copick_torch/entry_points/run_filter3d.py index eaa7125..0d1d7ce 100644 --- a/copick_torch/entry_points/run_filter3d.py +++ b/copick_torch/entry_points/run_filter3d.py @@ -11,7 +11,7 @@ def low_pass_commands(func): default=0, help="Low-pass cutoff frequency (in Angstroms)", ), - click.option("--lp-decay", type=float, required=False, default=0, help="Low-pass decay width (in pixels)"), + click.option("--lp-decay", type=float, required=False, default=50, help="Low-pass decay width (in pixels)"), click.option( "--hp-freq", type=float, @@ -19,12 +19,12 @@ def low_pass_commands(func): default=0, help="High-pass cutoff frequency (in Angstroms)", ), - click.option("--hp-decay", type=float, required=False, default=0, help="High-pass decay width (in pixels)"), + click.option("--hp-decay", type=float, required=False, default=50, help="High-pass decay width (in pixels)"), click.option( "--show-filter", type=bool, required=False, - default=False, + default=True, help="Save the filter as a Png (filter3d.png)", ), ] @@ -46,28 +46,36 @@ def copick_commands(func): ), click.option("--tomo-alg", type=str, required=True, help="Tomogram Algorithm to use"), click.option("--voxel-size", type=float, required=False, default=10, help="Voxel Size to Query the Data"), - click.option( - "--show-filter", - type=bool, - required=False, - default=True, - help="Save the filter as a Png (filter3d.png)", - ), ] for option in reversed(options): # Add options in reverse order to preserve correct order func = option(func) return func -def input_check(lp_freq, hp_freq, voxel_size): - if lp_freq == 0 and hp_freq == 0: - raise ValueError("Low-pass and high-pass frequencies cannot both be 0.") - elif lp_freq < voxel_size * 2: - raise ValueError("Low-pass frequency cannot be less than twice the Nyquist resolution.") - elif hp_freq < voxel_size * 2: - raise ValueError("High-pass frequency cannot be less than twice the Nyquist resolution.") - elif lp_freq > hp_freq and lp_freq > 0 and hp_freq > 0: - raise ValueError("Low-pass cutoff resolution must be less than high-pass cutoff resolution.") +def input_check(lp_res, hp_res, apix): + """ + lp_res, hp_res: resolutions in Å (0 means 'disabled') + apix: pixel size in Å/pixel + require_filter: if True, disallow the all-pass case (both 0) + """ + nyquist_res = 2.0 * apix # smallest physically valid resolution (Å) + + # All-pass allowed unless explicitly forbidden + if lp_res == 0 and hp_res == 0: + raise ValueError("Low-pass and high-pass cannot both be 0 (no filtering).") + + # Low-pass: if enabled, it cannot be finer (smaller Å) than Nyquist + if lp_res > 0 and lp_res < nyquist_res: + raise ValueError(f"Low-pass resolution {lp_res:.3g} Å is finer than Nyquist (2*apix = {nyquist_res:.3g} Å).") + + # High-pass: if enabled, it also cannot be finer than Nyquist (no frequencies exist beyond Nyquist) + if hp_res > 0 and hp_res < nyquist_res: + raise ValueError(f"High-pass resolution {hp_res:.3g} Å is finer than Nyquist (2*apix = {nyquist_res:.3g} Å).") + + # Band-pass consistency: in Å, larger number = lower spatial frequency + # Require hp > lp when both are enabled + if lp_res > 0 and hp_res > 0 and not (hp_res > lp_res): + raise ValueError("For band-pass, require hp (Å) > lp (Å).") def print_header(lp_freq, lp_decay, hp_freq, hp_decay): @@ -80,8 +88,8 @@ def print_header(lp_freq, lp_decay, hp_freq, hp_decay): @click.command(context_settings={"show_default": True}) -@low_pass_commands @copick_commands +@low_pass_commands def bandpass( config: str, run_ids: str, @@ -93,6 +101,9 @@ def bandpass( voxel_size: float, show_filter: bool, ): + """ + 3D bandpass filter tomograms. + """ run_filter3d(config, run_ids, lp_freq, lp_decay, hp_freq, hp_decay, tomo_alg, voxel_size, show_filter) @@ -115,7 +126,12 @@ def run_filter3d( from copick_torch import parallelization from copick_torch.filters.bandpass import init_filter3d, run_filter3d + # Input Check - Set Decay to 0 if Unused input_check(lp_freq, hp_freq, voxel_size) + if lp_freq == 0: + lp_decay = 0 + if hp_freq == 0: + hp_decay = 0 # Load Copick Project if os.path.exists(config): @@ -138,15 +154,21 @@ def run_filter3d( if hp_freq > 0: write_algorithm = write_algorithm + f"-hp{hp_freq:0.0f}A" + # Get Volume Shape + vol_shape = get_tomo_shape(root, run_ids, tomo_alg, voxel_size) + # Initialize Parallelization Pool pool = parallelization.GPUPool( init_fn=init_filter3d, - init_args=(voxel_size, lp_freq, lp_decay, hp_freq, hp_decay), + init_args=(voxel_size, vol_shape, lp_freq, lp_decay, hp_freq, hp_decay), verbose=True, ) + # Save Filter Image + if show_filter: + save_filter((voxel_size, vol_shape, lp_freq, lp_decay, hp_freq, hp_decay)) # Execute - tasks = [(run, tomo_alg, voxel_size, write_algorithm) for run in run_ids] + tasks = [(root.get_run(run), tomo_alg, voxel_size, write_algorithm) for run in run_ids] try: pool.execute( run_filter3d, @@ -160,6 +182,41 @@ def run_filter3d( save_parameters(config, [tomo_alg, voxel_size], [lp_freq, lp_decay, hp_freq, hp_decay], write_algorithm) print("✅ Completed the Filtering!") +def get_tomo_shape(root, run_ids, tomo_alg, voxel_size): + import numpy as np + import zarr + + for runID in run_ids: + # Get Volume Shape from First Run + run = root.get_run(runID) + + # Get Target Shape + vs = run.get_voxel_spacing(voxel_size) + if vs is None: + continue + tomo = vs.get_tomogram(tomo_alg) + if tomo is None: + continue + loc = tomo.zarr() + shape = zarr.open(loc)['0'].shape + target = np.zeros(shape, dtype=np.uint8) + + return target.shape + +def save_filter(params): + from copick_torch.filters.bandpass import Filter3D + import torch + + filter = Filter3D( + params[0], params[1], params[2], + params[3], params[4], params[5], + device=torch.device("cpu") + ) + + filter.show_filter() + + del filter + def save_parameters(config, tomo_info, parameters, write_algorithm): import os diff --git a/copick_torch/filters/bandpass.py b/copick_torch/filters/bandpass.py index e835642..102deaf 100644 --- a/copick_torch/filters/bandpass.py +++ b/copick_torch/filters/bandpass.py @@ -41,9 +41,9 @@ def __init__(self, apix, sz, lp=0, lpd=0, hp=0, hpd=0, device=None): else: self.device = device - # Check if low-pass cutoff resolution is less than high-pass cutoff resolution - if self.lp > self.hp and self.lp > 0 and self.hp > 0: - raise ValueError("Low-pass cutoff resolution must be less than high-pass cutoff resolution.") + # Allow LP-only, HP-only, or band-pass (hp > lp in Å) + if self.lp > 0 and self.hp > 0 and not (self.hp > self.lp): + raise ValueError("For band-pass, require hp (Å) > lp (Å).") # Convert cutoff values from angstroms to pixels self.lp_pix = self.angst_to_pix(self.lp) if self.lp > 0 else 0 # Low-pass cutoff in pixels @@ -94,41 +94,42 @@ def cosine_filter(self): def construct_filter(self, r, freq, freqdecay, mode="lp"): """ Constructs a low-pass or high-pass filter based on the mode. - - Args: - r (torch.Tensor): Radial spatial frequency tensor in pixels. - freq (float): Cutoff frequency in pixels. - freqdecay (float): Decay width in pixels. - mode (str): 'lp' for low-pass, 'hp' for high-pass. - - Returns: - torch.Tensor: Filter mask. + Handles pure LP or HP cases properly. """ if mode not in ["lp", "hp"]: raise ValueError("Mode must be 'lp' for low-pass or 'hp' for high-pass.") - # Skip filter - if freq == 0 and freqdecay == 0: - filter_mask = torch.ones_like(r) - # Box Filter - elif freq > 0 and freqdecay == 0: - filter_mask = (r < freq).float() - if mode == "hp": - filter_mask = 1 - filter_mask - # # Cosine Filter starting with 1 (low-pass) or 0 (high-pass) - # elif freq == 0 and freqdecay > 0: - # filter_mask = (r < freq).float() - # sel = r <= freqdecay - # filter_mask[sel] = 0.5 + 0.5 * torch.cos(math.pi * r[sel] / freqdecay) - # if mode == 'lp': filter_mask = 1 - filter_mask - # Box Filter with cosine decay - else: - half_decay = freqdecay / 2.0 - filter_mask = (r < freq).float() + # Start with neutral mask + filter_mask = torch.ones_like(r, dtype=self.dtype, device=self.device) + + # Skip filter: freq==0 disables that side + if freq == 0: + if mode == "lp": + # LP disabled → all-pass (1) + filter_mask[:] = 1.0 + elif mode == "hp": + # HP disabled → all-pass (1) + filter_mask[:] = 1.0 + return filter_mask + + # Box filter + if freqdecay == 0: + if mode == "lp": + filter_mask = (r <= freq).float() + else: # hp + filter_mask = (r >= freq).float() + return filter_mask + + # Cosine transition region + half_decay = freqdecay / 2.0 + if mode == "lp": + filter_mask = (r <= freq - half_decay).float() sel = (r > (freq - half_decay)) & (r < (freq + half_decay)) filter_mask[sel] = 0.5 + 0.5 * torch.cos(math.pi * (r[sel] - (freq - half_decay)) / freqdecay) - if mode == "hp": - filter_mask = 1 - filter_mask + else: # high-pass + filter_mask = (r >= freq + half_decay).float() + sel = (r > (freq - half_decay)) & (r < (freq + half_decay)) + filter_mask[sel] = 0.5 - 0.5 * torch.cos(math.pi * (r[sel] - (freq - half_decay)) / freqdecay) return filter_mask diff --git a/copick_torch/filters/downsample.py b/copick_torch/filters/downsample.py index 999a7bb..75478ee 100644 --- a/copick_torch/filters/downsample.py +++ b/copick_torch/filters/downsample.py @@ -49,10 +49,16 @@ def run(self, volume): if self.device.type == "cpu" and volume.dim() == 4: raise AssertionError("Batched volumes are not allowed on CPU. Please provide a single volume.") - if volume.dim() == 4: - output = self.batched_rescale(volume) - else: - output = self.single_rescale(volume) + # Try to Run on the GPU, if there's memory issues, fall back to CPU + try: + output = self.submit(volume) + except: + # Free GPU cache before retrying on CPU + torch.cuda.empty_cache() + + print('⚠️ GPU memory issue encountered, falling back to CPU for downsampling.') + self.device = torch.device("cpu") + output = self.submit(volume) # Return to CPU if Compute is on GPU if self.device != torch.device("cpu"): @@ -65,6 +71,17 @@ def run(self, volume): else: return output + def submit(self, volume: torch.Tensor) -> torch.Tensor: + """ + Submit the volume for rescaling based on its dimensionality. + """ + if volume.dim() == 4: + output = self.batched_rescale(volume) + else: + output = self.single_rescale(volume) + return output + + def batched_rescale(self, volume: torch.Tensor): """ Process a (batched) volume: move to device, perform FFT, crop in Fourier space, From 60dfb5615c11cae001d14d4468108a15810dd41b Mon Sep 17 00:00:00 2001 From: jtschwar Date: Thu, 30 Oct 2025 12:48:00 -0700 Subject: [PATCH 06/12] linting --- copick_torch/entry_points/run_filter3d.py | 17 ++++++++++++----- copick_torch/filters/downsample.py | 7 +++---- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/copick_torch/entry_points/run_filter3d.py b/copick_torch/entry_points/run_filter3d.py index 0d1d7ce..01a0740 100644 --- a/copick_torch/entry_points/run_filter3d.py +++ b/copick_torch/entry_points/run_filter3d.py @@ -182,6 +182,7 @@ def run_filter3d( save_parameters(config, [tomo_alg, voxel_size], [lp_freq, lp_decay, hp_freq, hp_decay], write_algorithm) print("✅ Completed the Filtering!") + def get_tomo_shape(root, run_ids, tomo_alg, voxel_size): import numpy as np import zarr @@ -198,19 +199,25 @@ def get_tomo_shape(root, run_ids, tomo_alg, voxel_size): if tomo is None: continue loc = tomo.zarr() - shape = zarr.open(loc)['0'].shape + shape = zarr.open(loc)["0"].shape target = np.zeros(shape, dtype=np.uint8) return target.shape + def save_filter(params): - from copick_torch.filters.bandpass import Filter3D import torch + from copick_torch.filters.bandpass import Filter3D + filter = Filter3D( - params[0], params[1], params[2], - params[3], params[4], params[5], - device=torch.device("cpu") + params[0], + params[1], + params[2], + params[3], + params[4], + params[5], + device=torch.device("cpu"), ) filter.show_filter() diff --git a/copick_torch/filters/downsample.py b/copick_torch/filters/downsample.py index 75478ee..b9ba0ba 100644 --- a/copick_torch/filters/downsample.py +++ b/copick_torch/filters/downsample.py @@ -50,13 +50,13 @@ def run(self, volume): raise AssertionError("Batched volumes are not allowed on CPU. Please provide a single volume.") # Try to Run on the GPU, if there's memory issues, fall back to CPU - try: + try: output = self.submit(volume) - except: + except Exception: # Free GPU cache before retrying on CPU torch.cuda.empty_cache() - print('⚠️ GPU memory issue encountered, falling back to CPU for downsampling.') + print("⚠️ GPU memory issue encountered, falling back to CPU for downsampling.") self.device = torch.device("cpu") output = self.submit(volume) @@ -81,7 +81,6 @@ def submit(self, volume: torch.Tensor) -> torch.Tensor: output = self.single_rescale(volume) return output - def batched_rescale(self, volume: torch.Tensor): """ Process a (batched) volume: move to device, perform FFT, crop in Fourier space, From fd05d15069fb09e810189fb7f3143945ff69ba8f Mon Sep 17 00:00:00 2001 From: Jonathan Schwartz <32110921+jtschwar@users.noreply.github.com> Date: Sun, 23 Nov 2025 21:23:05 -0800 Subject: [PATCH 07/12] Update copick_torch/entry_points/run_filter3d.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- copick_torch/entry_points/run_filter3d.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/copick_torch/entry_points/run_filter3d.py b/copick_torch/entry_points/run_filter3d.py index 01a0740..8dd0da4 100644 --- a/copick_torch/entry_points/run_filter3d.py +++ b/copick_torch/entry_points/run_filter3d.py @@ -25,7 +25,7 @@ def low_pass_commands(func): type=bool, required=False, default=True, - help="Save the filter as a Png (filter3d.png)", + help="Save the filter as a PNG (filter3d.png)", ), ] for option in reversed(options): # Add options in reverse order to preserve correct order From 5889f6ad50e95be9507f2195836ce83da4431691 Mon Sep 17 00:00:00 2001 From: Jonathan Schwartz <32110921+jtschwar@users.noreply.github.com> Date: Sun, 23 Nov 2025 21:23:17 -0800 Subject: [PATCH 08/12] Update copick_torch/filters/bandpass.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- copick_torch/filters/bandpass.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/copick_torch/filters/bandpass.py b/copick_torch/filters/bandpass.py index 102deaf..74abd2e 100644 --- a/copick_torch/filters/bandpass.py +++ b/copick_torch/filters/bandpass.py @@ -1,5 +1,5 @@ """ -This module contains functions for creating cosine-low pass filter and applying it to tomograms. +This module contains functions for creating cosine low-pass filter and applying it to tomograms. This is a written translation of the MATLAB code cosine_filter.m from the artia-wrapper package (https://github.com/uermel/artia-wrapper/tree/master) """ From ab18e6520d99b21eac8324c4b96f51fba863aa0f Mon Sep 17 00:00:00 2001 From: Jonathan Schwartz <32110921+jtschwar@users.noreply.github.com> Date: Sun, 23 Nov 2025 21:23:26 -0800 Subject: [PATCH 09/12] Update copick_torch/entry_points/run_membrane_seg.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- copick_torch/entry_points/run_membrane_seg.py | 1 - 1 file changed, 1 deletion(-) diff --git a/copick_torch/entry_points/run_membrane_seg.py b/copick_torch/entry_points/run_membrane_seg.py index 22104ae..4d9c331 100644 --- a/copick_torch/entry_points/run_membrane_seg.py +++ b/copick_torch/entry_points/run_membrane_seg.py @@ -97,7 +97,6 @@ def run(config, tomo_alg, voxel_size, session_id, threshold, user_id): def run_segmenter(run, tomo_alg, voxel_size, session_id, threshold, user_id, gpu_id, models): - from copick_utils.io import readers, writers from copick_torch.inference import membrain_seg From 617bdb89bc20dcb5ed57a5c0a2a6d146c25b9007 Mon Sep 17 00:00:00 2001 From: Jonathan Schwartz <32110921+jtschwar@users.noreply.github.com> Date: Sun, 23 Nov 2025 21:25:00 -0800 Subject: [PATCH 10/12] Update copick_torch/filters/downsample.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- copick_torch/filters/downsample.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/copick_torch/filters/downsample.py b/copick_torch/filters/downsample.py index b9ba0ba..0deda85 100644 --- a/copick_torch/filters/downsample.py +++ b/copick_torch/filters/downsample.py @@ -52,14 +52,19 @@ def run(self, volume): # Try to Run on the GPU, if there's memory issues, fall back to CPU try: output = self.submit(volume) - except Exception: - # Free GPU cache before retrying on CPU - torch.cuda.empty_cache() - - print("⚠️ GPU memory issue encountered, falling back to CPU for downsampling.") - self.device = torch.device("cpu") - output = self.submit(volume) - + except (torch.cuda.OutOfMemoryError, RuntimeError) as e: + # Only handle GPU out-of-memory errors, not all exceptions + if isinstance(e, torch.cuda.OutOfMemoryError) or ( + isinstance(e, RuntimeError) and "out of memory" in str(e).lower() + ): + # Free GPU cache before retrying on CPU + torch.cuda.empty_cache() + + print("⚠️ GPU memory issue encountered, falling back to CPU for downsampling.") + self.device = torch.device("cpu") + output = self.submit(volume) + else: + raise # Return to CPU if Compute is on GPU if self.device != torch.device("cpu"): output = output.cpu() From cc0ab68b75712858d115664e89b7ff2bc3703a56 Mon Sep 17 00:00:00 2001 From: Jonathan Schwartz <32110921+jtschwar@users.noreply.github.com> Date: Sun, 23 Nov 2025 21:25:15 -0800 Subject: [PATCH 11/12] Update copick_torch/entry_points/run_filter3d.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- copick_torch/entry_points/run_filter3d.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/copick_torch/entry_points/run_filter3d.py b/copick_torch/entry_points/run_filter3d.py index 8dd0da4..5b09246 100644 --- a/copick_torch/entry_points/run_filter3d.py +++ b/copick_torch/entry_points/run_filter3d.py @@ -200,11 +200,7 @@ def get_tomo_shape(root, run_ids, tomo_alg, voxel_size): continue loc = tomo.zarr() shape = zarr.open(loc)["0"].shape - target = np.zeros(shape, dtype=np.uint8) - - return target.shape - - + return shape def save_filter(params): import torch From 73385622515fcb234512d2e320af01b7b616fa36 Mon Sep 17 00:00:00 2001 From: jtschwar Date: Thu, 29 Jan 2026 13:21:38 -0800 Subject: [PATCH 12/12] Format with black --- copick_torch/entry_points/run_filter3d.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/copick_torch/entry_points/run_filter3d.py b/copick_torch/entry_points/run_filter3d.py index 5b09246..73c5c8c 100644 --- a/copick_torch/entry_points/run_filter3d.py +++ b/copick_torch/entry_points/run_filter3d.py @@ -201,6 +201,8 @@ def get_tomo_shape(root, run_ids, tomo_alg, voxel_size): loc = tomo.zarr() shape = zarr.open(loc)["0"].shape return shape + + def save_filter(params): import torch