""" Transformer-based varitional encoder model. """ import torch import torch.nn as nn import torch.nn.functional as F import math import copy def clones(module, N): return nn.ModuleList([copy.deepcopy(module) for _ in range(N)]) def build_mask(base_mask): assert len(base_mask.shape) == 2 batch_size, seq_len = base_mask.shape[0], base_mask.shape[-1] # create subsequent token mask sub_mask = torch.tril(torch.ones([seq_len, seq_len], dtype=torch.uint8)).type_as(base_mask) sub_mask = sub_mask.unsqueeze(0).expand(batch_size, -1, -1) base_mask = base_mask.unsqueeze(1).expand(-1, seq_len, -1) return sub_mask & base_mask class Adaptor(nn.Module): def __init__(self, input_dim, tar_dim): super(Adaptor, self).__init__() if tar_dim == 32768: output_channel = 8 elif tar_dim == 16384: output_channel = 4 else: raise NotImplementedError("only support 512px, 256px does not need this") self.tar_dim = tar_dim self.fc1 = nn.Linear(input_dim, 4096) self.ln_fc1 = nn.LayerNorm(4096) self.fc2 = nn.Linear(4096, 4096) self.ln_fc2 = nn.LayerNorm(4096) self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, padding=1) self.ln_conv1 = nn.LayerNorm([32, 64, 64]) self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1) self.ln_conv2 = nn.LayerNorm([64, 64, 64]) self.conv3 = nn.Conv2d(in_channels=64, out_channels=output_channel, kernel_size=3, padding=1) def forward(self, x): x = torch.relu(self.ln_fc1(self.fc1(x))) x = torch.relu(self.ln_fc2(self.fc2(x))) x = x.view(-1, 1, 64, 64) x = torch.relu(self.ln_conv1(self.conv1(x))) x = torch.relu(self.ln_conv2(self.conv2(x))) x = self.conv3(x) x = x.view(-1, self.tar_dim) return x class Compressor(nn.Module): def __init__(self, input_dim=4096, tar_dim=2048): super(Compressor, self).__init__() self.fc1 = nn.Linear(input_dim, tar_dim) self.ln_fc1 = nn.LayerNorm(tar_dim) self.fc2 = nn.Linear(tar_dim, tar_dim) def forward(self, x): x = torch.relu(self.ln_fc1(self.fc1(x))) x = self.fc2(x) return x class TransEncoder(nn.Module): def __init__(self, d_model, N, num_token, head_num, d_ff, latten_size, down_sample_block=3, dropout=0.1, last_norm=True): super(TransEncoder, self).__init__() self.N = N if d_model==4096: # for T5-XXL, first use MLP to compress into 1024 self.compressor = Compressor(input_dim=d_model, tar_dim=1024) d_model = 1024 else: self.compressor = None self.layers = clones(EncoderLayer(MultiHeadAttentioin(d_model, head_num, dropout=dropout), FeedForward(d_model, d_ff, dropout=dropout), LayerNorm(d_model), LayerNorm(d_model)), N) self.reduction_layers = nn.ModuleList() for _ in range(down_sample_block): self.reduction_layers.append( EncoderReductionLayer(MultiHeadAttentioin(d_model, head_num, dropout=dropout), FeedForward(d_model, d_ff, dropout=dropout), nn.Linear(d_model, d_model // 2), LayerNorm(d_model), LayerNorm(d_model))) d_model = d_model // 2 if latten_size == 8192 or latten_size == 4096: self.arc = 0 self.linear = nn.Linear(d_model*num_token, latten_size) self.norm = LayerNorm(latten_size) if last_norm else None else: self.arc = 1 self.adaptor = Adaptor(d_model*num_token, latten_size) def forward(self, x, mask): mask = mask.unsqueeze(1) if self.compressor is not None: x = self.compressor(x) for i, layer in enumerate(self.layers): x = layer(x, mask) for i, layer in enumerate(self.reduction_layers): x = layer(x, mask) if self.arc == 0: x = self.linear(x.view(x.shape[0],-1)) x = self.norm(x) if self.norm else x else: x = self.adaptor(x.view(x.shape[0],-1)) return x class EncoderLayer(nn.Module): def __init__(self, attn, feed_forward, norm1, norm2, dropout=0.1): super(EncoderLayer, self).__init__() self.attn = attn self.feed_forward = feed_forward self.norm1, self.norm2 = norm1, norm2 self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout) def forward(self, x, mask): # multihead attn & norm a = self.attn(x, x, x, mask) t = self.norm1(x + self.dropout1(a)) # feed forward & norm z = self.feed_forward(t) # linear(dropout(act(linear(x))))) y = self.norm2(t + self.dropout2(z)) return y class EncoderReductionLayer(nn.Module): def __init__(self, attn, feed_forward, reduction, norm1, norm2, dropout=0.1): super(EncoderReductionLayer, self).__init__() self.attn = attn self.feed_forward = feed_forward self.reduction = reduction self.norm1, self.norm2 = norm1, norm2 self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout) def forward(self, x, mask): # multihead attn & norm a = self.attn(x, x, x, mask) t = self.norm1(x + self.dropout1(a)) # feed forward & norm z = self.feed_forward(t) # linear(dropout(act(linear(x))))) y = self.norm2(t + self.dropout2(z)) # reduction # y = self.reduction(y).view(x.shape[0], -1, x.shape[-1]) y = self.reduction(y) return y class MultiHeadAttentioin(nn.Module): def __init__(self, d_model, head_num, dropout=0.1, d_v=None): super(MultiHeadAttentioin, self).__init__() assert d_model % head_num == 0, "d_model must be divisible by head_num" self.d_model = d_model self.head_num = head_num self.d_k = d_model // head_num self.d_v = self.d_k if d_v is None else d_v # d_model = d_k * head_num self.W_Q = nn.Linear(d_model, head_num * self.d_k) self.W_K = nn.Linear(d_model, head_num * self.d_k) self.W_V = nn.Linear(d_model, head_num * self.d_v) self.W_O = nn.Linear(d_model, d_model) self.dropout = nn.Dropout(dropout) def scaled_dp_attn(self, query, key, value, mask=None): assert self.d_k == query.shape[-1] # scores: [batch_size, head_num, seq_len, seq_len] scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.d_k) # if torch.isinf(scores).any(): # # to avoid leaking # scores = torch.where(scores == float('-inf'), torch.tensor(-65504.0), scores) # scores = torch.where(scores == float('inf'), torch.tensor(65504.0), scores) if mask is not None: assert mask.ndim == 3, "Mask shape {} doesn't seem right...".format(mask.shape) mask = mask.unsqueeze(1) try: if scores.dtype == torch.float32: scores = scores.masked_fill(mask == 0, -1e9) else: scores = scores.masked_fill(mask == 0, -1e4) except RuntimeError: print("- scores device: {}".format(scores.device)) print("- mask device: {}".format(mask.device)) # attn: [batch_size, head_num, seq_len, seq_len] attn = F.softmax(scores, dim=-1) attn = self.dropout(attn) return torch.matmul(attn, value), attn def forward(self, q, k, v, mask): batch_size = q.shape[0] query = self.W_Q(q).view(batch_size, -1, self.head_num, self.d_k).transpose(1, 2) key = self.W_K(k).view(batch_size, -1, self.head_num, self.d_k).transpose(1, 2) value = self.W_V(v).view(batch_size, -1, self.head_num, self.d_k).transpose(1, 2) heads, attn = self.scaled_dp_attn(query, key, value, mask) heads = heads.transpose(1, 2).contiguous().view(batch_size, -1, self.head_num * self.d_k) assert heads.shape[-1] == self.d_model and heads.shape[0] == batch_size y = self.W_O(heads) assert y.shape == q.shape return y class LayerNorm(nn.Module): def __init__(self, layer_size, eps=1e-5): super(LayerNorm, self).__init__() self.g = nn.Parameter(torch.ones(layer_size)) self.b = nn.Parameter(torch.zeros(layer_size)) self.eps = eps def forward(self, x): mean = x.mean(-1, keepdim=True) std = x.std(-1, keepdim=True) x = (x - mean) / (std + self.eps) return self.g * x + self.b class FeedForward(nn.Module): def __init__(self, d_model, d_ff, dropout=0.1, act='relu', d_output=None): super(FeedForward, self).__init__() self.d_model = d_model self.d_ff = d_ff d_output = d_model if d_output is None else d_output self.ffn_1 = nn.Linear(d_model, d_ff) self.ffn_2 = nn.Linear(d_ff, d_output) if act == 'relu': self.act = nn.ReLU() elif act == 'rrelu': self.act = nn.RReLU() else: raise NotImplementedError self.dropout = nn.Dropout(dropout) def forward(self, x): y = self.ffn_2(self.dropout(self.act(self.ffn_1(x)))) return y