ViTamin-XL-384px / model.py
bbexx's picture
add dependency
16d57ae
raw
history blame
28 kB
""" ViTamin
Paper: Designing Scalable Vison Models in the Vision-Language Era
@misc{chen2023designing,
title={Designing Scalable Vison Models in the Vision-Language Era},
author={Jieneng Chen and Qihang Yu and Xiaohui Shen and Alan Yuille and Liang-Cheih Chen},
year={2023},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
Based on Apache 2.0 licensed code at https://github.com/Beckschen/ViTamin
by Jieneng Chen 2024
Reference: https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
"""
from dataclasses import dataclass
import logging
import math
from typing import Optional, Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.checkpoint import checkpoint
from functools import partial
from open_clip.hf_model import HFTextEncoder
from open_clip.modified_resnet import ModifiedResNet
from open_clip.transformer import LayerNormFp32, LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer
from open_clip.utils import to_2tuple
import time
import timm
from timm.models.vision_transformer import _create_vision_transformer
from .timm_model import TimmModel
from .vitamin import *
# from .vitamin import HybridEmbed, MbConvStages, VitCfg, VitConvCfg
from .vitamin import GeGluMlp, ViTamin, HybridEmbed, MbConvStages, VitCfg, VitConvCfg
from transformers.modeling_utils import PreTrainedModel
from .configuration_vitamin import ViTaminConfig, ViTaminVisionConfig
@dataclass
class CLIPVisionCfg:
layers: Union[Tuple[int, int, int, int], int] = 12
width: int = 768
head_width: int = 64
mlp_ratio: float = 4.0
patch_size: int = 16
image_size: Union[Tuple[int, int], int] = 224
ls_init_value: Optional[float] = None
patch_dropout: float = 0.
input_patchnorm: bool = False
global_average_pool: bool = False
attentional_pool: bool = False
n_queries: int = 256
attn_pooler_heads: int = 8
output_tokens: bool = False
timm_model_name: str = None
timm_model_pretrained: bool = False
timm_pool: str = 'avg'
timm_proj: str = 'linear'
timm_proj_bias: bool = False
timm_drop: float = 0.
timm_drop_path: Optional[float] = None
@dataclass
class CLIPTextCfg:
context_length: int = 77
vocab_size: int = 49408
width: int = 512
heads: int = 8
layers: int = 12
ls_init_value: Optional[float] = None # layer scale initial value
hf_model_name: str = None
hf_tokenizer_name: str = None
hf_model_pretrained: bool = True
proj: str = 'mlp'
pooler_type: str = 'mean_pooler'
embed_cls: bool = False
pad_id: int = 0
output_tokens: bool = False
text_mask: str = 'first' # default first truncate in bpe_tokenizer
def get_cast_dtype(precision: str):
cast_dtype = None
if precision == 'bf16':
cast_dtype = torch.bfloat16
elif precision == 'fp16':
cast_dtype = torch.float16
return cast_dtype
def get_input_dtype(precision: str):
input_dtype = None
if precision in ('bf16', 'pure_bf16'):
input_dtype = torch.bfloat16
elif precision in ('fp16', 'pure_fp16'):
input_dtype = torch.float16
return input_dtype
def _build_vision_tower(
embed_dim: int,
vision_cfg: CLIPVisionCfg,
quick_gelu: bool = False,
cast_dtype: Optional[torch.dtype] = None
):
if isinstance(vision_cfg, dict):
vision_cfg = CLIPVisionCfg(**vision_cfg)
act_layer = QuickGELU if quick_gelu else nn.GELU
if vision_cfg.timm_model_name:
visual = TimmModel(
vision_cfg.timm_model_name,
pretrained=vision_cfg.timm_model_pretrained,
pool=vision_cfg.timm_pool,
proj=vision_cfg.timm_proj,
proj_bias=vision_cfg.timm_proj_bias,
drop=vision_cfg.timm_drop,
drop_path=vision_cfg.timm_drop_path,
patch_drop=vision_cfg.patch_dropout if vision_cfg.patch_dropout > 0 else None,
embed_dim=embed_dim,
image_size=vision_cfg.image_size,
)
elif isinstance(vision_cfg.layers, (tuple, list)):
vision_heads = vision_cfg.width * 32 // vision_cfg.head_width
visual = ModifiedResNet(
layers=vision_cfg.layers,
output_dim=embed_dim,
heads=vision_heads,
image_size=vision_cfg.image_size,
width=vision_cfg.width,
)
else:
vision_heads = vision_cfg.width // vision_cfg.head_width
norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
visual = VisionTransformer(
image_size=vision_cfg.image_size,
patch_size=vision_cfg.patch_size,
width=vision_cfg.width,
layers=vision_cfg.layers,
heads=vision_heads,
mlp_ratio=vision_cfg.mlp_ratio,
ls_init_value=vision_cfg.ls_init_value,
patch_dropout=vision_cfg.patch_dropout,
input_patchnorm=vision_cfg.input_patchnorm,
global_average_pool=vision_cfg.global_average_pool,
attentional_pool=vision_cfg.attentional_pool,
n_queries=vision_cfg.n_queries,
attn_pooler_heads=vision_cfg.attn_pooler_heads,
output_tokens=vision_cfg.output_tokens,
output_dim=embed_dim,
act_layer=act_layer,
norm_layer=norm_layer,
)
return visual
def _build_text_tower(
embed_dim: int,
text_cfg: CLIPTextCfg,
quick_gelu: bool = False,
cast_dtype: Optional[torch.dtype] = None,
):
if isinstance(text_cfg, dict):
text_cfg = CLIPTextCfg(**text_cfg)
if text_cfg.hf_model_name:
text = HFTextEncoder(
text_cfg.hf_model_name,
output_dim=embed_dim,
proj=text_cfg.proj,
pooler_type=text_cfg.pooler_type,
pretrained=text_cfg.hf_model_pretrained,
output_tokens=text_cfg.output_tokens,
)
else:
act_layer = QuickGELU if quick_gelu else nn.GELU
norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
text = TextTransformer(
context_length=text_cfg.context_length,
vocab_size=text_cfg.vocab_size,
width=text_cfg.width,
heads=text_cfg.heads,
layers=text_cfg.layers,
ls_init_value=text_cfg.ls_init_value,
output_dim=embed_dim,
embed_cls=text_cfg.embed_cls,
output_tokens=text_cfg.output_tokens,
pad_id=text_cfg.pad_id,
act_layer=act_layer,
norm_layer=norm_layer,
)
return text
class CLIP(nn.Module):
output_dict: torch.jit.Final[bool]
def __init__(
self,
embed_dim: int,
vision_cfg: CLIPVisionCfg,
text_cfg: CLIPTextCfg,
quick_gelu: bool = False,
cast_dtype: Optional[torch.dtype] = None,
output_dict: bool = False,
):
super().__init__()
self.output_dict = output_dict
self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
self.transformer = text.transformer
self.context_length = text.context_length
self.vocab_size = text.vocab_size
self.token_embedding = text.token_embedding
self.positional_embedding = text.positional_embedding
self.ln_final = text.ln_final
self.text_projection = text.text_projection
self.register_buffer('attn_mask', text.attn_mask, persistent=False)
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
self.method_lock_text_tower = text.lock
self.text_no_grad = False
def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
# lock image tower as per LiT - https://arxiv.org/abs/2111.07991
self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True, unlock_text_proj=False):
# added by jieneng
self.method_lock_text_tower(unlocked_layers, freeze_layer_norm)
self.text_no_grad = True
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True, enable_text=True):
self.visual.set_grad_checkpointing(enable)
self.transformer.grad_checkpointing = enable_text
def encode_image(self, image, normalize: bool = False):
features = self.visual(image)
return F.normalize(features, dim=-1) if normalize else features
def encode_text(self, text, normalize: bool = False):
cast_dtype = self.transformer.get_cast_dtype()
x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
x = x + self.positional_embedding.to(cast_dtype)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x, attn_mask=self.attn_mask)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_final(x) # [batch_size, n_ctx, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
return F.normalize(x, dim=-1) if normalize else x
def forward(
self,
image: Optional[torch.Tensor] = None,
text: Optional[torch.Tensor] = None,
):
# torch.cuda.synchronize()
image_features = self.encode_image(image, normalize=True) if image is not None else None
if self.text_no_grad:
with torch.no_grad():
text_features = self.encode_text(text, normalize=True).detach() if text is not None else None
else:
text_features = self.encode_text(text, normalize=True) if text is not None else None
if self.output_dict:
return {
"image_features": image_features,
"text_features": text_features,
"logit_scale": self.logit_scale.exp()
}
return image_features, text_features, self.logit_scale.exp()
# class CustomTextCLIP(nn.Module):
class CustomTextCLIP(nn.Module):
output_dict: torch.jit.Final[bool]
def __init__(
self,
embed_dim: int,
vision_cfg: CLIPVisionCfg,
text_cfg: CLIPTextCfg,
quick_gelu: bool = False,
cast_dtype: Optional[torch.dtype] = None,
output_dict: bool = False,
):
super().__init__()
self.output_dict = output_dict
self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
self.context_length = self.text.context_length
self.vocab_size = self.text.vocab_size
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
self.text_no_grad = False
def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
# lock image tower as per LiT - https://arxiv.org/abs/2111.07991
self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True, unlock_text_proj = False):
self.text.lock(unlocked_layers, freeze_layer_norm, unlock_text_proj)
self.text_no_grad = True
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True, enable_text=True):
self.visual.set_grad_checkpointing(enable)
self.text.set_grad_checkpointing(enable_text)
def encode_image(self, image, normalize: bool = False):
features = self.visual(image)
return F.normalize(features, dim=-1) if normalize else features
def encode_text(self, text, normalize: bool = False):
features = self.text(text)
return F.normalize(features, dim=-1) if normalize else features
def forward(
self,
image: Optional[torch.Tensor] = None,
text: Optional[torch.Tensor] = None,
):
image_features = self.encode_image(image, normalize=True) if image is not None else None
# if self.text_no_grad:
# with torch.no_grad():
# text_features = self.encode_text(text, normalize=True).detach() if text is not None else None
# else:
text_features = self.encode_text(text, normalize=True) if text is not None else None
if self.output_dict:
return {
"image_features": image_features,
"text_features": text_features,
"logit_scale": self.logit_scale.exp()
}
return image_features, text_features, self.logit_scale.exp()
class ViTaminPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = ViTaminConfig
base_model_prefix = 'vitamin'
# hack CLIPVisionModel for llava: https://github.com/huggingface/transformers/blob/9acce7de1cb8229304a467938ebb47727d60cdb2/src/transformers/models/clip/modeling_clip.py#L878
class ViTaminVisionModel(PreTrainedModel):
config_class = ViTaminVisionConfig
main_input_name = 'pixel_values'
def __init__(self, config: ViTaminVisionConfig):
super().__init__(config)
self.visual = _build_vision_tower(config.embed_dim, config)
def forward(
self,
pixel_values: Optional[torch.FloatTensor] = None,
select_layer = -2,
):
assert len(pixel_values.shape) == 4, f'wrong pixel_values size: {pixel_values.shape}'
x = self.visual.trunk.patch_embed.backbone.stem(pixel_values)
x = self.visual.trunk.patch_embed.backbone.stages[0](x)
x = self.visual.trunk.patch_embed.backbone.stages[1](x)
x = self.visual.trunk.patch_embed.backbone.pool(x)
x = self.visual.trunk.patch_embed.proj(x)
x = x.flatten(2).transpose(1, 2)
x = self.visual.trunk.patch_drop(x)
x = self.visual.trunk.norm_pre(x)
x = self.visual.trunk.blocks[:select_layer+1](x)
return x
class ViTaminCLIP(ViTaminPreTrainedModel):
output_dict: torch.jit.Final[bool]
config_class: ViTaminConfig
def __init__(
self,
config: ViTaminConfig
):
super().__init__(config)
embed_dim=config.embed_dim #: int,
vision_cfg=config.vision_cfg #: CLIPVisionCfg,
text_cfg=config.text_cfg #: CLIPTextCfg,
quick_gelu=False
cast_dtype=None
output_dict=False
self.config = config
self.output_dict = output_dict
self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
self.context_length = self.text.context_length
self.vocab_size = self.text.vocab_size
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
self.text_no_grad = False
def forward_visual4llava(
self,
pixel_values: Optional[torch.FloatTensor] = None,
select_layer = -2,
):
assert len(pixel_values.shape) == 4, f'wrong pixel_values size: {pixel_values.shape}'
x = self.visual.trunk.patch_embed.backbone.stem(pixel_values)
x = self.visual.trunk.patch_embed.backbone.stages[0](x)
x = self.visual.trunk.patch_embed.backbone.stages[1](x)
x = self.visual.trunk.patch_embed.backbone.pool(x)
x = self.visual.trunk.patch_embed.proj(x)
x = x.flatten(2).transpose(1, 2)
x = self.visual.trunk.patch_drop(x)
x = self.visual.trunk.norm_pre(x)
x = self.visual.trunk.blocks[:select_layer+1](x)
return x
def encode_image(self, image, normalize: bool = False):
features = self.visual(image)
return F.normalize(features, dim=-1) if normalize else features
def encode_text(self, text, normalize: bool = False):
features = self.text(text)
return F.normalize(features, dim=-1) if normalize else features
def forward_pixel(
self,
image: Optional[torch.Tensor] = None,
text: Optional[torch.Tensor] = None,
):
x = self.visual.trunk.patch_embed.backbone.stem(image)
x = self.visual.trunk.patch_embed.backbone.stages[0](x)
x = self.visual.trunk.patch_embed.backbone.stages[1](x)
x = self.visual.trunk.patch_embed.backbone.pool(x)
x = self.visual.trunk.patch_embed.proj(x)
x = x.flatten(2).transpose(1, 2)
x = self.visual.trunk.patch_drop(x)
x = self.visual.trunk.norm_pre(x)
x = self.visual.trunk.blocks(x)
x = self.visual.trunk.fc_norm(x)
x = self.visual.head.proj(x)
image_features = F.normalize(x, dim=-1)
text_features = self.encode_text(text, normalize=True) if text is not None else None
if self.output_dict:
return {
"image_features": image_features,
"text_features": text_features,
"logit_scale": self.logit_scale.exp()
}
return image_features, text_features, self.logit_scale.exp()
def forward(
self,
image: Optional[torch.Tensor] = None,
text: Optional[torch.Tensor] = None,
):
image_features = self.encode_image(image, normalize=True) if image is not None else None
# if self.text_no_grad:
# with torch.no_grad():
# text_features = self.encode_text(text, normalize=True).detach() if text is not None else None
# else:
text_features = self.encode_text(text, normalize=True) if text is not None else None
if self.output_dict:
return {
"image_features": image_features,
"text_features": text_features,
"logit_scale": self.logit_scale.exp()
}
return image_features, text_features, self.logit_scale.exp()
def convert_weights_to_lp(model: nn.Module, dtype=torch.float16):
"""Convert applicable model parameters to low-precision (bf16 or fp16)"""
def _convert_weights(l):
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
l.weight.data = l.weight.data.to(dtype)
if l.bias is not None:
l.bias.data = l.bias.data.to(dtype)
if isinstance(l, (nn.MultiheadAttention, Attention)):
for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
tensor = getattr(l, attr)
if tensor is not None:
tensor.data = tensor.data.to(dtype)
if isinstance(l, (CLIP, TextTransformer)):
# convert text nn.Parameter projections
attr = getattr(l, "text_projection", None)
if attr is not None:
attr.data = attr.data.to(dtype)
if isinstance(l, VisionTransformer):
# convert vision nn.Parameter projections
attr = getattr(l, "proj", None)
if attr is not None:
attr.data = attr.data.to(dtype)
model.apply(_convert_weights)
convert_weights_to_fp16 = convert_weights_to_lp # backwards compat
# used to maintain checkpoint compatibility
def convert_to_custom_text_state_dict(state_dict: dict):
if 'text_projection' in state_dict:
# old format state_dict, move text tower -> .text
new_state_dict = {}
for k, v in state_dict.items():
if any(k.startswith(p) for p in (
'text_projection',
'positional_embedding',
'token_embedding',
'transformer',
'ln_final',
)):
k = 'text.' + k
new_state_dict[k] = v
return new_state_dict
return state_dict
def build_model_from_openai_state_dict(
state_dict: dict,
quick_gelu=True,
cast_dtype=torch.float16,
):
vit = "visual.proj" in state_dict
if vit:
vision_width = state_dict["visual.conv1.weight"].shape[0]
vision_layers = len(
[k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
image_size = vision_patch_size * grid_size
else:
counts: list = [
len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
vision_layers = tuple(counts)
vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
vision_patch_size = None
assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
image_size = output_width * 32
embed_dim = state_dict["text_projection"].shape[1]
context_length = state_dict["positional_embedding"].shape[0]
vocab_size = state_dict["token_embedding.weight"].shape[0]
transformer_width = state_dict["ln_final.weight"].shape[0]
transformer_heads = transformer_width // 64
transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
vision_cfg = CLIPVisionCfg(
layers=vision_layers,
width=vision_width,
patch_size=vision_patch_size,
image_size=image_size,
)
text_cfg = CLIPTextCfg(
context_length=context_length,
vocab_size=vocab_size,
width=transformer_width,
heads=transformer_heads,
layers=transformer_layers,
)
model = CLIP(
embed_dim,
vision_cfg=vision_cfg,
text_cfg=text_cfg,
quick_gelu=quick_gelu, # OpenAI models were trained with QuickGELU
cast_dtype=cast_dtype,
)
for key in ["input_resolution", "context_length", "vocab_size"]:
state_dict.pop(key, None)
convert_weights_to_fp16(model) # OpenAI state dicts are partially converted to float16
model.load_state_dict(state_dict)
return model.eval()
def trace_model(model, batch_size=256, device=torch.device('cpu')):
model.eval()
image_size = model.visual.image_size
example_images = torch.ones((batch_size, 3, image_size, image_size), device=device)
example_text = torch.zeros((batch_size, model.context_length), dtype=torch.int, device=device)
model = torch.jit.trace_module(
model,
inputs=dict(
forward=(example_images, example_text),
encode_text=(example_text,),
encode_image=(example_images,)
))
model.visual.image_size = image_size
return model
def resize_pos_embed_timm(state_dict, model, interpolation: str = 'bicubic', antialias: bool = True):
# Rescale the grid of position embeddings when loading from state_dict
old_pos_embed = state_dict.get('visual.trunk.pos_embed', None) # 1, 196, 1024]
if old_pos_embed is None:
return
grid_size = to_2tuple(model.visual.trunk.patch_embed.grid_size)
if hasattr(model.visual.trunk, 'cls_token') and model.visual.trunk.cls_token is not None:
return
# extra_tokens?
raise NotImplementedError
new_seq_len = grid_size[0] * grid_size[1]
if new_seq_len == old_pos_embed.shape[0]:
return
pos_emb_img = old_pos_embed
old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img[0]))))
old_pos_emb_img = pos_emb_img
logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size) # Resizing position embedding grid-size from (1, 1) to (21, 21)
pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
pos_emb_img = F.interpolate(
pos_emb_img,
size=grid_size,
mode=interpolation,
antialias=antialias,
align_corners=False,
)
pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)
state_dict['visual.trunk.pos_embed'] = pos_emb_img
def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic', antialias: bool = True):
# Rescale the grid of position embeddings when loading from state_dict
pe_key_name = 'visual.positional_embedding'
old_pos_embed = state_dict.get('visual.positional_embedding', None)
if old_pos_embed is None:
pe_key_name = 'visual.trunk.pos_embed'
old_pos_embed = state_dict.get('visual.trunk.pos_embed', None) # 1, 196, 1024]
if old_pos_embed is None:
return
if hasattr(model.visual, 'grid_size'):
grid_size = to_2tuple(model.visual.grid_size)
elif hasattr(model.visual.trunk.patch_embed, 'grid_size'):
grid_size = to_2tuple(model.visual.trunk.patch_embed.grid_size)
else:
return
if hasattr(model.visual.trunk, 'cls_token') and model.visual.trunk.cls_token is not None:
extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
else:
extra_tokens = 0
new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
if new_seq_len == old_pos_embed.shape[0]:
return
if extra_tokens:
pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
else:
pos_emb_tok, pos_emb_img = None, old_pos_embed
old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))
old_pos_emb_img = pos_emb_img
logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size) # Resizing position embedding grid-size from (1, 1) to (21, 21)
pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
pos_emb_img = F.interpolate(
pos_emb_img,
size=grid_size,
mode=interpolation,
antialias=antialias,
align_corners=False,
)
pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
if pos_emb_tok is not None:
new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
else:
new_pos_embed = pos_emb_img
state_dict[pe_key_name] = new_pos_embed
def resize_text_pos_embed(state_dict, model, interpolation: str = 'linear', antialias: bool = False):
old_pos_embed = state_dict.get('positional_embedding', None)
if old_pos_embed is None:
return
# FIXME add support for text cls_token
model_pos_embed = getattr(model, 'positional_embedding', None)
if model_pos_embed is None:
model_pos_embed = getattr(model.text, 'positional_embedding', None)
old_num_pos = old_pos_embed.shape[0]
old_width = old_pos_embed.shape[1]
num_pos = model_pos_embed.shape[0]
width = model_pos_embed.shape[1]
assert old_width == width, 'text pos_embed width changed!'
if old_num_pos == num_pos:
return
logging.info('Resizing text position embedding num_pos from %s to %s', old_num_pos, num_pos)
old_pos_embed = old_pos_embed.reshape(1, old_num_pos, old_width).permute(0, 2, 1)
old_pos_embed = F.interpolate(
old_pos_embed,
size=num_pos,
mode=interpolation,
antialias=antialias,
align_corners=False,
)
old_pos_embed = old_pos_embed.permute(0, 2, 1)[0]
new_pos_embed = old_pos_embed
state_dict['positional_embedding'] = new_pos_embed