from collections import OrderedDict from typing import Tuple, Union import logging import os import numpy as np import torch import torch.nn.functional as F from torch import nn from timm.models.layers import DropPath, trunc_normal_ from .registry import register_lang_encoder logger = logging.getLogger(__name__) class LayerNorm(nn.Module): def __init__(self, hidden_size, eps=1e-12): """Construct a layernorm module in the TF style (epsilon inside the square root). """ super(LayerNorm, self).__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.bias = nn.Parameter(torch.zeros(hidden_size)) self.variance_epsilon = eps def forward(self, x): pdtype = x.dtype x = x.float() u = x.mean(-1, keepdim=True) s = (x - u).pow(2).mean(-1, keepdim=True) x = (x - u) / torch.sqrt(s + self.variance_epsilon) return self.weight * x.to(pdtype) + self.bias class QuickGELU(nn.Module): def forward(self, x: torch.Tensor): return x * torch.sigmoid(1.702 * x) class ResidualAttentionBlock(nn.Module): def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None, drop_path: float = 0.0): super().__init__() self.attn = nn.MultiheadAttention(d_model, n_head) self.ln_1 = LayerNorm(d_model) self.mlp = nn.Sequential(OrderedDict([ ("c_fc", nn.Linear(d_model, d_model * 4)), ("gelu", QuickGELU()), ("c_proj", nn.Linear(d_model * 4, d_model)) ])) self.ln_2 = LayerNorm(d_model) self.attn_mask = attn_mask self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() def attention(self, x: torch.Tensor, key_padding_mask: torch.Tensor = None): self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) \ if self.attn_mask is not None else None return self.attn( x, x, x, key_padding_mask=key_padding_mask, need_weights=False, attn_mask=self.attn_mask )[0] def forward(self, x: torch.Tensor, key_padding_mask: torch.Tensor = None): x = x + self.drop_path(self.attention(self.ln_1(x), key_padding_mask=key_padding_mask)) x = x + self.drop_path(self.mlp(self.ln_2(x))) return x class Transformer(nn.Module): def __init__(self, context_length: int, vocab_size: int, width: int, layers: int, heads: int, drop_path: float = 0.0, autogressive: bool =True): super().__init__() self.token_embedding = nn.Embedding(vocab_size, width) self.context_length = context_length self.positional_embedding = nn.Parameter( torch.empty(self.context_length, width) ) self.width = width self.layers = layers self.autogressive = autogressive attn_mask = self.build_attention_mask() if autogressive else None dpr = [x.item() for x in torch.linspace(0, drop_path, layers)] # stochastic depth decay rule self.resblocks = nn.ModuleList( [ ResidualAttentionBlock(width, heads, attn_mask, dpr[i]) for i in range(layers) ] ) self.ln_final = LayerNorm(width) trunc_normal_(self.positional_embedding, std=.02) # nn.init.normal_(self.token_embedding, std=.02) trunc_normal_(self.token_embedding.weight, std=.02) self.apply(self._init_weights) @property def dim_out(self): return self.width def build_attention_mask(self): # lazily create causal attention mask, with full attention between the vision tokens # pytorch uses additive attention mask; fill with -inf mask = torch.empty(self.context_length, self.context_length) mask.fill_(float("-inf")) mask.triu_(1) # zero out the lower diagonal return mask def _init_weights(self, m): if isinstance(m, (nn.Linear, nn.Conv2d)): logger.info('=> init weight of Linear/Conv2d from trunc norm') trunc_normal_(m.weight, std=0.02) if m.bias is not None: logger.info('=> init bias of Linear/Conv2d to zeros') nn.init.constant_(m.bias, 0) elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)): nn.init.constant_(m.bias, 0) def load_pretrained(self, pretrained='', pretrained_layers=[], verbose=True): if os.path.isfile(pretrained): pretrained_dict = torch.load(pretrained, map_location='cpu') logging.info(f'=> loading pretrained model {pretrained}') model_dict = self.state_dict() pretrained_dict = { k: v for k, v in pretrained_dict.items() if k in model_dict.keys() } need_init_state_dict = {} for k, v in pretrained_dict.items(): need_init = ( k.split('.')[0] in pretrained_layers or pretrained_layers[0] == '*' ) if need_init: if verbose: logging.info(f'=> init {k} from {pretrained}') need_init_state_dict[k] = v self.load_state_dict(need_init_state_dict, strict=False) @torch.jit.ignore def no_weight_decay(self): return { 'positional_embedding', 'token_embedding', } def forward(self, input_ids, attention_mask=None): key_padding_mask = (input_ids == 0) if not self.autogressive else None x = self.token_embedding(input_ids) # [batch_size, n_ctx, d_model] x = x + self.positional_embedding x = x.permute(1, 0, 2) # NLD -> LND for block in self.resblocks: x = block(x, key_padding_mask) x = x.permute(1, 0, 2) # LND -> NLD x = self.ln_final(x) return {'last_hidden_state': x} @register_lang_encoder def lang_encoder(config_encoder, tokenizer, verbose, **kwargs): transformer = Transformer( context_length=config_encoder['CONTEXT_LENGTH'], vocab_size=tokenizer.vocab_size, width=config_encoder['WIDTH'], layers=config_encoder['LAYERS'], heads=config_encoder['HEADS'], autogressive=config_encoder.get('AUTOGRESSIVE', True) ) if config_encoder['LOAD_PRETRAINED']: transformer.load_pretrained() return transformer