MinxuanQin commited on
Commit
6cb5353
1 Parent(s): 502b0e8

first trial with blip

Browse files
Files changed (2) hide show
  1. app.py +34 -0
  2. model_loader.py +203 -0
app.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append(".")
3
+
4
+ import streamlit as st
5
+ import pandas as pd
6
+
7
+ from vqa_demo.model_loader import *
8
+
9
+
10
+ # load dataset
11
+ ds = load_dataset("test")
12
+
13
+ # define selector
14
+ model_name = st.sidebar.selectbox(
15
+ "Select a model: ",
16
+ ('vilt', 'git', 'blip', 'vbert')
17
+ )
18
+
19
+ image_selector_unspecific = st.number_input(
20
+ "Select an image id: ",
21
+ 0, len(ds)
22
+ )
23
+
24
+ # select and display
25
+ sample = ds[image_selector_unspecific]
26
+ image = sample['image']
27
+ image
28
+
29
+ # inference
30
+ question = st.text_input(f"Ask the model a question related to the image: \n"
31
+ f"(e.g. \"{sample['question']}\")")
32
+ args = load_model(model_name) # TODO: cache
33
+ answer = get_answer(args, image, question, model_name)
34
+ st.write("answer")
model_loader.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import hf_hub_download
2
+ from PIL import Image
3
+ import torch
4
+ from datasets import load_dataset, get_dataset_split_names
5
+ import numpy as np
6
+
7
+ import requests
8
+ from transformers import ViltProcessor, ViltForQuestionAnswering
9
+ from transformers import AutoProcessor, AutoModelForCausalLM
10
+ from transformers import BlipProcessor, BlipForQuestionAnswering
11
+ from nltk.corpus import wordnet
12
+
13
+ import os
14
+ import requests
15
+ from tqdm import tqdm
16
+ import timm
17
+
18
+ # VLMO: modify in vlmo/config.py: set test_only -> True
19
+ from datasets import load_dataset, get_dataset_split_names
20
+
21
+ import torch
22
+ import torchvision
23
+ from torchvision.models import resnet50
24
+ import torchvision.transforms as transforms
25
+ from transformers import VisualBertForMultipleChoice, VisualBertForQuestionAnswering, BertTokenizerFast, AutoTokenizer, ViltForQuestionAnswering
26
+
27
+ from PIL import Image
28
+ from nltk.corpus import wordnet
29
+ import time
30
+
31
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
32
+ VQA_URL = "https://dl.fbaipublicfiles.com/pythia/data/answers_vqa.txt"
33
+
34
+ # load processor and model
35
+ def load_model(name):
36
+ if name == "vilt":
37
+ processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
38
+ model = ViltForQuestionAnswering.from_pretrained("CARETS/vilt_neg_model")
39
+ elif name == "git":
40
+ processor = AutoProcessor.from_pretrained("microsoft/git-base-vqav2")
41
+ model = AutoModelForCausalLM.from_pretrained("microsoft/git-base-vqav2")
42
+ elif name == "blip":
43
+ processor = BlipProcessor.from_pretrained('Salesforce/blip-vqa-base')
44
+ model = BlipForQuestionAnswering.from_pretrained('Salesforce/blip-vqa-base')
45
+ elif name == "vbert":
46
+ processor = AutoTokenizer.from_pretrained("bert-base-uncased")
47
+ model = VisualBertForQuestionAnswering.from_pretrained("uclanlp/visualbert-vqa")
48
+ else:
49
+ raise ValueError("invalid model name: ", name)
50
+
51
+ return (processor, model)
52
+
53
+
54
+ def load_dataset(type):
55
+ if type == "train":
56
+ return load_dataset("HuggingFaceM4/VQAv2", split="train", streaming=False)
57
+ elif type == "test":
58
+ return load_dataset("HuggingFaceM4/VQAv2", split="validation", streaming=False)
59
+ else:
60
+ raise ValueError("invalid dataset: ", type)
61
+
62
+
63
+ def tokenize_function(examples, processor):
64
+ sample = {}
65
+ sample['inputs'] = processor(images=examples['image'], text=examples['question'], return_tensors="pt")
66
+ sample['outputs'] = examples['multiple_choice_answer']
67
+ return sample
68
+
69
+
70
+ def label_count_list(labels):
71
+ res = {}
72
+ keys = set(labels)
73
+ for key in keys:
74
+ res[key] = labels.count(key)
75
+ return res
76
+
77
+
78
+ def get_item(image, question, tokenizer, image_model, model_name):
79
+ inputs = tokenizer(
80
+ question,
81
+ # padding='max_length',
82
+ # truncation=True,
83
+ # max_length=128,
84
+ return_tensors='pt'
85
+ )
86
+ visual_embeds = get_img_feats(image, image_model=image_model, name=model_name)\
87
+ .squeeze(2, 3).unsqueeze(0)
88
+ visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long)
89
+ visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float)
90
+ upd_dict = {
91
+ "visual_embeds": visual_embeds,
92
+ "visual_token_type_ids": visual_token_type_ids,
93
+ "visual_attention_mask": visual_attention_mask,
94
+ }
95
+ inputs.update(upd_dict)
96
+
97
+ return upd_dict, inputs
98
+
99
+
100
+ def get_img_feats(image, image_model, new_size=None, name='resnet50'):
101
+ if name == "resnet50":
102
+ image_model = torch.nn.Sequential(*list(image_model.children())[:-1])
103
+
104
+ # apply transforms when necessary
105
+ if new_size is not None:
106
+ transfrom_f = transforms.Resize((new_size, new_size), interpolation=transforms.InterpolationMode.LANCZOS)
107
+ image = transfrom_f(image)
108
+
109
+ transform = transforms.Compose([
110
+ transforms.ToTensor(), # Convert PIL Image back to tensor
111
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
112
+ ])
113
+
114
+ # get features
115
+ image = transform(image)
116
+ if name == "resnet50":
117
+ image_features = image_model(image.unsqueeze(0))
118
+ elif name == "vitb16":
119
+ image_features = image_model.forward_features(image.unsqueeze(0))
120
+ return image_features
121
+
122
+
123
+ def get_data(query, delim=","):
124
+ assert isinstance(query, str)
125
+ if os.path.isfile(query):
126
+ with open(query) as f:
127
+ data = eval(f.read())
128
+ else:
129
+ req = requests.get(query)
130
+ try:
131
+ data = requests.json()
132
+ except Exception:
133
+ data = req.content.decode()
134
+ assert data is not None, "could not connect"
135
+ try:
136
+ data = eval(data)
137
+ except Exception:
138
+ data = data.split("\n")
139
+ req.close()
140
+ return data
141
+
142
+ def err_msg():
143
+ print("Load error, try again")
144
+ return "[ERROR]"
145
+
146
+
147
+ def get_answer(model_loader_args, img, question, model_name):
148
+ processor, model = model_loader_args[0], model_loader_args[1]
149
+ if model_name == "vilt":
150
+ try:
151
+ encoding = processor(images=img, text=question, return_tensors="pt")
152
+ except Exception:
153
+ return err_msg()
154
+ else:
155
+ outputs = model(**encoding)
156
+ logits = outputs.logits
157
+ idx = logits.argmax(-1).item()
158
+ pred = model.config.id2label[idx]
159
+
160
+ elif model_name == "git":
161
+ try:
162
+ pixel_values = processor(images=img, return_tensors="pt").pixel_values
163
+ input_ids = processor(text=question, add_special_tokens=False).input_ids
164
+ input_ids = [processor.tokenizer.cls_token_id] + input_ids
165
+ input_ids = torch.tensor(input_ids).unsqueeze(0)
166
+ except Exception:
167
+ return err_msg()
168
+ else:
169
+ generate_ids = model.generate(pixel_values=pixel_values, input_ids=input_ids, max_length=50)
170
+ output = processor.batch_decode(generate_ids, skip_special_tokens=True)
171
+ output = output[0]
172
+ pred = output.split('?')[-1]
173
+ pred = pred.strip()
174
+
175
+ elif model_name == "vbert":
176
+ vqa_answers = get_data(VQA_URL)
177
+ try:
178
+ # load question and image (processor = tokenizer)
179
+ _, inputs = get_item(img, question, processor, model_name)
180
+ outputs = model(**inputs)
181
+ except Exception:
182
+ return err_msg()
183
+ else:
184
+ answer_idx = torch.argmax(outputs.logits, dim=1).item() # from 3129
185
+ pred = vqa_answers[answer_idx]
186
+
187
+ elif model_name == "blip":
188
+ try:
189
+ pixel_values = processor(images=img, return_tensors="pt").pixel_values
190
+ blip_ques = processor.tokenizer.cls_token + question
191
+ batch_input_ids = processor(text=blip_ques, add_special_tokens=False).input_ids
192
+ batch_input_ids = torch.tensor(batch_input_ids).unsqueeze(0)
193
+
194
+ generate_ids = model.generate(pixel_values=pixel_values, input_ids=batch_input_ids, max_length=50)
195
+ blip_output = processor.batch_decode(generate_ids, skip_special_tokens=True)
196
+ except Exception:
197
+ return err_msg()
198
+ else:
199
+ pred = blip_output[0]
200
+ else:
201
+ return "Invalid model name"
202
+
203
+ return pred