Spaces:
Build error
Build error
Add cache
Browse files
app.py
CHANGED
@@ -23,11 +23,15 @@ from paddleocr import PaddleOCR
|
|
23 |
import postprocess
|
24 |
|
25 |
|
26 |
-
|
27 |
-
|
28 |
-
|
|
|
|
|
|
|
29 |
|
30 |
-
|
|
|
31 |
|
32 |
detection_class_names = ['table', 'table rotated', 'no object']
|
33 |
structure_class_names = [
|
@@ -62,7 +66,7 @@ def cv_to_PIL(cv_img):
|
|
62 |
return PIL.Image.fromarray(cv2.cvtColor(cv_img, cv2.COLOR_BGR2RGB))
|
63 |
|
64 |
|
65 |
-
def table_detection(pil_img):
|
66 |
image = PIL_to_cv(pil_img)
|
67 |
pred = detection_model(image, size=imgsz)
|
68 |
pred = pred.xywhn[0]
|
@@ -70,7 +74,7 @@ def table_detection(pil_img):
|
|
70 |
return result
|
71 |
|
72 |
|
73 |
-
def table_structure(pil_img):
|
74 |
image = PIL_to_cv(pil_img)
|
75 |
pred = structure_model(image, size=imgsz)
|
76 |
pred = pred.xywhn[0]
|
|
|
23 |
import postprocess
|
24 |
|
25 |
|
26 |
+
@st.cache_resource(ttl=3600)
|
27 |
+
def load_models():
|
28 |
+
ocr_instance = PaddleOCR(use_angle_cls=False, lang='en', use_gpu=True)
|
29 |
+
detection_model = torch.hub.load('ultralytics/yolov5', 'custom', 'weights/detection_wts.pt', force_reload=True, skip_validation=True, trust_repo=True)
|
30 |
+
structure_model = torch.hub.load('ultralytics/yolov5', 'custom', 'weights/structure_wts.pt', force_reload=True, skip_validation=True, trust_repo=True)
|
31 |
+
return ocr_instance, detection_model, structure_model
|
32 |
|
33 |
+
|
34 |
+
ocr_instance, detection_model, structure_model = load_models()
|
35 |
|
36 |
detection_class_names = ['table', 'table rotated', 'no object']
|
37 |
structure_class_names = [
|
|
|
66 |
return PIL.Image.fromarray(cv2.cvtColor(cv_img, cv2.COLOR_BGR2RGB))
|
67 |
|
68 |
|
69 |
+
def table_detection(pil_img, imgsz=640):
|
70 |
image = PIL_to_cv(pil_img)
|
71 |
pred = detection_model(image, size=imgsz)
|
72 |
pred = pred.xywhn[0]
|
|
|
74 |
return result
|
75 |
|
76 |
|
77 |
+
def table_structure(pil_img, imgsz=640):
|
78 |
image = PIL_to_cv(pil_img)
|
79 |
pred = structure_model(image, size=imgsz)
|
80 |
pred = pred.xywhn[0]
|