Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
258 changes: 156 additions & 102 deletions catwalk/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from sqlalchemy.orm import sessionmaker
import logging
import time

import functools


def generate_binary_at_x(test_predictions, x_value, unit='top_n'):
Expand Down Expand Up @@ -106,71 +106,175 @@ def _validate_metrics(
elif not met.greater_is_better in (True, False):
raise ValueError("For custom metric {} greater_is_better must be boolean True or False".format(name))

def _generate_evaluations(
def _build_parameter_string(
self,
threshold_unit,
threshold_value,
parameter_combination
):
"""Encode the metric parameters and threshold into a short, human-parseable string

Examples are: '100_abs', '5_pct'

Args:
threshold_unit (string) the type of threshold, either 'percentile' or 'top_n'
threshold_value (int) the numeric threshold,
parameter_combination (dict) The non-threshold parameter keys and values used

Returns: (string) A short, human-parseable string
"""
full_params = parameter_combination.copy()
if not (threshold_unit == 'percentile' and threshold_value == 100):
short_threshold_unit = 'pct' if threshold_unit == 'percentile' else 'abs'
full_params[short_threshold_unit] = threshold_value
parameter_string = '/'.join([
'{}_{}'.format(val, key)
for key, val in full_params.items()
])
return parameter_string

def _filter_nan_labels(self, predicted_classes, labels):
"""Filter missing labels and their corresponding predictions

Args:
predicted_classes (list) Predicted binary classes, of same length as labels
labels (list) Labels, maybe containing NaNs

Returns: (tuple) Copies of the input lists, with NaN labels removed
"""
labels = numpy.array(labels)
predicted_classes = numpy.array(predicted_classes)
nan_mask = numpy.isfinite(labels)
return (
(predicted_classes[nan_mask]).tolist(),
(labels[nan_mask]).tolist()
)

def _evaluations_for_threshold(
self,
metrics,
parameters,
threshold_config,
predictions_proba,
predictions_binary,
labels,
threshold_unit,
threshold_value,
):
"""Generate evaluations based on config and create ORM objects to hold them
"""Generate evaluations for a given threshold in a metric group, and create ORM objects to hold them

Args:
metrics (list) names of metric to compute
parameters (list) dicts holding parameters to pass to metrics
threshold_config (dict) Unit type and value referring to how any
thresholds were computed. Combined with parameter string
to make a unique identifier for the parameter in the database
threshold_unit (string) the type of threshold, either 'percentile' or 'top_n'
threshold_value (int) the numeric threshold,
predictions_proba (list) Probability predictions
predictions_binary (list) Binary predictions
labels (list) True labels
labels (list) True labels (may have NaNs)

Returns: (list) results_schema.Evaluation objects
Raises: UnknownMetricError if a given metric is not present in
self.available_metrics
"""

# using threshold configuration, convert probabilities to predicted classes
predicted_classes = generate_binary_at_x(
predictions_proba,
threshold_value,
unit=threshold_unit
)
# filter out null labels
predicted_classes_with_labels, present_labels = self._filter_nan_labels(
predicted_classes,
labels,
)
num_labeled_examples = len(present_labels)
num_labeled_above_threshold = predicted_classes_with_labels.count(1)
num_positive_labels = present_labels.count(1)
evaluations = []
num_labeled_examples = len(labels)
num_labeled_above_threshold = predictions_binary.count(1)
num_positive_labels = labels.count(1)
for metric in metrics:
if metric in self.available_metrics:
for parameter_combination in parameters:
value = self.available_metrics[metric](
predictions_proba,
predictions_binary,
labels,
parameter_combination
)

full_params = parameter_combination.copy()
full_params.update(threshold_config)
parameter_string = '/'.join([
'{}_{}'.format(val, key)
for key, val in full_params.items()
])
logging.info(
'Evaluations for %s%s, labeled examples %s, above threshold %s, positive labels %s, value %s',
metric,
parameter_string,
num_labeled_examples,
num_labeled_above_threshold,
num_positive_labels,
value
)
evaluations.append(Evaluation(
metric=metric,
parameter=parameter_string,
value=value,
num_labeled_examples=num_labeled_examples,
num_labeled_above_threshold=num_labeled_above_threshold,
num_positive_labels=num_positive_labels,
sort_seed=self.sort_seed
))
else:
if metric not in self.available_metrics:
raise UnknownMetricError()

for parameter_combination in parameters:
value = self.available_metrics[metric](
predictions_proba,
predicted_classes_with_labels,
present_labels,
parameter_combination
)

# convert the thresholds/parameters into something
# more readable
parameter_string = self._build_parameter_string(
threshold_unit=threshold_unit,
threshold_value=threshold_value,
parameter_combination=parameter_combination
)

logging.info(
'Evaluations for %s%s, labeled examples %s, above threshold %s, positive labels %s, value %s',
metric,
parameter_string,
num_labeled_examples,
num_labeled_above_threshold,
num_positive_labels,
value
)
evaluations.append(Evaluation(
metric=metric,
parameter=parameter_string,
value=value,
num_labeled_examples=num_labeled_examples,
num_labeled_above_threshold=num_labeled_above_threshold,
num_positive_labels=num_positive_labels,
sort_seed=self.sort_seed
))
return evaluations

def _evaluations_for_group(
self,
group,
predictions_proba_sorted,
labels_sorted
):
"""Generate evaluations for a given metric group, and create ORM objects to hold them

Args:
group (dict) A configuration dictionary for the group.
Should contain the key 'metrics', and optionally 'parameters' or 'thresholds'
predictions_proba (list) Probability predictions
labels (list) True labels (may have NaNs)

Returns: (list) results_schema.Evaluation objects
"""
logging.info('Creating evaluations for metric group %s', group)
parameters = group.get('parameters', [{}])
generate_evaluations = functools.partial(
self._evaluations_for_threshold,
metrics=group['metrics'],
parameters=parameters,
predictions_proba=predictions_proba_sorted,
labels=labels_sorted
)
evaluations = []
if 'thresholds' not in group:
logging.info('Not a thresholded group, generating evaluation based on all predictions')
evaluations = evaluations + generate_evaluations(
threshold_unit='percentile',
threshold_value=100
)

for pct_thresh in group.get('thresholds', {}).get('percentiles', []):
logging.info('Processing percent threshold %s', pct_thresh)
evaluations = evaluations + generate_evaluations(
threshold_unit='percentile',
threshold_value=pct_thresh
)

for abs_thresh in group.get('thresholds', {}).get('top_n', []):
logging.info('Processing absolute threshold %s', abs_thresh)
evaluations = evaluations + generate_evaluations(
threshold_unit='top_n',
threshold_value=abs_thresh
)
return evaluations

def evaluate(
Expand Down Expand Up @@ -204,64 +308,14 @@ def evaluate(
labels,
self.sort_seed
)
labels_sorted = numpy.array(labels_sorted)

evaluations = []
for group in self.metric_groups:
logging.info('Creating evaluations for metric group %s', group)
parameters = group.get('parameters', [{}])
if 'thresholds' not in group:
logging.info('Not a thresholded group, generating evaluation based on all predictions')
evaluations = evaluations + self._generate_evaluations(
group['metrics'],
parameters,
{},
predictions_proba,
generate_binary_at_x(
predictions_proba_sorted,
100,
unit='percentile'
),
labels_sorted.tolist(),
)

for pct_thresh in group.get('thresholds', {}).get('percentiles', []):
logging.info('Processing percent threshold %s', pct_thresh)
predicted_classes = numpy.array(generate_binary_at_x(
predictions_proba_sorted,
pct_thresh,
unit='percentile'
))
nan_mask = numpy.isfinite(labels_sorted)
predicted_classes = (predicted_classes[nan_mask]).tolist()
present_labels_sorted = (labels_sorted[nan_mask]).tolist()
evaluations = evaluations + self._generate_evaluations(
group['metrics'],
parameters,
{'pct': pct_thresh},
None,
predicted_classes,
present_labels_sorted,
)

for abs_thresh in group.get('thresholds', {}).get('top_n', []):
logging.info('Processing absolute threshold %s', abs_thresh)
predicted_classes = numpy.array(generate_binary_at_x(
predictions_proba_sorted,
abs_thresh,
unit='top_n'
))
nan_mask = numpy.isfinite(labels_sorted)
predicted_classes = (predicted_classes[nan_mask]).tolist()
present_labels_sorted = (labels_sorted[nan_mask]).tolist()
evaluations = evaluations + self._generate_evaluations(
group['metrics'],
parameters,
{'abs': abs_thresh},
None,
predicted_classes,
present_labels_sorted,
)
evaluations += self._evaluations_for_group(
group,
predictions_proba_sorted,
labels_sorted
)

logging.info('Writing metrics to db')
self._write_to_db(
Expand Down
7 changes: 6 additions & 1 deletion tests/test_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,9 @@ def test_model_scoring_inspections():
metric_groups = [{
'metrics': ['precision@', 'recall@', 'fpr@'],
'thresholds': {'percentiles': [50.0], 'top_n': [3]}
}, {
# ensure we test a non-thresholded metric as well
'metrics': ['accuracy'],
}]

model_evaluator = ModelEvaluator(metric_groups, db_engine)
Expand Down Expand Up @@ -151,7 +154,9 @@ def test_model_scoring_inspections():
):
assert record['num_labeled_examples'] == 4
assert record['num_positive_labels'] == 2
if 'pct' in record['parameter']:
if record['parameter'] == '':
assert record['num_labeled_above_threshold'] == 4
elif 'pct' in record['parameter']:
assert record['num_labeled_above_threshold'] == 1
else:
assert record['num_labeled_above_threshold'] == 2
Expand Down