File size: 4,489 Bytes
1689fdb
 
 
 
 
 
 
b2c2fa2
 
 
1689fdb
52ca21c
 
1689fdb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32d3bb8
1689fdb
 
 
1869205
ec7e02b
1869205
1689fdb
 
 
b2c2fa2
1689fdb
 
 
 
 
 
 
 
 
 
 
 
 
32d3bb8
1689fdb
 
 
 
 
b2c2fa2
3df418f
b2c2fa2
 
ab85ffe
 
 
b2c2fa2
1689fdb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b2c2fa2
 
1689fdb
 
b2c2fa2
 
1689fdb
b2c2fa2
5751f79
 
 
1689fdb
228599b
1689fdb
 
b2c2fa2
1689fdb
 
228599b
1689fdb
32e270a
aa52b57
5e67ac5
228599b
32e270a
 
 
 
1689fdb
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
from huggingface_hub import hf_hub_download
import re
from PIL import Image
import requests
from nougat.dataset.rasterize import rasterize_paper
from transformers import NougatProcessor, VisionEncoderDecoderModel
import torch
import gradio as gr
import uuid
import os

processor = NougatProcessor.from_pretrained("facebook/nougat-small")
model = VisionEncoderDecoderModel.from_pretrained("facebook/nougat-small")

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device) 


def get_pdf(pdf_link):
  unique_filename = f"{os.getcwd()}/downloaded_paper_{uuid.uuid4().hex}.pdf"

  response = requests.get(pdf_link)

  if response.status_code == 200:
      with open(unique_filename, 'wb') as pdf_file:
          pdf_file.write(response.content)
      print("PDF downloaded successfully.")
  else:
      print("Failed to download the PDF.")
  return unique_filename



def predict(image):
  # prepare PDF image for the model
  image = Image.open(image)
  pixel_values = processor(image, return_tensors="pt").pixel_values

  # generate transcription (here we only generate 30 tokens)
  outputs = model.generate(
      pixel_values.to(device),
      min_length=1,
      max_new_tokens=1500,
      bad_words_ids=[[processor.tokenizer.unk_token_id]],
  )

  page_sequence = processor.batch_decode(outputs, skip_special_tokens=True)[0]
  page_sequence = processor.post_process_generation(page_sequence, fix_markdown=False)
  return page_sequence



def inference(pdf_file, pdf_link, file_btn):
  if pdf_file is None:
    if pdf_link == '':
      print("No file is uploaded and No link is provided")
      return "No data provided. Upload a pdf file or provide a pdf link and try again!"
    else:
      file_name = get_pdf(pdf_link)
  else:
    file_name = pdf_file.name
    pdf_name = pdf_file.name.split('/')[-1].split('.')[0]

  images = rasterize_paper(file_name, return_pil=True)
  sequence = ""
  #ย infer for every page and concat
  for image in images:
    sequence += predict(image)


  content = sequence.replace(r'\(', '$').replace(r'\)', '$').replace(r'\[', '$$').replace(r'\]', '$$')

  if file_btn:
    with open(f"{os.getcwd()}/output.txt","w+") as f:
      f.write(content)
      f.close()
    file_path = f"{os.getcwd()}/output.txt"
      
  return content, file_path


css = """
  #mkd {
    height: 500px; 
    overflow: auto; 
    border: 1px solid #ccc; 
  }
"""

with gr.Blocks(css=css) as demo:
  gr.HTML("<h1><center>Nougat: Neural Optical Understanding for Academic Documents ๐Ÿซ<center><h1>")
  gr.HTML("<h3><center>Lukas Blecher et al. <a href='https://arxiv.org/pdf/2308.13418.pdf' target='_blank'>Paper</a>, <a href='https://facebookresearch.github.io/nougat/'>Project</a><center></h3>")
  gr.HTML("<h3><center>This demo is based on transformers implementation of Nougat ๐Ÿค—<center><h3>")


  with gr.Row():
    mkd = gr.Markdown('<h4><center>Upload a PDF</center></h4>',scale=1)
    mkd = gr.Markdown('<h4><center><i>OR</i></center></h4>',scale=1)
    mkd = gr.Markdown('<h4><center>Provide a PDF link</center></h4>',scale=1)
  
  
  with gr.Row(equal_height=True):
    pdf_file = gr.File(label='PDF ๐Ÿ“‘', file_count='single', scale=1)
    pdf_link = gr.Textbox(placeholder='Enter an arxiv link here', label='Link to Paper๐Ÿ”—', scale=1)
  with gr.Row():
    file_btn = gr.Checkbox(label='Download output as file ๐Ÿ“‘')
  with gr.Row():
    btn = gr.Button('Run Nougat ๐Ÿซ')
  with gr.Row():
    clr = gr.Button('Clear Inputs & Outputs ๐Ÿงผ')

  output_headline = gr.Markdown("## PDF converted to markup language through Nougat-OCR๐Ÿ‘‡")
  with gr.Row():
      parsed_output = gr.Markdown(elem_id='mkd', value='Output Text ๐Ÿ“')
      output_file = gr.File(file_types = ["txt"], label="Output File ๐Ÿ“‘")
  
  btn.click(inference, [pdf_file, pdf_link, file_btn], [parsed_output, output_file])
  clr.click(lambda : (gr.update(value=None), 
                      gr.update(value=None),
                      gr.update(value=None), 
                      gr.update(value=None)), 
             [], 
             [pdf_file, pdf_link, file_btn, parsed_output, output_file]
            )
  gr.Examples(
      [["nougat.pdf", "", True], [None, "https://arxiv.org/pdf/2308.08316.pdf", True]],
      inputs = [pdf_file, pdf_link, file_btn],
      outputs = [parsed_output, output_file],
      fn=inference,
      cache_examples=True,
      label='Click on any Examples below to get Nougat OCR results quickly:'
  )


    
demo.queue()
demo.launch(debug=True)