from typing import List, Dict, Union, Tuple from PIL import Image, ImageDraw, ImageFilter, ImageOps, ImageEnhance import spacy import hashlib import os import torch import torchvision import torchvision.transforms as transforms import clip from transformers import BertTokenizer, RobertaTokenizerFast import ruamel.yaml as yaml import copy from interpreter import Box import pycocotools.mask as mask_utils import alpha_clip from segment_anything import sam_model_registry, SamPredictor import numpy as np import cv2 import matplotlib.pyplot as plt import pickle class Executor: def __init__(self, device: str = "cpu", box_representation_method: str = "crop", method_aggregator: str = "max", enlarge_boxes: int = 0, expand_position_embedding: bool = False, square_size: bool = False, blur_std_dev: int = 100, cache_path: str = None, input_file: str = None) -> None: IMPLEMENTED_METHODS = ["blur", "full", "gray"] if any(m not in IMPLEMENTED_METHODS for m in box_representation_method.split(",")): raise NotImplementedError IMPLEMENTED_AGGREGATORS = ["max", "sum"] if method_aggregator not in IMPLEMENTED_AGGREGATORS: raise NotImplementedError self.box_representation_method = box_representation_method self.method_aggregator = method_aggregator self.enlarge_boxes = enlarge_boxes self.device = device self.expand_position_embedding = expand_position_embedding self.square_size = square_size self.blur_std_dev = blur_std_dev self.cache_path = cache_path def preprocess_image(self, image: Image) -> List[torch.Tensor]: return [preprocess(image) for preprocess in self.preprocesses] def preprocess_mask(self, mask: Image) -> List[torch.Tensor]: preprocess = self.preprocesses[0] return preprocess.transforms[1](preprocess.transforms[0](mask)) def preprocess_text(self, text: str) -> torch.Tensor: raise NotImplementedError def call_model(self, model: torch.nn.Module, images: torch.Tensor, text: Union[torch.Tensor, Dict[str, torch.Tensor]]) -> torch.Tensor: raise NotImplementedError def tensorize_inputs(self, caption: str, image: Image, boxes: List[Box], image_name: str = None, image_pth: str = None) -> Tuple[List[torch.Tensor], torch.Tensor]: images = [] for preprocess in self.preprocesses: images.append([]) if 'aclip' in self.clip_type: self.all_masks = [] read_save = False if self.mask_path is not None: # load mask if cached file_name = image_pth.split('/')[-1].split('.')[0]+'.pkl' if os.path.exists(os.path.join(self.mask_path, file_name)): all_rles = pickle.load(open(os.path.join(self.mask_path, file_name),'rb')) for rle in all_rles: mask = np.array(mask_utils.decode(rle), dtype=bool) self.all_masks.append(mask) read_save = True if not read_save: # use SAM to generate masks self.predictor.set_image(np.array(image.convert('RGB'))) all_rles = [] for i in range(len(boxes)): box = [ max(boxes[i].left-self.enlarge_boxes, 0), max(boxes[i].top-self.enlarge_boxes, 0), min(boxes[i].right+self.enlarge_boxes, image.width), min(boxes[i].bottom+self.enlarge_boxes, image.height) ] # box prompt input_box = np.array(box) masks, _, _ = self.predictor.predict( point_coords=None, point_labels=None, box=input_box[None, :], multimask_output=False, ) self.all_masks.append(masks[0]) rle = mask_utils.encode(np.array(masks[0][:, :, None], order='F', dtype="uint8"))[0] rle["counts"] = rle["counts"].decode("utf-8") all_rles.append(rle) if self.mask_path is not None: # save mask os.makedirs(self.mask_path, exist_ok=True) pickle.dump(all_rles, open(os.path.join(self.mask_path, file_name),'wb')) if self.cache_path is None or any([not os.path.exists(os.path.join(self.cache_path, "refcoco_val", model_name, "image", image_name, method_name+".pt")) for model_name in self.model_names for method_name in self.box_representation_method.split(',')]): if "full" in self.box_representation_method: # original full image with alpha-map for i in range(len(boxes)): image_i = image.copy() preprocessed_images = self.preprocess_image(image_i) for j, img in enumerate(preprocessed_images): images[j].append(img.to(self.device)) if "blur" in self.box_representation_method: for i in range(len(boxes)): image_i = image.copy() mask = Image.new('L', image_i.size, 0) draw = ImageDraw.Draw(mask) box = ( max(boxes[i].left-self.enlarge_boxes, 0), max(boxes[i].top-self.enlarge_boxes, 0), min(boxes[i].right+self.enlarge_boxes, image_i.width), min(boxes[i].bottom+self.enlarge_boxes, image_i.height) ) if 'aclip' in self.clip_type: width, height = image.size for y in range(height): for x in range(width): if self.all_masks[i][y][x] == 1: draw.point((x, y), fill=255) else: draw.rectangle([box[:2], box[2:]], fill=255) blurred = image_i.filter(ImageFilter.GaussianBlur(self.blur_std_dev)) blurred.paste(image_i, mask=mask) preprocessed_images = self.preprocess_image(blurred) for j, img in enumerate(preprocessed_images): images[j].append(img.to(self.device)) if "gray" in self.box_representation_method: for i in range(len(boxes)): image_i = image.copy() mask_i = self.all_masks[i] width, height = image.size pixels = image_i.load() for y in range(height): for x in range(width): if mask_i[y][x] == 0: pixel_value = pixels[x, y] gray_value = int(0.2989 * pixel_value[0] + 0.5870 * pixel_value[1] + 0.1140 * pixel_value[2]) pixels[x, y] = (gray_value, gray_value, gray_value) preprocessed_images = self.preprocess_image(image_i) for j, img in enumerate(preprocessed_images): images[j].append(img.to(self.device)) imgs = [torch.stack(image_list) for image_list in images] else: imgs = [[] for _ in self.models] text_tensor = self.preprocess_text(caption.lower()).to(self.device) return imgs, text_tensor @torch.no_grad() def __call__(self, caption: str, image: Image, boxes: List[Box], image_name: str = None, image_pth=None) -> torch.Tensor: images, text_tensor = self.tensorize_inputs(caption, image, boxes, image_name, image_pth) all_logits_per_image = [] all_logits_per_text = [] box_representation_methods = self.box_representation_method.split(',') caption_hash = hashlib.md5(caption.encode('utf-8')).hexdigest() for model, images_t, model_name in zip(self.models, images, self.model_names): self.image_feat_path = "" if self.cache_path is not None: text_cache_path = os.path.join(self.cache_path, "refcoco_val", model_name, "text"+("_shade" if self.box_representation_method == "shade" else "")) image_feat_path = os.path.join(self.cache_path, "refcoco_val", model_name, "image", image_name) self.image_feat_path = image_feat_path image_features = None text_features = None if self.cache_path is not None and os.path.exists(os.path.join(self.cache_path, "refcoco_val", model_name)): if os.path.exists(os.path.join(text_cache_path, caption_hash+".pt")): text_features = torch.load(os.path.join(text_cache_path, caption_hash+".pt"), map_location=self.device) if os.path.exists(image_feat_path): if all([os.path.exists(os.path.join(image_feat_path, method_name+".pt")) for method_name in box_representation_methods]): image_features = [] for method_name in box_representation_methods: features = torch.load(os.path.join(image_feat_path, method_name+".pt"), map_location=self.device) image_features.append(torch.stack([ features[(box.x, box.y, box.w, box.h)] for box in boxes ])) image_features = torch.stack(image_features) image_features = image_features.view(-1, image_features.shape[-1]) logits_per_image, logits_per_text, image_features, text_features = self.call_model(model, images_t, text_tensor, image_features=image_features, text_features=text_features, boxes=boxes, image_pth=image_pth) all_logits_per_image.append(logits_per_image) all_logits_per_text.append(logits_per_text) if self.cache_path is not None and image_name is not None and image_features is not None: image_features = image_features.view(len(box_representation_methods), len(boxes), image_features.shape[-1]) if not os.path.exists(image_feat_path): os.makedirs(image_feat_path) for i in range(image_features.shape[0]): method_name = box_representation_methods[i] if not os.path.exists(os.path.join(image_feat_path, method_name+".pt")): image_features_dict = {(box.x, box.y, box.w, box.h): image_features[i,j,:].cpu() for j, box in enumerate(boxes)} torch.save(image_features_dict, os.path.join(image_feat_path, method_name+".pt")) if self.cache_path is not None and not os.path.exists(os.path.join(text_cache_path, caption_hash+".pt")) and text_features is not None: assert text_features.shape[0] == 1 if not os.path.exists(text_cache_path): os.makedirs(text_cache_path) torch.save(text_features.cpu(), os.path.join(text_cache_path, caption_hash+".pt")) all_logits_per_image = torch.stack(all_logits_per_image).sum(0) all_logits_per_text = torch.stack(all_logits_per_text).sum(0) if self.method_aggregator == "max": all_logits_per_text = all_logits_per_text.view(-1, len(boxes)).max(dim=0, keepdim=True)[0] elif self.method_aggregator == "sum": all_logits_per_text = all_logits_per_text.view(-1, len(boxes)).sum(dim=0, keepdim=True) return all_logits_per_text.view(-1) class ClipExecutor(Executor): def __init__(self, clip_model: str = "ViT-B/32", device: str = "cpu", box_representation_method: str = "crop", method_aggregator: str = "max", enlarge_boxes: int = 0, expand_position_embedding: bool = False, square_size: bool = False, blur_std_dev: int = 100, cache_path: str = None, input_file: str = None, clip_type: str=None) -> None: super().__init__(device, box_representation_method, method_aggregator, enlarge_boxes, expand_position_embedding, square_size, blur_std_dev, cache_path) self.clip_models = clip_model.split(",") self.model_names = [model_name.replace("/", "_") for model_name in self.clip_models] self.models = [] self.preprocesses = [] self.data_name = input_file.split('/')[-1].split('.')[0] self.mask_path = None self.clip_type = clip_type if self.cache_path is not None: self.mask_path = os.path.join(self.cache_path, "refcoco_val", 'det_masks') sam_checkpoint = "./ckpt/sam_vit_h_4b8939.pth" model_type = "vit_h" sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) sam.to(device=device) self.predictor = SamPredictor(sam) for model_name in self.clip_models: if 'aclip' in self.clip_type:#using alpha-clip self.mask_transform = transforms.Compose([ transforms.ToTensor(), transforms.Resize((224, 224)), transforms.Normalize(0.5, 0.26) ]) if model_name == 'ViT-B/16': model, preprocess = alpha_clip.load("ViT-B/16", alpha_vision_ckpt_pth="./ckpt/grit1m/clip_b16_grit+mim_fultune_4xe.pth", device=device) elif model_name == 'ViT-L/14': model, preprocess = alpha_clip.load("ViT-L/14", alpha_vision_ckpt_pth="./ckpt/grit1m/clip_l14_grit+mim_fultune_6xe.pth", device=device) else: model, preprocess = clip.load(model_name, device=device, jit=False) self.models.append(model) if self.square_size: print("Square size!") preprocess.transforms[0] = transforms.Resize((model.visual.input_resolution, model.visual.input_resolution), interpolation=transforms.InterpolationMode.BICUBIC) self.preprocesses.append(preprocess) self.models = torch.nn.ModuleList(self.models) def preprocess_text(self, text: str) -> torch.Tensor: if "aclip" in self.box_representation_method: return alpha_clip.tokenize([text.lower()]) if "shade" in self.box_representation_method: return clip.tokenize([text.lower()+" is in red color."]) return clip.tokenize(["a photo of "+text.lower()]) def call_model(self, model: torch.nn.Module, images: torch.Tensor, text: torch.Tensor, image_features: torch.Tensor = None, text_features: torch.Tensor = None, boxes=None, image_pth=None) -> torch.Tensor: if image_features is None: print('computing image features') if 'aclip' not in self.clip_type: image_features = model.encode_image(images) image_features = image_features / image_features.norm(dim=-1, keepdim=True) else: image_features = [] if 'full' in self.box_representation_method: aclip_images = images[:len(boxes)] alphas = [] if os.path.exists(os.path.join(self.image_feat_path, 'full.pt')): features = torch.load(os.path.join(self.image_feat_path, 'full.pt'), map_location=self.device) aclip_image_features = torch.stack([ features[(box.x, box.y, box.w, box.h)] for box in boxes ]) else: for i in range(len(self.all_masks)): binary_mask = self.all_masks[i] alpha = self.mask_transform((binary_mask * 255).astype(np.uint8)) alpha = alpha.half().cuda().unsqueeze(dim=0) alphas.append(alpha) alphas = torch.cat(alphas, dim=0) aclip_images = aclip_images.half() aclip_image_features = model.visual(aclip_images, alphas) # using alpha channels images = images[len(boxes):] image_features.append(aclip_image_features) if 'blur' in self.box_representation_method: if os.path.exists(os.path.join(self.image_feat_path, 'blur.pt')): features = torch.load(os.path.join(self.image_feat_path, 'blur.pt'), map_location=self.device) ablur_images_features = torch.stack([ features[(box.x, box.y, box.w, box.h)] for box in boxes ]) else: ablur_images = images[:len(boxes)] alphas = [] for i in range(len(self.all_masks)): binary_mask = self.all_masks[i] alpha = self.mask_transform((binary_mask * 255).astype(np.uint8)) alpha = alpha.half().cuda().unsqueeze(dim=0) alphas.append(alpha) alphas = torch.cat(alphas, dim=0) ablur_images = ablur_images.half() ablur_images_features = model.visual(ablur_images, alphas) images = images[len(boxes):] image_features.append(ablur_images_features) if 'gray' in self.box_representation_method: if os.path.exists(os.path.join(self.image_feat_path, 'gray.pt')): features = torch.load(os.path.join(self.image_feat_path, 'gray.pt'), map_location=self.device) gray_images_features = torch.stack([ features[(box.x, box.y, box.w, box.h)] for box in boxes ]) else: gray_images = images[:len(boxes)] alphas = [] for i in range(len(self.all_masks)): binary_mask = self.all_masks[i] alpha = self.mask_transform((binary_mask * 255).astype(np.uint8)) alpha = alpha.half().cuda().unsqueeze(dim=0) alphas.append(alpha) alphas = torch.cat(alphas, dim=0) gray_images = gray_images.half() gray_images_features = model.visual(gray_images, alphas) images = images[len(boxes):] image_features.append(gray_images_features) image_features = torch.cat(image_features, dim=0) image_features = image_features / image_features.norm(dim=-1, keepdim=True) if text_features is None: print('computing text features') text_features = model.encode_text(text) # normalized features text_features = text_features / text_features.norm(dim=-1, keepdim=True) # cosine similarity as logits logit_scale = model.logit_scale.exp() logits_per_image = logit_scale * image_features @ text_features.t() logits_per_text = logits_per_image.t() return logits_per_image, logits_per_text, image_features, text_features def __call__(self, caption: str, image: Image, boxes: List[Box], image_name: str = None, image_pth=None) -> torch.Tensor: if self.expand_position_embedding: original_preprocesses = self.preprocesses new_preprocesses = [] original_position_embeddings = [] for model_name, model, preprocess in zip(self.clip_models, self.models, self.preprocesses): if "RN" in model_name: model_spatial_dim = int((model.visual.attnpool.positional_embedding.shape[0]-1)**0.5) patch_size = model.visual.input_resolution // model_spatial_dim original_positional_embedding = model.visual.attnpool.positional_embedding.clone() model.visual.attnpool.positional_embedding = torch.nn.Parameter(torch.nn.functional.interpolate( model.visual.attnpool.positional_embedding[1:,:].permute(1, 0).view(1, -1, model_spatial_dim, model_spatial_dim), size=(image.height // patch_size, image.width // patch_size), mode='bicubic', align_corners=False ).squeeze(0).permute(1, 2, 0).view(-1, original_positional_embedding.shape[-1])) model.visual.attnpool.positional_embedding = torch.nn.Parameter(torch.cat(( original_positional_embedding[:1,:], model.visual.attnpool.positional_embedding ), dim=0)) transform = transforms.Compose([ transforms.Resize(((image.height // patch_size)*patch_size, (image.width // patch_size)*patch_size), interpolation=Image.BICUBIC), lambda image: image.convert("RGB"), transforms.ToTensor(), transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), ]) else: model_spatial_dim = int((model.visual.positional_embedding.shape[0]-1)**0.5) patch_size = model.visual.input_resolution // model_spatial_dim original_positional_embedding = model.visual.positional_embedding.clone() model.visual.positional_embedding = torch.nn.Parameter(torch.nn.functional.interpolate( model.visual.positional_embedding[1:,:].permute(1, 0).view(1, -1, model_spatial_dim, model_spatial_dim), size=(image.height // patch_size, image.width // patch_size), mode='bicubic', align_corners=False ).squeeze(0).permute(1, 2, 0).view(-1, original_positional_embedding.shape[-1])) model.visual.positional_embedding = torch.nn.Parameter(torch.cat(( original_positional_embedding[:1,:], model.visual.positional_embedding ), dim=0)) transform = transforms.Compose([ transforms.Resize(((image.height // patch_size)*patch_size, (image.width // patch_size)*patch_size), interpolation=Image.BICUBIC), lambda image: image.convert("RGB"), transforms.ToTensor(), transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), ]) new_preprocesses.append(transform) original_position_embeddings.append(original_positional_embedding) self.preprocesses = new_preprocesses result = super().__call__(caption, image, boxes, image_name, image_pth) if self.expand_position_embedding: self.preprocesses = original_preprocesses for model, model_name, pos_embedding in zip(self.models, self.clip_models, original_position_embeddings): if "RN" in model_name: model.visual.attnpool.positional_embedding = torch.nn.Parameter(pos_embedding) else: model.visual.positional_embedding = torch.nn.Parameter(pos_embedding) return result