diff --git a/src/cdtools/models/bragg_2d_ptycho.py b/src/cdtools/models/bragg_2d_ptycho.py index bc10ab4d..b3869a9c 100644 --- a/src/cdtools/models/bragg_2d_ptycho.py +++ b/src/cdtools/models/bragg_2d_ptycho.py @@ -258,6 +258,7 @@ def from_dataset( obj_padding=200, obj_view_crop=None, units='um', + surface_normal=None ): wavelength = dataset.wavelength det_basis = dataset.detector_geometry['basis'] @@ -278,24 +279,38 @@ def from_dataset( distance, oversampling=oversampling) - # now we grab the sample surface normal - if hasattr(dataset, 'sample_info') and \ - dataset.sample_info is not None and \ - 'orientation' in dataset.sample_info: - surface_normal = dataset.sample_info['orientation'][2] - else: - surface_normal = np.array([0.,0.,1.]) - - # If this information is supplied when the function is called, - # then we override the information in the .cxi file - if scattering_mode in {'t', 'transmission'}: + # Now we define the surface normal + # The surface normal definition is based on the following heirarchy: + # manual surface_normal definition > scattering_mode + # > dataset.sample_info['orientation'] > transmission geometry + if surface_normal is not None: + surface_normal = np.asarray(surface_normal) + elif scattering_mode.strip().lower() in {'t', 'transmission'}: surface_normal = np.array([0.,0.,1.]) - elif scattering_mode in {'r', 'reflection'}: + elif scattering_mode.strip().lower() in {'r', 'reflection'}: outgoing_dir = np.cross(det_basis[:,0], det_basis[:,1]) outgoing_dir /= np.linalg.norm(outgoing_dir) surface_normal = outgoing_dir + np.array([0.,0.,1.]) surface_normal /= np.linalg.norm(outgoing_dir) + elif scattering_mode is not None: + raise ValueError( + 'Scattering mode must be either "transmission" ("t"), "reflection" ("r"), or the default of None.' + ) + elif hasattr(dataset, 'sample_info') and \ + dataset.sample_info is not None and \ + 'orientation' in dataset.sample_info: + # If the scattering_mode has not been defined, we grab + # this from the cxi file if its present. + surface_normal = dataset.sample_info['orientation'][2] + else: + surface_normal = np.array([0., 0., 1.]) + # Guard against any surface_normal entries that are not castable + # to a length-3 numpy vector, with a sensible error message + if not surface_normal.shape == (3,): + raise ValueError( + '`surface_normal` needs to be a numpy vector with 3 elements. If it was set incorrectly from dataset.sample_info, consider explicitly setting it via the `surface_normal` keyword argument.' + ) # and we use that to generate the probe basis ew_normal = np.cross(np.array(ew_basis)[:,1],