|
""" |
|
Transformer-based varitional encoder model. |
|
""" |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import math |
|
import copy |
|
|
|
|
|
def clones(module, N): |
|
return nn.ModuleList([copy.deepcopy(module) for _ in range(N)]) |
|
|
|
|
|
def build_mask(base_mask): |
|
assert len(base_mask.shape) == 2 |
|
batch_size, seq_len = base_mask.shape[0], base_mask.shape[-1] |
|
|
|
|
|
sub_mask = torch.tril(torch.ones([seq_len, seq_len], |
|
dtype=torch.uint8)).type_as(base_mask) |
|
sub_mask = sub_mask.unsqueeze(0).expand(batch_size, -1, -1) |
|
base_mask = base_mask.unsqueeze(1).expand(-1, seq_len, -1) |
|
return sub_mask & base_mask |
|
|
|
|
|
class Adaptor(nn.Module): |
|
def __init__(self, input_dim, tar_dim): |
|
super(Adaptor, self).__init__() |
|
|
|
if tar_dim == 32768: |
|
output_channel = 8 |
|
elif tar_dim == 16384: |
|
output_channel = 4 |
|
else: |
|
raise NotImplementedError("only support 512px, 256px does not need this") |
|
|
|
self.tar_dim = tar_dim |
|
|
|
self.fc1 = nn.Linear(input_dim, 4096) |
|
self.ln_fc1 = nn.LayerNorm(4096) |
|
self.fc2 = nn.Linear(4096, 4096) |
|
self.ln_fc2 = nn.LayerNorm(4096) |
|
|
|
self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, padding=1) |
|
self.ln_conv1 = nn.LayerNorm([32, 64, 64]) |
|
self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1) |
|
self.ln_conv2 = nn.LayerNorm([64, 64, 64]) |
|
self.conv3 = nn.Conv2d(in_channels=64, out_channels=output_channel, kernel_size=3, padding=1) |
|
|
|
def forward(self, x): |
|
x = torch.relu(self.ln_fc1(self.fc1(x))) |
|
x = torch.relu(self.ln_fc2(self.fc2(x))) |
|
|
|
x = x.view(-1, 1, 64, 64) |
|
|
|
x = torch.relu(self.ln_conv1(self.conv1(x))) |
|
x = torch.relu(self.ln_conv2(self.conv2(x))) |
|
|
|
x = self.conv3(x) |
|
x = x.view(-1, self.tar_dim) |
|
|
|
return x |
|
|
|
|
|
class Compressor(nn.Module): |
|
def __init__(self, input_dim=4096, tar_dim=2048): |
|
super(Compressor, self).__init__() |
|
|
|
self.fc1 = nn.Linear(input_dim, tar_dim) |
|
self.ln_fc1 = nn.LayerNorm(tar_dim) |
|
self.fc2 = nn.Linear(tar_dim, tar_dim) |
|
|
|
|
|
def forward(self, x): |
|
x = torch.relu(self.ln_fc1(self.fc1(x))) |
|
x = self.fc2(x) |
|
|
|
return x |
|
|
|
|
|
class TransEncoder(nn.Module): |
|
def __init__(self, d_model, N, num_token, head_num, d_ff, latten_size, down_sample_block=3, dropout=0.1, last_norm=True): |
|
super(TransEncoder, self).__init__() |
|
self.N = N |
|
if d_model==4096: |
|
|
|
self.compressor = Compressor(input_dim=d_model, tar_dim=1024) |
|
d_model = 1024 |
|
else: |
|
self.compressor = None |
|
|
|
self.layers = clones(EncoderLayer(MultiHeadAttentioin(d_model, head_num, dropout=dropout), |
|
FeedForward(d_model, d_ff, dropout=dropout), |
|
LayerNorm(d_model), |
|
LayerNorm(d_model)), N) |
|
|
|
self.reduction_layers = nn.ModuleList() |
|
for _ in range(down_sample_block): |
|
self.reduction_layers.append( |
|
EncoderReductionLayer(MultiHeadAttentioin(d_model, head_num, dropout=dropout), |
|
FeedForward(d_model, d_ff, dropout=dropout), |
|
nn.Linear(d_model, d_model // 2), |
|
LayerNorm(d_model), |
|
LayerNorm(d_model))) |
|
d_model = d_model // 2 |
|
|
|
if latten_size == 8192 or latten_size == 4096: |
|
self.arc = 0 |
|
self.linear = nn.Linear(d_model*num_token, latten_size) |
|
self.norm = LayerNorm(latten_size) if last_norm else None |
|
else: |
|
self.arc = 1 |
|
self.adaptor = Adaptor(d_model*num_token, latten_size) |
|
|
|
|
|
def forward(self, x, mask): |
|
mask = mask.unsqueeze(1) |
|
|
|
if self.compressor is not None: |
|
x = self.compressor(x) |
|
|
|
for i, layer in enumerate(self.layers): |
|
x = layer(x, mask) |
|
|
|
for i, layer in enumerate(self.reduction_layers): |
|
x = layer(x, mask) |
|
|
|
if self.arc == 0: |
|
x = self.linear(x.view(x.shape[0],-1)) |
|
x = self.norm(x) if self.norm else x |
|
else: |
|
x = self.adaptor(x.view(x.shape[0],-1)) |
|
|
|
return x |
|
|
|
|
|
class EncoderLayer(nn.Module): |
|
def __init__(self, attn, feed_forward, norm1, norm2, dropout=0.1): |
|
super(EncoderLayer, self).__init__() |
|
self.attn = attn |
|
self.feed_forward = feed_forward |
|
self.norm1, self.norm2 = norm1, norm2 |
|
|
|
self.dropout1 = nn.Dropout(dropout) |
|
self.dropout2 = nn.Dropout(dropout) |
|
|
|
def forward(self, x, mask): |
|
|
|
a = self.attn(x, x, x, mask) |
|
t = self.norm1(x + self.dropout1(a)) |
|
|
|
|
|
z = self.feed_forward(t) |
|
y = self.norm2(t + self.dropout2(z)) |
|
|
|
return y |
|
|
|
|
|
class EncoderReductionLayer(nn.Module): |
|
def __init__(self, attn, feed_forward, reduction, norm1, norm2, dropout=0.1): |
|
super(EncoderReductionLayer, self).__init__() |
|
self.attn = attn |
|
self.feed_forward = feed_forward |
|
self.reduction = reduction |
|
self.norm1, self.norm2 = norm1, norm2 |
|
|
|
self.dropout1 = nn.Dropout(dropout) |
|
self.dropout2 = nn.Dropout(dropout) |
|
|
|
def forward(self, x, mask): |
|
|
|
a = self.attn(x, x, x, mask) |
|
t = self.norm1(x + self.dropout1(a)) |
|
|
|
|
|
z = self.feed_forward(t) |
|
y = self.norm2(t + self.dropout2(z)) |
|
|
|
|
|
|
|
y = self.reduction(y) |
|
|
|
return y |
|
|
|
|
|
class MultiHeadAttentioin(nn.Module): |
|
def __init__(self, d_model, head_num, dropout=0.1, d_v=None): |
|
super(MultiHeadAttentioin, self).__init__() |
|
assert d_model % head_num == 0, "d_model must be divisible by head_num" |
|
|
|
self.d_model = d_model |
|
self.head_num = head_num |
|
self.d_k = d_model // head_num |
|
self.d_v = self.d_k if d_v is None else d_v |
|
|
|
|
|
self.W_Q = nn.Linear(d_model, head_num * self.d_k) |
|
self.W_K = nn.Linear(d_model, head_num * self.d_k) |
|
self.W_V = nn.Linear(d_model, head_num * self.d_v) |
|
self.W_O = nn.Linear(d_model, d_model) |
|
|
|
self.dropout = nn.Dropout(dropout) |
|
|
|
def scaled_dp_attn(self, query, key, value, mask=None): |
|
assert self.d_k == query.shape[-1] |
|
|
|
|
|
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.d_k) |
|
|
|
|
|
|
|
|
|
|
|
|
|
if mask is not None: |
|
assert mask.ndim == 3, "Mask shape {} doesn't seem right...".format(mask.shape) |
|
mask = mask.unsqueeze(1) |
|
try: |
|
if scores.dtype == torch.float32: |
|
scores = scores.masked_fill(mask == 0, -1e9) |
|
else: |
|
scores = scores.masked_fill(mask == 0, -1e4) |
|
except RuntimeError: |
|
print("- scores device: {}".format(scores.device)) |
|
print("- mask device: {}".format(mask.device)) |
|
|
|
|
|
attn = F.softmax(scores, dim=-1) |
|
attn = self.dropout(attn) |
|
return torch.matmul(attn, value), attn |
|
|
|
def forward(self, q, k, v, mask): |
|
batch_size = q.shape[0] |
|
|
|
query = self.W_Q(q).view(batch_size, -1, self.head_num, self.d_k).transpose(1, 2) |
|
key = self.W_K(k).view(batch_size, -1, self.head_num, self.d_k).transpose(1, 2) |
|
value = self.W_V(v).view(batch_size, -1, self.head_num, self.d_k).transpose(1, 2) |
|
|
|
heads, attn = self.scaled_dp_attn(query, key, value, mask) |
|
heads = heads.transpose(1, 2).contiguous().view(batch_size, -1, |
|
self.head_num * self.d_k) |
|
assert heads.shape[-1] == self.d_model and heads.shape[0] == batch_size |
|
|
|
y = self.W_O(heads) |
|
|
|
assert y.shape == q.shape |
|
return y |
|
|
|
|
|
class LayerNorm(nn.Module): |
|
def __init__(self, layer_size, eps=1e-5): |
|
super(LayerNorm, self).__init__() |
|
self.g = nn.Parameter(torch.ones(layer_size)) |
|
self.b = nn.Parameter(torch.zeros(layer_size)) |
|
self.eps = eps |
|
|
|
def forward(self, x): |
|
mean = x.mean(-1, keepdim=True) |
|
std = x.std(-1, keepdim=True) |
|
x = (x - mean) / (std + self.eps) |
|
return self.g * x + self.b |
|
|
|
|
|
class FeedForward(nn.Module): |
|
def __init__(self, d_model, d_ff, dropout=0.1, act='relu', d_output=None): |
|
super(FeedForward, self).__init__() |
|
self.d_model = d_model |
|
self.d_ff = d_ff |
|
d_output = d_model if d_output is None else d_output |
|
|
|
self.ffn_1 = nn.Linear(d_model, d_ff) |
|
self.ffn_2 = nn.Linear(d_ff, d_output) |
|
|
|
if act == 'relu': |
|
self.act = nn.ReLU() |
|
elif act == 'rrelu': |
|
self.act = nn.RReLU() |
|
else: |
|
raise NotImplementedError |
|
|
|
self.dropout = nn.Dropout(dropout) |
|
|
|
def forward(self, x): |
|
y = self.ffn_2(self.dropout(self.act(self.ffn_1(x)))) |
|
return y |
|
|
|
|
|
|