Spaces:
Sleeping
Sleeping
deepsh2207
commited on
Commit
β’
640c986
0
Parent(s):
base code added
Browse files- README.md +39 -0
- app.py +109 -0
- backend/pytorch.py +86 -0
- gitattributes +27 -0
- packages.txt +1 -0
- requirements.txt +2 -0
README.md
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: docTR
|
3 |
+
emoji: π
|
4 |
+
colorFrom: purple
|
5 |
+
colorTo: pink
|
6 |
+
sdk: streamlit
|
7 |
+
sdk_version: 1.30.0
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
license: apache-2.0
|
11 |
+
---
|
12 |
+
|
13 |
+
# Configuration
|
14 |
+
|
15 |
+
`title`: _string_
|
16 |
+
Display title for the Space
|
17 |
+
|
18 |
+
`emoji`: _string_
|
19 |
+
Space emoji (emoji-only character allowed)
|
20 |
+
|
21 |
+
`colorFrom`: _string_
|
22 |
+
Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
|
23 |
+
|
24 |
+
`colorTo`: _string_
|
25 |
+
Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
|
26 |
+
|
27 |
+
`sdk`: _string_
|
28 |
+
Can be either `gradio` or `streamlit`
|
29 |
+
|
30 |
+
`sdk_version` : _string_
|
31 |
+
Only applicable for `streamlit` SDK.
|
32 |
+
See [doc](https://hf.co/docs/hub/spaces) for more info on supported versions.
|
33 |
+
|
34 |
+
`app_file`: _string_
|
35 |
+
Path to your main application file (which contains either `gradio` or `streamlit` Python code).
|
36 |
+
Path is relative to the root of the repository.
|
37 |
+
|
38 |
+
`pinned`: _boolean_
|
39 |
+
Whether the Space stays on top of your list.
|
app.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2021-2024, Mindee.
|
2 |
+
|
3 |
+
# This program is licensed under the Apache License 2.0.
|
4 |
+
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
5 |
+
|
6 |
+
import cv2
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
import numpy as np
|
9 |
+
import streamlit as st
|
10 |
+
import torch
|
11 |
+
|
12 |
+
from doctr.io import DocumentFile
|
13 |
+
from doctr.utils.visualization import visualize_page
|
14 |
+
|
15 |
+
from backend.pytorch import DET_ARCHS, RECO_ARCHS, forward_image, load_predictor
|
16 |
+
|
17 |
+
forward_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
18 |
+
|
19 |
+
|
20 |
+
def main(det_archs, reco_archs):
|
21 |
+
"""Build a streamlit layout"""
|
22 |
+
# Wide mode
|
23 |
+
st.set_page_config(layout="wide")
|
24 |
+
|
25 |
+
# Designing the interface
|
26 |
+
st.title("docTR: Document Text Recognition")
|
27 |
+
# For newline
|
28 |
+
st.write("\n")
|
29 |
+
# Instructions
|
30 |
+
st.markdown("*Hint: click on the top-right corner of an image to enlarge it!*")
|
31 |
+
# Set the columns
|
32 |
+
cols = st.columns((1, 1, 1, 1))
|
33 |
+
cols[0].subheader("Input page")
|
34 |
+
cols[1].subheader("Segmentation heatmap")
|
35 |
+
cols[2].subheader("OCR output")
|
36 |
+
cols[3].subheader("Page reconstitution")
|
37 |
+
|
38 |
+
# Sidebar
|
39 |
+
# File selection
|
40 |
+
st.sidebar.title("Document selection")
|
41 |
+
# Choose your own image
|
42 |
+
uploaded_file = st.sidebar.file_uploader("Upload files", type=["pdf", "png", "jpeg", "jpg"])
|
43 |
+
if uploaded_file is not None:
|
44 |
+
if uploaded_file.name.endswith(".pdf"):
|
45 |
+
doc = DocumentFile.from_pdf(uploaded_file.read())
|
46 |
+
else:
|
47 |
+
doc = DocumentFile.from_images(uploaded_file.read())
|
48 |
+
page_idx = st.sidebar.selectbox("Page selection", [idx + 1 for idx in range(len(doc))]) - 1
|
49 |
+
page = doc[page_idx]
|
50 |
+
cols[0].image(page)
|
51 |
+
|
52 |
+
# Model selection
|
53 |
+
st.sidebar.title("Model selection")
|
54 |
+
det_arch = st.sidebar.selectbox("Text detection model", det_archs)
|
55 |
+
reco_arch = st.sidebar.selectbox("Text recognition model", reco_archs)
|
56 |
+
|
57 |
+
# For newline
|
58 |
+
st.sidebar.write("\n")
|
59 |
+
# Only straight pages or possible rotation
|
60 |
+
st.sidebar.title("Parameters")
|
61 |
+
assume_straight_pages = st.sidebar.checkbox("Assume straight pages", value=True)
|
62 |
+
st.sidebar.write("\n")
|
63 |
+
# Straighten pages
|
64 |
+
straighten_pages = st.sidebar.checkbox("Straighten pages", value=False)
|
65 |
+
st.sidebar.write("\n")
|
66 |
+
# Binarization threshold
|
67 |
+
bin_thresh = st.sidebar.slider("Binarization threshold", min_value=0.1, max_value=0.9, value=0.3, step=0.1)
|
68 |
+
st.sidebar.write("\n")
|
69 |
+
|
70 |
+
if st.sidebar.button("Analyze page"):
|
71 |
+
if uploaded_file is None:
|
72 |
+
st.sidebar.write("Please upload a document")
|
73 |
+
|
74 |
+
else:
|
75 |
+
with st.spinner("Loading model..."):
|
76 |
+
predictor = load_predictor(
|
77 |
+
det_arch, reco_arch, assume_straight_pages, straighten_pages, bin_thresh, forward_device
|
78 |
+
)
|
79 |
+
|
80 |
+
with st.spinner("Analyzing..."):
|
81 |
+
# Forward the image to the model
|
82 |
+
seg_map = forward_image(predictor, page, forward_device)
|
83 |
+
seg_map = np.squeeze(seg_map)
|
84 |
+
seg_map = cv2.resize(seg_map, (page.shape[1], page.shape[0]), interpolation=cv2.INTER_LINEAR)
|
85 |
+
|
86 |
+
# Plot the raw heatmap
|
87 |
+
fig, ax = plt.subplots()
|
88 |
+
ax.imshow(seg_map)
|
89 |
+
ax.axis("off")
|
90 |
+
cols[1].pyplot(fig)
|
91 |
+
|
92 |
+
# Plot OCR output
|
93 |
+
out = predictor([page])
|
94 |
+
fig = visualize_page(out.pages[0].export(), out.pages[0].page, interactive=False, add_labels=False)
|
95 |
+
cols[2].pyplot(fig)
|
96 |
+
|
97 |
+
# Page reconsitution under input page
|
98 |
+
page_export = out.pages[0].export()
|
99 |
+
if assume_straight_pages or (not assume_straight_pages and straighten_pages):
|
100 |
+
img = out.pages[0].synthesize()
|
101 |
+
cols[3].image(img, clamp=True)
|
102 |
+
|
103 |
+
# Display JSON
|
104 |
+
st.markdown("\nHere are your analysis results in JSON format:")
|
105 |
+
st.json(page_export, expanded=False)
|
106 |
+
|
107 |
+
|
108 |
+
if __name__ == "__main__":
|
109 |
+
main(DET_ARCHS, RECO_ARCHS)
|
backend/pytorch.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2021-2024, Mindee.
|
2 |
+
|
3 |
+
# This program is licensed under the Apache License 2.0.
|
4 |
+
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
|
9 |
+
from doctr.models import ocr_predictor
|
10 |
+
from doctr.models.predictor import OCRPredictor
|
11 |
+
|
12 |
+
DET_ARCHS = [
|
13 |
+
"db_resnet50",
|
14 |
+
"db_resnet34",
|
15 |
+
"db_mobilenet_v3_large",
|
16 |
+
"linknet_resnet18",
|
17 |
+
"linknet_resnet34",
|
18 |
+
"linknet_resnet50",
|
19 |
+
]
|
20 |
+
RECO_ARCHS = [
|
21 |
+
"crnn_vgg16_bn",
|
22 |
+
"crnn_mobilenet_v3_small",
|
23 |
+
"crnn_mobilenet_v3_large",
|
24 |
+
"master",
|
25 |
+
"sar_resnet31",
|
26 |
+
"vitstr_small",
|
27 |
+
"vitstr_base",
|
28 |
+
"parseq",
|
29 |
+
]
|
30 |
+
|
31 |
+
|
32 |
+
def load_predictor(
|
33 |
+
det_arch: str,
|
34 |
+
reco_arch: str,
|
35 |
+
assume_straight_pages: bool,
|
36 |
+
straighten_pages: bool,
|
37 |
+
bin_thresh: float,
|
38 |
+
device: torch.device,
|
39 |
+
) -> OCRPredictor:
|
40 |
+
"""Load a predictor from doctr.models
|
41 |
+
|
42 |
+
Args:
|
43 |
+
----
|
44 |
+
det_arch: detection architecture
|
45 |
+
reco_arch: recognition architecture
|
46 |
+
assume_straight_pages: whether to assume straight pages or not
|
47 |
+
straighten_pages: whether to straighten rotated pages or not
|
48 |
+
bin_thresh: binarization threshold for the segmentation map
|
49 |
+
device: torch.device, the device to load the predictor on
|
50 |
+
|
51 |
+
Returns:
|
52 |
+
-------
|
53 |
+
instance of OCRPredictor
|
54 |
+
"""
|
55 |
+
predictor = ocr_predictor(
|
56 |
+
det_arch,
|
57 |
+
reco_arch,
|
58 |
+
pretrained=True,
|
59 |
+
assume_straight_pages=assume_straight_pages,
|
60 |
+
straighten_pages=straighten_pages,
|
61 |
+
export_as_straight_boxes=straighten_pages,
|
62 |
+
detect_orientation=not assume_straight_pages,
|
63 |
+
).to(device)
|
64 |
+
predictor.det_predictor.model.postprocessor.bin_thresh = bin_thresh
|
65 |
+
return predictor
|
66 |
+
|
67 |
+
|
68 |
+
def forward_image(predictor: OCRPredictor, image: np.ndarray, device: torch.device) -> np.ndarray:
|
69 |
+
"""Forward an image through the predictor
|
70 |
+
|
71 |
+
Args:
|
72 |
+
----
|
73 |
+
predictor: instance of OCRPredictor
|
74 |
+
image: image to process
|
75 |
+
device: torch.device, the device to process the image on
|
76 |
+
|
77 |
+
Returns:
|
78 |
+
-------
|
79 |
+
segmentation map
|
80 |
+
"""
|
81 |
+
with torch.no_grad():
|
82 |
+
processed_batches = predictor.det_predictor.pre_processor([image])
|
83 |
+
out = predictor.det_predictor.model(processed_batches[0].to(device), return_model_output=True)
|
84 |
+
seg_map = out["out_map"].to("cpu").numpy()
|
85 |
+
|
86 |
+
return seg_map
|
gitattributes
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bin.* filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
20 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
26 |
+
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
packages.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
python3-opencv
|
requirements.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
-e git+https://github.com/mindee/doctr.git#egg=python-doctr[torch]
|
2 |
+
streamlit>=1.0.0
|