LanguageBind / model /process_clip.py
LinB203
add project files
5c98ca3
import logging
import math
from typing import Optional, Tuple
from einops import rearrange
from peft import LoraConfig, get_peft_model
from transformers import CLIPConfig
from transformers.models.clip.modeling_clip import CLIPEncoderLayer as SpatialCLIPEncoderLayer, CLIPAttention, CLIPMLP
import torch
from torch import nn
from torch.nn import functional as F
from training.distributed import is_master
aaa = {'NUM_FRAMES': 1, 'PATCH_DROPOUT': 0.0}
def set_global_value(k, v):
global aaa
aaa[k] = v
def get_global_value():
global aaa
return aaa
# @dataclass
# class CLIPVisionCfg:
# layers: Union[Tuple[int, int, int, int], int] = 12
# width: int = 768
# head_width: int = 64
# mlp_ratio: float = 4.0
# patch_size: int = 16
# image_size: Union[Tuple[int, int], int] = 224
# cast_dtype: str = None
# num_frames: int = 2
#
# ls_init_value: Optional[float] = None # layer scale initial value
# patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results
# input_patchnorm: bool = False # whether to use dual patchnorm - would only apply the input layernorm on each patch, as post-layernorm already exist in original clip vit design
# global_average_pool: bool = False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580)
# attentional_pool: bool = False # whether to use attentional pooler in the last embedding layer
# n_queries: int = 256 # n_queries for attentional pooler
# attn_pooler_heads: int = 8 # n heads for attentional_pooling
# output_tokens: bool = False
#
# timm_model_name: str = None # a valid model name overrides layers, width, patch_size
# timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model
# timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
# timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '')
# timm_proj_bias: bool = False # enable bias final projection
# timm_drop: float = 0. # head dropout
# timm_drop_path: Optional[float] = None # backbone stochastic depth
# class Video_VisionTransformer(nn.Module):
# output_tokens: torch.jit.Final[bool]
#
# def __init__(
# self,
# num_frames: int,
# image_size: int,
# patch_size: int,
# width: int,
# layers: int,
# heads: int,
# mlp_ratio: float,
# ls_init_value: float = None,
# global_average_pool: bool = False,
# attentional_pool: bool = False,
# n_queries: int = 256,
# attn_pooler_heads: int = 8,
# output_dim: int = 512,
# patch_dropout: float = 0.,
# input_patchnorm: bool = False,
# act_layer: Callable = nn.GELU,
# norm_layer: Callable = LayerNorm,
# output_tokens: bool = False
# ):
# super().__init__()
# self.output_tokens = output_tokens
# image_height, image_width = self.image_size = to_2tuple(image_size)
# patch_height, patch_width = self.patch_size = to_2tuple(patch_size)
# self.grid_size = (image_height // patch_height, image_width // patch_width)
# self.output_dim = output_dim
#
# # whether to layernorm each patch, as done in dual patchnorm paper - https://arxiv.org/abs/2302.01327v1
# self.input_patchnorm = input_patchnorm
#
# if input_patchnorm:
# patch_input_dim = patch_height * patch_width * 3
# self.patchnorm_pre_ln = LayerNorm(patch_input_dim)
# self.conv1 = nn.Linear(patch_input_dim, width)
# else:
# self.patchnorm_pre_ln = nn.Identity()
# self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size,
# bias=False)
#
# # class embeddings and positional embeddings
# self.scale = scale = width ** -0.5
# self.class_embedding = nn.Parameter(scale * torch.randn(width))
# self.positional_embedding = nn.Parameter(scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, width))
#
# self.temporal_embedding = nn.Parameter(torch.zeros(1, num_frames, width))
# # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn
# self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity()
#
# self.ln_pre = norm_layer(width)
# self.transformer = Transformer(
# width,
# layers,
# heads,
# mlp_ratio,
# ls_init_value=ls_init_value,
# act_layer=act_layer,
# norm_layer=norm_layer,
# )
#
# self.global_average_pool = global_average_pool
# if attentional_pool:
# self.attn_pool = AttentionalPooler(output_dim, width, n_head=attn_pooler_heads, n_queries=n_queries)
# self.ln_post = norm_layer(output_dim)
# self.proj = nn.Parameter(scale * torch.randn(output_dim, output_dim))
# else:
# self.attn_pool = None
# self.ln_post = norm_layer(width)
# self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
#
# self.init_parameters()
#
#
# def lock(self, unlocked_groups=0, freeze_bn_stats=False):
# for param in self.parameters():
# param.requires_grad = False
#
# if unlocked_groups != 0:
# groups = [
# [
# self.conv1,
# self.positional_embedding,
# self.ln_pre,
# ],
# *zip(self.transformer.resblocks[:-1], [self.class_embedding for i in range(len(self.transformer.resblocks[:-1]))]),
# [
# self.class_embedding,
# self.transformer.resblocks[-1],
# self.ln_post,
# ],
# [self.proj, self.temporal_embedding]
# ]
#
# def _unlock(x):
# if isinstance(x, Sequence):
# for g in x:
# _unlock(g)
# else:
# if isinstance(x, torch.nn.Parameter):
# x.requires_grad = True
# else:
# for p in x.parameters():
# p.requires_grad = True
#
# _unlock(groups[-unlocked_groups:])
#
# def init_parameters(self):
# # FIXME OpenAI CLIP did not define an init for the VisualTransformer
# # TODO experiment if default PyTorch init, below, or alternate init is best.
#
# nn.init.normal_(self.temporal_embedding, std=self.scale)
# # nn.init.normal_(self.class_embedding, std=self.scale)
# # nn.init.normal_(self.positional_embedding, std=self.scale)
# #
# # proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
# # attn_std = self.transformer.width ** -0.5
# # fc_std = (2 * self.transformer.width) ** -0.5
# # for block in self.transformer.resblocks:
# # nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
# # nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
# # nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
# # nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
# #
# # if self.text_projection is not None:
# # nn.init.normal_(self.text_projection, std=self.scale)
# # pass
#
# @torch.jit.ignore
# def set_grad_checkpointing(self, enable=True):
# self.transformer.grad_checkpointing = enable
#
# def _global_pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
# if self.global_average_pool:
# return x.mean(dim=1), x
# else:
# return x[:, 0], x[:, 1:]
#
# def forward(self, x: torch.Tensor):
# # print('input img', x.shape)
# B, _, T, _, _ = x.shape
# x = rearrange(x, 'b c t h w -> (b t) c h w')
# # to patches - whether to use dual patchnorm - https://arxiv.org/abs/2302.01327v1
# if self.input_patchnorm:
# # einops - rearrange(x, 'b c (h p1) (w p2) -> b (h w) (c p1 p2)')
# x = x.reshape(x.shape[0], x.shape[1], self.grid_size[0], self.patch_size[0], self.grid_size[1],
# self.patch_size[1])
# x = x.permute(0, 2, 4, 1, 3, 5)
# x = x.reshape(x.shape[0], self.grid_size[0] * self.grid_size[1], -1)
# x = self.patchnorm_pre_ln(x)
# x = self.conv1(x)
# else:
# x = self.conv1(x) # shape = [*, width, grid, grid]
# x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
# x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
#
# # print('embed img', x.shape)
# # class embeddings and positional embeddings
# x = torch.cat(
# [self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
# x], dim=1) # shape = [*, grid ** 2 + 1, width]
# x = x + self.positional_embedding.to(x.dtype)
#
# n = x.shape[1]
# x = rearrange(x, '(b t) n d -> (b n) t d', t=T)
# x = x + self.temporal_embedding[:, :T, :]
# x = rearrange(x, '(b n) t d -> (b t) n d', n=n)
#
# # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
# x = self.patch_dropout(x)
# x = self.ln_pre(x)
#
# # print('patch_dropout img', x.shape)
# x = x.permute(1, 0, 2) # NLD -> LND
# # print('permute img', x.shape)
# x = self.transformer(x)
# x = x.permute(1, 0, 2) # LND -> NLD
#
# if self.attn_pool is not None:
# x = self.attn_pool(x)
# x = self.ln_post(x)
# pooled, tokens = self._global_pool(x)
# else:
# pooled, tokens = self._global_pool(x)
# pooled = self.ln_post(pooled) # bt, d
#
# pooled = pooled.reshape(B, T, -1).mean(1)
# if self.proj is not None:
# pooled = pooled @ self.proj
#
# if self.output_tokens:
# return pooled, tokens
#
# return pooled
#
# def _build_vision_tower(
# embed_dim: int,
# vision_cfg: CLIPVisionCfg,
# quick_gelu: bool = False,
# cast_dtype: Optional[torch.dtype] = None
# ):
# if isinstance(vision_cfg, dict):
# vision_cfg = CLIPVisionCfg(**vision_cfg)
#
# # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more
# # memory efficient in recent PyTorch releases (>= 1.10).
# # NOTE: timm models always use native GELU regardless of quick_gelu flag.
# act_layer = QuickGELU if quick_gelu else nn.GELU
#
# vision_heads = vision_cfg.width // vision_cfg.head_width
# norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
# visual = Video_VisionTransformer(
# num_frames=vision_cfg.num_frames,
# image_size=vision_cfg.image_size,
# patch_size=vision_cfg.patch_size,
# width=vision_cfg.width,
# layers=vision_cfg.layers,
# heads=vision_heads,
# mlp_ratio=vision_cfg.mlp_ratio,
# ls_init_value=vision_cfg.ls_init_value,
# patch_dropout=vision_cfg.patch_dropout,
# input_patchnorm=vision_cfg.input_patchnorm,
# global_average_pool=vision_cfg.global_average_pool,
# attentional_pool=vision_cfg.attentional_pool,
# n_queries=vision_cfg.n_queries,
# attn_pooler_heads=vision_cfg.attn_pooler_heads,
# output_tokens=vision_cfg.output_tokens,
# output_dim=embed_dim,
# act_layer=act_layer,
# norm_layer=norm_layer,
# )
#
# return visual
class CLIPEncoderLayer(SpatialCLIPEncoderLayer):
def __init__(self, config: CLIPConfig):
super().__init__(config)
self.temporal_embedding = nn.Parameter(torch.zeros(1, config.num_frames, config.hidden_size))
nn.init.normal_(self.temporal_embedding, std=config.hidden_size ** -0.5)
self.embed_dim = config.hidden_size
self.temporal_attn = CLIPAttention(config)
self.temporal_mlp = CLIPMLP(config)
# self.t_attn_gate = nn.Parameter(torch.tensor([-20.]))
# self.t_ffn_gate = nn.Parameter(torch.tensor([-20.]))
self.temporal_layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.temporal_layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
causal_attention_mask: torch.Tensor,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.FloatTensor]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
`(config.encoder_attention_heads,)`.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
"""
bt, n, d = hidden_states.shape
t = get_global_value()['NUM_FRAMES']
# time embed
if t != 1:
n = hidden_states.shape[1]
hidden_states = rearrange(hidden_states, '(b t) n d -> (b n) t d', t=t)
hidden_states = hidden_states + self.temporal_embedding[:, :t, :]
hidden_states = rearrange(hidden_states, '(b n) t d -> (b t) n d', n=n)
# time attn
residual = hidden_states
hidden_states = rearrange(hidden_states, '(b t) n d -> (b n) t d', t=t)
# hidden_states = self.layer_norm1(hidden_states) # share layernorm
hidden_states = self.temporal_layer_norm1(hidden_states)
hidden_states, attn_weights = self.temporal_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
causal_attention_mask=causal_attention_mask,
output_attentions=output_attentions,
)
hidden_states = residual + rearrange(hidden_states, '(b n) t d -> (b t) n d', n=n)
residual = hidden_states
hidden_states = rearrange(hidden_states, '(b t) n d -> (b n) t d', t=t)
# hidden_states = self.layer_norm2(hidden_states) # share layernorm
hidden_states = self.temporal_layer_norm2(hidden_states)
hidden_states = self.temporal_mlp(hidden_states)
hidden_states = residual + rearrange(hidden_states, '(b n) t d -> (b t) n d', n=n)
# spatial attn
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
hidden_states, attn_weights = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
causal_attention_mask=causal_attention_mask,
output_attentions=output_attentions,
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (attn_weights,)
return outputs
# class ResidualAttentionBlock(SpatialResidualAttentionBlock):
# def __init__(self,
# num_frames: int,
# d_model: int,
# n_head: int,
# mlp_ratio: float = 4.0,
# ls_init_value: float = None,
# act_layer: Callable = nn.GELU,
# norm_layer: Callable = LayerNorm,
# is_cross_attention: bool = False,):
# super().__init__(d_model, n_head, mlp_ratio, ls_init_value, act_layer, norm_layer, is_cross_attention)
#
# self.num_frames = num_frames
# self.time_ln_1 = norm_layer(d_model)
# self.time_attn = nn.MultiheadAttention(d_model, n_head)
# self.time_ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
#
# def time_attention(
# self,
# q_x: torch.Tensor,
# k_x: Optional[torch.Tensor] = None,
# v_x: Optional[torch.Tensor] = None,
# attn_mask: Optional[torch.Tensor] = None,
# ):
# k_x = k_x if k_x is not None else q_x
# v_x = v_x if v_x is not None else q_x
#
# attn_mask = attn_mask.to(q_x.dtype) if attn_mask is not None else None
# return self.time_attn(
# q_x, k_x, v_x, need_weights=True, attn_mask=attn_mask
# )[0]
#
# def forward(
# self,
# q_x: torch.Tensor,
# k_x: Optional[torch.Tensor] = None,
# v_x: Optional[torch.Tensor] = None,
# attn_mask: Optional[torch.Tensor] = None,
# ):
# k_x = self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None
# v_x = self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None
#
# n, bt, d = q_x.shape
# t = get_global_value()['NUM_FRAMES']
#
# # time attn
# # print('q_x', q_x.shape)
# xt = rearrange(q_x, 'n (b t) d -> t (b n) d', t=t)
# # print('xt', xt.shape)
# xt = self.time_ls_1(self.time_attention(q_x=self.time_ln_1(xt), k_x=None, v_x=None, attn_mask=None))
# # print('time_attention xt', xt.shape)
# q_x = q_x + rearrange(xt, 't (b n) d -> n (b t) d', n=n)
# # print('time_attention q_x', xt.shape)
#
# # spatial attn
# x = q_x + self.ls_1(self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask))
#
# x = x + self.ls_2(self.mlp(self.ln_2(x)))
# return x
def print_trainable_parameters(model, msg=''):
"""
Prints the number of trainable parameters in the model.
"""
trainable_params = 0
all_param = 0
for _, param in model.named_parameters():
all_param += param.numel()
if param.requires_grad:
trainable_params += param.numel()
logging.info(f"{msg} Trainable params: {trainable_params} || all params: {all_param} || "
f"trainable: {100 * trainable_params / all_param:.2f}%")
def convert_model_to_lora(args, model):
if args.clip_type == 'vl' and args.add_time_attn:
target_modules = ["temporal_attn.k_proj", "temporal_attn.v_proj",
"temporal_attn.q_proj", "temporal_attn.out_proj",
"temporal_mlp.fc1", "temporal_mlp.fc2"]
else:
target_modules = ["k_proj", "v_proj", "q_proj", "out_proj"]
config = LoraConfig(
r=args.lora_r, # 16
lora_alpha=args.lora_alpha, # 16
target_modules=target_modules, # self_attn.out_proj
lora_dropout=args.lora_dropout, # 0.1
bias="none",
modules_to_save=[],
)
model.vision_model.encoder.is_gradient_checkpointing = False
model.vision_model.encoder = get_peft_model(model.vision_model.encoder, config)
if is_master(args):
print_trainable_parameters(model.vision_model.encoder, msg='The model.vision_model.encoder: ')
# model.text_model.encoder.is_gradient_checkpointing = False
# model.text_model.encoder = get_peft_model(model.text_model.encoder, config)
# if is_master(args):
# print_trainable_parameters(model.text_model.encoder, msg='The model.text_model.encoder: ')
def add_time_attn_block(m: nn.ModuleList, device):
config = m.config
for i, sub_m in enumerate(m.layers):
if isinstance(sub_m, SpatialCLIPEncoderLayer):
oup = CLIPEncoderLayer(config).to(device)
state_dict = sub_m.state_dict()
new_state_dict = {}
for k, v in state_dict.items():
if 'self_attn' in k:
new_state_dict[k] = v
# if 'out_proj' in k:
# v = torch.zeros_like(v, dtype=v.dtype, device=v.device)
new_k = 'temporal_attn.' + '.'.join(k.split('.')[1:])
new_state_dict[new_k] = v
elif 'mlp' in k:
new_state_dict[k] = v
# if 'out_proj' in k:
# v = torch.zeros_like(v, dtype=v.dtype, device=v.device)
new_k = 'temporal_mlp.' + '.'.join(k.split('.')[1:])
new_state_dict[new_k] = v
elif 'layer_norm1' in k:
new_state_dict[k] = v
new_k = 'temporal_layer_norm1.' + '.'.join(k.split('.')[1:])
new_state_dict[new_k] = v
elif 'layer_norm2' in k:
new_state_dict[k] = v
new_k = 'temporal_layer_norm2.' + '.'.join(k.split('.')[1:])
new_state_dict[new_k] = v
else:
new_state_dict[k] = v
missing_keys, unexpected_keys = oup.load_state_dict(new_state_dict, strict=False)
# assert missing_keys == ["t_attn_gate", "t_ffn_gate"]
assert missing_keys == ['temporal_embedding']
assert unexpected_keys == []
m.layers[i] = oup
def resize_pos(m: nn.Module, args):
# convert embedding
if args.clip_type == 'al':
m.image_size = [args.num_mel_bins, args.target_length]
m.config.image_size = [m.image_size, m.image_size] if isinstance(m.image_size, int) else m.image_size
# m.config.num_channels = 1
# new_patch_embedding = nn.Conv2d(
# in_channels=m.config.num_channels,
# out_channels=m.embed_dim,
# kernel_size=m.patch_size,
# stride=m.patch_size,
# bias=False,
# )
# state_dict = m.patch_embedding.state_dict()
# for k, v in state_dict.items():
# state_dict[k] = torch.mean(v, dim=1, keepdim=True).to(v.dtype)
# m.patch_embedding = new_patch_embedding
# m.patch_embedding.load_state_dict(state_dict)
# pos resize
old_pos_embed_state_dict = m.position_embedding.state_dict()
old_pos_embed = old_pos_embed_state_dict['weight']
dtype = old_pos_embed.dtype
grid_size = [m.config.image_size[0] // m.patch_size, m.config.image_size[1] // m.patch_size]
extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
if new_seq_len == old_pos_embed.shape[0]:
m.to(args.device)
return
m.num_patches = grid_size[0] * grid_size[1]
m.num_positions = m.num_patches + 1
m.register_buffer("position_ids", torch.arange(m.num_positions).expand((1, -1)))
new_position_embedding = nn.Embedding(m.num_positions, m.embed_dim)
if extra_tokens:
pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
else:
pos_emb_tok, pos_emb_img = None, old_pos_embed
old_grid_size = [int(math.sqrt(len(pos_emb_img)))]*2
if is_master(args):
logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size)
pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
pos_emb_img = F.interpolate(
pos_emb_img,
size=grid_size,
mode='bicubic',
antialias=True,
align_corners=False,
)
pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
if pos_emb_tok is not None:
new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
else:
new_pos_embed = pos_emb_img
old_pos_embed_state_dict['weight'] = new_pos_embed.to(dtype)
m.position_embedding = new_position_embedding
m.position_embedding.load_state_dict(old_pos_embed_state_dict)
m.to(args.device)
# def i2v_linear_resize_pos_embed(state_dict, model, interpolation: str = 'linear', antialias: bool = True):
# # Rescale the grid of position embeddings when loading from state_dict
# old_pos_embed = state_dict.get('visual.positional_embedding', None)
# if old_pos_embed is None or not hasattr(model.visual, 'grid_size'):
# return
# # grid_size = to_2tuple(model.visual.grid_size)
# grid_size = model.visual.grid_size
# extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
# # new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
# new_seq_len = grid_size[0] * grid_size[1] * grid_size[2] + extra_tokens
# if new_seq_len == old_pos_embed.shape[0]:
# return
#
# if extra_tokens:
# pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
# else:
# pos_emb_tok, pos_emb_img = None, old_pos_embed
# # old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))
#
# logging.info('Resizing position embedding grid-size from %s to %s', old_pos_embed.shape[0], new_seq_len)
# # pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
# pos_emb_img = pos_emb_img.unsqueeze(0).permute(0, 2, 1)
# pos_emb_img = F.interpolate(
# pos_emb_img,
# # size=grid_size,
# size=new_seq_len - extra_tokens,
# mode=interpolation,
# # antialias=antialias,
# # align_corners=False,
# )
# # pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
# pos_emb_img = pos_emb_img.permute(0, 2, 1)[0]
# if pos_emb_tok is not None:
# new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
# else:
# new_pos_embed = pos_emb_img
# state_dict['visual.positional_embedding'] = new_pos_embed
#
# def inflate_patch_embed(state_dict, model):
# old_patch_embed_shape = model.visual.conv1.weight.shape
# new_patch_embed_shape = state_dict['visual.conv1.weight'].shape
# if old_patch_embed_shape == new_patch_embed_shape:
# return
# expanded_weight = state_dict['visual.conv1.weight'].unsqueeze(2).repeat(1, 1, 2, 1, 1)
# state_dict['visual.conv1.weight'] = expanded_weight
#
#
# def load_checkpoint(model, pretrained, strict=True):
# state_dict = load_state_dict(pretrained)
# # detect old format and make compatible with new format
# if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'):
# state_dict = convert_to_custom_text_state_dict(state_dict)
# i2v_linear_resize_pos_embed(state_dict, model)
# inflate_patch_embed(state_dict, model)
# incompatible_keys = model.load_state_dict(state_dict, strict=strict)
# return incompatible_keys