Pinwheel's picture
HF Demo
128757a
raw
history blame contribute delete
No virus
3.71 kB
import sys
from PIL import Image
import torch
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from models.blip_vqa import blip_vqa
from models.blip_itm import blip_itm
class VQA:
def __init__(self, model_path, image_size=480):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model = blip_vqa(pretrained=model_path, image_size=image_size, vit='base')
self.model.eval()
self.model = self.model.to(self.device)
def load_demo_image(self, image_size, img_path, device):
raw_image = Image.open(img_path).convert('RGB')
w,h = raw_image.size
transform = transforms.Compose([
transforms.Resize((image_size,image_size),interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
])
image = transform(raw_image).unsqueeze(0).to(device)
return raw_image, image
def vqa(self, img_path, question):
raw_image, image = self.load_demo_image(image_size=480, img_path=img_path, device=self.device)
with torch.no_grad():
answer = self.model(image, question, train=False, inference='generate')
return answer[0]
class ITM:
def __init__(self, model_path, image_size=384):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model = blip_itm(pretrained=model_path, image_size=image_size, vit='base')
self.model.eval()
self.model = self.model.to(device='cpu')
def load_demo_image(self, image_size, img_path, device):
raw_image = Image.open(img_path).convert('RGB')
w,h = raw_image.size
transform = transforms.Compose([
transforms.Resize((image_size,image_size),interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
])
image = transform(raw_image).unsqueeze(0).to(device)
return raw_image, image
def itm(self, img_path, caption):
raw_image, image = self.load_demo_image(image_size=384,img_path=img_path, device=self.device)
itm_output = self.model(image,caption,match_head='itm')
itm_score = torch.nn.functional.softmax(itm_output,dim=1)[:,1]
itc_score = self.model(image,caption,match_head='itc')
# print('The image and text is matched with a probability of %.4f'%itm_score)
# print('The image feature and text feature has a cosine similarity of %.4f'%itc_score)
return itm_score, itc_score
if __name__=="__main__":
if not len(sys.argv) == 3:
print('Format: python3 vqa.py <path_to_img> <question>')
print('Sample: python3 vqa.py sample.jpg "What is the color of the horse?"')
else:
model_path = 'checkpoints/model_base_vqa_capfilt_large.pth'
model2_path = 'model_base_retrieval_coco.pth'
# vqa_object = VQA(model_path=model_path)
itm_object = ITM(model_path=model2_path)
img_path = sys.argv[1]
# question = sys.argv[2]
caption = sys.argv[2]
# answer = vqa_object.vqa(img_path, caption)
itm_score, itc_score = itm_object.itm(img_path, caption)
# print('Question: {} | Answer: {}'.format(caption, answer))
print('Caption: {} | The image and text is matched with a probability of %.4f: {} | The image feature and text feature has a cosine similarity of %.4f: {}'.format (caption,itm_score,itc_score))