Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 27 additions & 12 deletions src/cdtools/models/bragg_2d_ptycho.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,7 @@ def from_dataset(
obj_padding=200,
obj_view_crop=None,
units='um',
surface_normal=None
):
wavelength = dataset.wavelength
det_basis = dataset.detector_geometry['basis']
Expand All @@ -278,24 +279,38 @@ def from_dataset(
distance,
oversampling=oversampling)

# now we grab the sample surface normal
if hasattr(dataset, 'sample_info') and \
dataset.sample_info is not None and \
'orientation' in dataset.sample_info:
surface_normal = dataset.sample_info['orientation'][2]
else:
surface_normal = np.array([0.,0.,1.])

# If this information is supplied when the function is called,
# then we override the information in the .cxi file
if scattering_mode in {'t', 'transmission'}:
# Now we define the surface normal
# The surface normal definition is based on the following heirarchy:
# manual surface_normal definition > scattering_mode
# > dataset.sample_info['orientation'] > transmission geometry
if surface_normal is not None:
surface_normal = np.asarray(surface_normal)
elif scattering_mode.strip().lower() in {'t', 'transmission'}:
surface_normal = np.array([0.,0.,1.])
elif scattering_mode in {'r', 'reflection'}:
elif scattering_mode.strip().lower() in {'r', 'reflection'}:
outgoing_dir = np.cross(det_basis[:,0], det_basis[:,1])
outgoing_dir /= np.linalg.norm(outgoing_dir)
surface_normal = outgoing_dir + np.array([0.,0.,1.])
surface_normal /= np.linalg.norm(outgoing_dir)
elif scattering_mode is not None:
raise ValueError(
'Scattering mode must be either "transmission" ("t"), "reflection" ("r"), or the default of None.'
)
elif hasattr(dataset, 'sample_info') and \
dataset.sample_info is not None and \
'orientation' in dataset.sample_info:
# If the scattering_mode has not been defined, we grab
# this from the cxi file if its present.
surface_normal = dataset.sample_info['orientation'][2]
else:
surface_normal = np.array([0., 0., 1.])

# Guard against any surface_normal entries that are not castable
# to a length-3 numpy vector, with a sensible error message
if not surface_normal.shape == (3,):
raise ValueError(
'`surface_normal` needs to be a numpy vector with 3 elements. If it was set incorrectly from dataset.sample_info, consider explicitly setting it via the `surface_normal` keyword argument.'
)
# and we use that to generate the probe basis

ew_normal = np.cross(np.array(ew_basis)[:,1],
Expand Down