| from torch import nn |
| import torch |
| from collections import OrderedDict |
| from transformers.models.bert.tokenization_bert import BertTokenizer |
|
|
|
|
| class text_process(object): |
| def __init__(self, context_length=80, mlm_probability=0.15): |
| self.context_length = context_length |
| self.mlm_probability = mlm_probability |
|
|
| bert_path = './bert' |
| self.tokenizer = BertTokenizer.from_pretrained(bert_path, model_max_length=context_length) |
|
|
| def __call__(self, text): |
| text = self.tokenizer(_preprocess_text(text), return_tensors="pt", truncation=True, padding='max_length') |
| text_ids = text['input_ids'] |
| attention_mask = text['attention_mask'] |
|
|
| return text_ids[0] |
|
|
| def __repr__(self): |
| repr = "(DataAugmentationForBERT,\n" |
| repr += f" content_length = {self.context_length},\n" |
| repr += f" mlm_probability = {self.mlm_probability},\n" |
| repr += ")" |
| return repr |
|
|
|
|
| class LayerNorm(nn.LayerNorm): |
| """Subclass torch's LayerNorm to handle fp16.""" |
|
|
| def forward(self, x: torch.Tensor): |
| orig_type = x.dtype |
| ret = super().forward(x.type(torch.float32)) |
| return ret.type(orig_type) |
|
|
| class QuickGELU(nn.Module): |
| def forward(self, x: torch.Tensor): |
| return x * torch.sigmoid(1.702 * x) |
|
|
| class ResidualAttentionBlock(nn.Module): |
| def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): |
| super().__init__() |
|
|
| self.attn = nn.MultiheadAttention(d_model, n_head) |
| self.ln_1 = LayerNorm(d_model) |
| self.mlp = nn.Sequential(OrderedDict([ |
| ("c_fc", nn.Linear(d_model, d_model * 4)), |
| ("gelu", QuickGELU()), |
| ("c_proj", nn.Linear(d_model * 4, d_model)) |
| ])) |
| self.ln_2 = LayerNorm(d_model) |
| self.attn_mask = attn_mask |
|
|
| def attention(self, x: torch.Tensor): |
| self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None |
| return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] |
|
|
| def forward(self, x: torch.Tensor): |
| x = x + self.attention(self.ln_1(x)) |
| x = x + self.mlp(self.ln_2(x)) |
| return x |
| |
|
|
| class Transformer(nn.Module): |
| def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): |
| super().__init__() |
| self.width = width |
| self.layers = layers |
| self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) |
|
|
| def forward(self, x: torch.Tensor): |
| |
| for resblock in self.resblocks: |
| |
| x = torch.utils.checkpoint.checkpoint(resblock, x, use_reentrant=False) |
| return x |
|
|
|
|
| class VisualTransformer(nn.Module): |
| def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int): |
| super().__init__() |
| self.input_resolution = input_resolution |
| self.output_dim = output_dim |
| self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) |
|
|
| scale = width ** -0.5 |
| self.class_embedding = nn.Parameter(scale * torch.randn(width)) |
| self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) |
| self.ln_pre = LayerNorm(width) |
|
|
| self.transformer = Transformer(width, layers, heads) |
|
|
| self.ln_post = LayerNorm(width) |
| self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) |
|
|
| def forward(self, x: torch.Tensor): |
| x = self.conv1(x) |
| x = x.reshape(x.shape[0], x.shape[1], -1) |
| x = x.permute(0, 2, 1) |
| x = torch.cat( |
| [self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), |
| x], dim=1) |
| x = x + self.positional_embedding.to(x.dtype) |
| x = self.ln_pre(x) |
|
|
| x = x.permute(1, 0, 2) |
| x = self.transformer(x) |
| x = x.permute(1, 0, 2) |
|
|
| |
| x = self.ln_post(x) |
|
|
| if self.proj is not None: |
| x = x @ self.proj |
|
|
| return x[:, 0, :], x |
| |
|
|
| def _preprocess_text(text): |
| |
| text = text.lower().replace("“", "\"").replace("”", "\"") |
| return text |