Raghu commited on
Commit
6fe5290
·
1 Parent(s): b8f0f36

Improve LayoutLM total detection: add confidence scores, validate against OCR text, use OCR fallback when LayoutLM misses total

Browse files
Files changed (1) hide show
  1. app.py +111 -18
app.py CHANGED
@@ -876,6 +876,7 @@ class LayoutLMFieldExtractor:
876
  return words, boxes
877
 
878
  def predict_fields(self, image, ocr_results=None):
 
879
  if self.model is None:
880
  self.load()
881
 
@@ -901,40 +902,122 @@ class LayoutLMFieldExtractor:
901
  with torch.no_grad():
902
  outputs = self.model(**encoding)
903
  logits = outputs.logits[0]
 
 
904
  preds = logits.argmax(-1).cpu().tolist()
 
905
  tokens = self.processor.tokenizer.convert_ids_to_tokens(encoding["input_ids"][0].cpu())
906
 
 
907
  entities = {"VENDOR": [], "DATE": [], "TOTAL": [], "TIME": []}
908
- current = {"label": None, "tokens": []}
 
 
909
 
910
- for token, pred in zip(tokens, preds):
911
  label = self.id2label.get(pred, "O")
 
 
912
  if token in ["[PAD]", "[CLS]", "[SEP]"]:
913
  continue
 
914
  if label.startswith("B-"):
915
- # flush previous
916
  if current["label"] and current["tokens"]:
917
- entities[current["label"]].append(" ".join(current["tokens"]))
918
- current = {"label": label[2:], "tokens": [token]}
 
 
 
919
  elif label.startswith("I-") and current["label"] == label[2:]:
920
  current["tokens"].append(token)
921
  else:
922
  if current["label"] and current["tokens"]:
923
- entities[current["label"]].append(" ".join(current["tokens"]))
924
- current = {"label": None, "tokens": []}
925
- if current["label"] and current["tokens"]:
926
- entities[current["label"]].append(" ".join(current["tokens"]))
 
927
 
928
- def pick_first(key):
929
- vals = entities.get(key, [])
930
- return vals[0].replace("", " ").strip() if vals else None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
931
 
932
- return {
933
- "vendor": pick_first("VENDOR"),
934
- "date": pick_first("DATE"),
935
- "total": pick_first("TOTAL"),
936
- "time": pick_first("TIME"),
937
- }
938
 
939
 
940
  # ============================================================================
@@ -1157,6 +1240,13 @@ def process_receipt(image):
1157
  if not ocr_val and layoutlm_val:
1158
  # OCR didn't find it, use LayoutLM
1159
  fields[field_name] = layoutlm_val
 
 
 
 
 
 
 
1160
  elif ocr_val and layoutlm_val and field_name == 'total':
1161
  # For total: validate LayoutLM against OCR text
1162
  ocr_text = ' '.join([r['text'] for r in ocr_results])
@@ -1170,6 +1260,9 @@ def process_receipt(image):
1170
  else:
1171
  # LayoutLM doesn't match OCR, trust OCR (more reliable)
1172
  fields['total'] = ocr_val
 
 
 
1173
  elif ocr_val and layoutlm_val and field_name != 'total':
1174
  # For other fields, prefer LayoutLM if it's longer/more complete
1175
  if len(str(layoutlm_val)) > len(str(ocr_val)):
 
876
  return words, boxes
877
 
878
  def predict_fields(self, image, ocr_results=None):
879
+ """Predict fields with confidence scores and improved total extraction."""
880
  if self.model is None:
881
  self.load()
882
 
 
902
  with torch.no_grad():
903
  outputs = self.model(**encoding)
904
  logits = outputs.logits[0]
905
+ # Get softmax probabilities for confidence
906
+ probs = torch.softmax(logits, dim=-1)
907
  preds = logits.argmax(-1).cpu().tolist()
908
+ probs_np = probs.cpu().numpy()
909
  tokens = self.processor.tokenizer.convert_ids_to_tokens(encoding["input_ids"][0].cpu())
910
 
911
+ # Extract entities with confidence scores
912
  entities = {"VENDOR": [], "DATE": [], "TOTAL": [], "TIME": []}
913
+ entity_confidences = {"VENDOR": [], "DATE": [], "TOTAL": [], "TIME": []}
914
+ entity_positions = {"VENDOR": [], "DATE": [], "TOTAL": [], "TIME": []}
915
+ current = {"label": None, "tokens": [], "start_idx": None}
916
 
917
+ for idx, (token, pred) in enumerate(zip(tokens, preds)):
918
  label = self.id2label.get(pred, "O")
919
+ conf = float(probs_np[idx, pred])
920
+
921
  if token in ["[PAD]", "[CLS]", "[SEP]"]:
922
  continue
923
+
924
  if label.startswith("B-"):
925
+ # Flush previous
926
  if current["label"] and current["tokens"]:
927
+ entity_text = " ".join(current["tokens"]).replace("▁", " ").strip()
928
+ entities[current["label"]].append(entity_text)
929
+ entity_confidences[current["label"]].append(conf)
930
+ entity_positions[current["label"]].append(current["start_idx"])
931
+ current = {"label": label[2:], "tokens": [token], "start_idx": idx}
932
  elif label.startswith("I-") and current["label"] == label[2:]:
933
  current["tokens"].append(token)
934
  else:
935
  if current["label"] and current["tokens"]:
936
+ entity_text = " ".join(current["tokens"]).replace("▁", " ").strip()
937
+ entities[current["label"]].append(entity_text)
938
+ entity_confidences[current["label"]].append(conf)
939
+ entity_positions[current["label"]].append(current["start_idx"])
940
+ current = {"label": None, "tokens": [], "start_idx": None}
941
 
942
+ if current["label"] and current["tokens"]:
943
+ entity_text = " ".join(current["tokens"]).replace("▁", " ").strip()
944
+ entities[current["label"]].append(entity_text)
945
+ entity_confidences[current["label"]].append(conf)
946
+ entity_positions[current["label"]].append(current["start_idx"])
947
+
948
+ # Smart field selection with confidence and position awareness
949
+ result = {}
950
+
951
+ # Vendor: prefer first high-confidence result
952
+ if entities["VENDOR"]:
953
+ best_vendor_idx = max(range(len(entities["VENDOR"])),
954
+ key=lambda i: entity_confidences["VENDOR"][i])
955
+ if entity_confidences["VENDOR"][best_vendor_idx] > 0.3:
956
+ result["vendor"] = entities["VENDOR"][best_vendor_idx]
957
+
958
+ # Date: prefer first high-confidence result
959
+ if entities["DATE"]:
960
+ best_date_idx = max(range(len(entities["DATE"])),
961
+ key=lambda i: entity_confidences["DATE"][i])
962
+ if entity_confidences["DATE"][best_date_idx] > 0.3:
963
+ result["date"] = entities["DATE"][best_date_idx]
964
+
965
+ # Time: prefer first high-confidence result
966
+ if entities["TIME"]:
967
+ best_time_idx = max(range(len(entities["TIME"])),
968
+ key=lambda i: entity_confidences["TIME"][i])
969
+ if entity_confidences["TIME"][best_time_idx] > 0.3:
970
+ result["time"] = entities["TIME"][best_time_idx]
971
+
972
+ # Total: improved extraction - look for amounts near "TOTAL" keyword in OCR
973
+ if entities["TOTAL"]:
974
+ # Get all total candidates with confidence
975
+ total_candidates = [(entities["TOTAL"][i], entity_confidences["TOTAL"][i],
976
+ entity_positions["TOTAL"][i])
977
+ for i in range(len(entities["TOTAL"]))]
978
+
979
+ # If OCR results available, validate against OCR text
980
+ if ocr_results:
981
+ ocr_text = ' '.join([r['text'] for r in ocr_results]).upper()
982
+ ocr_lines = [r['text'] for r in ocr_results]
983
+
984
+ # Find amounts near "TOTAL" keyword
985
+ best_total = None
986
+ best_conf = 0
987
+
988
+ for total_val, conf, pos in total_candidates:
989
+ # Clean the total value
990
+ total_clean = str(total_val).replace('$', '').replace(',', '').replace('.', '').strip()
991
+
992
+ # Check if this total appears near "TOTAL" keyword in OCR
993
+ for i, line in enumerate(ocr_lines):
994
+ line_upper = line.upper()
995
+ if 'TOTAL' in line_upper or 'AMOUNT DUE' in line_upper:
996
+ # Check this line and next 2 lines for the amount
997
+ search_text = ' '.join(ocr_lines[i:min(i+3, len(ocr_lines))])
998
+ search_clean = search_text.replace('$', '').replace(',', '').replace('.', '')
999
+
1000
+ if total_clean in search_clean:
1001
+ # Found near TOTAL keyword - high confidence
1002
+ if conf > best_conf:
1003
+ best_total = total_val
1004
+ best_conf = conf
1005
+ break
1006
+
1007
+ if best_total:
1008
+ result["total"] = best_total
1009
+ else:
1010
+ # Fallback: use highest confidence total
1011
+ best_idx = max(range(len(total_candidates)), key=lambda i: total_candidates[i][1])
1012
+ if total_candidates[best_idx][1] > 0.3:
1013
+ result["total"] = total_candidates[best_idx][0]
1014
+ else:
1015
+ # No OCR, use highest confidence
1016
+ best_idx = max(range(len(total_candidates)), key=lambda i: total_candidates[i][1])
1017
+ if total_candidates[best_idx][1] > 0.3:
1018
+ result["total"] = total_candidates[best_idx][0]
1019
 
1020
+ return result
 
 
 
 
 
1021
 
1022
 
1023
  # ============================================================================
 
1240
  if not ocr_val and layoutlm_val:
1241
  # OCR didn't find it, use LayoutLM
1242
  fields[field_name] = layoutlm_val
1243
+ elif ocr_val and not layoutlm_val:
1244
+ # LayoutLM didn't find it, but OCR did - use OCR (especially for total)
1245
+ if field_name == 'total':
1246
+ fields[field_name] = ocr_val
1247
+ else:
1248
+ # For other fields, prefer OCR if LayoutLM missed it
1249
+ fields[field_name] = ocr_val
1250
  elif ocr_val and layoutlm_val and field_name == 'total':
1251
  # For total: validate LayoutLM against OCR text
1252
  ocr_text = ' '.join([r['text'] for r in ocr_results])
 
1260
  else:
1261
  # LayoutLM doesn't match OCR, trust OCR (more reliable)
1262
  fields['total'] = ocr_val
1263
+ elif ocr_val and not layoutlm_val and field_name == 'total':
1264
+ # LayoutLM didn't find total, but OCR did - use OCR
1265
+ fields['total'] = ocr_val
1266
  elif ocr_val and layoutlm_val and field_name != 'total':
1267
  # For other fields, prefer LayoutLM if it's longer/more complete
1268
  if len(str(layoutlm_val)) > len(str(ocr_val)):