mp-02 commited on
Commit
97095a3
·
verified ·
1 Parent(s): b248a87

Update sroie_inference.py

Browse files
Files changed (1) hide show
  1. sroie_inference.py +9 -4
sroie_inference.py CHANGED
@@ -46,6 +46,8 @@ def prediction(image):
46
 
47
  predictions = outputs.logits.argmax(-1).squeeze().tolist()
48
  token_boxes = encoding.bbox.squeeze().tolist()
 
 
49
 
50
  inp_ids = encoding.input_ids.squeeze().tolist()
51
  inp_words = [tokenizer.decode(i) for i in inp_ids]
@@ -55,6 +57,7 @@ def prediction(image):
55
 
56
  true_predictions = [id2label[pred] for idx, pred in enumerate(predictions) if not is_subword[idx]]
57
  true_boxes = [unnormalize_box(box, width, height) for idx, box in enumerate(token_boxes) if not is_subword[idx]]
 
58
  true_words = []
59
 
60
  for id, i in enumerate(inp_words):
@@ -66,16 +69,18 @@ def prediction(image):
66
  true_predictions = true_predictions[1:-1]
67
  true_boxes = true_boxes[1:-1]
68
  true_words = true_words[1:-1]
 
69
 
70
  preds = []
71
  l_words = []
72
  bboxes = []
73
 
74
  for i, j in enumerate(true_predictions):
75
- if j != 'others':
76
- preds.append(true_predictions[i])
77
- l_words.append(true_words[i])
78
- bboxes.append(true_boxes[i])
 
79
 
80
  d = {}
81
  for id, i in enumerate(preds):
 
46
 
47
  predictions = outputs.logits.argmax(-1).squeeze().tolist()
48
  token_boxes = encoding.bbox.squeeze().tolist()
49
+ probabilities = torch.softmax(outputs.logits, dim=-1)
50
+ confidence_scores = probabilities.max(-1).values.squeeze().tolist()
51
 
52
  inp_ids = encoding.input_ids.squeeze().tolist()
53
  inp_words = [tokenizer.decode(i) for i in inp_ids]
 
57
 
58
  true_predictions = [id2label[pred] for idx, pred in enumerate(predictions) if not is_subword[idx]]
59
  true_boxes = [unnormalize_box(box, width, height) for idx, box in enumerate(token_boxes) if not is_subword[idx]]
60
+ true_confidence_scores = [confidence_scores[idx] for idx, conf in enumerate(confidence_scores) if not is_subword[idx]]
61
  true_words = []
62
 
63
  for id, i in enumerate(inp_words):
 
69
  true_predictions = true_predictions[1:-1]
70
  true_boxes = true_boxes[1:-1]
71
  true_words = true_words[1:-1]
72
+ true_confidence_scores = true_confidence_scores[1:-1]
73
 
74
  preds = []
75
  l_words = []
76
  bboxes = []
77
 
78
  for i, j in enumerate(true_predictions):
79
+ if true_confidence_scores[i] < 0.9: #####################################àà
80
+ true_predictions[i] = "O"
81
+ preds.append(true_predictions[i])
82
+ l_words.append(true_words[i])
83
+ bboxes.append(true_boxes[i])
84
 
85
  d = {}
86
  for id, i in enumerate(preds):