Spaces:
Running
Running
File size: 4,931 Bytes
d812fea 640c986 d840e96 640c986 d812fea 640c986 383e64a 640c986 383e64a 640c986 383e64a cdd9cb4 640c986 d812fea 640c986 d812fea 640c986 cdd9cb4 640c986 cdd9cb4 640c986 cdd9cb4 640c986 d812fea 84f3cde 9bc57af 84f3cde d840e96 84f3cde d812fea 640c986 d840e96 640c986 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
# import cv2
# import matplotlib.pyplot as plt
import numpy as np
import streamlit as st
import torch
import json
from doctr.io import DocumentFile
from doctr.utils.visualization import visualize_page
from backend.pytorch import DET_ARCHS, RECO_ARCHS, load_predictor #forward_image
forward_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def main(det_archs, reco_archs):
"""Build a streamlit layout"""
# Wide mode
st.set_page_config(layout="wide")
# Designing the interface
st.title("Document Text Extraction")
# For newline
st.write("\n")
# Instructions
st.markdown("*Hint: click on the top-right corner of an image to enlarge it!*")
# Set the columns
# cols = st.columns((1, 1, 1, 1))
cols = st.columns((1, 1, 1))
cols[0].subheader("Input page")
# cols[1].subheader("Segmentation heatmap")
cols[1].subheader("OCR output")
cols[2].subheader("Page reconstitution")
# Sidebar
# File selection
st.sidebar.title("Document selection")
# Choose your own image
uploaded_file = st.sidebar.file_uploader("Upload files", type=["pdf", "png", "jpeg", "jpg"])
if uploaded_file is not None:
if uploaded_file.name.endswith(".pdf"):
doc = DocumentFile.from_pdf(uploaded_file.read())
else:
doc = DocumentFile.from_images(uploaded_file.read())
page_idx = st.sidebar.selectbox("Page selection", [idx + 1 for idx in range(len(doc))]) - 1
page = doc[page_idx]
cols[0].image(page)
# Model selection
st.sidebar.title("Model selection")
det_arch = st.sidebar.selectbox("Text detection model", det_archs)
reco_arch = st.sidebar.selectbox("Text recognition model", reco_archs)
# # For newline
# st.sidebar.write("\n")
# # Only straight pages or possible rotation
# st.sidebar.title("Parameters")
# assume_straight_pages = st.sidebar.checkbox("Assume straight pages", value=True)
# st.sidebar.write("\n")
# # Straighten pages
# straighten_pages = st.sidebar.checkbox("Straighten pages", value=False)
# st.sidebar.write("\n")
# # Binarization threshold
# bin_thresh = st.sidebar.slider("Binarization threshold", min_value=0.1, max_value=0.9, value=0.3, step=0.1)
# st.sidebar.write("\n")
if st.sidebar.button("Analyze page"):
if uploaded_file is None:
st.sidebar.write("Please upload a document")
else:
with st.spinner("Loading model..."):
# Default Values
assume_straight_pages, straighten_pages, bin_thresh = True, False, 0.3
predictor = load_predictor(
det_arch, reco_arch, assume_straight_pages, straighten_pages, bin_thresh, forward_device
)
with st.spinner("Analyzing..."):
# # Forward the image to the model
# seg_map = forward_image(predictor, page, forward_device)
# seg_map = np.squeeze(seg_map)
# seg_map = cv2.resize(seg_map, (page.shape[1], page.shape[0]), interpolation=cv2.INTER_LINEAR)
# # Plot the raw heatmap
# fig, ax = plt.subplots()
# ax.imshow(seg_map)
# ax.axis("off")
# cols[1].pyplot(fig)
# Plot OCR output
out = predictor([page])
fig = visualize_page(out.pages[0].export(), out.pages[0].page, interactive=False, add_labels=False)
cols[1].pyplot(fig)
# Page reconsitution under input page
page_export = out.pages[0].export()
if assume_straight_pages or (not assume_straight_pages and straighten_pages):
img = out.pages[0].synthesize()
cols[2].image(img, clamp=True)
print('out',out)
print('\n')
print('page_export',page_export)
print('\n')
all_text = ''
for i in page_export['blocks']:
for line in i['lines']:
for word in line['words']:
all_text+=word['value']
all_text+=' '
all_text+='\n'
print('all_text', all_text)
print('\n')
# Display Text
st.markdown("\n## **Here is your text:**")
st.write(all_text)
# Display JSON
json_string = json.dumps(page_export)
st.markdown("\n## **Here are your analysis results in JSON format:**")
st.download_button(label="Download JSON", data=json_string, file_name='data.json', mime='application/json')
st.json(page_export, expanded=False)
if __name__ == "__main__":
main(DET_ARCHS, RECO_ARCHS)
|