Spaces:
Runtime error
Runtime error
File size: 7,130 Bytes
c29b35f fb8db0f c29b35f e949d7b c29b35f e949d7b c29b35f e949d7b e932abd e949d7b c29b35f e949d7b fb8db0f e932abd fb8db0f e949d7b 41a34cd e949d7b fb8db0f e949d7b c2ef1c6 4f4785c fb8db0f c2ef1c6 41a34cd e949d7b fb8db0f 41a34cd e932abd 4f4785c fb8db0f 4f4785c 29bcc5f 9cdd1c7 e949d7b 29bcc5f e949d7b 9cdd1c7 2a394f6 e932abd 9cdd1c7 2a394f6 e932abd 9cdd1c7 2a394f6 9cdd1c7 2a394f6 fb8db0f 9cdd1c7 fb8db0f 2a394f6 9cdd1c7 2a394f6 9cdd1c7 2a394f6 fb8db0f e932abd 4f4785c e932abd 4f4785c e932abd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 |
from einops.layers.torch import Rearrange
import einops
import math
import pytorch_lightning as pl
import torch.nn as nn
import torch
class AddPositionalEncoding(nn.Module):
def __init__(self, d_model, max_sequence_len=5000):
super().__init__()
positions = torch.arange(max_sequence_len)
even_embedding_indices = torch.arange(0, d_model, 2)
expression = torch.exp(even_embedding_indices * (-math.log(10000.0) / d_model))
expression = torch.einsum("i, j -> ij", positions, expression)
even_encodings = torch.sin(expression)
odd_encodings = torch.cos(expression)
positional_encodings = einops.rearrange(
[even_encodings, odd_encodings],
'even_odd pos embed -> pos (embed even_odd)'
)
self.register_buffer('positional_encodings', positional_encodings)
def forward(self, batch):
seq_len = batch.size(1)
positional_encodings = self.positional_encodings[:seq_len, :]
return batch + positional_encodings
class ImageEmbedding(nn.Module):
"""Reshape image into patches and project into given dimension"""
def __init__(self, d_model, input_width, input_height, patch_size, dropout):
super().__init__()
assert input_width % patch_size == 0 and input_height % patch_size == 0, \
"Cannot split image in patches"
tokenize = Rearrange(
'b c (h1 h2) (w1 w2) -> b (c h1 w1) (h2 w2)',
h2=patch_size,
w2=patch_size
)
project = nn.Linear(patch_size ** 2, d_model)
self.embed = nn.Sequential(
tokenize,
project,
AddPositionalEncoding(d_model),
nn.Dropout(p=dropout)
)
def forward(self, image_batch):
image_batch = self.embed(image_batch)
return image_batch
class TexEmbedding(nn.Module):
def __init__(self, d_model: int, vocab_size: int, dropout: float):
super().__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
self.add_positional_encoding = AddPositionalEncoding(d_model)
self.dropout = nn.Dropout(p=dropout)
self.d_model = d_model
def forward(self, tex_ids_batch):
tex_ids_batch = self.embedding(tex_ids_batch.long()) * math.sqrt(self.d_model)
tex_ids_batch = self.add_positional_encoding(tex_ids_batch)
tex_ids_batch = self.dropout(tex_ids_batch)
return tex_ids_batch
class Transformer(pl.LightningModule):
def __init__(self,
num_encoder_layers: int,
num_decoder_layers: int,
d_model: int,
nhead: int,
image_width: int,
image_height: int,
tgt_vocab_size: int,
pad_idx: int,
dim_feedforward: int = 512,
dropout: float = .1,
):
super().__init__()
self.transformer = nn.Transformer(d_model=d_model,
nhead=nhead,
num_encoder_layers=num_encoder_layers,
num_decoder_layers=num_decoder_layers,
dim_feedforward=dim_feedforward,
dropout=dropout,
batch_first=True)
for p in self.transformer.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
self.d_model = d_model
self.src_tok_emb = ImageEmbedding(d_model, image_width, image_height, patch_size=16, dropout=dropout)
self.tgt_tok_emb = TexEmbedding(d_model, tgt_vocab_size, dropout=dropout)
self.generator = nn.Linear(d_model, tgt_vocab_size)
self.tgt_tok_emb.embedding.weight = self.generator.weight
self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=pad_idx, label_smoothing=.1)
self.save_hyperparameters()
def forward(self, src, tgt, src_mask=None, tgt_mask=None, memory_mask=None, src_padding_mask=None,
tgt_padding_mask=None):
"""The positions of masks with ``True``
are not allowed to attend while ``False`` values will be unchanged.
The positions of padding masks with the
value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged."""
src = self.src_tok_emb(src)
tgt = self.tgt_tok_emb(tgt)
outs = self.transformer(src, tgt, src_mask, tgt_mask, memory_mask, src_padding_mask, tgt_padding_mask)
return self.generator(outs)
def encode(self, src, src_mask=None, src_padding_mask=None):
src = self.src_tok_emb(src)
return self.transformer.encoder(src, src_mask, src_padding_mask)
def decode(self, tgt, memory=None, tgt_mask=None, memory_mask=None, tgt_padding_mask=None):
tgt = self.tgt_tok_emb(tgt)
outs = self.transformer.decoder(tgt, memory, tgt_mask, memory_mask, tgt_padding_mask)
return self.generator(outs)
def _shared_step(self, batch):
src = batch['images']
tgt = batch['tex_ids']
tgt_input = tgt[:, :-1]
tgt_output = tgt[:, 1:]
src_mask = None
tgt_mask = self.transformer.generate_square_subsequent_mask(tgt_input.shape[1]).to(self.device, torch.bool)
memory_mask = None
src_padding_mask = None
tgt_padding_mask = torch.logical_not(batch['tex_attention_masks'][:, :-1])
outs = self(src, tgt_input, src_mask, tgt_mask, memory_mask, src_padding_mask, tgt_padding_mask)
loss = self.loss_fn(einops.rearrange(outs, 'b n prob -> b prob n'), tgt_output.long())
return loss
def training_step(self, batch, batch_idx):
loss = self._shared_step(batch)
self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
return loss
def validation_step(self, batch, batch_idx):
loss = self._shared_step(batch)
self.log("val_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
return loss
def test_step(self, batch, batch_idx):
loss = self._shared_step(batch)
self.log("test_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
return loss
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1, betas=(0.9, 0.98), eps=1e-9)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, NoamLRLambda(self.d_model))
return [optimizer], [scheduler]
class NoamLRLambda:
def __init__(self, d_model, factor=1, warmup=4000):
"""
:param d_model: size of hidden model dimension
:param factor: multiplicative factor
:param warmup: number of warmup steps
"""
self.d_model = d_model
self.factor = factor
self.warmup = warmup
def __call__(self, step):
step += 1
return self.factor * self.d_model ** (-0.5) * min(step ** (-0.5), step * self.warmup ** (-1.5))
|