|
from collections import OrderedDict |
|
from typing import Tuple, Union |
|
import logging |
|
import os |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn.functional as F |
|
from torch import nn |
|
|
|
from timm.models.layers import DropPath, trunc_normal_ |
|
|
|
from .registry import register_lang_encoder |
|
from ..Utils import is_main_process |
|
from ..Utils import register_norm_module |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
@register_norm_module |
|
class LayerNorm(nn.Module): |
|
def __init__(self, hidden_size, eps=1e-12): |
|
"""Construct a layernorm module in the TF style (epsilon inside the square root). |
|
""" |
|
super(LayerNorm, self).__init__() |
|
self.weight = nn.Parameter(torch.ones(hidden_size)) |
|
self.bias = nn.Parameter(torch.zeros(hidden_size)) |
|
self.variance_epsilon = eps |
|
|
|
def forward(self, x): |
|
pdtype = x.dtype |
|
x = x.float() |
|
u = x.mean(-1, keepdim=True) |
|
s = (x - u).pow(2).mean(-1, keepdim=True) |
|
x = (x - u) / torch.sqrt(s + self.variance_epsilon) |
|
return self.weight * x.to(pdtype) + self.bias |
|
|
|
|
|
class QuickGELU(nn.Module): |
|
def forward(self, x: torch.Tensor): |
|
return x * torch.sigmoid(1.702 * x) |
|
|
|
|
|
class ResidualAttentionBlock(nn.Module): |
|
def __init__(self, |
|
d_model: int, |
|
n_head: int, |
|
attn_mask: torch.Tensor = None, |
|
drop_path: float = 0.0): |
|
super().__init__() |
|
|
|
self.attn = nn.MultiheadAttention(d_model, n_head) |
|
self.ln_1 = LayerNorm(d_model) |
|
self.mlp = nn.Sequential(OrderedDict([ |
|
("c_fc", nn.Linear(d_model, d_model * 4)), |
|
("gelu", QuickGELU()), |
|
("c_proj", nn.Linear(d_model * 4, d_model)) |
|
])) |
|
self.ln_2 = LayerNorm(d_model) |
|
self.attn_mask = attn_mask |
|
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() |
|
|
|
def attention(self, x: torch.Tensor, key_padding_mask: torch.Tensor = None): |
|
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) \ |
|
if self.attn_mask is not None else None |
|
|
|
|
|
return self.attn( |
|
x, x, x, |
|
key_padding_mask=key_padding_mask, |
|
need_weights=False, |
|
attn_mask=self.attn_mask |
|
)[0] |
|
|
|
def forward(self, x: torch.Tensor, key_padding_mask: torch.Tensor = None): |
|
x = x + self.drop_path(self.attention(self.ln_1(x), key_padding_mask=key_padding_mask)) |
|
x = x + self.drop_path(self.mlp(self.ln_2(x))) |
|
return x |
|
|
|
|
|
class Transformer(nn.Module): |
|
def __init__(self, |
|
context_length: int, |
|
vocab_size: int, |
|
width: int, |
|
layers: int, |
|
heads: int, |
|
drop_path: float = 0.0, |
|
autogressive: bool =True, |
|
key_padding_token: int = 0, |
|
): |
|
super().__init__() |
|
|
|
self.token_embedding = nn.Embedding(vocab_size, width) |
|
self.key_padding_token = key_padding_token |
|
|
|
self.context_length = context_length |
|
self.positional_embedding = nn.Parameter( |
|
torch.empty(self.context_length, width) |
|
) |
|
|
|
self.width = width |
|
self.layers = layers |
|
self.autogressive = autogressive |
|
attn_mask = self.build_attention_mask() if autogressive else None |
|
dpr = [x.item() for x in torch.linspace(0, drop_path, layers)] |
|
self.resblocks = nn.ModuleList( |
|
[ |
|
ResidualAttentionBlock(width, heads, attn_mask, dpr[i]) |
|
for i in range(layers) |
|
] |
|
) |
|
|
|
self.ln_final = LayerNorm(width) |
|
|
|
trunc_normal_(self.positional_embedding, std=.02) |
|
|
|
trunc_normal_(self.token_embedding.weight, std=.02) |
|
self.apply(self._init_weights) |
|
|
|
@property |
|
def dim_out(self): |
|
return self.width |
|
|
|
def build_attention_mask(self): |
|
|
|
|
|
mask = torch.empty(self.context_length, self.context_length) |
|
mask.fill_(float("-inf")) |
|
mask.triu_(1) |
|
return mask |
|
|
|
def _init_weights(self, m): |
|
if isinstance(m, (nn.Linear, nn.Conv2d)): |
|
if is_main_process(): |
|
logger.info('=> init weight of Linear/Conv2d from trunc norm') |
|
trunc_normal_(m.weight, std=0.02) |
|
if m.bias is not None: |
|
if is_main_process(): |
|
logger.info('=> init bias of Linear/Conv2d to zeros') |
|
nn.init.constant_(m.bias, 0) |
|
elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)): |
|
nn.init.constant_(m.bias, 0) |
|
|
|
def load_pretrained(self, pretrained='', pretrained_layers=[], verbose=True): |
|
if os.path.isfile(pretrained): |
|
pretrained_dict = torch.load(pretrained, map_location='cpu') |
|
logging.info(f'=> loading pretrained model {pretrained}') |
|
model_dict = self.state_dict() |
|
pretrained_dict = { |
|
k: v for k, v in pretrained_dict.items() |
|
if k in model_dict.keys() |
|
} |
|
need_init_state_dict = {} |
|
for k, v in pretrained_dict.items(): |
|
need_init = ( |
|
k.split('.')[0] in pretrained_layers |
|
or pretrained_layers[0] == '*' |
|
) |
|
if need_init: |
|
if verbose: |
|
logging.info(f'=> init {k} from {pretrained}') |
|
|
|
need_init_state_dict[k] = v |
|
self.load_state_dict(need_init_state_dict, strict=False) |
|
|
|
|
|
@torch.jit.ignore |
|
def no_weight_decay(self): |
|
return { |
|
'positional_embedding', |
|
'token_embedding', |
|
} |
|
|
|
def forward(self, input_ids, attention_mask=None): |
|
input_ids = input_ids.to(self.positional_embedding.device, non_blocking=True) |
|
|
|
|
|
|
|
|
|
key_padding_mask = (attention_mask == 0) if not self.autogressive else None |
|
|
|
x = self.token_embedding(input_ids) |
|
x = x + self.positional_embedding |
|
x = x.permute(1, 0, 2) |
|
for block in self.resblocks: |
|
x = block(x, key_padding_mask) |
|
x = x.permute(1, 0, 2) |
|
|
|
x = self.ln_final(x) |
|
|
|
return {'last_hidden_state': x} |
|
|
|
|
|
@register_lang_encoder |
|
def lang_encoder(config_encoder, tokenizer, verbose, **kwargs): |
|
transformer = Transformer( |
|
context_length=config_encoder['CONTEXT_LENGTH'], |
|
vocab_size=tokenizer.vocab_size, |
|
width=config_encoder['WIDTH'], |
|
layers=config_encoder['LAYERS'], |
|
heads=config_encoder['HEADS'], |
|
autogressive=config_encoder.get('AUTOGRESSIVE', True), |
|
key_padding_token=config_encoder.get('KEY_PADDING_TOKEN', 0), |
|
) |
|
|
|
if config_encoder['LOAD_PRETRAINED']: |
|
transformer.load_pretrained() |
|
|
|
return transformer |
|
|