Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
# Modified from https://github.com/zejiangh/MILAN | |
from collections import OrderedDict | |
from typing import Optional, Tuple, Union | |
import numpy as np | |
import torch | |
from mmengine.logging import MMLogger | |
from torch import nn | |
from mmpretrain.registry import MODELS | |
class LayerNorm(nn.LayerNorm): | |
"""Subclass torch's LayerNorm to handle fp16.""" | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
"""Forward function.""" | |
orig_type = x.dtype | |
ret = super().forward(x.type(torch.float32)) | |
return ret.type(orig_type) | |
class QuickGELU(nn.Module): | |
"""A faster version of GELU.""" | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
"""Forward function.""" | |
return x * torch.sigmoid(1.702 * x) | |
class ResidualAttentionBlock(nn.Module): | |
"""Residual Attention Block (RAB). | |
This module implements the same function as the MultiheadAttention, | |
but with a different interface, which is mainly used | |
in CLIP. | |
Args: | |
d_model (int): The feature dimension. | |
n_head (int): The number of attention heads. | |
attn_mask (torch.Tensor, optional): The attention mask. | |
Defaults to None. | |
""" | |
def __init__(self, | |
d_model: int, | |
n_head: int, | |
attn_mask: Optional[torch.Tensor] = None, | |
return_attention: bool = False) -> None: | |
super().__init__() | |
self.attn = nn.MultiheadAttention(d_model, n_head) | |
self.ln_1 = LayerNorm(d_model) | |
self.mlp = nn.Sequential( | |
OrderedDict([('c_fc', nn.Linear(d_model, d_model * 4)), | |
('gelu', QuickGELU()), | |
('c_proj', nn.Linear(d_model * 4, d_model))])) | |
self.ln_2 = LayerNorm(d_model) | |
self.attn_mask = attn_mask | |
self.return_attention = return_attention | |
def attention(self, x: torch.Tensor) -> torch.Tensor: | |
"""Attention function.""" | |
self.attn_mask = self.attn_mask.to( | |
dtype=x.dtype, | |
device=x.device) if self.attn_mask is not None else None | |
if self.return_attention: | |
return self.attn( | |
x, | |
x, | |
x, | |
need_weights=self.return_attention, | |
attn_mask=self.attn_mask) | |
else: | |
return self.attn( | |
x, | |
x, | |
x, | |
need_weights=self.return_attention, | |
attn_mask=self.attn_mask)[0] | |
def forward( | |
self, x: torch.Tensor | |
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: | |
"""Forward function.""" | |
if self.return_attention: | |
x_, attention = self.attention(self.ln_1(x)) | |
x = x + x_ | |
x = x + self.mlp(self.ln_2(x)) | |
return x, attention | |
else: | |
x = x + self.attention(self.ln_1(x)) | |
x = x + self.mlp(self.ln_2(x)) | |
return x | |
class Transformer(nn.Module): | |
"""Transformer. | |
Both visual and text branches use this transformer. | |
Args: | |
width (int): The feature dimension. | |
layers (int): The number of layers. | |
heads (int): The number of attention heads. | |
attn_mask (torch.Tensor, optional): The attention mask. | |
""" | |
def __init__(self, | |
width: int, | |
layers: int, | |
heads: int, | |
attn_mask: Optional[torch.Tensor] = None) -> None: | |
super().__init__() | |
self.width = width | |
self.layers = layers | |
self.resblocks = nn.ModuleList() | |
for _ in range(layers - 1): | |
self.resblocks.append( | |
ResidualAttentionBlock(width, heads, attn_mask)) | |
self.resblocks.append( | |
ResidualAttentionBlock( | |
width, heads, attn_mask, return_attention=True)) | |
def forward( | |
self, x: torch.Tensor | |
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
"""Forward function.""" | |
z = [] | |
for idx, blk in enumerate(self.resblocks): | |
if idx < self.layers - 1: | |
x = blk(x) | |
z.append(x.permute(1, 0, 2)) | |
else: | |
x, attention = blk(x) | |
z.append(x.permute(1, 0, 2)) | |
return x, attention, z | |
class VisionTransformer(nn.Module): | |
"""Vision Transformer for CLIP. | |
Args: | |
input_resolution (int): The image size. | |
patch_size (int): The patch size. | |
width (int): The feature dimension. | |
layers (int): The number of layers. | |
heads (int): The number of attention heads. | |
out_dim (int): The output dimension. | |
fineturn (bool): Whether to fineturn the model. | |
average_target (bool): Whether to average the target. | |
""" | |
def __init__(self, | |
input_resolution: int, | |
patch_size: int, | |
width: int, | |
layers: int, | |
heads: int, | |
output_dim: int, | |
finetune=False, | |
average_targets: int = 1) -> None: | |
super().__init__() | |
self.input_resolution = input_resolution | |
self.output_dim = output_dim | |
self.conv1 = nn.Conv2d( | |
in_channels=3, | |
out_channels=width, | |
kernel_size=patch_size, | |
stride=patch_size, | |
bias=False) | |
scale = width**-0.5 | |
self.class_embedding = nn.Parameter(scale * torch.randn(width)) | |
self.positional_embedding = nn.Parameter(scale * torch.randn( | |
(input_resolution // patch_size)**2 + 1, width)) | |
self.ln_pre = LayerNorm(width) | |
self.transformer = Transformer(width, layers, heads) | |
self.finetune = finetune | |
if finetune is False: | |
self.ln_post = LayerNorm(width) | |
self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) | |
self.average_targets = average_targets | |
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |
"""Forward function.""" | |
x = self.conv1(x) # shape = [*, width, grid, grid] | |
x = x.reshape(x.shape[0], x.shape[1], | |
-1) # shape = [*, width, grid ** 2] | |
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] | |
x = torch.cat([ | |
self.class_embedding.to(x.dtype) + torch.zeros( | |
x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x | |
], | |
dim=1) # shape = [*, grid ** 2 + 1, width] | |
x = x + self.positional_embedding.to(x.dtype) | |
x = self.ln_pre(x) | |
x = x.permute(1, 0, 2) # NLD -> LND | |
x, attention, z = self.transformer(x) | |
x = x.permute(1, 0, 2) # LND -> NLD | |
x = self.ln_post(x) | |
if self.proj is not None: | |
x = x @ self.proj | |
return x, attention | |
class CLIP(nn.Module): | |
"""CLIP. | |
Args: | |
embed_dim (int): The embedding dimension. | |
image_resolution (int): The image size. | |
vision_layers (int): The number of layers in the vision transformer. | |
vision_width (int): The feature dimension in the vision transformer. | |
vision_patch_size (int): The patch size in the vision transformer. | |
context_length (int): The context length. | |
vocab_size (int): The vocabulary size. | |
transformer_width (int): The feature dimension in the text transformer. | |
transformer_heads (int): The number of attention heads in the | |
text transformer. | |
transformer_layers (int): The number of layers in the text transformer. | |
fineturn (bool): Whether to fineturn the model. | |
average_target (bool): Whether to average the target. | |
""" | |
def __init__( | |
self, | |
embed_dim: int, | |
image_resolution: int, | |
vision_layers: Union[Tuple[int, int, int, int], int], | |
vision_width: int, | |
vision_patch_size: int, | |
context_length: int, | |
vocab_size: int, | |
transformer_width: int, | |
transformer_heads: int, | |
transformer_layers: int, | |
finetune: bool = False, | |
average_targets: int = 1, | |
) -> None: | |
super().__init__() | |
self.context_length = context_length | |
vision_heads = vision_width // 64 | |
self.visual = VisionTransformer( | |
input_resolution=image_resolution, | |
patch_size=vision_patch_size, | |
width=vision_width, | |
layers=vision_layers, | |
heads=vision_heads, | |
output_dim=embed_dim, | |
finetune=finetune, | |
average_targets=average_targets, | |
) | |
self.transformer = Transformer( | |
width=transformer_width, | |
layers=transformer_layers, | |
heads=transformer_heads, | |
attn_mask=self.build_attention_mask()) | |
self.vocab_size = vocab_size | |
self.token_embedding = nn.Embedding(vocab_size, transformer_width) | |
self.positional_embedding = nn.Parameter( | |
torch.empty(self.context_length, transformer_width)) | |
self.ln_final = LayerNorm(transformer_width) | |
self.text_projection = nn.Parameter( | |
torch.empty(transformer_width, embed_dim)) | |
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) | |
self.initialize_parameters() | |
def initialize_parameters(self) -> None: | |
"""Initialize the parameters. | |
The pretrained weight will override the initialized parameters by this | |
function. | |
""" | |
nn.init.normal_(self.token_embedding.weight, std=0.02) | |
nn.init.normal_(self.positional_embedding, std=0.01) | |
proj_std = (self.transformer.width**-0.5) * ( | |
(2 * self.transformer.layers)**-0.5) | |
attn_std = self.transformer.width**-0.5 | |
fc_std = (2 * self.transformer.width)**-0.5 | |
for block in self.transformer.resblocks: | |
nn.init.normal_(block.attn.in_proj_weight, std=attn_std) | |
nn.init.normal_(block.attn.out_proj.weight, std=proj_std) | |
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) | |
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) | |
if self.text_projection is not None: | |
nn.init.normal_( | |
self.text_projection, std=self.transformer.width**-0.5) | |
def build_attention_mask(self) -> torch.Tensor: | |
"""Build the attention mask.""" | |
# lazily create causal attention mask, with full attention between the | |
# vision tokens pytorch uses additive attention mask; fill with -inf | |
mask = torch.empty(self.context_length, self.context_length) | |
mask.fill_(float('-inf')) | |
mask.triu_(1) # zero out the lower diagonal | |
return mask | |
def dtype(self) -> torch.dtype: | |
"""Get the dtype.""" | |
return self.visual.conv1.weight.dtype | |
def encode_image(self, | |
image: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |
"""Encode the image. | |
Get the feature and attention mask from the last layer of the visual | |
branch of CLIP. | |
Args: | |
image (torch.Tensor): The image tensor with shape NCHW. | |
Returns: | |
Tuple[torch.Tensor, torch.Tensor]: The feature and attention mask. | |
""" | |
return self.visual(image.type(self.dtype)) | |
def build_clip_model(state_dict: dict, | |
finetune: bool = False, | |
average_targets: int = 1) -> nn.Module: | |
"""Build the CLIP model. | |
Args: | |
state_dict (dict): The pretrained state dict. | |
finetune (bool): Whether to fineturn the model. | |
average_targets (bool): Whether to average the target. | |
Returns: | |
nn.Module: The CLIP model. | |
""" | |
vit = 'visual.proj' in state_dict | |
if vit: | |
vision_width = state_dict['visual.conv1.weight'].shape[0] | |
vision_layers = len([ | |
k for k in state_dict.keys() | |
if k.startswith('visual.') and k.endswith('.attn.in_proj_weight') | |
]) | |
vision_patch_size = state_dict['visual.conv1.weight'].shape[-1] | |
grid_size = round( | |
(state_dict['visual.positional_embedding'].shape[0] - 1)**0.5) | |
image_resolution = vision_patch_size * grid_size | |
embed_dim = state_dict['text_projection'].shape[1] | |
context_length = state_dict['positional_embedding'].shape[0] | |
vocab_size = state_dict['token_embedding.weight'].shape[0] | |
transformer_width = state_dict['ln_final.weight'].shape[0] | |
transformer_heads = transformer_width // 64 | |
transformer_layers = len( | |
set( | |
k.split('.')[2] for k in state_dict | |
if k.startswith('transformer.resblocks'))) | |
model = CLIP( | |
embed_dim, | |
image_resolution, | |
vision_layers, | |
vision_width, | |
vision_patch_size, | |
context_length, | |
vocab_size, | |
transformer_width, | |
transformer_heads, | |
transformer_layers, | |
finetune, | |
average_targets, | |
) | |
for key in ['input_resolution', 'context_length', 'vocab_size']: | |
if key in state_dict: | |
del state_dict[key] | |
msg = model.load_state_dict(state_dict, strict=False) | |
MMLogger.get_current_instance().info(f'Load CLIP model: {msg}') | |
return model.eval() | |