osanseviero HF staff commited on
Commit
2ef412f
1 Parent(s): fdaccba

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +97 -0
app.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import streamlit as st
3
+ import matplotlib.pyplot as plt
4
+
5
+ from doctr.io import DocumentFile
6
+ from doctr.models import ocr_predictor
7
+ from doctr.utils.visualization import visualize_page
8
+
9
+ DET_ARCHS = ["db_resnet50"]
10
+ RECO_ARCHS = ["crnn_vgg16_bn", "master", "sar_resnet31"]
11
+
12
+ def main():
13
+
14
+ # Wide mode
15
+ st.set_page_config(layout="wide")
16
+
17
+ # Designing the interface
18
+ st.title("DocTR: Document Text Recognition")
19
+ # For newline
20
+ st.write('\n')
21
+ #
22
+ st.write('Find more info at: https://github.com/mindee/doctr')
23
+ # For newline
24
+ st.write('\n')
25
+ # Instructions
26
+ st.markdown("*Hint: click on the top-right corner of an image to enlarge it!*")
27
+ # Set the columns
28
+ cols = st.beta_columns((1, 1, 1, 1))
29
+ cols[0].subheader("Input page")
30
+ cols[1].subheader("Segmentation heatmap")
31
+ cols[2].subheader("OCR output")
32
+ cols[3].subheader("Page reconstitution")
33
+
34
+ # Sidebar
35
+ # File selection
36
+ st.sidebar.title("Document selection")
37
+ # Disabling warning
38
+ st.set_option('deprecation.showfileUploaderEncoding', False)
39
+ # Choose your own image
40
+ uploaded_file = st.sidebar.file_uploader("Upload files", type=['pdf', 'png', 'jpeg', 'jpg'])
41
+ if uploaded_file is not None:
42
+ if uploaded_file.name.endswith('.pdf'):
43
+ doc = DocumentFile.from_pdf(uploaded_file.read()).as_images()
44
+ else:
45
+ doc = DocumentFile.from_images(uploaded_file.read())
46
+ page_idx = st.sidebar.selectbox("Page selection", [idx + 1 for idx in range(len(doc))]) - 1
47
+ cols[0].image(doc[page_idx])
48
+
49
+ # Model selection
50
+ st.sidebar.title("Model selection")
51
+ det_arch = st.sidebar.selectbox("Text detection model", DET_ARCHS)
52
+ reco_arch = st.sidebar.selectbox("Text recognition model", RECO_ARCHS)
53
+
54
+ # For newline
55
+ st.sidebar.write('\n')
56
+
57
+ if st.sidebar.button("Analyze page"):
58
+
59
+ if uploaded_file is None:
60
+ st.sidebar.write("Please upload a document")
61
+
62
+ else:
63
+ with st.spinner('Loading model...'):
64
+ predictor = ocr_predictor(det_arch, reco_arch, pretrained=True)
65
+
66
+ with st.spinner('Analyzing...'):
67
+
68
+ # Forward the image to the model
69
+ processed_batches = predictor.det_predictor.pre_processor([doc[page_idx]])
70
+ out = predictor.det_predictor.model(processed_batches[0], return_model_output=True)
71
+ seg_map = out["out_map"]
72
+ seg_map = tf.squeeze(seg_map[0, ...], axis=[2])
73
+ seg_map = cv2.resize(seg_map.numpy(), (doc[page_idx].shape[1], doc[page_idx].shape[0]),
74
+ interpolation=cv2.INTER_LINEAR)
75
+ # Plot the raw heatmap
76
+ fig, ax = plt.subplots()
77
+ ax.imshow(seg_map)
78
+ ax.axis('off')
79
+ cols[1].pyplot(fig)
80
+
81
+ # Plot OCR output
82
+ out = predictor([doc[page_idx]])
83
+ fig = visualize_page(out.pages[0].export(), doc[page_idx], interactive=False)
84
+ cols[2].pyplot(fig)
85
+
86
+ # Page reconsitution under input page
87
+ page_export = out.pages[0].export()
88
+ img = out.pages[0].synthesize()
89
+ cols[3].image(img, clamp=True)
90
+
91
+ # Display JSON
92
+ st.markdown("\nHere are your analysis results in JSON format:")
93
+ st.json(page_export)
94
+
95
+
96
+ if __name__ == '__main__':
97
+ main()