diff --git a/examples/line_based_frc.py b/examples/line_based_frc.py new file mode 100644 index 00000000..980a7f4a --- /dev/null +++ b/examples/line_based_frc.py @@ -0,0 +1,117 @@ +import numpy as np +import matplotlib.pyplot as plt + +from cdtools.tools.analysis import line_based_frc + + +def create_stripe_test_pattern(shape, stripe_spacing, noise_level=0.1): + """Create test pattern with anisotropic stripes""" + h, w = shape + + # Create vertical stripes (varying along x-axis) + x = np.arange(w) + pattern = np.sin(2 * np.pi * x / stripe_spacing) + + # Broadcast to full image + image = np.tile(pattern, (h, 1)) + + # Add some modulation along y-axis for realism + y = np.arange(h) + y_mod = 1 + 0.3 * np.sin(2 * np.pi * y / (h // 3)) + image = image * y_mod[:, np.newaxis] + + # make the lower part 0 + image[7 * h // 8:] = 0 + + # Add noise + image += noise_level * np.random.randn(h, w) + + return image + + +# Create test images +shape = (200, 300) +stripe_spacing = 8 # pixels + +unit = "nm" +pixel_size = 19.0 # nm, for example + +# Create two slightly different versions to simulate repeated measurements +np.random.seed(42) +image1 = create_stripe_test_pattern( + shape, + stripe_spacing, + noise_level=0.05, +) + +np.random.seed(43) # Different seed for second image +image2 = create_stripe_test_pattern( + shape, + stripe_spacing, + noise_level=0.05, +) + +# Display test images +fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5)) + +ax1.imshow(image1, cmap="gray") +ax1.set_title("Test Image 1") +ax1.axis("off") + +ax2.imshow(image2, cmap="gray") +ax2.set_title("Test Image 2") +ax2.axis("off") +plt.tight_layout() + + +frc_thresholds = {"1/2": 0.5, "1/4": 0.25, "1/7": 1 / 7} + +# Compute binned FRC +frequencies, frc_curve, resolution_dict = line_based_frc( + image1, image2, axis=1, n_bins=100, thresholds=frc_thresholds, + pixel_size=pixel_size, unit=unit +) + +########################### +# Plot FRC curve with threshold crossings +########################### + +# Color map for different thresholds +colors = ["red", "blue", "green", "orange", "purple", "brown", "pink", "gray"] + +fig, ax1 = plt.subplots(1, 1, figsize=(6, 6)) + +# FRC curve with all thresholds +ax1.plot(frequencies, frc_curve, "k-", linewidth=3, label="FRC", zorder=10) + +for i, (method, threshold_val) in enumerate(frc_thresholds.items()): + color = colors[i % len(colors)] + ax1.axhline( + y=threshold_val, + color=color, + linestyle="--", + alpha=0.8, + label=f"{method} ({resolution_dict[method]:.3f} {unit})", + ) + + # Mark threshold frequency using the resolution value + resolution = resolution_dict[method] + if resolution is not None and resolution > 0: + freq = pixel_size / resolution + # Find the closest frequency in the array + ax1.scatter( + # threshold_val * np.ones_like(freq), + freq, + threshold_val, + color=color, + s=80, + edgecolor="black", + zorder=20, + ) + +ax1.set_xlabel(f"Spatial Frequency (1/{unit})") +ax1.set_ylabel("FRC") +ax1.set_title("Line-based FRC Curve") +ax1.legend() +plt.tight_layout() +plt.show() diff --git a/src/cdtools/datasets/ptycho_2d_dataset.py b/src/cdtools/datasets/ptycho_2d_dataset.py index adfe866e..83860925 100644 --- a/src/cdtools/datasets/ptycho_2d_dataset.py +++ b/src/cdtools/datasets/ptycho_2d_dataset.py @@ -300,35 +300,71 @@ def plot_mean_pattern(self, log_offset=1): cmap_label=cmap_label, title=title, ) - - - def split(self): - """Splits a dataset into two pseudorandomly selected sub-datasets + + def split(self, select_randomly: bool = True): """ + Splits a dataset into two pseudorandomly selected sub-datasets - # the selection is only 5,000 items long, so we repeat it to be long - # enough for the dataset - repeated_random_selection = (random_selection - * int(np.ceil(len(self) / len(random_selection)))) + select_randomly : bool + If True, the dataset is split into two disjoint datasets + using a pseudorandom selection. If False, the dataset is split + into two halves. + """ - repeated_random_selection = np.array(repeated_random_selection) - # Here, I use a fixed random selection for reproducibility - cut_random_selection =repeated_random_selection.astype(bool)[:len(self)] - - dataset_1 = deepcopy(self) - dataset_1.translations = self.translations[cut_random_selection] - dataset_1.patterns = self.patterns[cut_random_selection] - if hasattr(self, 'intensities') and self.intensities is not None: - dataset_1.intensities = self.intensities[cut_random_selection] - - dataset_2 = deepcopy(self) - dataset_2.translations = self.translations[~cut_random_selection] - dataset_2.patterns = self.patterns[~cut_random_selection] - if hasattr(self, 'intensities') and self.intensities is not None: - dataset_2.intensities = self.intensities[~cut_random_selection] + if select_randomly is True: + # the selection is only 5,000 items long, so we repeat it to be long + # enough for the dataset + repeated_random_selection = (random_selection * int(np.ceil(len(self) / len(random_selection)))) + + repeated_random_selection = np.array(repeated_random_selection) + # Here, I use a fixed random selection for reproducibility + cut_random_selection = repeated_random_selection.astype(bool)[:len(self)] + + dataset_1 = deepcopy(self) + dataset_1.translations = self.translations[cut_random_selection] + dataset_1.patterns = self.patterns[cut_random_selection] + if hasattr(self, 'intensities') and self.intensities is not None: + dataset_1.intensities = self.intensities[cut_random_selection] + + dataset_2 = deepcopy(self) + dataset_2.translations = self.translations[~cut_random_selection] + dataset_2.patterns = self.patterns[~cut_random_selection] + if hasattr(self, 'intensities') and self.intensities is not None: + dataset_2.intensities = self.intensities[~cut_random_selection] - return dataset_1, dataset_2 + return dataset_1, dataset_2 + elif select_randomly is False: + # If we are not randomly selecting, we just split the dataset in half + # by taking every second item of the dataset + + if len(self.translations) < 2: + raise ValueError( + 'The dataset is too small to split. It must contain at least 2 items.' + ) + if len(self.translations) % 2 != 0: + warnings.warn( + 'The dataset has an odd number of items, so the first half will be one item larger than the second half.' + ) + + dataset_1 = deepcopy(self) + dataset_1.translations = self.translations[::2] + dataset_1.patterns = self.patterns[::2] + if hasattr(self, 'intensities') and self.intensities is not None: + dataset_1.intensities = self.intensities[::2] + + dataset_2 = deepcopy(self) + dataset_2.translations = self.translations[1::2] + dataset_2.patterns = self.patterns[1::2] + if hasattr(self, 'intensities') and self.intensities is not None: + dataset_2.intensities = self.intensities[1::2] + + return dataset_1, dataset_2 + else: + raise ValueError( + 'select_randomly must be True or a list of booleans, not ' + f'{type(select_randomly)}' + ) def pad(self, to_pad, value=0, mask=True): """Pads all the diffraction patterns by a speficied amount diff --git a/src/cdtools/tools/analysis/analysis.py b/src/cdtools/tools/analysis/analysis.py index 9844edb4..e55b7f38 100644 --- a/src/cdtools/tools/analysis/analysis.py +++ b/src/cdtools/tools/analysis/analysis.py @@ -34,6 +34,7 @@ 'standardize_reconstruction_set', 'standardize_reconstruction_pair', 'calc_spectral_info', + 'line_based_frc', ] @@ -1525,4 +1526,161 @@ def calc_spectral_info(dataset, nbins=50): pattern_snr = sum_spectrum_sq / sum_spectrum return sum_pattern, frc_bins[:-1], mean_spectrum - + + +def line_based_frc( + image1, image2, axis=1, n_bins=None, thresholds={}, pixel_size=1.0, unit="pixels" +): + """ + Compute line-based FRC with multiple threshold methods. This works great for highly anisotropic + images, such as those with vertical or horizontal stripes. + + Parameters + ---------- + image1, image2: np.ndarray or torch.Tensor + Input images (should be same size). Must be 2D arrays. Should be aligned before + hand using ip.find_shift, ip.sinc_subpixel_shift, or similar methods. + axis: int + axis along which to extract lines (1 for horizontal lines, 0 for vertical) + if None, standard 2D FRC is computed + n_bins: int + number of frequency bins (if None, no binning is applied) + thresholds: dict + Dictionary of threshold methods and values + pixel_size: float + Physical size of one pixel + unit: str + Unit for the pixel size and resolution output + + Returns + ------- + frequencies: np.ndarray + Spatial frequency array + frc_curve: np.ndarray + FRC values as function of frequency + resolution_dict: dict + Resolution at different thresholds + """ + + def find_resolution(frequencies, frc_curve, thresholds): + resolution_dict = {} + for method, threshold_val in thresholds.items(): + above_threshold = frc_curve >= threshold_val + crossings = np.where(np.diff(above_threshold.astype(int)) == -1)[0] + if crossings.size > 0: + idx = crossings[0] + freq1, freq2 = frequencies[idx], frequencies[idx + 1] + frc1, frc2 = frc_curve[idx], frc_curve[idx + 1] + # Linear interpolation + resolution_frequency = freq1 + (threshold_val - frc1) * (freq2 - freq1) / (frc2 - frc1) + else: + if np.all(above_threshold): + resolution_frequency = frequencies[-1] + else: + resolution_frequency = 0 + if resolution_frequency > 0: + resolution_pixels = 1 / resolution_frequency + resolution_physical = resolution_pixels * pixel_size + else: + resolution_physical = np.inf + resolution_dict[method] = resolution_physical + return resolution_dict + + # transform all torch tensors to np.arrays + if isinstance(image1, t.Tensor): + image1 = image1.numpy() + if isinstance(image2, t.Tensor): + image2 = image2.numpy() + + assert isinstance(image1, np.ndarray), "image1 must be a numpy array or torch tensor" + assert isinstance(image2, np.ndarray), "image2 must be a numpy array or torch tensor" + assert image1.shape == image2.shape, "Images must have same shape" + if image1.ndim != 2 or image2.ndim != 2: + raise ValueError("Images must be 2D") + + if axis == 1: + lines1, lines2 = image1, image2 + n_freq = image1.shape[1] // 2 + elif axis == 0: + lines1, lines2 = image1.T, image2.T + n_freq = image1.shape[0] // 2 + elif axis is None: + lines1, lines2 = image1, image2 + n_freq = min(image1.shape) // 2 + else: + raise ValueError("Axis must be 0, 1, or None for line-based FRC") + + if axis is not None: + # Compute 1D FFT for each line + fft1 = np.fft.fft(lines1, axis=1) + fft2 = np.fft.fft(lines2, axis=1) + + # Only keep positive frequencies + fft1 = fft1[:, :n_freq] + fft2 = fft2[:, :n_freq] + + # Get the raw frequency array + raw_frequencies = np.fft.fftfreq(lines1.shape[1])[:n_freq] + + if n_bins is None: + # No binning - use original approach + numerator = np.sum(fft1 * np.conj(fft2), axis=0) + denominator = np.sqrt( + np.sum(np.abs(fft1) ** 2, axis=0) * np.sum(np.abs(fft2) ** 2, axis=0) + ) + + frc_curve = np.abs(numerator) / denominator + frequencies = raw_frequencies + + else: + # Binning approach - similar to 2D FRC + max_freq = raw_frequencies[-1] + freq_bins = np.linspace(0, max_freq, n_bins + 1) + frc_curve, frequencies = [], [] + for i in range(len(freq_bins) - 1): + mask = (raw_frequencies >= freq_bins[i]) & (raw_frequencies < freq_bins[i + 1]) + if np.any(mask): + fft1_bin = fft1[:, mask] + fft2_bin = fft2[:, mask] + + # Compute FRC for this bin (sum over both lines and frequencies) + # see van Heel, M. (2005). https://doi.org/10.1016/j.jsb.2005.05.009 + numerator = np.sum(fft1_bin * np.conj(fft2_bin)) + denominator = np.sqrt( + np.sum(np.abs(fft1_bin) ** 2) * np.sum(np.abs(fft2_bin) ** 2) + ) + frc_curve.append(np.abs(numerator) / denominator) + frequencies.append((freq_bins[i] + freq_bins[i + 1]) / 2) + else: + # Compute standard 2D FRC + fft1 = np.fft.fft2(lines1) + fft2 = np.fft.fft2(lines2) + + # Shift zero frequency to center + fft1 = np.fft.fftshift(fft1) + fft2 = np.fft.fftshift(fft2) + + # Compute radial frequencies + shape = fft1.shape + center = [int(s // 2) for s in shape] + Y, X = np.ogrid[:shape[0], :shape[1]] + r = np.sqrt((Y - center[0]) ** 2 + (X - center[1]) ** 2).astype(np.int32) + if n_bins is None: + n_bins = min(shape) // 2 + max_r = np.min(center) + bin_edges = np.linspace(0, max_r, n_bins + 1) + frc_curve, frequencies = [], [] + for i in range(n_bins): + mask = (r >= bin_edges[i]) & (r < bin_edges[i + 1]) + if np.any(mask): + num = np.sum(fft1[mask] * np.conj(fft2[mask])) + denom = np.sqrt(np.sum(np.abs(fft1[mask]) ** 2) * np.sum(np.abs(fft2[mask]) ** 2)) + frc_val = np.abs(num) / denom if denom != 0 else 0 + frc_curve.append(frc_val) + # Frequency in cycles per pixel + freq = (bin_edges[i] + bin_edges[i + 1]) / 2 / max_r * 0.5 + frequencies.append(freq) + frc_curve, frequencies = np.array(frc_curve), np.array(frequencies) + resolution_dict = find_resolution(frequencies, frc_curve, thresholds) + + return frequencies, frc_curve, resolution_dict diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 4d17f0b3..ac9c7bc3 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -510,3 +510,38 @@ def test_Ptycho2DDataset_crop_translations(ptycho_cxi_1): assert t.allclose(copied_dataset.patterns, dataset.patterns[10:-10, :]) assert t.allclose(copied_dataset.translations, dataset.translations[10:-10, :]) + + +def test_Ptycho2DDataset_split(ptycho_cxi_1): + # Grab dataset + cxi, expected = ptycho_cxi_1 + dataset = Ptycho2DDataset.from_cxi(cxi) + + # Test: Split the dataset into two datasets, select randomly, default True + dataset_1, dataset_2 = dataset.split() + + # This is using a fixed random selection, so we can check the beginning of this. + # from random_selection.py: random_selection = [0, 0, 1, 0, 0, 1, 0, ...] + assert len(dataset_1) + len(dataset_2) == len(dataset) + + assert t.allclose(dataset_1.patterns[0], dataset.patterns[2]) + assert t.allclose(dataset_1.patterns[1], dataset.patterns[5]) + assert t.allclose(dataset_2.patterns[0], dataset.patterns[0]) + assert t.allclose(dataset_2.patterns[1], dataset.patterns[1]) + assert t.allclose(dataset_2.patterns[2], dataset.patterns[3]) + + # Test the translations + assert t.allclose(dataset_1.translations[0], dataset.translations[2]) + assert t.allclose(dataset_1.translations[1], dataset.translations[5]) + assert t.allclose(dataset_2.translations[0], dataset.translations[0]) + assert t.allclose(dataset_2.translations[1], dataset.translations[1]) + assert t.allclose(dataset_2.translations[2], dataset.translations[3]) + + # Test: Split the dataset into two datasets + dataset_1, dataset_2 = dataset.split(select_randomly=False) + + assert len(dataset_1) + len(dataset_2) == len(dataset) + assert t.allclose(dataset_1.patterns, dataset.patterns[::2]) + assert t.allclose(dataset_2.patterns, dataset.patterns[1::2]) + assert t.allclose(dataset_1.translations, dataset.translations[::2]) + assert t.allclose(dataset_2.translations, dataset.translations[1::2]) diff --git a/tests/tools/test_analysis.py b/tests/tools/test_analysis.py index ad06757e..d293c9fb 100644 --- a/tests/tools/test_analysis.py +++ b/tests/tools/test_analysis.py @@ -1,9 +1,11 @@ import numpy as np +import pytest import torch as t from scipy import linalg as la from scipy.sparse import linalg as spla from cdtools.tools import analysis, initializers +from cdtools.tools.analysis import line_based_frc def test_product_svd(): @@ -559,6 +561,107 @@ def test_calc_generalized_rms_error(): fields_2 = t.rand(3, 1, 17, dtype=t.complex128) fields_3 = fields_2.flip(0) - assert analysis.calc_generalized_rms_error( - fields_1, fields_2, dims=1 - ).shape == t.Size([3]) + + assert (analysis.calc_generalized_rms_error(fields_1, fields_2, dims=1).shape == t.Size([3])) + + +def test_line_based_frc(): + # First test some basic input with two fields of the same size + field_1 = t.rand(30, 17, dtype=t.complex128) + field_2 = t.rand(30, 17, dtype=t.complex128) + + # Compute FRC + freq, frc, res_dict = line_based_frc(field_1, field_2, + axis=1, n_bins=50, + thresholds={"1/7": 1/7}, + ) + + # check if length of frc is longer than 0 + assert len(frc) > 0 + # check if freq is same length as frc + assert len(freq) == len(frc) + # check if res_dict is a dictionary + assert isinstance(res_dict, dict) + + # Now check if it works with numpy arrays + field_1_np = field_1.numpy() + field_2_np = field_2.numpy() + + freq_np, frc_np, res_dict_np = line_based_frc(field_1_np, field_2_np) + + # Do the same checks as above + assert len(frc_np) > 0 + assert len(freq_np) == len(frc_np) + assert isinstance(res_dict_np, dict) + + # Create test images from the provided example + def create_stripe_test_pattern(shape, stripe_spacing, noise_level=0.1): + """Create test pattern with anisotropic stripes""" + h, w = shape + + # Create vertical stripes (varying along x-axis) + x = np.arange(w) + pattern = np.sin(2 * np.pi * x / stripe_spacing) + + # Broadcast to full image + image = np.tile(pattern, (h, 1)) + + # Add some modulation along y-axis for realism + y = np.arange(h) + y_mod = 1 + 0.3 * np.sin(2 * np.pi * y / (h // 3)) + image = image * y_mod[:, np.newaxis] + + # make the lower part 0 + image[7 * h // 8:] = 0 + + # Add noise + image += noise_level * np.random.randn(h, w) + + return image + + shape = (200, 300) + stripe_spacing = 8 # pixels + + unit = "nm" + pixel_size = 19.0 # nm, for example + + # Create two slightly different versions to simulate repeated measurements + np.random.seed(42) + image1 = create_stripe_test_pattern( + shape, + stripe_spacing, + noise_level=0.05, + ) + + np.random.seed(43) # Different seed for second image + image2 = create_stripe_test_pattern( + shape, + stripe_spacing, + noise_level=0.05, + ) + + frc_thresholds = {"1/2": 0.5, "1/4": 0.25, "1/7": 1 / 7} + + # Compute binned FRC + frequencies, frc_curve, res_dict = line_based_frc( + image1, image2, axis=1, n_bins=100, thresholds=frc_thresholds, + pixel_size=pixel_size, unit=unit + ) + + # run through the above checks + assert len(frc) > 0 + assert len(freq) == len(frc) + assert isinstance(res_dict, dict) + + # check if the values are within expected ranges + assert res_dict["1/2"] > 70.0 and res_dict["1/2"] < 75.0 + assert res_dict["1/4"] > 52.0 and res_dict["1/4"] < 57.0 + assert res_dict["1/7"] > 45.0 and res_dict["1/7"] < 48.0 + + # Now lets check if we catch some errors correctly + # Check if it raises an error for non-2D input + with pytest.raises(ValueError): + line_based_frc(np.random.rand(30, 17, 3), np.random.rand(30, 17, 3)) + line_based_frc(np.random.rand(30, 17, 3), np.random.rand(30, 17)) + line_based_frc("string", np.random.rand(30, 17)) +