thinkersloop commited on
Commit
ce636e9
1 Parent(s): bf2e5d7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -11
app.py CHANGED
@@ -1,25 +1,50 @@
1
- import gradio as gr
2
  import torch
3
- from PIL import Image
4
 
5
- from donut import DonutModel
 
 
 
 
6
 
7
  def demo_process(input_img):
8
- global pretrained_model, task_prompt, task_name
9
  # input_img = Image.fromarray(input_img)
10
- output = pretrained_model.inference(image=input_img, prompt=task_prompt)["predictions"][0]
11
- return output
12
 
13
- task_prompt = f"<s_cord-v2>"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  image = Image.open("./sample_1.jpg")
16
  image.save("cord_sample_1.png")
17
  image = Image.open("./sample_2.jpg")
18
  image.save("cord_sample_2.png")
19
 
20
- pretrained_model = DonutModel.from_pretrained("thinkersloop/donut-demo")
21
- pretrained_model.encoder.to(torch.bfloat16)
22
- pretrained_model.eval()
23
 
24
  demo = gr.Interface(
25
  fn=demo_process,
@@ -31,4 +56,4 @@ demo = gr.Interface(
31
  cache_examples=False,
32
  )
33
 
34
- demo.launch()
 
 
1
  import torch
2
+ import re
3
 
4
+ import gradio as gr
5
+
6
+ from PIL import Image
7
+ from transformers import DonutProcessor, VisionEncoderDecoderModel
8
+
9
 
10
  def demo_process(input_img):
 
11
  # input_img = Image.fromarray(input_img)
 
 
12
 
13
+ processor = DonutProcessor.from_pretrained("thinkersloop/donut-demo")
14
+ pretrained_model = VisionEncoderDecoderModel.from_pretrained("thinkersloop/donut-demo")
15
+
16
+ device = "cuda" if torch.cuda.is_available() else "cpu"
17
+ pretrained_model.to(device)
18
+
19
+ pixel_values = processor(image, return_tensors="pt").pixel_values
20
+
21
+ task_prompt = "<s_cord-v2>"
22
+ decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt")["input_ids"]
23
+
24
+ outputs = pretrained_model.generate(pixel_values.to(device),
25
+ decoder_input_ids=decoder_input_ids.to(device),
26
+ max_length=pretrained_model.decoder.config.max_position_embeddings,
27
+ early_stopping=True,
28
+ pad_token_id=processor.tokenizer.pad_token_id,
29
+ eos_token_id=processor.tokenizer.eos_token_id,
30
+ use_cache=True,
31
+ num_beams=1,
32
+ bad_words_ids=[[processor.tokenizer.unk_token_id]],
33
+ return_dict_in_generate=True,
34
+ output_scores=True,)
35
+ sequence = processor.batch_decode(outputs.sequences)[0]
36
+ sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
37
+ sequence = re.sub(r"<.*?>", "", sequence, count=1).strip()
38
+
39
+ return processor.token2json(sequence)
40
+
41
+ # task_prompt = f"<s_cord-v2>"
42
 
43
  image = Image.open("./sample_1.jpg")
44
  image.save("cord_sample_1.png")
45
  image = Image.open("./sample_2.jpg")
46
  image.save("cord_sample_2.png")
47
 
 
 
 
48
 
49
  demo = gr.Interface(
50
  fn=demo_process,
 
56
  cache_examples=False,
57
  )
58
 
59
+ demo.launch()