bachpc commited on
Commit
6a2b711
1 Parent(s): 188eda0
Files changed (1) hide show
  1. app.py +10 -6
app.py CHANGED
@@ -23,11 +23,15 @@ from paddleocr import PaddleOCR
23
  import postprocess
24
 
25
 
26
- ocr_instance = PaddleOCR(use_angle_cls=False, lang='en', use_gpu=True)
27
- detection_model = torch.hub.load('ultralytics/yolov5', 'custom', 'weights/detection_wts.pt', force_reload=True, skip_validation=True, trust_repo=True)
28
- structure_model = torch.hub.load('ultralytics/yolov5', 'custom', 'weights/structure_wts.pt', force_reload=True, skip_validation=True, trust_repo=True)
 
 
 
29
 
30
- imgsz = 640
 
31
 
32
  detection_class_names = ['table', 'table rotated', 'no object']
33
  structure_class_names = [
@@ -62,7 +66,7 @@ def cv_to_PIL(cv_img):
62
  return PIL.Image.fromarray(cv2.cvtColor(cv_img, cv2.COLOR_BGR2RGB))
63
 
64
 
65
- def table_detection(pil_img):
66
  image = PIL_to_cv(pil_img)
67
  pred = detection_model(image, size=imgsz)
68
  pred = pred.xywhn[0]
@@ -70,7 +74,7 @@ def table_detection(pil_img):
70
  return result
71
 
72
 
73
- def table_structure(pil_img):
74
  image = PIL_to_cv(pil_img)
75
  pred = structure_model(image, size=imgsz)
76
  pred = pred.xywhn[0]
 
23
  import postprocess
24
 
25
 
26
+ @st.cache_resource(ttl=3600)
27
+ def load_models():
28
+ ocr_instance = PaddleOCR(use_angle_cls=False, lang='en', use_gpu=True)
29
+ detection_model = torch.hub.load('ultralytics/yolov5', 'custom', 'weights/detection_wts.pt', force_reload=True, skip_validation=True, trust_repo=True)
30
+ structure_model = torch.hub.load('ultralytics/yolov5', 'custom', 'weights/structure_wts.pt', force_reload=True, skip_validation=True, trust_repo=True)
31
+ return ocr_instance, detection_model, structure_model
32
 
33
+
34
+ ocr_instance, detection_model, structure_model = load_models()
35
 
36
  detection_class_names = ['table', 'table rotated', 'no object']
37
  structure_class_names = [
 
66
  return PIL.Image.fromarray(cv2.cvtColor(cv_img, cv2.COLOR_BGR2RGB))
67
 
68
 
69
+ def table_detection(pil_img, imgsz=640):
70
  image = PIL_to_cv(pil_img)
71
  pred = detection_model(image, size=imgsz)
72
  pred = pred.xywhn[0]
 
74
  return result
75
 
76
 
77
+ def table_structure(pil_img, imgsz=640):
78
  image = PIL_to_cv(pil_img)
79
  pred = structure_model(image, size=imgsz)
80
  pred = pred.xywhn[0]