Your Name
fixed a small bug on mask
b67d94e
import torch
import torch.nn as nn
import numpy as np
from functools import partial
from lib.model_zoo.common.get_model import register
import torch.nn.functional as F
symbol = 'clip'
class AbstractEncoder(nn.Module):
def __init__(self):
super().__init__()
def encode(self, *args, **kwargs):
raise NotImplementedError
from transformers import CLIPTokenizer, CLIPTextModel
def disabled_train(self, mode=True):
"""Overwrite model.train with this function to make sure train/eval mode
does not change anymore."""
return self
###############
# for vd next #
###############
from transformers import CLIPModel
@register('clip_text_context_encoder')
class CLIPTextContextEncoder(AbstractEncoder):
def __init__(self,
version="openai/clip-vit-large-patch14",
max_length=77,
fp16=False, ):
super().__init__()
self.tokenizer = CLIPTokenizer.from_pretrained(version)
self.model = CLIPModel.from_pretrained(version)
self.max_length = max_length
self.fp16 = fp16
self.freeze()
def get_device(self):
# A trick to get device
return self.model.text_projection.weight.device
def freeze(self):
self.model = self.model.eval()
self.train = disabled_train
for param in self.parameters():
param.requires_grad = False
def encode(self, text):
batch_encoding = self.tokenizer(
text, truncation=True, max_length=self.max_length, return_length=True,
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
tokens = batch_encoding["input_ids"].to(self.get_device())
outputs = self.model.text_model(input_ids=tokens)
z = self.model.text_projection(outputs.last_hidden_state)
z_pooled = self.model.text_projection(outputs.pooler_output)
z = z / torch.norm(z_pooled.unsqueeze(1), dim=-1, keepdim=True)
return z
from transformers import CLIPProcessor
@register('clip_image_context_encoder')
class CLIPImageContextEncoder(AbstractEncoder):
def __init__(self,
version="openai/clip-vit-large-patch14",
fp16=False, ):
super().__init__()
self.tokenizer = CLIPTokenizer.from_pretrained(version)
self.processor = CLIPProcessor.from_pretrained(version)
self.model = CLIPModel.from_pretrained(version)
self.fp16 = fp16
self.freeze()
def get_device(self):
# A trick to get device
return self.model.text_projection.weight.device
def freeze(self):
self.model = self.model.eval()
self.train = disabled_train
for param in self.parameters():
param.requires_grad = False
def _encode(self, images):
if isinstance(images, torch.Tensor):
import torchvision.transforms as tvtrans
images = [tvtrans.ToPILImage()(i) for i in images]
inputs = self.processor(images=images, return_tensors="pt")
pixels = inputs['pixel_values'].half() if self.fp16 else inputs['pixel_values']
pixels = pixels.to(self.get_device())
outputs = self.model.vision_model(pixel_values=pixels)
z = outputs.last_hidden_state
z = self.model.vision_model.post_layernorm(z)
z = self.model.visual_projection(z)
z_pooled = z[:, 0:1]
z = z / torch.norm(z_pooled, dim=-1, keepdim=True)
return z
@torch.no_grad()
def _encode_wmask(self, images, masks):
assert isinstance(masks, torch.Tensor)
assert (len(masks.shape)==4) and (masks.shape[1]==1)
masks = torch.clamp(masks, 0, 1)
masks = masks.float()
masks = F.interpolate(masks, [224, 224], mode='bilinear')
if masks.sum() == masks.numel():
return self._encode(images)
device = images.device
dtype = images.dtype
gscale = masks.mean(axis=[1, 2, 3], keepdim=True).flatten(2)
vtoken_kernel_size = self.model.vision_model.embeddings.patch_embedding.kernel_size
vtoken_stride = self.model.vision_model.embeddings.patch_embedding.stride
mask_kernal = torch.ones([1, 1, *vtoken_kernel_size], device=device, requires_grad=False).float()
vtoken_mask = torch.nn.functional.conv2d(masks, mask_kernal, stride=vtoken_stride).flatten(2).transpose(1, 2)
vtoken_mask = vtoken_mask/np.prod(vtoken_kernel_size)
vtoken_mask = torch.concat([gscale, vtoken_mask], axis=1)
import types
def customized_embedding_forward(self, pixel_values):
batch_size = pixel_values.shape[0]
patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
embeddings = embeddings + self.position_embedding(self.position_ids)
embeddings = embeddings*vtoken_mask.to(embeddings.dtype)
return embeddings
old_forward = self.model.vision_model.embeddings.forward
self.model.vision_model.embeddings.forward = types.MethodType(
customized_embedding_forward, self.model.vision_model.embeddings)
z = self._encode(images)
self.model.vision_model.embeddings.forward = old_forward
z = z * vtoken_mask.to(dtype)
return z
def encode(self, images, masks=None):
if masks is None:
return self._encode(images)
else:
return self._encode_wmask(images, masks)