Spaces:
Runtime error
Runtime error
Update app.py
#1
by
ndtran
- opened
app.py
CHANGED
@@ -1,8 +1,104 @@
|
|
1 |
import gradio as gr
|
|
|
|
|
|
|
2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
|
4 |
def get_answer(image, question) -> str:
|
5 |
-
|
|
|
|
|
6 |
|
7 |
with gr.Blocks() as demo:
|
8 |
with gr.Row():
|
|
|
1 |
import gradio as gr
|
2 |
+
import torch, os, json, requests
|
3 |
+
from PIL import Image
|
4 |
+
from transformers import DonutProcessor, VisionEncoderDecoderModel, VisionEncoderDecoderConfig
|
5 |
|
6 |
+
def load_image_from_URL(url):
|
7 |
+
res = requests.get(url)
|
8 |
+
|
9 |
+
if res.status_code == 200:
|
10 |
+
img = Image.open(requests.get(url, stream = True).raw)
|
11 |
+
|
12 |
+
if img.mode == "RGBA":
|
13 |
+
img = img.convert("RGB")
|
14 |
+
|
15 |
+
return img
|
16 |
+
|
17 |
+
return None
|
18 |
+
|
19 |
+
class OCRVQAModel(torch.nn.Module):
|
20 |
+
def add_tokens(self, list_of_tokens):
|
21 |
+
self.added_tokens.update(list_of_tokens)
|
22 |
+
newly_added_num = self.processor.tokenizer.add_tokens(list_of_tokens)
|
23 |
+
|
24 |
+
if newly_added_num > 0:
|
25 |
+
self.donut.decoder.resize_token_embeddings(len(self.processor.tokenizer))
|
26 |
+
|
27 |
+
def __init__(self, config):
|
28 |
+
super().__init__()
|
29 |
+
|
30 |
+
self.model_name_or_path = config['donut']
|
31 |
+
self.processor_name_or_path = config['processor']
|
32 |
+
self.config_name_or_path = config['config']
|
33 |
+
|
34 |
+
self.donut_config = VisionEncoderDecoderConfig.from_pretrained(self.config_name_or_path)
|
35 |
+
self.donut_config.encoder.image_size = [800, 600]
|
36 |
+
self.donut_config.decoder.max_length = 64
|
37 |
+
|
38 |
+
self.processor = DonutProcessor.from_pretrained(self.processor_name_or_path)
|
39 |
+
self.donut = VisionEncoderDecoderModel.from_pretrained(self.model_name_or_path, config = self.donut_config)
|
40 |
+
|
41 |
+
self.added_tokens = set([])
|
42 |
+
self.setup()
|
43 |
+
|
44 |
+
def setup(self):
|
45 |
+
self.add_tokens(["<yes/>", "<no/>"])
|
46 |
+
self.processor.feature_extractor.size = self.donut_config.encoder.image_size[::-1]
|
47 |
+
self.processor.feature_extractor.do_align_long_axis = False
|
48 |
+
|
49 |
+
def inference(self, image_src, prompt, device):
|
50 |
+
if os.path.exists(image_src):
|
51 |
+
image = Image.open(image_src)
|
52 |
+
else:
|
53 |
+
image = load_image_from_URL(image_src)
|
54 |
+
|
55 |
+
if not Image:
|
56 |
+
return {
|
57 |
+
'question': prompt,
|
58 |
+
'answer': 'Some error occurred during inference time.'
|
59 |
+
}
|
60 |
+
|
61 |
+
self.donut.eval()
|
62 |
+
with torch.no_grad():
|
63 |
+
image_ids = self.processor(image, return_tensors="pt").pixel_values.to(device)
|
64 |
+
|
65 |
+
question = f'<s_docvqa><s_question>{prompt}</s_question><s_answer>'
|
66 |
+
|
67 |
+
embedded_question = self.processor.tokenizer(
|
68 |
+
question,
|
69 |
+
add_special_tokens = False,
|
70 |
+
return_tensors = "pt"
|
71 |
+
)["input_ids"].to(device)
|
72 |
+
|
73 |
+
outputs = self.donut.generate(
|
74 |
+
image_ids,
|
75 |
+
decoder_input_ids=embedded_question,
|
76 |
+
max_length = self.donut.decoder.config.max_position_embeddings,
|
77 |
+
early_stopping = True,
|
78 |
+
pad_token_id = self.processor.tokenizer.pad_token_id,
|
79 |
+
eos_token_id = self.processor.tokenizer.eos_token_id,
|
80 |
+
use_cache = True,
|
81 |
+
num_beams = 1,
|
82 |
+
bad_words_ids = [
|
83 |
+
[self.processor.tokenizer.unk_token_id]
|
84 |
+
],
|
85 |
+
return_dict_in_generate = True
|
86 |
+
)
|
87 |
+
|
88 |
+
return self.processor.token2json(self.processor.batch_decode(outputs.sequences)[0])
|
89 |
+
|
90 |
+
model = OCRVQAModel(
|
91 |
+
'ndtran/donut_ocr-vqa-200k',
|
92 |
+
'ndtran/donut_ocr-vqa-200k'
|
93 |
+
)
|
94 |
+
|
95 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
96 |
+
model = model.to(device)
|
97 |
|
98 |
def get_answer(image, question) -> str:
|
99 |
+
global model, device
|
100 |
+
result = model.inference(image, question, device)
|
101 |
+
return result.get('answer', 'I don\'t know :<')
|
102 |
|
103 |
with gr.Blocks() as demo:
|
104 |
with gr.Row():
|