Spaces:
Running
Running
Commit
·
4eb7c20
1
Parent(s):
049c6c7
Switch HF Spaces to Torch (credit: @Felix92 )
Browse files- README.md +12 -11
- app.py +46 -50
- backend/pytorch.py +86 -0
- requirements.txt +1 -2
README.md
CHANGED
@@ -4,35 +4,36 @@ emoji: 📑
|
|
4 |
colorFrom: purple
|
5 |
colorTo: pink
|
6 |
sdk: streamlit
|
7 |
-
sdk_version:
|
8 |
app_file: app.py
|
9 |
pinned: false
|
|
|
10 |
---
|
11 |
|
12 |
# Configuration
|
13 |
|
14 |
-
`title`: _string_
|
15 |
Display title for the Space
|
16 |
|
17 |
-
`emoji`: _string_
|
18 |
Space emoji (emoji-only character allowed)
|
19 |
|
20 |
-
`colorFrom`: _string_
|
21 |
Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
|
22 |
|
23 |
-
`colorTo`: _string_
|
24 |
Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
|
25 |
|
26 |
-
`sdk`: _string_
|
27 |
Can be either `gradio` or `streamlit`
|
28 |
|
29 |
-
`sdk_version` : _string_
|
30 |
-
Only applicable for `streamlit` SDK.
|
31 |
See [doc](https://hf.co/docs/hub/spaces) for more info on supported versions.
|
32 |
|
33 |
-
`app_file`: _string_
|
34 |
-
Path to your main application file (which contains either `gradio` or `streamlit` Python code).
|
35 |
Path is relative to the root of the repository.
|
36 |
|
37 |
-
`pinned`: _boolean_
|
38 |
Whether the Space stays on top of your list.
|
|
|
4 |
colorFrom: purple
|
5 |
colorTo: pink
|
6 |
sdk: streamlit
|
7 |
+
sdk_version: 1.30.0
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
+
license: apache-2.0
|
11 |
---
|
12 |
|
13 |
# Configuration
|
14 |
|
15 |
+
`title`: _string_
|
16 |
Display title for the Space
|
17 |
|
18 |
+
`emoji`: _string_
|
19 |
Space emoji (emoji-only character allowed)
|
20 |
|
21 |
+
`colorFrom`: _string_
|
22 |
Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
|
23 |
|
24 |
+
`colorTo`: _string_
|
25 |
Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
|
26 |
|
27 |
+
`sdk`: _string_
|
28 |
Can be either `gradio` or `streamlit`
|
29 |
|
30 |
+
`sdk_version` : _string_
|
31 |
+
Only applicable for `streamlit` SDK.
|
32 |
See [doc](https://hf.co/docs/hub/spaces) for more info on supported versions.
|
33 |
|
34 |
+
`app_file`: _string_
|
35 |
+
Path to your main application file (which contains either `gradio` or `streamlit` Python code).
|
36 |
Path is relative to the root of the repository.
|
37 |
|
38 |
+
`pinned`: _boolean_
|
39 |
Whether the Space stays on top of your list.
|
app.py
CHANGED
@@ -1,47 +1,35 @@
|
|
1 |
-
# Copyright (C) 2021, Mindee.
|
2 |
|
3 |
-
# This program is licensed under the Apache License
|
4 |
-
# See LICENSE or go to <https://
|
5 |
-
|
6 |
-
import os
|
7 |
|
|
|
8 |
import matplotlib.pyplot as plt
|
|
|
9 |
import streamlit as st
|
10 |
-
|
11 |
-
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
|
12 |
-
|
13 |
-
import cv2
|
14 |
-
import tensorflow as tf
|
15 |
-
|
16 |
-
gpu_devices = tf.config.experimental.list_physical_devices('GPU')
|
17 |
-
if any(gpu_devices):
|
18 |
-
tf.config.experimental.set_memory_growth(gpu_devices[0], True)
|
19 |
|
20 |
from doctr.io import DocumentFile
|
21 |
-
from doctr.models import ocr_predictor
|
22 |
from doctr.utils.visualization import visualize_page
|
23 |
|
24 |
-
DET_ARCHS
|
25 |
-
RECO_ARCHS = ["crnn_vgg16_bn", "crnn_mobilenet_v3_small", "master", "sar_resnet31"]
|
26 |
|
|
|
27 |
|
28 |
-
def main():
|
29 |
|
|
|
|
|
30 |
# Wide mode
|
31 |
st.set_page_config(layout="wide")
|
32 |
|
33 |
# Designing the interface
|
34 |
st.title("docTR: Document Text Recognition")
|
35 |
# For newline
|
36 |
-
st.write(
|
37 |
-
#
|
38 |
-
st.write('Find more info at: https://github.com/mindee/doctr')
|
39 |
-
# For newline
|
40 |
-
st.write('\n')
|
41 |
# Instructions
|
42 |
st.markdown("*Hint: click on the top-right corner of an image to enlarge it!*")
|
43 |
# Set the columns
|
44 |
-
cols = st.
|
45 |
cols[0].subheader("Input page")
|
46 |
cols[1].subheader("Segmentation heatmap")
|
47 |
cols[2].subheader("OCR output")
|
@@ -50,64 +38,72 @@ def main():
|
|
50 |
# Sidebar
|
51 |
# File selection
|
52 |
st.sidebar.title("Document selection")
|
53 |
-
# Disabling warning
|
54 |
-
st.set_option('deprecation.showfileUploaderEncoding', False)
|
55 |
# Choose your own image
|
56 |
-
uploaded_file = st.sidebar.file_uploader("Upload files", type=[
|
57 |
if uploaded_file is not None:
|
58 |
-
if uploaded_file.name.endswith(
|
59 |
doc = DocumentFile.from_pdf(uploaded_file.read())
|
60 |
else:
|
61 |
doc = DocumentFile.from_images(uploaded_file.read())
|
62 |
page_idx = st.sidebar.selectbox("Page selection", [idx + 1 for idx in range(len(doc))]) - 1
|
63 |
-
|
|
|
64 |
|
65 |
# Model selection
|
66 |
st.sidebar.title("Model selection")
|
67 |
-
det_arch = st.sidebar.selectbox("Text detection model",
|
68 |
-
reco_arch = st.sidebar.selectbox("Text recognition model",
|
69 |
|
70 |
# For newline
|
71 |
-
st.sidebar.write(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
|
73 |
if st.sidebar.button("Analyze page"):
|
74 |
-
|
75 |
if uploaded_file is None:
|
76 |
st.sidebar.write("Please upload a document")
|
77 |
|
78 |
else:
|
79 |
-
with st.spinner(
|
80 |
-
predictor =
|
81 |
-
|
82 |
-
|
83 |
|
|
|
84 |
# Forward the image to the model
|
85 |
-
|
86 |
-
|
87 |
-
seg_map =
|
88 |
-
|
89 |
-
seg_map = cv2.resize(seg_map.numpy(), (doc[page_idx].shape[1], doc[page_idx].shape[0]),
|
90 |
-
interpolation=cv2.INTER_LINEAR)
|
91 |
# Plot the raw heatmap
|
92 |
fig, ax = plt.subplots()
|
93 |
ax.imshow(seg_map)
|
94 |
-
ax.axis(
|
95 |
cols[1].pyplot(fig)
|
96 |
|
97 |
# Plot OCR output
|
98 |
-
out = predictor([
|
99 |
-
fig = visualize_page(out.pages[0].export(),
|
100 |
cols[2].pyplot(fig)
|
101 |
|
102 |
# Page reconsitution under input page
|
103 |
page_export = out.pages[0].export()
|
104 |
-
|
105 |
-
|
|
|
106 |
|
107 |
# Display JSON
|
108 |
st.markdown("\nHere are your analysis results in JSON format:")
|
109 |
st.json(page_export)
|
110 |
|
111 |
|
112 |
-
if __name__ ==
|
113 |
-
main()
|
|
|
1 |
+
# Copyright (C) 2021-2024, Mindee.
|
2 |
|
3 |
+
# This program is licensed under the Apache License 2.0.
|
4 |
+
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
|
|
|
5 |
|
6 |
+
import cv2
|
7 |
import matplotlib.pyplot as plt
|
8 |
+
import numpy as np
|
9 |
import streamlit as st
|
10 |
+
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
from doctr.io import DocumentFile
|
|
|
13 |
from doctr.utils.visualization import visualize_page
|
14 |
|
15 |
+
from backend.pytorch import DET_ARCHS, RECO_ARCHS, forward_image, load_predictor
|
|
|
16 |
|
17 |
+
forward_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
18 |
|
|
|
19 |
|
20 |
+
def main(det_archs, reco_archs):
|
21 |
+
"""Build a streamlit layout"""
|
22 |
# Wide mode
|
23 |
st.set_page_config(layout="wide")
|
24 |
|
25 |
# Designing the interface
|
26 |
st.title("docTR: Document Text Recognition")
|
27 |
# For newline
|
28 |
+
st.write("\n")
|
|
|
|
|
|
|
|
|
29 |
# Instructions
|
30 |
st.markdown("*Hint: click on the top-right corner of an image to enlarge it!*")
|
31 |
# Set the columns
|
32 |
+
cols = st.columns((1, 1, 1, 1))
|
33 |
cols[0].subheader("Input page")
|
34 |
cols[1].subheader("Segmentation heatmap")
|
35 |
cols[2].subheader("OCR output")
|
|
|
38 |
# Sidebar
|
39 |
# File selection
|
40 |
st.sidebar.title("Document selection")
|
|
|
|
|
41 |
# Choose your own image
|
42 |
+
uploaded_file = st.sidebar.file_uploader("Upload files", type=["pdf", "png", "jpeg", "jpg"])
|
43 |
if uploaded_file is not None:
|
44 |
+
if uploaded_file.name.endswith(".pdf"):
|
45 |
doc = DocumentFile.from_pdf(uploaded_file.read())
|
46 |
else:
|
47 |
doc = DocumentFile.from_images(uploaded_file.read())
|
48 |
page_idx = st.sidebar.selectbox("Page selection", [idx + 1 for idx in range(len(doc))]) - 1
|
49 |
+
page = doc[page_idx]
|
50 |
+
cols[0].image(page)
|
51 |
|
52 |
# Model selection
|
53 |
st.sidebar.title("Model selection")
|
54 |
+
det_arch = st.sidebar.selectbox("Text detection model", det_archs)
|
55 |
+
reco_arch = st.sidebar.selectbox("Text recognition model", reco_archs)
|
56 |
|
57 |
# For newline
|
58 |
+
st.sidebar.write("\n")
|
59 |
+
# Only straight pages or possible rotation
|
60 |
+
st.sidebar.title("Parameters")
|
61 |
+
assume_straight_pages = st.sidebar.checkbox("Assume straight pages", value=True)
|
62 |
+
st.sidebar.write("\n")
|
63 |
+
# Straighten pages
|
64 |
+
straighten_pages = st.sidebar.checkbox("Straighten pages", value=False)
|
65 |
+
st.sidebar.write("\n")
|
66 |
+
# Binarization threshold
|
67 |
+
bin_thresh = st.sidebar.slider("Binarization threshold", min_value=0.1, max_value=0.9, value=0.3, step=0.1)
|
68 |
+
st.sidebar.write("\n")
|
69 |
|
70 |
if st.sidebar.button("Analyze page"):
|
|
|
71 |
if uploaded_file is None:
|
72 |
st.sidebar.write("Please upload a document")
|
73 |
|
74 |
else:
|
75 |
+
with st.spinner("Loading model..."):
|
76 |
+
predictor = load_predictor(
|
77 |
+
det_arch, reco_arch, assume_straight_pages, straighten_pages, bin_thresh, forward_device
|
78 |
+
)
|
79 |
|
80 |
+
with st.spinner("Analyzing..."):
|
81 |
# Forward the image to the model
|
82 |
+
seg_map = forward_image(predictor, page, forward_device)
|
83 |
+
seg_map = np.squeeze(seg_map)
|
84 |
+
seg_map = cv2.resize(seg_map, (page.shape[1], page.shape[0]), interpolation=cv2.INTER_LINEAR)
|
85 |
+
|
|
|
|
|
86 |
# Plot the raw heatmap
|
87 |
fig, ax = plt.subplots()
|
88 |
ax.imshow(seg_map)
|
89 |
+
ax.axis("off")
|
90 |
cols[1].pyplot(fig)
|
91 |
|
92 |
# Plot OCR output
|
93 |
+
out = predictor([page])
|
94 |
+
fig = visualize_page(out.pages[0].export(), out.pages[0].page, interactive=False, add_labels=False)
|
95 |
cols[2].pyplot(fig)
|
96 |
|
97 |
# Page reconsitution under input page
|
98 |
page_export = out.pages[0].export()
|
99 |
+
if assume_straight_pages or (not assume_straight_pages and straighten_pages):
|
100 |
+
img = out.pages[0].synthesize()
|
101 |
+
cols[3].image(img, clamp=True)
|
102 |
|
103 |
# Display JSON
|
104 |
st.markdown("\nHere are your analysis results in JSON format:")
|
105 |
st.json(page_export)
|
106 |
|
107 |
|
108 |
+
if __name__ == "__main__":
|
109 |
+
main(DET_ARCHS, RECO_ARCHS)
|
backend/pytorch.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2021-2024, Mindee.
|
2 |
+
|
3 |
+
# This program is licensed under the Apache License 2.0.
|
4 |
+
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
|
9 |
+
from doctr.models import ocr_predictor
|
10 |
+
from doctr.models.predictor import OCRPredictor
|
11 |
+
|
12 |
+
DET_ARCHS = [
|
13 |
+
"db_resnet50",
|
14 |
+
"db_resnet34",
|
15 |
+
"db_mobilenet_v3_large",
|
16 |
+
"linknet_resnet18",
|
17 |
+
"linknet_resnet34",
|
18 |
+
"linknet_resnet50",
|
19 |
+
]
|
20 |
+
RECO_ARCHS = [
|
21 |
+
"crnn_vgg16_bn",
|
22 |
+
"crnn_mobilenet_v3_small",
|
23 |
+
"crnn_mobilenet_v3_large",
|
24 |
+
"master",
|
25 |
+
"sar_resnet31",
|
26 |
+
"vitstr_small",
|
27 |
+
"vitstr_base",
|
28 |
+
"parseq",
|
29 |
+
]
|
30 |
+
|
31 |
+
|
32 |
+
def load_predictor(
|
33 |
+
det_arch: str,
|
34 |
+
reco_arch: str,
|
35 |
+
assume_straight_pages: bool,
|
36 |
+
straighten_pages: bool,
|
37 |
+
bin_thresh: float,
|
38 |
+
device: torch.device,
|
39 |
+
) -> OCRPredictor:
|
40 |
+
"""Load a predictor from doctr.models
|
41 |
+
|
42 |
+
Args:
|
43 |
+
----
|
44 |
+
det_arch: detection architecture
|
45 |
+
reco_arch: recognition architecture
|
46 |
+
assume_straight_pages: whether to assume straight pages or not
|
47 |
+
straighten_pages: whether to straighten rotated pages or not
|
48 |
+
bin_thresh: binarization threshold for the segmentation map
|
49 |
+
device: torch.device, the device to load the predictor on
|
50 |
+
|
51 |
+
Returns:
|
52 |
+
-------
|
53 |
+
instance of OCRPredictor
|
54 |
+
"""
|
55 |
+
predictor = ocr_predictor(
|
56 |
+
det_arch,
|
57 |
+
reco_arch,
|
58 |
+
pretrained=True,
|
59 |
+
assume_straight_pages=assume_straight_pages,
|
60 |
+
straighten_pages=straighten_pages,
|
61 |
+
export_as_straight_boxes=straighten_pages,
|
62 |
+
detect_orientation=not assume_straight_pages,
|
63 |
+
).to(device)
|
64 |
+
predictor.det_predictor.model.postprocessor.bin_thresh = bin_thresh
|
65 |
+
return predictor
|
66 |
+
|
67 |
+
|
68 |
+
def forward_image(predictor: OCRPredictor, image: np.ndarray, device: torch.device) -> np.ndarray:
|
69 |
+
"""Forward an image through the predictor
|
70 |
+
|
71 |
+
Args:
|
72 |
+
----
|
73 |
+
predictor: instance of OCRPredictor
|
74 |
+
image: image to process
|
75 |
+
device: torch.device, the device to process the image on
|
76 |
+
|
77 |
+
Returns:
|
78 |
+
-------
|
79 |
+
segmentation map
|
80 |
+
"""
|
81 |
+
with torch.no_grad():
|
82 |
+
processed_batches = predictor.det_predictor.pre_processor([image])
|
83 |
+
out = predictor.det_predictor.model(processed_batches[0].to(device), return_model_output=True)
|
84 |
+
seg_map = out["out_map"].to("cpu").numpy()
|
85 |
+
|
86 |
+
return seg_map
|
requirements.txt
CHANGED
@@ -1,3 +1,2 @@
|
|
1 |
-
-e git+https://github.com/mindee/doctr.git#egg=python-doctr[
|
2 |
streamlit>=1.0.0
|
3 |
-
PyMuPDF>=1.16.0,!=1.18.11,!=1.18.12,!=1.19.5
|
|
|
1 |
+
-e git+https://github.com/mindee/doctr.git#egg=python-doctr[torch]
|
2 |
streamlit>=1.0.0
|
|