Spaces:
Build error
Build error
from collections import OrderedDict | |
from typing import Tuple, Union | |
import logging | |
import os | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
from torch import nn | |
from timm.models.layers import DropPath, trunc_normal_ | |
from .registry import register_lang_encoder | |
logger = logging.getLogger(__name__) | |
class LayerNorm(nn.Module): | |
def __init__(self, hidden_size, eps=1e-12): | |
"""Construct a layernorm module in the TF style (epsilon inside the square root). | |
""" | |
super(LayerNorm, self).__init__() | |
self.weight = nn.Parameter(torch.ones(hidden_size)) | |
self.bias = nn.Parameter(torch.zeros(hidden_size)) | |
self.variance_epsilon = eps | |
def forward(self, x): | |
pdtype = x.dtype | |
x = x.float() | |
u = x.mean(-1, keepdim=True) | |
s = (x - u).pow(2).mean(-1, keepdim=True) | |
x = (x - u) / torch.sqrt(s + self.variance_epsilon) | |
return self.weight * x.to(pdtype) + self.bias | |
class QuickGELU(nn.Module): | |
def forward(self, x: torch.Tensor): | |
return x * torch.sigmoid(1.702 * x) | |
class ResidualAttentionBlock(nn.Module): | |
def __init__(self, | |
d_model: int, | |
n_head: int, | |
attn_mask: torch.Tensor = None, | |
drop_path: float = 0.0): | |
super().__init__() | |
self.attn = nn.MultiheadAttention(d_model, n_head) | |
self.ln_1 = LayerNorm(d_model) | |
self.mlp = nn.Sequential(OrderedDict([ | |
("c_fc", nn.Linear(d_model, d_model * 4)), | |
("gelu", QuickGELU()), | |
("c_proj", nn.Linear(d_model * 4, d_model)) | |
])) | |
self.ln_2 = LayerNorm(d_model) | |
self.attn_mask = attn_mask | |
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() | |
def attention(self, x: torch.Tensor, key_padding_mask: torch.Tensor = None): | |
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) \ | |
if self.attn_mask is not None else None | |
return self.attn( | |
x, x, x, | |
key_padding_mask=key_padding_mask, | |
need_weights=False, | |
attn_mask=self.attn_mask | |
)[0] | |
def forward(self, x: torch.Tensor, key_padding_mask: torch.Tensor = None): | |
x = x + self.drop_path(self.attention(self.ln_1(x), key_padding_mask=key_padding_mask)) | |
x = x + self.drop_path(self.mlp(self.ln_2(x))) | |
return x | |
class Transformer(nn.Module): | |
def __init__(self, | |
context_length: int, | |
vocab_size: int, | |
width: int, | |
layers: int, | |
heads: int, | |
drop_path: float = 0.0, | |
autogressive: bool =True): | |
super().__init__() | |
self.token_embedding = nn.Embedding(vocab_size, width) | |
self.context_length = context_length | |
self.positional_embedding = nn.Parameter( | |
torch.empty(self.context_length, width) | |
) | |
self.width = width | |
self.layers = layers | |
self.autogressive = autogressive | |
attn_mask = self.build_attention_mask() if autogressive else None | |
dpr = [x.item() for x in torch.linspace(0, drop_path, layers)] # stochastic depth decay rule | |
self.resblocks = nn.ModuleList( | |
[ | |
ResidualAttentionBlock(width, heads, attn_mask, dpr[i]) | |
for i in range(layers) | |
] | |
) | |
self.ln_final = LayerNorm(width) | |
trunc_normal_(self.positional_embedding, std=.02) | |
# nn.init.normal_(self.token_embedding, std=.02) | |
trunc_normal_(self.token_embedding.weight, std=.02) | |
self.apply(self._init_weights) | |
def dim_out(self): | |
return self.width | |
def build_attention_mask(self): | |
# lazily create causal attention mask, with full attention between the vision tokens | |
# pytorch uses additive attention mask; fill with -inf | |
mask = torch.empty(self.context_length, self.context_length) | |
mask.fill_(float("-inf")) | |
mask.triu_(1) # zero out the lower diagonal | |
return mask | |
def _init_weights(self, m): | |
if isinstance(m, (nn.Linear, nn.Conv2d)): | |
logger.info('=> init weight of Linear/Conv2d from trunc norm') | |
trunc_normal_(m.weight, std=0.02) | |
if m.bias is not None: | |
logger.info('=> init bias of Linear/Conv2d to zeros') | |
nn.init.constant_(m.bias, 0) | |
elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)): | |
nn.init.constant_(m.bias, 0) | |
def load_pretrained(self, pretrained='', pretrained_layers=[], verbose=True): | |
if os.path.isfile(pretrained): | |
pretrained_dict = torch.load(pretrained, map_location='cpu') | |
logging.info(f'=> loading pretrained model {pretrained}') | |
model_dict = self.state_dict() | |
pretrained_dict = { | |
k: v for k, v in pretrained_dict.items() | |
if k in model_dict.keys() | |
} | |
need_init_state_dict = {} | |
for k, v in pretrained_dict.items(): | |
need_init = ( | |
k.split('.')[0] in pretrained_layers | |
or pretrained_layers[0] == '*' | |
) | |
if need_init: | |
if verbose: | |
logging.info(f'=> init {k} from {pretrained}') | |
need_init_state_dict[k] = v | |
self.load_state_dict(need_init_state_dict, strict=False) | |
def no_weight_decay(self): | |
return { | |
'positional_embedding', | |
'token_embedding', | |
} | |
def forward(self, input_ids, attention_mask=None): | |
key_padding_mask = (input_ids == 0) if not self.autogressive else None | |
x = self.token_embedding(input_ids) # [batch_size, n_ctx, d_model] | |
x = x + self.positional_embedding | |
x = x.permute(1, 0, 2) # NLD -> LND | |
for block in self.resblocks: | |
x = block(x, key_padding_mask) | |
x = x.permute(1, 0, 2) # LND -> NLD | |
x = self.ln_final(x) | |
return {'last_hidden_state': x} | |
def lang_encoder(config_encoder, tokenizer, verbose, **kwargs): | |
transformer = Transformer( | |
context_length=config_encoder['CONTEXT_LENGTH'], | |
vocab_size=tokenizer.vocab_size, | |
width=config_encoder['WIDTH'], | |
layers=config_encoder['LAYERS'], | |
heads=config_encoder['HEADS'], | |
autogressive=config_encoder.get('AUTOGRESSIVE', True) | |
) | |
if config_encoder['LOAD_PRETRAINED']: | |
transformer.load_pretrained() | |
return transformer | |