from huggingface_hub import hf_hub_download |
from PIL import Image |
import torch |
from datasets import load_dataset, get_dataset_split_names |
import numpy as np |
import requests |
import streamlit as st |
from transformers import ViltProcessor, ViltForQuestionAnswering |
from transformers import AutoProcessor, AutoModelForCausalLM |
from transformers import BlipProcessor, BlipForQuestionAnswering |
import os |
import requests |
from tqdm import tqdm |
import timm |
from datasets import load_dataset, get_dataset_split_names |
import torch |
import torchvision |
from torchvision.models import resnet50 |
import torchvision.transforms as transforms |
from transformers import VisualBertForMultipleChoice, VisualBertForQuestionAnswering, BertTokenizerFast, AutoTokenizer, ViltForQuestionAnswering |
from PIL import Image |
import time |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
VQA_URL = "https://dl.fbaipublicfiles.com/pythia/data/answers_vqa.txt" |
def load_model(name): |
if name == "vilt": |
processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa") |
model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa") |
elif name == "vilt_finetuned": |
processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa") |
model = ViltForQuestionAnswering.from_pretrained("Minqin/carets_vqa_finetuned") |
elif name == "git": |
processor = AutoProcessor.from_pretrained("microsoft/git-base-vqav2") |
model = AutoModelForCausalLM.from_pretrained("microsoft/git-base-vqav2") |
elif name == "blip": |
processor = BlipProcessor.from_pretrained('Salesforce/blip-vqa-base') |
model = BlipForQuestionAnswering.from_pretrained('Salesforce/blip-vqa-base') |
elif name == "vbert": |
processor = AutoTokenizer.from_pretrained("bert-base-uncased") |
model = VisualBertForQuestionAnswering.from_pretrained("uclanlp/visualbert-vqa") |
else: |
raise ValueError("invalid model name: ", name) |
return (processor, model) |
''' |
def load_dataset(type): |
if type == "train": |
return load_dataset("HuggingFaceM4/VQAv2", split="train", cache_dir="cache", streaming=False) |
elif type == "test": |
return load_dataset("HuggingFaceM4/VQAv2", split="validation", cache_dir="cache", streaming=False) |
else: |
raise ValueError("invalid dataset: ", type) |
''' |
def load_img_model(name): |
""" |
loads image models for feature extraction |
returns model name and the loaded model |
""" |
if name == "resnet50": |
model = resnet50(weights='DEFAULT') |
elif name == "vitb16": |
model = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=0) |
else: |
raise ValueError("undefined model name: ", name) |
return model, name |
def label_count_list(labels): |
res = {} |
keys = set(labels) |
for key in keys: |
res[key] = labels.count(key) |
return res |
def get_item(image, question, tokenizer, image_model, model_name): |
inputs = tokenizer( |
question, |
return_tensors='pt' |
) |
visual_embeds = get_img_feats(image, image_model=image_model, name=model_name)\ |
.squeeze(2, 3).unsqueeze(0) |
visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long) |
visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float) |
upd_dict = { |
"visual_embeds": visual_embeds, |
"visual_token_type_ids": visual_token_type_ids, |
"visual_attention_mask": visual_attention_mask, |
} |
inputs.update(upd_dict) |
return upd_dict, inputs |
def get_img_feats(image, image_model, new_size=None, name='resnet50'): |
if name == "resnet50": |
image_model = torch.nn.Sequential(*list(image_model.children())[:-1]) |
if new_size is not None: |
transfrom_f = transforms.Resize((new_size, new_size), interpolation=transforms.InterpolationMode.LANCZOS) |
image = transfrom_f(image) |
transform = transforms.Compose([ |
transforms.ToTensor(), |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
]) |
image = transform(image) |
if name == "resnet50": |
image_features = image_model(image.unsqueeze(0)) |
elif name == "vitb16": |
image_features = image_model.forward_features(image.unsqueeze(0)) |
return image_features |
def get_data(query, delim=","): |
assert isinstance(query, str) |
if os.path.isfile(query): |
with open(query) as f: |
data = eval(f.read()) |
else: |
req = requests.get(query) |
try: |
data = requests.json() |
except Exception: |
data = req.content.decode() |
assert data is not None, "could not connect" |
try: |
data = eval(data) |
except Exception: |
data = data.split("\n") |
req.close() |
return data |
def err_msg(): |
print("Load error, try again") |
return "[ERROR]" |
def get_answer(model_loader_args, img, question, model_name): |
processor, model = model_loader_args[0], model_loader_args[1] |
if model_name == "vilt": |
try: |
encoding = processor(images=img, text=question, return_tensors="pt") |
except Exception: |
return err_msg() |
else: |
outputs = model(**encoding) |
logits = outputs.logits |
idx = logits.argmax(-1).item() |
pred = model.config.id2label[idx] |
elif model_name == "vilt_finetuned": |
try: |
encoding = processor(images=img, text=question, return_tensors="pt") |
except Exception: |
return err_msg() |
else: |
outputs = model(**encoding) |
logits = outputs.logits |
idx = logits.argmax(-1).item() |
pred = model.config.id2label[idx] |
elif model_name == "git": |
try: |
pixel_values = processor(images=img, return_tensors="pt").pixel_values |
input_ids = processor(text=question, add_special_tokens=False).input_ids |
input_ids = [processor.tokenizer.cls_token_id] + input_ids |
input_ids = torch.tensor(input_ids).unsqueeze(0) |
except Exception: |
return err_msg() |
else: |
generate_ids = model.generate(pixel_values=pixel_values, input_ids=input_ids, max_length=50) |
output = processor.batch_decode(generate_ids, skip_special_tokens=True) |
output = output[0] |
pred = output.split('?')[-1] |
pred = pred.strip() |
elif model_name == "vbert": |
vqa_answers = get_data(VQA_URL) |
img_model, name = load_img_model("resnet50") |
_, inputs = get_item(img, question, processor, img_model, name) |
outputs = model(**inputs) |
answer_idx = torch.argmax(outputs.logits, dim=1).item() |
pred = vqa_answers[answer_idx] |
elif model_name == "blip": |
try: |
pixel_values = processor(images=img, return_tensors="pt").pixel_values |
blip_ques = processor.tokenizer.cls_token + question |
batch_input_ids = processor(text=blip_ques, add_special_tokens=False).input_ids |
batch_input_ids = torch.tensor(batch_input_ids).unsqueeze(0) |
generate_ids = model.generate(pixel_values=pixel_values, input_ids=batch_input_ids, max_length=50) |
blip_output = processor.batch_decode(generate_ids, skip_special_tokens=True) |
except Exception: |
return err_msg() |
else: |
pred = blip_output[0] |
else: |
return "Invalid model name" |
return pred |