VDebugger-generalist-for-VQA / vision_models.py
Xueqing Wu
init
e20ef71
import abc
import os
import re
import timeit
from typing import Union
import torch
import torchvision
from PIL import Image
from torch import hub
from torch.nn import functional as F
from torchvision import transforms
device = "cuda" if torch.cuda.is_available() else "cpu"
class BaseModel(abc.ABC):
to_batch = False
seconds_collect_data = 1.5 # Window of seconds to group inputs, if to_batch is True
max_batch_size = 10 # Maximum batch size, if to_batch is True. Maximum allowed by OpenAI
requires_gpu = True
num_gpus = 1 # Number of required GPUs
load_order = 0 # Order in which the model is loaded. Lower is first. By default, models are loaded alphabetically
def __init__(self, gpu_number):
self.dev = f'cuda:{gpu_number}' if device == 'cuda' else device
@abc.abstractmethod
def forward(self, *args, **kwargs):
"""
If to_batch is True, every arg and kwarg will be a list of inputs, and the output should be a list of outputs.
The way it is implemented in the background, if inputs with defaults are not specified, they will take the
default value, but still be given as a list to the forward method.
"""
pass
@classmethod
@abc.abstractmethod
def name(cls) -> str:
"""The name of the model has to be given by the subclass"""
pass
@classmethod
def list_processes(cls):
"""
A single model can be run in multiple processes, for example if there are different tasks to be done with it.
If multiple processes are used, override this method to return a list of strings.
Remember the @classmethod decorator.
If we specify a list of processes, the self.forward() method has to have a "process_name" parameter that gets
automatically passed in.
See GPT3Model for an example.
"""
return [cls.name]
# ------------------------------ Specific models ---------------------------- #
class ObjectDetector(BaseModel):
name = 'object_detector'
def __init__(self, gpu_number=0):
super().__init__(gpu_number)
detection_model = hub.load('facebookresearch/detr', 'detr_resnet50', pretrained=True).to(self.dev)
detection_model.eval()
self.detection_model = detection_model
@torch.no_grad()
def forward(self, image: torch.Tensor):
"""get_object_detection_bboxes"""
input_batch = image.to(self.dev).unsqueeze(0) # create a mini-batch as expected by the model
detections = self.detection_model(input_batch)
p = detections['pred_boxes']
p = torch.stack([p[..., 0], 1 - p[..., 3], p[..., 2], 1 - p[..., 1]], -1) # [left, lower, right, upper]
detections['pred_boxes'] = p
return detections
class DepthEstimationModel(BaseModel):
name = 'depth'
def __init__(self, gpu_number=0, model_type='DPT_Large'):
super().__init__(gpu_number)
# Model options: MiDaS_small, DPT_Hybrid, DPT_Large
depth_estimation_model = hub.load('intel-isl/MiDaS', model_type, pretrained=True).to(self.dev)
depth_estimation_model.eval()
midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms")
if model_type == "DPT_Large" or model_type == "DPT_Hybrid":
self.transform = midas_transforms.dpt_transform
else:
self.transform = midas_transforms.small_transform
self.depth_estimation_model = depth_estimation_model
@torch.no_grad()
def forward(self, image: torch.Tensor):
"""Estimate depth map"""
image_numpy = image.cpu().permute(1, 2, 0).numpy() * 255
input_batch = self.transform(image_numpy).to(self.dev)
prediction = self.depth_estimation_model(input_batch)
# Resize to original size
prediction = torch.nn.functional.interpolate(
prediction.unsqueeze(1),
size=image_numpy.shape[:2],
mode="bicubic",
align_corners=False,
).squeeze()
# We compute the inverse because the model returns inverse depth
to_return = 1 / prediction
to_return = to_return.cpu()
return to_return # To save: plt.imsave(path_save, prediction.cpu().numpy())
class CLIPModel(BaseModel):
name = 'clip'
def __init__(self, gpu_number=0, version="ViT-L/14@336px"): # @336px
super().__init__(gpu_number)
import clip
self.clip = clip
model, preprocess = clip.load(version, device=self.dev)
model.eval()
model.requires_grad_ = False
self.model = model
self.negative_text_features = None
self.transform = self.get_clip_transforms_from_tensor(336 if "336" in version else 224)
# @staticmethod
def _convert_image_to_rgb(self, image):
return image.convert("RGB")
# @staticmethod
def get_clip_transforms_from_tensor(self, n_px=336):
return transforms.Compose([
transforms.ToPILImage(),
transforms.Resize(n_px, interpolation=transforms.InterpolationMode.BICUBIC),
transforms.CenterCrop(n_px),
self._convert_image_to_rgb,
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
])
@torch.no_grad()
def binary_score(self, image: torch.Tensor, prompt, negative_categories=None):
is_video = isinstance(image, torch.Tensor) and image.ndim == 4
if is_video: # video
image = torch.stack([self.transform(image[i]) for i in range(image.shape[0])], dim=0)
else:
image = self.transform(image).unsqueeze(0).to(self.dev)
prompt_prefix = "photo of "
prompt = prompt_prefix + prompt
if negative_categories is None:
if self.negative_text_features is None:
self.negative_text_features = self.clip_negatives(prompt_prefix)
negative_text_features = self.negative_text_features
else:
negative_text_features = self.clip_negatives(prompt_prefix, negative_categories)
text = self.clip.tokenize([prompt]).to(self.dev)
image_features = self.model.encode_image(image.to(self.dev))
image_features = F.normalize(image_features, dim=-1)
pos_text_features = self.model.encode_text(text)
pos_text_features = F.normalize(pos_text_features, dim=-1)
text_features = torch.concat([pos_text_features, negative_text_features], axis=0)
# run competition where we do a binary classification
# between the positive and all the negatives, then take the mean
sim = (100.0 * image_features @ text_features.T).squeeze(dim=0)
if is_video:
query = sim[..., 0].unsqueeze(-1).broadcast_to(sim.shape[0], sim.shape[-1] - 1)
others = sim[..., 1:]
res = F.softmax(torch.stack([query, others], dim=-1), dim=-1)[..., 0].mean(-1)
else:
res = F.softmax(torch.cat((sim[0].broadcast_to(1, sim.shape[0] - 1),
sim[1:].unsqueeze(0)), dim=0), dim=0)[0].mean()
return res
@torch.no_grad()
def clip_negatives(self, prompt_prefix, negative_categories=None):
if negative_categories is None:
with open('useful_lists/random_negatives.txt') as f:
negative_categories = [x.strip() for x in f.read().split()]
# negative_categories = negative_categories[:1000]
# negative_categories = ["a cat", "a lamp"]
negative_categories = [prompt_prefix + x for x in negative_categories]
negative_tokens = self.clip.tokenize(negative_categories).to(self.dev)
negative_text_features = self.model.encode_text(negative_tokens)
negative_text_features = F.normalize(negative_text_features, dim=-1)
return negative_text_features
@torch.no_grad()
def classify(self, image: Union[torch.Tensor, list], categories: list[str], return_index=True):
is_list = isinstance(image, list)
if is_list:
assert len(image) == len(categories)
image = [self.transform(x).unsqueeze(0) for x in image]
image_clip = torch.cat(image, dim=0).to(self.dev)
elif len(image.shape) == 3:
image_clip = self.transform(image).to(self.dev).unsqueeze(0)
else: # Video (process images separately)
image_clip = torch.stack([self.transform(x) for x in image], dim=0).to(self.dev)
# if len(image_clip.shape) == 3:
# image_clip = image_clip.unsqueeze(0)
prompt_prefix = "photo of "
categories = [prompt_prefix + x for x in categories]
categories = self.clip.tokenize(categories).to(self.dev)
text_features = self.model.encode_text(categories)
text_features = F.normalize(text_features, dim=-1)
image_features = self.model.encode_image(image_clip)
image_features = F.normalize(image_features, dim=-1)
if image_clip.shape[0] == 1:
# get category from image
softmax_arg = image_features @ text_features.T # 1 x n
else:
if is_list:
# get highest category-image match with n images and n corresponding categories
softmax_arg = (image_features @ text_features.T).diag().unsqueeze(0) # n x n -> 1 x n
else:
softmax_arg = (image_features @ text_features.T)
similarity = (100.0 * softmax_arg).softmax(dim=-1).squeeze(0)
if not return_index:
return similarity
else:
result = torch.argmax(similarity, dim=-1)
if result.shape == ():
result = result.item()
return result
@torch.no_grad()
def compare(self, images: list[torch.Tensor], prompt, return_scores=False):
images = [self.transform(im).unsqueeze(0).to(self.dev) for im in images]
images = torch.cat(images, dim=0)
prompt_prefix = "photo of "
prompt = prompt_prefix + prompt
text = self.clip.tokenize([prompt]).to(self.dev)
image_features = self.model.encode_image(images.to(self.dev))
image_features = F.normalize(image_features, dim=-1)
text_features = self.model.encode_text(text)
text_features = F.normalize(text_features, dim=-1)
sim = (image_features @ text_features.T).squeeze(dim=-1) # Only one text, so squeeze
if return_scores:
return sim
res = sim.argmax()
return res
def forward(self, image, prompt, task='score', return_index=True, negative_categories=None, return_scores=False):
if task == 'classify':
categories = prompt
clip_sim = self.classify(image, categories, return_index=return_index)
out = clip_sim
elif task == 'score':
clip_score = self.binary_score(image, prompt, negative_categories=negative_categories)
out = clip_score
else: # task == 'compare'
idx = self.compare(image, prompt, return_scores)
out = idx
if not isinstance(out, int):
out = out.cpu()
return out
class MaskRCNNModel(BaseModel):
name = 'maskrcnn'
def __init__(self, gpu_number=0, threshold=0.8):
super().__init__(gpu_number)
obj_detect = torchvision.models.detection.maskrcnn_resnet50_fpn_v2(weights='COCO_V1').to(self.dev)
obj_detect.eval()
obj_detect.requires_grad_(False)
self.categories = torchvision.models.detection.MaskRCNN_ResNet50_FPN_V2_Weights.COCO_V1.meta['categories']
self.obj_detect = obj_detect
self.threshold = threshold
def prepare_image(self, image):
image = image.to(self.dev)
return image
@torch.no_grad()
def detect(self, images: torch.Tensor, confidence_threshold: float = None):
if type(images) != list:
images = [images]
threshold = confidence_threshold if confidence_threshold is not None else self.threshold
images = [self.prepare_image(im) for im in images]
detections = self.obj_detect(images)
scores = []
for i in range(len(images)):
scores.append(detections[i]['scores'][detections[i]['scores'] > threshold])
height = detections[i]['masks'].shape[-2]
# Just return boxes (no labels no masks, no scores) with scores > threshold
d_i = detections[i]['boxes'][detections[i]['scores'] > threshold]
# Return [left, lower, right, upper] instead of [left, upper, right, lower]
detections[i] = torch.stack([d_i[:, 0], height - d_i[:, 3], d_i[:, 2], height - d_i[:, 1]], dim=1)
return detections, scores
def forward(self, image, confidence_threshold: float = None):
obj_detections, obj_scores = self.detect(image, confidence_threshold=confidence_threshold)
# Move to CPU before sharing. Alternatively we can try cloning tensors in CUDA, but may not work
obj_detections = [(v.to('cpu') if isinstance(v, torch.Tensor) else list(v)) for v in obj_detections]
obj_scores = [(v.to('cpu') if isinstance(v, torch.Tensor) else list(v)) for v in obj_scores]
return obj_detections, obj_scores
class GLIPModel(BaseModel):
name = 'glip'
def __init__(self, model_size='large', gpu_number=0, *args):
BaseModel.__init__(self, gpu_number)
# with contextlib.redirect_stderr(open(os.devnull, "w")): # Do not print nltk_data messages when importing
from maskrcnn_benchmark.engine.predictor_glip import GLIPDemo, to_image_list, create_positive_map, \
create_positive_map_label_to_token_from_positive_map
working_dir = 'pretrained_models/GLIP/'
if model_size == 'tiny':
config_file = working_dir + "configs/glip_Swin_T_O365_GoldG.yaml"
weight_file = working_dir + "checkpoints/glip_tiny_model_o365_goldg_cc_sbu.pth"
else: # large
config_file = working_dir + "configs/glip_Swin_L.yaml"
weight_file = working_dir + "checkpoints/glip_large_model.pth"
class OurGLIPDemo(GLIPDemo):
def __init__(self, dev, *args_demo):
kwargs = {
'min_image_size': 800,
'confidence_threshold': 0.5,
'show_mask_heatmaps': False
}
self.dev = dev
from maskrcnn_benchmark.config import cfg
# manual override some options
cfg.local_rank = 0
cfg.num_gpus = 1
cfg.merge_from_file(config_file)
cfg.merge_from_list(["MODEL.WEIGHT", weight_file])
cfg.merge_from_list(["MODEL.DEVICE", self.dev])
from transformers.utils import logging
logging.set_verbosity_error()
GLIPDemo.__init__(self, cfg, *args_demo, **kwargs)
if self.cfg.MODEL.RPN_ARCHITECTURE == "VLDYHEAD":
plus = 1
else:
plus = 0
self.plus = plus
self.color = 255
@torch.no_grad()
def compute_prediction(self, original_image, original_caption, custom_entity=None):
image = self.transforms(original_image)
# image = [image, image.permute(0, 2, 1)]
image_list = to_image_list(image, self.cfg.DATALOADER.SIZE_DIVISIBILITY)
image_list = image_list.to(self.dev)
# caption
if isinstance(original_caption, list):
if len(original_caption) > 40:
all_predictions = None
for loop_num, i in enumerate(range(0, len(original_caption), 40)):
list_step = original_caption[i:i + 40]
prediction_step = self.compute_prediction(original_image, list_step, custom_entity=None)
if all_predictions is None:
all_predictions = prediction_step
else:
# Aggregate predictions
all_predictions.bbox = torch.cat((all_predictions.bbox, prediction_step.bbox), dim=0)
for k in all_predictions.extra_fields:
all_predictions.extra_fields[k] = \
torch.cat((all_predictions.extra_fields[k],
prediction_step.extra_fields[k] + loop_num), dim=0)
return all_predictions
# we directly provided a list of category names
caption_string = ""
tokens_positive = []
seperation_tokens = " . "
for word in original_caption:
tokens_positive.append([len(caption_string), len(caption_string) + len(word)])
caption_string += word
caption_string += seperation_tokens
tokenized = self.tokenizer([caption_string], return_tensors="pt")
# tokens_positive = [tokens_positive] # This was wrong
tokens_positive = [[v] for v in tokens_positive]
original_caption = caption_string
# print(tokens_positive)
else:
tokenized = self.tokenizer([original_caption], return_tensors="pt")
if custom_entity is None:
tokens_positive = self.run_ner(original_caption)
# print(tokens_positive)
# process positive map
positive_map = create_positive_map(tokenized, tokens_positive)
positive_map_label_to_token = create_positive_map_label_to_token_from_positive_map(positive_map,
plus=self.plus)
self.positive_map_label_to_token = positive_map_label_to_token
tic = timeit.time.perf_counter()
# compute predictions
predictions = self.model(image_list, captions=[original_caption],
positive_map=positive_map_label_to_token)
predictions = [o.to(self.cpu_device) for o in predictions]
# print("inference time per image: {}".format(timeit.time.perf_counter() - tic))
# always single image is passed at a time
prediction = predictions[0]
# reshape prediction (a BoxList) into the original image size
height, width = original_image.shape[-2:]
# if self.tensor_inputs:
# else:
# height, width = original_image.shape[:-1]
prediction = prediction.resize((width, height))
if prediction.has_field("mask"):
# if we have masks, paste the masks in the right position
# in the image, as defined by the bounding boxes
masks = prediction.get_field("mask")
# always single image is passed at a time
masks = self.masker([masks], [prediction])[0]
prediction.add_field("mask", masks)
return prediction
@staticmethod
def to_left_right_upper_lower(bboxes):
return [(bbox[1], bbox[3], bbox[0], bbox[2]) for bbox in bboxes]
@staticmethod
def to_xmin_ymin_xmax_ymax(bboxes):
# invert the previous method
return [(bbox[2], bbox[0], bbox[3], bbox[1]) for bbox in bboxes]
@staticmethod
def prepare_image(image):
image = image[[2, 1, 0]] # convert to bgr for opencv-format for glip
return image
@torch.no_grad()
def forward(self, image: torch.Tensor, obj: Union[str, list], confidence_threshold=None):
if confidence_threshold is not None:
original_confidence_threshold = self.confidence_threshold
self.confidence_threshold = confidence_threshold
# if isinstance(object, list):
# object = ' . '.join(object) + ' .' # add separation tokens
image = self.prepare_image(image)
# Avoid the resizing creating a huge image in a pathological case
ratio = image.shape[1] / image.shape[2]
ratio = max(ratio, 1 / ratio)
original_min_image_size = self.min_image_size
if ratio > 10:
self.min_image_size = int(original_min_image_size * 10 / ratio)
self.transforms = self.build_transform()
with torch.cuda.device(self.dev):
inference_output = self.inference(image, obj)
bboxes = inference_output.bbox.cpu().numpy().astype(int)
# bboxes = self.to_left_right_upper_lower(bboxes)
if ratio > 10:
self.min_image_size = original_min_image_size
self.transforms = self.build_transform()
bboxes = torch.tensor(bboxes)
# Convert to [left, lower, right, upper] instead of [left, upper, right, lower]
height = image.shape[-2]
bboxes = torch.stack([bboxes[:, 0], height - bboxes[:, 3], bboxes[:, 2], height - bboxes[:, 1]], dim=1)
if confidence_threshold is not None:
self.confidence_threshold = original_confidence_threshold
# subtract 1 because it's 1-indexed for some reason
# return bboxes, inference_output.get_field("labels").cpu().numpy() - 1
return bboxes, inference_output.get_field("scores")
self.glip_demo = OurGLIPDemo(*args, dev=self.dev)
def forward(self, *args, **kwargs):
return self.glip_demo.forward(*args, **kwargs)
class BLIPModel(BaseModel):
name = 'blip'
to_batch = True
max_batch_size = 32
seconds_collect_data = 0.2 # The queue has additionally the time it is executing the previous forward pass
def __init__(self, gpu_number=0, half_precision=True, blip_v2_model_type="blip2-flan-t5-xl"):
super().__init__(gpu_number)
# from lavis.models import load_model_and_preprocess
from transformers import Blip2Processor, Blip2ForConditionalGeneration
# https://huggingface.co/models?sort=downloads&search=Salesforce%2Fblip2-
assert blip_v2_model_type in ['blip2-flan-t5-xxl', 'blip2-flan-t5-xl', 'blip2-opt-2.7b', 'blip2-opt-6.7b',
'blip2-opt-2.7b-coco', 'blip2-flan-t5-xl-coco', 'blip2-opt-6.7b-coco']
with torch.cuda.device(self.dev):
max_memory = {gpu_number: torch.cuda.mem_get_info(self.dev)[0]}
self.processor = Blip2Processor.from_pretrained(f"Salesforce/{blip_v2_model_type}")
# Device_map must be sequential for manual GPU selection
try:
self.model = Blip2ForConditionalGeneration.from_pretrained(
f"Salesforce/{blip_v2_model_type}", load_in_8bit=half_precision,
torch_dtype=torch.float16 if half_precision else "auto",
device_map="sequential", max_memory=max_memory
)
except Exception as e:
# Clarify error message. The problem is that it tries to load part of the model to disk.
if "had weights offloaded to the disk" in e.args[0]:
extra_text = ' You may want to consider setting half_precision to True.' if half_precision else ''
raise MemoryError(f"Not enough GPU memory in GPU {self.dev} to load the model.{extra_text}")
else:
raise e
self.qa_prompt = "Question: {} Short answer:"
self.caption_prompt = "a photo of"
self.half_precision = half_precision
self.max_words = 50
@torch.no_grad()
def caption(self, image, prompt=None):
inputs = self.processor(images=image, text=prompt, return_tensors="pt").to(self.dev, torch.float16)
generation_output = self.model.generate(**inputs, length_penalty=1., num_beams=5, max_length=30, min_length=1,
do_sample=False, top_p=0.9, repetition_penalty=1.0,
num_return_sequences=1, temperature=1,
return_dict_in_generate=True, output_scores=True)
generated_text = [cap.strip() for cap in self.processor.batch_decode(
generation_output.sequences, skip_special_tokens=True)]
return generated_text, generation_output.sequences_scores.cpu().numpy().tolist()
def pre_question(self, question):
# from LAVIS blip_processors
question = re.sub(
r"([.!\"()*#:;~])",
"",
question.lower(),
)
question = question.rstrip(" ")
# truncate question
question_words = question.split(" ")
if len(question_words) > self.max_words:
question = " ".join(question_words[: self.max_words])
return question
@torch.no_grad()
def qa(self, image, question):
inputs = self.processor(images=image, text=question, return_tensors="pt", padding="longest").to(self.dev)
if self.half_precision:
inputs['pixel_values'] = inputs['pixel_values'].half()
generation_output = self.model.generate(**inputs, length_penalty=-1, num_beams=5, max_length=10, min_length=1,
do_sample=False, top_p=0.9, repetition_penalty=1.0,
num_return_sequences=1, temperature=1,
return_dict_in_generate=True, output_scores=True)
generated_text = self.processor.batch_decode(generation_output.sequences, skip_special_tokens=True)
return generated_text, generation_output.sequences_scores.cpu().numpy().tolist()
def forward(self, image, question=None, task='caption'):
if not self.to_batch:
image, question, task = [image], [question], [task]
if len(image) > 0 and 'float' in str(image[0].dtype) and image[0].max() <= 1:
image = [im * 255 for im in image]
# Separate into qa and caption batches.
prompts_qa = [self.qa_prompt.format(self.pre_question(q)) for q, t in zip(question, task) if t == 'qa']
images_qa = [im for i, im in enumerate(image) if task[i] == 'qa']
images_caption = [im for i, im in enumerate(image) if task[i] == 'caption']
with torch.cuda.device(self.dev):
response_qa, scores_qa = self.qa(images_qa, prompts_qa) if len(images_qa) > 0 else ([], [])
response_caption, scores_caption = self.caption(images_caption) if len(images_caption) > 0 else ([], [])
response = []
for t in task:
if t == 'qa':
response.append([response_qa.pop(0), scores_qa.pop(0)])
else:
response.append([response_caption.pop(0), scores_caption.pop(0)])
if not self.to_batch:
response = response[0]
return response
class XVLMModel(BaseModel):
name = 'xvlm'
def __init__(self, gpu_number=0, path_checkpoint='pretrained_models/xvlm/retrieval_mscoco_checkpoint_9.pth'):
from xvlm.xvlm import XVLMBase
from transformers import BertTokenizer
super().__init__(gpu_number)
image_res = 384
self.max_words = 30
config_xvlm = {
'image_res': image_res,
'patch_size': 32,
'text_encoder': 'bert-base-uncased',
'block_num': 9,
'max_tokens': 40,
'embed_dim': 256,
}
vision_config = {
'vision_width': 1024,
'image_res': 384,
'window_size': 12,
'embed_dim': 128,
'depths': [2, 2, 18, 2],
'num_heads': [4, 8, 16, 32]
}
model = XVLMBase(config_xvlm, use_contrastive_loss=True, vision_config=vision_config)
checkpoint = torch.load(path_checkpoint, map_location='cpu')
state_dict = checkpoint['model'] if 'model' in checkpoint.keys() else checkpoint
msg = model.load_state_dict(state_dict, strict=False)
if len(msg.missing_keys) > 0:
print('XVLM Missing keys: ', msg.missing_keys)
model = model.to(self.dev)
model.eval()
self.model = model
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
self.transform = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize((image_res, image_res), interpolation=Image.BICUBIC),
transforms.ToTensor(),
normalize,
])
with open('useful_lists/random_negatives.txt') as f:
self.negative_categories = [x.strip() for x in f.read().split()]
@staticmethod
def pre_caption(caption, max_words):
caption = re.sub(
r"([,.'!?\"()*#:;~])",
'',
caption.lower(),
).replace('-', ' ').replace('/', ' ').replace('<person>', 'person')
caption = re.sub(
r"\s{2,}",
' ',
caption,
)
caption = caption.rstrip('\n')
caption = caption.strip(' ')
# truncate caption
caption_words = caption.split(' ')
if len(caption_words) > max_words:
caption = ' '.join(caption_words[:max_words])
if not len(caption):
raise ValueError("pre_caption yields invalid text")
return caption
@torch.no_grad()
def score(self, images, texts):
if isinstance(texts, str):
texts = [texts]
if not isinstance(images, list):
images = [images]
images = [self.transform(image) for image in images]
images = torch.stack(images, dim=0).to(self.dev)
texts = [self.pre_caption(text, self.max_words) for text in texts]
text_input = self.tokenizer(texts, padding='longest', return_tensors="pt").to(self.dev)
image_embeds, image_atts = self.model.get_vision_embeds(images)
text_ids, text_atts = text_input.input_ids, text_input.attention_mask
text_embeds = self.model.get_text_embeds(text_ids, text_atts)
image_feat, text_feat = self.model.get_features(image_embeds, text_embeds)
logits = image_feat @ text_feat.t()
return logits
@torch.no_grad()
def binary_score(self, image, text, negative_categories):
# Compare with a pre-defined set of negatives
texts = [text] + negative_categories
sim = 100 * self.score(image, texts)[0]
res = F.softmax(torch.cat((sim[0].broadcast_to(1, sim.shape[0] - 1),
sim[1:].unsqueeze(0)), dim=0), dim=0)[0].mean()
return res
def forward(self, image, text, task='score', negative_categories=None):
if task == 'score':
score = self.score(image, text)
else: # binary
score = self.binary_score(image, text, negative_categories=negative_categories)
return score.cpu()