Spaces:
Runtime error
Runtime error
File size: 4,979 Bytes
c29b35f e949d7b c29b35f e949d7b c29b35f e949d7b c29b35f e949d7b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
from einops.layers.torch import Rearrange
import einops
import math
import torch.nn as nn
import torch
class AddPositionalEncoding(nn.Module):
def __init__(self, d_model, max_sequence_len=5000):
super().__init__()
# pos - position in sequence, i - index of element embedding
# PE(pos, 2i) = sin(pos / 10000**(2i / d_model)) = sin(pos * e**(2i * (-log(10000))/d_model))
# PE(pos, 2i+1) = cos(pos / 10000**(2i / d_model)) = cos(pos * e**(2i * (-log(10000))/d_model))
positions = torch.arange(max_sequence_len)
even_embedding_indices = torch.arange(0, d_model, 2)
expression = torch.exp(even_embedding_indices * (-math.log(10000.0) / d_model))
expression = torch.einsum("i, j -> ij", positions, expression)
even_encodings = torch.sin(expression)
odd_encodings = torch.cos(expression)
positional_encodings = einops.rearrange(
[even_encodings, odd_encodings],
'even_odd pos embed -> pos (embed even_odd)'
)
self.register_buffer('positional_encodings', positional_encodings)
def forward(self, batch):
seq_len = batch.size(1)
positional_encodings = self.positional_encodings[:seq_len, :]
# implicit batch broadcasting
return batch + positional_encodings
class ImageEmbedding(nn.Module):
"""Reshape image into patches and project into given dimension"""
def __init__(self, d_model, input_width, input_height, patch_size=16, dropout=.1):
super().__init__()
assert input_width % patch_size == 0 and input_height % patch_size == 0, \
"Cannot split image in patches"
tokenize = Rearrange(
'b c (h1 h2) (w1 w2) -> b (c h1 w1) (h2 w2)',
h2=patch_size,
w2=patch_size
)
project = nn.Linear(patch_size ** 2, d_model)
self.embed = nn.Sequential(
tokenize,
project,
AddPositionalEncoding(d_model),
nn.Dropout(p=dropout)
)
def forward(self, image_batch):
image_batch = self.embed(image_batch)
return image_batch
class ImageEncoder(nn.Module):
"""
Given an image, returns its vector representation.
"""
def __init__(self, image_width, image_height, d_model, num_layers=8):
super().__init__()
image_embedding = ImageEmbedding(d_model, image_width, image_height)
encoder_layer = nn.TransformerEncoderLayer(
d_model=d_model,
nhead=8,
dim_feedforward=2048,
batch_first=True
)
transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
self.encode = nn.Sequential(image_embedding, transformer_encoder)
def forward(self, batch):
return self.encode(batch)
class Seq2SeqTransformer(nn.Module):
def __init__(self,
num_encoder_layers: int,
num_decoder_layers: int,
emb_size: int,
nhead: int,
image_width: int,
image_height: int,
tgt_vocab_size: int,
dim_feedforward: int = 512,
dropout: float = 0.1):
super(Seq2SeqTransformer, self).__init__()
self.transformer = nn.Transformer(d_model=emb_size,
nhead=nhead,
num_encoder_layers=num_encoder_layers,
num_decoder_layers=num_decoder_layers,
dim_feedforward=dim_feedforward,
dropout=dropout)
# TODO: share weights between generator and embedding
self.generator = nn.Linear(emb_size, tgt_vocab_size)
self.src_tok_emb = ImageEmbedding(emb_size, image_width, image_height, dropout=dropout)
self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size)
def forward(self,
src: Tensor,
trg: Tensor,
src_mask: Tensor,
tgt_mask: Tensor,
src_padding_mask: Tensor,
tgt_padding_mask: Tensor,
memory_key_padding_mask: Tensor):
src_emb = self.positional_encoding(self.src_tok_emb(src))
tgt_emb = self.positional_encoding(self.tgt_tok_emb(trg))
outs = self.transformer(src_emb, tgt_emb, src_mask, tgt_mask, None,
src_padding_mask, tgt_padding_mask, memory_key_padding_mask)
return self.generator(outs)
def encode(self, src: Tensor, src_mask: Tensor):
return self.transformer.encoder(self.positional_encoding(
self.src_tok_emb(src)), src_mask)
def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor):
return self.transformer.decoder(self.positional_encoding(
self.tgt_tok_emb(tgt)), memory,
tgt_mask)
|