From 6de34919287391d3fb0febe1dce70b1ae12dc21e Mon Sep 17 00:00:00 2001 From: gnzng Date: Thu, 10 Jul 2025 14:24:24 -0700 Subject: [PATCH 1/9] Add line-based FRC function and corresponding tests, and example file - Implemented `line_based_frc` function for computing line-based Fourier Ring Correlation with multiple threshold methods. - Added unit tests for `line_based_frc` to validate functionality with both PyTorch tensors and NumPy arrays. - Created example script to demonstrate usage of `line_based_frc` with generated test images. --- examples/line_based_frc.py | 117 ++++++++++++++++++++ src/cdtools/tools/analysis/analysis.py | 146 ++++++++++++++++++++++++- tests/tools/test_analysis.py | 95 +++++++++++++++- 3 files changed, 356 insertions(+), 2 deletions(-) create mode 100644 examples/line_based_frc.py diff --git a/examples/line_based_frc.py b/examples/line_based_frc.py new file mode 100644 index 00000000..980a7f4a --- /dev/null +++ b/examples/line_based_frc.py @@ -0,0 +1,117 @@ +import numpy as np +import matplotlib.pyplot as plt + +from cdtools.tools.analysis import line_based_frc + + +def create_stripe_test_pattern(shape, stripe_spacing, noise_level=0.1): + """Create test pattern with anisotropic stripes""" + h, w = shape + + # Create vertical stripes (varying along x-axis) + x = np.arange(w) + pattern = np.sin(2 * np.pi * x / stripe_spacing) + + # Broadcast to full image + image = np.tile(pattern, (h, 1)) + + # Add some modulation along y-axis for realism + y = np.arange(h) + y_mod = 1 + 0.3 * np.sin(2 * np.pi * y / (h // 3)) + image = image * y_mod[:, np.newaxis] + + # make the lower part 0 + image[7 * h // 8:] = 0 + + # Add noise + image += noise_level * np.random.randn(h, w) + + return image + + +# Create test images +shape = (200, 300) +stripe_spacing = 8 # pixels + +unit = "nm" +pixel_size = 19.0 # nm, for example + +# Create two slightly different versions to simulate repeated measurements +np.random.seed(42) +image1 = create_stripe_test_pattern( + shape, + stripe_spacing, + noise_level=0.05, +) + +np.random.seed(43) # Different seed for second image +image2 = create_stripe_test_pattern( + shape, + stripe_spacing, + noise_level=0.05, +) + +# Display test images +fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5)) + +ax1.imshow(image1, cmap="gray") +ax1.set_title("Test Image 1") +ax1.axis("off") + +ax2.imshow(image2, cmap="gray") +ax2.set_title("Test Image 2") +ax2.axis("off") +plt.tight_layout() + + +frc_thresholds = {"1/2": 0.5, "1/4": 0.25, "1/7": 1 / 7} + +# Compute binned FRC +frequencies, frc_curve, resolution_dict = line_based_frc( + image1, image2, axis=1, n_bins=100, thresholds=frc_thresholds, + pixel_size=pixel_size, unit=unit +) + +########################### +# Plot FRC curve with threshold crossings +########################### + +# Color map for different thresholds +colors = ["red", "blue", "green", "orange", "purple", "brown", "pink", "gray"] + +fig, ax1 = plt.subplots(1, 1, figsize=(6, 6)) + +# FRC curve with all thresholds +ax1.plot(frequencies, frc_curve, "k-", linewidth=3, label="FRC", zorder=10) + +for i, (method, threshold_val) in enumerate(frc_thresholds.items()): + color = colors[i % len(colors)] + ax1.axhline( + y=threshold_val, + color=color, + linestyle="--", + alpha=0.8, + label=f"{method} ({resolution_dict[method]:.3f} {unit})", + ) + + # Mark threshold frequency using the resolution value + resolution = resolution_dict[method] + if resolution is not None and resolution > 0: + freq = pixel_size / resolution + # Find the closest frequency in the array + ax1.scatter( + # threshold_val * np.ones_like(freq), + freq, + threshold_val, + color=color, + s=80, + edgecolor="black", + zorder=20, + ) + +ax1.set_xlabel(f"Spatial Frequency (1/{unit})") +ax1.set_ylabel("FRC") +ax1.set_title("Line-based FRC Curve") +ax1.legend() +plt.tight_layout() +plt.show() diff --git a/src/cdtools/tools/analysis/analysis.py b/src/cdtools/tools/analysis/analysis.py index 9844edb4..38b13060 100644 --- a/src/cdtools/tools/analysis/analysis.py +++ b/src/cdtools/tools/analysis/analysis.py @@ -34,6 +34,7 @@ 'standardize_reconstruction_set', 'standardize_reconstruction_pair', 'calc_spectral_info', + 'line_based_frc', ] @@ -1525,4 +1526,147 @@ def calc_spectral_info(dataset, nbins=50): pattern_snr = sum_spectrum_sq / sum_spectrum return sum_pattern, frc_bins[:-1], mean_spectrum - + + +def line_based_frc(image1, image2, axis=1, n_bins=None, thresholds={}, pixel_size=1.0, unit="pixels"): + """ + Compute line-based FRC with multiple threshold methods + + Parameters + ---------- + image1, image2: np.ndarray or torch.Tensor + Input images (should be same size). Must be 2D arrays. + axis: int + axis along which to extract lines (1 for horizontal lines, 0 for vertical) + n_bins: int + number of frequency bins (if None, no binning is applied) + thresholds: dict + Dictionary of threshold methods and values + pixel_size: float + Physical size of one pixel + unit: str + Unit for the pixel size and resolution output + + Returns + ------- + frequencies: np.ndarray + Spatial frequency array + frc_curve: np.ndarray + FRC values as function of frequency + resolution_dict: dict + Resolution at different thresholds + """ + + # transform all torch tensors to np.arrays + if isinstance(image1, t.Tensor): + image1 = image1.numpy() + if isinstance(image2, t.Tensor): + image2 = image2.numpy() + + assert image1.shape == image2.shape, "Images must have same shape" + + if image1.ndim != 2: + raise ValueError("Images must be 2D") + + if image2.ndim != 2: + raise ValueError("Images must be 2D") + + if axis == 1: + # Extract horizontal lines (perpendicular to vertical stripes) + lines1 = image1 + lines2 = image2 + n_freq = image1.shape[1] // 2 + + else: + # Extract vertical lines (perpendicular to horizontal stripes) + lines1 = image1.T + lines2 = image2.T + n_freq = image1.shape[0] // 2 + + # Compute 1D FFT for each line + fft1 = np.fft.fft(lines1, axis=1) + fft2 = np.fft.fft(lines2, axis=1) + + # Only keep positive frequencies + fft1 = fft1[:, :n_freq] + fft2 = fft2[:, :n_freq] + + # Get the raw frequency array + raw_frequencies = np.fft.fftfreq(lines1.shape[1])[:n_freq] + + if n_bins is None: + # No binning - use original approach + numerator = np.sum(fft1 * np.conj(fft2), axis=0) + denominator = np.sqrt(np.sum(np.abs(fft1)**2, axis=0) * np.sum(np.abs(fft2)**2, axis=0)) + + frc_curve = np.abs(numerator) / denominator + frequencies = raw_frequencies + + else: + # Binning approach - similar to 2D FRC + max_freq = raw_frequencies[-1] + freq_bins = np.linspace(0, max_freq, n_bins + 1) + + # Compute binned FRC + frc_curve = [] + frequencies = [] + + for i in range(len(freq_bins) - 1): + # Find frequency indices in this bin + mask = (raw_frequencies >= freq_bins[i]) & (raw_frequencies < freq_bins[i+1]) + n_freq_components = np.sum(mask) + + if n_freq_components > 0: + # Extract FFT values in this frequency bin + fft1_bin = fft1[:, mask] # Shape: (n_lines, n_freq_components) + fft2_bin = fft2[:, mask] + + # Compute FRC for this bin (sum over both lines and frequencies) + numerator = np.sum(fft1_bin * np.conj(fft2_bin)) + denominator = np.sqrt(np.sum(np.abs(fft1_bin)**2) * np.sum(np.abs(fft2_bin)**2)) + + frc_value = np.abs(numerator) / denominator + frc_curve.append(frc_value) + frequencies.append((freq_bins[i] + freq_bins[i+1]) / 2) + + # Convert lists to numpy arrays + frc_curve = np.array(frc_curve) + frequencies = np.array(frequencies) + + # Find resolution at each threshold using interpolation + resolution_dict = {} + for method, threshold_val in thresholds.items(): + # Find where FRC crosses the threshold + above_threshold = frc_curve >= threshold_val + + # Find all crossings (from above to below threshold) + crossings = np.where(np.diff(above_threshold.astype(int)) == -1)[0] + + if crossings.size > 0: + # Use the first crossing (lowest frequency where FRC drops below threshold) + idx = crossings[0] + freq1 = frequencies[idx] + freq2 = frequencies[idx + 1] + frc1 = frc_curve[idx] + frc2 = frc_curve[idx + 1] + + # Linear interpolation: find frequency where FRC == threshold_val + resolution_frequency = freq1 + (threshold_val - frc1) * (freq2 - freq1) / (frc2 - frc1) + else: + # No crossing found, check if all above or all below + if np.all(above_threshold): + # FRC never drops below threshold, use max frequency + resolution_frequency = frequencies[-1] + else: + # FRC never reaches threshold + resolution_frequency = 0 + + if resolution_frequency > 0: + resolution_pixels = 1 / resolution_frequency + resolution_physical = resolution_pixels * pixel_size + else: + resolution_physical = np.inf + + resolution_dict[method] = resolution_physical + + return frequencies, frc_curve, resolution_dict diff --git a/tests/tools/test_analysis.py b/tests/tools/test_analysis.py index f0a1cf14..265a5eaf 100644 --- a/tests/tools/test_analysis.py +++ b/tests/tools/test_analysis.py @@ -2,9 +2,9 @@ from scipy import linalg as la from scipy.sparse import linalg as spla import torch as t -from itertools import combinations from cdtools.tools import analysis, initializers +from cdtools.tools.analysis import line_based_frc def test_product_svd(): @@ -519,3 +519,96 @@ def test_calc_generalized_rms_error(): assert (analysis.calc_generalized_rms_error(fields_1, fields_2, dims=1).shape == t.Size([3])) + +def test_line_based_frc(): + # First test some basic input with two fields of the same size + field_1 = t.rand(30, 17, dtype=t.complex128) + field_2 = t.rand(30, 17, dtype=t.complex128) + + # Compute FRC + freq, frc, res_dict = line_based_frc(field_1, field_2, + axis=1, n_bins=50, + thresholds={"1/7": 1/7}, + ) + + # check if length of frc is longer than 0 + assert len(frc) > 0 + # check if freq is same length as frc + assert len(freq) == len(frc) + # check if res_dict is a dictionary + assert isinstance(res_dict, dict) + + # Now check if it works with numpy arrays + field_1_np = field_1.numpy() + field_2_np = field_2.numpy() + + freq_np, frc_np, res_dict_np = line_based_frc(field_1_np, field_2_np) + + # Do the same checks as above + assert len(frc_np) > 0 + assert len(freq_np) == len(frc_np) + assert isinstance(res_dict_np, dict) + + # Create test images from the provided example + def create_stripe_test_pattern(shape, stripe_spacing, noise_level=0.1): + """Create test pattern with anisotropic stripes""" + h, w = shape + + # Create vertical stripes (varying along x-axis) + x = np.arange(w) + pattern = np.sin(2 * np.pi * x / stripe_spacing) + + # Broadcast to full image + image = np.tile(pattern, (h, 1)) + + # Add some modulation along y-axis for realism + y = np.arange(h) + y_mod = 1 + 0.3 * np.sin(2 * np.pi * y / (h // 3)) + image = image * y_mod[:, np.newaxis] + + # make the lower part 0 + image[7 * h // 8:] = 0 + + # Add noise + image += noise_level * np.random.randn(h, w) + + return image + + shape = (200, 300) + stripe_spacing = 8 # pixels + + unit = "nm" + pixel_size = 19.0 # nm, for example + + # Create two slightly different versions to simulate repeated measurements + np.random.seed(42) + image1 = create_stripe_test_pattern( + shape, + stripe_spacing, + noise_level=0.05, + ) + + np.random.seed(43) # Different seed for second image + image2 = create_stripe_test_pattern( + shape, + stripe_spacing, + noise_level=0.05, + ) + + frc_thresholds = {"1/2": 0.5, "1/4": 0.25, "1/7": 1 / 7} + + # Compute binned FRC + frequencies, frc_curve, res_dict = line_based_frc( + image1, image2, axis=1, n_bins=100, thresholds=frc_thresholds, + pixel_size=pixel_size, unit=unit + ) + + # run through the above checks + assert len(frc) > 0 + assert len(freq) == len(frc) + assert isinstance(res_dict, dict) + + # check if the values are within expected ranges + assert res_dict["1/2"] > 70.0 and res_dict["1/2"] < 75.0 + assert res_dict["1/4"] > 52.0 and res_dict["1/4"] < 57.0 + assert res_dict["1/7"] > 45.0 and res_dict["1/7"] < 48.0 From e0cf3ea2e175398ce66c86156d392a9e73ad18a0 Mon Sep 17 00:00:00 2001 From: gnzng Date: Thu, 10 Jul 2025 14:33:47 -0700 Subject: [PATCH 2/9] Enhance line_based_frc function with detailed docstring and input validation checks --- src/cdtools/tools/analysis/analysis.py | 6 +++++- tests/tools/test_analysis.py | 10 +++++++++- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/src/cdtools/tools/analysis/analysis.py b/src/cdtools/tools/analysis/analysis.py index 38b13060..055f028b 100644 --- a/src/cdtools/tools/analysis/analysis.py +++ b/src/cdtools/tools/analysis/analysis.py @@ -1530,7 +1530,8 @@ def calc_spectral_info(dataset, nbins=50): def line_based_frc(image1, image2, axis=1, n_bins=None, thresholds={}, pixel_size=1.0, unit="pixels"): """ - Compute line-based FRC with multiple threshold methods + Compute line-based FRC with multiple threshold methods. This works great for highly anisotropic + images, such as those with vertical or horizontal stripes. Parameters ---------- @@ -1563,6 +1564,9 @@ def line_based_frc(image1, image2, axis=1, n_bins=None, thresholds={}, pixel_siz if isinstance(image2, t.Tensor): image2 = image2.numpy() + assert isinstance(image1, np.ndarray), "image1 must be a numpy array or torch tensor" + assert isinstance(image2, np.ndarray), "image2 must be a numpy array or torch tensor" + assert image1.shape == image2.shape, "Images must have same shape" if image1.ndim != 2: diff --git a/tests/tools/test_analysis.py b/tests/tools/test_analysis.py index 265a5eaf..0c84a27f 100644 --- a/tests/tools/test_analysis.py +++ b/tests/tools/test_analysis.py @@ -1,7 +1,8 @@ import numpy as np +import pytest +import torch as t from scipy import linalg as la from scipy.sparse import linalg as spla -import torch as t from cdtools.tools import analysis, initializers from cdtools.tools.analysis import line_based_frc @@ -612,3 +613,10 @@ def create_stripe_test_pattern(shape, stripe_spacing, noise_level=0.1): assert res_dict["1/2"] > 70.0 and res_dict["1/2"] < 75.0 assert res_dict["1/4"] > 52.0 and res_dict["1/4"] < 57.0 assert res_dict["1/7"] > 45.0 and res_dict["1/7"] < 48.0 + + # Now lets check if we catch some errors correctly + # Check if it raises an error for non-2D input + with pytest.raises(ValueError): + line_based_frc(np.random.rand(30, 17, 3), np.random.rand(30, 17, 3)) + line_based_frc(np.random.rand(30, 17, 3), np.random.rand(30, 17)) + line_based_frc("string", np.random.rand(30, 17)) From 027831d3a8e365a61a633668a1aa7b1e319eb98c Mon Sep 17 00:00:00 2001 From: gnzng Date: Thu, 10 Jul 2025 14:46:29 -0700 Subject: [PATCH 3/9] Add citation for FRC computation in line_based_frc function --- src/cdtools/tools/analysis/analysis.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/cdtools/tools/analysis/analysis.py b/src/cdtools/tools/analysis/analysis.py index 055f028b..98c49864 100644 --- a/src/cdtools/tools/analysis/analysis.py +++ b/src/cdtools/tools/analysis/analysis.py @@ -1622,10 +1622,11 @@ def line_based_frc(image1, image2, axis=1, n_bins=None, thresholds={}, pixel_siz if n_freq_components > 0: # Extract FFT values in this frequency bin - fft1_bin = fft1[:, mask] # Shape: (n_lines, n_freq_components) + fft1_bin = fft1[:, mask] fft2_bin = fft2[:, mask] # Compute FRC for this bin (sum over both lines and frequencies) + # see van Heel, M. (2005). https://doi.org/10.1016/j.jsb.2005.05.009 numerator = np.sum(fft1_bin * np.conj(fft2_bin)) denominator = np.sqrt(np.sum(np.abs(fft1_bin)**2) * np.sum(np.abs(fft2_bin)**2)) From 505466b067a696ad248c62ad690f7bf252a8da59 Mon Sep 17 00:00:00 2001 From: gnzng Date: Tue, 15 Jul 2025 14:43:58 -0700 Subject: [PATCH 4/9] add to line_based_frc function to support standard 2D FRC computation for comparison --- src/cdtools/tools/analysis/analysis.py | 242 +++++++++++++++++-------- 1 file changed, 164 insertions(+), 78 deletions(-) diff --git a/src/cdtools/tools/analysis/analysis.py b/src/cdtools/tools/analysis/analysis.py index 98c49864..7fecf2b4 100644 --- a/src/cdtools/tools/analysis/analysis.py +++ b/src/cdtools/tools/analysis/analysis.py @@ -1528,7 +1528,9 @@ def calc_spectral_info(dataset, nbins=50): return sum_pattern, frc_bins[:-1], mean_spectrum -def line_based_frc(image1, image2, axis=1, n_bins=None, thresholds={}, pixel_size=1.0, unit="pixels"): +def line_based_frc( + image1, image2, axis=1, n_bins=None, thresholds={}, pixel_size=1.0, unit="pixels" +): """ Compute line-based FRC with multiple threshold methods. This works great for highly anisotropic images, such as those with vertical or horizontal stripes. @@ -1539,6 +1541,7 @@ def line_based_frc(image1, image2, axis=1, n_bins=None, thresholds={}, pixel_siz Input images (should be same size). Must be 2D arrays. axis: int axis along which to extract lines (1 for horizontal lines, 0 for vertical) + if None, standard 2D FRC is computed n_bins: int number of frequency bins (if None, no binning is applied) thresholds: dict @@ -1564,8 +1567,12 @@ def line_based_frc(image1, image2, axis=1, n_bins=None, thresholds={}, pixel_siz if isinstance(image2, t.Tensor): image2 = image2.numpy() - assert isinstance(image1, np.ndarray), "image1 must be a numpy array or torch tensor" - assert isinstance(image2, np.ndarray), "image2 must be a numpy array or torch tensor" + assert isinstance( + image1, np.ndarray + ), "image1 must be a numpy array or torch tensor" + assert isinstance( + image2, np.ndarray + ), "image2 must be a numpy array or torch tensor" assert image1.shape == image2.shape, "Images must have same shape" @@ -1580,98 +1587,177 @@ def line_based_frc(image1, image2, axis=1, n_bins=None, thresholds={}, pixel_siz lines1 = image1 lines2 = image2 n_freq = image1.shape[1] // 2 - - else: + elif axis == 0: # Extract vertical lines (perpendicular to horizontal stripes) lines1 = image1.T lines2 = image2.T n_freq = image1.shape[0] // 2 + elif axis is None: + # we now assume axis=None, which means we compute standard 2D FRC + lines1 = image1 + lines2 = image2 + n_freq = min(image1.shape) // 2 # Use half the smaller dimension + else: + raise ValueError("Axis must be 0, 1, or None for line-based FRC") - # Compute 1D FFT for each line - fft1 = np.fft.fft(lines1, axis=1) - fft2 = np.fft.fft(lines2, axis=1) + if axis is not None: + # Compute 1D FFT for each line + fft1 = np.fft.fft(lines1, axis=1) + fft2 = np.fft.fft(lines2, axis=1) - # Only keep positive frequencies - fft1 = fft1[:, :n_freq] - fft2 = fft2[:, :n_freq] + # Only keep positive frequencies + fft1 = fft1[:, :n_freq] + fft2 = fft2[:, :n_freq] - # Get the raw frequency array - raw_frequencies = np.fft.fftfreq(lines1.shape[1])[:n_freq] + # Get the raw frequency array + raw_frequencies = np.fft.fftfreq(lines1.shape[1])[:n_freq] - if n_bins is None: - # No binning - use original approach - numerator = np.sum(fft1 * np.conj(fft2), axis=0) - denominator = np.sqrt(np.sum(np.abs(fft1)**2, axis=0) * np.sum(np.abs(fft2)**2, axis=0)) + if n_bins is None: + # No binning - use original approach + numerator = np.sum(fft1 * np.conj(fft2), axis=0) + denominator = np.sqrt( + np.sum(np.abs(fft1) ** 2, axis=0) * np.sum(np.abs(fft2) ** 2, axis=0) + ) - frc_curve = np.abs(numerator) / denominator - frequencies = raw_frequencies + frc_curve = np.abs(numerator) / denominator + frequencies = raw_frequencies - else: - # Binning approach - similar to 2D FRC - max_freq = raw_frequencies[-1] - freq_bins = np.linspace(0, max_freq, n_bins + 1) + else: + # Binning approach - similar to 2D FRC + max_freq = raw_frequencies[-1] + freq_bins = np.linspace(0, max_freq, n_bins + 1) + + # Compute binned FRC + frc_curve = [] + frequencies = [] + + for i in range(len(freq_bins) - 1): + # Find frequency indices in this bin + mask = (raw_frequencies >= freq_bins[i]) & ( + raw_frequencies < freq_bins[i + 1] + ) + n_freq_components = np.sum(mask) + + if n_freq_components > 0: + # Extract FFT values in this frequency bin + fft1_bin = fft1[:, mask] + fft2_bin = fft2[:, mask] + + # Compute FRC for this bin (sum over both lines and frequencies) + # see van Heel, M. (2005). https://doi.org/10.1016/j.jsb.2005.05.009 + numerator = np.sum(fft1_bin * np.conj(fft2_bin)) + denominator = np.sqrt( + np.sum(np.abs(fft1_bin) ** 2) * np.sum(np.abs(fft2_bin) ** 2) + ) + + frc_value = np.abs(numerator) / denominator + frc_curve.append(frc_value) + frequencies.append((freq_bins[i] + freq_bins[i + 1]) / 2) + + # Convert lists to numpy arrays + frc_curve = np.array(frc_curve) + frequencies = np.array(frequencies) + + # Find resolution at each threshold using interpolation + resolution_dict = {} + for method, threshold_val in thresholds.items(): + # Find where FRC crosses the threshold + above_threshold = frc_curve >= threshold_val + + # Find all crossings (from above to below threshold) + crossings = np.where(np.diff(above_threshold.astype(int)) == -1)[0] + + if crossings.size > 0: + # Use the first crossing (lowest frequency where FRC drops below threshold) + idx = crossings[0] + freq1 = frequencies[idx] + freq2 = frequencies[idx + 1] + frc1 = frc_curve[idx] + frc2 = frc_curve[idx + 1] + + # Linear interpolation: find frequency where FRC == threshold_val + resolution_frequency = freq1 + (threshold_val - frc1) * ( + freq2 - freq1 + ) / (frc2 - frc1) + else: + # No crossing found, check if all above or all below + if np.all(above_threshold): + # FRC never drops below threshold, use max frequency + resolution_frequency = frequencies[-1] + else: + # FRC never reaches threshold + resolution_frequency = 0 + + if resolution_frequency > 0: + resolution_pixels = 1 / resolution_frequency + resolution_physical = resolution_pixels * pixel_size + else: + resolution_physical = np.inf - # Compute binned FRC + resolution_dict[method] = resolution_physical + + return frequencies, frc_curve, resolution_dict + else: + # Compute standard 2D FRC + fft1 = np.fft.fft2(lines1) + fft2 = np.fft.fft2(lines2) + + # Shift zero frequency to center + fft1 = np.fft.fftshift(fft1) + fft2 = np.fft.fftshift(fft2) + + # Compute radial frequencies + shape = fft1.shape + center = [int(s // 2) for s in shape] + Y, X = np.ogrid[:shape[0], :shape[1]] + r = np.sqrt((Y - center[0]) ** 2 + (X - center[1]) ** 2) + r = r.astype(np.int32) + + if n_bins is None: + n_bins = min(shape) // 2 + + max_r = np.min(center) + bin_edges = np.linspace(0, max_r, n_bins + 1) frc_curve = [] frequencies = [] - for i in range(len(freq_bins) - 1): - # Find frequency indices in this bin - mask = (raw_frequencies >= freq_bins[i]) & (raw_frequencies < freq_bins[i+1]) - n_freq_components = np.sum(mask) - - if n_freq_components > 0: - # Extract FFT values in this frequency bin - fft1_bin = fft1[:, mask] - fft2_bin = fft2[:, mask] - - # Compute FRC for this bin (sum over both lines and frequencies) - # see van Heel, M. (2005). https://doi.org/10.1016/j.jsb.2005.05.009 - numerator = np.sum(fft1_bin * np.conj(fft2_bin)) - denominator = np.sqrt(np.sum(np.abs(fft1_bin)**2) * np.sum(np.abs(fft2_bin)**2)) - - frc_value = np.abs(numerator) / denominator - frc_curve.append(frc_value) - frequencies.append((freq_bins[i] + freq_bins[i+1]) / 2) + for i in range(n_bins): + mask = (r >= bin_edges[i]) & (r < bin_edges[i + 1]) + if np.any(mask): + num = np.sum(fft1[mask] * np.conj(fft2[mask])) + denom = np.sqrt(np.sum(np.abs(fft1[mask]) ** 2) * np.sum(np.abs(fft2[mask]) ** 2)) + frc_val = np.abs(num) / denom if denom != 0 else 0 + frc_curve.append(frc_val) + # Frequency in cycles per pixel + freq = (bin_edges[i] + bin_edges[i + 1]) / 2 / max_r * 0.5 / pixel_size + frequencies.append(freq) - # Convert lists to numpy arrays frc_curve = np.array(frc_curve) frequencies = np.array(frequencies) - # Find resolution at each threshold using interpolation - resolution_dict = {} - for method, threshold_val in thresholds.items(): - # Find where FRC crosses the threshold - above_threshold = frc_curve >= threshold_val - - # Find all crossings (from above to below threshold) - crossings = np.where(np.diff(above_threshold.astype(int)) == -1)[0] - - if crossings.size > 0: - # Use the first crossing (lowest frequency where FRC drops below threshold) - idx = crossings[0] - freq1 = frequencies[idx] - freq2 = frequencies[idx + 1] - frc1 = frc_curve[idx] - frc2 = frc_curve[idx + 1] - - # Linear interpolation: find frequency where FRC == threshold_val - resolution_frequency = freq1 + (threshold_val - frc1) * (freq2 - freq1) / (frc2 - frc1) - else: - # No crossing found, check if all above or all below - if np.all(above_threshold): - # FRC never drops below threshold, use max frequency - resolution_frequency = frequencies[-1] + # Find resolution at each threshold using interpolation + resolution_dict = {} + for method, threshold_val in thresholds.items(): + above_threshold = frc_curve >= threshold_val + crossings = np.where(np.diff(above_threshold.astype(int)) == -1)[0] + if crossings.size > 0: + idx = crossings[0] + freq1 = frequencies[idx] + freq2 = frequencies[idx + 1] + frc1 = frc_curve[idx] + frc2 = frc_curve[idx + 1] + # Linear interpolation + resolution_frequency = freq1 + (threshold_val - frc1) * (freq2 - freq1) / (frc2 - frc1) else: - # FRC never reaches threshold - resolution_frequency = 0 - - if resolution_frequency > 0: - resolution_pixels = 1 / resolution_frequency - resolution_physical = resolution_pixels * pixel_size - else: - resolution_physical = np.inf - - resolution_dict[method] = resolution_physical + if np.all(above_threshold): + resolution_frequency = frequencies[-1] + else: + resolution_frequency = 0 + if resolution_frequency > 0: + resolution_pixels = 1 / resolution_frequency + resolution_physical = resolution_pixels * pixel_size + else: + resolution_physical = np.inf + resolution_dict[method] = resolution_physical - return frequencies, frc_curve, resolution_dict + return frequencies, frc_curve, resolution_dict From 4e19e1c1088b3da67b09f54a768f6208dbd8c94a Mon Sep 17 00:00:00 2001 From: gnzng Date: Tue, 15 Jul 2025 14:58:13 -0700 Subject: [PATCH 5/9] reduce code redundancy --- src/cdtools/tools/analysis/analysis.py | 160 +++++++------------------ 1 file changed, 41 insertions(+), 119 deletions(-) diff --git a/src/cdtools/tools/analysis/analysis.py b/src/cdtools/tools/analysis/analysis.py index 7fecf2b4..36553352 100644 --- a/src/cdtools/tools/analysis/analysis.py +++ b/src/cdtools/tools/analysis/analysis.py @@ -1561,42 +1561,51 @@ def line_based_frc( Resolution at different thresholds """ + def find_resolution(frequencies, frc_curve, thresholds): + resolution_dict = {} + for method, threshold_val in thresholds.items(): + above_threshold = frc_curve >= threshold_val + crossings = np.where(np.diff(above_threshold.astype(int)) == -1)[0] + if crossings.size > 0: + idx = crossings[0] + freq1, freq2 = frequencies[idx], frequencies[idx + 1] + frc1, frc2 = frc_curve[idx], frc_curve[idx + 1] + # Linear interpolation + resolution_frequency = freq1 + (threshold_val - frc1) * (freq2 - freq1) / (frc2 - frc1) + else: + if np.all(above_threshold): + resolution_frequency = frequencies[-1] + else: + resolution_frequency = 0 + if resolution_frequency > 0: + resolution_pixels = 1 / resolution_frequency + resolution_physical = resolution_pixels * pixel_size + else: + resolution_physical = np.inf + resolution_dict[method] = resolution_physical + return resolution_dict + # transform all torch tensors to np.arrays if isinstance(image1, t.Tensor): image1 = image1.numpy() if isinstance(image2, t.Tensor): image2 = image2.numpy() - assert isinstance( - image1, np.ndarray - ), "image1 must be a numpy array or torch tensor" - assert isinstance( - image2, np.ndarray - ), "image2 must be a numpy array or torch tensor" - + assert isinstance(image1, np.ndarray), "image1 must be a numpy array or torch tensor" + assert isinstance(image2, np.ndarray), "image2 must be a numpy array or torch tensor" assert image1.shape == image2.shape, "Images must have same shape" - - if image1.ndim != 2: - raise ValueError("Images must be 2D") - - if image2.ndim != 2: + if image1.ndim != 2 or image2.ndim != 2: raise ValueError("Images must be 2D") if axis == 1: - # Extract horizontal lines (perpendicular to vertical stripes) - lines1 = image1 - lines2 = image2 + lines1, lines2 = image1, image2 n_freq = image1.shape[1] // 2 elif axis == 0: - # Extract vertical lines (perpendicular to horizontal stripes) - lines1 = image1.T - lines2 = image2.T + lines1, lines2 = image1.T, image2.T n_freq = image1.shape[0] // 2 elif axis is None: - # we now assume axis=None, which means we compute standard 2D FRC - lines1 = image1 - lines2 = image2 - n_freq = min(image1.shape) // 2 # Use half the smaller dimension + lines1, lines2 = image1, image2 + n_freq = min(image1.shape) // 2 else: raise ValueError("Axis must be 0, 1, or None for line-based FRC") @@ -1626,20 +1635,10 @@ def line_based_frc( # Binning approach - similar to 2D FRC max_freq = raw_frequencies[-1] freq_bins = np.linspace(0, max_freq, n_bins + 1) - - # Compute binned FRC - frc_curve = [] - frequencies = [] - + frc_curve, frequencies = [], [] for i in range(len(freq_bins) - 1): - # Find frequency indices in this bin - mask = (raw_frequencies >= freq_bins[i]) & ( - raw_frequencies < freq_bins[i + 1] - ) - n_freq_components = np.sum(mask) - - if n_freq_components > 0: - # Extract FFT values in this frequency bin + mask = (raw_frequencies >= freq_bins[i]) & (raw_frequencies < freq_bins[i + 1]) + if np.any(mask): fft1_bin = fft1[:, mask] fft2_bin = fft2[:, mask] @@ -1649,54 +1648,8 @@ def line_based_frc( denominator = np.sqrt( np.sum(np.abs(fft1_bin) ** 2) * np.sum(np.abs(fft2_bin) ** 2) ) - - frc_value = np.abs(numerator) / denominator - frc_curve.append(frc_value) + frc_curve.append(np.abs(numerator) / denominator) frequencies.append((freq_bins[i] + freq_bins[i + 1]) / 2) - - # Convert lists to numpy arrays - frc_curve = np.array(frc_curve) - frequencies = np.array(frequencies) - - # Find resolution at each threshold using interpolation - resolution_dict = {} - for method, threshold_val in thresholds.items(): - # Find where FRC crosses the threshold - above_threshold = frc_curve >= threshold_val - - # Find all crossings (from above to below threshold) - crossings = np.where(np.diff(above_threshold.astype(int)) == -1)[0] - - if crossings.size > 0: - # Use the first crossing (lowest frequency where FRC drops below threshold) - idx = crossings[0] - freq1 = frequencies[idx] - freq2 = frequencies[idx + 1] - frc1 = frc_curve[idx] - frc2 = frc_curve[idx + 1] - - # Linear interpolation: find frequency where FRC == threshold_val - resolution_frequency = freq1 + (threshold_val - frc1) * ( - freq2 - freq1 - ) / (frc2 - frc1) - else: - # No crossing found, check if all above or all below - if np.all(above_threshold): - # FRC never drops below threshold, use max frequency - resolution_frequency = frequencies[-1] - else: - # FRC never reaches threshold - resolution_frequency = 0 - - if resolution_frequency > 0: - resolution_pixels = 1 / resolution_frequency - resolution_physical = resolution_pixels * pixel_size - else: - resolution_physical = np.inf - - resolution_dict[method] = resolution_physical - - return frequencies, frc_curve, resolution_dict else: # Compute standard 2D FRC fft1 = np.fft.fft2(lines1) @@ -1710,17 +1663,12 @@ def line_based_frc( shape = fft1.shape center = [int(s // 2) for s in shape] Y, X = np.ogrid[:shape[0], :shape[1]] - r = np.sqrt((Y - center[0]) ** 2 + (X - center[1]) ** 2) - r = r.astype(np.int32) - + r = np.sqrt((Y - center[0]) ** 2 + (X - center[1]) ** 2).astype(np.int32) if n_bins is None: n_bins = min(shape) // 2 - max_r = np.min(center) bin_edges = np.linspace(0, max_r, n_bins + 1) - frc_curve = [] - frequencies = [] - + frc_curve, frequencies = [], [] for i in range(n_bins): mask = (r >= bin_edges[i]) & (r < bin_edges[i + 1]) if np.any(mask): @@ -1731,33 +1679,7 @@ def line_based_frc( # Frequency in cycles per pixel freq = (bin_edges[i] + bin_edges[i + 1]) / 2 / max_r * 0.5 / pixel_size frequencies.append(freq) - - frc_curve = np.array(frc_curve) - frequencies = np.array(frequencies) - - # Find resolution at each threshold using interpolation - resolution_dict = {} - for method, threshold_val in thresholds.items(): - above_threshold = frc_curve >= threshold_val - crossings = np.where(np.diff(above_threshold.astype(int)) == -1)[0] - if crossings.size > 0: - idx = crossings[0] - freq1 = frequencies[idx] - freq2 = frequencies[idx + 1] - frc1 = frc_curve[idx] - frc2 = frc_curve[idx + 1] - # Linear interpolation - resolution_frequency = freq1 + (threshold_val - frc1) * (freq2 - freq1) / (frc2 - frc1) - else: - if np.all(above_threshold): - resolution_frequency = frequencies[-1] - else: - resolution_frequency = 0 - if resolution_frequency > 0: - resolution_pixels = 1 / resolution_frequency - resolution_physical = resolution_pixels * pixel_size - else: - resolution_physical = np.inf - resolution_dict[method] = resolution_physical - - return frequencies, frc_curve, resolution_dict + frc_curve, frequencies = np.array(frc_curve), np.array(frequencies) + resolution_dict = find_resolution(frequencies, frc_curve, thresholds) + + return frequencies, frc_curve, resolution_dict From 3a71d742d2e447c863c17d68ac356c05d5b4599a Mon Sep 17 00:00:00 2001 From: gnzng Date: Tue, 29 Jul 2025 14:28:07 -0700 Subject: [PATCH 6/9] Enhance split method in Ptycho2DDataset to allow non-random dataset splitting --- src/cdtools/datasets/ptycho_2d_dataset.py | 84 ++++++++++++++++------- tests/test_datasets.py | 35 ++++++++++ 2 files changed, 95 insertions(+), 24 deletions(-) diff --git a/src/cdtools/datasets/ptycho_2d_dataset.py b/src/cdtools/datasets/ptycho_2d_dataset.py index 3825d6de..1b6b58ad 100644 --- a/src/cdtools/datasets/ptycho_2d_dataset.py +++ b/src/cdtools/datasets/ptycho_2d_dataset.py @@ -296,35 +296,71 @@ def plot_mean_pattern(self, log_offset=1): cmap_label=cmap_label, title=title, ) - - - def split(self): - """Splits a dataset into two pseudorandomly selected sub-datasets + + def split(self, select_randomly: bool = True): """ + Splits a dataset into two pseudorandomly selected sub-datasets - # the selection is only 5,000 items long, so we repeat it to be long - # enough for the dataset - repeated_random_selection = (random_selection - * int(np.ceil(len(self) / len(random_selection)))) + select_randomly : bool + If True, the dataset is split into two disjoint datasets + using a pseudorandom selection. If False, the dataset is split + into two halves. + """ - repeated_random_selection = np.array(repeated_random_selection) - # Here, I use a fixed random selection for reproducibility - cut_random_selection =repeated_random_selection.astype(bool)[:len(self)] - - dataset_1 = deepcopy(self) - dataset_1.translations = self.translations[cut_random_selection] - dataset_1.patterns = self.patterns[cut_random_selection] - if hasattr(self, 'intensities') and self.intensities is not None: - dataset_1.intensities = self.intensities[cut_random_selection] - - dataset_2 = deepcopy(self) - dataset_2.translations = self.translations[~cut_random_selection] - dataset_2.patterns = self.patterns[~cut_random_selection] - if hasattr(self, 'intensities') and self.intensities is not None: - dataset_2.intensities = self.intensities[~cut_random_selection] + if select_randomly is True: + # the selection is only 5,000 items long, so we repeat it to be long + # enough for the dataset + repeated_random_selection = (random_selection * int(np.ceil(len(self) / len(random_selection)))) + + repeated_random_selection = np.array(repeated_random_selection) + # Here, I use a fixed random selection for reproducibility + cut_random_selection = repeated_random_selection.astype(bool)[:len(self)] + + dataset_1 = deepcopy(self) + dataset_1.translations = self.translations[cut_random_selection] + dataset_1.patterns = self.patterns[cut_random_selection] + if hasattr(self, 'intensities') and self.intensities is not None: + dataset_1.intensities = self.intensities[cut_random_selection] + + dataset_2 = deepcopy(self) + dataset_2.translations = self.translations[~cut_random_selection] + dataset_2.patterns = self.patterns[~cut_random_selection] + if hasattr(self, 'intensities') and self.intensities is not None: + dataset_2.intensities = self.intensities[~cut_random_selection] - return dataset_1, dataset_2 + return dataset_1, dataset_2 + elif select_randomly is False: + # If we are not randomly selecting, we just split the dataset in half + # by taking every second item of the dataset + + if len(self.translations) < 2: + raise ValueError( + 'The dataset is too small to split. It must contain at least 2 items.' + ) + if len(self.translations) % 2 != 0: + warnings.warn( + 'The dataset has an odd number of items, so the first half will be one item larger than the second half.' + ) + + dataset_1 = deepcopy(self) + dataset_1.translations = self.translations[::2] + dataset_1.patterns = self.patterns[::2] + if hasattr(self, 'intensities') and self.intensities is not None: + dataset_1.intensities = self.intensities[::2] + + dataset_2 = deepcopy(self) + dataset_2.translations = self.translations[1::2] + dataset_2.patterns = self.patterns[1::2] + if hasattr(self, 'intensities') and self.intensities is not None: + dataset_2.intensities = self.intensities[1::2] + + return dataset_1, dataset_2 + else: + raise ValueError( + 'select_randomly must be True or a list of booleans, not ' + f'{type(select_randomly)}' + ) def pad(self, to_pad, value=0, mask=True): """Pads all the diffraction patterns by a speficied amount diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 4d17f0b3..ac9c7bc3 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -510,3 +510,38 @@ def test_Ptycho2DDataset_crop_translations(ptycho_cxi_1): assert t.allclose(copied_dataset.patterns, dataset.patterns[10:-10, :]) assert t.allclose(copied_dataset.translations, dataset.translations[10:-10, :]) + + +def test_Ptycho2DDataset_split(ptycho_cxi_1): + # Grab dataset + cxi, expected = ptycho_cxi_1 + dataset = Ptycho2DDataset.from_cxi(cxi) + + # Test: Split the dataset into two datasets, select randomly, default True + dataset_1, dataset_2 = dataset.split() + + # This is using a fixed random selection, so we can check the beginning of this. + # from random_selection.py: random_selection = [0, 0, 1, 0, 0, 1, 0, ...] + assert len(dataset_1) + len(dataset_2) == len(dataset) + + assert t.allclose(dataset_1.patterns[0], dataset.patterns[2]) + assert t.allclose(dataset_1.patterns[1], dataset.patterns[5]) + assert t.allclose(dataset_2.patterns[0], dataset.patterns[0]) + assert t.allclose(dataset_2.patterns[1], dataset.patterns[1]) + assert t.allclose(dataset_2.patterns[2], dataset.patterns[3]) + + # Test the translations + assert t.allclose(dataset_1.translations[0], dataset.translations[2]) + assert t.allclose(dataset_1.translations[1], dataset.translations[5]) + assert t.allclose(dataset_2.translations[0], dataset.translations[0]) + assert t.allclose(dataset_2.translations[1], dataset.translations[1]) + assert t.allclose(dataset_2.translations[2], dataset.translations[3]) + + # Test: Split the dataset into two datasets + dataset_1, dataset_2 = dataset.split(select_randomly=False) + + assert len(dataset_1) + len(dataset_2) == len(dataset) + assert t.allclose(dataset_1.patterns, dataset.patterns[::2]) + assert t.allclose(dataset_2.patterns, dataset.patterns[1::2]) + assert t.allclose(dataset_1.translations, dataset.translations[::2]) + assert t.allclose(dataset_2.translations, dataset.translations[1::2]) From 4c021148f02f4553b1e29d0f90c46d4761bf9dc8 Mon Sep 17 00:00:00 2001 From: gnzng Date: Tue, 29 Jul 2025 16:10:45 -0700 Subject: [PATCH 7/9] Enhance line_based_frc function to align images using hann window and subpixel shifting --- src/cdtools/tools/analysis/analysis.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/cdtools/tools/analysis/analysis.py b/src/cdtools/tools/analysis/analysis.py index 36553352..28bd275b 100644 --- a/src/cdtools/tools/analysis/analysis.py +++ b/src/cdtools/tools/analysis/analysis.py @@ -1597,6 +1597,14 @@ def find_resolution(frequencies, frc_curve, thresholds): if image1.ndim != 2 or image2.ndim != 2: raise ValueError("Images must be 2D") + # Aligning the images using find_shift and sinc_subpixel_shift + # using hann window to reduce artifacts + shift = ip.find_shift( + t.as_tensor(ip.hann_window(image1)), + t.as_tensor(ip.hann_window(image2)) + ) + image2 = ip.sinc_subpixel_shift(t.as_tensor(image2), shift).numpy() + if axis == 1: lines1, lines2 = image1, image2 n_freq = image1.shape[1] // 2 From 368a2c616ca3666c68d686895b84e858b1206d29 Mon Sep 17 00:00:00 2001 From: gnzng Date: Tue, 29 Jul 2025 16:35:52 -0700 Subject: [PATCH 8/9] remove alignment and assume aligned images as input --- src/cdtools/tools/analysis/analysis.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/src/cdtools/tools/analysis/analysis.py b/src/cdtools/tools/analysis/analysis.py index 28bd275b..f48867b4 100644 --- a/src/cdtools/tools/analysis/analysis.py +++ b/src/cdtools/tools/analysis/analysis.py @@ -1538,7 +1538,8 @@ def line_based_frc( Parameters ---------- image1, image2: np.ndarray or torch.Tensor - Input images (should be same size). Must be 2D arrays. + Input images (should be same size). Must be 2D arrays. Should be aligned before + hand using ip.find_shift, ip.sinc_subpixel_shift, or similar methods. axis: int axis along which to extract lines (1 for horizontal lines, 0 for vertical) if None, standard 2D FRC is computed @@ -1597,14 +1598,6 @@ def find_resolution(frequencies, frc_curve, thresholds): if image1.ndim != 2 or image2.ndim != 2: raise ValueError("Images must be 2D") - # Aligning the images using find_shift and sinc_subpixel_shift - # using hann window to reduce artifacts - shift = ip.find_shift( - t.as_tensor(ip.hann_window(image1)), - t.as_tensor(ip.hann_window(image2)) - ) - image2 = ip.sinc_subpixel_shift(t.as_tensor(image2), shift).numpy() - if axis == 1: lines1, lines2 = image1, image2 n_freq = image1.shape[1] // 2 From e4a713cdf1a855ee373b068c49b797f6355bfa6d Mon Sep 17 00:00:00 2001 From: gnzng Date: Tue, 21 Oct 2025 20:15:40 -0700 Subject: [PATCH 9/9] Fix frequency calculation in line_based_frc for 2D case --- src/cdtools/tools/analysis/analysis.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/cdtools/tools/analysis/analysis.py b/src/cdtools/tools/analysis/analysis.py index f48867b4..e55b7f38 100644 --- a/src/cdtools/tools/analysis/analysis.py +++ b/src/cdtools/tools/analysis/analysis.py @@ -1678,9 +1678,9 @@ def find_resolution(frequencies, frc_curve, thresholds): frc_val = np.abs(num) / denom if denom != 0 else 0 frc_curve.append(frc_val) # Frequency in cycles per pixel - freq = (bin_edges[i] + bin_edges[i + 1]) / 2 / max_r * 0.5 / pixel_size + freq = (bin_edges[i] + bin_edges[i + 1]) / 2 / max_r * 0.5 frequencies.append(freq) frc_curve, frequencies = np.array(frc_curve), np.array(frequencies) resolution_dict = find_resolution(frequencies, frc_curve, thresholds) - + return frequencies, frc_curve, resolution_dict