Iqra Ali commited on
Commit
a394d66
·
1 Parent(s): e6d304f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -25
app.py CHANGED
@@ -1,35 +1,28 @@
1
 
2
  import re
3
  import gradio as gr
 
4
  import torch
5
  from transformers import DonutProcessor, VisionEncoderDecoderModel
6
- import transformers
7
- from PIL import Image
8
- import random
9
- import numpy as np
10
-
11
- # hidde logs
12
- transformers.logging.disable_default_handler()
13
-
14
-
15
- # Load our model from Hugging Face
16
- processor = DonutProcessor.from_pretrained("Iqra56/Donut_Updated")
17
- model = VisionEncoderDecoderModel.from_pretrained("Iqra56/Donut_Updated")
18
 
19
- # Move model to GPU
 
 
 
 
 
20
  device = "cuda" if torch.cuda.is_available() else "cpu"
21
  model.to(device)
22
 
23
- # Load random document image from the test set
24
- test_sample = processed_dataset["test"][random.randint(1,7)]
25
-
26
- def run_prediction(sample, model=model, processor=processor):
27
- # prepare inputs
28
- pixel_values = torch.tensor(test_sample["pixel_values"]).unsqueeze(0)
29
  task_prompt = "<s>"
30
  decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
31
-
32
- # run inference
33
  outputs = model.generate(
34
  pixel_values.to(device),
35
  decoder_input_ids=decoder_input_ids.to(device),
@@ -50,18 +43,18 @@ def run_prediction(sample, model=model, processor=processor):
50
 
51
  return processor.token2json(sequence)
52
 
53
- description = "Gradio Demo for Donut, an instance of `VisionEncoderDecoderModel` fine-tuned on DocVQA (document visual question answering). To use it, simply upload your image and type a question and click 'submit', or click one of the examples to load them. Read more at the links below."
54
  article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2111.15664' target='_blank'>Donut: OCR-free Document Understanding Transformer</a> | <a href='https://github.com/clovaai/donut' target='_blank'>Github Repo</a></p>"
55
 
56
  demo = gr.Interface(
57
  fn=process_document,
58
- inputs=["image", "text"],
59
  outputs="json",
60
- title="Demo: Donut 🍩 for DocVQA",
61
  description=description,
62
  article=article,
63
  enable_queue=True,
64
- examples=[["example_1.png", "When is the coffee break?"], ["example_2.jpeg", "What's the population of Stoddard?"]],
65
  cache_examples=False)
66
 
67
  demo.launch()
 
1
 
2
  import re
3
  import gradio as gr
4
+
5
  import torch
6
  from transformers import DonutProcessor, VisionEncoderDecoderModel
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
+ #processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-cord-v2")
9
+ #model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-cord-v2")
10
+ #processor = DonutProcessor.from_pretrained("Iqra56/ENGLISHDONUT")
11
+ #model = VisionEncoderDecoderModel.from_pretrained("Iqra56/ENGLISHDONUT")
12
+ processor = DonutProcessor.from_pretrained("Iqra56/DONUTWOKEYS")
13
+ model = VisionEncoderDecoderModel.from_pretrained("Iqra56/DONUTWOKEYS")
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
15
  model.to(device)
16
 
17
+ def process_document(image):
18
+ # prepare encoder inputs
19
+ pixel_values = processor(image, return_tensors="pt").pixel_values
20
+
21
+ # prepare decoder inputs
 
22
  task_prompt = "<s>"
23
  decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
24
+
25
+ # generate answer
26
  outputs = model.generate(
27
  pixel_values.to(device),
28
  decoder_input_ids=decoder_input_ids.to(device),
 
43
 
44
  return processor.token2json(sequence)
45
 
46
+ description = "Gradio Demo for Donut, an instance of `VisionEncoderDecoderModel` fine-tuned on CORD (document parsing). To use it, simply upload your image and click 'submit', or click one of the examples to load them. Read more at the links below."
47
  article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2111.15664' target='_blank'>Donut: OCR-free Document Understanding Transformer</a> | <a href='https://github.com/clovaai/donut' target='_blank'>Github Repo</a></p>"
48
 
49
  demo = gr.Interface(
50
  fn=process_document,
51
+ inputs="image",
52
  outputs="json",
53
+ title="Demo: Donut 🍩 for Document Parsing",
54
  description=description,
55
  article=article,
56
  enable_queue=True,
57
+ examples=[[""], [""], [""]],
58
  cache_examples=False)
59
 
60
  demo.launch()