Spaces:
Running
Running
import logging | |
import math | |
from typing import Optional, Tuple | |
from einops import rearrange | |
from peft import LoraConfig, get_peft_model | |
from transformers import CLIPConfig | |
from transformers.models.clip.modeling_clip import CLIPEncoderLayer as SpatialCLIPEncoderLayer, CLIPAttention, CLIPMLP | |
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
from training.distributed import is_master | |
aaa = {'NUM_FRAMES': 1, 'PATCH_DROPOUT': 0.0} | |
def set_global_value(k, v): | |
global aaa | |
aaa[k] = v | |
def get_global_value(): | |
global aaa | |
return aaa | |
# @dataclass | |
# class CLIPVisionCfg: | |
# layers: Union[Tuple[int, int, int, int], int] = 12 | |
# width: int = 768 | |
# head_width: int = 64 | |
# mlp_ratio: float = 4.0 | |
# patch_size: int = 16 | |
# image_size: Union[Tuple[int, int], int] = 224 | |
# cast_dtype: str = None | |
# num_frames: int = 2 | |
# | |
# ls_init_value: Optional[float] = None # layer scale initial value | |
# patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results | |
# input_patchnorm: bool = False # whether to use dual patchnorm - would only apply the input layernorm on each patch, as post-layernorm already exist in original clip vit design | |
# global_average_pool: bool = False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580) | |
# attentional_pool: bool = False # whether to use attentional pooler in the last embedding layer | |
# n_queries: int = 256 # n_queries for attentional pooler | |
# attn_pooler_heads: int = 8 # n heads for attentional_pooling | |
# output_tokens: bool = False | |
# | |
# timm_model_name: str = None # a valid model name overrides layers, width, patch_size | |
# timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model | |
# timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '') | |
# timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '') | |
# timm_proj_bias: bool = False # enable bias final projection | |
# timm_drop: float = 0. # head dropout | |
# timm_drop_path: Optional[float] = None # backbone stochastic depth | |
# class Video_VisionTransformer(nn.Module): | |
# output_tokens: torch.jit.Final[bool] | |
# | |
# def __init__( | |
# self, | |
# num_frames: int, | |
# image_size: int, | |
# patch_size: int, | |
# width: int, | |
# layers: int, | |
# heads: int, | |
# mlp_ratio: float, | |
# ls_init_value: float = None, | |
# global_average_pool: bool = False, | |
# attentional_pool: bool = False, | |
# n_queries: int = 256, | |
# attn_pooler_heads: int = 8, | |
# output_dim: int = 512, | |
# patch_dropout: float = 0., | |
# input_patchnorm: bool = False, | |
# act_layer: Callable = nn.GELU, | |
# norm_layer: Callable = LayerNorm, | |
# output_tokens: bool = False | |
# ): | |
# super().__init__() | |
# self.output_tokens = output_tokens | |
# image_height, image_width = self.image_size = to_2tuple(image_size) | |
# patch_height, patch_width = self.patch_size = to_2tuple(patch_size) | |
# self.grid_size = (image_height // patch_height, image_width // patch_width) | |
# self.output_dim = output_dim | |
# | |
# # whether to layernorm each patch, as done in dual patchnorm paper - https://arxiv.org/abs/2302.01327v1 | |
# self.input_patchnorm = input_patchnorm | |
# | |
# if input_patchnorm: | |
# patch_input_dim = patch_height * patch_width * 3 | |
# self.patchnorm_pre_ln = LayerNorm(patch_input_dim) | |
# self.conv1 = nn.Linear(patch_input_dim, width) | |
# else: | |
# self.patchnorm_pre_ln = nn.Identity() | |
# self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, | |
# bias=False) | |
# | |
# # class embeddings and positional embeddings | |
# self.scale = scale = width ** -0.5 | |
# self.class_embedding = nn.Parameter(scale * torch.randn(width)) | |
# self.positional_embedding = nn.Parameter(scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, width)) | |
# | |
# self.temporal_embedding = nn.Parameter(torch.zeros(1, num_frames, width)) | |
# # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn | |
# self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity() | |
# | |
# self.ln_pre = norm_layer(width) | |
# self.transformer = Transformer( | |
# width, | |
# layers, | |
# heads, | |
# mlp_ratio, | |
# ls_init_value=ls_init_value, | |
# act_layer=act_layer, | |
# norm_layer=norm_layer, | |
# ) | |
# | |
# self.global_average_pool = global_average_pool | |
# if attentional_pool: | |
# self.attn_pool = AttentionalPooler(output_dim, width, n_head=attn_pooler_heads, n_queries=n_queries) | |
# self.ln_post = norm_layer(output_dim) | |
# self.proj = nn.Parameter(scale * torch.randn(output_dim, output_dim)) | |
# else: | |
# self.attn_pool = None | |
# self.ln_post = norm_layer(width) | |
# self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) | |
# | |
# self.init_parameters() | |
# | |
# | |
# def lock(self, unlocked_groups=0, freeze_bn_stats=False): | |
# for param in self.parameters(): | |
# param.requires_grad = False | |
# | |
# if unlocked_groups != 0: | |
# groups = [ | |
# [ | |
# self.conv1, | |
# self.positional_embedding, | |
# self.ln_pre, | |
# ], | |
# *zip(self.transformer.resblocks[:-1], [self.class_embedding for i in range(len(self.transformer.resblocks[:-1]))]), | |
# [ | |
# self.class_embedding, | |
# self.transformer.resblocks[-1], | |
# self.ln_post, | |
# ], | |
# [self.proj, self.temporal_embedding] | |
# ] | |
# | |
# def _unlock(x): | |
# if isinstance(x, Sequence): | |
# for g in x: | |
# _unlock(g) | |
# else: | |
# if isinstance(x, torch.nn.Parameter): | |
# x.requires_grad = True | |
# else: | |
# for p in x.parameters(): | |
# p.requires_grad = True | |
# | |
# _unlock(groups[-unlocked_groups:]) | |
# | |
# def init_parameters(self): | |
# # FIXME OpenAI CLIP did not define an init for the VisualTransformer | |
# # TODO experiment if default PyTorch init, below, or alternate init is best. | |
# | |
# nn.init.normal_(self.temporal_embedding, std=self.scale) | |
# # nn.init.normal_(self.class_embedding, std=self.scale) | |
# # nn.init.normal_(self.positional_embedding, std=self.scale) | |
# # | |
# # proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) | |
# # attn_std = self.transformer.width ** -0.5 | |
# # fc_std = (2 * self.transformer.width) ** -0.5 | |
# # for block in self.transformer.resblocks: | |
# # nn.init.normal_(block.attn.in_proj_weight, std=attn_std) | |
# # nn.init.normal_(block.attn.out_proj.weight, std=proj_std) | |
# # nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) | |
# # nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) | |
# # | |
# # if self.text_projection is not None: | |
# # nn.init.normal_(self.text_projection, std=self.scale) | |
# # pass | |
# | |
# @torch.jit.ignore | |
# def set_grad_checkpointing(self, enable=True): | |
# self.transformer.grad_checkpointing = enable | |
# | |
# def _global_pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |
# if self.global_average_pool: | |
# return x.mean(dim=1), x | |
# else: | |
# return x[:, 0], x[:, 1:] | |
# | |
# def forward(self, x: torch.Tensor): | |
# # print('input img', x.shape) | |
# B, _, T, _, _ = x.shape | |
# x = rearrange(x, 'b c t h w -> (b t) c h w') | |
# # to patches - whether to use dual patchnorm - https://arxiv.org/abs/2302.01327v1 | |
# if self.input_patchnorm: | |
# # einops - rearrange(x, 'b c (h p1) (w p2) -> b (h w) (c p1 p2)') | |
# x = x.reshape(x.shape[0], x.shape[1], self.grid_size[0], self.patch_size[0], self.grid_size[1], | |
# self.patch_size[1]) | |
# x = x.permute(0, 2, 4, 1, 3, 5) | |
# x = x.reshape(x.shape[0], self.grid_size[0] * self.grid_size[1], -1) | |
# x = self.patchnorm_pre_ln(x) | |
# x = self.conv1(x) | |
# else: | |
# x = self.conv1(x) # shape = [*, width, grid, grid] | |
# x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] | |
# x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] | |
# | |
# # print('embed img', x.shape) | |
# # class embeddings and positional embeddings | |
# x = torch.cat( | |
# [self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), | |
# x], dim=1) # shape = [*, grid ** 2 + 1, width] | |
# x = x + self.positional_embedding.to(x.dtype) | |
# | |
# n = x.shape[1] | |
# x = rearrange(x, '(b t) n d -> (b n) t d', t=T) | |
# x = x + self.temporal_embedding[:, :T, :] | |
# x = rearrange(x, '(b n) t d -> (b t) n d', n=n) | |
# | |
# # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in | |
# x = self.patch_dropout(x) | |
# x = self.ln_pre(x) | |
# | |
# # print('patch_dropout img', x.shape) | |
# x = x.permute(1, 0, 2) # NLD -> LND | |
# # print('permute img', x.shape) | |
# x = self.transformer(x) | |
# x = x.permute(1, 0, 2) # LND -> NLD | |
# | |
# if self.attn_pool is not None: | |
# x = self.attn_pool(x) | |
# x = self.ln_post(x) | |
# pooled, tokens = self._global_pool(x) | |
# else: | |
# pooled, tokens = self._global_pool(x) | |
# pooled = self.ln_post(pooled) # bt, d | |
# | |
# pooled = pooled.reshape(B, T, -1).mean(1) | |
# if self.proj is not None: | |
# pooled = pooled @ self.proj | |
# | |
# if self.output_tokens: | |
# return pooled, tokens | |
# | |
# return pooled | |
# | |
# def _build_vision_tower( | |
# embed_dim: int, | |
# vision_cfg: CLIPVisionCfg, | |
# quick_gelu: bool = False, | |
# cast_dtype: Optional[torch.dtype] = None | |
# ): | |
# if isinstance(vision_cfg, dict): | |
# vision_cfg = CLIPVisionCfg(**vision_cfg) | |
# | |
# # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more | |
# # memory efficient in recent PyTorch releases (>= 1.10). | |
# # NOTE: timm models always use native GELU regardless of quick_gelu flag. | |
# act_layer = QuickGELU if quick_gelu else nn.GELU | |
# | |
# vision_heads = vision_cfg.width // vision_cfg.head_width | |
# norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm | |
# visual = Video_VisionTransformer( | |
# num_frames=vision_cfg.num_frames, | |
# image_size=vision_cfg.image_size, | |
# patch_size=vision_cfg.patch_size, | |
# width=vision_cfg.width, | |
# layers=vision_cfg.layers, | |
# heads=vision_heads, | |
# mlp_ratio=vision_cfg.mlp_ratio, | |
# ls_init_value=vision_cfg.ls_init_value, | |
# patch_dropout=vision_cfg.patch_dropout, | |
# input_patchnorm=vision_cfg.input_patchnorm, | |
# global_average_pool=vision_cfg.global_average_pool, | |
# attentional_pool=vision_cfg.attentional_pool, | |
# n_queries=vision_cfg.n_queries, | |
# attn_pooler_heads=vision_cfg.attn_pooler_heads, | |
# output_tokens=vision_cfg.output_tokens, | |
# output_dim=embed_dim, | |
# act_layer=act_layer, | |
# norm_layer=norm_layer, | |
# ) | |
# | |
# return visual | |
class CLIPEncoderLayer(SpatialCLIPEncoderLayer): | |
def __init__(self, config: CLIPConfig): | |
super().__init__(config) | |
self.temporal_embedding = nn.Parameter(torch.zeros(1, config.num_frames, config.hidden_size)) | |
nn.init.normal_(self.temporal_embedding, std=config.hidden_size ** -0.5) | |
self.embed_dim = config.hidden_size | |
self.temporal_attn = CLIPAttention(config) | |
self.temporal_mlp = CLIPMLP(config) | |
# self.t_attn_gate = nn.Parameter(torch.tensor([-20.])) | |
# self.t_ffn_gate = nn.Parameter(torch.tensor([-20.])) | |
self.temporal_layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) | |
self.temporal_layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) | |
def forward( | |
self, | |
hidden_states: torch.Tensor, | |
attention_mask: torch.Tensor, | |
causal_attention_mask: torch.Tensor, | |
output_attentions: Optional[bool] = False, | |
) -> Tuple[torch.FloatTensor]: | |
""" | |
Args: | |
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` | |
attention_mask (`torch.FloatTensor`): attention mask of size | |
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. | |
`(config.encoder_attention_heads,)`. | |
output_attentions (`bool`, *optional*): | |
Whether or not to return the attentions tensors of all attention layers. See `attentions` under | |
returned tensors for more detail. | |
""" | |
bt, n, d = hidden_states.shape | |
t = get_global_value()['NUM_FRAMES'] | |
# time embed | |
if t != 1: | |
n = hidden_states.shape[1] | |
hidden_states = rearrange(hidden_states, '(b t) n d -> (b n) t d', t=t) | |
hidden_states = hidden_states + self.temporal_embedding[:, :t, :] | |
hidden_states = rearrange(hidden_states, '(b n) t d -> (b t) n d', n=n) | |
# time attn | |
residual = hidden_states | |
hidden_states = rearrange(hidden_states, '(b t) n d -> (b n) t d', t=t) | |
# hidden_states = self.layer_norm1(hidden_states) # share layernorm | |
hidden_states = self.temporal_layer_norm1(hidden_states) | |
hidden_states, attn_weights = self.temporal_attn( | |
hidden_states=hidden_states, | |
attention_mask=attention_mask, | |
causal_attention_mask=causal_attention_mask, | |
output_attentions=output_attentions, | |
) | |
hidden_states = residual + rearrange(hidden_states, '(b n) t d -> (b t) n d', n=n) | |
residual = hidden_states | |
hidden_states = rearrange(hidden_states, '(b t) n d -> (b n) t d', t=t) | |
# hidden_states = self.layer_norm2(hidden_states) # share layernorm | |
hidden_states = self.temporal_layer_norm2(hidden_states) | |
hidden_states = self.temporal_mlp(hidden_states) | |
hidden_states = residual + rearrange(hidden_states, '(b n) t d -> (b t) n d', n=n) | |
# spatial attn | |
residual = hidden_states | |
hidden_states = self.layer_norm1(hidden_states) | |
hidden_states, attn_weights = self.self_attn( | |
hidden_states=hidden_states, | |
attention_mask=attention_mask, | |
causal_attention_mask=causal_attention_mask, | |
output_attentions=output_attentions, | |
) | |
hidden_states = residual + hidden_states | |
residual = hidden_states | |
hidden_states = self.layer_norm2(hidden_states) | |
hidden_states = self.mlp(hidden_states) | |
hidden_states = residual + hidden_states | |
outputs = (hidden_states,) | |
if output_attentions: | |
outputs += (attn_weights,) | |
return outputs | |
# class ResidualAttentionBlock(SpatialResidualAttentionBlock): | |
# def __init__(self, | |
# num_frames: int, | |
# d_model: int, | |
# n_head: int, | |
# mlp_ratio: float = 4.0, | |
# ls_init_value: float = None, | |
# act_layer: Callable = nn.GELU, | |
# norm_layer: Callable = LayerNorm, | |
# is_cross_attention: bool = False,): | |
# super().__init__(d_model, n_head, mlp_ratio, ls_init_value, act_layer, norm_layer, is_cross_attention) | |
# | |
# self.num_frames = num_frames | |
# self.time_ln_1 = norm_layer(d_model) | |
# self.time_attn = nn.MultiheadAttention(d_model, n_head) | |
# self.time_ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() | |
# | |
# def time_attention( | |
# self, | |
# q_x: torch.Tensor, | |
# k_x: Optional[torch.Tensor] = None, | |
# v_x: Optional[torch.Tensor] = None, | |
# attn_mask: Optional[torch.Tensor] = None, | |
# ): | |
# k_x = k_x if k_x is not None else q_x | |
# v_x = v_x if v_x is not None else q_x | |
# | |
# attn_mask = attn_mask.to(q_x.dtype) if attn_mask is not None else None | |
# return self.time_attn( | |
# q_x, k_x, v_x, need_weights=True, attn_mask=attn_mask | |
# )[0] | |
# | |
# def forward( | |
# self, | |
# q_x: torch.Tensor, | |
# k_x: Optional[torch.Tensor] = None, | |
# v_x: Optional[torch.Tensor] = None, | |
# attn_mask: Optional[torch.Tensor] = None, | |
# ): | |
# k_x = self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None | |
# v_x = self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None | |
# | |
# n, bt, d = q_x.shape | |
# t = get_global_value()['NUM_FRAMES'] | |
# | |
# # time attn | |
# # print('q_x', q_x.shape) | |
# xt = rearrange(q_x, 'n (b t) d -> t (b n) d', t=t) | |
# # print('xt', xt.shape) | |
# xt = self.time_ls_1(self.time_attention(q_x=self.time_ln_1(xt), k_x=None, v_x=None, attn_mask=None)) | |
# # print('time_attention xt', xt.shape) | |
# q_x = q_x + rearrange(xt, 't (b n) d -> n (b t) d', n=n) | |
# # print('time_attention q_x', xt.shape) | |
# | |
# # spatial attn | |
# x = q_x + self.ls_1(self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask)) | |
# | |
# x = x + self.ls_2(self.mlp(self.ln_2(x))) | |
# return x | |
def print_trainable_parameters(model, msg=''): | |
""" | |
Prints the number of trainable parameters in the model. | |
""" | |
trainable_params = 0 | |
all_param = 0 | |
for _, param in model.named_parameters(): | |
all_param += param.numel() | |
if param.requires_grad: | |
trainable_params += param.numel() | |
logging.info(f"{msg} Trainable params: {trainable_params} || all params: {all_param} || " | |
f"trainable: {100 * trainable_params / all_param:.2f}%") | |
def convert_model_to_lora(args, model): | |
if args.clip_type == 'vl' and args.add_time_attn: | |
target_modules = ["temporal_attn.k_proj", "temporal_attn.v_proj", | |
"temporal_attn.q_proj", "temporal_attn.out_proj", | |
"temporal_mlp.fc1", "temporal_mlp.fc2"] | |
else: | |
target_modules = ["k_proj", "v_proj", "q_proj", "out_proj"] | |
config = LoraConfig( | |
r=args.lora_r, # 16 | |
lora_alpha=args.lora_alpha, # 16 | |
target_modules=target_modules, # self_attn.out_proj | |
lora_dropout=args.lora_dropout, # 0.1 | |
bias="none", | |
modules_to_save=[], | |
) | |
model.vision_model.encoder.is_gradient_checkpointing = False | |
model.vision_model.encoder = get_peft_model(model.vision_model.encoder, config) | |
if is_master(args): | |
print_trainable_parameters(model.vision_model.encoder, msg='The model.vision_model.encoder: ') | |
# model.text_model.encoder.is_gradient_checkpointing = False | |
# model.text_model.encoder = get_peft_model(model.text_model.encoder, config) | |
# if is_master(args): | |
# print_trainable_parameters(model.text_model.encoder, msg='The model.text_model.encoder: ') | |
def add_time_attn_block(m: nn.ModuleList, device): | |
config = m.config | |
for i, sub_m in enumerate(m.layers): | |
if isinstance(sub_m, SpatialCLIPEncoderLayer): | |
oup = CLIPEncoderLayer(config).to(device) | |
state_dict = sub_m.state_dict() | |
new_state_dict = {} | |
for k, v in state_dict.items(): | |
if 'self_attn' in k: | |
new_state_dict[k] = v | |
# if 'out_proj' in k: | |
# v = torch.zeros_like(v, dtype=v.dtype, device=v.device) | |
new_k = 'temporal_attn.' + '.'.join(k.split('.')[1:]) | |
new_state_dict[new_k] = v | |
elif 'mlp' in k: | |
new_state_dict[k] = v | |
# if 'out_proj' in k: | |
# v = torch.zeros_like(v, dtype=v.dtype, device=v.device) | |
new_k = 'temporal_mlp.' + '.'.join(k.split('.')[1:]) | |
new_state_dict[new_k] = v | |
elif 'layer_norm1' in k: | |
new_state_dict[k] = v | |
new_k = 'temporal_layer_norm1.' + '.'.join(k.split('.')[1:]) | |
new_state_dict[new_k] = v | |
elif 'layer_norm2' in k: | |
new_state_dict[k] = v | |
new_k = 'temporal_layer_norm2.' + '.'.join(k.split('.')[1:]) | |
new_state_dict[new_k] = v | |
else: | |
new_state_dict[k] = v | |
missing_keys, unexpected_keys = oup.load_state_dict(new_state_dict, strict=False) | |
# assert missing_keys == ["t_attn_gate", "t_ffn_gate"] | |
assert missing_keys == ['temporal_embedding'] | |
assert unexpected_keys == [] | |
m.layers[i] = oup | |
def resize_pos(m: nn.Module, args): | |
# convert embedding | |
if args.clip_type == 'al': | |
m.image_size = [args.num_mel_bins, args.target_length] | |
m.config.image_size = [m.image_size, m.image_size] if isinstance(m.image_size, int) else m.image_size | |
# m.config.num_channels = 1 | |
# new_patch_embedding = nn.Conv2d( | |
# in_channels=m.config.num_channels, | |
# out_channels=m.embed_dim, | |
# kernel_size=m.patch_size, | |
# stride=m.patch_size, | |
# bias=False, | |
# ) | |
# state_dict = m.patch_embedding.state_dict() | |
# for k, v in state_dict.items(): | |
# state_dict[k] = torch.mean(v, dim=1, keepdim=True).to(v.dtype) | |
# m.patch_embedding = new_patch_embedding | |
# m.patch_embedding.load_state_dict(state_dict) | |
# pos resize | |
old_pos_embed_state_dict = m.position_embedding.state_dict() | |
old_pos_embed = old_pos_embed_state_dict['weight'] | |
dtype = old_pos_embed.dtype | |
grid_size = [m.config.image_size[0] // m.patch_size, m.config.image_size[1] // m.patch_size] | |
extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more) | |
new_seq_len = grid_size[0] * grid_size[1] + extra_tokens | |
if new_seq_len == old_pos_embed.shape[0]: | |
m.to(args.device) | |
return | |
m.num_patches = grid_size[0] * grid_size[1] | |
m.num_positions = m.num_patches + 1 | |
m.register_buffer("position_ids", torch.arange(m.num_positions).expand((1, -1))) | |
new_position_embedding = nn.Embedding(m.num_positions, m.embed_dim) | |
if extra_tokens: | |
pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:] | |
else: | |
pos_emb_tok, pos_emb_img = None, old_pos_embed | |
old_grid_size = [int(math.sqrt(len(pos_emb_img)))]*2 | |
if is_master(args): | |
logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size) | |
pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2) | |
pos_emb_img = F.interpolate( | |
pos_emb_img, | |
size=grid_size, | |
mode='bicubic', | |
antialias=True, | |
align_corners=False, | |
) | |
pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0] | |
if pos_emb_tok is not None: | |
new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0) | |
else: | |
new_pos_embed = pos_emb_img | |
old_pos_embed_state_dict['weight'] = new_pos_embed.to(dtype) | |
m.position_embedding = new_position_embedding | |
m.position_embedding.load_state_dict(old_pos_embed_state_dict) | |
m.to(args.device) | |
# def i2v_linear_resize_pos_embed(state_dict, model, interpolation: str = 'linear', antialias: bool = True): | |
# # Rescale the grid of position embeddings when loading from state_dict | |
# old_pos_embed = state_dict.get('visual.positional_embedding', None) | |
# if old_pos_embed is None or not hasattr(model.visual, 'grid_size'): | |
# return | |
# # grid_size = to_2tuple(model.visual.grid_size) | |
# grid_size = model.visual.grid_size | |
# extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more) | |
# # new_seq_len = grid_size[0] * grid_size[1] + extra_tokens | |
# new_seq_len = grid_size[0] * grid_size[1] * grid_size[2] + extra_tokens | |
# if new_seq_len == old_pos_embed.shape[0]: | |
# return | |
# | |
# if extra_tokens: | |
# pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:] | |
# else: | |
# pos_emb_tok, pos_emb_img = None, old_pos_embed | |
# # old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img)))) | |
# | |
# logging.info('Resizing position embedding grid-size from %s to %s', old_pos_embed.shape[0], new_seq_len) | |
# # pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2) | |
# pos_emb_img = pos_emb_img.unsqueeze(0).permute(0, 2, 1) | |
# pos_emb_img = F.interpolate( | |
# pos_emb_img, | |
# # size=grid_size, | |
# size=new_seq_len - extra_tokens, | |
# mode=interpolation, | |
# # antialias=antialias, | |
# # align_corners=False, | |
# ) | |
# # pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0] | |
# pos_emb_img = pos_emb_img.permute(0, 2, 1)[0] | |
# if pos_emb_tok is not None: | |
# new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0) | |
# else: | |
# new_pos_embed = pos_emb_img | |
# state_dict['visual.positional_embedding'] = new_pos_embed | |
# | |
# def inflate_patch_embed(state_dict, model): | |
# old_patch_embed_shape = model.visual.conv1.weight.shape | |
# new_patch_embed_shape = state_dict['visual.conv1.weight'].shape | |
# if old_patch_embed_shape == new_patch_embed_shape: | |
# return | |
# expanded_weight = state_dict['visual.conv1.weight'].unsqueeze(2).repeat(1, 1, 2, 1, 1) | |
# state_dict['visual.conv1.weight'] = expanded_weight | |
# | |
# | |
# def load_checkpoint(model, pretrained, strict=True): | |
# state_dict = load_state_dict(pretrained) | |
# # detect old format and make compatible with new format | |
# if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'): | |
# state_dict = convert_to_custom_text_state_dict(state_dict) | |
# i2v_linear_resize_pos_embed(state_dict, model) | |
# inflate_patch_embed(state_dict, model) | |
# incompatible_keys = model.load_state_dict(state_dict, strict=strict) | |
# return incompatible_keys | |