Spaces:
Sleeping
Sleeping
Raghu
commited on
Commit
·
6fe5290
1
Parent(s):
b8f0f36
Improve LayoutLM total detection: add confidence scores, validate against OCR text, use OCR fallback when LayoutLM misses total
Browse files
app.py
CHANGED
|
@@ -876,6 +876,7 @@ class LayoutLMFieldExtractor:
|
|
| 876 |
return words, boxes
|
| 877 |
|
| 878 |
def predict_fields(self, image, ocr_results=None):
|
|
|
|
| 879 |
if self.model is None:
|
| 880 |
self.load()
|
| 881 |
|
|
@@ -901,40 +902,122 @@ class LayoutLMFieldExtractor:
|
|
| 901 |
with torch.no_grad():
|
| 902 |
outputs = self.model(**encoding)
|
| 903 |
logits = outputs.logits[0]
|
|
|
|
|
|
|
| 904 |
preds = logits.argmax(-1).cpu().tolist()
|
|
|
|
| 905 |
tokens = self.processor.tokenizer.convert_ids_to_tokens(encoding["input_ids"][0].cpu())
|
| 906 |
|
|
|
|
| 907 |
entities = {"VENDOR": [], "DATE": [], "TOTAL": [], "TIME": []}
|
| 908 |
-
|
|
|
|
|
|
|
| 909 |
|
| 910 |
-
for token, pred in zip(tokens, preds):
|
| 911 |
label = self.id2label.get(pred, "O")
|
|
|
|
|
|
|
| 912 |
if token in ["[PAD]", "[CLS]", "[SEP]"]:
|
| 913 |
continue
|
|
|
|
| 914 |
if label.startswith("B-"):
|
| 915 |
-
#
|
| 916 |
if current["label"] and current["tokens"]:
|
| 917 |
-
|
| 918 |
-
|
|
|
|
|
|
|
|
|
|
| 919 |
elif label.startswith("I-") and current["label"] == label[2:]:
|
| 920 |
current["tokens"].append(token)
|
| 921 |
else:
|
| 922 |
if current["label"] and current["tokens"]:
|
| 923 |
-
|
| 924 |
-
|
| 925 |
-
|
| 926 |
-
|
|
|
|
| 927 |
|
| 928 |
-
|
| 929 |
-
|
| 930 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 931 |
|
| 932 |
-
return
|
| 933 |
-
"vendor": pick_first("VENDOR"),
|
| 934 |
-
"date": pick_first("DATE"),
|
| 935 |
-
"total": pick_first("TOTAL"),
|
| 936 |
-
"time": pick_first("TIME"),
|
| 937 |
-
}
|
| 938 |
|
| 939 |
|
| 940 |
# ============================================================================
|
|
@@ -1157,6 +1240,13 @@ def process_receipt(image):
|
|
| 1157 |
if not ocr_val and layoutlm_val:
|
| 1158 |
# OCR didn't find it, use LayoutLM
|
| 1159 |
fields[field_name] = layoutlm_val
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1160 |
elif ocr_val and layoutlm_val and field_name == 'total':
|
| 1161 |
# For total: validate LayoutLM against OCR text
|
| 1162 |
ocr_text = ' '.join([r['text'] for r in ocr_results])
|
|
@@ -1170,6 +1260,9 @@ def process_receipt(image):
|
|
| 1170 |
else:
|
| 1171 |
# LayoutLM doesn't match OCR, trust OCR (more reliable)
|
| 1172 |
fields['total'] = ocr_val
|
|
|
|
|
|
|
|
|
|
| 1173 |
elif ocr_val and layoutlm_val and field_name != 'total':
|
| 1174 |
# For other fields, prefer LayoutLM if it's longer/more complete
|
| 1175 |
if len(str(layoutlm_val)) > len(str(ocr_val)):
|
|
|
|
| 876 |
return words, boxes
|
| 877 |
|
| 878 |
def predict_fields(self, image, ocr_results=None):
|
| 879 |
+
"""Predict fields with confidence scores and improved total extraction."""
|
| 880 |
if self.model is None:
|
| 881 |
self.load()
|
| 882 |
|
|
|
|
| 902 |
with torch.no_grad():
|
| 903 |
outputs = self.model(**encoding)
|
| 904 |
logits = outputs.logits[0]
|
| 905 |
+
# Get softmax probabilities for confidence
|
| 906 |
+
probs = torch.softmax(logits, dim=-1)
|
| 907 |
preds = logits.argmax(-1).cpu().tolist()
|
| 908 |
+
probs_np = probs.cpu().numpy()
|
| 909 |
tokens = self.processor.tokenizer.convert_ids_to_tokens(encoding["input_ids"][0].cpu())
|
| 910 |
|
| 911 |
+
# Extract entities with confidence scores
|
| 912 |
entities = {"VENDOR": [], "DATE": [], "TOTAL": [], "TIME": []}
|
| 913 |
+
entity_confidences = {"VENDOR": [], "DATE": [], "TOTAL": [], "TIME": []}
|
| 914 |
+
entity_positions = {"VENDOR": [], "DATE": [], "TOTAL": [], "TIME": []}
|
| 915 |
+
current = {"label": None, "tokens": [], "start_idx": None}
|
| 916 |
|
| 917 |
+
for idx, (token, pred) in enumerate(zip(tokens, preds)):
|
| 918 |
label = self.id2label.get(pred, "O")
|
| 919 |
+
conf = float(probs_np[idx, pred])
|
| 920 |
+
|
| 921 |
if token in ["[PAD]", "[CLS]", "[SEP]"]:
|
| 922 |
continue
|
| 923 |
+
|
| 924 |
if label.startswith("B-"):
|
| 925 |
+
# Flush previous
|
| 926 |
if current["label"] and current["tokens"]:
|
| 927 |
+
entity_text = " ".join(current["tokens"]).replace("▁", " ").strip()
|
| 928 |
+
entities[current["label"]].append(entity_text)
|
| 929 |
+
entity_confidences[current["label"]].append(conf)
|
| 930 |
+
entity_positions[current["label"]].append(current["start_idx"])
|
| 931 |
+
current = {"label": label[2:], "tokens": [token], "start_idx": idx}
|
| 932 |
elif label.startswith("I-") and current["label"] == label[2:]:
|
| 933 |
current["tokens"].append(token)
|
| 934 |
else:
|
| 935 |
if current["label"] and current["tokens"]:
|
| 936 |
+
entity_text = " ".join(current["tokens"]).replace("▁", " ").strip()
|
| 937 |
+
entities[current["label"]].append(entity_text)
|
| 938 |
+
entity_confidences[current["label"]].append(conf)
|
| 939 |
+
entity_positions[current["label"]].append(current["start_idx"])
|
| 940 |
+
current = {"label": None, "tokens": [], "start_idx": None}
|
| 941 |
|
| 942 |
+
if current["label"] and current["tokens"]:
|
| 943 |
+
entity_text = " ".join(current["tokens"]).replace("▁", " ").strip()
|
| 944 |
+
entities[current["label"]].append(entity_text)
|
| 945 |
+
entity_confidences[current["label"]].append(conf)
|
| 946 |
+
entity_positions[current["label"]].append(current["start_idx"])
|
| 947 |
+
|
| 948 |
+
# Smart field selection with confidence and position awareness
|
| 949 |
+
result = {}
|
| 950 |
+
|
| 951 |
+
# Vendor: prefer first high-confidence result
|
| 952 |
+
if entities["VENDOR"]:
|
| 953 |
+
best_vendor_idx = max(range(len(entities["VENDOR"])),
|
| 954 |
+
key=lambda i: entity_confidences["VENDOR"][i])
|
| 955 |
+
if entity_confidences["VENDOR"][best_vendor_idx] > 0.3:
|
| 956 |
+
result["vendor"] = entities["VENDOR"][best_vendor_idx]
|
| 957 |
+
|
| 958 |
+
# Date: prefer first high-confidence result
|
| 959 |
+
if entities["DATE"]:
|
| 960 |
+
best_date_idx = max(range(len(entities["DATE"])),
|
| 961 |
+
key=lambda i: entity_confidences["DATE"][i])
|
| 962 |
+
if entity_confidences["DATE"][best_date_idx] > 0.3:
|
| 963 |
+
result["date"] = entities["DATE"][best_date_idx]
|
| 964 |
+
|
| 965 |
+
# Time: prefer first high-confidence result
|
| 966 |
+
if entities["TIME"]:
|
| 967 |
+
best_time_idx = max(range(len(entities["TIME"])),
|
| 968 |
+
key=lambda i: entity_confidences["TIME"][i])
|
| 969 |
+
if entity_confidences["TIME"][best_time_idx] > 0.3:
|
| 970 |
+
result["time"] = entities["TIME"][best_time_idx]
|
| 971 |
+
|
| 972 |
+
# Total: improved extraction - look for amounts near "TOTAL" keyword in OCR
|
| 973 |
+
if entities["TOTAL"]:
|
| 974 |
+
# Get all total candidates with confidence
|
| 975 |
+
total_candidates = [(entities["TOTAL"][i], entity_confidences["TOTAL"][i],
|
| 976 |
+
entity_positions["TOTAL"][i])
|
| 977 |
+
for i in range(len(entities["TOTAL"]))]
|
| 978 |
+
|
| 979 |
+
# If OCR results available, validate against OCR text
|
| 980 |
+
if ocr_results:
|
| 981 |
+
ocr_text = ' '.join([r['text'] for r in ocr_results]).upper()
|
| 982 |
+
ocr_lines = [r['text'] for r in ocr_results]
|
| 983 |
+
|
| 984 |
+
# Find amounts near "TOTAL" keyword
|
| 985 |
+
best_total = None
|
| 986 |
+
best_conf = 0
|
| 987 |
+
|
| 988 |
+
for total_val, conf, pos in total_candidates:
|
| 989 |
+
# Clean the total value
|
| 990 |
+
total_clean = str(total_val).replace('$', '').replace(',', '').replace('.', '').strip()
|
| 991 |
+
|
| 992 |
+
# Check if this total appears near "TOTAL" keyword in OCR
|
| 993 |
+
for i, line in enumerate(ocr_lines):
|
| 994 |
+
line_upper = line.upper()
|
| 995 |
+
if 'TOTAL' in line_upper or 'AMOUNT DUE' in line_upper:
|
| 996 |
+
# Check this line and next 2 lines for the amount
|
| 997 |
+
search_text = ' '.join(ocr_lines[i:min(i+3, len(ocr_lines))])
|
| 998 |
+
search_clean = search_text.replace('$', '').replace(',', '').replace('.', '')
|
| 999 |
+
|
| 1000 |
+
if total_clean in search_clean:
|
| 1001 |
+
# Found near TOTAL keyword - high confidence
|
| 1002 |
+
if conf > best_conf:
|
| 1003 |
+
best_total = total_val
|
| 1004 |
+
best_conf = conf
|
| 1005 |
+
break
|
| 1006 |
+
|
| 1007 |
+
if best_total:
|
| 1008 |
+
result["total"] = best_total
|
| 1009 |
+
else:
|
| 1010 |
+
# Fallback: use highest confidence total
|
| 1011 |
+
best_idx = max(range(len(total_candidates)), key=lambda i: total_candidates[i][1])
|
| 1012 |
+
if total_candidates[best_idx][1] > 0.3:
|
| 1013 |
+
result["total"] = total_candidates[best_idx][0]
|
| 1014 |
+
else:
|
| 1015 |
+
# No OCR, use highest confidence
|
| 1016 |
+
best_idx = max(range(len(total_candidates)), key=lambda i: total_candidates[i][1])
|
| 1017 |
+
if total_candidates[best_idx][1] > 0.3:
|
| 1018 |
+
result["total"] = total_candidates[best_idx][0]
|
| 1019 |
|
| 1020 |
+
return result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1021 |
|
| 1022 |
|
| 1023 |
# ============================================================================
|
|
|
|
| 1240 |
if not ocr_val and layoutlm_val:
|
| 1241 |
# OCR didn't find it, use LayoutLM
|
| 1242 |
fields[field_name] = layoutlm_val
|
| 1243 |
+
elif ocr_val and not layoutlm_val:
|
| 1244 |
+
# LayoutLM didn't find it, but OCR did - use OCR (especially for total)
|
| 1245 |
+
if field_name == 'total':
|
| 1246 |
+
fields[field_name] = ocr_val
|
| 1247 |
+
else:
|
| 1248 |
+
# For other fields, prefer OCR if LayoutLM missed it
|
| 1249 |
+
fields[field_name] = ocr_val
|
| 1250 |
elif ocr_val and layoutlm_val and field_name == 'total':
|
| 1251 |
# For total: validate LayoutLM against OCR text
|
| 1252 |
ocr_text = ' '.join([r['text'] for r in ocr_results])
|
|
|
|
| 1260 |
else:
|
| 1261 |
# LayoutLM doesn't match OCR, trust OCR (more reliable)
|
| 1262 |
fields['total'] = ocr_val
|
| 1263 |
+
elif ocr_val and not layoutlm_val and field_name == 'total':
|
| 1264 |
+
# LayoutLM didn't find total, but OCR did - use OCR
|
| 1265 |
+
fields['total'] = ocr_val
|
| 1266 |
elif ocr_val and layoutlm_val and field_name != 'total':
|
| 1267 |
# For other fields, prefer LayoutLM if it's longer/more complete
|
| 1268 |
if len(str(layoutlm_val)) > len(str(ocr_val)):
|