""" ViTamin Paper: Designing Scalable Vison Models in the Vision-Language Era @misc{chen2023designing, title={Designing Scalable Vison Models in the Vision-Language Era}, author={Jieneng Chen and Qihang Yu and Xiaohui Shen and Alan Yuille and Liang-Cheih Chen}, year={2023}, archivePrefix={arXiv}, primaryClass={cs.CV} } Based on Apache 2.0 licensed code at https://github.com/Beckschen/ViTamin by Jieneng Chen 2024 Reference: https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. """ from dataclasses import dataclass import logging import math from typing import Optional, Tuple, Union import numpy as np import torch import torch.nn.functional as F from torch import nn from torch.utils.checkpoint import checkpoint from functools import partial from open_clip.hf_model import HFTextEncoder from open_clip.modified_resnet import ModifiedResNet from open_clip.transformer import LayerNormFp32, LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer from open_clip.utils import to_2tuple import time import timm from timm.models.vision_transformer import _create_vision_transformer from .timm_model import TimmModel from .vitamin import * # from .vitamin import HybridEmbed, MbConvStages, VitCfg, VitConvCfg from .vitamin import GeGluMlp, ViTamin, HybridEmbed, MbConvStages, VitCfg, VitConvCfg from transformers.modeling_utils import PreTrainedModel from .configuration_vitamin import ViTaminConfig, ViTaminVisionConfig @dataclass class CLIPVisionCfg: layers: Union[Tuple[int, int, int, int], int] = 12 width: int = 768 head_width: int = 64 mlp_ratio: float = 4.0 patch_size: int = 16 image_size: Union[Tuple[int, int], int] = 224 ls_init_value: Optional[float] = None patch_dropout: float = 0. input_patchnorm: bool = False global_average_pool: bool = False attentional_pool: bool = False n_queries: int = 256 attn_pooler_heads: int = 8 output_tokens: bool = False timm_model_name: str = None timm_model_pretrained: bool = False timm_pool: str = 'avg' timm_proj: str = 'linear' timm_proj_bias: bool = False timm_drop: float = 0. timm_drop_path: Optional[float] = None @dataclass class CLIPTextCfg: context_length: int = 77 vocab_size: int = 49408 width: int = 512 heads: int = 8 layers: int = 12 ls_init_value: Optional[float] = None # layer scale initial value hf_model_name: str = None hf_tokenizer_name: str = None hf_model_pretrained: bool = True proj: str = 'mlp' pooler_type: str = 'mean_pooler' embed_cls: bool = False pad_id: int = 0 output_tokens: bool = False text_mask: str = 'first' # default first truncate in bpe_tokenizer def get_cast_dtype(precision: str): cast_dtype = None if precision == 'bf16': cast_dtype = torch.bfloat16 elif precision == 'fp16': cast_dtype = torch.float16 return cast_dtype def get_input_dtype(precision: str): input_dtype = None if precision in ('bf16', 'pure_bf16'): input_dtype = torch.bfloat16 elif precision in ('fp16', 'pure_fp16'): input_dtype = torch.float16 return input_dtype def _build_vision_tower( embed_dim: int, vision_cfg: CLIPVisionCfg, quick_gelu: bool = False, cast_dtype: Optional[torch.dtype] = None ): if isinstance(vision_cfg, dict): vision_cfg = CLIPVisionCfg(**vision_cfg) act_layer = QuickGELU if quick_gelu else nn.GELU if vision_cfg.timm_model_name: visual = TimmModel( vision_cfg.timm_model_name, pretrained=vision_cfg.timm_model_pretrained, pool=vision_cfg.timm_pool, proj=vision_cfg.timm_proj, proj_bias=vision_cfg.timm_proj_bias, drop=vision_cfg.timm_drop, drop_path=vision_cfg.timm_drop_path, patch_drop=vision_cfg.patch_dropout if vision_cfg.patch_dropout > 0 else None, embed_dim=embed_dim, image_size=vision_cfg.image_size, ) elif isinstance(vision_cfg.layers, (tuple, list)): vision_heads = vision_cfg.width * 32 // vision_cfg.head_width visual = ModifiedResNet( layers=vision_cfg.layers, output_dim=embed_dim, heads=vision_heads, image_size=vision_cfg.image_size, width=vision_cfg.width, ) else: vision_heads = vision_cfg.width // vision_cfg.head_width norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm visual = VisionTransformer( image_size=vision_cfg.image_size, patch_size=vision_cfg.patch_size, width=vision_cfg.width, layers=vision_cfg.layers, heads=vision_heads, mlp_ratio=vision_cfg.mlp_ratio, ls_init_value=vision_cfg.ls_init_value, patch_dropout=vision_cfg.patch_dropout, input_patchnorm=vision_cfg.input_patchnorm, global_average_pool=vision_cfg.global_average_pool, attentional_pool=vision_cfg.attentional_pool, n_queries=vision_cfg.n_queries, attn_pooler_heads=vision_cfg.attn_pooler_heads, output_tokens=vision_cfg.output_tokens, output_dim=embed_dim, act_layer=act_layer, norm_layer=norm_layer, ) return visual def _build_text_tower( embed_dim: int, text_cfg: CLIPTextCfg, quick_gelu: bool = False, cast_dtype: Optional[torch.dtype] = None, ): if isinstance(text_cfg, dict): text_cfg = CLIPTextCfg(**text_cfg) if text_cfg.hf_model_name: text = HFTextEncoder( text_cfg.hf_model_name, output_dim=embed_dim, proj=text_cfg.proj, pooler_type=text_cfg.pooler_type, pretrained=text_cfg.hf_model_pretrained, output_tokens=text_cfg.output_tokens, ) else: act_layer = QuickGELU if quick_gelu else nn.GELU norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm text = TextTransformer( context_length=text_cfg.context_length, vocab_size=text_cfg.vocab_size, width=text_cfg.width, heads=text_cfg.heads, layers=text_cfg.layers, ls_init_value=text_cfg.ls_init_value, output_dim=embed_dim, embed_cls=text_cfg.embed_cls, output_tokens=text_cfg.output_tokens, pad_id=text_cfg.pad_id, act_layer=act_layer, norm_layer=norm_layer, ) return text class CLIP(nn.Module): output_dict: torch.jit.Final[bool] def __init__( self, embed_dim: int, vision_cfg: CLIPVisionCfg, text_cfg: CLIPTextCfg, quick_gelu: bool = False, cast_dtype: Optional[torch.dtype] = None, output_dict: bool = False, ): super().__init__() self.output_dict = output_dict self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype) text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype) self.transformer = text.transformer self.context_length = text.context_length self.vocab_size = text.vocab_size self.token_embedding = text.token_embedding self.positional_embedding = text.positional_embedding self.ln_final = text.ln_final self.text_projection = text.text_projection self.register_buffer('attn_mask', text.attn_mask, persistent=False) self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) self.method_lock_text_tower = text.lock self.text_no_grad = False def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False): # lock image tower as per LiT - https://arxiv.org/abs/2111.07991 self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats) def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True, unlock_text_proj=False): # added by jieneng self.method_lock_text_tower(unlocked_layers, freeze_layer_norm) self.text_no_grad = True @torch.jit.ignore def set_grad_checkpointing(self, enable=True, enable_text=True): self.visual.set_grad_checkpointing(enable) self.transformer.grad_checkpointing = enable_text def encode_image(self, image, normalize: bool = False): features = self.visual(image) return F.normalize(features, dim=-1) if normalize else features def encode_text(self, text, normalize: bool = False): cast_dtype = self.transformer.get_cast_dtype() x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] x = x + self.positional_embedding.to(cast_dtype) x = x.permute(1, 0, 2) # NLD -> LND x = self.transformer(x, attn_mask=self.attn_mask) x = x.permute(1, 0, 2) # LND -> NLD x = self.ln_final(x) # [batch_size, n_ctx, transformer.width] # take features from the eot embedding (eot_token is the highest number in each sequence) x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection return F.normalize(x, dim=-1) if normalize else x def forward( self, image: Optional[torch.Tensor] = None, text: Optional[torch.Tensor] = None, ): # torch.cuda.synchronize() image_features = self.encode_image(image, normalize=True) if image is not None else None if self.text_no_grad: with torch.no_grad(): text_features = self.encode_text(text, normalize=True).detach() if text is not None else None else: text_features = self.encode_text(text, normalize=True) if text is not None else None if self.output_dict: return { "image_features": image_features, "text_features": text_features, "logit_scale": self.logit_scale.exp() } return image_features, text_features, self.logit_scale.exp() # class CustomTextCLIP(nn.Module): class CustomTextCLIP(nn.Module): output_dict: torch.jit.Final[bool] def __init__( self, embed_dim: int, vision_cfg: CLIPVisionCfg, text_cfg: CLIPTextCfg, quick_gelu: bool = False, cast_dtype: Optional[torch.dtype] = None, output_dict: bool = False, ): super().__init__() self.output_dict = output_dict self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype) self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype) self.context_length = self.text.context_length self.vocab_size = self.text.vocab_size self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) self.text_no_grad = False def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False): # lock image tower as per LiT - https://arxiv.org/abs/2111.07991 self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats) def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True, unlock_text_proj = False): self.text.lock(unlocked_layers, freeze_layer_norm, unlock_text_proj) self.text_no_grad = True @torch.jit.ignore def set_grad_checkpointing(self, enable=True, enable_text=True): self.visual.set_grad_checkpointing(enable) self.text.set_grad_checkpointing(enable_text) def encode_image(self, image, normalize: bool = False): features = self.visual(image) return F.normalize(features, dim=-1) if normalize else features def encode_text(self, text, normalize: bool = False): features = self.text(text) return F.normalize(features, dim=-1) if normalize else features def forward( self, image: Optional[torch.Tensor] = None, text: Optional[torch.Tensor] = None, ): image_features = self.encode_image(image, normalize=True) if image is not None else None # if self.text_no_grad: # with torch.no_grad(): # text_features = self.encode_text(text, normalize=True).detach() if text is not None else None # else: text_features = self.encode_text(text, normalize=True) if text is not None else None if self.output_dict: return { "image_features": image_features, "text_features": text_features, "logit_scale": self.logit_scale.exp() } return image_features, text_features, self.logit_scale.exp() class ViTaminPreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. """ config_class = ViTaminConfig base_model_prefix = 'vitamin' # hack CLIPVisionModel for llava: https://github.com/huggingface/transformers/blob/9acce7de1cb8229304a467938ebb47727d60cdb2/src/transformers/models/clip/modeling_clip.py#L878 class ViTaminVisionModel(PreTrainedModel): config_class = ViTaminVisionConfig main_input_name = 'pixel_values' def __init__(self, config: ViTaminVisionConfig): super().__init__(config) self.visual = _build_vision_tower(config.embed_dim, config) def forward( self, pixel_values: Optional[torch.FloatTensor] = None, select_layer = -2, ): assert len(pixel_values.shape) == 4, f'wrong pixel_values size: {pixel_values.shape}' x = self.visual.trunk.patch_embed.backbone.stem(pixel_values) x = self.visual.trunk.patch_embed.backbone.stages[0](x) x = self.visual.trunk.patch_embed.backbone.stages[1](x) x = self.visual.trunk.patch_embed.backbone.pool(x) x = self.visual.trunk.patch_embed.proj(x) x = x.flatten(2).transpose(1, 2) x = self.visual.trunk.patch_drop(x) x = self.visual.trunk.norm_pre(x) x = self.visual.trunk.blocks[:select_layer+1](x) return x class ViTaminCLIP(ViTaminPreTrainedModel): output_dict: torch.jit.Final[bool] config_class: ViTaminConfig def __init__( self, config: ViTaminConfig ): super().__init__(config) embed_dim=config.embed_dim #: int, vision_cfg=config.vision_cfg #: CLIPVisionCfg, text_cfg=config.text_cfg #: CLIPTextCfg, quick_gelu=False cast_dtype=None output_dict=False self.config = config self.output_dict = output_dict self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype) self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype) self.context_length = self.text.context_length self.vocab_size = self.text.vocab_size self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) self.text_no_grad = False def forward_visual4llava( self, pixel_values: Optional[torch.FloatTensor] = None, select_layer = -2, ): assert len(pixel_values.shape) == 4, f'wrong pixel_values size: {pixel_values.shape}' x = self.visual.trunk.patch_embed.backbone.stem(pixel_values) x = self.visual.trunk.patch_embed.backbone.stages[0](x) x = self.visual.trunk.patch_embed.backbone.stages[1](x) x = self.visual.trunk.patch_embed.backbone.pool(x) x = self.visual.trunk.patch_embed.proj(x) x = x.flatten(2).transpose(1, 2) x = self.visual.trunk.patch_drop(x) x = self.visual.trunk.norm_pre(x) x = self.visual.trunk.blocks[:select_layer+1](x) return x def encode_image(self, image, normalize: bool = False): features = self.visual(image) return F.normalize(features, dim=-1) if normalize else features def encode_text(self, text, normalize: bool = False): features = self.text(text) return F.normalize(features, dim=-1) if normalize else features def forward_pixel( self, image: Optional[torch.Tensor] = None, text: Optional[torch.Tensor] = None, ): x = self.visual.trunk.patch_embed.backbone.stem(image) x = self.visual.trunk.patch_embed.backbone.stages[0](x) x = self.visual.trunk.patch_embed.backbone.stages[1](x) x = self.visual.trunk.patch_embed.backbone.pool(x) x = self.visual.trunk.patch_embed.proj(x) x = x.flatten(2).transpose(1, 2) x = self.visual.trunk.patch_drop(x) x = self.visual.trunk.norm_pre(x) x = self.visual.trunk.blocks(x) x = self.visual.trunk.fc_norm(x) x = self.visual.head.proj(x) image_features = F.normalize(x, dim=-1) text_features = self.encode_text(text, normalize=True) if text is not None else None if self.output_dict: return { "image_features": image_features, "text_features": text_features, "logit_scale": self.logit_scale.exp() } return image_features, text_features, self.logit_scale.exp() def forward( self, image: Optional[torch.Tensor] = None, text: Optional[torch.Tensor] = None, ): image_features = self.encode_image(image, normalize=True) if image is not None else None # if self.text_no_grad: # with torch.no_grad(): # text_features = self.encode_text(text, normalize=True).detach() if text is not None else None # else: text_features = self.encode_text(text, normalize=True) if text is not None else None if self.output_dict: return { "image_features": image_features, "text_features": text_features, "logit_scale": self.logit_scale.exp() } return image_features, text_features, self.logit_scale.exp() def convert_weights_to_lp(model: nn.Module, dtype=torch.float16): """Convert applicable model parameters to low-precision (bf16 or fp16)""" def _convert_weights(l): if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): l.weight.data = l.weight.data.to(dtype) if l.bias is not None: l.bias.data = l.bias.data.to(dtype) if isinstance(l, (nn.MultiheadAttention, Attention)): for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: tensor = getattr(l, attr) if tensor is not None: tensor.data = tensor.data.to(dtype) if isinstance(l, (CLIP, TextTransformer)): # convert text nn.Parameter projections attr = getattr(l, "text_projection", None) if attr is not None: attr.data = attr.data.to(dtype) if isinstance(l, VisionTransformer): # convert vision nn.Parameter projections attr = getattr(l, "proj", None) if attr is not None: attr.data = attr.data.to(dtype) model.apply(_convert_weights) convert_weights_to_fp16 = convert_weights_to_lp # backwards compat # used to maintain checkpoint compatibility def convert_to_custom_text_state_dict(state_dict: dict): if 'text_projection' in state_dict: # old format state_dict, move text tower -> .text new_state_dict = {} for k, v in state_dict.items(): if any(k.startswith(p) for p in ( 'text_projection', 'positional_embedding', 'token_embedding', 'transformer', 'ln_final', )): k = 'text.' + k new_state_dict[k] = v return new_state_dict return state_dict def build_model_from_openai_state_dict( state_dict: dict, quick_gelu=True, cast_dtype=torch.float16, ): vit = "visual.proj" in state_dict if vit: vision_width = state_dict["visual.conv1.weight"].shape[0] vision_layers = len( [k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) image_size = vision_patch_size * grid_size else: counts: list = [ len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] vision_layers = tuple(counts) vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) vision_patch_size = None assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] image_size = output_width * 32 embed_dim = state_dict["text_projection"].shape[1] context_length = state_dict["positional_embedding"].shape[0] vocab_size = state_dict["token_embedding.weight"].shape[0] transformer_width = state_dict["ln_final.weight"].shape[0] transformer_heads = transformer_width // 64 transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) vision_cfg = CLIPVisionCfg( layers=vision_layers, width=vision_width, patch_size=vision_patch_size, image_size=image_size, ) text_cfg = CLIPTextCfg( context_length=context_length, vocab_size=vocab_size, width=transformer_width, heads=transformer_heads, layers=transformer_layers, ) model = CLIP( embed_dim, vision_cfg=vision_cfg, text_cfg=text_cfg, quick_gelu=quick_gelu, # OpenAI models were trained with QuickGELU cast_dtype=cast_dtype, ) for key in ["input_resolution", "context_length", "vocab_size"]: state_dict.pop(key, None) convert_weights_to_fp16(model) # OpenAI state dicts are partially converted to float16 model.load_state_dict(state_dict) return model.eval() def trace_model(model, batch_size=256, device=torch.device('cpu')): model.eval() image_size = model.visual.image_size example_images = torch.ones((batch_size, 3, image_size, image_size), device=device) example_text = torch.zeros((batch_size, model.context_length), dtype=torch.int, device=device) model = torch.jit.trace_module( model, inputs=dict( forward=(example_images, example_text), encode_text=(example_text,), encode_image=(example_images,) )) model.visual.image_size = image_size return model def resize_pos_embed_timm(state_dict, model, interpolation: str = 'bicubic', antialias: bool = True): # Rescale the grid of position embeddings when loading from state_dict old_pos_embed = state_dict.get('visual.trunk.pos_embed', None) # 1, 196, 1024] if old_pos_embed is None: return grid_size = to_2tuple(model.visual.trunk.patch_embed.grid_size) if hasattr(model.visual.trunk, 'cls_token') and model.visual.trunk.cls_token is not None: return # extra_tokens? raise NotImplementedError new_seq_len = grid_size[0] * grid_size[1] if new_seq_len == old_pos_embed.shape[0]: return pos_emb_img = old_pos_embed old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img[0])))) old_pos_emb_img = pos_emb_img logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size) # Resizing position embedding grid-size from (1, 1) to (21, 21) pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2) pos_emb_img = F.interpolate( pos_emb_img, size=grid_size, mode=interpolation, antialias=antialias, align_corners=False, ) pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1) state_dict['visual.trunk.pos_embed'] = pos_emb_img def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic', antialias: bool = True): # Rescale the grid of position embeddings when loading from state_dict pe_key_name = 'visual.positional_embedding' old_pos_embed = state_dict.get('visual.positional_embedding', None) if old_pos_embed is None: pe_key_name = 'visual.trunk.pos_embed' old_pos_embed = state_dict.get('visual.trunk.pos_embed', None) # 1, 196, 1024] if old_pos_embed is None: return if hasattr(model.visual, 'grid_size'): grid_size = to_2tuple(model.visual.grid_size) elif hasattr(model.visual.trunk.patch_embed, 'grid_size'): grid_size = to_2tuple(model.visual.trunk.patch_embed.grid_size) else: return if hasattr(model.visual.trunk, 'cls_token') and model.visual.trunk.cls_token is not None: extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more) else: extra_tokens = 0 new_seq_len = grid_size[0] * grid_size[1] + extra_tokens if new_seq_len == old_pos_embed.shape[0]: return if extra_tokens: pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:] else: pos_emb_tok, pos_emb_img = None, old_pos_embed old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img)))) old_pos_emb_img = pos_emb_img logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size) # Resizing position embedding grid-size from (1, 1) to (21, 21) pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2) pos_emb_img = F.interpolate( pos_emb_img, size=grid_size, mode=interpolation, antialias=antialias, align_corners=False, ) pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0] if pos_emb_tok is not None: new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0) else: new_pos_embed = pos_emb_img state_dict[pe_key_name] = new_pos_embed def resize_text_pos_embed(state_dict, model, interpolation: str = 'linear', antialias: bool = False): old_pos_embed = state_dict.get('positional_embedding', None) if old_pos_embed is None: return # FIXME add support for text cls_token model_pos_embed = getattr(model, 'positional_embedding', None) if model_pos_embed is None: model_pos_embed = getattr(model.text, 'positional_embedding', None) old_num_pos = old_pos_embed.shape[0] old_width = old_pos_embed.shape[1] num_pos = model_pos_embed.shape[0] width = model_pos_embed.shape[1] assert old_width == width, 'text pos_embed width changed!' if old_num_pos == num_pos: return logging.info('Resizing text position embedding num_pos from %s to %s', old_num_pos, num_pos) old_pos_embed = old_pos_embed.reshape(1, old_num_pos, old_width).permute(0, 2, 1) old_pos_embed = F.interpolate( old_pos_embed, size=num_pos, mode=interpolation, antialias=antialias, align_corners=False, ) old_pos_embed = old_pos_embed.permute(0, 2, 1)[0] new_pos_embed = old_pos_embed state_dict['positional_embedding'] = new_pos_embed