domenicCarter
commited on
Commit
•
7013115
1
Parent(s):
e57ccf6
feat: first commit
Browse files
app.py
CHANGED
@@ -1,7 +1,93 @@
|
|
1 |
import gradio as gr
|
|
|
|
|
|
|
2 |
|
3 |
-
|
4 |
-
return "Hello " + name + "!!"
|
5 |
|
6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
demo.launch()
|
|
|
1 |
import gradio as gr
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
from rapidocr_onnxruntime import RapidOCR
|
5 |
|
6 |
+
engine = RapidOCR()
|
|
|
7 |
|
8 |
+
info_points = {
|
9 |
+
"customer_name": [156, 109, 928, 168],
|
10 |
+
"amount": [157, 397, 606, 461],
|
11 |
+
"price": [155, 341, 607, 399],
|
12 |
+
"plateNumber": [740, 173, 928, 227]
|
13 |
+
}
|
14 |
+
|
15 |
+
def find_reference_points(template_image, target_image):
|
16 |
+
# OCR处理模板图像和目标图像
|
17 |
+
template_result, _ = engine(template_image)
|
18 |
+
target_result, _ = engine(target_image)
|
19 |
+
|
20 |
+
reference_points_template = []
|
21 |
+
reference_points_target = []
|
22 |
+
|
23 |
+
# 查找匹配的文本块
|
24 |
+
for template_word in template_result:
|
25 |
+
template_text = template_word[1]
|
26 |
+
template_x, template_y = template_word[0][1]
|
27 |
+
|
28 |
+
for target_word in target_result:
|
29 |
+
target_text = target_word[1]
|
30 |
+
target_x, target_y = target_word[0][1]
|
31 |
+
|
32 |
+
if template_text == target_text:
|
33 |
+
reference_points_template.append((template_x, template_y))
|
34 |
+
reference_points_target.append((target_x, target_y))
|
35 |
+
break
|
36 |
+
|
37 |
+
return np.array(reference_points_template), np.array(reference_points_target)
|
38 |
+
|
39 |
+
def align_images(template_image, target_image):
|
40 |
+
# 找到参考点
|
41 |
+
src_pts, dst_pts = find_reference_points(template_image, target_image)
|
42 |
+
|
43 |
+
if len(src_pts) < 4 or len(dst_pts) < 4:
|
44 |
+
return target_image # 如果找不到足够的参考点,返回原始图像
|
45 |
+
|
46 |
+
# 计算透视变换矩阵
|
47 |
+
M, _ = cv2.findHomography(dst_pts, src_pts, cv2.RANSAC, 5.0)
|
48 |
+
|
49 |
+
# 应用透视变换
|
50 |
+
aligned_image = cv2.warpPerspective(target_image, M, (template_image.shape[1], template_image.shape[0]))
|
51 |
+
|
52 |
+
return aligned_image
|
53 |
+
|
54 |
+
def process_images(template_image, target_image):
|
55 |
+
# 将Gradio的图像格式转换为OpenCV格式
|
56 |
+
template_image = cv2.cvtColor(template_image, cv2.COLOR_RGB2BGR)
|
57 |
+
# template_image = cv2.imread("../data/template.jpg")
|
58 |
+
target_image = cv2.cvtColor(target_image, cv2.COLOR_RGB2BGR)
|
59 |
+
|
60 |
+
# 对齐图像
|
61 |
+
aligned_image = align_images(template_image, target_image)
|
62 |
+
|
63 |
+
# 将结果转换回RGB格式以供Gradio显示
|
64 |
+
aligned_image = cv2.cvtColor(aligned_image, cv2.COLOR_BGR2RGB)
|
65 |
+
|
66 |
+
# 识别信息
|
67 |
+
info_dict = {}
|
68 |
+
# 在info_points中绘制矩形框
|
69 |
+
for key, value in info_points.items():
|
70 |
+
cv2.rectangle(aligned_image, (value[0], value[1]), (value[2], value[3]), (0, 255, 0), 2)
|
71 |
+
# ocr识别
|
72 |
+
ocr_result, _ = engine(aligned_image[value[1]:value[3], value[0]:value[2]])
|
73 |
+
info_dict[key] = ocr_result[0][1]
|
74 |
+
|
75 |
+
return aligned_image, info_dict
|
76 |
+
|
77 |
+
# 创建Gradio界面
|
78 |
+
demo = gr.Interface(
|
79 |
+
fn=process_images,
|
80 |
+
inputs=[
|
81 |
+
gr.Image(label="模板图像"),
|
82 |
+
gr.Image(label="目标图像")
|
83 |
+
],
|
84 |
+
outputs=[
|
85 |
+
gr.Image(label="对齐后的图像"),
|
86 |
+
gr.Textbox(label="识别信息")
|
87 |
+
],
|
88 |
+
title="磅单提取工具",
|
89 |
+
description="上传一张模板图像和一张目标图像,提取关键信息。"
|
90 |
+
)
|
91 |
+
|
92 |
+
# 启动Gradio应用
|
93 |
demo.launch()
|