Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 117 additions & 0 deletions examples/line_based_frc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import numpy as np
import matplotlib.pyplot as plt

from cdtools.tools.analysis import line_based_frc


def create_stripe_test_pattern(shape, stripe_spacing, noise_level=0.1):
"""Create test pattern with anisotropic stripes"""
h, w = shape

# Create vertical stripes (varying along x-axis)
x = np.arange(w)
pattern = np.sin(2 * np.pi * x / stripe_spacing)

# Broadcast to full image
image = np.tile(pattern, (h, 1))

# Add some modulation along y-axis for realism
y = np.arange(h)
y_mod = 1 + 0.3 * np.sin(2 * np.pi * y / (h // 3))
image = image * y_mod[:, np.newaxis]

# make the lower part 0
image[7 * h // 8:] = 0

# Add noise
image += noise_level * np.random.randn(h, w)

return image


# Create test images
shape = (200, 300)
stripe_spacing = 8 # pixels

unit = "nm"
pixel_size = 19.0 # nm, for example

# Create two slightly different versions to simulate repeated measurements
np.random.seed(42)
image1 = create_stripe_test_pattern(
shape,
stripe_spacing,
noise_level=0.05,
)

np.random.seed(43) # Different seed for second image
image2 = create_stripe_test_pattern(
shape,
stripe_spacing,
noise_level=0.05,
)

# Display test images
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))

ax1.imshow(image1, cmap="gray")
ax1.set_title("Test Image 1")
ax1.axis("off")

ax2.imshow(image2, cmap="gray")
ax2.set_title("Test Image 2")
ax2.axis("off")
plt.tight_layout()


frc_thresholds = {"1/2": 0.5, "1/4": 0.25, "1/7": 1 / 7}

# Compute binned FRC
frequencies, frc_curve, resolution_dict = line_based_frc(
image1, image2, axis=1, n_bins=100, thresholds=frc_thresholds,
pixel_size=pixel_size, unit=unit
)

###########################
# Plot FRC curve with threshold crossings
###########################

# Color map for different thresholds
colors = ["red", "blue", "green", "orange", "purple", "brown", "pink", "gray"]

fig, ax1 = plt.subplots(1, 1, figsize=(6, 6))

# FRC curve with all thresholds
ax1.plot(frequencies, frc_curve, "k-", linewidth=3, label="FRC", zorder=10)

for i, (method, threshold_val) in enumerate(frc_thresholds.items()):
color = colors[i % len(colors)]
ax1.axhline(
y=threshold_val,
color=color,
linestyle="--",
alpha=0.8,
label=f"{method} ({resolution_dict[method]:.3f} {unit})",
)

# Mark threshold frequency using the resolution value
resolution = resolution_dict[method]
if resolution is not None and resolution > 0:
freq = pixel_size / resolution
# Find the closest frequency in the array
ax1.scatter(
# threshold_val * np.ones_like(freq),
freq,
threshold_val,
color=color,
s=80,
edgecolor="black",
zorder=20,
)

ax1.set_xlabel(f"Spatial Frequency (1/{unit})")
ax1.set_ylabel("FRC")
ax1.set_title("Line-based FRC Curve")
ax1.legend()
plt.tight_layout()
plt.show()
84 changes: 60 additions & 24 deletions src/cdtools/datasets/ptycho_2d_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,35 +300,71 @@ def plot_mean_pattern(self, log_offset=1):
cmap_label=cmap_label,
title=title,
)


def split(self):
"""Splits a dataset into two pseudorandomly selected sub-datasets

def split(self, select_randomly: bool = True):
"""
Splits a dataset into two pseudorandomly selected sub-datasets

# the selection is only 5,000 items long, so we repeat it to be long
# enough for the dataset
repeated_random_selection = (random_selection
* int(np.ceil(len(self) / len(random_selection))))
select_randomly : bool
If True, the dataset is split into two disjoint datasets
using a pseudorandom selection. If False, the dataset is split
into two halves.
"""

repeated_random_selection = np.array(repeated_random_selection)
# Here, I use a fixed random selection for reproducibility
cut_random_selection =repeated_random_selection.astype(bool)[:len(self)]

dataset_1 = deepcopy(self)
dataset_1.translations = self.translations[cut_random_selection]
dataset_1.patterns = self.patterns[cut_random_selection]
if hasattr(self, 'intensities') and self.intensities is not None:
dataset_1.intensities = self.intensities[cut_random_selection]

dataset_2 = deepcopy(self)
dataset_2.translations = self.translations[~cut_random_selection]
dataset_2.patterns = self.patterns[~cut_random_selection]
if hasattr(self, 'intensities') and self.intensities is not None:
dataset_2.intensities = self.intensities[~cut_random_selection]
if select_randomly is True:
# the selection is only 5,000 items long, so we repeat it to be long
# enough for the dataset
repeated_random_selection = (random_selection * int(np.ceil(len(self) / len(random_selection))))

repeated_random_selection = np.array(repeated_random_selection)
# Here, I use a fixed random selection for reproducibility
cut_random_selection = repeated_random_selection.astype(bool)[:len(self)]

dataset_1 = deepcopy(self)
dataset_1.translations = self.translations[cut_random_selection]
dataset_1.patterns = self.patterns[cut_random_selection]
if hasattr(self, 'intensities') and self.intensities is not None:
dataset_1.intensities = self.intensities[cut_random_selection]

dataset_2 = deepcopy(self)
dataset_2.translations = self.translations[~cut_random_selection]
dataset_2.patterns = self.patterns[~cut_random_selection]
if hasattr(self, 'intensities') and self.intensities is not None:
dataset_2.intensities = self.intensities[~cut_random_selection]

return dataset_1, dataset_2
return dataset_1, dataset_2

elif select_randomly is False:
# If we are not randomly selecting, we just split the dataset in half
# by taking every second item of the dataset

if len(self.translations) < 2:
raise ValueError(
'The dataset is too small to split. It must contain at least 2 items.'
)
if len(self.translations) % 2 != 0:
warnings.warn(
'The dataset has an odd number of items, so the first half will be one item larger than the second half.'
)

dataset_1 = deepcopy(self)
dataset_1.translations = self.translations[::2]
dataset_1.patterns = self.patterns[::2]
if hasattr(self, 'intensities') and self.intensities is not None:
dataset_1.intensities = self.intensities[::2]

dataset_2 = deepcopy(self)
dataset_2.translations = self.translations[1::2]
dataset_2.patterns = self.patterns[1::2]
if hasattr(self, 'intensities') and self.intensities is not None:
dataset_2.intensities = self.intensities[1::2]

return dataset_1, dataset_2
else:
raise ValueError(
'select_randomly must be True or a list of booleans, not '
f'{type(select_randomly)}'
)

def pad(self, to_pad, value=0, mask=True):
"""Pads all the diffraction patterns by a speficied amount
Expand Down
160 changes: 159 additions & 1 deletion src/cdtools/tools/analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
'standardize_reconstruction_set',
'standardize_reconstruction_pair',
'calc_spectral_info',
'line_based_frc',
]


Expand Down Expand Up @@ -1525,4 +1526,161 @@ def calc_spectral_info(dataset, nbins=50):
pattern_snr = sum_spectrum_sq / sum_spectrum

return sum_pattern, frc_bins[:-1], mean_spectrum



def line_based_frc(
image1, image2, axis=1, n_bins=None, thresholds={}, pixel_size=1.0, unit="pixels"
):
"""
Compute line-based FRC with multiple threshold methods. This works great for highly anisotropic
images, such as those with vertical or horizontal stripes.

Parameters
----------
image1, image2: np.ndarray or torch.Tensor
Input images (should be same size). Must be 2D arrays. Should be aligned before
hand using ip.find_shift, ip.sinc_subpixel_shift, or similar methods.
axis: int
axis along which to extract lines (1 for horizontal lines, 0 for vertical)
if None, standard 2D FRC is computed
n_bins: int
number of frequency bins (if None, no binning is applied)
thresholds: dict
Dictionary of threshold methods and values
pixel_size: float
Physical size of one pixel
unit: str
Unit for the pixel size and resolution output

Returns
-------
frequencies: np.ndarray
Spatial frequency array
frc_curve: np.ndarray
FRC values as function of frequency
resolution_dict: dict
Resolution at different thresholds
"""

def find_resolution(frequencies, frc_curve, thresholds):
resolution_dict = {}
for method, threshold_val in thresholds.items():
above_threshold = frc_curve >= threshold_val
crossings = np.where(np.diff(above_threshold.astype(int)) == -1)[0]
if crossings.size > 0:
idx = crossings[0]
freq1, freq2 = frequencies[idx], frequencies[idx + 1]
frc1, frc2 = frc_curve[idx], frc_curve[idx + 1]
# Linear interpolation
resolution_frequency = freq1 + (threshold_val - frc1) * (freq2 - freq1) / (frc2 - frc1)
else:
if np.all(above_threshold):
resolution_frequency = frequencies[-1]
else:
resolution_frequency = 0
if resolution_frequency > 0:
resolution_pixels = 1 / resolution_frequency
resolution_physical = resolution_pixels * pixel_size
else:
resolution_physical = np.inf
resolution_dict[method] = resolution_physical
return resolution_dict

# transform all torch tensors to np.arrays
if isinstance(image1, t.Tensor):
image1 = image1.numpy()
if isinstance(image2, t.Tensor):
image2 = image2.numpy()

assert isinstance(image1, np.ndarray), "image1 must be a numpy array or torch tensor"
assert isinstance(image2, np.ndarray), "image2 must be a numpy array or torch tensor"
assert image1.shape == image2.shape, "Images must have same shape"
if image1.ndim != 2 or image2.ndim != 2:
raise ValueError("Images must be 2D")

if axis == 1:
lines1, lines2 = image1, image2
n_freq = image1.shape[1] // 2
elif axis == 0:
lines1, lines2 = image1.T, image2.T
n_freq = image1.shape[0] // 2
elif axis is None:
lines1, lines2 = image1, image2
n_freq = min(image1.shape) // 2
else:
raise ValueError("Axis must be 0, 1, or None for line-based FRC")

if axis is not None:
# Compute 1D FFT for each line
fft1 = np.fft.fft(lines1, axis=1)
fft2 = np.fft.fft(lines2, axis=1)

# Only keep positive frequencies
fft1 = fft1[:, :n_freq]
fft2 = fft2[:, :n_freq]

# Get the raw frequency array
raw_frequencies = np.fft.fftfreq(lines1.shape[1])[:n_freq]

if n_bins is None:
# No binning - use original approach
numerator = np.sum(fft1 * np.conj(fft2), axis=0)
denominator = np.sqrt(
np.sum(np.abs(fft1) ** 2, axis=0) * np.sum(np.abs(fft2) ** 2, axis=0)
)

frc_curve = np.abs(numerator) / denominator
frequencies = raw_frequencies

else:
# Binning approach - similar to 2D FRC
max_freq = raw_frequencies[-1]
freq_bins = np.linspace(0, max_freq, n_bins + 1)
frc_curve, frequencies = [], []
for i in range(len(freq_bins) - 1):
mask = (raw_frequencies >= freq_bins[i]) & (raw_frequencies < freq_bins[i + 1])
if np.any(mask):
fft1_bin = fft1[:, mask]
fft2_bin = fft2[:, mask]

# Compute FRC for this bin (sum over both lines and frequencies)
# see van Heel, M. (2005). https://doi.org/10.1016/j.jsb.2005.05.009
numerator = np.sum(fft1_bin * np.conj(fft2_bin))
denominator = np.sqrt(
np.sum(np.abs(fft1_bin) ** 2) * np.sum(np.abs(fft2_bin) ** 2)
)
frc_curve.append(np.abs(numerator) / denominator)
frequencies.append((freq_bins[i] + freq_bins[i + 1]) / 2)
else:
# Compute standard 2D FRC
fft1 = np.fft.fft2(lines1)
fft2 = np.fft.fft2(lines2)

# Shift zero frequency to center
fft1 = np.fft.fftshift(fft1)
fft2 = np.fft.fftshift(fft2)

# Compute radial frequencies
shape = fft1.shape
center = [int(s // 2) for s in shape]
Y, X = np.ogrid[:shape[0], :shape[1]]
r = np.sqrt((Y - center[0]) ** 2 + (X - center[1]) ** 2).astype(np.int32)
if n_bins is None:
n_bins = min(shape) // 2
max_r = np.min(center)
bin_edges = np.linspace(0, max_r, n_bins + 1)
frc_curve, frequencies = [], []
for i in range(n_bins):
mask = (r >= bin_edges[i]) & (r < bin_edges[i + 1])
if np.any(mask):
num = np.sum(fft1[mask] * np.conj(fft2[mask]))
denom = np.sqrt(np.sum(np.abs(fft1[mask]) ** 2) * np.sum(np.abs(fft2[mask]) ** 2))
frc_val = np.abs(num) / denom if denom != 0 else 0
frc_curve.append(frc_val)
# Frequency in cycles per pixel
freq = (bin_edges[i] + bin_edges[i + 1]) / 2 / max_r * 0.5
frequencies.append(freq)
frc_curve, frequencies = np.array(frc_curve), np.array(frequencies)
resolution_dict = find_resolution(frequencies, frc_curve, thresholds)

return frequencies, frc_curve, resolution_dict
Loading