From ea2382af67a95c218c54e0672b50387b155c7f36 Mon Sep 17 00:00:00 2001 From: Tristan Crockett Date: Fri, 15 Sep 2017 18:10:28 -0500 Subject: [PATCH] Evaluations refactoring and non-thresholded bugfix, plus test to ensure that bug is fixed --- catwalk/evaluation.py | 258 +++++++++++++++++++++++---------------- tests/test_evaluation.py | 7 +- 2 files changed, 162 insertions(+), 103 deletions(-) diff --git a/catwalk/evaluation.py b/catwalk/evaluation.py index 200f0e4..ecda8fd 100644 --- a/catwalk/evaluation.py +++ b/catwalk/evaluation.py @@ -5,7 +5,7 @@ from sqlalchemy.orm import sessionmaker import logging import time - +import functools def generate_binary_at_x(test_predictions, x_value, unit='top_n'): @@ -106,71 +106,175 @@ def _validate_metrics( elif not met.greater_is_better in (True, False): raise ValueError("For custom metric {} greater_is_better must be boolean True or False".format(name)) - def _generate_evaluations( + def _build_parameter_string( + self, + threshold_unit, + threshold_value, + parameter_combination + ): + """Encode the metric parameters and threshold into a short, human-parseable string + + Examples are: '100_abs', '5_pct' + + Args: + threshold_unit (string) the type of threshold, either 'percentile' or 'top_n' + threshold_value (int) the numeric threshold, + parameter_combination (dict) The non-threshold parameter keys and values used + + Returns: (string) A short, human-parseable string + """ + full_params = parameter_combination.copy() + if not (threshold_unit == 'percentile' and threshold_value == 100): + short_threshold_unit = 'pct' if threshold_unit == 'percentile' else 'abs' + full_params[short_threshold_unit] = threshold_value + parameter_string = '/'.join([ + '{}_{}'.format(val, key) + for key, val in full_params.items() + ]) + return parameter_string + + def _filter_nan_labels(self, predicted_classes, labels): + """Filter missing labels and their corresponding predictions + + Args: + predicted_classes (list) Predicted binary classes, of same length as labels + labels (list) Labels, maybe containing NaNs + + Returns: (tuple) Copies of the input lists, with NaN labels removed + """ + labels = numpy.array(labels) + predicted_classes = numpy.array(predicted_classes) + nan_mask = numpy.isfinite(labels) + return ( + (predicted_classes[nan_mask]).tolist(), + (labels[nan_mask]).tolist() + ) + + def _evaluations_for_threshold( self, metrics, parameters, - threshold_config, predictions_proba, - predictions_binary, labels, + threshold_unit, + threshold_value, ): - """Generate evaluations based on config and create ORM objects to hold them + """Generate evaluations for a given threshold in a metric group, and create ORM objects to hold them Args: metrics (list) names of metric to compute parameters (list) dicts holding parameters to pass to metrics - threshold_config (dict) Unit type and value referring to how any - thresholds were computed. Combined with parameter string - to make a unique identifier for the parameter in the database + threshold_unit (string) the type of threshold, either 'percentile' or 'top_n' + threshold_value (int) the numeric threshold, predictions_proba (list) Probability predictions - predictions_binary (list) Binary predictions - labels (list) True labels + labels (list) True labels (may have NaNs) Returns: (list) results_schema.Evaluation objects Raises: UnknownMetricError if a given metric is not present in self.available_metrics """ + + # using threshold configuration, convert probabilities to predicted classes + predicted_classes = generate_binary_at_x( + predictions_proba, + threshold_value, + unit=threshold_unit + ) + # filter out null labels + predicted_classes_with_labels, present_labels = self._filter_nan_labels( + predicted_classes, + labels, + ) + num_labeled_examples = len(present_labels) + num_labeled_above_threshold = predicted_classes_with_labels.count(1) + num_positive_labels = present_labels.count(1) evaluations = [] - num_labeled_examples = len(labels) - num_labeled_above_threshold = predictions_binary.count(1) - num_positive_labels = labels.count(1) for metric in metrics: - if metric in self.available_metrics: - for parameter_combination in parameters: - value = self.available_metrics[metric]( - predictions_proba, - predictions_binary, - labels, - parameter_combination - ) - - full_params = parameter_combination.copy() - full_params.update(threshold_config) - parameter_string = '/'.join([ - '{}_{}'.format(val, key) - for key, val in full_params.items() - ]) - logging.info( - 'Evaluations for %s%s, labeled examples %s, above threshold %s, positive labels %s, value %s', - metric, - parameter_string, - num_labeled_examples, - num_labeled_above_threshold, - num_positive_labels, - value - ) - evaluations.append(Evaluation( - metric=metric, - parameter=parameter_string, - value=value, - num_labeled_examples=num_labeled_examples, - num_labeled_above_threshold=num_labeled_above_threshold, - num_positive_labels=num_positive_labels, - sort_seed=self.sort_seed - )) - else: + if metric not in self.available_metrics: raise UnknownMetricError() + + for parameter_combination in parameters: + value = self.available_metrics[metric]( + predictions_proba, + predicted_classes_with_labels, + present_labels, + parameter_combination + ) + + # convert the thresholds/parameters into something + # more readable + parameter_string = self._build_parameter_string( + threshold_unit=threshold_unit, + threshold_value=threshold_value, + parameter_combination=parameter_combination + ) + + logging.info( + 'Evaluations for %s%s, labeled examples %s, above threshold %s, positive labels %s, value %s', + metric, + parameter_string, + num_labeled_examples, + num_labeled_above_threshold, + num_positive_labels, + value + ) + evaluations.append(Evaluation( + metric=metric, + parameter=parameter_string, + value=value, + num_labeled_examples=num_labeled_examples, + num_labeled_above_threshold=num_labeled_above_threshold, + num_positive_labels=num_positive_labels, + sort_seed=self.sort_seed + )) + return evaluations + + def _evaluations_for_group( + self, + group, + predictions_proba_sorted, + labels_sorted + ): + """Generate evaluations for a given metric group, and create ORM objects to hold them + + Args: + group (dict) A configuration dictionary for the group. + Should contain the key 'metrics', and optionally 'parameters' or 'thresholds' + predictions_proba (list) Probability predictions + labels (list) True labels (may have NaNs) + + Returns: (list) results_schema.Evaluation objects + """ + logging.info('Creating evaluations for metric group %s', group) + parameters = group.get('parameters', [{}]) + generate_evaluations = functools.partial( + self._evaluations_for_threshold, + metrics=group['metrics'], + parameters=parameters, + predictions_proba=predictions_proba_sorted, + labels=labels_sorted + ) + evaluations = [] + if 'thresholds' not in group: + logging.info('Not a thresholded group, generating evaluation based on all predictions') + evaluations = evaluations + generate_evaluations( + threshold_unit='percentile', + threshold_value=100 + ) + + for pct_thresh in group.get('thresholds', {}).get('percentiles', []): + logging.info('Processing percent threshold %s', pct_thresh) + evaluations = evaluations + generate_evaluations( + threshold_unit='percentile', + threshold_value=pct_thresh + ) + + for abs_thresh in group.get('thresholds', {}).get('top_n', []): + logging.info('Processing absolute threshold %s', abs_thresh) + evaluations = evaluations + generate_evaluations( + threshold_unit='top_n', + threshold_value=abs_thresh + ) return evaluations def evaluate( @@ -204,64 +308,14 @@ def evaluate( labels, self.sort_seed ) - labels_sorted = numpy.array(labels_sorted) evaluations = [] for group in self.metric_groups: - logging.info('Creating evaluations for metric group %s', group) - parameters = group.get('parameters', [{}]) - if 'thresholds' not in group: - logging.info('Not a thresholded group, generating evaluation based on all predictions') - evaluations = evaluations + self._generate_evaluations( - group['metrics'], - parameters, - {}, - predictions_proba, - generate_binary_at_x( - predictions_proba_sorted, - 100, - unit='percentile' - ), - labels_sorted.tolist(), - ) - - for pct_thresh in group.get('thresholds', {}).get('percentiles', []): - logging.info('Processing percent threshold %s', pct_thresh) - predicted_classes = numpy.array(generate_binary_at_x( - predictions_proba_sorted, - pct_thresh, - unit='percentile' - )) - nan_mask = numpy.isfinite(labels_sorted) - predicted_classes = (predicted_classes[nan_mask]).tolist() - present_labels_sorted = (labels_sorted[nan_mask]).tolist() - evaluations = evaluations + self._generate_evaluations( - group['metrics'], - parameters, - {'pct': pct_thresh}, - None, - predicted_classes, - present_labels_sorted, - ) - - for abs_thresh in group.get('thresholds', {}).get('top_n', []): - logging.info('Processing absolute threshold %s', abs_thresh) - predicted_classes = numpy.array(generate_binary_at_x( - predictions_proba_sorted, - abs_thresh, - unit='top_n' - )) - nan_mask = numpy.isfinite(labels_sorted) - predicted_classes = (predicted_classes[nan_mask]).tolist() - present_labels_sorted = (labels_sorted[nan_mask]).tolist() - evaluations = evaluations + self._generate_evaluations( - group['metrics'], - parameters, - {'abs': abs_thresh}, - None, - predicted_classes, - present_labels_sorted, - ) + evaluations += self._evaluations_for_group( + group, + predictions_proba_sorted, + labels_sorted + ) logging.info('Writing metrics to db') self._write_to_db( diff --git a/tests/test_evaluation.py b/tests/test_evaluation.py index 320eac4..6e9c0bf 100644 --- a/tests/test_evaluation.py +++ b/tests/test_evaluation.py @@ -120,6 +120,9 @@ def test_model_scoring_inspections(): metric_groups = [{ 'metrics': ['precision@', 'recall@', 'fpr@'], 'thresholds': {'percentiles': [50.0], 'top_n': [3]} + }, { + # ensure we test a non-thresholded metric as well + 'metrics': ['accuracy'], }] model_evaluator = ModelEvaluator(metric_groups, db_engine) @@ -151,7 +154,9 @@ def test_model_scoring_inspections(): ): assert record['num_labeled_examples'] == 4 assert record['num_positive_labels'] == 2 - if 'pct' in record['parameter']: + if record['parameter'] == '': + assert record['num_labeled_above_threshold'] == 4 + elif 'pct' in record['parameter']: assert record['num_labeled_above_threshold'] == 1 else: assert record['num_labeled_above_threshold'] == 2