Spaces:
Sleeping
Sleeping
import logging | |
import torch | |
from torch import nn | |
import numpy as np | |
from PIL import Image | |
import torchvision.transforms as transforms | |
from torchvision.transforms import InterpolationMode | |
from .backbones.internvideo2 import InternVideo2, LLaMA, Tokenizer | |
from .criterions import VTC_VTM_Loss | |
logger = logging.getLogger(__name__) | |
class InternVideo2_CLIP(nn.Module): | |
def __init__(self, config, tokenizer=None, is_pretrain=True): | |
super().__init__() | |
self.config = config | |
self.tokenizer = tokenizer | |
self.is_pretrain = is_pretrain | |
# create modules. | |
if tokenizer is None: | |
self.tokenizer = Tokenizer(config.model.tokenizer_path) | |
self.vision_encoder = self.build_vision_encoder() | |
self.text_encoder = self.build_text_encoder() | |
# adopt 1 / 100. as in ViCLIP | |
self.temp = nn.parameter.Parameter(torch.ones([]) * config.model.temp) | |
self.temp_min = config.model.temp_min | |
# freeze model | |
if self.config.model.freeze_vision: | |
for name, p in self.vision_encoder.named_parameters(): | |
if self.config.model.open_vision_clip_projector and name.startswith('clip_projector'): | |
logger.info(f"Unfreeze {name}") | |
else: | |
logger.info(f"Freeze {name}") | |
p.requires_grad = False | |
if self.config.model.freeze_text: | |
for name, p in self.text_encoder.named_parameters(): | |
if self.config.model.open_text_projection and name.startswith('text_projection'): | |
logger.info(f"Unfreeze {name}") | |
elif self.config.model.open_text_lora and 'lora' in name: | |
logger.info(f"Unfreeze {name}") | |
else: | |
logger.info(f"Freeze {name}") | |
p.requires_grad = False | |
img_size = self.config.model.vision_encoder.img_size | |
self.transform = transforms.Compose( | |
[ | |
transforms.Resize( | |
(img_size, img_size), | |
interpolation=InterpolationMode.BICUBIC, | |
), | |
transforms.Lambda(lambda x: x.float().div(255.0)), | |
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), | |
] | |
) | |
# load pretrained models | |
self.load_checkpoint( | |
config.model.vision_ckpt_path, config.model.text_ckpt_path, | |
config.model.get("extra_ckpt_path", None) | |
) | |
# criterions | |
self.clip_loss = VTC_VTM_Loss(False) | |
def no_weight_decay(self): | |
ret = {"temp"} | |
ret.update( | |
{"vision_encoder." + k for k in self.vision_encoder.no_weight_decay()} | |
) | |
# no weight decay for LLM if training | |
ret.update( | |
{"text_encoder." + k for k, _ in self.text_encoder.named_parameters()} | |
) | |
return ret | |
def clip_contrastive_temperature(self): | |
"""Seems only used during pre-training""" | |
self.temp.clamp_(min=self.temp_min) | |
def forward(self, image, text, idx): | |
"""forward and calculate loss. | |
Args: | |
image (torch.Tensor): The input images. Shape: [B,T,C,H,W]. | |
text (dict): TODO | |
idx (torch.Tensor): TODO | |
Returns: TODO | |
""" | |
self.clip_contrastive_temperature() | |
vision_embeds = self.encode_vision(image) | |
text_embeds = self.encode_text(text) | |
# VTC loss | |
loss_vtc = self.clip_loss.vtc_loss( | |
vision_embeds, text_embeds, idx, self.temp, all_gather=True | |
) | |
return dict( | |
loss_vtc=loss_vtc, | |
) | |
def encode_vision(self, image, test=False): | |
"""encode image / videos as features. | |
Args: | |
image (torch.Tensor): The input images. | |
test (bool): Whether testing. | |
Returns: tuple. | |
- vision_embeds (torch.Tensor): The features of all patches. Shape: [B,C]. | |
""" | |
T = image.shape[1] | |
use_image = True if T == 1 else False | |
image = image.permute(0, 2, 1, 3, 4) # [B,T,C,H,W] -> [B,C,T,H,W] | |
vision_embeds = self.vision_encoder(image, use_image=use_image) | |
return vision_embeds | |
def encode_text(self, text): | |
"""encode text. | |
Args: | |
text (dict): The output of huggingface's `PreTrainedTokenizer`. contains keys: | |
- input_ids (torch.Tensor): Token ids to be fed to a model. Shape: [B,L]. | |
- attention_mask (torch.Tensor): The mask indicate padded tokens. Shape: [B,L]. 0 is padded token. | |
- other keys refer to "https://huggingface.co/docs/transformers/v4.21.2/en/main_classes/tokenizer#transformers.PreTrainedTokenizer.__call__". | |
Returns: tuple. | |
- text_embeds (torch.Tensor): The features of all tokens. Shape: [B,C]. | |
""" | |
text_embeds = self.text_encoder(text) | |
return text_embeds | |
def build_vision_encoder(self): | |
"""build vision encoder | |
Returns: (vision_encoder, vision_layernorm). Each is a `nn.Module`. | |
""" | |
vision_encoder = InternVideo2( | |
in_chans=self.config.model.vision_encoder.in_chans, | |
patch_size=self.config.model.vision_encoder.patch_size, | |
img_size=self.config.model.vision_encoder.img_size, | |
qkv_bias=self.config.model.vision_encoder.qkv_bias, | |
drop_path_rate=self.config.model.vision_encoder.drop_path_rate, | |
head_drop_path_rate=self.config.model.vision_encoder.head_drop_path_rate, | |
embed_dim=self.config.model.vision_encoder.embed_dim, | |
num_heads=self.config.model.vision_encoder.num_heads, | |
mlp_ratio=self.config.model.vision_encoder.mlp_ratio, | |
init_values=self.config.model.vision_encoder.init_values, | |
qk_normalization=self.config.model.vision_encoder.qk_normalization, | |
depth=self.config.model.vision_encoder.depth, | |
use_flash_attn=self.config.model.vision_encoder.use_flash_attn, | |
use_fused_rmsnorm=self.config.model.vision_encoder.use_fused_rmsnorm, | |
use_fused_mlp=self.config.model.vision_encoder.use_fused_mlp, | |
fused_mlp_heuristic=self.config.model.vision_encoder.fused_mlp_heuristic, | |
attn_pool_num_heads=self.config.model.vision_encoder.attn_pool_num_heads, | |
clip_embed_dim=self.config.model.vision_encoder.clip_embed_dim, | |
layerscale_no_force_fp32=self.config.model.vision_encoder.layerscale_no_force_fp32, | |
num_frames=self.config.model.vision_encoder.num_frames, | |
tubelet_size=self.config.model.vision_encoder.tubelet_size, | |
sep_pos_embed=self.config.model.vision_encoder.sep_pos_embed, | |
use_checkpoint=self.config.model.vision_encoder.use_checkpoint, | |
checkpoint_num=self.config.model.vision_encoder.checkpoint_num, | |
) | |
return vision_encoder | |
def build_text_encoder(self): | |
"""build text_encoder and possiblly video-to-text multimodal fusion encoder. | |
Returns: nn.Module. The text encoder | |
""" | |
text_encoder = LLaMA( | |
use_flash_attn=self.config.model.text_encoder.use_flash_attn, | |
transformer_width=self.config.model.text_encoder.transformer_width, | |
llama_path=self.config.model.text_encoder.llama_path, | |
use_lora=self.config.model.text_encoder.use_lora, | |
) | |
return text_encoder | |
def load_checkpoint(self, vision_ckpt_path=None, text_ckpt_path=None, extra_ckpt_path=None): | |
assert vision_ckpt_path is not None, "No vision_encoder checkpoint" | |
assert text_ckpt_path is not None, "No text_encoder checkpoint" | |
new_ckpt = {} | |
# load vision_encoder | |
logger.info(f"Load vision_encoder checkpoint from {vision_ckpt_path}") | |
vision_ckpt = torch.load(vision_ckpt_path, map_location='cpu') | |
if 'module' in vision_ckpt.keys(): | |
vision_ckpt = vision_ckpt['module'] | |
elif 'model' in vision_ckpt.keys(): | |
vision_ckpt = vision_ckpt['model'] | |
if self.config.model.get('load_vision_ckpt_from_internvideo2_stage2', False): | |
from .backbones.internvideo2.pos_embed import interpolate_pos_embed | |
orig_t_size = self.config.model.get('vision_ckpt_t_size', 4) | |
interpolate_pos_embed(vision_ckpt, self.vision_encoder, orig_t_size=orig_t_size) # 4 for InternVideo2 stage2 | |
for k, v in vision_ckpt.items(): | |
if k.startswith('vision_encoder.'): | |
if 'clip_decoder' in k or 'final_clip_decoder' in k: | |
continue | |
elif 'clip_pos_embed' in k or 'clip_img_pos_embed' in k or 'img_pos_embed' in k : | |
continue | |
else: | |
new_ckpt[k] = v | |
else: | |
continue | |
else: | |
for k, v in vision_ckpt.items(): | |
if k.startswith('clip_decoder.') or k.startswith('mae_decoder.') or k.startswith('final_clip_decoder.'): | |
continue | |
elif k in ['clip_pos_embed', 'mae_pos_embed']: | |
continue | |
else: | |
new_k = 'vision_encoder.' + k | |
new_ckpt[new_k] = v | |
# load text_encoder | |
logger.info(f"Load text_encoder checkpoint from {text_ckpt_path}") | |
test_ckpt = torch.load(text_ckpt_path, map_location='cpu') | |
if 'module' in test_ckpt.keys(): | |
test_ckpt = test_ckpt['module'] | |
for k, v in test_ckpt.items(): | |
if k.startswith('transformer.') or k == 'text_projection': | |
new_k = "text_encoder." + k | |
else: | |
continue | |
new_ckpt[new_k] = v | |
# load extra checkpoint | |
# often when post-pretrain after previous pretraining, thus the keys are same | |
if extra_ckpt_path is not None: | |
logger.info(f"Load extra checkpoint from {extra_ckpt_path}") | |
extra_ckpt = torch.load(extra_ckpt_path, map_location='cpu') | |
if 'module' in extra_ckpt.keys(): | |
extra_ckpt = extra_ckpt['module'] | |
for k, v in extra_ckpt.items(): | |
new_ckpt[k] = v | |
msg = self.load_state_dict(new_ckpt, strict=False) | |
logger.info(msg) | |