mins
initial commit
b443c25
"""
Mostly copy-paste from LLaVA-HR
https://github.com/luogen1996/LLaVA-HR
"""
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
import math
import torch
import torch.nn.functional as F
from typing import List, Optional
def forward_embeddings(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
batch_size = pixel_values.shape[0]
target_dtype = self.patch_embedding.weight.dtype
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
position_embeddings = self.position_embedding(self.position_ids)
if position_embeddings.shape[1]!=embeddings.shape[1]:
position_embeddings=resample_pos_embed(position_embeddings,embeddings.shape[1])
embeddings = embeddings + position_embeddings
return embeddings
def resample_pos_embed(
posemb,
new_size: int,
num_prefix_tokens: int = 1,
interpolation: str = 'bicubic',
antialias: bool = True,
verbose: bool = False,
):
new_size=[int(math.sqrt(new_size-num_prefix_tokens)),int(math.sqrt(new_size-num_prefix_tokens))]
num_pos_tokens = posemb.shape[1] - num_prefix_tokens
old_size = int(math.sqrt(num_pos_tokens))
bs=posemb.shape[0]
if num_prefix_tokens:
posemb_prefix, posemb = posemb[:,:num_prefix_tokens], posemb[:,num_prefix_tokens:]
else:
posemb_prefix, posemb = None, posemb
# do the interpolation
embed_dim = posemb.shape[-1]
orig_dtype = posemb.dtype
posemb = posemb.float() # interpolate needs float32
posemb = posemb.reshape(bs, old_size, old_size, -1).permute(0, 3, 1, 2)
posemb = F.interpolate(posemb, size=new_size, mode=interpolation, antialias=antialias)
posemb = posemb.permute(0, 2, 3, 1).reshape(bs, -1, embed_dim)
posemb = posemb.to(dtype=orig_dtype)
# add back extra (class, etc) prefix tokens
if posemb_prefix is not None:
posemb = torch.cat([posemb_prefix, posemb],1)
if not torch.jit.is_scripting() and verbose:
print(f'Resized position embedding: {old_size} to {new_size}.')
return posemb
class HRCLIPVisionTower(nn.Module):
def __init__(self, vision_tower, args, delay_load=False):
super().__init__()
self.is_loaded = False
self.freeze_vision=args.freeze_vision
self.input_image_size=args.input_image_size
self.vision_tower_name = vision_tower
self.select_layer = args.mm_vision_select_layer
self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
if not delay_load:
self.load_model()
else:
self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
def load_model(self):
self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name)
# checkpointing for clip
self.vision_tower.vision_model.encoder.gradient_checkpointing =True
if self.freeze_vision:
self.vision_tower.requires_grad_(False)
cls_=self.vision_tower.vision_model.embeddings
bound_method = forward_embeddings.__get__(cls_, cls_.__class__)
setattr(cls_, 'forward', bound_method)
if self.input_image_size is not None:
self.image_processor.size=self.input_image_size
self.image_processor.crop_size={
'height':self.input_image_size,
'width': self.input_image_size
}
self.is_loaded = True
def forward(self, x):
# 448 image input
blks = self.vision_tower.vision_model.encoder.layers
x = self.vision_tower.vision_model.embeddings(x)
x = self.vision_tower.vision_model.pre_layrnorm(x[:, 1:])
# inference of fast branch
for blk in blks:
if self.training:
x=checkpoint(
blk.__call__,
x,
None,
None
)[0]
else:
x = blk(x, None, None)[0]
return x
@property
def dummy_feature(self):
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
@property
def dtype(self):
return self.vision_tower.dtype
@property
def device(self):
return self.vision_tower.device
@property
def num_attention_heads(self):
return self.config.num_attention_heads
@property
def num_layers(self):
return self.config.num_hidden_layers
@property
def config(self):
if self.is_loaded:
return self.vision_tower.config
else:
return self.cfg_only
@property
def hidden_size(self):
return self.config.hidden_size
@property
def num_patches(self):
return (self.config.image_size // self.config.patch_size) ** 2