Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import math, copy | |
import numpy as np | |
from torch.autograd import Variable | |
class BasicBlock(nn.Module): | |
def __init__(self, inplanes, planes, downsample): | |
super(BasicBlock, self).__init__() | |
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=1, padding=1) | |
self.bn1 = nn.BatchNorm2d(planes) | |
self.relu = nn.ReLU(inplace=True) | |
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1) | |
self.bn2 = nn.BatchNorm2d(planes) | |
self.downsample = downsample | |
def forward(self, x): | |
residual = x | |
out = self.conv1(x) | |
out = self.bn1(out) | |
out = self.relu(out) | |
out = self.conv2(out) | |
out = self.bn2(out) | |
if self.downsample != None: | |
residual = self.downsample(residual) | |
out += residual | |
out = self.relu(out) | |
return out | |
class ResNet(nn.Module): | |
def __init__(self, num_in, block, layers): | |
super(ResNet, self).__init__() | |
self.conv1 = nn.Conv2d(num_in, 64, kernel_size=3, stride=1, padding=1) | |
self.bn1 = nn.BatchNorm2d(64) | |
self.relu1 = nn.ReLU(inplace=True) | |
self.pool = nn.MaxPool2d((2, 2), (2, 2)) | |
self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1) | |
self.bn2 = nn.BatchNorm2d(128) | |
self.relu2 = nn.ReLU(inplace=True) | |
self.layer1_pool = nn.MaxPool2d((2, 2), (2, 2)) | |
self.layer1 = self._make_layer(block, 128, 256, layers[0]) | |
self.layer1_conv = nn.Conv2d(256, 256, 3, 1, 1) | |
self.layer1_bn = nn.BatchNorm2d(256) | |
self.layer1_relu = nn.ReLU(inplace=True) | |
self.layer2_pool = nn.MaxPool2d((2, 2), (2, 2)) | |
self.layer2 = self._make_layer(block, 256, 256, layers[1]) | |
self.layer2_conv = nn.Conv2d(256, 256, 3, 1, 1) | |
self.layer2_bn = nn.BatchNorm2d(256) | |
self.layer2_relu = nn.ReLU(inplace=True) | |
self.layer3_pool = nn.MaxPool2d((2, 2), (2, 2)) | |
self.layer3 = self._make_layer(block, 256, 512, layers[2]) | |
self.layer3_conv = nn.Conv2d(512, 512, 3, 1, 1) | |
self.layer3_bn = nn.BatchNorm2d(512) | |
self.layer3_relu = nn.ReLU(inplace=True) | |
self.layer4_pool = nn.MaxPool2d((2, 2), (2, 2)) | |
self.layer4 = self._make_layer(block, 512, 512, layers[3]) | |
self.layer4_conv2 = nn.Conv2d(512, 1024, 3, 1, 1) | |
self.layer4_conv2_bn = nn.BatchNorm2d(1024) | |
self.layer4_conv2_relu = nn.ReLU(inplace=True) | |
def _make_layer(self, block, inplanes, planes, blocks): | |
if inplanes != planes: | |
downsample = nn.Sequential( | |
nn.Conv2d(inplanes, planes, 3, 1, 1), | |
nn.BatchNorm2d(planes), ) | |
else: | |
downsample = None | |
layers = [] | |
layers.append(block(inplanes, planes, downsample)) | |
for i in range(1, blocks): | |
layers.append(block(planes, planes, downsample=None)) | |
return nn.Sequential(*layers) | |
def forward(self, x): | |
x = self.conv1(x) | |
x = self.bn1(x) | |
x = self.relu1(x) | |
x = self.pool(x) | |
x = self.conv2(x) | |
x = self.bn2(x) | |
x = self.relu2(x) | |
x = self.layer1_pool(x) | |
x = self.layer1(x) | |
x = self.layer1_conv(x) | |
x = self.layer1_bn(x) | |
x = self.layer1_relu(x) | |
x = self.layer2_pool(x) | |
x = self.layer2(x) | |
x = self.layer2_conv(x) | |
x = self.layer2_bn(x) | |
x = self.layer2_relu(x) | |
x = self.layer3_pool(x) | |
x = self.layer3(x) | |
x = self.layer3_conv(x) | |
x = self.layer3_bn(x) | |
x = self.layer3_relu(x) | |
x = self.layer4(x) | |
x = self.layer4_conv2(x) | |
x = self.layer4_conv2_bn(x) | |
x = self.layer4_conv2_relu(x) | |
return x | |
def clones(module, N): | |
return nn.ModuleList([copy.deepcopy(module) for _ in range(N)]) | |
class PositionalEncoding(nn.Module): | |
def __init__(self, d_model, dropout, max_len=7000): | |
super(PositionalEncoding, self).__init__() | |
self.dropout = nn.Dropout(p=dropout) | |
pe = torch.zeros(max_len, d_model) | |
position = torch.arange(0, max_len).unsqueeze(1).float() | |
div_term = torch.exp(torch.arange(0, d_model, 2).float() * | |
-(math.log(10000.0) / d_model)) | |
pe[:, 0::2] = torch.sin(position * div_term) | |
pe[:, 1::2] = torch.cos(position * div_term) | |
pe = pe.unsqueeze(0) | |
self.register_buffer('pe', pe) | |
def forward(self, x): | |
x = x + Variable(self.pe[:, :x.size(1)], | |
requires_grad=False) | |
return self.dropout(x) | |
class MultiHeadedAttention(nn.Module): | |
def __init__(self, h, d_model, dropout=0.1, compress_attention=False): | |
super(MultiHeadedAttention, self).__init__() | |
assert d_model % h == 0 | |
self.d_k = d_model // h | |
self.h = h | |
self.linears = clones(nn.Linear(d_model, d_model), 4) | |
self.attn = None | |
self.dropout = nn.Dropout(p=dropout) | |
self.compress_attention = compress_attention | |
self.compress_attention_linear = nn.Linear(h, 1) | |
def forward(self, query, key, value, mask=None, align=None): | |
if mask is not None: | |
mask = mask.unsqueeze(1) | |
nbatches = query.size(0) | |
query, key, value = \ | |
[l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2) | |
for l, x in zip(self.linears, (query, key, value))] | |
x, attention_map = attention(query, key, value, mask=mask, | |
dropout=self.dropout, align=align) | |
x = x.transpose(1, 2).contiguous() \ | |
.view(nbatches, -1, self.h * self.d_k) | |
if self.compress_attention: | |
batch, head, s1, s2 = attention_map.shape | |
attention_map = attention_map.permute(0, 2, 3, 1).contiguous() | |
attention_map = self.compress_attention_linear(attention_map).permute(0, 3, 1, 2).contiguous() | |
return self.linears[-1](x), attention_map | |
def subsequent_mask(size): | |
attn_shape = (1, size, size) | |
subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8') | |
return torch.from_numpy(subsequent_mask) == 0 | |
def attention(query, key, value, mask=None, dropout=None, align=None): | |
d_k = query.size(-1) | |
scores = torch.matmul(query, key.transpose(-2, -1)) \ | |
/ math.sqrt(d_k) | |
if mask is not None: | |
scores = scores.masked_fill(mask == 0, float('-inf')) | |
else: | |
pass | |
p_attn = F.softmax(scores, dim=-1) | |
if dropout is not None: | |
p_attn = dropout(p_attn) | |
return torch.matmul(p_attn, value), p_attn | |
class LayerNorm(nn.Module): | |
def __init__(self, features, eps=1e-6): | |
super(LayerNorm, self).__init__() | |
self.a_2 = nn.Parameter(torch.ones(features)) | |
self.b_2 = nn.Parameter(torch.zeros(features)) | |
self.eps = eps | |
def forward(self, x): | |
mean = x.mean(-1, keepdim=True) | |
std = x.std(-1, keepdim=True) | |
return self.a_2 * (x - mean) / (std + self.eps) + self.b_2 | |
class PositionwiseFeedForward(nn.Module): | |
def __init__(self, d_model, d_ff, dropout=0.1): | |
super(PositionwiseFeedForward, self).__init__() | |
self.w_1 = nn.Linear(d_model, d_ff) | |
self.w_2 = nn.Linear(d_ff, d_model) | |
self.dropout = nn.Dropout(dropout) | |
def forward(self, x): | |
return self.w_2(self.dropout(F.relu(self.w_1(x)))) | |
class Generator(nn.Module): | |
def __init__(self, d_model, vocab, norm=False): | |
super(Generator, self).__init__() | |
self.proj = nn.Linear(d_model, vocab) | |
self.norm = norm | |
self.activation = nn.ReLU() #nn.Sigmoid() | |
def forward(self, x): | |
if self.norm: | |
return self.activation(self.proj(x)) | |
else: | |
return self.proj(x) | |
class Embeddings(nn.Module): | |
def __init__(self, d_model, vocab): | |
super(Embeddings, self).__init__() | |
self.lut = nn.Embedding(vocab, d_model) | |
self.d_model = d_model | |
def forward(self, x): | |
embed = self.lut(x) * math.sqrt(self.d_model) | |
return embed | |
class TransformerBlock(nn.Module): | |
"""Transformer Block""" | |
def __init__(self, dim, num_heads, ff_dim, dropout): | |
super(TransformerBlock, self).__init__() | |
self.attn = MultiHeadedAttention(h=num_heads, d_model=dim, dropout=dropout) | |
self.proj = nn.Linear(dim, dim) | |
self.norm1 = nn.LayerNorm(dim, eps=1e-6) | |
self.pwff = PositionwiseFeedForward(dim, ff_dim) | |
self.norm2 = nn.LayerNorm(dim, eps=1e-6) | |
self.drop = nn.Dropout(dropout) | |
def forward(self, x, mask): | |
x = self.norm1(x) | |
h = self.drop(self.proj(self.attn(x, x, x, mask)[0])) | |
x = x + h | |
h = self.drop(self.pwff(self.norm2(x))) | |
x = x + h | |
return x | |
class Decoder(nn.Module): | |
def __init__(self): | |
super(Decoder, self).__init__() | |
self.mask_multihead = MultiHeadedAttention(h=4, d_model=1024, dropout=0.1) | |
self.mul_layernorm1 = LayerNorm(features=1024) | |
self.multihead = MultiHeadedAttention(h=4, d_model=1024, dropout=0.1, compress_attention=False) | |
self.mul_layernorm2 = LayerNorm(features=1024) | |
self.pff = PositionwiseFeedForward(1024, 2048) | |
self.mul_layernorm3 = LayerNorm(features=1024) | |
def forward(self, text, conv_feature): | |
text_max_length = text.shape[1] | |
mask = subsequent_mask(text_max_length).cuda() | |
result = text | |
result = self.mul_layernorm1(result + self.mask_multihead(result, result, result, mask=mask)[0]) | |
b, c, h, w = conv_feature.shape | |
conv_feature = conv_feature.view(b, c, h * w).permute(0, 2, 1).contiguous() | |
word_image_align, attention_map = self.multihead(result, conv_feature, conv_feature, mask=None) | |
result = self.mul_layernorm2(result + word_image_align) | |
result = self.mul_layernorm3(result + self.pff(result)) | |
return result, attention_map | |
class TransformerOCR(nn.Module): | |
def __init__(self, word_n_class=6738, use_bbox=False): | |
super(TransformerOCR, self).__init__() | |
self.word_n_class = word_n_class | |
self.use_bbox = use_bbox | |
self.embedding_word = Embeddings(512, self.word_n_class) | |
self.pe = PositionalEncoding(d_model=512, dropout=0.1, max_len=7000) | |
self.encoder = ResNet(num_in=3, block=BasicBlock, layers=[3,4,6,3]).cuda() | |
self.decoder = Decoder() | |
self.generator_word = Generator(1024, self.word_n_class) | |
if self.use_bbox: | |
self.generator_loc = Generator(1024, 1, True) | |
def forward(self, image, text_length, text_input, conv_feature=None): | |
if conv_feature is None: | |
conv_feature = self.encoder(image) | |
if text_length is None: | |
return { | |
'conv': conv_feature, | |
} | |
text_embedding = self.embedding_word(text_input) | |
postion_embedding = self.pe(torch.zeros(text_embedding.shape).cuda()).cuda() | |
text_input_with_pe = torch.cat([text_embedding, postion_embedding], 2) | |
text_input_with_pe, attention_map = self.decoder(text_input_with_pe, conv_feature) | |
word_decoder_result = self.generator_word(text_input_with_pe) | |
if self.use_bbox: | |
word_loc_result = self.generator_loc(text_input_with_pe) | |
else: | |
word_loc_result = None | |
return { | |
'pred': word_decoder_result, | |
'map': attention_map, | |
'conv': conv_feature, | |
'loc': word_loc_result | |
} | |
if __name__ == '__main__': | |
net = ResNet(num_in=3, block=BasicBlock, layers=[3, 4, 6, 3]).cuda() | |
image = torch.Tensor(8, 3, 64, 64).cuda() | |
result = net(image) | |
print(result.shape) | |
pass | |