mp-02 commited on
Commit
e41ca05
1 Parent(s): fa99101

Update sroie_inference.py

Browse files
Files changed (1) hide show
  1. sroie_inference.py +114 -114
sroie_inference.py CHANGED
@@ -1,114 +1,114 @@
1
- import torch
2
- import cv2
3
- import numpy as np
4
- from PIL import Image, ImageDraw, ImageFont
5
- from transformers import LayoutLMv3TokenizerFast, LayoutLMv3Processor, LayoutLMv3ForTokenClassification
6
- from utils import OCR, unnormalize_box
7
-
8
-
9
- labels = ["O", "B-COMPANY", "I-COMPANY", "B-DATE", "I-DATE", "B-ADDRESS", "I-ADDRESS", "B-TOTAL", "I-TOTAL"]
10
- id2label = {v: k for v, k in enumerate(labels)}
11
- label2id = {k: v for v, k in enumerate(labels)}
12
-
13
- tokenizer = LayoutLMv3TokenizerFast.from_pretrained("Theivaprakasham/layoutlmv3-finetuned-sroie", apply_ocr=False)
14
- processor = LayoutLMv3Processor.from_pretrained("Theivaprakasham/layoutlmv3-finetuned-sroie", apply_ocr=False)
15
- model = LayoutLMv3ForTokenClassification.from_pretrained("Theivaprakasham/layoutlmv3-finetuned-sroie")
16
-
17
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
18
- model.to(device)
19
-
20
-
21
- def blur(image, boxes):
22
- image = np.array(image)
23
- for box in boxes:
24
-
25
- blur_x = int(box[0])
26
- blur_y = int(box[1])
27
- blur_width = int(box[2]-box[0])
28
- blur_height = int(box[3]-box[1])
29
-
30
- roi = image[blur_y:blur_y + blur_height, blur_x:blur_x + blur_width]
31
- blur_image = cv2.GaussianBlur(roi, (201, 201), 0)
32
- image[blur_y:blur_y + blur_height, blur_x:blur_x + blur_width] = blur_image
33
-
34
- return Image.fromarray(image, 'RGB')
35
-
36
-
37
- def prediction(image):
38
- boxes, words = OCR(image)
39
- encoding = processor(image, words, boxes=boxes, return_offsets_mapping=True, return_tensors="pt", truncation=True)
40
- offset_mapping = encoding.pop('offset_mapping')
41
-
42
- for k, v in encoding.items():
43
- encoding[k] = v.to(device)
44
-
45
- outputs = model(**encoding)
46
-
47
- predictions = outputs.logits.argmax(-1).squeeze().tolist()
48
- token_boxes = encoding.bbox.squeeze().tolist()
49
-
50
- inp_ids = encoding.input_ids.squeeze().tolist()
51
- inp_words = [tokenizer.decode(i) for i in inp_ids]
52
-
53
- width, height = image.size
54
- is_subword = np.array(offset_mapping.squeeze().tolist())[:, 0] != 0
55
-
56
- true_predictions = [id2label[pred] for idx, pred in enumerate(predictions) if not is_subword[idx]]
57
- true_boxes = [unnormalize_box(box, width, height) for idx, box in enumerate(token_boxes) if not is_subword[idx]]
58
- true_words = []
59
-
60
- for id, i in enumerate(inp_words):
61
- if not is_subword[id]:
62
- true_words.append(i)
63
- else:
64
- true_words[-1] = true_words[-1]+i
65
-
66
- true_predictions = true_predictions[1:-1]
67
- true_boxes = true_boxes[1:-1]
68
- true_words = true_words[1:-1]
69
-
70
- preds = []
71
- l_words = []
72
- bboxes = []
73
-
74
- for i, j in enumerate(true_predictions):
75
- if j != 'others':
76
- preds.append(true_predictions[i])
77
- l_words.append(true_words[i])
78
- bboxes.append(true_boxes[i])
79
-
80
- d = {}
81
- for id, i in enumerate(preds):
82
- if i not in d.keys():
83
- d[i] = l_words[id]
84
- else:
85
- d[i] = d[i] + ", " + l_words[id]
86
-
87
- d = {k: v.strip() for (k, v) in d.items()}
88
-
89
- keys_to_pop = []
90
- for k, v in d.items():
91
- if k[:2] == "I-":
92
- d["B-" + k[2:]] = d["B-" + k[2:]] + ", " + v
93
- keys_to_pop.append(k)
94
-
95
- if "O" in d: d.pop("O")
96
- if "B-TOTAL" in d: d.pop("B-TOTAL")
97
- for k in keys_to_pop: d.pop(k)
98
-
99
- blur_boxes = []
100
- for prediction, box in zip(preds, bboxes):
101
- if prediction != 'O' and prediction[2:] != 'TOTAL':
102
- blur_boxes.append(box)
103
-
104
- image = (blur(image, blur_boxes))
105
-
106
- draw = ImageDraw.Draw(image, "RGBA")
107
- font = ImageFont.load_default()
108
-
109
- for prediction, box in zip(preds, bboxes):
110
- draw.rectangle(box)
111
- draw.text((box[0]+10, box[1]-10), text=prediction, font=font, fill="black", font_size="8")
112
-
113
- return d, image
114
-
 
1
+ import torch
2
+ import cv2
3
+ import numpy as np
4
+ from PIL import Image, ImageDraw, ImageFont
5
+ from transformers import LayoutLMv3TokenizerFast, LayoutLMv3Processor, LayoutLMv3ForTokenClassification
6
+ from utils import OCR, unnormalize_box
7
+
8
+
9
+ labels = ["O", "B-COMPANY", "I-COMPANY", "B-DATE", "I-DATE", "B-ADDRESS", "I-ADDRESS", "B-TOTAL", "I-TOTAL"]
10
+ id2label = {v: k for v, k in enumerate(labels)}
11
+ label2id = {k: v for v, k in enumerate(labels)}
12
+
13
+ tokenizer = LayoutLMv3TokenizerFast.from_pretrained("Theivaprakasham/layoutlmv3-finetuned-sroie", apply_ocr=False)
14
+ processor = LayoutLMv3Processor.from_pretrained("Theivaprakasham/layoutlmv3-finetuned-sroie", apply_ocr=False)
15
+ model = LayoutLMv3ForTokenClassification.from_pretrained("Theivaprakasham/layoutlmv3-finetuned-sroie")
16
+
17
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
18
+ model.to(device)
19
+
20
+
21
+ def blur(image, boxes):
22
+ image = np.array(image)
23
+ for box in boxes:
24
+
25
+ blur_x = int(box[0])
26
+ blur_y = int(box[1])
27
+ blur_width = int(box[2]-box[0])
28
+ blur_height = int(box[3]-box[1])
29
+
30
+ roi = image[blur_y:blur_y + blur_height, blur_x:blur_x + blur_width]
31
+ blur_image = cv2.GaussianBlur(roi, (201, 201), 0)
32
+ image[blur_y:blur_y + blur_height, blur_x:blur_x + blur_width] = blur_image
33
+
34
+ return Image.fromarray(image, 'RGB')
35
+
36
+
37
+ def prediction(image):
38
+ boxes, words = OCR(image)
39
+ encoding = processor(image, words, boxes=boxes, return_offsets_mapping=True, return_tensors="pt", truncation=True)
40
+ offset_mapping = encoding.pop('offset_mapping')
41
+
42
+ for k, v in encoding.items():
43
+ encoding[k] = v.to(device)
44
+
45
+ outputs = model(**encoding)
46
+
47
+ predictions = outputs.logits.argmax(-1).squeeze().tolist()
48
+ token_boxes = encoding.bbox.squeeze().tolist()
49
+
50
+ inp_ids = encoding.input_ids.squeeze().tolist()
51
+ inp_words = [tokenizer.decode(i) for i in inp_ids]
52
+
53
+ width, height = image.size
54
+ is_subword = np.array(offset_mapping.squeeze().tolist())[:, 0] != 0
55
+
56
+ true_predictions = [id2label[pred] for idx, pred in enumerate(predictions) if not is_subword[idx]]
57
+ true_boxes = [unnormalize_box(box, width, height) for idx, box in enumerate(token_boxes) if not is_subword[idx]]
58
+ true_words = []
59
+
60
+ for id, i in enumerate(inp_words):
61
+ if not is_subword[id]:
62
+ true_words.append(i)
63
+ else:
64
+ true_words[-1] = true_words[-1]+i
65
+
66
+ true_predictions = true_predictions[1:-1]
67
+ true_boxes = true_boxes[1:-1]
68
+ true_words = true_words[1:-1]
69
+
70
+ preds = []
71
+ l_words = []
72
+ bboxes = []
73
+
74
+ for i, j in enumerate(true_predictions):
75
+ if j != 'others':
76
+ preds.append(true_predictions[i])
77
+ l_words.append(true_words[i])
78
+ bboxes.append(true_boxes[i])
79
+
80
+ d = {}
81
+ for id, i in enumerate(preds):
82
+ if i not in d.keys():
83
+ d[i] = l_words[id]
84
+ else:
85
+ d[i] = d[i] + ", " + l_words[id]
86
+
87
+ d = {k: v.strip() for (k, v) in d.items()}
88
+
89
+ keys_to_pop = []
90
+ for k, v in d.items():
91
+ if k[:2] == "I-":
92
+ d["B-" + k[2:]] = d["B-" + k[2:]] + ", " + v
93
+ keys_to_pop.append(k)
94
+
95
+ if "O" in d: d.pop("O")
96
+ if "B-TOTAL" in d: d.pop("B-TOTAL")
97
+ for k in keys_to_pop: d.pop(k)
98
+
99
+ blur_boxes = []
100
+ for prediction, box in zip(preds, bboxes):
101
+ if prediction != 'O' and prediction[2:] != 'TOTAL':
102
+ blur_boxes.append(box)
103
+
104
+ image = (blur(image, blur_boxes))
105
+
106
+ #draw = ImageDraw.Draw(image, "RGBA")
107
+ #font = ImageFont.load_default()
108
+
109
+ #for prediction, box in zip(preds, bboxes):
110
+ # draw.rectangle(box)
111
+ # draw.text((box[0]+10, box[1]-10), text=prediction, font=font, fill="black", font_size="8")
112
+
113
+ return d, image
114
+