Spaces:
Runtime error
Runtime error
File size: 4,489 Bytes
1689fdb b2c2fa2 1689fdb 52ca21c 1689fdb 32d3bb8 1689fdb 1869205 ec7e02b 1869205 1689fdb b2c2fa2 1689fdb 32d3bb8 1689fdb b2c2fa2 3df418f b2c2fa2 ab85ffe b2c2fa2 1689fdb b2c2fa2 1689fdb b2c2fa2 1689fdb b2c2fa2 5751f79 1689fdb 228599b 1689fdb b2c2fa2 1689fdb 228599b 1689fdb 32e270a aa52b57 5e67ac5 228599b 32e270a 1689fdb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
from huggingface_hub import hf_hub_download
import re
from PIL import Image
import requests
from nougat.dataset.rasterize import rasterize_paper
from transformers import NougatProcessor, VisionEncoderDecoderModel
import torch
import gradio as gr
import uuid
import os
processor = NougatProcessor.from_pretrained("facebook/nougat-small")
model = VisionEncoderDecoderModel.from_pretrained("facebook/nougat-small")
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
def get_pdf(pdf_link):
unique_filename = f"{os.getcwd()}/downloaded_paper_{uuid.uuid4().hex}.pdf"
response = requests.get(pdf_link)
if response.status_code == 200:
with open(unique_filename, 'wb') as pdf_file:
pdf_file.write(response.content)
print("PDF downloaded successfully.")
else:
print("Failed to download the PDF.")
return unique_filename
def predict(image):
# prepare PDF image for the model
image = Image.open(image)
pixel_values = processor(image, return_tensors="pt").pixel_values
# generate transcription (here we only generate 30 tokens)
outputs = model.generate(
pixel_values.to(device),
min_length=1,
max_new_tokens=1500,
bad_words_ids=[[processor.tokenizer.unk_token_id]],
)
page_sequence = processor.batch_decode(outputs, skip_special_tokens=True)[0]
page_sequence = processor.post_process_generation(page_sequence, fix_markdown=False)
return page_sequence
def inference(pdf_file, pdf_link, file_btn):
if pdf_file is None:
if pdf_link == '':
print("No file is uploaded and No link is provided")
return "No data provided. Upload a pdf file or provide a pdf link and try again!"
else:
file_name = get_pdf(pdf_link)
else:
file_name = pdf_file.name
pdf_name = pdf_file.name.split('/')[-1].split('.')[0]
images = rasterize_paper(file_name, return_pil=True)
sequence = ""
#ย infer for every page and concat
for image in images:
sequence += predict(image)
content = sequence.replace(r'\(', '$').replace(r'\)', '$').replace(r'\[', '$$').replace(r'\]', '$$')
if file_btn:
with open(f"{os.getcwd()}/output.txt","w+") as f:
f.write(content)
f.close()
file_path = f"{os.getcwd()}/output.txt"
return content, file_path
css = """
#mkd {
height: 500px;
overflow: auto;
border: 1px solid #ccc;
}
"""
with gr.Blocks(css=css) as demo:
gr.HTML("<h1><center>Nougat: Neural Optical Understanding for Academic Documents ๐ซ<center><h1>")
gr.HTML("<h3><center>Lukas Blecher et al. <a href='https://arxiv.org/pdf/2308.13418.pdf' target='_blank'>Paper</a>, <a href='https://facebookresearch.github.io/nougat/'>Project</a><center></h3>")
gr.HTML("<h3><center>This demo is based on transformers implementation of Nougat ๐ค<center><h3>")
with gr.Row():
mkd = gr.Markdown('<h4><center>Upload a PDF</center></h4>',scale=1)
mkd = gr.Markdown('<h4><center><i>OR</i></center></h4>',scale=1)
mkd = gr.Markdown('<h4><center>Provide a PDF link</center></h4>',scale=1)
with gr.Row(equal_height=True):
pdf_file = gr.File(label='PDF ๐', file_count='single', scale=1)
pdf_link = gr.Textbox(placeholder='Enter an arxiv link here', label='Link to Paper๐', scale=1)
with gr.Row():
file_btn = gr.Checkbox(label='Download output as file ๐')
with gr.Row():
btn = gr.Button('Run Nougat ๐ซ')
with gr.Row():
clr = gr.Button('Clear Inputs & Outputs ๐งผ')
output_headline = gr.Markdown("## PDF converted to markup language through Nougat-OCR๐")
with gr.Row():
parsed_output = gr.Markdown(elem_id='mkd', value='Output Text ๐')
output_file = gr.File(file_types = ["txt"], label="Output File ๐")
btn.click(inference, [pdf_file, pdf_link, file_btn], [parsed_output, output_file])
clr.click(lambda : (gr.update(value=None),
gr.update(value=None),
gr.update(value=None),
gr.update(value=None)),
[],
[pdf_file, pdf_link, file_btn, parsed_output, output_file]
)
gr.Examples(
[["nougat.pdf", "", True], [None, "https://arxiv.org/pdf/2308.08316.pdf", True]],
inputs = [pdf_file, pdf_link, file_btn],
outputs = [parsed_output, output_file],
fn=inference,
cache_examples=True,
label='Click on any Examples below to get Nougat OCR results quickly:'
)
demo.queue()
demo.launch(debug=True) |