Spaces:
Running
Running
File size: 4,770 Bytes
6936493 c2ce78d 6936493 c2ce78d 6936493 c2ce78d 6936493 c2ce78d 6936493 c2ce78d 6936493 c2ce78d 6936493 c2ce78d 6936493 c2ce78d 6936493 c2ce78d 6936493 c2ce78d 6936493 6dbbae0 c2ce78d 6936493 c2ce78d 6dbbae0 c2ce78d 6dbbae0 6936493 c2ce78d 6dbbae0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
# -*- encoding: utf-8 -*-
# @Author: SWHL
# @Contact: liekkaskono@163.com
import hashlib
import io
import numpy as np
import pandas as pd
import pypdfium2
import streamlit as st
from PIL import Image
from rapid_latex_ocr import LatexOCR
from streamlit_drawable_canvas import st_canvas
MAX_WIDTH = 800
MAX_HEIGHT = 1000
st.set_page_config(layout="wide")
@st.cache_resource()
def load_model_cached():
return LatexOCR()
def get_canvas_hash(pil_image):
return hashlib.md5(pil_image.tobytes()).hexdigest()
def open_pdf(pdf_file):
stream = io.BytesIO(pdf_file.getvalue())
return pypdfium2.PdfDocument(stream)
@st.cache_data()
def page_count(pdf_file):
doc = open_pdf(pdf_file)
return len(doc)
@st.cache_data()
def get_page_image(pdf_file, page_num, dpi=96):
doc = open_pdf(pdf_file)
renderer = doc.render(
pypdfium2.PdfBitmap.to_pil,
page_indices=[page_num - 1],
scale=dpi / 72,
)
png = list(renderer)[0]
png_image = png.convert("RGB")
return png_image
@st.cache_data()
def get_uploaded_image(in_file):
if isinstance(in_file, Image.Image):
return in_file.convert("RGB")
return Image.open(in_file).convert("RGB")
def resize_image(pil_image):
if pil_image is None:
return
pil_image.thumbnail((MAX_WIDTH, MAX_HEIGHT), Image.Resampling.LANCZOS)
@st.cache_data()
def get_image_size(pil_image):
if pil_image is None:
return MAX_HEIGHT, MAX_WIDTH
height, width = pil_image.height, pil_image.width
return height, width
if __name__ == "__main__":
st.markdown(
"<h1 style='text-align: center;'><a href='https://github.com/RapidAI/RapidLatexOCR' style='text-decoration: none'>Rapid ⚡︎ LaTeX OCR</a></h1>",
unsafe_allow_html=True,
)
st.markdown(
"""
<p align="center">
<a href=""><img src="https://img.shields.io/badge/Python->=3.6,<3.12-aff.svg"></a>
<a href=""><img src="https://img.shields.io/badge/OS-Linux%2C%20Win%2C%20Mac-pink.svg"></a>
<a href="https://pepy.tech/project/rapid_latex_ocr"><img src="https://static.pepy.tech/personalized-badge/rapid_latex_ocr?period=total&units=abbreviation&left_color=grey&right_color=blue&left_text=Downloads"></a>
<a href="https://pypi.org/project/rapid_latex_ocr/"><img alt="PyPI" src="https://img.shields.io/pypi/v/rapid_latex_ocr"></a>
<a href="https://semver.org/"><img alt="SemVer2.0" src="https://img.shields.io/badge/SemVer-2.0-brightgreen"></a>
<a href="https://github.com/psf/black"><img src="https://img.shields.io/badge/code%20style-black-000000.svg"></a>
<a href="https://github.com/RapidAI/RapidLatexOCR"><img src="https://img.shields.io/badge/Github-link-brightgreen.svg"></a>
</p>
""",
unsafe_allow_html=True,
)
in_file = st.file_uploader(
"PDF file or image:", type=["pdf", "png", "jpg", "jpeg", "gif", "webp"]
)
if in_file is None:
st.stop()
filetype = in_file.type
if "pdf" in filetype:
page_count = page_count(in_file)
page_number = st.number_input(
f"Page number out of {page_count}:",
min_value=1,
value=1,
max_value=page_count,
)
pil_image = get_page_image(in_file, page_number)
else:
pil_image = get_uploaded_image(in_file)
resize_image(pil_image)
canvas_hash = get_canvas_hash(pil_image) if pil_image else "canvas"
model = load_model_cached()
canvas_result = st_canvas(
fill_color="rgba(255, 165, 0, 0.1)",
stroke_width=1,
stroke_color="#FFAA00",
background_color="#FFF",
background_image=pil_image,
update_streamlit=True,
height=get_image_size(pil_image)[0],
width=get_image_size(pil_image)[1],
drawing_mode="rect",
point_display_radius=0,
key=canvas_hash,
)
if canvas_result.json_data is not None:
objects = pd.json_normalize(canvas_result.json_data["objects"])
bbox_list = None
if objects.shape[0] > 0:
boxes = objects[objects["type"] == "rect"][
["left", "top", "width", "height"]
]
boxes["right"] = boxes["left"] + boxes["width"]
boxes["bottom"] = boxes["top"] + boxes["height"]
bbox_list = boxes[["left", "top", "right", "bottom"]].values.tolist()
if bbox_list:
bbox_nums = len(bbox_list)
for i, bbox in enumerate(bbox_list):
input_img = pil_image.crop(bbox)
rec_res, elapse = model(np.array(input_img))
st.markdown(f"#### {i + 1}. (cost: {elapse:.3f}s)")
st.latex(rec_res)
st.code(rec_res)
|