Files changed (1) hide show
  1. app.py +97 -1
app.py CHANGED
@@ -1,8 +1,104 @@
1
  import gradio as gr
 
 
 
2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
  def get_answer(image, question) -> str:
5
- return "I don't know"
 
 
6
 
7
  with gr.Blocks() as demo:
8
  with gr.Row():
 
1
  import gradio as gr
2
+ import torch, os, json, requests
3
+ from PIL import Image
4
+ from transformers import DonutProcessor, VisionEncoderDecoderModel, VisionEncoderDecoderConfig
5
 
6
+ def load_image_from_URL(url):
7
+ res = requests.get(url)
8
+
9
+ if res.status_code == 200:
10
+ img = Image.open(requests.get(url, stream = True).raw)
11
+
12
+ if img.mode == "RGBA":
13
+ img = img.convert("RGB")
14
+
15
+ return img
16
+
17
+ return None
18
+
19
+ class OCRVQAModel(torch.nn.Module):
20
+ def add_tokens(self, list_of_tokens):
21
+ self.added_tokens.update(list_of_tokens)
22
+ newly_added_num = self.processor.tokenizer.add_tokens(list_of_tokens)
23
+
24
+ if newly_added_num > 0:
25
+ self.donut.decoder.resize_token_embeddings(len(self.processor.tokenizer))
26
+
27
+ def __init__(self, config):
28
+ super().__init__()
29
+
30
+ self.model_name_or_path = config['donut']
31
+ self.processor_name_or_path = config['processor']
32
+ self.config_name_or_path = config['config']
33
+
34
+ self.donut_config = VisionEncoderDecoderConfig.from_pretrained(self.config_name_or_path)
35
+ self.donut_config.encoder.image_size = [800, 600]
36
+ self.donut_config.decoder.max_length = 64
37
+
38
+ self.processor = DonutProcessor.from_pretrained(self.processor_name_or_path)
39
+ self.donut = VisionEncoderDecoderModel.from_pretrained(self.model_name_or_path, config = self.donut_config)
40
+
41
+ self.added_tokens = set([])
42
+ self.setup()
43
+
44
+ def setup(self):
45
+ self.add_tokens(["<yes/>", "<no/>"])
46
+ self.processor.feature_extractor.size = self.donut_config.encoder.image_size[::-1]
47
+ self.processor.feature_extractor.do_align_long_axis = False
48
+
49
+ def inference(self, image_src, prompt, device):
50
+ if os.path.exists(image_src):
51
+ image = Image.open(image_src)
52
+ else:
53
+ image = load_image_from_URL(image_src)
54
+
55
+ if not Image:
56
+ return {
57
+ 'question': prompt,
58
+ 'answer': 'Some error occurred during inference time.'
59
+ }
60
+
61
+ self.donut.eval()
62
+ with torch.no_grad():
63
+ image_ids = self.processor(image, return_tensors="pt").pixel_values.to(device)
64
+
65
+ question = f'<s_docvqa><s_question>{prompt}</s_question><s_answer>'
66
+
67
+ embedded_question = self.processor.tokenizer(
68
+ question,
69
+ add_special_tokens = False,
70
+ return_tensors = "pt"
71
+ )["input_ids"].to(device)
72
+
73
+ outputs = self.donut.generate(
74
+ image_ids,
75
+ decoder_input_ids=embedded_question,
76
+ max_length = self.donut.decoder.config.max_position_embeddings,
77
+ early_stopping = True,
78
+ pad_token_id = self.processor.tokenizer.pad_token_id,
79
+ eos_token_id = self.processor.tokenizer.eos_token_id,
80
+ use_cache = True,
81
+ num_beams = 1,
82
+ bad_words_ids = [
83
+ [self.processor.tokenizer.unk_token_id]
84
+ ],
85
+ return_dict_in_generate = True
86
+ )
87
+
88
+ return self.processor.token2json(self.processor.batch_decode(outputs.sequences)[0])
89
+
90
+ model = OCRVQAModel(
91
+ 'ndtran/donut_ocr-vqa-200k',
92
+ 'ndtran/donut_ocr-vqa-200k'
93
+ )
94
+
95
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
96
+ model = model.to(device)
97
 
98
  def get_answer(image, question) -> str:
99
+ global model, device
100
+ result = model.inference(image, question, device)
101
+ return result.get('answer', 'I don\'t know :<')
102
 
103
  with gr.Blocks() as demo:
104
  with gr.Row():