AlhitawiMohammed22 commited on
Commit
c46149c
1 Parent(s): ff135d3

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +145 -0
app.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ["USE_TORCH"] = "1"
3
+ os.environ["USE_TF"] = "0"
4
+ import torch
5
+ from torch.utils.data.dataloader import DataLoader
6
+
7
+ from builder import DocumentBuilder
8
+ from trocr import IAMDataset, device, get_processor_model
9
+ from doctr.utils.visualization import visualize_page
10
+ from doctr.models.predictor.base import _OCRPredictor
11
+ from doctr.models.detection.predictor import DetectionPredictor
12
+ from doctr.models.preprocessor import PreProcessor
13
+ from doctr.models import db_resnet50, db_mobilenet_v3_large
14
+
15
+ from doctr.io import DocumentFile
16
+ import numpy as np
17
+ import cv2
18
+ import matplotlib.pyplot as plt
19
+ import streamlit as st
20
+
21
+ DET_ARCHS = ["db_resnet50", "db_mobilenet_v3_large"]
22
+ RECO_ARCHS = ["microsoft/trocr-large-printed", "microsoft/trocr-large-stage1", "microsoft/trocr-large-handwritten"]
23
+
24
+
25
+ def main():
26
+ # Wide mode
27
+ st.set_page_config(layout="wide")
28
+ # Designing the interface
29
+ st.title("docTR + TrOCR")
30
+ # For newline
31
+ st.write('\n')
32
+ #
33
+ st.write('For Detection DocTR: https://github.com/mindee/doctr')
34
+ # For newline
35
+ st.write('\n')
36
+ st.write('For Recognition TrOCR: https://github.com/microsoft/unilm/tree/master/trocr')
37
+ # For newline
38
+ st.write('\n')
39
+
40
+ st.write('Any Issue please dm')
41
+ # For newline
42
+ st.write('\n')
43
+ # Instructions
44
+ st.markdown(
45
+ "*Hint: click on the top-right corner of an image to enlarge it!*")
46
+ # Set the columns
47
+ cols = st.columns((1, 1, 1))
48
+ cols[0].subheader("Input page")
49
+ cols[1].subheader("Segmentation heatmap")
50
+
51
+ # Sidebar
52
+ # File selection
53
+ st.sidebar.title("Document selection")
54
+ # Disabling warning
55
+ st.set_option('deprecation.showfileUploaderEncoding', False)
56
+ # Choose your own image
57
+ uploaded_file = st.sidebar.file_uploader(
58
+ "Upload files", type=['pdf', 'png', 'jpeg', 'jpg'])
59
+ if uploaded_file is not None:
60
+ if uploaded_file.name.endswith('.pdf'):
61
+ doc = DocumentFile.from_pdf(uploaded_file.read()).as_images()
62
+ else:
63
+ doc = DocumentFile.from_images(uploaded_file.read())
64
+ page_idx = st.sidebar.selectbox(
65
+ "Page selection", [idx + 1 for idx in range(len(doc))]) - 1
66
+ cols[0].image(doc[page_idx])
67
+ # Model selection
68
+ st.sidebar.title("Model selection")
69
+ det_arch = st.sidebar.selectbox("Text detection model", DET_ARCHS)
70
+ rec_arch = st.sidebar.selectbox("Text recognition model", RECO_ARCHS)
71
+ # For newline
72
+ st.sidebar.write('\n')
73
+ if st.sidebar.button("Analyze page"):
74
+ if uploaded_file is None:
75
+ st.sidebar.write("Please upload a document")
76
+ else:
77
+ with st.spinner('Loading model...'):
78
+ if det_arch == "db_resnet50":
79
+ det_model = db_resnet50(pretrained=True)
80
+ else:
81
+ det_model = db_mobilenet_v3_large(pretrained=True)
82
+ det_predictor = DetectionPredictor(PreProcessor((1024, 1024), batch_size=1, mean=(0.798, 0.785, 0.772), std=(0.264, 0.2749, 0.287)), det_model)
83
+ rec_processor, rec_model = get_processor_model(rec_arch)
84
+ with st.spinner('Analyzing...'):
85
+ # Forward the image to the model
86
+ processed_batches = det_predictor.pre_processor([doc[page_idx]])
87
+ out = det_predictor.model(processed_batches[0], return_model_output=True)
88
+ seg_map = out["out_map"]
89
+ seg_map = torch.squeeze(seg_map[0, ...], axis=0)
90
+ seg_map = cv2.resize(seg_map.detach().numpy(), (doc[page_idx].shape[1], doc[page_idx].shape[0]),
91
+ interpolation=cv2.INTER_LINEAR)
92
+ # Plot the raw heatmap
93
+ fig, ax = plt.subplots()
94
+ ax.imshow(seg_map)
95
+ ax.axis('off')
96
+ cols[1].pyplot(fig)
97
+
98
+ # Plot OCR output
99
+ # Localize text elements
100
+ loc_preds = out["preds"]
101
+
102
+ # Check whether crop mode should be switched to channels first
103
+ channels_last = len(doc) == 0 or isinstance(doc[0], np.ndarray)
104
+
105
+ # Crop images
106
+ crops, loc_preds = _OCRPredictor._prepare_crops(
107
+ doc, loc_preds, channels_last=channels_last, assume_straight_pages=True
108
+ )
109
+
110
+ test_dataset = IAMDataset(crops[0], rec_processor)
111
+ test_dataloader = DataLoader(test_dataset, batch_size=16)
112
+
113
+ text = []
114
+ with torch.no_grad():
115
+ for batch in test_dataloader:
116
+ pixel_values = batch["pixel_values"].to(device)
117
+ generated_ids = rec_model.generate(pixel_values)
118
+ generated_text = rec_processor.batch_decode(
119
+ generated_ids, skip_special_tokens=True)
120
+ text.extend(generated_text)
121
+ boxes, text_preds = _OCRPredictor._process_predictions(
122
+ loc_preds, text)
123
+
124
+ doc_builder = DocumentBuilder()
125
+ out = doc_builder(
126
+ boxes,
127
+ text_preds,
128
+ [
129
+ # type: ignore[misc]
130
+ page.shape[:2] if channels_last else page.shape[-2:]
131
+ for page in [doc[page_idx]]
132
+ ]
133
+ )
134
+
135
+ for df in out:
136
+ st.markdown("text")
137
+ st.write(" ".join(df["word"].to_list()))
138
+ st.write('\n')
139
+ st.markdown("\n Dataframe Output- similar to Tesseract:")
140
+ st.dataframe(df)
141
+
142
+
143
+
144
+ if __name__ == '__main__':
145
+ main()