odulcy-mindee commited on
Commit
4eb7c20
β€’
1 Parent(s): 049c6c7

Switch HF Spaces to Torch (credit: @Felix92 )

Browse files
Files changed (4) hide show
  1. README.md +12 -11
  2. app.py +46 -50
  3. backend/pytorch.py +86 -0
  4. requirements.txt +1 -2
README.md CHANGED
@@ -4,35 +4,36 @@ emoji: πŸ“‘
4
  colorFrom: purple
5
  colorTo: pink
6
  sdk: streamlit
7
- sdk_version: 0.84.2
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
  # Configuration
13
 
14
- `title`: _string_
15
  Display title for the Space
16
 
17
- `emoji`: _string_
18
  Space emoji (emoji-only character allowed)
19
 
20
- `colorFrom`: _string_
21
  Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
22
 
23
- `colorTo`: _string_
24
  Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
25
 
26
- `sdk`: _string_
27
  Can be either `gradio` or `streamlit`
28
 
29
- `sdk_version` : _string_
30
- Only applicable for `streamlit` SDK.
31
  See [doc](https://hf.co/docs/hub/spaces) for more info on supported versions.
32
 
33
- `app_file`: _string_
34
- Path to your main application file (which contains either `gradio` or `streamlit` Python code).
35
  Path is relative to the root of the repository.
36
 
37
- `pinned`: _boolean_
38
  Whether the Space stays on top of your list.
4
  colorFrom: purple
5
  colorTo: pink
6
  sdk: streamlit
7
+ sdk_version: 1.30.0
8
  app_file: app.py
9
  pinned: false
10
+ license: apache-2.0
11
  ---
12
 
13
  # Configuration
14
 
15
+ `title`: _string_
16
  Display title for the Space
17
 
18
+ `emoji`: _string_
19
  Space emoji (emoji-only character allowed)
20
 
21
+ `colorFrom`: _string_
22
  Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
23
 
24
+ `colorTo`: _string_
25
  Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
26
 
27
+ `sdk`: _string_
28
  Can be either `gradio` or `streamlit`
29
 
30
+ `sdk_version` : _string_
31
+ Only applicable for `streamlit` SDK.
32
  See [doc](https://hf.co/docs/hub/spaces) for more info on supported versions.
33
 
34
+ `app_file`: _string_
35
+ Path to your main application file (which contains either `gradio` or `streamlit` Python code).
36
  Path is relative to the root of the repository.
37
 
38
+ `pinned`: _boolean_
39
  Whether the Space stays on top of your list.
app.py CHANGED
@@ -1,47 +1,35 @@
1
- # Copyright (C) 2021, Mindee.
2
 
3
- # This program is licensed under the Apache License version 2.
4
- # See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0.txt> for full license details.
5
-
6
- import os
7
 
 
8
  import matplotlib.pyplot as plt
 
9
  import streamlit as st
10
-
11
- os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
12
-
13
- import cv2
14
- import tensorflow as tf
15
-
16
- gpu_devices = tf.config.experimental.list_physical_devices('GPU')
17
- if any(gpu_devices):
18
- tf.config.experimental.set_memory_growth(gpu_devices[0], True)
19
 
20
  from doctr.io import DocumentFile
21
- from doctr.models import ocr_predictor
22
  from doctr.utils.visualization import visualize_page
23
 
24
- DET_ARCHS = ["db_resnet50", "db_mobilenet_v3_large"]
25
- RECO_ARCHS = ["crnn_vgg16_bn", "crnn_mobilenet_v3_small", "master", "sar_resnet31"]
26
 
 
27
 
28
- def main():
29
 
 
 
30
  # Wide mode
31
  st.set_page_config(layout="wide")
32
 
33
  # Designing the interface
34
  st.title("docTR: Document Text Recognition")
35
  # For newline
36
- st.write('\n')
37
- #
38
- st.write('Find more info at: https://github.com/mindee/doctr')
39
- # For newline
40
- st.write('\n')
41
  # Instructions
42
  st.markdown("*Hint: click on the top-right corner of an image to enlarge it!*")
43
  # Set the columns
44
- cols = st.beta_columns((1, 1, 1, 1))
45
  cols[0].subheader("Input page")
46
  cols[1].subheader("Segmentation heatmap")
47
  cols[2].subheader("OCR output")
@@ -50,64 +38,72 @@ def main():
50
  # Sidebar
51
  # File selection
52
  st.sidebar.title("Document selection")
53
- # Disabling warning
54
- st.set_option('deprecation.showfileUploaderEncoding', False)
55
  # Choose your own image
56
- uploaded_file = st.sidebar.file_uploader("Upload files", type=['pdf', 'png', 'jpeg', 'jpg'])
57
  if uploaded_file is not None:
58
- if uploaded_file.name.endswith('.pdf'):
59
  doc = DocumentFile.from_pdf(uploaded_file.read())
60
  else:
61
  doc = DocumentFile.from_images(uploaded_file.read())
62
  page_idx = st.sidebar.selectbox("Page selection", [idx + 1 for idx in range(len(doc))]) - 1
63
- cols[0].image(doc[page_idx])
 
64
 
65
  # Model selection
66
  st.sidebar.title("Model selection")
67
- det_arch = st.sidebar.selectbox("Text detection model", DET_ARCHS)
68
- reco_arch = st.sidebar.selectbox("Text recognition model", RECO_ARCHS)
69
 
70
  # For newline
71
- st.sidebar.write('\n')
 
 
 
 
 
 
 
 
 
 
72
 
73
  if st.sidebar.button("Analyze page"):
74
-
75
  if uploaded_file is None:
76
  st.sidebar.write("Please upload a document")
77
 
78
  else:
79
- with st.spinner('Loading model...'):
80
- predictor = ocr_predictor(det_arch, reco_arch, pretrained=True)
81
-
82
- with st.spinner('Analyzing...'):
83
 
 
84
  # Forward the image to the model
85
- processed_batches = predictor.det_predictor.pre_processor([doc[page_idx]])
86
- out = predictor.det_predictor.model(processed_batches[0], return_model_output=True)
87
- seg_map = out["out_map"]
88
- seg_map = tf.squeeze(seg_map[0, ...], axis=[2])
89
- seg_map = cv2.resize(seg_map.numpy(), (doc[page_idx].shape[1], doc[page_idx].shape[0]),
90
- interpolation=cv2.INTER_LINEAR)
91
  # Plot the raw heatmap
92
  fig, ax = plt.subplots()
93
  ax.imshow(seg_map)
94
- ax.axis('off')
95
  cols[1].pyplot(fig)
96
 
97
  # Plot OCR output
98
- out = predictor([doc[page_idx]])
99
- fig = visualize_page(out.pages[0].export(), doc[page_idx], interactive=False)
100
  cols[2].pyplot(fig)
101
 
102
  # Page reconsitution under input page
103
  page_export = out.pages[0].export()
104
- img = out.pages[0].synthesize()
105
- cols[3].image(img, clamp=True)
 
106
 
107
  # Display JSON
108
  st.markdown("\nHere are your analysis results in JSON format:")
109
  st.json(page_export)
110
 
111
 
112
- if __name__ == '__main__':
113
- main()
1
+ # Copyright (C) 2021-2024, Mindee.
2
 
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
 
 
5
 
6
+ import cv2
7
  import matplotlib.pyplot as plt
8
+ import numpy as np
9
  import streamlit as st
10
+ import torch
 
 
 
 
 
 
 
 
11
 
12
  from doctr.io import DocumentFile
 
13
  from doctr.utils.visualization import visualize_page
14
 
15
+ from backend.pytorch import DET_ARCHS, RECO_ARCHS, forward_image, load_predictor
 
16
 
17
+ forward_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
18
 
 
19
 
20
+ def main(det_archs, reco_archs):
21
+ """Build a streamlit layout"""
22
  # Wide mode
23
  st.set_page_config(layout="wide")
24
 
25
  # Designing the interface
26
  st.title("docTR: Document Text Recognition")
27
  # For newline
28
+ st.write("\n")
 
 
 
 
29
  # Instructions
30
  st.markdown("*Hint: click on the top-right corner of an image to enlarge it!*")
31
  # Set the columns
32
+ cols = st.columns((1, 1, 1, 1))
33
  cols[0].subheader("Input page")
34
  cols[1].subheader("Segmentation heatmap")
35
  cols[2].subheader("OCR output")
38
  # Sidebar
39
  # File selection
40
  st.sidebar.title("Document selection")
 
 
41
  # Choose your own image
42
+ uploaded_file = st.sidebar.file_uploader("Upload files", type=["pdf", "png", "jpeg", "jpg"])
43
  if uploaded_file is not None:
44
+ if uploaded_file.name.endswith(".pdf"):
45
  doc = DocumentFile.from_pdf(uploaded_file.read())
46
  else:
47
  doc = DocumentFile.from_images(uploaded_file.read())
48
  page_idx = st.sidebar.selectbox("Page selection", [idx + 1 for idx in range(len(doc))]) - 1
49
+ page = doc[page_idx]
50
+ cols[0].image(page)
51
 
52
  # Model selection
53
  st.sidebar.title("Model selection")
54
+ det_arch = st.sidebar.selectbox("Text detection model", det_archs)
55
+ reco_arch = st.sidebar.selectbox("Text recognition model", reco_archs)
56
 
57
  # For newline
58
+ st.sidebar.write("\n")
59
+ # Only straight pages or possible rotation
60
+ st.sidebar.title("Parameters")
61
+ assume_straight_pages = st.sidebar.checkbox("Assume straight pages", value=True)
62
+ st.sidebar.write("\n")
63
+ # Straighten pages
64
+ straighten_pages = st.sidebar.checkbox("Straighten pages", value=False)
65
+ st.sidebar.write("\n")
66
+ # Binarization threshold
67
+ bin_thresh = st.sidebar.slider("Binarization threshold", min_value=0.1, max_value=0.9, value=0.3, step=0.1)
68
+ st.sidebar.write("\n")
69
 
70
  if st.sidebar.button("Analyze page"):
 
71
  if uploaded_file is None:
72
  st.sidebar.write("Please upload a document")
73
 
74
  else:
75
+ with st.spinner("Loading model..."):
76
+ predictor = load_predictor(
77
+ det_arch, reco_arch, assume_straight_pages, straighten_pages, bin_thresh, forward_device
78
+ )
79
 
80
+ with st.spinner("Analyzing..."):
81
  # Forward the image to the model
82
+ seg_map = forward_image(predictor, page, forward_device)
83
+ seg_map = np.squeeze(seg_map)
84
+ seg_map = cv2.resize(seg_map, (page.shape[1], page.shape[0]), interpolation=cv2.INTER_LINEAR)
85
+
 
 
86
  # Plot the raw heatmap
87
  fig, ax = plt.subplots()
88
  ax.imshow(seg_map)
89
+ ax.axis("off")
90
  cols[1].pyplot(fig)
91
 
92
  # Plot OCR output
93
+ out = predictor([page])
94
+ fig = visualize_page(out.pages[0].export(), out.pages[0].page, interactive=False, add_labels=False)
95
  cols[2].pyplot(fig)
96
 
97
  # Page reconsitution under input page
98
  page_export = out.pages[0].export()
99
+ if assume_straight_pages or (not assume_straight_pages and straighten_pages):
100
+ img = out.pages[0].synthesize()
101
+ cols[3].image(img, clamp=True)
102
 
103
  # Display JSON
104
  st.markdown("\nHere are your analysis results in JSON format:")
105
  st.json(page_export)
106
 
107
 
108
+ if __name__ == "__main__":
109
+ main(DET_ARCHS, RECO_ARCHS)
backend/pytorch.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2021-2024, Mindee.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+ import numpy as np
7
+ import torch
8
+
9
+ from doctr.models import ocr_predictor
10
+ from doctr.models.predictor import OCRPredictor
11
+
12
+ DET_ARCHS = [
13
+ "db_resnet50",
14
+ "db_resnet34",
15
+ "db_mobilenet_v3_large",
16
+ "linknet_resnet18",
17
+ "linknet_resnet34",
18
+ "linknet_resnet50",
19
+ ]
20
+ RECO_ARCHS = [
21
+ "crnn_vgg16_bn",
22
+ "crnn_mobilenet_v3_small",
23
+ "crnn_mobilenet_v3_large",
24
+ "master",
25
+ "sar_resnet31",
26
+ "vitstr_small",
27
+ "vitstr_base",
28
+ "parseq",
29
+ ]
30
+
31
+
32
+ def load_predictor(
33
+ det_arch: str,
34
+ reco_arch: str,
35
+ assume_straight_pages: bool,
36
+ straighten_pages: bool,
37
+ bin_thresh: float,
38
+ device: torch.device,
39
+ ) -> OCRPredictor:
40
+ """Load a predictor from doctr.models
41
+
42
+ Args:
43
+ ----
44
+ det_arch: detection architecture
45
+ reco_arch: recognition architecture
46
+ assume_straight_pages: whether to assume straight pages or not
47
+ straighten_pages: whether to straighten rotated pages or not
48
+ bin_thresh: binarization threshold for the segmentation map
49
+ device: torch.device, the device to load the predictor on
50
+
51
+ Returns:
52
+ -------
53
+ instance of OCRPredictor
54
+ """
55
+ predictor = ocr_predictor(
56
+ det_arch,
57
+ reco_arch,
58
+ pretrained=True,
59
+ assume_straight_pages=assume_straight_pages,
60
+ straighten_pages=straighten_pages,
61
+ export_as_straight_boxes=straighten_pages,
62
+ detect_orientation=not assume_straight_pages,
63
+ ).to(device)
64
+ predictor.det_predictor.model.postprocessor.bin_thresh = bin_thresh
65
+ return predictor
66
+
67
+
68
+ def forward_image(predictor: OCRPredictor, image: np.ndarray, device: torch.device) -> np.ndarray:
69
+ """Forward an image through the predictor
70
+
71
+ Args:
72
+ ----
73
+ predictor: instance of OCRPredictor
74
+ image: image to process
75
+ device: torch.device, the device to process the image on
76
+
77
+ Returns:
78
+ -------
79
+ segmentation map
80
+ """
81
+ with torch.no_grad():
82
+ processed_batches = predictor.det_predictor.pre_processor([image])
83
+ out = predictor.det_predictor.model(processed_batches[0].to(device), return_model_output=True)
84
+ seg_map = out["out_map"].to("cpu").numpy()
85
+
86
+ return seg_map
requirements.txt CHANGED
@@ -1,3 +1,2 @@
1
- -e git+https://github.com/mindee/doctr.git#egg=python-doctr[tf]
2
  streamlit>=1.0.0
3
- PyMuPDF>=1.16.0,!=1.18.11,!=1.18.12,!=1.19.5
1
+ -e git+https://github.com/mindee/doctr.git#egg=python-doctr[torch]
2
  streamlit>=1.0.0