ML2TransformerApp / model.py
dkoshman
removed data, added tex preloading
e949d7b
raw
history blame
No virus
4.98 kB
from einops.layers.torch import Rearrange
import einops
import math
import torch.nn as nn
import torch
class AddPositionalEncoding(nn.Module):
def __init__(self, d_model, max_sequence_len=5000):
super().__init__()
# pos - position in sequence, i - index of element embedding
# PE(pos, 2i) = sin(pos / 10000**(2i / d_model)) = sin(pos * e**(2i * (-log(10000))/d_model))
# PE(pos, 2i+1) = cos(pos / 10000**(2i / d_model)) = cos(pos * e**(2i * (-log(10000))/d_model))
positions = torch.arange(max_sequence_len)
even_embedding_indices = torch.arange(0, d_model, 2)
expression = torch.exp(even_embedding_indices * (-math.log(10000.0) / d_model))
expression = torch.einsum("i, j -> ij", positions, expression)
even_encodings = torch.sin(expression)
odd_encodings = torch.cos(expression)
positional_encodings = einops.rearrange(
[even_encodings, odd_encodings],
'even_odd pos embed -> pos (embed even_odd)'
)
self.register_buffer('positional_encodings', positional_encodings)
def forward(self, batch):
seq_len = batch.size(1)
positional_encodings = self.positional_encodings[:seq_len, :]
# implicit batch broadcasting
return batch + positional_encodings
class ImageEmbedding(nn.Module):
"""Reshape image into patches and project into given dimension"""
def __init__(self, d_model, input_width, input_height, patch_size=16, dropout=.1):
super().__init__()
assert input_width % patch_size == 0 and input_height % patch_size == 0, \
"Cannot split image in patches"
tokenize = Rearrange(
'b c (h1 h2) (w1 w2) -> b (c h1 w1) (h2 w2)',
h2=patch_size,
w2=patch_size
)
project = nn.Linear(patch_size ** 2, d_model)
self.embed = nn.Sequential(
tokenize,
project,
AddPositionalEncoding(d_model),
nn.Dropout(p=dropout)
)
def forward(self, image_batch):
image_batch = self.embed(image_batch)
return image_batch
class ImageEncoder(nn.Module):
"""
Given an image, returns its vector representation.
"""
def __init__(self, image_width, image_height, d_model, num_layers=8):
super().__init__()
image_embedding = ImageEmbedding(d_model, image_width, image_height)
encoder_layer = nn.TransformerEncoderLayer(
d_model=d_model,
nhead=8,
dim_feedforward=2048,
batch_first=True
)
transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
self.encode = nn.Sequential(image_embedding, transformer_encoder)
def forward(self, batch):
return self.encode(batch)
class Seq2SeqTransformer(nn.Module):
def __init__(self,
num_encoder_layers: int,
num_decoder_layers: int,
emb_size: int,
nhead: int,
image_width: int,
image_height: int,
tgt_vocab_size: int,
dim_feedforward: int = 512,
dropout: float = 0.1):
super(Seq2SeqTransformer, self).__init__()
self.transformer = nn.Transformer(d_model=emb_size,
nhead=nhead,
num_encoder_layers=num_encoder_layers,
num_decoder_layers=num_decoder_layers,
dim_feedforward=dim_feedforward,
dropout=dropout)
# TODO: share weights between generator and embedding
self.generator = nn.Linear(emb_size, tgt_vocab_size)
self.src_tok_emb = ImageEmbedding(emb_size, image_width, image_height, dropout=dropout)
self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size)
def forward(self,
src: Tensor,
trg: Tensor,
src_mask: Tensor,
tgt_mask: Tensor,
src_padding_mask: Tensor,
tgt_padding_mask: Tensor,
memory_key_padding_mask: Tensor):
src_emb = self.positional_encoding(self.src_tok_emb(src))
tgt_emb = self.positional_encoding(self.tgt_tok_emb(trg))
outs = self.transformer(src_emb, tgt_emb, src_mask, tgt_mask, None,
src_padding_mask, tgt_padding_mask, memory_key_padding_mask)
return self.generator(outs)
def encode(self, src: Tensor, src_mask: Tensor):
return self.transformer.encoder(self.positional_encoding(
self.src_tok_emb(src)), src_mask)
def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor):
return self.transformer.decoder(self.positional_encoding(
self.tgt_tok_emb(tgt)), memory,
tgt_mask)