|
import abc |
|
import os |
|
import re |
|
import timeit |
|
from typing import Union |
|
|
|
import torch |
|
import torchvision |
|
from PIL import Image |
|
from torch import hub |
|
from torch.nn import functional as F |
|
from torchvision import transforms |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
class BaseModel(abc.ABC): |
|
to_batch = False |
|
seconds_collect_data = 1.5 |
|
max_batch_size = 10 |
|
requires_gpu = True |
|
num_gpus = 1 |
|
load_order = 0 |
|
|
|
def __init__(self, gpu_number): |
|
self.dev = f'cuda:{gpu_number}' if device == 'cuda' else device |
|
|
|
@abc.abstractmethod |
|
def forward(self, *args, **kwargs): |
|
""" |
|
If to_batch is True, every arg and kwarg will be a list of inputs, and the output should be a list of outputs. |
|
The way it is implemented in the background, if inputs with defaults are not specified, they will take the |
|
default value, but still be given as a list to the forward method. |
|
""" |
|
pass |
|
|
|
@classmethod |
|
@abc.abstractmethod |
|
def name(cls) -> str: |
|
"""The name of the model has to be given by the subclass""" |
|
pass |
|
|
|
@classmethod |
|
def list_processes(cls): |
|
""" |
|
A single model can be run in multiple processes, for example if there are different tasks to be done with it. |
|
If multiple processes are used, override this method to return a list of strings. |
|
Remember the @classmethod decorator. |
|
If we specify a list of processes, the self.forward() method has to have a "process_name" parameter that gets |
|
automatically passed in. |
|
See GPT3Model for an example. |
|
""" |
|
return [cls.name] |
|
|
|
|
|
|
|
|
|
|
|
class ObjectDetector(BaseModel): |
|
name = 'object_detector' |
|
|
|
def __init__(self, gpu_number=0): |
|
super().__init__(gpu_number) |
|
|
|
detection_model = hub.load('facebookresearch/detr', 'detr_resnet50', pretrained=True).to(self.dev) |
|
detection_model.eval() |
|
|
|
self.detection_model = detection_model |
|
|
|
@torch.no_grad() |
|
def forward(self, image: torch.Tensor): |
|
"""get_object_detection_bboxes""" |
|
input_batch = image.to(self.dev).unsqueeze(0) |
|
detections = self.detection_model(input_batch) |
|
p = detections['pred_boxes'] |
|
p = torch.stack([p[..., 0], 1 - p[..., 3], p[..., 2], 1 - p[..., 1]], -1) |
|
detections['pred_boxes'] = p |
|
return detections |
|
|
|
|
|
class DepthEstimationModel(BaseModel): |
|
name = 'depth' |
|
|
|
def __init__(self, gpu_number=0, model_type='DPT_Large'): |
|
super().__init__(gpu_number) |
|
|
|
depth_estimation_model = hub.load('intel-isl/MiDaS', model_type, pretrained=True).to(self.dev) |
|
depth_estimation_model.eval() |
|
|
|
midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms") |
|
|
|
if model_type == "DPT_Large" or model_type == "DPT_Hybrid": |
|
self.transform = midas_transforms.dpt_transform |
|
else: |
|
self.transform = midas_transforms.small_transform |
|
|
|
self.depth_estimation_model = depth_estimation_model |
|
|
|
@torch.no_grad() |
|
def forward(self, image: torch.Tensor): |
|
"""Estimate depth map""" |
|
image_numpy = image.cpu().permute(1, 2, 0).numpy() * 255 |
|
input_batch = self.transform(image_numpy).to(self.dev) |
|
prediction = self.depth_estimation_model(input_batch) |
|
|
|
prediction = torch.nn.functional.interpolate( |
|
prediction.unsqueeze(1), |
|
size=image_numpy.shape[:2], |
|
mode="bicubic", |
|
align_corners=False, |
|
).squeeze() |
|
|
|
to_return = 1 / prediction |
|
to_return = to_return.cpu() |
|
return to_return |
|
|
|
|
|
class CLIPModel(BaseModel): |
|
name = 'clip' |
|
|
|
def __init__(self, gpu_number=0, version="ViT-L/14@336px"): |
|
super().__init__(gpu_number) |
|
|
|
import clip |
|
self.clip = clip |
|
|
|
model, preprocess = clip.load(version, device=self.dev) |
|
model.eval() |
|
model.requires_grad_ = False |
|
self.model = model |
|
self.negative_text_features = None |
|
self.transform = self.get_clip_transforms_from_tensor(336 if "336" in version else 224) |
|
|
|
|
|
def _convert_image_to_rgb(self, image): |
|
return image.convert("RGB") |
|
|
|
|
|
def get_clip_transforms_from_tensor(self, n_px=336): |
|
return transforms.Compose([ |
|
transforms.ToPILImage(), |
|
transforms.Resize(n_px, interpolation=transforms.InterpolationMode.BICUBIC), |
|
transforms.CenterCrop(n_px), |
|
self._convert_image_to_rgb, |
|
transforms.ToTensor(), |
|
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), |
|
]) |
|
|
|
@torch.no_grad() |
|
def binary_score(self, image: torch.Tensor, prompt, negative_categories=None): |
|
is_video = isinstance(image, torch.Tensor) and image.ndim == 4 |
|
if is_video: |
|
image = torch.stack([self.transform(image[i]) for i in range(image.shape[0])], dim=0) |
|
else: |
|
image = self.transform(image).unsqueeze(0).to(self.dev) |
|
|
|
prompt_prefix = "photo of " |
|
prompt = prompt_prefix + prompt |
|
|
|
if negative_categories is None: |
|
if self.negative_text_features is None: |
|
self.negative_text_features = self.clip_negatives(prompt_prefix) |
|
negative_text_features = self.negative_text_features |
|
else: |
|
negative_text_features = self.clip_negatives(prompt_prefix, negative_categories) |
|
|
|
text = self.clip.tokenize([prompt]).to(self.dev) |
|
|
|
image_features = self.model.encode_image(image.to(self.dev)) |
|
image_features = F.normalize(image_features, dim=-1) |
|
|
|
pos_text_features = self.model.encode_text(text) |
|
pos_text_features = F.normalize(pos_text_features, dim=-1) |
|
|
|
text_features = torch.concat([pos_text_features, negative_text_features], axis=0) |
|
|
|
|
|
|
|
sim = (100.0 * image_features @ text_features.T).squeeze(dim=0) |
|
if is_video: |
|
query = sim[..., 0].unsqueeze(-1).broadcast_to(sim.shape[0], sim.shape[-1] - 1) |
|
others = sim[..., 1:] |
|
res = F.softmax(torch.stack([query, others], dim=-1), dim=-1)[..., 0].mean(-1) |
|
else: |
|
res = F.softmax(torch.cat((sim[0].broadcast_to(1, sim.shape[0] - 1), |
|
sim[1:].unsqueeze(0)), dim=0), dim=0)[0].mean() |
|
return res |
|
|
|
@torch.no_grad() |
|
def clip_negatives(self, prompt_prefix, negative_categories=None): |
|
if negative_categories is None: |
|
with open('useful_lists/random_negatives.txt') as f: |
|
negative_categories = [x.strip() for x in f.read().split()] |
|
|
|
|
|
negative_categories = [prompt_prefix + x for x in negative_categories] |
|
negative_tokens = self.clip.tokenize(negative_categories).to(self.dev) |
|
|
|
negative_text_features = self.model.encode_text(negative_tokens) |
|
negative_text_features = F.normalize(negative_text_features, dim=-1) |
|
|
|
return negative_text_features |
|
|
|
@torch.no_grad() |
|
def classify(self, image: Union[torch.Tensor, list], categories: list[str], return_index=True): |
|
is_list = isinstance(image, list) |
|
if is_list: |
|
assert len(image) == len(categories) |
|
image = [self.transform(x).unsqueeze(0) for x in image] |
|
image_clip = torch.cat(image, dim=0).to(self.dev) |
|
elif len(image.shape) == 3: |
|
image_clip = self.transform(image).to(self.dev).unsqueeze(0) |
|
else: |
|
image_clip = torch.stack([self.transform(x) for x in image], dim=0).to(self.dev) |
|
|
|
|
|
|
|
|
|
prompt_prefix = "photo of " |
|
categories = [prompt_prefix + x for x in categories] |
|
categories = self.clip.tokenize(categories).to(self.dev) |
|
|
|
text_features = self.model.encode_text(categories) |
|
text_features = F.normalize(text_features, dim=-1) |
|
|
|
image_features = self.model.encode_image(image_clip) |
|
image_features = F.normalize(image_features, dim=-1) |
|
|
|
if image_clip.shape[0] == 1: |
|
|
|
softmax_arg = image_features @ text_features.T |
|
else: |
|
if is_list: |
|
|
|
softmax_arg = (image_features @ text_features.T).diag().unsqueeze(0) |
|
else: |
|
softmax_arg = (image_features @ text_features.T) |
|
|
|
similarity = (100.0 * softmax_arg).softmax(dim=-1).squeeze(0) |
|
if not return_index: |
|
return similarity |
|
else: |
|
result = torch.argmax(similarity, dim=-1) |
|
if result.shape == (): |
|
result = result.item() |
|
return result |
|
|
|
@torch.no_grad() |
|
def compare(self, images: list[torch.Tensor], prompt, return_scores=False): |
|
images = [self.transform(im).unsqueeze(0).to(self.dev) for im in images] |
|
images = torch.cat(images, dim=0) |
|
|
|
prompt_prefix = "photo of " |
|
prompt = prompt_prefix + prompt |
|
|
|
text = self.clip.tokenize([prompt]).to(self.dev) |
|
|
|
image_features = self.model.encode_image(images.to(self.dev)) |
|
image_features = F.normalize(image_features, dim=-1) |
|
|
|
text_features = self.model.encode_text(text) |
|
text_features = F.normalize(text_features, dim=-1) |
|
|
|
sim = (image_features @ text_features.T).squeeze(dim=-1) |
|
|
|
if return_scores: |
|
return sim |
|
res = sim.argmax() |
|
return res |
|
|
|
def forward(self, image, prompt, task='score', return_index=True, negative_categories=None, return_scores=False): |
|
if task == 'classify': |
|
categories = prompt |
|
clip_sim = self.classify(image, categories, return_index=return_index) |
|
out = clip_sim |
|
elif task == 'score': |
|
clip_score = self.binary_score(image, prompt, negative_categories=negative_categories) |
|
out = clip_score |
|
else: |
|
idx = self.compare(image, prompt, return_scores) |
|
out = idx |
|
if not isinstance(out, int): |
|
out = out.cpu() |
|
return out |
|
|
|
|
|
class MaskRCNNModel(BaseModel): |
|
name = 'maskrcnn' |
|
|
|
def __init__(self, gpu_number=0, threshold=0.8): |
|
super().__init__(gpu_number) |
|
obj_detect = torchvision.models.detection.maskrcnn_resnet50_fpn_v2(weights='COCO_V1').to(self.dev) |
|
obj_detect.eval() |
|
obj_detect.requires_grad_(False) |
|
self.categories = torchvision.models.detection.MaskRCNN_ResNet50_FPN_V2_Weights.COCO_V1.meta['categories'] |
|
self.obj_detect = obj_detect |
|
self.threshold = threshold |
|
|
|
def prepare_image(self, image): |
|
image = image.to(self.dev) |
|
return image |
|
|
|
@torch.no_grad() |
|
def detect(self, images: torch.Tensor, confidence_threshold: float = None): |
|
if type(images) != list: |
|
images = [images] |
|
threshold = confidence_threshold if confidence_threshold is not None else self.threshold |
|
|
|
images = [self.prepare_image(im) for im in images] |
|
detections = self.obj_detect(images) |
|
scores = [] |
|
for i in range(len(images)): |
|
scores.append(detections[i]['scores'][detections[i]['scores'] > threshold]) |
|
|
|
height = detections[i]['masks'].shape[-2] |
|
|
|
d_i = detections[i]['boxes'][detections[i]['scores'] > threshold] |
|
|
|
detections[i] = torch.stack([d_i[:, 0], height - d_i[:, 3], d_i[:, 2], height - d_i[:, 1]], dim=1) |
|
|
|
return detections, scores |
|
|
|
def forward(self, image, confidence_threshold: float = None): |
|
obj_detections, obj_scores = self.detect(image, confidence_threshold=confidence_threshold) |
|
|
|
obj_detections = [(v.to('cpu') if isinstance(v, torch.Tensor) else list(v)) for v in obj_detections] |
|
obj_scores = [(v.to('cpu') if isinstance(v, torch.Tensor) else list(v)) for v in obj_scores] |
|
return obj_detections, obj_scores |
|
|
|
|
|
class GLIPModel(BaseModel): |
|
name = 'glip' |
|
|
|
def __init__(self, model_size='large', gpu_number=0, *args): |
|
BaseModel.__init__(self, gpu_number) |
|
|
|
|
|
from maskrcnn_benchmark.engine.predictor_glip import GLIPDemo, to_image_list, create_positive_map, \ |
|
create_positive_map_label_to_token_from_positive_map |
|
|
|
working_dir = 'pretrained_models/GLIP/' |
|
if model_size == 'tiny': |
|
config_file = working_dir + "configs/glip_Swin_T_O365_GoldG.yaml" |
|
weight_file = working_dir + "checkpoints/glip_tiny_model_o365_goldg_cc_sbu.pth" |
|
else: |
|
config_file = working_dir + "configs/glip_Swin_L.yaml" |
|
weight_file = working_dir + "checkpoints/glip_large_model.pth" |
|
|
|
class OurGLIPDemo(GLIPDemo): |
|
|
|
def __init__(self, dev, *args_demo): |
|
|
|
kwargs = { |
|
'min_image_size': 800, |
|
'confidence_threshold': 0.5, |
|
'show_mask_heatmaps': False |
|
} |
|
|
|
self.dev = dev |
|
|
|
from maskrcnn_benchmark.config import cfg |
|
|
|
|
|
cfg.local_rank = 0 |
|
cfg.num_gpus = 1 |
|
cfg.merge_from_file(config_file) |
|
cfg.merge_from_list(["MODEL.WEIGHT", weight_file]) |
|
cfg.merge_from_list(["MODEL.DEVICE", self.dev]) |
|
|
|
from transformers.utils import logging |
|
|
|
logging.set_verbosity_error() |
|
GLIPDemo.__init__(self, cfg, *args_demo, **kwargs) |
|
if self.cfg.MODEL.RPN_ARCHITECTURE == "VLDYHEAD": |
|
plus = 1 |
|
else: |
|
plus = 0 |
|
self.plus = plus |
|
self.color = 255 |
|
|
|
@torch.no_grad() |
|
def compute_prediction(self, original_image, original_caption, custom_entity=None): |
|
image = self.transforms(original_image) |
|
|
|
image_list = to_image_list(image, self.cfg.DATALOADER.SIZE_DIVISIBILITY) |
|
image_list = image_list.to(self.dev) |
|
|
|
if isinstance(original_caption, list): |
|
|
|
if len(original_caption) > 40: |
|
all_predictions = None |
|
for loop_num, i in enumerate(range(0, len(original_caption), 40)): |
|
list_step = original_caption[i:i + 40] |
|
prediction_step = self.compute_prediction(original_image, list_step, custom_entity=None) |
|
if all_predictions is None: |
|
all_predictions = prediction_step |
|
else: |
|
|
|
all_predictions.bbox = torch.cat((all_predictions.bbox, prediction_step.bbox), dim=0) |
|
for k in all_predictions.extra_fields: |
|
all_predictions.extra_fields[k] = \ |
|
torch.cat((all_predictions.extra_fields[k], |
|
prediction_step.extra_fields[k] + loop_num), dim=0) |
|
return all_predictions |
|
|
|
|
|
caption_string = "" |
|
tokens_positive = [] |
|
seperation_tokens = " . " |
|
for word in original_caption: |
|
tokens_positive.append([len(caption_string), len(caption_string) + len(word)]) |
|
caption_string += word |
|
caption_string += seperation_tokens |
|
|
|
tokenized = self.tokenizer([caption_string], return_tensors="pt") |
|
|
|
tokens_positive = [[v] for v in tokens_positive] |
|
|
|
original_caption = caption_string |
|
|
|
else: |
|
tokenized = self.tokenizer([original_caption], return_tensors="pt") |
|
if custom_entity is None: |
|
tokens_positive = self.run_ner(original_caption) |
|
|
|
|
|
positive_map = create_positive_map(tokenized, tokens_positive) |
|
|
|
positive_map_label_to_token = create_positive_map_label_to_token_from_positive_map(positive_map, |
|
plus=self.plus) |
|
self.positive_map_label_to_token = positive_map_label_to_token |
|
tic = timeit.time.perf_counter() |
|
|
|
|
|
predictions = self.model(image_list, captions=[original_caption], |
|
positive_map=positive_map_label_to_token) |
|
predictions = [o.to(self.cpu_device) for o in predictions] |
|
|
|
|
|
|
|
prediction = predictions[0] |
|
|
|
|
|
height, width = original_image.shape[-2:] |
|
|
|
|
|
|
|
prediction = prediction.resize((width, height)) |
|
|
|
if prediction.has_field("mask"): |
|
|
|
|
|
masks = prediction.get_field("mask") |
|
|
|
masks = self.masker([masks], [prediction])[0] |
|
prediction.add_field("mask", masks) |
|
|
|
return prediction |
|
|
|
@staticmethod |
|
def to_left_right_upper_lower(bboxes): |
|
return [(bbox[1], bbox[3], bbox[0], bbox[2]) for bbox in bboxes] |
|
|
|
@staticmethod |
|
def to_xmin_ymin_xmax_ymax(bboxes): |
|
|
|
return [(bbox[2], bbox[0], bbox[3], bbox[1]) for bbox in bboxes] |
|
|
|
@staticmethod |
|
def prepare_image(image): |
|
image = image[[2, 1, 0]] |
|
return image |
|
|
|
@torch.no_grad() |
|
def forward(self, image: torch.Tensor, obj: Union[str, list], confidence_threshold=None): |
|
if confidence_threshold is not None: |
|
original_confidence_threshold = self.confidence_threshold |
|
self.confidence_threshold = confidence_threshold |
|
|
|
|
|
|
|
image = self.prepare_image(image) |
|
|
|
|
|
ratio = image.shape[1] / image.shape[2] |
|
ratio = max(ratio, 1 / ratio) |
|
original_min_image_size = self.min_image_size |
|
if ratio > 10: |
|
self.min_image_size = int(original_min_image_size * 10 / ratio) |
|
self.transforms = self.build_transform() |
|
|
|
with torch.cuda.device(self.dev): |
|
inference_output = self.inference(image, obj) |
|
|
|
bboxes = inference_output.bbox.cpu().numpy().astype(int) |
|
|
|
|
|
if ratio > 10: |
|
self.min_image_size = original_min_image_size |
|
self.transforms = self.build_transform() |
|
|
|
bboxes = torch.tensor(bboxes) |
|
|
|
|
|
height = image.shape[-2] |
|
bboxes = torch.stack([bboxes[:, 0], height - bboxes[:, 3], bboxes[:, 2], height - bboxes[:, 1]], dim=1) |
|
|
|
if confidence_threshold is not None: |
|
self.confidence_threshold = original_confidence_threshold |
|
|
|
|
|
|
|
return bboxes, inference_output.get_field("scores") |
|
|
|
self.glip_demo = OurGLIPDemo(*args, dev=self.dev) |
|
|
|
def forward(self, *args, **kwargs): |
|
return self.glip_demo.forward(*args, **kwargs) |
|
|
|
|
|
class BLIPModel(BaseModel): |
|
name = 'blip' |
|
to_batch = True |
|
max_batch_size = 32 |
|
seconds_collect_data = 0.2 |
|
|
|
def __init__(self, gpu_number=0, half_precision=True, blip_v2_model_type="blip2-flan-t5-xl"): |
|
super().__init__(gpu_number) |
|
|
|
|
|
from transformers import Blip2Processor, Blip2ForConditionalGeneration |
|
|
|
|
|
assert blip_v2_model_type in ['blip2-flan-t5-xxl', 'blip2-flan-t5-xl', 'blip2-opt-2.7b', 'blip2-opt-6.7b', |
|
'blip2-opt-2.7b-coco', 'blip2-flan-t5-xl-coco', 'blip2-opt-6.7b-coco'] |
|
|
|
with torch.cuda.device(self.dev): |
|
max_memory = {gpu_number: torch.cuda.mem_get_info(self.dev)[0]} |
|
|
|
self.processor = Blip2Processor.from_pretrained(f"Salesforce/{blip_v2_model_type}") |
|
|
|
try: |
|
self.model = Blip2ForConditionalGeneration.from_pretrained( |
|
f"Salesforce/{blip_v2_model_type}", load_in_8bit=half_precision, |
|
torch_dtype=torch.float16 if half_precision else "auto", |
|
device_map="sequential", max_memory=max_memory |
|
) |
|
except Exception as e: |
|
|
|
if "had weights offloaded to the disk" in e.args[0]: |
|
extra_text = ' You may want to consider setting half_precision to True.' if half_precision else '' |
|
raise MemoryError(f"Not enough GPU memory in GPU {self.dev} to load the model.{extra_text}") |
|
else: |
|
raise e |
|
|
|
self.qa_prompt = "Question: {} Short answer:" |
|
self.caption_prompt = "a photo of" |
|
self.half_precision = half_precision |
|
self.max_words = 50 |
|
|
|
@torch.no_grad() |
|
def caption(self, image, prompt=None): |
|
inputs = self.processor(images=image, text=prompt, return_tensors="pt").to(self.dev, torch.float16) |
|
generation_output = self.model.generate(**inputs, length_penalty=1., num_beams=5, max_length=30, min_length=1, |
|
do_sample=False, top_p=0.9, repetition_penalty=1.0, |
|
num_return_sequences=1, temperature=1, |
|
return_dict_in_generate=True, output_scores=True) |
|
generated_text = [cap.strip() for cap in self.processor.batch_decode( |
|
generation_output.sequences, skip_special_tokens=True)] |
|
return generated_text, generation_output.sequences_scores.cpu().numpy().tolist() |
|
|
|
def pre_question(self, question): |
|
|
|
question = re.sub( |
|
r"([.!\"()*#:;~])", |
|
"", |
|
question.lower(), |
|
) |
|
question = question.rstrip(" ") |
|
|
|
|
|
question_words = question.split(" ") |
|
if len(question_words) > self.max_words: |
|
question = " ".join(question_words[: self.max_words]) |
|
|
|
return question |
|
|
|
@torch.no_grad() |
|
def qa(self, image, question): |
|
inputs = self.processor(images=image, text=question, return_tensors="pt", padding="longest").to(self.dev) |
|
if self.half_precision: |
|
inputs['pixel_values'] = inputs['pixel_values'].half() |
|
generation_output = self.model.generate(**inputs, length_penalty=-1, num_beams=5, max_length=10, min_length=1, |
|
do_sample=False, top_p=0.9, repetition_penalty=1.0, |
|
num_return_sequences=1, temperature=1, |
|
return_dict_in_generate=True, output_scores=True) |
|
generated_text = self.processor.batch_decode(generation_output.sequences, skip_special_tokens=True) |
|
return generated_text, generation_output.sequences_scores.cpu().numpy().tolist() |
|
|
|
def forward(self, image, question=None, task='caption'): |
|
if not self.to_batch: |
|
image, question, task = [image], [question], [task] |
|
|
|
if len(image) > 0 and 'float' in str(image[0].dtype) and image[0].max() <= 1: |
|
image = [im * 255 for im in image] |
|
|
|
|
|
prompts_qa = [self.qa_prompt.format(self.pre_question(q)) for q, t in zip(question, task) if t == 'qa'] |
|
images_qa = [im for i, im in enumerate(image) if task[i] == 'qa'] |
|
images_caption = [im for i, im in enumerate(image) if task[i] == 'caption'] |
|
|
|
with torch.cuda.device(self.dev): |
|
response_qa, scores_qa = self.qa(images_qa, prompts_qa) if len(images_qa) > 0 else ([], []) |
|
response_caption, scores_caption = self.caption(images_caption) if len(images_caption) > 0 else ([], []) |
|
|
|
response = [] |
|
for t in task: |
|
if t == 'qa': |
|
response.append([response_qa.pop(0), scores_qa.pop(0)]) |
|
else: |
|
response.append([response_caption.pop(0), scores_caption.pop(0)]) |
|
|
|
if not self.to_batch: |
|
response = response[0] |
|
return response |
|
|
|
|
|
class XVLMModel(BaseModel): |
|
name = 'xvlm' |
|
|
|
def __init__(self, gpu_number=0, path_checkpoint='pretrained_models/xvlm/retrieval_mscoco_checkpoint_9.pth'): |
|
|
|
from xvlm.xvlm import XVLMBase |
|
from transformers import BertTokenizer |
|
|
|
super().__init__(gpu_number) |
|
|
|
image_res = 384 |
|
self.max_words = 30 |
|
config_xvlm = { |
|
'image_res': image_res, |
|
'patch_size': 32, |
|
'text_encoder': 'bert-base-uncased', |
|
'block_num': 9, |
|
'max_tokens': 40, |
|
'embed_dim': 256, |
|
} |
|
|
|
vision_config = { |
|
'vision_width': 1024, |
|
'image_res': 384, |
|
'window_size': 12, |
|
'embed_dim': 128, |
|
'depths': [2, 2, 18, 2], |
|
'num_heads': [4, 8, 16, 32] |
|
} |
|
model = XVLMBase(config_xvlm, use_contrastive_loss=True, vision_config=vision_config) |
|
checkpoint = torch.load(path_checkpoint, map_location='cpu') |
|
state_dict = checkpoint['model'] if 'model' in checkpoint.keys() else checkpoint |
|
msg = model.load_state_dict(state_dict, strict=False) |
|
if len(msg.missing_keys) > 0: |
|
print('XVLM Missing keys: ', msg.missing_keys) |
|
|
|
model = model.to(self.dev) |
|
model.eval() |
|
|
|
self.model = model |
|
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') |
|
|
|
normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) |
|
self.transform = transforms.Compose([ |
|
transforms.ToPILImage(), |
|
transforms.Resize((image_res, image_res), interpolation=Image.BICUBIC), |
|
transforms.ToTensor(), |
|
normalize, |
|
]) |
|
|
|
with open('useful_lists/random_negatives.txt') as f: |
|
self.negative_categories = [x.strip() for x in f.read().split()] |
|
|
|
@staticmethod |
|
def pre_caption(caption, max_words): |
|
caption = re.sub( |
|
r"([,.'!?\"()*#:;~])", |
|
'', |
|
caption.lower(), |
|
).replace('-', ' ').replace('/', ' ').replace('<person>', 'person') |
|
|
|
caption = re.sub( |
|
r"\s{2,}", |
|
' ', |
|
caption, |
|
) |
|
caption = caption.rstrip('\n') |
|
caption = caption.strip(' ') |
|
|
|
|
|
caption_words = caption.split(' ') |
|
if len(caption_words) > max_words: |
|
caption = ' '.join(caption_words[:max_words]) |
|
|
|
if not len(caption): |
|
raise ValueError("pre_caption yields invalid text") |
|
|
|
return caption |
|
|
|
@torch.no_grad() |
|
def score(self, images, texts): |
|
|
|
if isinstance(texts, str): |
|
texts = [texts] |
|
|
|
if not isinstance(images, list): |
|
images = [images] |
|
|
|
images = [self.transform(image) for image in images] |
|
images = torch.stack(images, dim=0).to(self.dev) |
|
|
|
texts = [self.pre_caption(text, self.max_words) for text in texts] |
|
text_input = self.tokenizer(texts, padding='longest', return_tensors="pt").to(self.dev) |
|
|
|
image_embeds, image_atts = self.model.get_vision_embeds(images) |
|
text_ids, text_atts = text_input.input_ids, text_input.attention_mask |
|
text_embeds = self.model.get_text_embeds(text_ids, text_atts) |
|
|
|
image_feat, text_feat = self.model.get_features(image_embeds, text_embeds) |
|
logits = image_feat @ text_feat.t() |
|
|
|
return logits |
|
|
|
@torch.no_grad() |
|
def binary_score(self, image, text, negative_categories): |
|
|
|
texts = [text] + negative_categories |
|
sim = 100 * self.score(image, texts)[0] |
|
res = F.softmax(torch.cat((sim[0].broadcast_to(1, sim.shape[0] - 1), |
|
sim[1:].unsqueeze(0)), dim=0), dim=0)[0].mean() |
|
return res |
|
|
|
def forward(self, image, text, task='score', negative_categories=None): |
|
if task == 'score': |
|
score = self.score(image, text) |
|
else: |
|
score = self.binary_score(image, text, negative_categories=negative_categories) |
|
return score.cpu() |
|
|