From ff7d0295c5b14e034cb511211d4da21d1901e812 Mon Sep 17 00:00:00 2001 From: magnurud Date: Tue, 17 Jun 2025 16:13:41 +0200 Subject: [PATCH 1/4] first step move top1 preds to the end refactoring --- .../preprocess/preprocess/make_predictions.py | 26 +------- .../preprocess/preprocess/utils.py | 19 ++++++ .../preprocess/tests/test_handler.py | 61 +++++++++++++++++-- 3 files changed, 79 insertions(+), 27 deletions(-) diff --git a/docs/workflows/transitions/preprocess/preprocess/make_predictions.py b/docs/workflows/transitions/preprocess/preprocess/make_predictions.py index 10a5ea8..4ff6e39 100644 --- a/docs/workflows/transitions/preprocess/preprocess/make_predictions.py +++ b/docs/workflows/transitions/preprocess/preprocess/make_predictions.py @@ -15,18 +15,13 @@ above_threshold_or_optional, add_confidence_to_ground_truth, create_form_config_from_model, - filter_away_low_confidence_lines, - filter_by_top1, filter_optional_fields, format_verified_output, - get_column_names, get_labels, is_enum, is_line, - concatenate_lines_from_different_pages_and_merge_continued_lines, - concatenate_lines_from_different_pages, merge_predictions_and_gt, - patch_empty_predictions, + patch_and_filter_predictions, required_labels, threshold_is_zero_for_all, ) @@ -165,23 +160,8 @@ def make_predictions(las_client, event): try: if predictions: field_config = form_config['config']['fields'] - column_names = get_column_names(field_config) - labels = labels.union(column_names) - top1_preds = filter_by_top1(predictions, labels) - - if model_metadata.get('mergeContinuedLines'): - predictions = concatenate_lines_from_different_pages_and_merge_continued_lines( - predictions=predictions, - field_config=field_config, - ) - else: - predictions = concatenate_lines_from_different_pages(predictions, field_config) - - predictions = patch_empty_predictions(predictions, labels, no_empty_prediction_fields) - predictions = filter_away_low_confidence_lines(predictions, field_config) - + predictions, top1_preds = patch_and_filter_predictions(predictions, field_config, labels, model_metadata.get('mergeContinuedLines'), no_empty_prediction_fields) logging.info(f'patched and filtered predictions {predictions}') - all_above_threshold_or_optional = True if threshold_is_zero_for_all(field_config): needs_validation = False @@ -206,7 +186,7 @@ def make_predictions(las_client, event): if not above_threshold_or_optional(prediction, field_config): all_above_threshold_or_optional = False - has_all_required_labels = required_labels(field_config) <= set(map(lambda p: p['label'], predictions)) + has_all_required_labels = required_labels(field_config) <= set(map(lambda p: p['label'], top1_preds)) needs_validation = not has_all_required_labels or not all_above_threshold_or_optional logging.info(f'All predictions above threshold (or optional): {all_above_threshold_or_optional}') diff --git a/docs/workflows/transitions/preprocess/preprocess/utils.py b/docs/workflows/transitions/preprocess/preprocess/utils.py index 621b5fe..e4e9ea4 100644 --- a/docs/workflows/transitions/preprocess/preprocess/utils.py +++ b/docs/workflows/transitions/preprocess/preprocess/utils.py @@ -85,6 +85,25 @@ def top1(predictions, label): return result +def patch_and_filter_predictions(predictions, field_config, labels, merge_continued_lines, no_empty_prediction_fields): + column_names = get_column_names(field_config) + labels = labels.union(column_names) + + if merge_continued_lines: + predictions = concatenate_lines_from_different_pages_and_merge_continued_lines( + predictions=predictions, + field_config=field_config, + ) + else: + predictions = concatenate_lines_from_different_pages(predictions, field_config) + + predictions = patch_empty_predictions(predictions, labels, no_empty_prediction_fields) + predictions = filter_away_low_confidence_lines(predictions, field_config) + top1_preds = filter_by_top1(predictions, labels) + + return predictions, top1_preds + + def above_threshold_or_optional(prediction, field_config): label, confidence = prediction['label'], prediction.get('confidence') if label not in field_config: diff --git a/docs/workflows/transitions/preprocess/tests/test_handler.py b/docs/workflows/transitions/preprocess/tests/test_handler.py index 4bc0402..8ba6bb0 100644 --- a/docs/workflows/transitions/preprocess/tests/test_handler.py +++ b/docs/workflows/transitions/preprocess/tests/test_handler.py @@ -9,13 +9,14 @@ from ..preprocess.make_predictions import make_predictions from ..preprocess.utils import ( + concatenate_lines_from_different_pages_and_merge_continued_lines, create_form_config_from_model, + filter_away_low_confidence_lines, filter_by_top1, - concatenate_lines_from_different_pages_and_merge_continued_lines, - patch_empty_predictions, - get_labels, get_column_names, - filter_away_low_confidence_lines, + get_labels, + patch_and_filter_predictions, + patch_empty_predictions, ) @@ -958,3 +959,55 @@ def test_create_form_config_from_model(field_config, form_config): assert conf_levels else: assert conf_levels == form_config[field]['fields'][line_field]['confidenceLevels']['automated'] + +@pytest.fixture +def form_config_simple(): + yield base64.b64encode(json.dumps({ + 'config': { + 'fields': { + 'total_amount': { + 'type': 'amount', + 'confidenceLevels': {'automated': 0.98, 'high': 0.97, 'medium': 0.9, 'low': 0.5} + }, + 'line_items': { + 'type': 'lines', + 'fields': { + 'subtotal': { + 'type': 'string', + 'confidenceLevels': {'automated': 0.98, 'high': 0.97, 'medium': 0.9, 'low': 0.5}, + } + } + } + } + } + }).encode('utf-8')) + +@pytest.mark.parametrize('predictions', [[ + {'label': 'total_amount', 'page': 0, 'value': None, 'confidence': 0.99}, + {'label': 'total_amount', 'page': 1, 'value': '100.00', 'confidence': 0.35}, + {'label': 'line_items', 'value': [ + [ + {'label': 'subtotal', 'page': 0, 'value': '50.00', 'confidence': 0.99}, + ], [ + {'label': 'subtotal', 'page': 0, 'value': '30.00', 'confidence': 0.99}, + ], [], + ]}, + {'label': 'line_items', 'value': [ + [ + {'label': 'subtotal', 'page': 1, 'value': '72.15', 'confidence': 0.9}, + ], + ]}, +]]) +@pytest.mark.parametrize('no_empty_prediction_fields', [{'total_amount', 'line_items'}, {}]) +def test_patch_and_filter(form_config_simple, predictions, no_empty_prediction_fields): + form_config = json.loads(base64.b64decode(form_config_simple)) + field_config = form_config['config']['fields'] + labels = get_labels(field_config) + new_predictions, top_1 = patch_and_filter_predictions(predictions, field_config, labels, False, no_empty_prediction_fields) + best_total_amount = [p for p in top_1 if p['label'] == 'total_amount'] + assert len(best_total_amount) == 1 + assert best_total_amount[0]['value'] == ('100.00' if no_empty_prediction_fields else None) + line_items = [p for p in top_1 if p['label'] == 'line_items'] + assert len(line_items) == 1 + line_values = line_items[0]['value'] + assert [line_candidate for line_candidate in line_values if line_candidate[0]['page'] == 1] From 3be52ed72e36bde8b0cef45f269deda77ff485ef Mon Sep 17 00:00:00 2001 From: magnurud Date: Wed, 18 Jun 2025 10:28:56 +0200 Subject: [PATCH 2/4] Add comment --- docs/workflows/transitions/preprocess/tests/test_handler.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docs/workflows/transitions/preprocess/tests/test_handler.py b/docs/workflows/transitions/preprocess/tests/test_handler.py index 8ba6bb0..3062147 100644 --- a/docs/workflows/transitions/preprocess/tests/test_handler.py +++ b/docs/workflows/transitions/preprocess/tests/test_handler.py @@ -1006,6 +1006,9 @@ def test_patch_and_filter(form_config_simple, predictions, no_empty_prediction_f new_predictions, top_1 = patch_and_filter_predictions(predictions, field_config, labels, False, no_empty_prediction_fields) best_total_amount = [p for p in top_1 if p['label'] == 'total_amount'] assert len(best_total_amount) == 1 + # If a field is part of no_empty_prediction_fields, it means that there exists at least one page with a prediction + # If a field is not part of no_empty_prediction_fields, it means that there is no prediction for that field on any page + # This would not be the case for the predictions above, but faking it will make the best_total_amount None since it has the highest confidence. assert best_total_amount[0]['value'] == ('100.00' if no_empty_prediction_fields else None) line_items = [p for p in top_1 if p['label'] == 'line_items'] assert len(line_items) == 1 From 3cd9afc2ce115c6a520d1bb0b7a03cbfaedc946d Mon Sep 17 00:00:00 2001 From: magnurud Date: Wed, 18 Jun 2025 13:46:49 +0200 Subject: [PATCH 3/4] fix linting --- .../preprocess/preprocess/make_predictions.py | 8 +++++++- .../transitions/preprocess/tests/test_handler.py | 13 +++++++++---- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/docs/workflows/transitions/preprocess/preprocess/make_predictions.py b/docs/workflows/transitions/preprocess/preprocess/make_predictions.py index 4ff6e39..f963f5f 100644 --- a/docs/workflows/transitions/preprocess/preprocess/make_predictions.py +++ b/docs/workflows/transitions/preprocess/preprocess/make_predictions.py @@ -160,7 +160,13 @@ def make_predictions(las_client, event): try: if predictions: field_config = form_config['config']['fields'] - predictions, top1_preds = patch_and_filter_predictions(predictions, field_config, labels, model_metadata.get('mergeContinuedLines'), no_empty_prediction_fields) + predictions, top1_preds = patch_and_filter_predictions( + predictions=predictions, + field_config=field_config, + labels=labels, + merge_continued_lines=model_metadata.get('mergeContinuedLines'), + no_empty_prediction_fields=no_empty_prediction_fields, + ) logging.info(f'patched and filtered predictions {predictions}') all_above_threshold_or_optional = True if threshold_is_zero_for_all(field_config): diff --git a/docs/workflows/transitions/preprocess/tests/test_handler.py b/docs/workflows/transitions/preprocess/tests/test_handler.py index 3062147..49cd12d 100644 --- a/docs/workflows/transitions/preprocess/tests/test_handler.py +++ b/docs/workflows/transitions/preprocess/tests/test_handler.py @@ -960,6 +960,7 @@ def test_create_form_config_from_model(field_config, form_config): else: assert conf_levels == form_config[field]['fields'][line_field]['confidenceLevels']['automated'] + @pytest.fixture def form_config_simple(): yield base64.b64encode(json.dumps({ @@ -982,6 +983,7 @@ def form_config_simple(): } }).encode('utf-8')) + @pytest.mark.parametrize('predictions', [[ {'label': 'total_amount', 'page': 0, 'value': None, 'confidence': 0.99}, {'label': 'total_amount', 'page': 1, 'value': '100.00', 'confidence': 0.35}, @@ -1003,12 +1005,15 @@ def test_patch_and_filter(form_config_simple, predictions, no_empty_prediction_f form_config = json.loads(base64.b64decode(form_config_simple)) field_config = form_config['config']['fields'] labels = get_labels(field_config) - new_predictions, top_1 = patch_and_filter_predictions(predictions, field_config, labels, False, no_empty_prediction_fields) + _, top_1 = patch_and_filter_predictions(predictions, field_config, labels, False, no_empty_prediction_fields) best_total_amount = [p for p in top_1 if p['label'] == 'total_amount'] assert len(best_total_amount) == 1 - # If a field is part of no_empty_prediction_fields, it means that there exists at least one page with a prediction - # If a field is not part of no_empty_prediction_fields, it means that there is no prediction for that field on any page - # This would not be the case for the predictions above, but faking it will make the best_total_amount None since it has the highest confidence. + # If a field is part of no_empty_prediction_fields, + # it means that there exists at least one page with a prediction + # If a field is not part of no_empty_prediction_fields, + # it means that there is no prediction for that field on any page + # This would not be the case for the predictions above, + # but faking it will make the best_total_amount None since it has the highest confidence. assert best_total_amount[0]['value'] == ('100.00' if no_empty_prediction_fields else None) line_items = [p for p in top_1 if p['label'] == 'line_items'] assert len(line_items) == 1 From 40ab3ba987580174ceb2b03feef5847b8eb96403 Mon Sep 17 00:00:00 2001 From: magnurud Date: Wed, 18 Jun 2025 13:50:45 +0200 Subject: [PATCH 4/4] remove additional whitespace --- .../transitions/preprocess/preprocess/make_predictions.py | 8 ++++---- .../transitions/preprocess/tests/test_handler.py | 6 +++--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/docs/workflows/transitions/preprocess/preprocess/make_predictions.py b/docs/workflows/transitions/preprocess/preprocess/make_predictions.py index f963f5f..907e4e3 100644 --- a/docs/workflows/transitions/preprocess/preprocess/make_predictions.py +++ b/docs/workflows/transitions/preprocess/preprocess/make_predictions.py @@ -161,10 +161,10 @@ def make_predictions(las_client, event): if predictions: field_config = form_config['config']['fields'] predictions, top1_preds = patch_and_filter_predictions( - predictions=predictions, - field_config=field_config, - labels=labels, - merge_continued_lines=model_metadata.get('mergeContinuedLines'), + predictions=predictions, + field_config=field_config, + labels=labels, + merge_continued_lines=model_metadata.get('mergeContinuedLines'), no_empty_prediction_fields=no_empty_prediction_fields, ) logging.info(f'patched and filtered predictions {predictions}') diff --git a/docs/workflows/transitions/preprocess/tests/test_handler.py b/docs/workflows/transitions/preprocess/tests/test_handler.py index 49cd12d..3681076 100644 --- a/docs/workflows/transitions/preprocess/tests/test_handler.py +++ b/docs/workflows/transitions/preprocess/tests/test_handler.py @@ -1008,11 +1008,11 @@ def test_patch_and_filter(form_config_simple, predictions, no_empty_prediction_f _, top_1 = patch_and_filter_predictions(predictions, field_config, labels, False, no_empty_prediction_fields) best_total_amount = [p for p in top_1 if p['label'] == 'total_amount'] assert len(best_total_amount) == 1 - # If a field is part of no_empty_prediction_fields, + # If a field is part of no_empty_prediction_fields, # it means that there exists at least one page with a prediction - # If a field is not part of no_empty_prediction_fields, + # If a field is not part of no_empty_prediction_fields, # it means that there is no prediction for that field on any page - # This would not be the case for the predictions above, + # This would not be the case for the predictions above, # but faking it will make the best_total_amount None since it has the highest confidence. assert best_total_amount[0]['value'] == ('100.00' if no_empty_prediction_fields else None) line_items = [p for p in top_1 if p['label'] == 'line_items']