Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,13 @@
above_threshold_or_optional,
add_confidence_to_ground_truth,
create_form_config_from_model,
filter_away_low_confidence_lines,
filter_by_top1,
filter_optional_fields,
format_verified_output,
get_column_names,
get_labels,
is_enum,
is_line,
concatenate_lines_from_different_pages_and_merge_continued_lines,
concatenate_lines_from_different_pages,
merge_predictions_and_gt,
patch_empty_predictions,
patch_and_filter_predictions,
required_labels,
threshold_is_zero_for_all,
)
Expand Down Expand Up @@ -165,23 +160,14 @@ def make_predictions(las_client, event):
try:
if predictions:
field_config = form_config['config']['fields']
column_names = get_column_names(field_config)
labels = labels.union(column_names)
top1_preds = filter_by_top1(predictions, labels)

if model_metadata.get('mergeContinuedLines'):
predictions = concatenate_lines_from_different_pages_and_merge_continued_lines(
predictions=predictions,
field_config=field_config,
)
else:
predictions = concatenate_lines_from_different_pages(predictions, field_config)

predictions = patch_empty_predictions(predictions, labels, no_empty_prediction_fields)
predictions = filter_away_low_confidence_lines(predictions, field_config)

predictions, top1_preds = patch_and_filter_predictions(
predictions=predictions,
field_config=field_config,
labels=labels,
merge_continued_lines=model_metadata.get('mergeContinuedLines'),
no_empty_prediction_fields=no_empty_prediction_fields,
)
logging.info(f'patched and filtered predictions {predictions}')

all_above_threshold_or_optional = True
if threshold_is_zero_for_all(field_config):
needs_validation = False
Expand All @@ -206,7 +192,7 @@ def make_predictions(las_client, event):
if not above_threshold_or_optional(prediction, field_config):
all_above_threshold_or_optional = False

has_all_required_labels = required_labels(field_config) <= set(map(lambda p: p['label'], predictions))
has_all_required_labels = required_labels(field_config) <= set(map(lambda p: p['label'], top1_preds))
needs_validation = not has_all_required_labels or not all_above_threshold_or_optional

logging.info(f'All predictions above threshold (or optional): {all_above_threshold_or_optional}')
Expand Down
19 changes: 19 additions & 0 deletions docs/workflows/transitions/preprocess/preprocess/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,25 @@ def top1(predictions, label):
return result


def patch_and_filter_predictions(predictions, field_config, labels, merge_continued_lines, no_empty_prediction_fields):
column_names = get_column_names(field_config)
labels = labels.union(column_names)

if merge_continued_lines:
predictions = concatenate_lines_from_different_pages_and_merge_continued_lines(
predictions=predictions,
field_config=field_config,
)
else:
predictions = concatenate_lines_from_different_pages(predictions, field_config)

predictions = patch_empty_predictions(predictions, labels, no_empty_prediction_fields)
predictions = filter_away_low_confidence_lines(predictions, field_config)
top1_preds = filter_by_top1(predictions, labels)

return predictions, top1_preds


def above_threshold_or_optional(prediction, field_config):
label, confidence = prediction['label'], prediction.get('confidence')
if label not in field_config:
Expand Down
69 changes: 65 additions & 4 deletions docs/workflows/transitions/preprocess/tests/test_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,14 @@

from ..preprocess.make_predictions import make_predictions
from ..preprocess.utils import (
concatenate_lines_from_different_pages_and_merge_continued_lines,
create_form_config_from_model,
filter_away_low_confidence_lines,
filter_by_top1,
concatenate_lines_from_different_pages_and_merge_continued_lines,
patch_empty_predictions,
get_labels,
get_column_names,
filter_away_low_confidence_lines,
get_labels,
patch_and_filter_predictions,
patch_empty_predictions,
)


Expand Down Expand Up @@ -958,3 +959,63 @@ def test_create_form_config_from_model(field_config, form_config):
assert conf_levels
else:
assert conf_levels == form_config[field]['fields'][line_field]['confidenceLevels']['automated']


@pytest.fixture
def form_config_simple():
yield base64.b64encode(json.dumps({
'config': {
'fields': {
'total_amount': {
'type': 'amount',
'confidenceLevels': {'automated': 0.98, 'high': 0.97, 'medium': 0.9, 'low': 0.5}
},
'line_items': {
'type': 'lines',
'fields': {
'subtotal': {
'type': 'string',
'confidenceLevels': {'automated': 0.98, 'high': 0.97, 'medium': 0.9, 'low': 0.5},
}
}
}
}
}
}).encode('utf-8'))


@pytest.mark.parametrize('predictions', [[
{'label': 'total_amount', 'page': 0, 'value': None, 'confidence': 0.99},
{'label': 'total_amount', 'page': 1, 'value': '100.00', 'confidence': 0.35},
{'label': 'line_items', 'value': [
[
{'label': 'subtotal', 'page': 0, 'value': '50.00', 'confidence': 0.99},
], [
{'label': 'subtotal', 'page': 0, 'value': '30.00', 'confidence': 0.99},
], [],
]},
{'label': 'line_items', 'value': [
[
{'label': 'subtotal', 'page': 1, 'value': '72.15', 'confidence': 0.9},
],
]},
]])
@pytest.mark.parametrize('no_empty_prediction_fields', [{'total_amount', 'line_items'}, {}])
def test_patch_and_filter(form_config_simple, predictions, no_empty_prediction_fields):
form_config = json.loads(base64.b64decode(form_config_simple))
field_config = form_config['config']['fields']
labels = get_labels(field_config)
_, top_1 = patch_and_filter_predictions(predictions, field_config, labels, False, no_empty_prediction_fields)
best_total_amount = [p for p in top_1 if p['label'] == 'total_amount']
assert len(best_total_amount) == 1
# If a field is part of no_empty_prediction_fields,
# it means that there exists at least one page with a prediction
# If a field is not part of no_empty_prediction_fields,
# it means that there is no prediction for that field on any page
# This would not be the case for the predictions above,
# but faking it will make the best_total_amount None since it has the highest confidence.
assert best_total_amount[0]['value'] == ('100.00' if no_empty_prediction_fields else None)
line_items = [p for p in top_1 if p['label'] == 'line_items']
assert len(line_items) == 1
line_values = line_items[0]['value']
assert [line_candidate for line_candidate in line_values if line_candidate[0]['page'] == 1]