gwkrsrch commited on
Commit
77a91d3
1 Parent(s): e3079c2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -42
app.py CHANGED
@@ -5,60 +5,32 @@ MIT License
5
 
6
  https://github.com/clovaai/donut
7
  """
8
- import argparse
9
-
10
  import gradio as gr
11
  import torch
12
  from PIL import Image
13
 
14
  from donut import DonutModel
15
 
16
-
17
- def demo_process_vqa(input_img, question):
18
- global pretrained_model, task_prompt, task_name
19
- input_img = Image.fromarray(input_img)
20
- user_prompt = task_prompt.replace("{user_input}", question)
21
- output = pretrained_model.inference(input_img, prompt=user_prompt)["predictions"][0]
22
- return output
23
-
24
-
25
  def demo_process(input_img):
26
  global pretrained_model, task_prompt, task_name
27
  input_img = Image.fromarray(input_img)
28
  output = pretrained_model.inference(image=input_img, prompt=task_prompt)["predictions"][0]
29
  return output
30
 
 
31
 
32
- if __name__ == "__main__":
33
- parser = argparse.ArgumentParser()
34
- parser.add_argument("--task", type=str, default="cord-v2")
35
- parser.add_argument("--pretrained_path", type=str, default="naver-clova-ix/donut-base-finetuned-cord-v2")
36
- parser.add_argument("--port", type=int, default=None)
37
- parser.add_argument("--url", type=str, default=None)
38
- parser.add_argument("--sample_img_path", type=str, default="./sample_image_cord_test_receipt_00004.jpg")
39
- args, left_argv = parser.parse_known_args()
40
 
41
- task_name = args.task
42
- if "docvqa" == task_name:
43
- task_prompt = "<s_docvqa><s_question>{user_input}</s_question><s_answer>"
44
- else: # rvlcdip, cord, ...
45
- task_prompt = f"<s_{task_name}>"
46
-
47
- example_sample = []
48
- if args.sample_img_path:
49
- image = Image.open(args.sample_img_path)
50
- image.save("cord_sample_receipt.jpg")
51
- example_sample.append("cord_sample_receipt.jpg")
52
 
53
- pretrained_model = DonutModel.from_pretrained(args.pretrained_path)
54
- pretrained_model.encoder.to(torch.bfloat16)
55
- pretrained_model.eval()
56
-
57
- demo = gr.Interface(
58
- fn=demo_process_vqa if task_name == "docvqa" else demo_process,
59
- inputs=["image", "text"] if task_name == "docvqa" else "image",
60
- outputs="json",
61
- title=f"Donut 🍩 demonstration for `{task_name}` task",
62
- examples=[example_sample] if example_sample else None,
63
- )
64
- demo.launch(server_name=args.url, server_port=args.port)
 
5
 
6
  https://github.com/clovaai/donut
7
  """
 
 
8
  import gradio as gr
9
  import torch
10
  from PIL import Image
11
 
12
  from donut import DonutModel
13
 
 
 
 
 
 
 
 
 
 
14
  def demo_process(input_img):
15
  global pretrained_model, task_prompt, task_name
16
  input_img = Image.fromarray(input_img)
17
  output = pretrained_model.inference(image=input_img, prompt=task_prompt)["predictions"][0]
18
  return output
19
 
20
+ task_prompt = f"<s_cord-v2>"
21
 
22
+ image = Image.open("./sample_image_cord_test_receipt_00004.jpg")
23
+ image.save("cord_sample_receipt.jpg")
 
 
 
 
 
 
24
 
25
+ pretrained_model = DonutModel.from_pretrained("naver-clova-ix/donut-base-finetuned-cord-v2")
26
+ pretrained_model.encoder.to(torch.bfloat16)
27
+ pretrained_model.eval()
 
 
 
 
 
 
 
 
28
 
29
+ demo = gr.Interface(
30
+ fn=demo_process,
31
+ inputs=gr.inputs.Image(type="pil"),
32
+ outputs="json",
33
+ title=f"Donut 🍩 demonstration for `{task_name}` task",
34
+ examples=[["cord_sample_receipt.jpg"]],
35
+ )
36
+ demo.launch(debug=True)