|
|
from collections import OrderedDict |
|
|
from dataclasses import asdict |
|
|
from functools import partial |
|
|
from logging import getLogger |
|
|
from typing import Callable, Optional, Literal |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from einops import rearrange |
|
|
from timm.layers import DropPath |
|
|
from torch.nn import functional as F |
|
|
from torch.nn.init import constant_, xavier_uniform_ |
|
|
from torch.nn.parameter import Parameter |
|
|
from torch.utils.checkpoint import checkpoint |
|
|
import types |
|
|
from core.vision_encoder.rope import Rope2D |
|
|
from core.vision_encoder.config import PEConfig, PETextConfig, PE_VISION_CONFIG, PE_TEXT_CONFIG, fetch_pe_checkpoint |
|
|
|
|
|
|
|
|
|
|
|
logger = getLogger() |
|
|
|
|
|
|
|
|
|
|
|
class LayerScale(nn.Module): |
|
|
def __init__(self, dim, init_values=1e-5, inplace=False): |
|
|
super().__init__() |
|
|
self.inplace = inplace |
|
|
self.dim = dim |
|
|
self.init_values = init_values |
|
|
|
|
|
def forward(self, x): |
|
|
return x.mul_(self.gamma) if self.inplace else x * self.gamma |
|
|
|
|
|
def init_tensors(self): |
|
|
self.gamma = nn.Parameter(self.init_values * torch.ones(self.dim)) |
|
|
|
|
|
|
|
|
class AttentionPooling(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
embed_dim: int, |
|
|
num_heads: int, |
|
|
num_probe: int = 1, |
|
|
mlp_ratio: int = 4, |
|
|
act_layer: Callable = nn.GELU, |
|
|
norm_layer: Callable = nn.LayerNorm, |
|
|
): |
|
|
super().__init__() |
|
|
self.embed_dim = embed_dim |
|
|
self.num_heads = num_heads |
|
|
self.probe = nn.Parameter(torch.randn(1, num_probe, self.embed_dim)) |
|
|
self.attn = nn.MultiheadAttention(self.embed_dim, self.num_heads, batch_first=True) |
|
|
self.layernorm = norm_layer(embed_dim) |
|
|
self.mlp_width = int(embed_dim * mlp_ratio) |
|
|
self.mlp = nn.Sequential( |
|
|
OrderedDict( |
|
|
[ |
|
|
("c_fc", nn.Linear(self.embed_dim, self.mlp_width)), |
|
|
("gelu", act_layer()), |
|
|
("c_proj", nn.Linear(self.mlp_width, self.embed_dim)), |
|
|
] |
|
|
) |
|
|
) |
|
|
self._is_converted = False |
|
|
|
|
|
def forward(self, x: torch.Tensor): |
|
|
|
|
|
batch, _, _ = x.shape |
|
|
q = self.probe.repeat((batch, 1, 1)).to(x.dtype) |
|
|
x_attn = self.attn(q, x, x, need_weights=False)[0] |
|
|
x = x_attn + self.mlp(self.layernorm(x_attn)) |
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
class SelfAttention(nn.Module): |
|
|
r""" |
|
|
Implements sequence packed attention and RoPe |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
embed_dim: int, |
|
|
num_heads: int, |
|
|
rope: Optional[nn.Module] = None, |
|
|
): |
|
|
super(SelfAttention, self).__init__() |
|
|
self.embed_dim = embed_dim |
|
|
|
|
|
self.num_heads = num_heads |
|
|
self.head_dim = embed_dim // num_heads |
|
|
assert ( |
|
|
self.head_dim * num_heads == self.embed_dim |
|
|
), "embed_dim must be divisible by num_heads" |
|
|
|
|
|
|
|
|
self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim)) |
|
|
self.in_proj_bias = Parameter(torch.empty(3 * embed_dim)) |
|
|
|
|
|
self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True) |
|
|
|
|
|
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True) |
|
|
|
|
|
self.rope = rope |
|
|
self.scale = self.head_dim ** (-0.5) |
|
|
|
|
|
def init_tensors(self): |
|
|
xavier_uniform_(self.in_proj_weight) |
|
|
constant_(self.in_proj_bias, 0.0) |
|
|
constant_(self.out_proj.bias, 0.0) |
|
|
|
|
|
|
|
|
def del_muda(self): |
|
|
del self.in_proj_weight |
|
|
del self.in_proj_bias |
|
|
|
|
|
def migrate_weights(self): |
|
|
""" |
|
|
MUST be called *after* loading the state_dict. |
|
|
This copies the weights from the old Parameters to the new nn.Linear layer. |
|
|
""" |
|
|
|
|
|
with torch.no_grad(): |
|
|
self.in_proj.weight.copy_(self.in_proj_weight) |
|
|
self.in_proj.bias.copy_(self.in_proj_bias) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x, attn_mask=None, need_weights=False): |
|
|
batch, seq, embed_dim = x.shape |
|
|
|
|
|
|
|
|
proj = self.in_proj(x) |
|
|
|
|
|
proj = ( |
|
|
proj.unflatten(-1, (3, embed_dim)) |
|
|
.unsqueeze(0) |
|
|
.transpose(0, -2) |
|
|
.squeeze(-2) |
|
|
.contiguous() |
|
|
) |
|
|
q, k, v = proj[0], proj[1], proj[2] |
|
|
|
|
|
|
|
|
q = rearrange(q, "b s (h d) -> b h s d", h=self.num_heads) |
|
|
k = rearrange(k, "b s (h d) -> b h s d", h=self.num_heads) |
|
|
v = rearrange(v, "b s (h d) -> b h s d", h=self.num_heads) |
|
|
|
|
|
if self.rope: |
|
|
q, k = self.rope(q, k) |
|
|
|
|
|
if not need_weights: |
|
|
|
|
|
attn = F.scaled_dot_product_attention( |
|
|
q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False, scale=self.scale |
|
|
) |
|
|
attn = rearrange(attn, "b h s d -> b s (h d)") |
|
|
return self.out_proj(attn) |
|
|
else: |
|
|
|
|
|
q_scaled = q * self.scale |
|
|
|
|
|
attn_weights = torch.matmul(q_scaled, k.transpose(-2, -1)) |
|
|
|
|
|
if attn_mask is not None: |
|
|
attn_weights += attn_mask |
|
|
|
|
|
attn_weights = F.softmax(attn_weights, dim=-1) |
|
|
|
|
|
attn_output = torch.matmul(attn_weights, v) |
|
|
attn_output = rearrange(attn_output, "b h s d -> b s (h d)") |
|
|
|
|
|
output = self.out_proj(attn_output) |
|
|
return output, attn_weights |
|
|
|
|
|
|
|
|
class ResidualAttentionBlock(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
d_model: int, |
|
|
n_head: int, |
|
|
mlp_ratio: float = 4.0, |
|
|
ls_init_value: float = None, |
|
|
act_layer: Callable = nn.GELU, |
|
|
norm_layer: Callable = nn.LayerNorm, |
|
|
drop_path: float = 0.0, |
|
|
rope: Optional[nn.Module] = None, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
if rope: |
|
|
self.attn = SelfAttention(d_model, n_head, rope=rope) |
|
|
else: |
|
|
self.attn = nn.MultiheadAttention(d_model, n_head, batch_first=True) |
|
|
|
|
|
self.ls_1 = ( |
|
|
LayerScale(d_model, ls_init_value) |
|
|
if ls_init_value is not None |
|
|
else nn.Identity() |
|
|
) |
|
|
self.ls_2 = ( |
|
|
LayerScale(d_model, ls_init_value) |
|
|
if ls_init_value is not None |
|
|
else nn.Identity() |
|
|
) |
|
|
|
|
|
self.ln_1 = norm_layer(d_model) |
|
|
self.ln_2 = norm_layer(d_model) |
|
|
|
|
|
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() |
|
|
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() |
|
|
|
|
|
mlp_width = int(d_model * mlp_ratio) |
|
|
self.mlp = nn.Sequential( |
|
|
OrderedDict( |
|
|
[ |
|
|
("c_fc", nn.Linear(d_model, mlp_width)), |
|
|
("gelu", act_layer()), |
|
|
("c_proj", nn.Linear(mlp_width, d_model)), |
|
|
] |
|
|
) |
|
|
) |
|
|
|
|
|
def _call_attn( |
|
|
self, |
|
|
q_x: torch.Tensor, |
|
|
attn_mask: Optional[torch.Tensor] = None, |
|
|
need_weights: bool = False, |
|
|
): |
|
|
|
|
|
if attn_mask is not None: |
|
|
if not attn_mask.dtype == torch.bool: |
|
|
attn_mask = attn_mask.to(q_x.dtype) |
|
|
|
|
|
if isinstance(self.attn, SelfAttention): |
|
|
|
|
|
return self.attn(q_x, attn_mask=attn_mask, need_weights=need_weights) |
|
|
else: |
|
|
|
|
|
return self.attn(q_x, q_x, q_x, attn_mask=attn_mask, need_weights=need_weights)[0] |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
x: torch.Tensor, |
|
|
attn_mask: Optional[torch.Tensor] = None, |
|
|
need_weights: bool = False, |
|
|
): |
|
|
attn_result = self._call_attn(self.ln_1(x), attn_mask=attn_mask, need_weights=need_weights) |
|
|
|
|
|
attn_weights = None |
|
|
if need_weights: |
|
|
|
|
|
attn_output, attn_weights = attn_result |
|
|
else: |
|
|
attn_output = attn_result |
|
|
|
|
|
x = x + self.drop_path1(self.ls_1(attn_output)) |
|
|
x = x + self.drop_path2(self.ls_2(self.mlp(self.ln_2(x)))) |
|
|
|
|
|
if need_weights: |
|
|
return x, attn_weights |
|
|
|
|
|
return x |
|
|
|
|
|
def del_muda(self): |
|
|
self.attn.del_muda() |
|
|
|
|
|
class Transformer(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
width: int, |
|
|
layers: int, |
|
|
heads: int, |
|
|
mlp_ratio: float = 4.0, |
|
|
ls_init_value: float = None, |
|
|
act_layer: Callable = nn.GELU, |
|
|
norm_layer: Callable = nn.LayerNorm, |
|
|
drop_path: float = 0.0, |
|
|
rope: Optional[nn.Module] = None, |
|
|
): |
|
|
super().__init__() |
|
|
self.width = width |
|
|
self.layers = layers |
|
|
self.grad_checkpointing = False |
|
|
|
|
|
self.resblocks = nn.ModuleList( |
|
|
[ |
|
|
ResidualAttentionBlock( |
|
|
width, |
|
|
heads, |
|
|
mlp_ratio, |
|
|
ls_init_value=ls_init_value, |
|
|
act_layer=act_layer, |
|
|
norm_layer=norm_layer, |
|
|
drop_path=drop_path, |
|
|
rope=rope, |
|
|
) |
|
|
for _ in range(layers) |
|
|
] |
|
|
) |
|
|
|
|
|
@torch.jit.ignore |
|
|
def set_grad_checkpointing(self, enable=True): |
|
|
self.grad_checkpointing = enable |
|
|
|
|
|
@torch.jit.ignore |
|
|
def truncate(self, layer_idx: int): |
|
|
""" Delete layers so the last layer is the given layer index. """ |
|
|
self.layers = ((self.layers + layer_idx) % self.layers) + 1 |
|
|
self.resblocks = nn.ModuleList(self.resblocks[:self.layers]) |
|
|
|
|
|
def del_muda(self): |
|
|
for resblock in self.resblocks: |
|
|
resblock.del_muda() |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
x: torch.Tensor, |
|
|
attn_mask: Optional[torch.Tensor] = None, |
|
|
layer_idx: int = -1, |
|
|
need_weights: bool = False, |
|
|
): |
|
|
stop_idx = (self.layers + layer_idx) % self.layers |
|
|
|
|
|
attention_maps = [] |
|
|
|
|
|
for i, r in enumerate(self.resblocks): |
|
|
if self.grad_checkpointing and not torch.jit.is_scripting(): |
|
|
if need_weights: |
|
|
raise ValueError("Cannot get attention maps with gradient checkpointing enabled.") |
|
|
x = checkpoint(r, x, attn_mask, use_reentrant=False) |
|
|
else: |
|
|
if need_weights: |
|
|
x, attn_map = r(x, attn_mask=attn_mask, need_weights=True) |
|
|
attention_maps.append(attn_map) |
|
|
else: |
|
|
x = r(x, attn_mask=attn_mask, need_weights=False) |
|
|
|
|
|
if i == stop_idx: |
|
|
break |
|
|
|
|
|
if need_weights: |
|
|
return x, attention_maps |
|
|
|
|
|
return x |
|
|
|
|
|
|
|
|
class VisionTransformer(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
patch_size: int, |
|
|
width: int, |
|
|
layers: int, |
|
|
heads: int, |
|
|
mlp_ratio: float, |
|
|
act_layer: Callable = nn.GELU, |
|
|
norm_layer: Callable = partial(nn.LayerNorm, eps=1e-5), |
|
|
use_ln_pre: bool = True, |
|
|
use_ln_post: bool = True, |
|
|
ls_init_value: float = None, |
|
|
drop_path: float = 0.0, |
|
|
image_size: int = 448, |
|
|
use_abs_posemb: bool = True, |
|
|
use_rope2d: bool = True, |
|
|
use_cls_token: bool = False, |
|
|
output_dim: Optional[int] = 1280, |
|
|
attn_pooler_heads: int = 8, |
|
|
pool_type: Literal["attn", "tok", "avg", "none"] = "attn", |
|
|
): |
|
|
super().__init__() |
|
|
assert pool_type in ("attn", "tok", "avg", "none") |
|
|
self.pool_type = pool_type |
|
|
self.patch_size = patch_size |
|
|
|
|
|
self.output_dim = output_dim or width |
|
|
self.proj_dim = output_dim |
|
|
self.heads = heads |
|
|
self.width = width |
|
|
self.layers = layers |
|
|
|
|
|
self.use_abs_posemb = use_abs_posemb |
|
|
self.use_cls_token = use_cls_token |
|
|
self.use_rope2d = use_rope2d |
|
|
self.image_size = image_size |
|
|
|
|
|
self.conv1 = nn.Conv2d( |
|
|
in_channels=3, |
|
|
out_channels=width, |
|
|
kernel_size=patch_size, |
|
|
stride=patch_size, |
|
|
bias=False, |
|
|
) |
|
|
self.rope = ( |
|
|
Rope2D( |
|
|
dim=width // heads, |
|
|
use_cls_token=self.use_cls_token, |
|
|
) |
|
|
if self.use_rope2d |
|
|
else None |
|
|
) |
|
|
|
|
|
self.ln_pre = norm_layer(width) if use_ln_pre else nn.Identity() |
|
|
self.ln_post = norm_layer(self.width) if use_ln_post else nn.Identity() |
|
|
|
|
|
self.transformer = Transformer( |
|
|
width, |
|
|
layers, |
|
|
heads, |
|
|
mlp_ratio, |
|
|
ls_init_value=ls_init_value, |
|
|
act_layer=act_layer, |
|
|
norm_layer=norm_layer, |
|
|
drop_path=drop_path, |
|
|
rope=self.rope, |
|
|
) |
|
|
|
|
|
if pool_type == "attn": |
|
|
self.attn_pool = AttentionPooling( |
|
|
embed_dim=width, |
|
|
num_heads=attn_pooler_heads, |
|
|
act_layer=act_layer, |
|
|
norm_layer=norm_layer, |
|
|
) |
|
|
else: |
|
|
self.attn_pool = None |
|
|
|
|
|
self.init_tensors() |
|
|
|
|
|
|
|
|
def del_muda(self): |
|
|
self.transformer.del_muda() |
|
|
|
|
|
def delete_attn_pool(self): |
|
|
del self.attn_pool |
|
|
|
|
|
|
|
|
def init_tensors(self): |
|
|
def init_submodule_tensors(module): |
|
|
for name, child in module.named_children(): |
|
|
if hasattr(child, "init_tensors"): |
|
|
logger.debug(f"Initializing tensors for submodule: {name}") |
|
|
child.init_tensors() |
|
|
init_submodule_tensors(child) |
|
|
|
|
|
init_submodule_tensors(self) |
|
|
self.rope.init_tensors() |
|
|
|
|
|
|
|
|
init_scale = self.width**-0.5 |
|
|
|
|
|
if self.use_cls_token: |
|
|
self.class_embedding = nn.Parameter(init_scale * torch.randn(self.width)) |
|
|
|
|
|
if self.use_abs_posemb: |
|
|
self.posemb_grid_size = self.image_size // self.patch_size |
|
|
self.positional_embedding = nn.Parameter( |
|
|
init_scale |
|
|
* torch.randn( |
|
|
int(self.use_cls_token) + self.posemb_grid_size**2, self.width |
|
|
) |
|
|
) |
|
|
|
|
|
if self.proj_dim is not None: |
|
|
self.proj = nn.Parameter( |
|
|
init_scale * torch.randn(self.width, self.proj_dim) |
|
|
) |
|
|
|
|
|
|
|
|
def load_ckpt(self, ckpt_path: str, verbose: bool = True): |
|
|
_sd = torch.load(ckpt_path, weights_only=True) |
|
|
if "state_dict" in _sd: |
|
|
_sd = _sd["state_dict"] |
|
|
elif "weights" in _sd: |
|
|
_sd = _sd["weights"] |
|
|
|
|
|
|
|
|
_sd = {k.replace("module.", ""): v for k, v in _sd.items()} |
|
|
if any(k.startswith("visual.") for k in _sd): |
|
|
_sd = {k.replace("visual.", ""): v for k, v in _sd.items() if "visual" in k} |
|
|
|
|
|
m, u = self.load_state_dict(_sd, strict=False) |
|
|
|
|
|
if verbose or (m or u): |
|
|
logger.info(f"Missing keys for loading vision encoder: {m}") |
|
|
logger.info(f"Unexpected keys for loading vision encoder: {u}") |
|
|
print(f"Missing keys for loading vision encoder: {m}") |
|
|
print(f"Unexpected keys for loading vision encoder: {u}") |
|
|
|
|
|
|
|
|
def truncate(self, layer_idx: int): |
|
|
""" Delete layers so the last layer is the given layer index. """ |
|
|
self.transformer.truncate(layer_idx) |
|
|
self.layers = self.transformer.layers |
|
|
|
|
|
|
|
|
@classmethod |
|
|
def from_config( |
|
|
cls, |
|
|
name: str, |
|
|
pretrained: bool = False, |
|
|
checkpoint_path: Optional[str] = None, |
|
|
**kwdargs |
|
|
): |
|
|
if name not in PE_VISION_CONFIG: |
|
|
raise RuntimeError(f"{name} not found in configs.") |
|
|
|
|
|
args = asdict(PE_VISION_CONFIG[name]) |
|
|
args.update(kwdargs) |
|
|
|
|
|
model = cls(**args) |
|
|
if pretrained: |
|
|
model.load_ckpt(fetch_pe_checkpoint(name, checkpoint_path)) |
|
|
|
|
|
return model |
|
|
|
|
|
@classmethod |
|
|
def available_configs(cls): |
|
|
return list(PE_VISION_CONFIG.keys()) |
|
|
|
|
|
|
|
|
@torch.jit.ignore |
|
|
def set_grad_checkpointing(self, enable=True): |
|
|
self.transformer.set_grad_checkpointing(enable=enable) |
|
|
|
|
|
def _sample_abs_posemb(self, grid_h: int, grid_w: int): |
|
|
"""Interpolates the absolute position embedding if necessary.""" |
|
|
if self.posemb_grid_size == grid_h and self.posemb_grid_size == grid_w: |
|
|
return self.positional_embedding[None, ...] |
|
|
|
|
|
pos_embed = self.positional_embedding |
|
|
if self.use_cls_token: |
|
|
cls_token_embed, pos_embed = pos_embed[:1], pos_embed[1:] |
|
|
|
|
|
pos_embed = ( |
|
|
pos_embed.reshape(1, self.posemb_grid_size, self.posemb_grid_size, -1) |
|
|
.permute(0, 3, 1, 2) |
|
|
.contiguous() |
|
|
) |
|
|
pos_embed = F.interpolate( |
|
|
pos_embed, size=(grid_h, grid_w), mode="bilinear", align_corners=False |
|
|
) |
|
|
pos_embed = pos_embed.permute(0, 2, 3, 1).reshape(-1, self.width).contiguous() |
|
|
|
|
|
if self.use_cls_token: |
|
|
pos_embed = torch.cat([cls_token_embed, pos_embed], dim=0) |
|
|
|
|
|
return pos_embed[None, ...] |
|
|
|
|
|
def _pool(self, x: torch.Tensor): |
|
|
if self.pool_type == "tok": |
|
|
return x[:, 0] |
|
|
elif self.pool_type == "avg": |
|
|
return x.mean(dim=1) |
|
|
elif self.pool_type == "attn": |
|
|
return self.attn_pool(x).squeeze(1) |
|
|
elif self.pool_type == "none": |
|
|
return x |
|
|
else: |
|
|
raise NotImplementedError |
|
|
|
|
|
def forward_features( |
|
|
self, |
|
|
x: torch.Tensor, |
|
|
norm: bool = False, |
|
|
layer_idx: int = -1, |
|
|
strip_cls_token: bool = False, |
|
|
need_weights: bool = False, |
|
|
): |
|
|
batch, _, h, w = x.shape |
|
|
grid_h, grid_w = h // self.patch_size, w // self.patch_size |
|
|
|
|
|
x = self.conv1(x) |
|
|
x = x.permute(0, 2, 3, 1).reshape(batch, -1, self.width) |
|
|
|
|
|
if self.use_cls_token: |
|
|
x = torch.cat( |
|
|
[self.class_embedding.view(1, 1, -1).expand(batch, -1, -1), x], |
|
|
dim=1, |
|
|
) |
|
|
|
|
|
if self.use_abs_posemb: |
|
|
x = x + self._sample_abs_posemb(grid_h, grid_w) |
|
|
|
|
|
if self.use_rope2d: |
|
|
self.rope.update_grid(x.device, grid_h, grid_w) |
|
|
|
|
|
x = self.ln_pre(x) |
|
|
|
|
|
|
|
|
transformer_output = self.transformer(x, layer_idx=layer_idx, need_weights=need_weights) |
|
|
|
|
|
attention_maps = None |
|
|
if need_weights: |
|
|
x, attention_maps = transformer_output |
|
|
else: |
|
|
x = transformer_output |
|
|
|
|
|
if norm: |
|
|
x = self.ln_post(x) |
|
|
|
|
|
if strip_cls_token and self.use_cls_token: |
|
|
x = x[:, 1:, :] |
|
|
|
|
|
if need_weights: |
|
|
return x, attention_maps |
|
|
|
|
|
return x |
|
|
|
|
|
def forward(self, x: torch.Tensor, **kwargs): |
|
|
x = self.forward_features(x, norm=True, **kwargs) |
|
|
x = self._pool(x) |
|
|
|
|
|
if self.proj_dim is not None: |
|
|
x = x @ self.proj |
|
|
|
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
class TextTransformer(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
context_length: int = 72, |
|
|
vocab_size: int = 49408, |
|
|
width: int = 512, |
|
|
heads: int = 8, |
|
|
layers: int = 12, |
|
|
mlp_ratio: float = 4.0, |
|
|
ls_init_value: float = None, |
|
|
output_dim: int = 1280, |
|
|
no_causal_mask: bool = False, |
|
|
pad_id: int = 0, |
|
|
pool_type: str = "argmax", |
|
|
proj_bias: bool = False, |
|
|
act_layer: Callable = nn.GELU, |
|
|
norm_layer: Callable = partial(nn.LayerNorm, eps=1e-5), |
|
|
output_tokens: bool = False, |
|
|
use_ln_post: bool = True, |
|
|
): |
|
|
super().__init__() |
|
|
assert pool_type in ("first", "last", "argmax", "none") |
|
|
self.pool_type = pool_type |
|
|
self.output_tokens = output_tokens |
|
|
self.num_pos = self.context_length = context_length |
|
|
self.vocab_size = vocab_size |
|
|
self.width = width |
|
|
self.output_dim = output_dim |
|
|
self.heads = heads |
|
|
self.pad_id = pad_id |
|
|
self.layers = layers |
|
|
|
|
|
self.token_embedding = nn.Embedding(vocab_size, width) |
|
|
self.positional_embedding = nn.Parameter(torch.empty(self.num_pos, width)) |
|
|
|
|
|
self.transformer = Transformer( |
|
|
width=width, |
|
|
layers=layers, |
|
|
heads=heads, |
|
|
mlp_ratio=mlp_ratio, |
|
|
ls_init_value=ls_init_value, |
|
|
act_layer=act_layer, |
|
|
norm_layer=norm_layer, |
|
|
) |
|
|
|
|
|
self.ln_final = norm_layer(width) if use_ln_post else nn.Identity() |
|
|
|
|
|
if no_causal_mask: |
|
|
self.attn_mask = None |
|
|
else: |
|
|
self.register_buffer( |
|
|
"attn_mask", self.build_causal_mask(), persistent=False |
|
|
) |
|
|
|
|
|
if pool_type == "attn" or pool_type == "attn_eos": |
|
|
self.attn_pool = AttentionPooling( |
|
|
embed_dim=width, |
|
|
num_heads=heads, |
|
|
act_layer=act_layer, |
|
|
norm_layer=norm_layer, |
|
|
) |
|
|
else: |
|
|
self.attn_pool = None |
|
|
|
|
|
if proj_bias: |
|
|
self.text_projection = nn.Linear(width, output_dim) |
|
|
else: |
|
|
self.text_projection = nn.Parameter(torch.empty(width, output_dim)) |
|
|
|
|
|
def build_causal_mask(self): |
|
|
|
|
|
|
|
|
mask = torch.empty(self.num_pos, self.num_pos) |
|
|
mask.fill_(float("-inf")) |
|
|
mask.triu_(1) |
|
|
return mask |
|
|
|
|
|
def load_ckpt(self, ckpt_path: str, verbose: bool = True): |
|
|
_sd = torch.load(ckpt_path, weights_only=True) |
|
|
if "state_dict" in _sd: |
|
|
_sd = _sd["state_dict"] |
|
|
elif "weights" in _sd: |
|
|
_sd = _sd["weights"] |
|
|
|
|
|
_sd = {k.replace("module.", ""): v for k, v in _sd.items()} |
|
|
|
|
|
m, u = self.load_state_dict(_sd, strict=False) |
|
|
|
|
|
if verbose or (m or u): |
|
|
logger.info(f"Missing keys for loading model: {m}") |
|
|
logger.info(f"Unexpected keys for loading model: {u}") |
|
|
print(f"Missing keys for loading model: {m}") |
|
|
print(f"Unexpected keys for loading model: {u}") |
|
|
|
|
|
def build_cls_mask(self, text): |
|
|
cls_mask = (text != self.pad_id).unsqueeze(1) |
|
|
cls_mask = F.pad(cls_mask, (1, 0, cls_mask.shape[2], 0), value=True) |
|
|
additive_mask = torch.empty(cls_mask.shape, device=cls_mask.device) |
|
|
additive_mask.fill_(0) |
|
|
additive_mask.masked_fill_(~cls_mask, float("-inf")) |
|
|
additive_mask = torch.repeat_interleave(additive_mask, self.heads, 0) |
|
|
return additive_mask |
|
|
|
|
|
def text_global_pool( |
|
|
self, x, text: Optional[torch.Tensor] = None, pool_type: str = "argmax" |
|
|
): |
|
|
if pool_type == "first": |
|
|
pooled, tokens = x[:, 0], x[:, 1:] |
|
|
elif pool_type == "last": |
|
|
pooled, tokens = x[:, -1], x[:, :-1] |
|
|
elif pool_type == "argmax": |
|
|
|
|
|
assert text is not None |
|
|
pooled, tokens = x[torch.arange(x.shape[0]), text.argmax(dim=-1)], x |
|
|
else: |
|
|
pooled = tokens = x |
|
|
|
|
|
return pooled, tokens |
|
|
|
|
|
def forward(self, text): |
|
|
seq_len = text.shape[1] |
|
|
x = self.token_embedding( |
|
|
text |
|
|
) |
|
|
attn_mask = self.attn_mask |
|
|
if attn_mask is not None: |
|
|
attn_mask = attn_mask[:seq_len, :seq_len] |
|
|
|
|
|
x = x + self.positional_embedding[:seq_len] |
|
|
x = self.transformer(x, attn_mask=attn_mask) |
|
|
|
|
|
x = self.ln_final(x) |
|
|
pooled, tokens = self.text_global_pool(x, text, pool_type=self.pool_type) |
|
|
|
|
|
if self.text_projection is not None: |
|
|
if isinstance(self.text_projection, nn.Linear): |
|
|
pooled = self.text_projection(pooled) |
|
|
else: |
|
|
pooled = pooled @ self.text_projection |
|
|
|
|
|
if self.output_tokens: |
|
|
return pooled, tokens |
|
|
|
|
|
return pooled |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CLIP(TextTransformer): |
|
|
def __init__( |
|
|
self, |
|
|
vision_cfg: PEConfig, |
|
|
text_cfg: PETextConfig, |
|
|
init_logit_scale: float = np.log(1 / 0.07) |
|
|
): |
|
|
super(CLIP, self).__init__(**asdict(text_cfg)) |
|
|
self.visual = VisionTransformer(**asdict(vision_cfg)) |
|
|
self.image_size = self.visual.image_size |
|
|
self.logit_scale = nn.Parameter(torch.ones([]) * init_logit_scale) |
|
|
|
|
|
|
|
|
def encode_image(self, image, normalize: bool = False): |
|
|
x = self.visual(image) |
|
|
return F.normalize(x, dim=-1) if normalize else x |
|
|
|
|
|
def encode_video(self, video, normalize: bool = False): |
|
|
b, n, c, h, w = video.shape |
|
|
frms = video.reshape(b * n, c, h, w) |
|
|
frm_feats = self.encode_image(frms, normalize=normalize) |
|
|
video_feats = frm_feats.reshape(b, n, -1) |
|
|
video_feats = video_feats.mean(dim=1) |
|
|
return video_feats |
|
|
|
|
|
def encode_text(self, text, normalize: bool = False): |
|
|
x = super().forward(text) |
|
|
return F.normalize(x, dim=-1) if normalize else x |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
image: Optional[torch.Tensor] = None, |
|
|
text: Optional[torch.Tensor] = None, |
|
|
): |
|
|
image_features = ( |
|
|
self.encode_image(image, normalize=True) if image is not None else None |
|
|
) |
|
|
text_features = ( |
|
|
self.encode_text(text, normalize=True) if text is not None else None |
|
|
) |
|
|
return image_features, text_features, self.logit_scale.exp() |
|
|
|
|
|
|
|
|
@classmethod |
|
|
def from_config( |
|
|
cls, |
|
|
name: str, |
|
|
pretrained: bool = False, |
|
|
checkpoint_path: Optional[str] = None |
|
|
): |
|
|
if name not in PE_VISION_CONFIG or name not in PE_TEXT_CONFIG: |
|
|
raise RuntimeError(f"{name} not found in configs.") |
|
|
|
|
|
model = cls(PE_VISION_CONFIG[name], PE_TEXT_CONFIG[name]) |
|
|
if pretrained: |
|
|
model.load_ckpt(fetch_pe_checkpoint(name, checkpoint_path)) |
|
|
|
|
|
return model |
|
|
|
|
|
@classmethod |
|
|
def available_configs(cls): |
|
|
return [k for k in PE_VISION_CONFIG if k in PE_TEXT_CONFIG] |