Spaces:
Running
Running
truong-xuan-linh
commited on
Commit
•
1d3d5c8
1
Parent(s):
b2813ce
update visualize
Browse files- app.py +38 -11
- pre-requirements.txt +2 -0
- src/feature_extraction.py +87 -76
- src/image_visualization.py +21 -0
- src/model.py +168 -52
- src/ocr.py +68 -62
- utils/config.py +1 -1
- visualization/.gitkeep +0 -0
app.py
CHANGED
@@ -2,11 +2,13 @@ import glob
|
|
2 |
import streamlit as st
|
3 |
|
4 |
from streamlit_image_select import image_select
|
|
|
5 |
|
6 |
-
#Trick to not init function multitime
|
7 |
if "model" not in st.session_state:
|
8 |
print("INIT MODEL")
|
9 |
from src.model import Model
|
|
|
10 |
st.session_state.model = Model()
|
11 |
print("DONE INIT MODEL")
|
12 |
|
@@ -16,17 +18,25 @@ hide_menu_style = """
|
|
16 |
footer {visibility: hidden;}
|
17 |
</style>
|
18 |
"""
|
19 |
-
st.markdown(hide_menu_style, unsafe_allow_html=
|
20 |
|
21 |
mapper = {
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
}
|
28 |
|
29 |
-
image = st.file_uploader(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
example = image_select("Examples", glob.glob("images/*.jpg"))
|
31 |
|
32 |
if image:
|
@@ -40,10 +50,27 @@ else:
|
|
40 |
st.session_state.question = mapper[example]
|
41 |
st.session_state.image = example
|
42 |
|
43 |
-
if
|
44 |
st.image(st.session_state.image)
|
45 |
question = st.text_input("**Question:** ", value=st.session_state.question)
|
|
|
46 |
if question:
|
47 |
-
answer =
|
|
|
|
|
|
|
|
|
48 |
st.write(f"**Answer:** {answer}")
|
49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
import streamlit as st
|
3 |
|
4 |
from streamlit_image_select import image_select
|
5 |
+
import streamlit.components.v1 as components
|
6 |
|
7 |
+
# Trick to not init function multitime
|
8 |
if "model" not in st.session_state:
|
9 |
print("INIT MODEL")
|
10 |
from src.model import Model
|
11 |
+
|
12 |
st.session_state.model = Model()
|
13 |
print("DONE INIT MODEL")
|
14 |
|
|
|
18 |
footer {visibility: hidden;}
|
19 |
</style>
|
20 |
"""
|
21 |
+
st.markdown(hide_menu_style, unsafe_allow_html=True)
|
22 |
|
23 |
mapper = {
|
24 |
+
"images/000000000645.jpg": "Đây là đâu",
|
25 |
+
"images/000000000661.jpg": "Tốc độ tối đa trên đoạn đường này là bao nhiêu",
|
26 |
+
"images/000000000674.jpg": "Còn bao xa nữa là tới Huế",
|
27 |
+
"images/000000000706.jpg": "Cầu này dài bao nhiêu",
|
28 |
+
"images/000000000777.jpg": "Chè khúc bạch giá bao nhiêu",
|
29 |
}
|
30 |
|
31 |
+
image = st.file_uploader(
|
32 |
+
"Choose an image file",
|
33 |
+
type=[
|
34 |
+
"jpg",
|
35 |
+
"jpeg",
|
36 |
+
"png",
|
37 |
+
"webp",
|
38 |
+
],
|
39 |
+
)
|
40 |
example = image_select("Examples", glob.glob("images/*.jpg"))
|
41 |
|
42 |
if image:
|
|
|
50 |
st.session_state.question = mapper[example]
|
51 |
st.session_state.image = example
|
52 |
|
53 |
+
if "image" in st.session_state:
|
54 |
st.image(st.session_state.image)
|
55 |
question = st.text_input("**Question:** ", value=st.session_state.question)
|
56 |
+
visualize = True
|
57 |
if question:
|
58 |
+
answer, text_attention_html, images_visualize = (
|
59 |
+
st.session_state.model.inference(
|
60 |
+
st.session_state.image, question, visualize
|
61 |
+
)
|
62 |
+
)
|
63 |
st.write(f"**Answer:** {answer}")
|
64 |
+
|
65 |
+
if visualize:
|
66 |
+
st.write("**Explanation**")
|
67 |
+
col1, col2 = st.columns([1, 2])
|
68 |
+
# st.markdown(text_attention_html, unsafe_allow_html=True)
|
69 |
+
with col1:
|
70 |
+
st.write("*Text Attention*")
|
71 |
+
components.html(text_attention_html, height=960, scrolling=True)
|
72 |
+
|
73 |
+
with col2:
|
74 |
+
st.write("*Image Attention*")
|
75 |
+
for image_visualize in images_visualize:
|
76 |
+
st.image(image_visualize)
|
pre-requirements.txt
CHANGED
@@ -6,3 +6,5 @@ torchvision==0.18.0
|
|
6 |
streamlit==1.35.0
|
7 |
transformers==4.41.2
|
8 |
streamlit-image-select==0.6.0
|
|
|
|
|
|
6 |
streamlit==1.35.0
|
7 |
transformers==4.41.2
|
8 |
streamlit-image-select==0.6.0
|
9 |
+
bertviz==1.4.0
|
10 |
+
ipython==8.18.1
|
src/feature_extraction.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-
|
2 |
import torch
|
3 |
import requests
|
4 |
from PIL import Image, ImageFont, ImageDraw, ImageTransform
|
@@ -9,7 +8,9 @@ from src.ocr import OCRDetector
|
|
9 |
|
10 |
class ViT:
|
11 |
def __init__(self) -> None:
|
12 |
-
self.processor = AutoImageProcessor.from_pretrained(
|
|
|
|
|
13 |
self.model = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")
|
14 |
self.model.to(Config.device)
|
15 |
|
@@ -23,7 +24,9 @@ class ViT:
|
|
23 |
with torch.no_grad():
|
24 |
outputs = self.model(**inputs)
|
25 |
last_hidden_states = outputs.last_hidden_state
|
26 |
-
attention_mask = torch.ones(
|
|
|
|
|
27 |
|
28 |
return last_hidden_states.to(Config.device), attention_mask.to(Config.device)
|
29 |
|
@@ -34,16 +37,20 @@ class ViT:
|
|
34 |
image_outputs = self.model(**image_inputs)
|
35 |
image_pooler_output = image_outputs.pooler_output
|
36 |
image_pooler_output = torch.unsqueeze(image_pooler_output, 0)
|
37 |
-
image_attention_mask = torch.ones(
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
|
39 |
-
return image_pooler_output.to(Config.device), image_attention_mask.to(Config.device)
|
40 |
|
41 |
class OCR:
|
42 |
def __init__(self) -> None:
|
43 |
self.ocr_detector = OCRDetector()
|
44 |
|
45 |
def extraction(self, image_dir):
|
46 |
-
|
47 |
ocr_results = self.ocr_detector.text_detector(image_dir)
|
48 |
if not ocr_results:
|
49 |
print("NOT OCR1")
|
@@ -53,7 +60,6 @@ class OCR:
|
|
53 |
ocrs = self.post_process(ocr_results)
|
54 |
|
55 |
if not ocrs:
|
56 |
-
|
57 |
return "", [], []
|
58 |
|
59 |
ocrs.reverse()
|
@@ -74,10 +80,9 @@ class OCR:
|
|
74 |
ocr_content = " ".join(ocr_content.split())
|
75 |
ocr_content = "<extra_id_0>" + ocr_content
|
76 |
|
77 |
-
|
78 |
return ocr_content, groups_box, paragraph_boxes
|
79 |
|
80 |
-
def post_process(self,ocr_results):
|
81 |
ocrs = []
|
82 |
for result in ocr_results:
|
83 |
text = result["text"]
|
@@ -96,10 +101,7 @@ class OCR:
|
|
96 |
# if w*h < 300:
|
97 |
# continue
|
98 |
|
99 |
-
ocrs.append(
|
100 |
-
{"text": text.lower(),
|
101 |
-
"box": box}
|
102 |
-
)
|
103 |
return ocrs
|
104 |
|
105 |
@staticmethod
|
@@ -107,87 +109,96 @@ class OCR:
|
|
107 |
(x1, y1), (x2, y2), (x3, y3), (x4, y4) = box
|
108 |
w = x2 - x1
|
109 |
h = y4 - y1
|
110 |
-
scl = h//7
|
111 |
-
new_box =
|
|
|
|
|
|
|
|
|
|
|
112 |
(x1, y1), (x2, y2), (x3, y3), (x4, y4) = new_box
|
113 |
# Define 8-tuple with x,y coordinates of top-left, bottom-left, bottom-right and top-right corners and apply
|
114 |
transform = [x1, y1, x4, y4, x3, y3, x2, y2]
|
115 |
-
result = image.transform((w,h), ImageTransform.QuadTransform(transform))
|
116 |
return result
|
117 |
|
118 |
-
|
119 |
@staticmethod
|
120 |
def check_point_in_rectangle(box, point, padding_devide):
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
|
129 |
-
|
130 |
-
|
131 |
|
132 |
-
|
133 |
-
|
134 |
|
135 |
-
|
136 |
|
137 |
-
|
138 |
-
|
139 |
|
140 |
-
|
141 |
|
142 |
@staticmethod
|
143 |
def check_rectangle_overlap(rec1, rec2, padding_devide):
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
|
152 |
-
|
153 |
|
154 |
@staticmethod
|
155 |
def group_boxes(boxes, texts):
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import torch
|
2 |
import requests
|
3 |
from PIL import Image, ImageFont, ImageDraw, ImageTransform
|
|
|
8 |
|
9 |
class ViT:
|
10 |
def __init__(self) -> None:
|
11 |
+
self.processor = AutoImageProcessor.from_pretrained(
|
12 |
+
"google/vit-base-patch16-224-in21k"
|
13 |
+
)
|
14 |
self.model = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")
|
15 |
self.model.to(Config.device)
|
16 |
|
|
|
24 |
with torch.no_grad():
|
25 |
outputs = self.model(**inputs)
|
26 |
last_hidden_states = outputs.last_hidden_state
|
27 |
+
attention_mask = torch.ones(
|
28 |
+
(last_hidden_states.shape[0], last_hidden_states.shape[1])
|
29 |
+
)
|
30 |
|
31 |
return last_hidden_states.to(Config.device), attention_mask.to(Config.device)
|
32 |
|
|
|
37 |
image_outputs = self.model(**image_inputs)
|
38 |
image_pooler_output = image_outputs.pooler_output
|
39 |
image_pooler_output = torch.unsqueeze(image_pooler_output, 0)
|
40 |
+
image_attention_mask = torch.ones(
|
41 |
+
(image_pooler_output.shape[0], image_pooler_output.shape[1])
|
42 |
+
)
|
43 |
+
|
44 |
+
return image_pooler_output.to(Config.device), image_attention_mask.to(
|
45 |
+
Config.device
|
46 |
+
)
|
47 |
|
|
|
48 |
|
49 |
class OCR:
|
50 |
def __init__(self) -> None:
|
51 |
self.ocr_detector = OCRDetector()
|
52 |
|
53 |
def extraction(self, image_dir):
|
|
|
54 |
ocr_results = self.ocr_detector.text_detector(image_dir)
|
55 |
if not ocr_results:
|
56 |
print("NOT OCR1")
|
|
|
60 |
ocrs = self.post_process(ocr_results)
|
61 |
|
62 |
if not ocrs:
|
|
|
63 |
return "", [], []
|
64 |
|
65 |
ocrs.reverse()
|
|
|
80 |
ocr_content = " ".join(ocr_content.split())
|
81 |
ocr_content = "<extra_id_0>" + ocr_content
|
82 |
|
|
|
83 |
return ocr_content, groups_box, paragraph_boxes
|
84 |
|
85 |
+
def post_process(self, ocr_results):
|
86 |
ocrs = []
|
87 |
for result in ocr_results:
|
88 |
text = result["text"]
|
|
|
101 |
# if w*h < 300:
|
102 |
# continue
|
103 |
|
104 |
+
ocrs.append({"text": text.lower(), "box": box})
|
|
|
|
|
|
|
105 |
return ocrs
|
106 |
|
107 |
@staticmethod
|
|
|
109 |
(x1, y1), (x2, y2), (x3, y3), (x4, y4) = box
|
110 |
w = x2 - x1
|
111 |
h = y4 - y1
|
112 |
+
scl = h // 7
|
113 |
+
new_box = (
|
114 |
+
[max(x1 - scl, 0), max(y1 - scl, 0)],
|
115 |
+
[x2 + scl, y2 - scl],
|
116 |
+
[x3 + scl, y3 + scl],
|
117 |
+
[x4 - scl, y4 + scl],
|
118 |
+
)
|
119 |
(x1, y1), (x2, y2), (x3, y3), (x4, y4) = new_box
|
120 |
# Define 8-tuple with x,y coordinates of top-left, bottom-left, bottom-right and top-right corners and apply
|
121 |
transform = [x1, y1, x4, y4, x3, y3, x2, y2]
|
122 |
+
result = image.transform((w, h), ImageTransform.QuadTransform(transform))
|
123 |
return result
|
124 |
|
|
|
125 |
@staticmethod
|
126 |
def check_point_in_rectangle(box, point, padding_devide):
|
127 |
+
(x1, y1), (x2, y2), (x3, y3), (x4, y4) = box
|
128 |
+
x_min = min(x1, x4)
|
129 |
+
x_max = max(x2, x3)
|
130 |
|
131 |
+
padding = (x_max - x_min) // padding_devide
|
132 |
+
x_min = x_min - padding
|
133 |
+
x_max = x_max + padding
|
134 |
|
135 |
+
y_min = min(y1, y2)
|
136 |
+
y_max = max(y3, y4)
|
137 |
|
138 |
+
y_min = y_min - padding
|
139 |
+
y_max = y_max + padding
|
140 |
|
141 |
+
x, y = point
|
142 |
|
143 |
+
if x >= x_min and x <= x_max and y >= y_min and y <= y_max:
|
144 |
+
return True
|
145 |
|
146 |
+
return False
|
147 |
|
148 |
@staticmethod
|
149 |
def check_rectangle_overlap(rec1, rec2, padding_devide):
|
150 |
+
for point in rec1:
|
151 |
+
if OCR.check_point_in_rectangle(rec2, point, padding_devide):
|
152 |
+
return True
|
153 |
|
154 |
+
for point in rec2:
|
155 |
+
if OCR.check_point_in_rectangle(rec1, point, padding_devide):
|
156 |
+
return True
|
157 |
|
158 |
+
return False
|
159 |
|
160 |
@staticmethod
|
161 |
def group_boxes(boxes, texts):
|
162 |
+
groups = []
|
163 |
+
groups_text = []
|
164 |
+
paragraph_boxes = []
|
165 |
+
processed = []
|
166 |
+
boxes_cp = boxes.copy()
|
167 |
+
for i, (box, text) in enumerate(zip(boxes_cp, texts)):
|
168 |
+
(x1, y1), (x2, y2), (x3, y3), (x4, y4) = box
|
169 |
+
|
170 |
+
if i not in processed:
|
171 |
+
processed.append(i)
|
172 |
+
else:
|
173 |
+
continue
|
174 |
+
|
175 |
+
groups.append([box])
|
176 |
+
groups_text.append([text])
|
177 |
+
for j, (box2, text2) in enumerate(zip(boxes_cp[i + 1 :], texts[i + 1 :])):
|
178 |
+
if j + i + 1 in processed:
|
179 |
+
continue
|
180 |
+
padding_devide = len(groups[-1]) * 4
|
181 |
+
is_overlap = OCR.check_rectangle_overlap(box, box2, padding_devide)
|
182 |
+
if is_overlap:
|
183 |
+
(xx1, yy1), (xx2, yy2), (xx3, yy3), (xx4, yy4) = box2
|
184 |
+
processed.append(j + i + 1)
|
185 |
+
groups[-1].append(box2)
|
186 |
+
groups_text[-1].append(text2)
|
187 |
+
new_x1 = min(x1, xx1)
|
188 |
+
new_y1 = min(y1, yy1)
|
189 |
+
new_x2 = max(x2, xx2)
|
190 |
+
new_y2 = min(y2, yy2)
|
191 |
+
new_x3 = max(x3, xx3)
|
192 |
+
new_y3 = max(y3, yy3)
|
193 |
+
new_x4 = min(x4, xx4)
|
194 |
+
new_y4 = max(y4, yy4)
|
195 |
+
|
196 |
+
box = [
|
197 |
+
(new_x1, new_y1),
|
198 |
+
(new_x2, new_y2),
|
199 |
+
(new_x3, new_y3),
|
200 |
+
(new_x4, new_y4),
|
201 |
+
]
|
202 |
+
|
203 |
+
paragraph_boxes.append(box)
|
204 |
+
return groups, groups_text, paragraph_boxes
|
src/image_visualization.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import matplotlib.pyplot as plt
|
2 |
+
|
3 |
+
|
4 |
+
# Show attention
|
5 |
+
def plot_attention(img, result, attention_plot, image_dir):
|
6 |
+
# img = img.numpy().transpose((1, 2, 0))
|
7 |
+
temp_image = img
|
8 |
+
|
9 |
+
fig = plt.figure(figsize=(15, 15))
|
10 |
+
|
11 |
+
len_result = len(result)
|
12 |
+
for l in range(len_result):
|
13 |
+
temp_att = attention_plot[l][1:].reshape(14, 14)
|
14 |
+
# temp_att = np.resize(attention_plot[l].detach().numpy(),(98,98))
|
15 |
+
ax = fig.add_subplot(len_result // 2, len_result // 2, l + 1)
|
16 |
+
ax.set_title(result[l], fontsize=18)
|
17 |
+
img = ax.imshow(temp_image)
|
18 |
+
ax.imshow(temp_att, alpha=0.6, cmap="jet", extent=img.get_extent())
|
19 |
+
|
20 |
+
plt.tight_layout()
|
21 |
+
plt.savefig(image_dir)
|
src/model.py
CHANGED
@@ -8,12 +8,16 @@ from typing import *
|
|
8 |
from transformers import T5ForConditionalGeneration, AutoTokenizer
|
9 |
from utils.config import Config
|
10 |
from src.feature_extraction import ViT, OCR
|
|
|
|
|
|
|
|
|
11 |
|
12 |
_CONFIG_FOR_DOC = "T5Config"
|
13 |
_CHECKPOINT_FOR_DOC = "google-t5/t5-small"
|
14 |
|
15 |
-
class CustomT5Stack(T5Stack):
|
16 |
|
|
|
17 |
def forward(
|
18 |
self,
|
19 |
input_ids=None,
|
@@ -35,11 +39,19 @@ class CustomT5Stack(T5Stack):
|
|
35 |
torch.cuda.set_device(self.first_device)
|
36 |
self.embed_tokens = self.embed_tokens.to(self.first_device)
|
37 |
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
38 |
-
output_attentions =
|
|
|
|
|
|
|
|
|
39 |
output_hidden_states = (
|
40 |
-
output_hidden_states
|
|
|
|
|
|
|
|
|
|
|
41 |
)
|
42 |
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
43 |
|
44 |
if input_ids is not None and inputs_embeds is not None:
|
45 |
err_msg_prefix = "decoder_" if self.is_decoder else ""
|
@@ -53,11 +65,15 @@ class CustomT5Stack(T5Stack):
|
|
53 |
input_shape = inputs_embeds.size()[:-1]
|
54 |
else:
|
55 |
err_msg_prefix = "decoder_" if self.is_decoder else ""
|
56 |
-
raise ValueError(
|
|
|
|
|
57 |
|
58 |
if inputs_embeds is None:
|
59 |
if self.embed_tokens is None:
|
60 |
-
raise ValueError(
|
|
|
|
|
61 |
inputs_embeds = self.embed_tokens(input_ids)
|
62 |
if not self.is_decoder and images_embeds is not None:
|
63 |
inputs_embeds = torch.concat([inputs_embeds, images_embeds], dim=1)
|
@@ -66,33 +82,47 @@ class CustomT5Stack(T5Stack):
|
|
66 |
batch_size, seq_length = input_shape
|
67 |
|
68 |
# required mask seq length can be calculated via length of past
|
69 |
-
mask_seq_length =
|
|
|
|
|
|
|
|
|
70 |
|
71 |
if use_cache is True:
|
72 |
if not self.is_decoder:
|
73 |
-
raise ValueError(
|
|
|
|
|
74 |
|
75 |
# initialize past_key_values with `None` if past does not exist
|
76 |
if past_key_values is None:
|
77 |
past_key_values = [None] * len(self.block)
|
78 |
|
79 |
if attention_mask is None:
|
80 |
-
attention_mask = torch.ones(
|
|
|
|
|
81 |
|
82 |
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
83 |
# ourselves in which case we just need to make it broadcastable to all heads.
|
84 |
-
extended_attention_mask = self.get_extended_attention_mask(
|
|
|
|
|
85 |
|
86 |
# If a 2D or 3D attention mask is provided for the cross-attention
|
87 |
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
88 |
if self.is_decoder and encoder_hidden_states is not None:
|
89 |
-
encoder_batch_size, encoder_sequence_length, _ =
|
|
|
|
|
90 |
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
91 |
if encoder_attention_mask is None:
|
92 |
encoder_attention_mask = torch.ones(
|
93 |
encoder_hidden_shape, device=inputs_embeds.device, dtype=torch.long
|
94 |
)
|
95 |
-
encoder_extended_attention_mask = self.invert_attention_mask(
|
|
|
|
|
96 |
else:
|
97 |
encoder_extended_attention_mask = None
|
98 |
|
@@ -105,7 +135,9 @@ class CustomT5Stack(T5Stack):
|
|
105 |
|
106 |
# Prepare head mask if needed
|
107 |
head_mask = self.get_head_mask(head_mask, self.config.num_layers)
|
108 |
-
cross_attn_head_mask = self.get_head_mask(
|
|
|
|
|
109 |
present_key_value_states = () if use_cache else None
|
110 |
all_hidden_states = () if output_hidden_states else None
|
111 |
all_attentions = () if output_attentions else None
|
@@ -115,7 +147,9 @@ class CustomT5Stack(T5Stack):
|
|
115 |
|
116 |
hidden_states = self.dropout(inputs_embeds)
|
117 |
|
118 |
-
for i, (layer_module, past_key_value) in enumerate(
|
|
|
|
|
119 |
layer_head_mask = head_mask[i]
|
120 |
cross_attn_layer_head_mask = cross_attn_head_mask[i]
|
121 |
# Model parallel
|
@@ -127,15 +161,23 @@ class CustomT5Stack(T5Stack):
|
|
127 |
if position_bias is not None:
|
128 |
position_bias = position_bias.to(hidden_states.device)
|
129 |
if encoder_hidden_states is not None:
|
130 |
-
encoder_hidden_states = encoder_hidden_states.to(
|
|
|
|
|
131 |
if encoder_extended_attention_mask is not None:
|
132 |
-
encoder_extended_attention_mask =
|
|
|
|
|
133 |
if encoder_decoder_position_bias is not None:
|
134 |
-
encoder_decoder_position_bias = encoder_decoder_position_bias.to(
|
|
|
|
|
135 |
if layer_head_mask is not None:
|
136 |
layer_head_mask = layer_head_mask.to(hidden_states.device)
|
137 |
if cross_attn_layer_head_mask is not None:
|
138 |
-
cross_attn_layer_head_mask = cross_attn_layer_head_mask.to(
|
|
|
|
|
139 |
if output_hidden_states:
|
140 |
all_hidden_states = all_hidden_states + (hidden_states,)
|
141 |
|
@@ -181,10 +223,14 @@ class CustomT5Stack(T5Stack):
|
|
181 |
# (cross-attention position bias), (cross-attention weights)
|
182 |
position_bias = layer_outputs[2]
|
183 |
if self.is_decoder and encoder_hidden_states is not None:
|
184 |
-
encoder_decoder_position_bias = layer_outputs[
|
|
|
|
|
185 |
# append next layer key value states
|
186 |
if use_cache:
|
187 |
-
present_key_value_states = present_key_value_states + (
|
|
|
|
|
188 |
|
189 |
if output_attentions:
|
190 |
all_attentions = all_attentions + (layer_outputs[3],)
|
@@ -227,7 +273,9 @@ class CustomT5Stack(T5Stack):
|
|
227 |
|
228 |
class CustomT5ForConditionalGeneration(T5ForConditionalGeneration):
|
229 |
@add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING)
|
230 |
-
@replace_return_docstrings(
|
|
|
|
|
231 |
def forward(
|
232 |
self,
|
233 |
input_ids: Optional[torch.LongTensor] = None,
|
@@ -280,7 +328,9 @@ class CustomT5ForConditionalGeneration(T5ForConditionalGeneration):
|
|
280 |
>>> # studies have shown that owning a dog is good for you.
|
281 |
```"""
|
282 |
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
283 |
-
return_dict =
|
|
|
|
|
284 |
|
285 |
# FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
|
286 |
if head_mask is not None and decoder_head_mask is None:
|
@@ -299,7 +349,7 @@ class CustomT5ForConditionalGeneration(T5ForConditionalGeneration):
|
|
299 |
output_attentions=output_attentions,
|
300 |
output_hidden_states=output_hidden_states,
|
301 |
return_dict=return_dict,
|
302 |
-
images_embeds=images_embeds
|
303 |
)
|
304 |
elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
|
305 |
encoder_outputs = BaseModelOutput(
|
@@ -313,7 +363,11 @@ class CustomT5ForConditionalGeneration(T5ForConditionalGeneration):
|
|
313 |
if self.model_parallel:
|
314 |
torch.cuda.set_device(self.decoder.first_device)
|
315 |
|
316 |
-
if
|
|
|
|
|
|
|
|
|
317 |
# get decoder inputs from shifting lm labels to the right
|
318 |
decoder_input_ids = self._shift_right(labels)
|
319 |
|
@@ -326,7 +380,9 @@ class CustomT5ForConditionalGeneration(T5ForConditionalGeneration):
|
|
326 |
if attention_mask is not None:
|
327 |
attention_mask = attention_mask.to(self.decoder.first_device)
|
328 |
if decoder_attention_mask is not None:
|
329 |
-
decoder_attention_mask = decoder_attention_mask.to(
|
|
|
|
|
330 |
|
331 |
# Decode
|
332 |
decoder_outputs = self.decoder(
|
@@ -382,64 +438,124 @@ class CustomT5ForConditionalGeneration(T5ForConditionalGeneration):
|
|
382 |
encoder_hidden_states=encoder_outputs.hidden_states,
|
383 |
encoder_attentions=encoder_outputs.attentions,
|
384 |
)
|
385 |
-
|
|
|
386 |
transformers.models.t5.modeling_t5.T5Stack = CustomT5Stack
|
387 |
-
transformers.models.t5.modeling_t5.T5ForConditionalGeneration =
|
|
|
|
|
388 |
transformers.T5ForConditionalGeneration = CustomT5ForConditionalGeneration
|
|
|
389 |
|
390 |
|
391 |
class Model:
|
392 |
def __init__(self) -> None:
|
393 |
os.makedirs("storage", exist_ok=True)
|
394 |
-
|
395 |
if not os.path.exists("storage/vlsp_transfomer_vietocr.pth"):
|
396 |
print("DOWNLOADING model")
|
397 |
-
gdown.download(
|
|
|
|
|
398 |
self.vit5_tokenizer = AutoTokenizer.from_pretrained("VietAI/vit5-base")
|
399 |
-
self.model = T5ForConditionalGeneration.from_pretrained(
|
400 |
-
|
401 |
-
|
|
|
|
|
402 |
self.model.to(Config.device)
|
403 |
|
404 |
self.vit = ViT()
|
405 |
self.ocr = OCR()
|
406 |
|
407 |
def get_inputs(self, image_dir: str, question: str):
|
408 |
-
#VIT
|
409 |
image_feature, image_mask = self.vit.extraction(image_dir)
|
410 |
|
411 |
ocr_content, groups_box, paragraph_boxes = self.ocr.extraction(image_dir)
|
412 |
print("Input: ", question + " " + ocr_content)
|
413 |
-
#VIT5
|
414 |
-
input_ = self.vit5_tokenizer(
|
415 |
-
|
416 |
-
|
417 |
-
|
418 |
-
|
|
|
|
|
419 |
|
420 |
input_ids = input_.input_ids
|
421 |
attention_mask = input_.attention_mask
|
422 |
mask = torch.cat((attention_mask, image_mask), 1)
|
423 |
return {
|
424 |
-
|
425 |
-
|
426 |
-
|
427 |
-
|
428 |
|
429 |
-
def inference(self, image_dir: str, question: str):
|
430 |
inputs = self.get_inputs(image_dir, question)
|
431 |
with torch.no_grad():
|
432 |
input_ids = inputs["input_ids"]
|
433 |
attention_mask = inputs["attention_mask"]
|
434 |
images_embeds = inputs["images_embeds"]
|
435 |
generated_ids = self.model.generate(
|
436 |
-
|
437 |
-
|
438 |
-
|
439 |
-
|
440 |
-
|
441 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
442 |
|
443 |
-
|
|
|
|
|
|
|
444 |
|
445 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
from transformers import T5ForConditionalGeneration, AutoTokenizer
|
9 |
from utils.config import Config
|
10 |
from src.feature_extraction import ViT, OCR
|
11 |
+
from bertviz import model_view, head_view
|
12 |
+
from src.image_visualization import plot_attention
|
13 |
+
import numpy as np
|
14 |
+
from PIL import Image
|
15 |
|
16 |
_CONFIG_FOR_DOC = "T5Config"
|
17 |
_CHECKPOINT_FOR_DOC = "google-t5/t5-small"
|
18 |
|
|
|
19 |
|
20 |
+
class CustomT5Stack(T5Stack):
|
21 |
def forward(
|
22 |
self,
|
23 |
input_ids=None,
|
|
|
39 |
torch.cuda.set_device(self.first_device)
|
40 |
self.embed_tokens = self.embed_tokens.to(self.first_device)
|
41 |
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
42 |
+
output_attentions = (
|
43 |
+
output_attentions
|
44 |
+
if output_attentions is not None
|
45 |
+
else self.config.output_attentions
|
46 |
+
)
|
47 |
output_hidden_states = (
|
48 |
+
output_hidden_states
|
49 |
+
if output_hidden_states is not None
|
50 |
+
else self.config.output_hidden_states
|
51 |
+
)
|
52 |
+
return_dict = (
|
53 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
54 |
)
|
|
|
55 |
|
56 |
if input_ids is not None and inputs_embeds is not None:
|
57 |
err_msg_prefix = "decoder_" if self.is_decoder else ""
|
|
|
65 |
input_shape = inputs_embeds.size()[:-1]
|
66 |
else:
|
67 |
err_msg_prefix = "decoder_" if self.is_decoder else ""
|
68 |
+
raise ValueError(
|
69 |
+
f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds"
|
70 |
+
)
|
71 |
|
72 |
if inputs_embeds is None:
|
73 |
if self.embed_tokens is None:
|
74 |
+
raise ValueError(
|
75 |
+
"You have to initialize the model with valid token embeddings"
|
76 |
+
)
|
77 |
inputs_embeds = self.embed_tokens(input_ids)
|
78 |
if not self.is_decoder and images_embeds is not None:
|
79 |
inputs_embeds = torch.concat([inputs_embeds, images_embeds], dim=1)
|
|
|
82 |
batch_size, seq_length = input_shape
|
83 |
|
84 |
# required mask seq length can be calculated via length of past
|
85 |
+
mask_seq_length = (
|
86 |
+
past_key_values[0][0].shape[2] + seq_length
|
87 |
+
if past_key_values is not None
|
88 |
+
else seq_length
|
89 |
+
)
|
90 |
|
91 |
if use_cache is True:
|
92 |
if not self.is_decoder:
|
93 |
+
raise ValueError(
|
94 |
+
f"`use_cache` can only be set to `True` if {self} is used as a decoder"
|
95 |
+
)
|
96 |
|
97 |
# initialize past_key_values with `None` if past does not exist
|
98 |
if past_key_values is None:
|
99 |
past_key_values = [None] * len(self.block)
|
100 |
|
101 |
if attention_mask is None:
|
102 |
+
attention_mask = torch.ones(
|
103 |
+
batch_size, mask_seq_length, device=inputs_embeds.device
|
104 |
+
)
|
105 |
|
106 |
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
107 |
# ourselves in which case we just need to make it broadcastable to all heads.
|
108 |
+
extended_attention_mask = self.get_extended_attention_mask(
|
109 |
+
attention_mask, input_shape
|
110 |
+
)
|
111 |
|
112 |
# If a 2D or 3D attention mask is provided for the cross-attention
|
113 |
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
114 |
if self.is_decoder and encoder_hidden_states is not None:
|
115 |
+
encoder_batch_size, encoder_sequence_length, _ = (
|
116 |
+
encoder_hidden_states.size()
|
117 |
+
)
|
118 |
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
119 |
if encoder_attention_mask is None:
|
120 |
encoder_attention_mask = torch.ones(
|
121 |
encoder_hidden_shape, device=inputs_embeds.device, dtype=torch.long
|
122 |
)
|
123 |
+
encoder_extended_attention_mask = self.invert_attention_mask(
|
124 |
+
encoder_attention_mask
|
125 |
+
)
|
126 |
else:
|
127 |
encoder_extended_attention_mask = None
|
128 |
|
|
|
135 |
|
136 |
# Prepare head mask if needed
|
137 |
head_mask = self.get_head_mask(head_mask, self.config.num_layers)
|
138 |
+
cross_attn_head_mask = self.get_head_mask(
|
139 |
+
cross_attn_head_mask, self.config.num_layers
|
140 |
+
)
|
141 |
present_key_value_states = () if use_cache else None
|
142 |
all_hidden_states = () if output_hidden_states else None
|
143 |
all_attentions = () if output_attentions else None
|
|
|
147 |
|
148 |
hidden_states = self.dropout(inputs_embeds)
|
149 |
|
150 |
+
for i, (layer_module, past_key_value) in enumerate(
|
151 |
+
zip(self.block, past_key_values)
|
152 |
+
):
|
153 |
layer_head_mask = head_mask[i]
|
154 |
cross_attn_layer_head_mask = cross_attn_head_mask[i]
|
155 |
# Model parallel
|
|
|
161 |
if position_bias is not None:
|
162 |
position_bias = position_bias.to(hidden_states.device)
|
163 |
if encoder_hidden_states is not None:
|
164 |
+
encoder_hidden_states = encoder_hidden_states.to(
|
165 |
+
hidden_states.device
|
166 |
+
)
|
167 |
if encoder_extended_attention_mask is not None:
|
168 |
+
encoder_extended_attention_mask = (
|
169 |
+
encoder_extended_attention_mask.to(hidden_states.device)
|
170 |
+
)
|
171 |
if encoder_decoder_position_bias is not None:
|
172 |
+
encoder_decoder_position_bias = encoder_decoder_position_bias.to(
|
173 |
+
hidden_states.device
|
174 |
+
)
|
175 |
if layer_head_mask is not None:
|
176 |
layer_head_mask = layer_head_mask.to(hidden_states.device)
|
177 |
if cross_attn_layer_head_mask is not None:
|
178 |
+
cross_attn_layer_head_mask = cross_attn_layer_head_mask.to(
|
179 |
+
hidden_states.device
|
180 |
+
)
|
181 |
if output_hidden_states:
|
182 |
all_hidden_states = all_hidden_states + (hidden_states,)
|
183 |
|
|
|
223 |
# (cross-attention position bias), (cross-attention weights)
|
224 |
position_bias = layer_outputs[2]
|
225 |
if self.is_decoder and encoder_hidden_states is not None:
|
226 |
+
encoder_decoder_position_bias = layer_outputs[
|
227 |
+
4 if output_attentions else 3
|
228 |
+
]
|
229 |
# append next layer key value states
|
230 |
if use_cache:
|
231 |
+
present_key_value_states = present_key_value_states + (
|
232 |
+
present_key_value_state,
|
233 |
+
)
|
234 |
|
235 |
if output_attentions:
|
236 |
all_attentions = all_attentions + (layer_outputs[3],)
|
|
|
273 |
|
274 |
class CustomT5ForConditionalGeneration(T5ForConditionalGeneration):
|
275 |
@add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING)
|
276 |
+
@replace_return_docstrings(
|
277 |
+
output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC
|
278 |
+
)
|
279 |
def forward(
|
280 |
self,
|
281 |
input_ids: Optional[torch.LongTensor] = None,
|
|
|
328 |
>>> # studies have shown that owning a dog is good for you.
|
329 |
```"""
|
330 |
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
331 |
+
return_dict = (
|
332 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
333 |
+
)
|
334 |
|
335 |
# FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
|
336 |
if head_mask is not None and decoder_head_mask is None:
|
|
|
349 |
output_attentions=output_attentions,
|
350 |
output_hidden_states=output_hidden_states,
|
351 |
return_dict=return_dict,
|
352 |
+
images_embeds=images_embeds,
|
353 |
)
|
354 |
elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
|
355 |
encoder_outputs = BaseModelOutput(
|
|
|
363 |
if self.model_parallel:
|
364 |
torch.cuda.set_device(self.decoder.first_device)
|
365 |
|
366 |
+
if (
|
367 |
+
labels is not None
|
368 |
+
and decoder_input_ids is None
|
369 |
+
and decoder_inputs_embeds is None
|
370 |
+
):
|
371 |
# get decoder inputs from shifting lm labels to the right
|
372 |
decoder_input_ids = self._shift_right(labels)
|
373 |
|
|
|
380 |
if attention_mask is not None:
|
381 |
attention_mask = attention_mask.to(self.decoder.first_device)
|
382 |
if decoder_attention_mask is not None:
|
383 |
+
decoder_attention_mask = decoder_attention_mask.to(
|
384 |
+
self.decoder.first_device
|
385 |
+
)
|
386 |
|
387 |
# Decode
|
388 |
decoder_outputs = self.decoder(
|
|
|
438 |
encoder_hidden_states=encoder_outputs.hidden_states,
|
439 |
encoder_attentions=encoder_outputs.attentions,
|
440 |
)
|
441 |
+
|
442 |
+
|
443 |
transformers.models.t5.modeling_t5.T5Stack = CustomT5Stack
|
444 |
+
transformers.models.t5.modeling_t5.T5ForConditionalGeneration = (
|
445 |
+
CustomT5ForConditionalGeneration
|
446 |
+
)
|
447 |
transformers.T5ForConditionalGeneration = CustomT5ForConditionalGeneration
|
448 |
+
from transformers import T5ForConditionalGeneration
|
449 |
|
450 |
|
451 |
class Model:
|
452 |
def __init__(self) -> None:
|
453 |
os.makedirs("storage", exist_ok=True)
|
454 |
+
|
455 |
if not os.path.exists("storage/vlsp_transfomer_vietocr.pth"):
|
456 |
print("DOWNLOADING model")
|
457 |
+
gdown.download(
|
458 |
+
Config.model_url, output="storage/vlsp_transfomer_vietocr.pth"
|
459 |
+
)
|
460 |
self.vit5_tokenizer = AutoTokenizer.from_pretrained("VietAI/vit5-base")
|
461 |
+
self.model = T5ForConditionalGeneration.from_pretrained(
|
462 |
+
"truong-xuan-linh/VQA-vit5",
|
463 |
+
revision=Config.revision,
|
464 |
+
output_attentions=True,
|
465 |
+
)
|
466 |
self.model.to(Config.device)
|
467 |
|
468 |
self.vit = ViT()
|
469 |
self.ocr = OCR()
|
470 |
|
471 |
def get_inputs(self, image_dir: str, question: str):
|
472 |
+
# VIT
|
473 |
image_feature, image_mask = self.vit.extraction(image_dir)
|
474 |
|
475 |
ocr_content, groups_box, paragraph_boxes = self.ocr.extraction(image_dir)
|
476 |
print("Input: ", question + " " + ocr_content)
|
477 |
+
# VIT5
|
478 |
+
input_ = self.vit5_tokenizer(
|
479 |
+
question + " " + ocr_content,
|
480 |
+
padding="max_length",
|
481 |
+
truncation=True,
|
482 |
+
max_length=Config.question_maxlen + Config.ocr_maxlen,
|
483 |
+
return_tensors="pt",
|
484 |
+
)
|
485 |
|
486 |
input_ids = input_.input_ids
|
487 |
attention_mask = input_.attention_mask
|
488 |
mask = torch.cat((attention_mask, image_mask), 1)
|
489 |
return {
|
490 |
+
"input_ids": input_ids,
|
491 |
+
"attention_mask": mask,
|
492 |
+
"images_embeds": image_feature,
|
493 |
+
}
|
494 |
|
495 |
+
def inference(self, image_dir: str, question: str, explain: bool = False):
|
496 |
inputs = self.get_inputs(image_dir, question)
|
497 |
with torch.no_grad():
|
498 |
input_ids = inputs["input_ids"]
|
499 |
attention_mask = inputs["attention_mask"]
|
500 |
images_embeds = inputs["images_embeds"]
|
501 |
generated_ids = self.model.generate(
|
502 |
+
input_ids=input_ids,
|
503 |
+
attention_mask=attention_mask,
|
504 |
+
images_embeds=images_embeds,
|
505 |
+
num_beams=2,
|
506 |
+
max_length=Config.answer_maxlen,
|
507 |
+
)
|
508 |
+
|
509 |
+
pred_answer = self.vit5_tokenizer.decode(
|
510 |
+
generated_ids[0], skip_special_tokens=True
|
511 |
+
)
|
512 |
+
if not explain:
|
513 |
+
return pred_answer, None, None
|
514 |
|
515 |
+
with self.vit5_tokenizer.as_target_tokenizer():
|
516 |
+
decoder_input_ids = self.vit5_tokenizer(
|
517 |
+
pred_answer, return_tensors="pt", add_special_tokens=True
|
518 |
+
).input_ids
|
519 |
|
520 |
+
with torch.no_grad():
|
521 |
+
outputs = self.model(
|
522 |
+
input_ids=input_ids,
|
523 |
+
attention_mask=attention_mask,
|
524 |
+
images_embeds=images_embeds,
|
525 |
+
decoder_input_ids=decoder_input_ids,
|
526 |
+
)
|
527 |
+
|
528 |
+
encoder_text = self.vit5_tokenizer.convert_ids_to_tokens(input_ids[0])
|
529 |
+
decoder_text = self.vit5_tokenizer.convert_ids_to_tokens(decoder_input_ids[0])
|
530 |
+
while "<pad>" in encoder_text:
|
531 |
+
encoder_text.remove("<pad>")
|
532 |
+
|
533 |
+
text_encoder_attentions = [
|
534 |
+
att[:, :, : len(encoder_text), : len(encoder_text)]
|
535 |
+
for att in outputs.encoder_attentions
|
536 |
+
]
|
537 |
+
text_cross_attentions = [
|
538 |
+
att[:, :, :, : len(encoder_text)] for att in outputs.cross_attentions
|
539 |
+
]
|
540 |
+
|
541 |
+
html_output = head_view(
|
542 |
+
encoder_attention=text_encoder_attentions,
|
543 |
+
decoder_attention=outputs.decoder_attentions,
|
544 |
+
cross_attention=text_cross_attentions,
|
545 |
+
encoder_tokens=encoder_text[: len(encoder_text)],
|
546 |
+
decoder_tokens=decoder_text,
|
547 |
+
# display_mode="light",
|
548 |
+
html_action="return",
|
549 |
+
)
|
550 |
+
|
551 |
+
img = Image.open(image_dir).convert("RGB")
|
552 |
+
image_dirs = []
|
553 |
+
|
554 |
+
for i in range(len(outputs.cross_attentions[:1])):
|
555 |
+
image_dir = f"visualization/test_image_visualize_{i}.jpg"
|
556 |
+
image_dirs.append(image_dir)
|
557 |
+
attention_plot = np.mean(
|
558 |
+
outputs.cross_attentions[i][0, :, :, -197:].detach().numpy(), axis=0
|
559 |
+
)
|
560 |
+
plot_attention(img, decoder_text, attention_plot, image_dir)
|
561 |
+
return pred_answer, html_output.data, image_dirs
|
src/ocr.py
CHANGED
@@ -6,74 +6,80 @@ import requests
|
|
6 |
import numpy as np
|
7 |
from PIL import Image, ImageTransform
|
8 |
|
|
|
9 |
class OCRDetector:
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
|
|
|
|
16 |
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
|
32 |
-
|
33 |
-
|
34 |
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
|
|
|
|
|
|
|
|
|
|
46 |
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
continue
|
61 |
-
return results
|
62 |
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
|
|
6 |
import numpy as np
|
7 |
from PIL import Image, ImageTransform
|
8 |
|
9 |
+
|
10 |
class OCRDetector:
|
11 |
+
def __init__(self) -> None:
|
12 |
+
self.paddle_ocr = PaddleOCR(
|
13 |
+
lang="en",
|
14 |
+
use_angle_cls=False,
|
15 |
+
use_gpu=True if Config.device == "cpu" else False,
|
16 |
+
show_log=False,
|
17 |
+
)
|
18 |
+
# config['weights'] = './weights/transformerocr.pth'
|
19 |
|
20 |
+
vietocr_config = Cfg.load_config_from_name("vgg_transformer")
|
21 |
+
vietocr_config["weights"] = Config.ocr_path
|
22 |
+
vietocr_config["cnn"]["pretrained"] = False
|
23 |
+
vietocr_config["device"] = Config.device
|
24 |
+
vietocr_config["predictor"]["beamsearch"] = False
|
25 |
+
self.viet_ocr = Predictor(vietocr_config)
|
26 |
|
27 |
+
def find_box(self, image):
|
28 |
+
"""Xác định box dựa vào mô hình paddle_ocr"""
|
29 |
+
result = self.paddle_ocr.ocr(image, cls=False, rec=False)
|
30 |
+
result = result[0]
|
31 |
+
# Extracting detected components
|
32 |
+
boxes = result # [res[0] for res in result]
|
33 |
+
boxes = np.array(boxes).astype(int)
|
34 |
|
35 |
+
# scores = [res[1][1] for res in result]
|
36 |
+
return boxes
|
37 |
|
38 |
+
def cut_image_polygon(self, image, box):
|
39 |
+
(x1, y1), (x2, y2), (x3, y3), (x4, y4) = box
|
40 |
+
w = x2 - x1
|
41 |
+
h = y4 - y1
|
42 |
+
scl = h // 7
|
43 |
+
new_box = (
|
44 |
+
[max(x1 - scl, 0), max(y1 - scl, 0)],
|
45 |
+
[x2 + scl, y2 - scl],
|
46 |
+
[x3 + scl, y3 + scl],
|
47 |
+
[x4 - scl, y4 + scl],
|
48 |
+
)
|
49 |
+
(x1, y1), (x2, y2), (x3, y3), (x4, y4) = new_box
|
50 |
+
# Define 8-tuple with x,y coordinates of top-left, bottom-left, bottom-right and top-right corners and apply
|
51 |
+
transform = [x1, y1, x4, y4, x3, y3, x2, y2]
|
52 |
+
result = image.transform((w, h), ImageTransform.QuadTransform(transform))
|
53 |
+
return result
|
54 |
|
55 |
+
def vietnamese_text(self, boxes, image):
|
56 |
+
"""Xác định text dựa vào mô hình viet_ocr"""
|
57 |
+
results = []
|
58 |
+
for box in boxes:
|
59 |
+
try:
|
60 |
+
cut_image = self.cut_image_polygon(image, box)
|
61 |
+
# cut_image = Image.fromarray(np.uint8(cut_image))
|
62 |
+
text, score = self.viet_ocr.predict(cut_image, return_prob=True)
|
63 |
+
if score > Config.vietocr_threshold:
|
64 |
+
results.append({"text": text, "score": score, "box": box})
|
65 |
+
except:
|
66 |
+
continue
|
67 |
+
return results
|
|
|
|
|
68 |
|
69 |
+
# Merge
|
70 |
+
def text_detector(self, image_path):
|
71 |
+
if image_path.startswith("https://"):
|
72 |
+
image = Image.open(requests.get(image_path, stream=True).raw).convert("RGB")
|
73 |
+
else:
|
74 |
+
image = Image.open(image_path).convert("RGB")
|
75 |
+
# np_image = np.array(image)
|
76 |
|
77 |
+
boxes = self.find_box(image_path)
|
78 |
+
if not boxes.any():
|
79 |
+
return None
|
80 |
|
81 |
+
results = self.vietnamese_text(boxes, image)
|
82 |
+
if results != []:
|
83 |
+
return results
|
84 |
+
else:
|
85 |
+
return None
|
utils/config.py
CHANGED
@@ -10,4 +10,4 @@ class Config:
|
|
10 |
ocr_maxobj = 10000
|
11 |
num_ocr = 32
|
12 |
num_beams = 3
|
13 |
-
revision = "version_2_with_extra_id_0"
|
|
|
10 |
ocr_maxobj = 10000
|
11 |
num_ocr = 32
|
12 |
num_beams = 3
|
13 |
+
revision = "version_2_with_extra_id_0"
|
visualization/.gitkeep
ADDED
File without changes
|