diff --git a/docs/workflows/transitions/preprocess/preprocess/make_predictions.py b/docs/workflows/transitions/preprocess/preprocess/make_predictions.py index 10a5ea8..907e4e3 100644 --- a/docs/workflows/transitions/preprocess/preprocess/make_predictions.py +++ b/docs/workflows/transitions/preprocess/preprocess/make_predictions.py @@ -15,18 +15,13 @@ above_threshold_or_optional, add_confidence_to_ground_truth, create_form_config_from_model, - filter_away_low_confidence_lines, - filter_by_top1, filter_optional_fields, format_verified_output, - get_column_names, get_labels, is_enum, is_line, - concatenate_lines_from_different_pages_and_merge_continued_lines, - concatenate_lines_from_different_pages, merge_predictions_and_gt, - patch_empty_predictions, + patch_and_filter_predictions, required_labels, threshold_is_zero_for_all, ) @@ -165,23 +160,14 @@ def make_predictions(las_client, event): try: if predictions: field_config = form_config['config']['fields'] - column_names = get_column_names(field_config) - labels = labels.union(column_names) - top1_preds = filter_by_top1(predictions, labels) - - if model_metadata.get('mergeContinuedLines'): - predictions = concatenate_lines_from_different_pages_and_merge_continued_lines( - predictions=predictions, - field_config=field_config, - ) - else: - predictions = concatenate_lines_from_different_pages(predictions, field_config) - - predictions = patch_empty_predictions(predictions, labels, no_empty_prediction_fields) - predictions = filter_away_low_confidence_lines(predictions, field_config) - + predictions, top1_preds = patch_and_filter_predictions( + predictions=predictions, + field_config=field_config, + labels=labels, + merge_continued_lines=model_metadata.get('mergeContinuedLines'), + no_empty_prediction_fields=no_empty_prediction_fields, + ) logging.info(f'patched and filtered predictions {predictions}') - all_above_threshold_or_optional = True if threshold_is_zero_for_all(field_config): needs_validation = False @@ -206,7 +192,7 @@ def make_predictions(las_client, event): if not above_threshold_or_optional(prediction, field_config): all_above_threshold_or_optional = False - has_all_required_labels = required_labels(field_config) <= set(map(lambda p: p['label'], predictions)) + has_all_required_labels = required_labels(field_config) <= set(map(lambda p: p['label'], top1_preds)) needs_validation = not has_all_required_labels or not all_above_threshold_or_optional logging.info(f'All predictions above threshold (or optional): {all_above_threshold_or_optional}') diff --git a/docs/workflows/transitions/preprocess/preprocess/utils.py b/docs/workflows/transitions/preprocess/preprocess/utils.py index 621b5fe..e4e9ea4 100644 --- a/docs/workflows/transitions/preprocess/preprocess/utils.py +++ b/docs/workflows/transitions/preprocess/preprocess/utils.py @@ -85,6 +85,25 @@ def top1(predictions, label): return result +def patch_and_filter_predictions(predictions, field_config, labels, merge_continued_lines, no_empty_prediction_fields): + column_names = get_column_names(field_config) + labels = labels.union(column_names) + + if merge_continued_lines: + predictions = concatenate_lines_from_different_pages_and_merge_continued_lines( + predictions=predictions, + field_config=field_config, + ) + else: + predictions = concatenate_lines_from_different_pages(predictions, field_config) + + predictions = patch_empty_predictions(predictions, labels, no_empty_prediction_fields) + predictions = filter_away_low_confidence_lines(predictions, field_config) + top1_preds = filter_by_top1(predictions, labels) + + return predictions, top1_preds + + def above_threshold_or_optional(prediction, field_config): label, confidence = prediction['label'], prediction.get('confidence') if label not in field_config: diff --git a/docs/workflows/transitions/preprocess/tests/test_handler.py b/docs/workflows/transitions/preprocess/tests/test_handler.py index 4bc0402..3681076 100644 --- a/docs/workflows/transitions/preprocess/tests/test_handler.py +++ b/docs/workflows/transitions/preprocess/tests/test_handler.py @@ -9,13 +9,14 @@ from ..preprocess.make_predictions import make_predictions from ..preprocess.utils import ( + concatenate_lines_from_different_pages_and_merge_continued_lines, create_form_config_from_model, + filter_away_low_confidence_lines, filter_by_top1, - concatenate_lines_from_different_pages_and_merge_continued_lines, - patch_empty_predictions, - get_labels, get_column_names, - filter_away_low_confidence_lines, + get_labels, + patch_and_filter_predictions, + patch_empty_predictions, ) @@ -958,3 +959,63 @@ def test_create_form_config_from_model(field_config, form_config): assert conf_levels else: assert conf_levels == form_config[field]['fields'][line_field]['confidenceLevels']['automated'] + + +@pytest.fixture +def form_config_simple(): + yield base64.b64encode(json.dumps({ + 'config': { + 'fields': { + 'total_amount': { + 'type': 'amount', + 'confidenceLevels': {'automated': 0.98, 'high': 0.97, 'medium': 0.9, 'low': 0.5} + }, + 'line_items': { + 'type': 'lines', + 'fields': { + 'subtotal': { + 'type': 'string', + 'confidenceLevels': {'automated': 0.98, 'high': 0.97, 'medium': 0.9, 'low': 0.5}, + } + } + } + } + } + }).encode('utf-8')) + + +@pytest.mark.parametrize('predictions', [[ + {'label': 'total_amount', 'page': 0, 'value': None, 'confidence': 0.99}, + {'label': 'total_amount', 'page': 1, 'value': '100.00', 'confidence': 0.35}, + {'label': 'line_items', 'value': [ + [ + {'label': 'subtotal', 'page': 0, 'value': '50.00', 'confidence': 0.99}, + ], [ + {'label': 'subtotal', 'page': 0, 'value': '30.00', 'confidence': 0.99}, + ], [], + ]}, + {'label': 'line_items', 'value': [ + [ + {'label': 'subtotal', 'page': 1, 'value': '72.15', 'confidence': 0.9}, + ], + ]}, +]]) +@pytest.mark.parametrize('no_empty_prediction_fields', [{'total_amount', 'line_items'}, {}]) +def test_patch_and_filter(form_config_simple, predictions, no_empty_prediction_fields): + form_config = json.loads(base64.b64decode(form_config_simple)) + field_config = form_config['config']['fields'] + labels = get_labels(field_config) + _, top_1 = patch_and_filter_predictions(predictions, field_config, labels, False, no_empty_prediction_fields) + best_total_amount = [p for p in top_1 if p['label'] == 'total_amount'] + assert len(best_total_amount) == 1 + # If a field is part of no_empty_prediction_fields, + # it means that there exists at least one page with a prediction + # If a field is not part of no_empty_prediction_fields, + # it means that there is no prediction for that field on any page + # This would not be the case for the predictions above, + # but faking it will make the best_total_amount None since it has the highest confidence. + assert best_total_amount[0]['value'] == ('100.00' if no_empty_prediction_fields else None) + line_items = [p for p in top_1 if p['label'] == 'line_items'] + assert len(line_items) == 1 + line_values = line_items[0]['value'] + assert [line_candidate for line_candidate in line_values if line_candidate[0]['page'] == 1]