File size: 3,973 Bytes
1689fdb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32d3bb8
1689fdb
 
 
1869205
ec7e02b
1869205
1689fdb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32d3bb8
1689fdb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32e270a
 
 
 
 
 
 
 
1689fdb
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
from huggingface_hub import hf_hub_download
import re
from PIL import Image
import requests
from nougat.dataset.rasterize import rasterize_paper

from transformers import NougatProcessor, VisionEncoderDecoderModel
import torch

processor = NougatProcessor.from_pretrained("nielsr/nougat")
model = VisionEncoderDecoderModel.from_pretrained("nielsr/nougat")

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device) 


def get_pdf(pdf_link):
  unique_filename = f"{os.getcwd()}/downloaded_paper_{uuid.uuid4().hex}.pdf"

  response = requests.get(pdf_link)

  if response.status_code == 200:
      with open(unique_filename, 'wb') as pdf_file:
          pdf_file.write(response.content)
      print("PDF downloaded successfully.")
  else:
      print("Failed to download the PDF.")
  return unique_filename



def predict(image):
  # prepare PDF image for the model
  image = Image.open(image)
  pixel_values = processor(image, return_tensors="pt").pixel_values

  # generate transcription (here we only generate 30 tokens)
  outputs = model.generate(
      pixel_values.to(device),
      min_length=1,
      max_new_tokens=1500,
      bad_words_ids=[[processor.tokenizer.unk_token_id]],
  )

  page_sequence = processor.batch_decode(outputs, skip_special_tokens=True)[0]
  page_sequence = processor.post_process_generation(page_sequence, fix_markdown=False)
  return page_sequence



def inference(pdf_file, pdf_link):
  if pdf_file is None:
    if pdf_link == '':
      print("No file is uploaded and No link is provided")
      return "No data provided. Upload a pdf file or provide a pdf link and try again!"
    else:
      file_name = get_pdf(pdf_link)
  else:
    file_name = pdf_file.name
    pdf_name = pdf_file.name.split('/')[-1].split('.')[0]

  images = rasterize_paper(file_name, return_pil=True)
  sequence = ""
  # infer for every page and concat
  for image in images:
    sequence += predict(image)


  content = sequence.replace(r'\(', '$').replace(r'\)', '$').replace(r'\[', '$$').replace(r'\]', '$$')
  return content

import gradio as gr
import uuid
import os
import requests
import re

css = """
  #mkd {
    height: 500px; 
    overflow: auto; 
    border: 1px solid #ccc; 
  }
"""

with gr.Blocks(css=css) as demo:
  gr.HTML("<h1><center>Nougat: Neural Optical Understanding for Academic Documents 🍫<center><h1>")
  gr.HTML("<h3><center>Lukas Blecher et al. <a href='https://arxiv.org/pdf/2308.13418.pdf' target='_blank'>Paper</a>, <a href='https://facebookresearch.github.io/nougat/'>Project</a><center></h3>")
  gr.HTML("<h3><center>This demo is based on transformers implementation of Nougat 🤗<center><h3>")


  with gr.Row():
    mkd = gr.Markdown('<h4><center>Upload a PDF</center></h4>',scale=1)
    mkd = gr.Markdown('<h4><center><i>OR</i></center></h4>',scale=1)
    mkd = gr.Markdown('<h4><center>Provide a PDF link</center></h4>',scale=1)
  
  
  with gr.Row(equal_height=True):
    pdf_file = gr.File(label='PDF 📑', file_count='single', scale=1)
    pdf_link = gr.Textbox(placeholder='Enter an arxiv link here', label='Link to Paper🔗', scale=1)

  with gr.Row():
    btn = gr.Button('Run Nougat 🍫')
    clr = gr.Button('Clear 🧼')

  output_headline = gr.Markdown("PDF converted to markup language through Nougat-OCR👇")
  parsed_output = gr.Markdown(elem_id='mkd', value='OCR Output 📝')
  
  btn.click(inference, [pdf_file, pdf_link], parsed_output )
  clr.click(lambda : (gr.update(value=None), 
                      gr.update(value=None),
                      gr.update(value=None)), 
             [], 
             [pdf_file, pdf_link, parsed_output]
            )
  gr.Examples(
      [["nougat.pdf", ""], [None, "https://arxiv.org/pdf/2308.08316.pdf"]],
      inputs = [pdf_file, pdf_link],
      outputs = parsed_output,
      fn=inference,
      cache_examples=True,
      label='Click on any Examples below to get Nougat OCR results quickly:'
  )


    
demo.queue()
demo.launch(debug=True)