Spaces:
Runtime error
Runtime error
from einops.layers.torch import Rearrange | |
import einops | |
import math | |
import torch.nn as nn | |
import torch | |
class AddPositionalEncoding(nn.Module): | |
def __init__(self, d_model, max_sequence_len=5000): | |
super().__init__() | |
# pos - position in sequence, i - index of element embedding | |
# PE(pos, 2i) = sin(pos / 10000**(2i / d_model)) = sin(pos * e**(2i * (-log(10000))/d_model)) | |
# PE(pos, 2i+1) = cos(pos / 10000**(2i / d_model)) = cos(pos * e**(2i * (-log(10000))/d_model)) | |
positions = torch.arange(max_sequence_len) | |
even_embedding_indices = torch.arange(0, d_model, 2) | |
expression = torch.exp(even_embedding_indices * (-math.log(10000.0) / d_model)) | |
expression = torch.einsum("i, j -> ij", positions, expression) | |
even_encodings = torch.sin(expression) | |
odd_encodings = torch.cos(expression) | |
positional_encodings = einops.rearrange( | |
[even_encodings, odd_encodings], | |
'even_odd pos embed -> pos (embed even_odd)' | |
) | |
self.register_buffer('positional_encodings', positional_encodings) | |
def forward(self, batch): | |
seq_len = batch.size(1) | |
positional_encodings = self.positional_encodings[:seq_len, :] | |
# implicit batch broadcasting | |
return batch + positional_encodings | |
class ImageEmbedding(nn.Module): | |
"""Reshape image into patches and project into given dimension""" | |
def __init__(self, d_model, input_width, input_height, patch_size=16, dropout=.1): | |
super().__init__() | |
assert input_width % patch_size == 0 and input_height % patch_size == 0, \ | |
"Cannot split image in patches" | |
tokenize = Rearrange( | |
'b c (h1 h2) (w1 w2) -> b (c h1 w1) (h2 w2)', | |
h2=patch_size, | |
w2=patch_size | |
) | |
project = nn.Linear(patch_size ** 2, d_model) | |
self.embed = nn.Sequential( | |
tokenize, | |
project, | |
AddPositionalEncoding(d_model), | |
nn.Dropout(p=dropout) | |
) | |
def forward(self, image_batch): | |
image_batch = self.embed(image_batch) | |
return image_batch | |
class ImageEncoder(nn.Module): | |
""" | |
Given an image, returns its vector representation. | |
""" | |
def __init__(self, image_width, image_height, d_model, num_layers=8): | |
super().__init__() | |
image_embedding = ImageEmbedding(d_model, image_width, image_height) | |
encoder_layer = nn.TransformerEncoderLayer( | |
d_model=d_model, | |
nhead=8, | |
dim_feedforward=2048, | |
batch_first=True | |
) | |
transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) | |
self.encode = nn.Sequential(image_embedding, transformer_encoder) | |
def forward(self, batch): | |
return self.encode(batch) | |
class Seq2SeqTransformer(nn.Module): | |
def __init__(self, | |
num_encoder_layers: int, | |
num_decoder_layers: int, | |
emb_size: int, | |
nhead: int, | |
image_width: int, | |
image_height: int, | |
tgt_vocab_size: int, | |
dim_feedforward: int = 512, | |
dropout: float = 0.1): | |
super(Seq2SeqTransformer, self).__init__() | |
self.transformer = nn.Transformer(d_model=emb_size, | |
nhead=nhead, | |
num_encoder_layers=num_encoder_layers, | |
num_decoder_layers=num_decoder_layers, | |
dim_feedforward=dim_feedforward, | |
dropout=dropout) | |
# TODO: share weights between generator and embedding | |
self.generator = nn.Linear(emb_size, tgt_vocab_size) | |
self.src_tok_emb = ImageEmbedding(emb_size, image_width, image_height, dropout=dropout) | |
self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size) | |
def forward(self, | |
src: Tensor, | |
trg: Tensor, | |
src_mask: Tensor, | |
tgt_mask: Tensor, | |
src_padding_mask: Tensor, | |
tgt_padding_mask: Tensor, | |
memory_key_padding_mask: Tensor): | |
src_emb = self.positional_encoding(self.src_tok_emb(src)) | |
tgt_emb = self.positional_encoding(self.tgt_tok_emb(trg)) | |
outs = self.transformer(src_emb, tgt_emb, src_mask, tgt_mask, None, | |
src_padding_mask, tgt_padding_mask, memory_key_padding_mask) | |
return self.generator(outs) | |
def encode(self, src: Tensor, src_mask: Tensor): | |
return self.transformer.encoder(self.positional_encoding( | |
self.src_tok_emb(src)), src_mask) | |
def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor): | |
return self.transformer.decoder(self.positional_encoding( | |
self.tgt_tok_emb(tgt)), memory, | |
tgt_mask) | |