diff --git a/mPyPl/keras.py b/mPyPl/keras.py index fda7757..390914a 100644 --- a/mPyPl/keras.py +++ b/mPyPl/keras.py @@ -5,7 +5,7 @@ import numpy as np @Pipe -def as_batch(flow, feature_field_name='features', label_field_name='label', batchsize=16): +def as_batch(flow, feature_field_name='features', label_field_name='label', batchsize=16, out_features_dtype=None, out_labels_dtype=None): """ Split input datastream into a sequence of batches suitable for keras training. :param flow: input datastream @@ -22,18 +22,21 @@ def as_batch(flow, feature_field_name='features', label_field_name='label', batc # explicitly compute all fields - this is needed for all fields to be computed only once for on-demand evaluation flds = { i : data[i] for i in (feature_field_name if isinstance(feature_field_name, list) else [feature_field_name])} lbls = data[label_field_name] # TODO: what happens when label_field_name is a list? + if batch is None: if isinstance(feature_field_name, list): - batch = [np.zeros((batchsize,)+flds[i].shape) for i in feature_field_name] + batch = [np.zeros((batchsize,)+flds[i].shape, dtype=flds[i].dtype if out_features_dtype is None else out_features_dtype) for i in feature_field_name] else: - batch = np.zeros((batchsize,)+flds[feature_field_name].shape) - lbls_shape = lbls.shape if lbls is np.ndarray else (1,) - labels = np.zeros((batchsize,)+lbls_shape) + batch = np.zeros((batchsize,)+flds[feature_field_name].shape, dtype=flds[feature_field_name].dtype if out_features_dtype is None else out_features_dtype) + + lbls_shape = lbls.shape if type(lbls) is np.ndarray else (1,) + out_labels_dtype = out_labels_dtype if out_labels_dtype is not None else lbls.dtype if type(lbls) is np.ndarray else None + labels = np.zeros((batchsize,)+lbls_shape, dtype=out_labels_dtype) if isinstance(feature_field_name, list): for j,n in enumerate(feature_field_name): batch[j][i] = flds[n] else: batch[i] = flds[feature_field_name] - labels[i] = data[label_field_name] + labels[i] = lbls yield (batch, labels) batch = labels = None