Spaces:
Sleeping
Sleeping
Upload 53 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -35
- 0.jpg +0 -0
- README.md +1 -12
- app.py +254 -0
- basenet/__init__.py +0 -0
- basenet/__pycache__/__init__.cpython-310.pyc +0 -0
- basenet/__pycache__/vgg16_bn.cpython-310.pyc +0 -0
- basenet/vgg16_bn.py +71 -0
- craft.py +86 -0
- craft_utils.py +273 -0
- crop_Word/crop_1.jpg +0 -0
- crop_Word/crop_10.jpg +0 -0
- crop_Word/crop_11.jpg +0 -0
- crop_Word/crop_12.jpg +0 -0
- crop_Word/crop_13.jpg +0 -0
- crop_Word/crop_14.jpg +0 -0
- crop_Word/crop_15.jpg +0 -0
- crop_Word/crop_16.jpg +0 -0
- crop_Word/crop_17.jpg +0 -0
- crop_Word/crop_18.jpg +0 -0
- crop_Word/crop_19.jpg +0 -0
- crop_Word/crop_2.jpg +0 -0
- crop_Word/crop_20.jpg +0 -0
- crop_Word/crop_21.jpg +0 -0
- crop_Word/crop_22.jpg +0 -0
- crop_Word/crop_23.jpg +0 -0
- crop_Word/crop_24.jpg +0 -0
- crop_Word/crop_25.jpg +0 -0
- crop_Word/crop_26.jpg +0 -0
- crop_Word/crop_27.jpg +0 -0
- crop_Word/crop_28.jpg +0 -0
- crop_Word/crop_29.jpg +0 -0
- crop_Word/crop_3.jpg +0 -0
- crop_Word/crop_30.jpg +0 -0
- crop_Word/crop_31.jpg +0 -0
- crop_Word/crop_32.jpg +0 -0
- crop_Word/crop_33.jpg +0 -0
- crop_Word/crop_34.jpg +0 -0
- crop_Word/crop_4.jpg +0 -0
- crop_Word/crop_5.jpg +0 -0
- crop_Word/crop_6.jpg +0 -0
- crop_Word/crop_7.jpg +0 -0
- crop_Word/crop_8.jpg +0 -0
- crop_Word/crop_9.jpg +0 -0
- crop_images.py +73 -0
- file_utils.py +104 -0
- git +0 -0
- imgproc.py +74 -0
- refinenet.py +65 -0
- requirements.txt +15 -0
.gitattributes
CHANGED
@@ -1,35 +1 @@
|
|
1 |
-
*.
|
2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
1 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
0.jpg
ADDED
README.md
CHANGED
@@ -1,12 +1 @@
|
|
1 |
-
|
2 |
-
title: Test Ocr
|
3 |
-
emoji: 🌍
|
4 |
-
colorFrom: red
|
5 |
-
colorTo: purple
|
6 |
-
sdk: streamlit
|
7 |
-
sdk_version: 1.36.0
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
-
---
|
11 |
-
|
12 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
+
# ocr_all
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app.py
ADDED
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from PIL import Image
|
3 |
+
from tkinter import ttk
|
4 |
+
import tkinter as tk
|
5 |
+
from tkinter import filedialog
|
6 |
+
from PIL import Image, ImageTk
|
7 |
+
import os
|
8 |
+
import argparse
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.backends.cudnn as cudnn
|
12 |
+
from torch.autograd import Variable
|
13 |
+
from test import copyStateDict
|
14 |
+
from PIL import Image
|
15 |
+
import cv2
|
16 |
+
from skimage import io
|
17 |
+
import numpy as np
|
18 |
+
import test
|
19 |
+
import file_utils
|
20 |
+
import pandas as pd
|
21 |
+
from craft import CRAFT
|
22 |
+
from collections import OrderedDict
|
23 |
+
from PIL import Image
|
24 |
+
from vietocr.tool.predictor import Predictor
|
25 |
+
from vietocr.tool.config import Cfg
|
26 |
+
import os
|
27 |
+
import tkinter as tk
|
28 |
+
from tkinter import filedialog
|
29 |
+
import matplotlib.pyplot as plt
|
30 |
+
import numpy as np
|
31 |
+
from pathlib import Path
|
32 |
+
import cv2
|
33 |
+
import glob
|
34 |
+
import matplotlib.pyplot as plt
|
35 |
+
import tkinter as tk
|
36 |
+
from tkinter import filedialog
|
37 |
+
from PIL import Image, ImageTk
|
38 |
+
import file_utils
|
39 |
+
import os
|
40 |
+
import tkinter.messagebox as messagebox
|
41 |
+
import imgproc
|
42 |
+
|
43 |
+
|
44 |
+
# Tạo yêu cầu đến mô hình
|
45 |
+
def str2bool(v):
|
46 |
+
return v.lower() in ("yes", "y", "true", "t", "1")
|
47 |
+
|
48 |
+
|
49 |
+
|
50 |
+
'''
|
51 |
+
|
52 |
+
'''
|
53 |
+
# CRAFT
|
54 |
+
parser = argparse.ArgumentParser(description='CRAFT Text Detection')
|
55 |
+
parser.add_argument('--trained_model', default='weights/craft_mlt_25k.pth', type=str, help='pretrained model')
|
56 |
+
parser.add_argument('--text_threshold', default=0.7, type=float, help='text confidence threshold')
|
57 |
+
parser.add_argument('--low_text', default=0.4, type=float, help='text low-bound score')
|
58 |
+
parser.add_argument('--link_threshold', default=0.4, type=float, help='link confidence threshold')
|
59 |
+
parser.add_argument('--cpu', default=True, type=str2bool, help='Use cpu for inference')
|
60 |
+
parser.add_argument('--canvas_size', default=1280, type=int, help='image size for inference')
|
61 |
+
parser.add_argument('--mag_ratio', default=1.5, type=float, help='image magnification ratio')
|
62 |
+
parser.add_argument('--poly', default=False, action='store_true', help='enable polygon type')
|
63 |
+
parser.add_argument('--show_time', default=False, action='store_true', help='show processing time')
|
64 |
+
parser.add_argument('--test_folder', default='data_image', type=str, help='đường dẫn tới ảnh đầu vào')
|
65 |
+
parser.add_argument('--refine', default=True, action='store_true', help='enable link refiner')
|
66 |
+
parser.add_argument('--refiner_model', default='weights/craft_refiner_CTW1500.pth', type=str,
|
67 |
+
help='pretrained refiner model')
|
68 |
+
|
69 |
+
args = parser.parse_args()
|
70 |
+
|
71 |
+
|
72 |
+
|
73 |
+
#########################################################################################
|
74 |
+
csv_columns = ['x_top_left', 'y_top_left', 'x_top_right', 'y_top_right', 'x_bot_right', 'y_bot_right', 'x_bot_left',
|
75 |
+
'y_bot_left']
|
76 |
+
# load net
|
77 |
+
net = CRAFT() # initialize
|
78 |
+
print('Đang thực hiện load weight (' + args.trained_model + ')')
|
79 |
+
'''
|
80 |
+
nhảy sang file test, đưa vào train model
|
81 |
+
'''
|
82 |
+
if args.cpu:
|
83 |
+
net.load_state_dict(copyStateDict(torch.load(args.trained_model, map_location='cpu')))
|
84 |
+
else:
|
85 |
+
net.load_state_dict(copyStateDict(torch.load(args.trained_model, map_location='cpu')))
|
86 |
+
|
87 |
+
if args.cpu:
|
88 |
+
net = net.cpu()
|
89 |
+
net = torch.nn.DataParallel(net)
|
90 |
+
cudnn.benchmark = False
|
91 |
+
|
92 |
+
net.eval()
|
93 |
+
# LinkRefiner Đoạn này code không chạy qua nên không cần đọc vì weight đã load ở cái bên trên
|
94 |
+
# còn refine để mặc định bên trên là False nên sẽ bị bỏ qua
|
95 |
+
# ------------------------------------------------------------------------------------------
|
96 |
+
# ------------------------------------------------------------------------------------------
|
97 |
+
refine_net = None
|
98 |
+
if args.refine:
|
99 |
+
from refinenet import RefineNet
|
100 |
+
|
101 |
+
refine_net = RefineNet()
|
102 |
+
print('Đang thực hiện load weight (' + args.refiner_model + ')')
|
103 |
+
if args.cpu:
|
104 |
+
refine_net.load_state_dict(copyStateDict(torch.load(args.refiner_model, map_location='cpu')))
|
105 |
+
refine_net = refine_net.cpu()
|
106 |
+
refine_net = torch.nn.DataParallel(refine_net)
|
107 |
+
else:
|
108 |
+
refine_net.load_state_dict(copyStateDict(torch.load(args.refiner_model, map_location='cpu')))
|
109 |
+
|
110 |
+
refine_net.eval()
|
111 |
+
args.poly = True
|
112 |
+
|
113 |
+
|
114 |
+
config = Cfg.load_config_from_name('vgg_transformer')
|
115 |
+
config['export'] = 'transformerocr_checkpoint.pth'
|
116 |
+
config['device'] = 'cpu'
|
117 |
+
config['predictor']['beamsearch'] = False
|
118 |
+
|
119 |
+
detector = Predictor(config)
|
120 |
+
# ------------------------------------------------------------------------------------------
|
121 |
+
# ------------------------------------------------------------------------------------------
|
122 |
+
|
123 |
+
|
124 |
+
# Tạo tiêu đề và phần tải lên hình ảnh
|
125 |
+
st.title("Trích xuất thông tin từ căn cước công dân")
|
126 |
+
uploaded_file = st.file_uploader("Tải lên ảnh căn cước công dân", type=["jpg", "jpeg", "png"])
|
127 |
+
|
128 |
+
if uploaded_file is not None:
|
129 |
+
# Hiển thị hình ảnh tải lên
|
130 |
+
image = Image.open(uploaded_file)
|
131 |
+
image.save("uploaded_image.jpg")
|
132 |
+
st.image(image, caption='Hình ảnh căn cước', use_column_width=True)
|
133 |
+
import tempfile
|
134 |
+
# Lưu trữ tạm thời và lấy đường dẫn
|
135 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file:
|
136 |
+
# Save the image to the temp file
|
137 |
+
temp_file.write(uploaded_file.read())
|
138 |
+
temp_file_path = temp_file.name
|
139 |
+
|
140 |
+
# Hiển thị đường dẫn tạm thời
|
141 |
+
print(f"Đường dẫn tạm thời của file: {temp_file_path}")
|
142 |
+
# Nút chạy
|
143 |
+
if st.button("Run"):
|
144 |
+
print("ok")
|
145 |
+
image_path = "uploaded_image.jpg"
|
146 |
+
k = 1
|
147 |
+
crop_folder = "crop_Word"
|
148 |
+
result_folder = "Results"
|
149 |
+
image = imgproc.loadImage(image_path)
|
150 |
+
|
151 |
+
bboxes, polys, score_text, det_scores = test.test_net(net, image, args.text_threshold, args.link_threshold,
|
152 |
+
args.low_text, args.cpu, args.poly, args, refine_net)
|
153 |
+
bbox_score = {}
|
154 |
+
|
155 |
+
def crop_polygon(image, vertices, box_num1):
|
156 |
+
# Tạo mặt nạ
|
157 |
+
mask = np.zeros(image.shape[:2], dtype=np.uint8)
|
158 |
+
cv2.fillPoly(mask, [np.int32(vertices)], 255)
|
159 |
+
|
160 |
+
# Tìm bounding rect để crop vùng chứa đa giác
|
161 |
+
rect = cv2.boundingRect(np.int32(vertices))
|
162 |
+
|
163 |
+
# Crop và lấy hình ảnh con theo bounding rect
|
164 |
+
cropped = image[rect[1]:rect[1]+rect[3], rect[0]:rect[0]+rect[2]]
|
165 |
+
|
166 |
+
# Tạo mặt nạ cho vùng đã crop
|
167 |
+
cropped_mask = mask[rect[1]:rect[1]+rect[3], rect[0]:rect[0]+rect[2]]
|
168 |
+
|
169 |
+
# Lọc vùng bằng mặt nạ
|
170 |
+
result = cv2.bitwise_and(cropped, cropped, mask=cropped_mask)
|
171 |
+
crop_path = os.path.join(crop_folder, f"crop_{box_num1 + 1}.jpg")
|
172 |
+
cv2.imwrite(crop_path, result)
|
173 |
+
return result
|
174 |
+
|
175 |
+
if len(bboxes) == 0:
|
176 |
+
with open(f"data_text//text_{k}.txt", "w", encoding="utf-8") as f:
|
177 |
+
f.write(" ")
|
178 |
+
|
179 |
+
else:
|
180 |
+
# for box_num, item in enumerate(bboxes):
|
181 |
+
# # Crop the bbox from the image
|
182 |
+
# pts = np.array(item, np.int32).reshape((-1, 1, 2))
|
183 |
+
# rect = cv2.boundingRect(pts)
|
184 |
+
# x, y, w, h = rect
|
185 |
+
# cropped_img = image[y:y+h, x:x+w].copy()
|
186 |
+
# crop_path = os.path.join(crop_folder, f"crop_{box_num + 1}.jpg")
|
187 |
+
# cv2.imwrite(crop_path, cropped_img)
|
188 |
+
for box_num in range(len(bboxes)):
|
189 |
+
item = bboxes[box_num]
|
190 |
+
data = np.array([[int(item[0][0]), int(item[0][1]), int(item[1][0]), int(item[1][1]), int(item[2][0]),
|
191 |
+
int(item[2][1]), int(item[3][0]), int(item[3][1])]])
|
192 |
+
csvdata = pd.DataFrame(data, columns=csv_columns)
|
193 |
+
csvdata.to_csv(f'data{k}.csv', index=False, mode='a', header=False)
|
194 |
+
|
195 |
+
# save score text
|
196 |
+
filename, file_ext = os.path.splitext(os.path.basename(image_path))
|
197 |
+
mask_file = result_folder + "/res_" + filename + '_mask.jpg' # tạo đường dẫn file bản đồ nhiệt
|
198 |
+
|
199 |
+
cv2.imwrite(mask_file, score_text) # in ra bản đồ nhiệt
|
200 |
+
#
|
201 |
+
file_utils.saveResult(image_path, image[:, :, ::-1], polys, dirname=result_folder)
|
202 |
+
|
203 |
+
cropped_images = []
|
204 |
+
for i, box in enumerate(bboxes):
|
205 |
+
cropped = crop_polygon(image, box, i)
|
206 |
+
cropped_images.append(cropped)
|
207 |
+
|
208 |
+
|
209 |
+
print(f"Đã cắt {len(cropped_images)} vùng bounding box.")
|
210 |
+
path = glob.glob("crop_Word/*.jpg")
|
211 |
+
cv_img = [str(detector.predict(Image.open(f'crop_Word/crop_' + str(i + 1) + '.jpg'))) for i in
|
212 |
+
range(len(bboxes))]
|
213 |
+
print(cv_img)
|
214 |
+
# from google.generativeai.types import HarmCategory, HarmBlockThreshold
|
215 |
+
# import google.generativeai as genai
|
216 |
+
|
217 |
+
# genai.configure(api_key="AIzaSyAH4ayK6nL71wxPtuYOCe32OdZVZAANWic")
|
218 |
+
|
219 |
+
# # Khởi tạo mô hình
|
220 |
+
# model = genai.GenerativeModel(model_name='gemini-1.5-flash')
|
221 |
+
|
222 |
+
# # Thiết lập safe_setting cho các loại harm có sẵn và hợp lệ
|
223 |
+
# safety_settings = [
|
224 |
+
# {"category": HarmCategory.HARM_CATEGORY_HATE_SPEECH, "threshold": HarmBlockThreshold.BLOCK_NONE},
|
225 |
+
# {"category": HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, "threshold": HarmBlockThreshold.BLOCK_NONE},
|
226 |
+
# {"category": HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, "threshold": HarmBlockThreshold.BLOCK_NONE},
|
227 |
+
# {"category": HarmCategory.HARM_CATEGORY_HARASSMENT, "threshold": HarmBlockThreshold.BLOCK_NONE}
|
228 |
+
# ]
|
229 |
+
# print("ok")
|
230 |
+
# response = model.generate_content(
|
231 |
+
# [f"tìm tên, ngày sinh, nơi cư trú, số căn cước, hạn sử dụng trong mảng sau {cv_img}, chỉ trả về tên, ngày sinh, nơi cư trú, số căn cước,hạn sử dụng mà model tìm thấy được, không trả lời thêm gì, ví dụ 'NGUYỄN THANH SANG, 18/05/1981, 223/11 Kv Bỉnh- Dương, Long Hòa, Bình Thủy, Cần Thơ, 092081007131, 18/05/2041 '"],
|
232 |
+
# safety_settings=safety_settings
|
233 |
+
# )
|
234 |
+
# # print(f"tìm tên, ngày sinh, nơi cư trú, số căn cước, hạn sử dụng trong mảng sau {cv_img}, chỉ trả về tên, ngày sinh, nơi cư trú, số căn cước,hạn sử dụng mà model tìm thấy được, không trả lời thêm gì, ví dụ 'NGUYỄN THANH SANG, 18/05/1981, 223/11 Kv Bỉnh- Dương, Long Hòa, Bình Thủy, Cần Thơ, 092081007131, 18/05/2041 '")
|
235 |
+
# print(response.text)
|
236 |
+
|
237 |
+
|
238 |
+
for box in bboxes:
|
239 |
+
cv2.polylines(image, [np.int32(box)], isClosed=True, color=(0, 255, 0), thickness=1)
|
240 |
+
|
241 |
+
plt.figure(figsize=(20, 20))
|
242 |
+
plt.imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
|
243 |
+
plt.title('Detected Text Bounding Boxes')
|
244 |
+
plt.show()
|
245 |
+
print(f"đã load xong ảnh {k + 1}")
|
246 |
+
|
247 |
+
# # Hiển thị kết quả
|
248 |
+
st.subheader("Kết quả trích xuất:")
|
249 |
+
st.text_area("ALL TEXT",cv_img)
|
250 |
+
# st.text_area("Tên", response.text.get('name', ''))
|
251 |
+
# st.text_area("Ngày sinh", response.text.get('dob', ''))
|
252 |
+
# st.text_area("Nơi cư trú", response.text.get('address', ''))
|
253 |
+
# st.text_area("Số căn cước", response.text.get('id_number', ''))
|
254 |
+
# st.text_area("Hạn sử dụng", response.text.get('expiry', ''))
|
basenet/__init__.py
ADDED
File without changes
|
basenet/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (175 Bytes). View file
|
|
basenet/__pycache__/vgg16_bn.cpython-310.pyc
ADDED
Binary file (2.28 kB). View file
|
|
basenet/vgg16_bn.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import namedtuple
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.init as init
|
6 |
+
from torchvision import models
|
7 |
+
|
8 |
+
def init_weights(modules):
|
9 |
+
for m in modules:
|
10 |
+
if isinstance(m, nn.Conv2d):
|
11 |
+
init.xavier_uniform_(m.weight.data)
|
12 |
+
if m.bias is not None:
|
13 |
+
m.bias.data.zero_()
|
14 |
+
elif isinstance(m, nn.BatchNorm2d):
|
15 |
+
m.weight.data.fill_(1)
|
16 |
+
m.bias.data.zero_()
|
17 |
+
elif isinstance(m, nn.Linear):
|
18 |
+
m.weight.data.normal_(0, 0.01)
|
19 |
+
m.bias.data.zero_()
|
20 |
+
|
21 |
+
class vgg16_bn(torch.nn.Module):
|
22 |
+
def __init__(self, pretrained=True, freeze=True):
|
23 |
+
super(vgg16_bn, self).__init__()
|
24 |
+
vgg_pretrained_features = models.vgg16_bn(pretrained=pretrained).features
|
25 |
+
self.slice1 = torch.nn.Sequential()
|
26 |
+
self.slice2 = torch.nn.Sequential()
|
27 |
+
self.slice3 = torch.nn.Sequential()
|
28 |
+
self.slice4 = torch.nn.Sequential()
|
29 |
+
self.slice5 = torch.nn.Sequential()
|
30 |
+
for x in range(12): # conv2_2
|
31 |
+
self.slice1.add_module(str(x), vgg_pretrained_features[x])
|
32 |
+
for x in range(12, 19): # conv3_3
|
33 |
+
self.slice2.add_module(str(x), vgg_pretrained_features[x])
|
34 |
+
for x in range(19, 29): # conv4_3
|
35 |
+
self.slice3.add_module(str(x), vgg_pretrained_features[x])
|
36 |
+
for x in range(29, 39): # conv5_3
|
37 |
+
self.slice4.add_module(str(x), vgg_pretrained_features[x])
|
38 |
+
|
39 |
+
# fc6, fc7 without atrous conv
|
40 |
+
self.slice5 = torch.nn.Sequential(
|
41 |
+
nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
|
42 |
+
nn.Conv2d(512, 1024, kernel_size=3, padding=6, dilation=6),
|
43 |
+
nn.Conv2d(1024, 1024, kernel_size=1)
|
44 |
+
)
|
45 |
+
|
46 |
+
if not pretrained:
|
47 |
+
init_weights(self.slice1.modules())
|
48 |
+
init_weights(self.slice2.modules())
|
49 |
+
init_weights(self.slice3.modules())
|
50 |
+
init_weights(self.slice4.modules())
|
51 |
+
|
52 |
+
init_weights(self.slice5.modules()) # no pretrained model for fc6 and fc7
|
53 |
+
|
54 |
+
if freeze:
|
55 |
+
for param in self.slice1.parameters(): # only first conv
|
56 |
+
param.requires_grad= False
|
57 |
+
|
58 |
+
def forward(self, X):
|
59 |
+
h = self.slice1(X)
|
60 |
+
h_relu2_2 = h
|
61 |
+
h = self.slice2(h)
|
62 |
+
h_relu3_2 = h
|
63 |
+
h = self.slice3(h)
|
64 |
+
h_relu4_3 = h
|
65 |
+
h = self.slice4(h)
|
66 |
+
h_relu5_3 = h
|
67 |
+
h = self.slice5(h)
|
68 |
+
h_fc7 = h
|
69 |
+
vgg_outputs = namedtuple("VggOutputs", ['fc7', 'relu5_3', 'relu4_3', 'relu3_2', 'relu2_2'])
|
70 |
+
out = vgg_outputs(h_fc7, h_relu5_3, h_relu4_3, h_relu3_2, h_relu2_2)
|
71 |
+
return out
|
craft.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) 2019-present NAVER Corp.
|
3 |
+
MIT License
|
4 |
+
"""
|
5 |
+
|
6 |
+
# -*- coding: utf-8 -*-
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
|
11 |
+
from basenet.vgg16_bn import vgg16_bn, init_weights
|
12 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
13 |
+
print(device)
|
14 |
+
class double_conv(nn.Module):
|
15 |
+
def __init__(self, in_ch, mid_ch, out_ch):
|
16 |
+
super(double_conv, self).__init__()
|
17 |
+
self.conv = nn.Sequential(
|
18 |
+
nn.Conv2d(in_ch + mid_ch, mid_ch, kernel_size=1),
|
19 |
+
nn.BatchNorm2d(mid_ch),
|
20 |
+
nn.ReLU(inplace=True),
|
21 |
+
nn.Conv2d(mid_ch, out_ch, kernel_size=3, padding=1),
|
22 |
+
nn.BatchNorm2d(out_ch),
|
23 |
+
nn.ReLU(inplace=True)
|
24 |
+
)
|
25 |
+
|
26 |
+
def forward(self, x):
|
27 |
+
x = self.conv(x)
|
28 |
+
return x
|
29 |
+
|
30 |
+
|
31 |
+
class CRAFT(nn.Module):
|
32 |
+
def __init__(self, pretrained=False, freeze=False):
|
33 |
+
super(CRAFT, self).__init__()
|
34 |
+
|
35 |
+
""" Base network """
|
36 |
+
self.basenet = vgg16_bn(pretrained, freeze)
|
37 |
+
|
38 |
+
""" U network """
|
39 |
+
self.upconv1 = double_conv(1024, 512, 256)
|
40 |
+
self.upconv2 = double_conv(512, 256, 128)
|
41 |
+
self.upconv3 = double_conv(256, 128, 64)
|
42 |
+
self.upconv4 = double_conv(128, 64, 32)
|
43 |
+
|
44 |
+
num_class = 2
|
45 |
+
self.conv_cls = nn.Sequential(
|
46 |
+
nn.Conv2d(32, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True),
|
47 |
+
nn.Conv2d(32, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True),
|
48 |
+
nn.Conv2d(32, 16, kernel_size=3, padding=1), nn.ReLU(inplace=True),
|
49 |
+
nn.Conv2d(16, 16, kernel_size=1), nn.ReLU(inplace=True),
|
50 |
+
nn.Conv2d(16, num_class, kernel_size=1),
|
51 |
+
)
|
52 |
+
|
53 |
+
init_weights(self.upconv1.modules())
|
54 |
+
init_weights(self.upconv2.modules())
|
55 |
+
init_weights(self.upconv3.modules())
|
56 |
+
init_weights(self.upconv4.modules())
|
57 |
+
init_weights(self.conv_cls.modules())
|
58 |
+
|
59 |
+
def forward(self, x):
|
60 |
+
""" Base network """
|
61 |
+
sources = self.basenet(x)
|
62 |
+
|
63 |
+
""" U network """
|
64 |
+
y = torch.cat([sources[0], sources[1]], dim=1)
|
65 |
+
y = self.upconv1(y)
|
66 |
+
|
67 |
+
y = F.interpolate(y, size=sources[2].size()[2:], mode='bilinear', align_corners=False)
|
68 |
+
y = torch.cat([y, sources[2]], dim=1)
|
69 |
+
y = self.upconv2(y)
|
70 |
+
|
71 |
+
y = F.interpolate(y, size=sources[3].size()[2:], mode='bilinear', align_corners=False)
|
72 |
+
y = torch.cat([y, sources[3]], dim=1)
|
73 |
+
y = self.upconv3(y)
|
74 |
+
|
75 |
+
y = F.interpolate(y, size=sources[4].size()[2:], mode='bilinear', align_corners=False)
|
76 |
+
y = torch.cat([y, sources[4]], dim=1)
|
77 |
+
feature = self.upconv4(y)
|
78 |
+
|
79 |
+
y = self.conv_cls(feature)
|
80 |
+
|
81 |
+
return y.permute(0,2,3,1), feature
|
82 |
+
|
83 |
+
if __name__ == '__main__':
|
84 |
+
model = CRAFT(pretrained=True).cuda()
|
85 |
+
output, _ = model(torch.randn(1, 3, 768, 768).cuda())
|
86 |
+
print(output.shape)
|
craft_utils.py
ADDED
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Modify to Return Scores of Detection Boxes"""
|
2 |
+
|
3 |
+
"""
|
4 |
+
Copyright (c) 2019-present NAVER Corp.
|
5 |
+
MIT License
|
6 |
+
"""
|
7 |
+
|
8 |
+
# -*- coding: utf-8 -*-
|
9 |
+
import numpy as np
|
10 |
+
import cv2
|
11 |
+
import math
|
12 |
+
|
13 |
+
""" auxilary functions """
|
14 |
+
|
15 |
+
|
16 |
+
# unwarp corodinates
|
17 |
+
def warpCoord(Minv, pt):
|
18 |
+
out = np.matmul(Minv, (pt[0], pt[1], 1))
|
19 |
+
return np.array([out[0] / out[2], out[1] / out[2]])
|
20 |
+
|
21 |
+
|
22 |
+
""" end of auxilary functions """
|
23 |
+
|
24 |
+
|
25 |
+
def getDetBoxes_core(textmap, linkmap, text_threshold, link_threshold, low_text):
|
26 |
+
# prepare data
|
27 |
+
linkmap = linkmap.copy()
|
28 |
+
textmap = textmap.copy()
|
29 |
+
img_h, img_w = textmap.shape
|
30 |
+
|
31 |
+
# Helper function for generating random colors
|
32 |
+
def random_color():
|
33 |
+
return tuple(np.random.randint(0, 255, 3).tolist())
|
34 |
+
|
35 |
+
""" labeling method """
|
36 |
+
ret, text_score = cv2.threshold(textmap, low_text, 1, 0)
|
37 |
+
ret, link_score = cv2.threshold(linkmap, link_threshold, 1, 0)
|
38 |
+
|
39 |
+
text_score_comb = np.clip(text_score + link_score, 0, 1)
|
40 |
+
nLabels, labels, stats, centroids = cv2.connectedComponentsWithStats(text_score_comb.astype(np.uint8),
|
41 |
+
connectivity=4)
|
42 |
+
|
43 |
+
# Create a color version of linkmap for visualization
|
44 |
+
visualized_linkmap = cv2.cvtColor(linkmap, cv2.COLOR_GRAY2BGR)
|
45 |
+
det = []
|
46 |
+
det_scores = []
|
47 |
+
mapper = []
|
48 |
+
for k in range(1,nLabels):
|
49 |
+
# visualize stats on the original linkmap
|
50 |
+
x, y, w, h = stats[k, cv2.CC_STAT_LEFT], stats[k, cv2.CC_STAT_TOP], stats[k, cv2.CC_STAT_WIDTH], stats[
|
51 |
+
k, cv2.CC_STAT_HEIGHT]
|
52 |
+
cv2.rectangle(visualized_linkmap, (x, y), (x + w, y + h), random_color(), 2)
|
53 |
+
|
54 |
+
# size filtering
|
55 |
+
size = stats[k, cv2.CC_STAT_AREA]
|
56 |
+
if size < 10: continue
|
57 |
+
|
58 |
+
# thresholding
|
59 |
+
if np.max(textmap[labels == k]) < text_threshold: continue
|
60 |
+
|
61 |
+
# make segmentation map
|
62 |
+
segmap = np.zeros(textmap.shape, dtype=np.uint8)
|
63 |
+
segmap[labels==k] = 255
|
64 |
+
segmap[np.logical_and(link_score==1, text_score==0)] = 0 # remove link area
|
65 |
+
x, y = stats[k, cv2.CC_STAT_LEFT], stats[k, cv2.CC_STAT_TOP]
|
66 |
+
w, h = stats[k, cv2.CC_STAT_WIDTH], stats[k, cv2.CC_STAT_HEIGHT]
|
67 |
+
niter = int(math.sqrt(size * min(w, h) / (w * h)) * 2)
|
68 |
+
sx, ex, sy, ey = x - niter, x + w + niter + 1, y - niter, y + h + niter + 1
|
69 |
+
# boundary check
|
70 |
+
if sx < 0 : sx = 0
|
71 |
+
if sy < 0 : sy = 0
|
72 |
+
if ex >= img_w: ex = img_w
|
73 |
+
if ey >= img_h: ey = img_h
|
74 |
+
kernel = cv2.getStructuringElement(cv2.MORPH_RECT,(1 + niter, 1 + niter))
|
75 |
+
segmap[sy:ey, sx:ex] = cv2.dilate(segmap[sy:ey, sx:ex], kernel)
|
76 |
+
|
77 |
+
# make box
|
78 |
+
np_contours = np.roll(np.array(np.where(segmap!=0)),1,axis=0).transpose().reshape(-1,2)
|
79 |
+
rectangle = cv2.minAreaRect(np_contours)
|
80 |
+
box = cv2.boxPoints(rectangle)
|
81 |
+
|
82 |
+
# align diamond-shape
|
83 |
+
w, h = np.linalg.norm(box[0] - box[1]), np.linalg.norm(box[1] - box[2])
|
84 |
+
box_ratio = max(w, h) / (min(w, h) + 1e-5)
|
85 |
+
if abs(1 - box_ratio) <= 0.1:
|
86 |
+
l, r = min(np_contours[:,0]), max(np_contours[:,0])
|
87 |
+
t, b = min(np_contours[:,1]), max(np_contours[:,1])
|
88 |
+
box = np.array([[l, t], [r, t], [r, b], [l, b]], dtype=np.float32)
|
89 |
+
|
90 |
+
# make clock-wise order
|
91 |
+
startidx = box.sum(axis=1).argmin()
|
92 |
+
box = np.roll(box, 4-startidx, 0)
|
93 |
+
box = np.array(box)
|
94 |
+
|
95 |
+
det.append(box)
|
96 |
+
mapper.append(k)
|
97 |
+
det_scores.append(np.max(textmap[labels==k]))
|
98 |
+
# # Show the visualized linkmap with stats drawn
|
99 |
+
# cv2.imshow("Visualized Linkmap with Stats", visualized_linkmap)
|
100 |
+
# cv2.waitKey(0)
|
101 |
+
# cv2.destroyAllWindows()
|
102 |
+
return det, labels, mapper, det_scores
|
103 |
+
|
104 |
+
def getPoly_core(boxes, labels, mapper, linkmap):
|
105 |
+
# configs
|
106 |
+
num_cp = 5
|
107 |
+
max_len_ratio = 0.7
|
108 |
+
expand_ratio = 1.45
|
109 |
+
max_r = 2.0
|
110 |
+
step_r = 0.2
|
111 |
+
|
112 |
+
polys = []
|
113 |
+
for k, box in enumerate(boxes):
|
114 |
+
# size filter for small instance
|
115 |
+
w, h = int(np.linalg.norm(box[0] - box[1]) + 1), int(np.linalg.norm(box[1] - box[2]) + 1)
|
116 |
+
if w < 10 or h < 10:
|
117 |
+
polys.append(None);
|
118 |
+
continue
|
119 |
+
|
120 |
+
# warp image
|
121 |
+
tar = np.float32([[0, 0], [w, 0], [w, h], [0, h]])
|
122 |
+
M = cv2.getPerspectiveTransform(box, tar)
|
123 |
+
word_label = cv2.warpPerspective(labels, M, (w, h), flags=cv2.INTER_NEAREST)
|
124 |
+
try:
|
125 |
+
Minv = np.linalg.inv(M)
|
126 |
+
except:
|
127 |
+
polys.append(None);
|
128 |
+
continue
|
129 |
+
|
130 |
+
# binarization for selected label
|
131 |
+
cur_label = mapper[k]
|
132 |
+
word_label[word_label != cur_label] = 0
|
133 |
+
word_label[word_label > 0] = 1
|
134 |
+
|
135 |
+
""" Polygon generation """
|
136 |
+
# find top/bottom contours
|
137 |
+
cp = []
|
138 |
+
max_len = -1
|
139 |
+
for i in range(w):
|
140 |
+
region = np.where(word_label[:, i] != 0)[0]
|
141 |
+
if len(region) < 2: continue
|
142 |
+
cp.append((i, region[0], region[-1]))
|
143 |
+
length = region[-1] - region[0] + 1
|
144 |
+
if length > max_len: max_len = length
|
145 |
+
|
146 |
+
# pass if max_len is similar to h
|
147 |
+
if h * max_len_ratio < max_len:
|
148 |
+
polys.append(None);
|
149 |
+
continue
|
150 |
+
|
151 |
+
# get pivot points with fixed length
|
152 |
+
tot_seg = num_cp * 2 + 1
|
153 |
+
seg_w = w / tot_seg # segment width
|
154 |
+
pp = [None] * num_cp # init pivot points
|
155 |
+
cp_section = [[0, 0]] * tot_seg
|
156 |
+
seg_height = [0] * num_cp
|
157 |
+
seg_num = 0
|
158 |
+
num_sec = 0
|
159 |
+
prev_h = -1
|
160 |
+
for i in range(0, len(cp)):
|
161 |
+
(x, sy, ey) = cp[i]
|
162 |
+
if (seg_num + 1) * seg_w <= x and seg_num <= tot_seg:
|
163 |
+
# average previous segment
|
164 |
+
if num_sec == 0: break
|
165 |
+
cp_section[seg_num] = [cp_section[seg_num][0] / num_sec, cp_section[seg_num][1] / num_sec]
|
166 |
+
num_sec = 0
|
167 |
+
|
168 |
+
# reset variables
|
169 |
+
seg_num += 1
|
170 |
+
prev_h = -1
|
171 |
+
|
172 |
+
# accumulate center points
|
173 |
+
cy = (sy + ey) * 0.5
|
174 |
+
cur_h = ey - sy + 1
|
175 |
+
cp_section[seg_num] = [cp_section[seg_num][0] + x, cp_section[seg_num][1] + cy]
|
176 |
+
num_sec += 1
|
177 |
+
|
178 |
+
if seg_num % 2 == 0: continue # No polygon area
|
179 |
+
|
180 |
+
if prev_h < cur_h:
|
181 |
+
pp[int((seg_num - 1) / 2)] = (x, cy)
|
182 |
+
seg_height[int((seg_num - 1) / 2)] = cur_h
|
183 |
+
prev_h = cur_h
|
184 |
+
|
185 |
+
# processing last segment
|
186 |
+
if num_sec != 0:
|
187 |
+
cp_section[-1] = [cp_section[-1][0] / num_sec, cp_section[-1][1] / num_sec]
|
188 |
+
|
189 |
+
# pass if num of pivots is not sufficient or segment widh is smaller than character height
|
190 |
+
if None in pp or seg_w < np.max(seg_height) * 0.25:
|
191 |
+
polys.append(None);
|
192 |
+
continue
|
193 |
+
|
194 |
+
# calc median maximum of pivot points
|
195 |
+
half_char_h = np.median(seg_height) * expand_ratio / 2
|
196 |
+
|
197 |
+
# calc gradiant and apply to make horizontal pivots
|
198 |
+
new_pp = []
|
199 |
+
for i, (x, cy) in enumerate(pp):
|
200 |
+
dx = cp_section[i * 2 + 2][0] - cp_section[i * 2][0]
|
201 |
+
dy = cp_section[i * 2 + 2][1] - cp_section[i * 2][1]
|
202 |
+
if dx == 0: # gradient if zero
|
203 |
+
new_pp.append([x, cy - half_char_h, x, cy + half_char_h])
|
204 |
+
continue
|
205 |
+
rad = - math.atan2(dy, dx)
|
206 |
+
c, s = half_char_h * math.cos(rad), half_char_h * math.sin(rad)
|
207 |
+
new_pp.append([x - s, cy - c, x + s, cy + c])
|
208 |
+
|
209 |
+
# get edge points to cover character heatmaps
|
210 |
+
isSppFound, isEppFound = False, False
|
211 |
+
grad_s = (pp[1][1] - pp[0][1]) / (pp[1][0] - pp[0][0]) + (pp[2][1] - pp[1][1]) / (pp[2][0] - pp[1][0])
|
212 |
+
grad_e = (pp[-2][1] - pp[-1][1]) / (pp[-2][0] - pp[-1][0]) + (pp[-3][1] - pp[-2][1]) / (pp[-3][0] - pp[-2][0])
|
213 |
+
for r in np.arange(0.5, max_r, step_r):
|
214 |
+
dx = 2 * half_char_h * r
|
215 |
+
if not isSppFound:
|
216 |
+
line_img = np.zeros(word_label.shape, dtype=np.uint8)
|
217 |
+
dy = grad_s * dx
|
218 |
+
p = np.array(new_pp[0]) - np.array([dx, dy, dx, dy])
|
219 |
+
cv2.line(line_img, (int(p[0]), int(p[1])), (int(p[2]), int(p[3])), 1, thickness=1)
|
220 |
+
if np.sum(np.logical_and(word_label, line_img)) == 0 or r + 2 * step_r >= max_r:
|
221 |
+
spp = p
|
222 |
+
isSppFound = True
|
223 |
+
if not isEppFound:
|
224 |
+
line_img = np.zeros(word_label.shape, dtype=np.uint8)
|
225 |
+
dy = grad_e * dx
|
226 |
+
p = np.array(new_pp[-1]) + np.array([dx, dy, dx, dy])
|
227 |
+
cv2.line(line_img, (int(p[0]), int(p[1])), (int(p[2]), int(p[3])), 1, thickness=1)
|
228 |
+
if np.sum(np.logical_and(word_label, line_img)) == 0 or r + 2 * step_r >= max_r:
|
229 |
+
epp = p
|
230 |
+
isEppFound = True
|
231 |
+
if isSppFound and isEppFound:
|
232 |
+
break
|
233 |
+
|
234 |
+
# pass if boundary of polygon is not found
|
235 |
+
if not (isSppFound and isEppFound):
|
236 |
+
polys.append(None);
|
237 |
+
continue
|
238 |
+
|
239 |
+
# make final polygon
|
240 |
+
poly = []
|
241 |
+
poly.append(warpCoord(Minv, (spp[0], spp[1])))
|
242 |
+
for p in new_pp:
|
243 |
+
poly.append(warpCoord(Minv, (p[0], p[1])))
|
244 |
+
poly.append(warpCoord(Minv, (epp[0], epp[1])))
|
245 |
+
poly.append(warpCoord(Minv, (epp[2], epp[3])))
|
246 |
+
for p in reversed(new_pp):
|
247 |
+
poly.append(warpCoord(Minv, (p[2], p[3])))
|
248 |
+
poly.append(warpCoord(Minv, (spp[2], spp[3])))
|
249 |
+
|
250 |
+
# add to final result
|
251 |
+
polys.append(np.array(poly))
|
252 |
+
|
253 |
+
return polys
|
254 |
+
|
255 |
+
|
256 |
+
def getDetBoxes(textmap, linkmap, text_threshold, link_threshold, low_text, poly=False):
|
257 |
+
boxes, labels, mapper, det_scores = getDetBoxes_core(textmap, linkmap, text_threshold, link_threshold, low_text)
|
258 |
+
|
259 |
+
if poly:
|
260 |
+
polys = getPoly_core(boxes, labels, mapper, linkmap)
|
261 |
+
else:
|
262 |
+
polys = [None] * len(boxes)
|
263 |
+
|
264 |
+
return boxes, polys, det_scores
|
265 |
+
|
266 |
+
|
267 |
+
def adjustResultCoordinates(polys, ratio_w, ratio_h, ratio_net=2):
|
268 |
+
for i in range(len(polys)):
|
269 |
+
if polys[i] is not None:
|
270 |
+
for j in range(len(polys[i])):
|
271 |
+
polys[i][j][0] *= ratio_w * ratio_net
|
272 |
+
polys[i][j][1] *= ratio_h * ratio_net
|
273 |
+
return polys
|
crop_Word/crop_1.jpg
ADDED
crop_Word/crop_10.jpg
ADDED
crop_Word/crop_11.jpg
ADDED
crop_Word/crop_12.jpg
ADDED
crop_Word/crop_13.jpg
ADDED
crop_Word/crop_14.jpg
ADDED
crop_Word/crop_15.jpg
ADDED
crop_Word/crop_16.jpg
ADDED
crop_Word/crop_17.jpg
ADDED
crop_Word/crop_18.jpg
ADDED
crop_Word/crop_19.jpg
ADDED
crop_Word/crop_2.jpg
ADDED
crop_Word/crop_20.jpg
ADDED
crop_Word/crop_21.jpg
ADDED
crop_Word/crop_22.jpg
ADDED
crop_Word/crop_23.jpg
ADDED
crop_Word/crop_24.jpg
ADDED
crop_Word/crop_25.jpg
ADDED
crop_Word/crop_26.jpg
ADDED
crop_Word/crop_27.jpg
ADDED
crop_Word/crop_28.jpg
ADDED
crop_Word/crop_29.jpg
ADDED
crop_Word/crop_3.jpg
ADDED
crop_Word/crop_30.jpg
ADDED
crop_Word/crop_31.jpg
ADDED
crop_Word/crop_32.jpg
ADDED
crop_Word/crop_33.jpg
ADDED
crop_Word/crop_34.jpg
ADDED
crop_Word/crop_4.jpg
ADDED
crop_Word/crop_5.jpg
ADDED
crop_Word/crop_6.jpg
ADDED
crop_Word/crop_7.jpg
ADDED
crop_Word/crop_8.jpg
ADDED
crop_Word/crop_9.jpg
ADDED
crop_images.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import cv2
|
4 |
+
import pandas as pd
|
5 |
+
|
6 |
+
|
7 |
+
def crop(pts, image):
|
8 |
+
"""
|
9 |
+
Takes inputs as 8 points
|
10 |
+
and Returns cropped, masked image with a white background
|
11 |
+
"""
|
12 |
+
# Giới hạn giá trị của pts
|
13 |
+
pts[:, 0] = np.clip(pts[:, 0], 0, image.shape[1] - 1)
|
14 |
+
pts[:, 1] = np.clip(pts[:, 1], 0, image.shape[0] - 1)
|
15 |
+
|
16 |
+
rect = cv2.boundingRect(pts)
|
17 |
+
x, y, w, h = rect
|
18 |
+
x = int(x)
|
19 |
+
y = int(y)
|
20 |
+
w = int(w)
|
21 |
+
if h == 0 or w == 0:
|
22 |
+
return np.ones((10, 10, 3),
|
23 |
+
np.uint8) * 255 # Trả về một ảnh trắng 10x10, bạn có thể thay đổi kích thước này nếu muốn
|
24 |
+
|
25 |
+
cropped = image[y:y + h, x:x + w].copy()
|
26 |
+
pts = pts - pts.min(axis=0)
|
27 |
+
mask = np.zeros(cropped.shape[:2], np.uint8)
|
28 |
+
# print("Kích thước của mask:", mask.shape)
|
29 |
+
# print("Kích thước của cropped:", cropped.shape)
|
30 |
+
# print("Giá trị của pts:", pts)
|
31 |
+
|
32 |
+
cv2.drawContours(mask, [pts], -1, (255, 255, 255), -1, cv2.LINE_AA)
|
33 |
+
dst = cv2.bitwise_and(cropped, cropped, mask=mask)
|
34 |
+
bg = np.ones_like(cropped, np.uint8) * 255
|
35 |
+
cv2.bitwise_not(bg, bg, mask=mask)
|
36 |
+
dst2 = bg + dst
|
37 |
+
return dst2
|
38 |
+
|
39 |
+
|
40 |
+
def generate_words(image_name, score_bbox, image):
|
41 |
+
num_bboxes = len(score_bbox)
|
42 |
+
for num in range(num_bboxes):
|
43 |
+
bbox_coords = score_bbox[num].split(':')[-1].split(',\n')
|
44 |
+
if bbox_coords != ['{}']:
|
45 |
+
l_t = float(bbox_coords[0].strip(' array([').strip(']').split(',')[0])
|
46 |
+
t_l = float(bbox_coords[0].strip(' array([').strip(']').split(',')[1])
|
47 |
+
r_t = float(bbox_coords[1].strip(' [').strip(']').split(',')[0])
|
48 |
+
t_r = float(bbox_coords[1].strip(' [').strip(']').split(',')[1])
|
49 |
+
r_b = float(bbox_coords[2].strip(' [').strip(']').split(',')[0])
|
50 |
+
b_r = float(bbox_coords[2].strip(' [').strip(']').split(',')[1])
|
51 |
+
l_b = float(bbox_coords[3].strip(' [').strip(']').split(',')[0])
|
52 |
+
b_l = float(bbox_coords[3].strip(' [').strip(']').split(',')[1].strip(']'))
|
53 |
+
pts = np.array([[int(l_t), int(t_l)], [int(r_t), int(t_r)], [int(r_b), int(b_r)], [int(l_b), int(b_l)]])
|
54 |
+
|
55 |
+
if np.all(pts) > 0:
|
56 |
+
|
57 |
+
word = crop(pts, image)
|
58 |
+
|
59 |
+
folder = '/'.join(image_name.split('/')[:-1])
|
60 |
+
# CHANGE DIR
|
61 |
+
dir = '/content/Pipeline/Crop Words/'
|
62 |
+
if os.path.isdir(os.path.join(dir + folder)) == False:
|
63 |
+
os.makedirs(os.path.join(dir + folder))
|
64 |
+
try:
|
65 |
+
file_name = os.path.join(dir + image_name)
|
66 |
+
cv2.imwrite(
|
67 |
+
file_name + '_{}_{}_{}_{}_{}_{}_{}_{}.jpg'.format(l_t, t_l, r_t, t_r, r_b, b_r, l_b, b_l), word)
|
68 |
+
print('Image saved to ' + file_name + '_{}_{}_{}_{}_{}_{}_{}_{}.jpg'.format(l_t, t_l, r_t, t_r, r_b,
|
69 |
+
b_r, l_b, b_l))
|
70 |
+
except:
|
71 |
+
continue
|
72 |
+
|
73 |
+
|
file_utils.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
import os
|
3 |
+
import numpy as np
|
4 |
+
import cv2
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
import crop_images
|
7 |
+
import imgproc
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
|
10 |
+
|
11 |
+
# borrowed from https://github.com/lengstrom/fast-style-transfer/blob/master/src/utils.py
|
12 |
+
def get_files(img_dir):
|
13 |
+
imgs, masks, xmls = list_files(img_dir)
|
14 |
+
return imgs, masks, xmls
|
15 |
+
|
16 |
+
|
17 |
+
def list_files(in_path):
|
18 |
+
img_files = []
|
19 |
+
mask_files = []
|
20 |
+
gt_files = []
|
21 |
+
for (dirpath, dirnames, filenames) in os.walk(in_path):
|
22 |
+
for file in filenames:
|
23 |
+
filename, ext = os.path.splitext(file)
|
24 |
+
ext = str.lower(ext)
|
25 |
+
if ext == '.jpg' or ext == '.jpeg' or ext == '.gif' or ext == '.png' or ext == '.pgm':
|
26 |
+
img_files.append(os.path.join(dirpath, file))
|
27 |
+
elif ext == '.bmp':
|
28 |
+
mask_files.append(os.path.join(dirpath, file))
|
29 |
+
elif ext == '.xml' or ext == '.gt' or ext == '.txt':
|
30 |
+
gt_files.append(os.path.join(dirpath, file))
|
31 |
+
elif ext == '.zip':
|
32 |
+
continue
|
33 |
+
# img_files.sort()
|
34 |
+
# mask_files.sort()
|
35 |
+
# gt_files.sort()
|
36 |
+
return img_files, mask_files, gt_files
|
37 |
+
|
38 |
+
|
39 |
+
def saveResult(img_file, img, boxes, dirname='Results', verticals=None, texts=None):
|
40 |
+
""" save text detection result one by one
|
41 |
+
Args:
|
42 |
+
img_file (str): image file name
|
43 |
+
img (array): raw image context
|
44 |
+
boxes (array): array of result file
|
45 |
+
Shape: [num_detections, 4] for BB output / [num_detections, 4] for QUAD output
|
46 |
+
Return:
|
47 |
+
None
|
48 |
+
"""
|
49 |
+
|
50 |
+
img = np.array(img)
|
51 |
+
|
52 |
+
# make result file list: tên ảnh và đuôi ảnh
|
53 |
+
filename, file_ext = os.path.splitext(os.path.basename(img_file))
|
54 |
+
# result directory
|
55 |
+
res_file = dirname + "res_" + filename + '.txt'
|
56 |
+
res_img_file = dirname + "res_" + filename + '.jpg'
|
57 |
+
|
58 |
+
if not os.path.isdir(dirname):
|
59 |
+
os.mkdir(dirname)
|
60 |
+
with open(res_file, 'w') as f:
|
61 |
+
for i, box in enumerate(boxes):
|
62 |
+
poly = np.array(box).astype(np.int32).reshape((-1))
|
63 |
+
strResult = ','.join([str(p) for p in poly]) + '\r\n'
|
64 |
+
f.write(strResult)
|
65 |
+
# bỏ comment đoạn này để có thể bounding box lại b
|
66 |
+
poly = poly.reshape(-1, 2)
|
67 |
+
# cv2.polylines(img, [poly.reshape((-1, 1, 2))], True, color=(0, 0, 255), thickness=2)
|
68 |
+
# ptColor = (0, 255, 255)
|
69 |
+
xmin = min(poly[:, 0])
|
70 |
+
xmax = max(poly[:, 0])
|
71 |
+
ymin = min(poly[:, 1])
|
72 |
+
ymax = max(poly[:, 1])
|
73 |
+
width = xmax - xmin
|
74 |
+
height = ymax - ymin
|
75 |
+
# các điểm này từ file txt
|
76 |
+
pts = np.array([[xmin, ymax], [xmax, ymax], [xmax, ymin], [xmin, ymin]])
|
77 |
+
|
78 |
+
word = crop_images.crop(pts, img)
|
79 |
+
|
80 |
+
folder = '/'.join(filename.split('/')[:-1])
|
81 |
+
# đầu tiên đây là folder cropWord
|
82 |
+
dir = 'cropWord/'
|
83 |
+
if os.path.isdir(os.path.join(dir + folder)) == False:
|
84 |
+
os.makedirs(os.path.join(dir + folder))
|
85 |
+
try:
|
86 |
+
file_name = os.path.join(dir + filename)
|
87 |
+
cv2.imwrite(file_name + str(i) + '.jpg', word)
|
88 |
+
except:
|
89 |
+
continue
|
90 |
+
|
91 |
+
if verticals is not None:
|
92 |
+
if verticals[i]:
|
93 |
+
ptColor = (255, 0, 0)
|
94 |
+
|
95 |
+
if texts is not None:
|
96 |
+
font = cv2.FONT_HERSHEY_SIMPLEX
|
97 |
+
font_scale = 0.5
|
98 |
+
cv2.putText(img, "{}".format(texts[i]), (poly[0][0] + 1, poly[0][1] + 1), font, font_scale, (0, 0, 0),
|
99 |
+
thickness=1)
|
100 |
+
cv2.putText(img, "{:.2f}".format(texts[i]), tuple(poly[0]), font, font_scale, (0, 255, 255),
|
101 |
+
thickness=1)
|
102 |
+
|
103 |
+
# Save result image
|
104 |
+
# cv2.imwrite(res_img_file, img)
|
git
ADDED
File without changes
|
imgproc.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) 2019-present NAVER Corp.
|
3 |
+
MIT License
|
4 |
+
"""
|
5 |
+
|
6 |
+
# -*- coding: utf-8 -*-
|
7 |
+
import numpy as np
|
8 |
+
from skimage import io
|
9 |
+
import cv2
|
10 |
+
|
11 |
+
|
12 |
+
def loadImage(img_file):
|
13 |
+
img = io.imread(img_file) # RGB order
|
14 |
+
if img.shape[0] == 2: img = img[0]
|
15 |
+
if len(img.shape) == 2: img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
|
16 |
+
if img.shape[2] == 4: img = img[:, :, :3]
|
17 |
+
img = np.array(img)
|
18 |
+
|
19 |
+
return img
|
20 |
+
|
21 |
+
|
22 |
+
def normalizeMeanVariance(in_img, mean=(0.485, 0.456, 0.406), variance=(0.229, 0.224, 0.225)):
|
23 |
+
# should be RGB order
|
24 |
+
img = in_img.copy().astype(np.float32)
|
25 |
+
|
26 |
+
img -= np.array([mean[0] * 255.0, mean[1] * 255.0, mean[2] * 255.0], dtype=np.float32)
|
27 |
+
img /= np.array([variance[0] * 255.0, variance[1] * 255.0, variance[2] * 255.0], dtype=np.float32)
|
28 |
+
return img
|
29 |
+
|
30 |
+
|
31 |
+
def denormalizeMeanVariance(in_img, mean=(0.485, 0.456, 0.406), variance=(0.229, 0.224, 0.225)):
|
32 |
+
# should be RGB order
|
33 |
+
img = in_img.copy()
|
34 |
+
img *= variance
|
35 |
+
img += mean
|
36 |
+
img *= 255.0
|
37 |
+
img = np.clip(img, 0, 255).astype(np.uint8)
|
38 |
+
return img
|
39 |
+
|
40 |
+
|
41 |
+
def resize_aspect_ratio(img, square_size, interpolation, mag_ratio=1):
|
42 |
+
height, width, channel = img.shape
|
43 |
+
|
44 |
+
# magnify image size
|
45 |
+
target_size = mag_ratio * max(height, width)
|
46 |
+
|
47 |
+
# set original image size
|
48 |
+
if target_size > square_size:
|
49 |
+
target_size = square_size
|
50 |
+
|
51 |
+
ratio = target_size / max(height, width)
|
52 |
+
|
53 |
+
target_h, target_w = int(height * ratio), int(width * ratio)
|
54 |
+
proc = cv2.resize(img, (target_w, target_h), interpolation=interpolation)
|
55 |
+
|
56 |
+
# make canvas and paste image
|
57 |
+
target_h32, target_w32 = target_h, target_w
|
58 |
+
if target_h % 32 != 0:
|
59 |
+
target_h32 = target_h + (32 - target_h % 32)
|
60 |
+
if target_w % 32 != 0:
|
61 |
+
target_w32 = target_w + (32 - target_w % 32)
|
62 |
+
resized = np.zeros((target_h32, target_w32, channel), dtype=np.float32)
|
63 |
+
resized[0:target_h, 0:target_w, :] = proc
|
64 |
+
target_h, target_w = target_h32, target_w32
|
65 |
+
|
66 |
+
size_heatmap = (int(target_w / 2), int(target_h / 2))
|
67 |
+
|
68 |
+
return resized, ratio, size_heatmap
|
69 |
+
|
70 |
+
|
71 |
+
def cvt2HeatmapImg(img):
|
72 |
+
img = (np.clip(img, 0, 1) * 255).astype(np.uint8)
|
73 |
+
img = cv2.applyColorMap(img, cv2.COLORMAP_JET)
|
74 |
+
return img
|
refinenet.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) 2019-present NAVER Corp.
|
3 |
+
MIT License
|
4 |
+
"""
|
5 |
+
|
6 |
+
# -*- coding: utf-8 -*-
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from torch.autograd import Variable
|
11 |
+
from basenet.vgg16_bn import init_weights
|
12 |
+
|
13 |
+
|
14 |
+
class RefineNet(nn.Module):
|
15 |
+
def __init__(self):
|
16 |
+
super(RefineNet, self).__init__()
|
17 |
+
|
18 |
+
self.last_conv = nn.Sequential(
|
19 |
+
nn.Conv2d(34, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True),
|
20 |
+
nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True),
|
21 |
+
nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True)
|
22 |
+
)
|
23 |
+
|
24 |
+
self.aspp1 = nn.Sequential(
|
25 |
+
nn.Conv2d(64, 128, kernel_size=3, dilation=6, padding=6), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
|
26 |
+
nn.Conv2d(128, 128, kernel_size=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
|
27 |
+
nn.Conv2d(128, 1, kernel_size=1)
|
28 |
+
)
|
29 |
+
|
30 |
+
self.aspp2 = nn.Sequential(
|
31 |
+
nn.Conv2d(64, 128, kernel_size=3, dilation=12, padding=12), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
|
32 |
+
nn.Conv2d(128, 128, kernel_size=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
|
33 |
+
nn.Conv2d(128, 1, kernel_size=1)
|
34 |
+
)
|
35 |
+
|
36 |
+
self.aspp3 = nn.Sequential(
|
37 |
+
nn.Conv2d(64, 128, kernel_size=3, dilation=18, padding=18), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
|
38 |
+
nn.Conv2d(128, 128, kernel_size=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
|
39 |
+
nn.Conv2d(128, 1, kernel_size=1)
|
40 |
+
)
|
41 |
+
|
42 |
+
self.aspp4 = nn.Sequential(
|
43 |
+
nn.Conv2d(64, 128, kernel_size=3, dilation=24, padding=24), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
|
44 |
+
nn.Conv2d(128, 128, kernel_size=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
|
45 |
+
nn.Conv2d(128, 1, kernel_size=1)
|
46 |
+
)
|
47 |
+
|
48 |
+
init_weights(self.last_conv.modules())
|
49 |
+
init_weights(self.aspp1.modules())
|
50 |
+
init_weights(self.aspp2.modules())
|
51 |
+
init_weights(self.aspp3.modules())
|
52 |
+
init_weights(self.aspp4.modules())
|
53 |
+
|
54 |
+
def forward(self, y, upconv4):
|
55 |
+
refine = torch.cat([y.permute(0,3,1,2), upconv4], dim=1)
|
56 |
+
refine = self.last_conv(refine)
|
57 |
+
|
58 |
+
aspp1 = self.aspp1(refine)
|
59 |
+
aspp2 = self.aspp2(refine)
|
60 |
+
aspp3 = self.aspp3(refine)
|
61 |
+
aspp4 = self.aspp4(refine)
|
62 |
+
|
63 |
+
#out = torch.add([aspp1, aspp2, aspp3, aspp4], dim=1)
|
64 |
+
out = aspp1 + aspp2 + aspp3 + aspp4
|
65 |
+
return out.permute(0, 2, 3, 1) # , refine.permute(0,2,3,1)
|
requirements.txt
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
numpy
|
2 |
+
pandas
|
3 |
+
matplotlib
|
4 |
+
scikit-learn
|
5 |
+
tensorflow
|
6 |
+
seaborn
|
7 |
+
beautifulsoup4
|
8 |
+
requests
|
9 |
+
plotly
|
10 |
+
keras
|
11 |
+
vietocr
|
12 |
+
streamlit
|
13 |
+
torch --index-url https://download.pytorch.org/whl/cu121
|
14 |
+
torchvision --index-url https://download.pytorch.org/whl/cu121
|
15 |
+
torchaudio --index-url https://download.pytorch.org/whl/cu121
|