deepsh2207 commited on
Commit
640c986
·
0 Parent(s):

base code added

Browse files
Files changed (6) hide show
  1. README.md +39 -0
  2. app.py +109 -0
  3. backend/pytorch.py +86 -0
  4. gitattributes +27 -0
  5. packages.txt +1 -0
  6. requirements.txt +2 -0
README.md ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: docTR
3
+ emoji: 📑
4
+ colorFrom: purple
5
+ colorTo: pink
6
+ sdk: streamlit
7
+ sdk_version: 1.30.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: apache-2.0
11
+ ---
12
+
13
+ # Configuration
14
+
15
+ `title`: _string_
16
+ Display title for the Space
17
+
18
+ `emoji`: _string_
19
+ Space emoji (emoji-only character allowed)
20
+
21
+ `colorFrom`: _string_
22
+ Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
23
+
24
+ `colorTo`: _string_
25
+ Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
26
+
27
+ `sdk`: _string_
28
+ Can be either `gradio` or `streamlit`
29
+
30
+ `sdk_version` : _string_
31
+ Only applicable for `streamlit` SDK.
32
+ See [doc](https://hf.co/docs/hub/spaces) for more info on supported versions.
33
+
34
+ `app_file`: _string_
35
+ Path to your main application file (which contains either `gradio` or `streamlit` Python code).
36
+ Path is relative to the root of the repository.
37
+
38
+ `pinned`: _boolean_
39
+ Whether the Space stays on top of your list.
app.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2021-2024, Mindee.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+ import cv2
7
+ import matplotlib.pyplot as plt
8
+ import numpy as np
9
+ import streamlit as st
10
+ import torch
11
+
12
+ from doctr.io import DocumentFile
13
+ from doctr.utils.visualization import visualize_page
14
+
15
+ from backend.pytorch import DET_ARCHS, RECO_ARCHS, forward_image, load_predictor
16
+
17
+ forward_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
18
+
19
+
20
+ def main(det_archs, reco_archs):
21
+ """Build a streamlit layout"""
22
+ # Wide mode
23
+ st.set_page_config(layout="wide")
24
+
25
+ # Designing the interface
26
+ st.title("docTR: Document Text Recognition")
27
+ # For newline
28
+ st.write("\n")
29
+ # Instructions
30
+ st.markdown("*Hint: click on the top-right corner of an image to enlarge it!*")
31
+ # Set the columns
32
+ cols = st.columns((1, 1, 1, 1))
33
+ cols[0].subheader("Input page")
34
+ cols[1].subheader("Segmentation heatmap")
35
+ cols[2].subheader("OCR output")
36
+ cols[3].subheader("Page reconstitution")
37
+
38
+ # Sidebar
39
+ # File selection
40
+ st.sidebar.title("Document selection")
41
+ # Choose your own image
42
+ uploaded_file = st.sidebar.file_uploader("Upload files", type=["pdf", "png", "jpeg", "jpg"])
43
+ if uploaded_file is not None:
44
+ if uploaded_file.name.endswith(".pdf"):
45
+ doc = DocumentFile.from_pdf(uploaded_file.read())
46
+ else:
47
+ doc = DocumentFile.from_images(uploaded_file.read())
48
+ page_idx = st.sidebar.selectbox("Page selection", [idx + 1 for idx in range(len(doc))]) - 1
49
+ page = doc[page_idx]
50
+ cols[0].image(page)
51
+
52
+ # Model selection
53
+ st.sidebar.title("Model selection")
54
+ det_arch = st.sidebar.selectbox("Text detection model", det_archs)
55
+ reco_arch = st.sidebar.selectbox("Text recognition model", reco_archs)
56
+
57
+ # For newline
58
+ st.sidebar.write("\n")
59
+ # Only straight pages or possible rotation
60
+ st.sidebar.title("Parameters")
61
+ assume_straight_pages = st.sidebar.checkbox("Assume straight pages", value=True)
62
+ st.sidebar.write("\n")
63
+ # Straighten pages
64
+ straighten_pages = st.sidebar.checkbox("Straighten pages", value=False)
65
+ st.sidebar.write("\n")
66
+ # Binarization threshold
67
+ bin_thresh = st.sidebar.slider("Binarization threshold", min_value=0.1, max_value=0.9, value=0.3, step=0.1)
68
+ st.sidebar.write("\n")
69
+
70
+ if st.sidebar.button("Analyze page"):
71
+ if uploaded_file is None:
72
+ st.sidebar.write("Please upload a document")
73
+
74
+ else:
75
+ with st.spinner("Loading model..."):
76
+ predictor = load_predictor(
77
+ det_arch, reco_arch, assume_straight_pages, straighten_pages, bin_thresh, forward_device
78
+ )
79
+
80
+ with st.spinner("Analyzing..."):
81
+ # Forward the image to the model
82
+ seg_map = forward_image(predictor, page, forward_device)
83
+ seg_map = np.squeeze(seg_map)
84
+ seg_map = cv2.resize(seg_map, (page.shape[1], page.shape[0]), interpolation=cv2.INTER_LINEAR)
85
+
86
+ # Plot the raw heatmap
87
+ fig, ax = plt.subplots()
88
+ ax.imshow(seg_map)
89
+ ax.axis("off")
90
+ cols[1].pyplot(fig)
91
+
92
+ # Plot OCR output
93
+ out = predictor([page])
94
+ fig = visualize_page(out.pages[0].export(), out.pages[0].page, interactive=False, add_labels=False)
95
+ cols[2].pyplot(fig)
96
+
97
+ # Page reconsitution under input page
98
+ page_export = out.pages[0].export()
99
+ if assume_straight_pages or (not assume_straight_pages and straighten_pages):
100
+ img = out.pages[0].synthesize()
101
+ cols[3].image(img, clamp=True)
102
+
103
+ # Display JSON
104
+ st.markdown("\nHere are your analysis results in JSON format:")
105
+ st.json(page_export, expanded=False)
106
+
107
+
108
+ if __name__ == "__main__":
109
+ main(DET_ARCHS, RECO_ARCHS)
backend/pytorch.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2021-2024, Mindee.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+ import numpy as np
7
+ import torch
8
+
9
+ from doctr.models import ocr_predictor
10
+ from doctr.models.predictor import OCRPredictor
11
+
12
+ DET_ARCHS = [
13
+ "db_resnet50",
14
+ "db_resnet34",
15
+ "db_mobilenet_v3_large",
16
+ "linknet_resnet18",
17
+ "linknet_resnet34",
18
+ "linknet_resnet50",
19
+ ]
20
+ RECO_ARCHS = [
21
+ "crnn_vgg16_bn",
22
+ "crnn_mobilenet_v3_small",
23
+ "crnn_mobilenet_v3_large",
24
+ "master",
25
+ "sar_resnet31",
26
+ "vitstr_small",
27
+ "vitstr_base",
28
+ "parseq",
29
+ ]
30
+
31
+
32
+ def load_predictor(
33
+ det_arch: str,
34
+ reco_arch: str,
35
+ assume_straight_pages: bool,
36
+ straighten_pages: bool,
37
+ bin_thresh: float,
38
+ device: torch.device,
39
+ ) -> OCRPredictor:
40
+ """Load a predictor from doctr.models
41
+
42
+ Args:
43
+ ----
44
+ det_arch: detection architecture
45
+ reco_arch: recognition architecture
46
+ assume_straight_pages: whether to assume straight pages or not
47
+ straighten_pages: whether to straighten rotated pages or not
48
+ bin_thresh: binarization threshold for the segmentation map
49
+ device: torch.device, the device to load the predictor on
50
+
51
+ Returns:
52
+ -------
53
+ instance of OCRPredictor
54
+ """
55
+ predictor = ocr_predictor(
56
+ det_arch,
57
+ reco_arch,
58
+ pretrained=True,
59
+ assume_straight_pages=assume_straight_pages,
60
+ straighten_pages=straighten_pages,
61
+ export_as_straight_boxes=straighten_pages,
62
+ detect_orientation=not assume_straight_pages,
63
+ ).to(device)
64
+ predictor.det_predictor.model.postprocessor.bin_thresh = bin_thresh
65
+ return predictor
66
+
67
+
68
+ def forward_image(predictor: OCRPredictor, image: np.ndarray, device: torch.device) -> np.ndarray:
69
+ """Forward an image through the predictor
70
+
71
+ Args:
72
+ ----
73
+ predictor: instance of OCRPredictor
74
+ image: image to process
75
+ device: torch.device, the device to process the image on
76
+
77
+ Returns:
78
+ -------
79
+ segmentation map
80
+ """
81
+ with torch.no_grad():
82
+ processed_batches = predictor.det_predictor.pre_processor([image])
83
+ out = predictor.det_predictor.model(processed_batches[0].to(device), return_model_output=True)
84
+ seg_map = out["out_map"].to("cpu").numpy()
85
+
86
+ return seg_map
gitattributes ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bin.* filter=lfs diff=lfs merge=lfs -text
5
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.model filter=lfs diff=lfs merge=lfs -text
12
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
13
+ *.onnx filter=lfs diff=lfs merge=lfs -text
14
+ *.ot filter=lfs diff=lfs merge=lfs -text
15
+ *.parquet filter=lfs diff=lfs merge=lfs -text
16
+ *.pb filter=lfs diff=lfs merge=lfs -text
17
+ *.pt filter=lfs diff=lfs merge=lfs -text
18
+ *.pth filter=lfs diff=lfs merge=lfs -text
19
+ *.rar filter=lfs diff=lfs merge=lfs -text
20
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
21
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
22
+ *.tflite filter=lfs diff=lfs merge=lfs -text
23
+ *.tgz filter=lfs diff=lfs merge=lfs -text
24
+ *.xz filter=lfs diff=lfs merge=lfs -text
25
+ *.zip filter=lfs diff=lfs merge=lfs -text
26
+ *.zstandard filter=lfs diff=lfs merge=lfs -text
27
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ python3-opencv
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ -e git+https://github.com/mindee/doctr.git#egg=python-doctr[torch]
2
+ streamlit>=1.0.0