demo / transformer.py
joechou's picture
Update transformer.py
479c8b1
raw
history blame
7.61 kB
import math
from typing import Tuple
import torch
import torch.nn as nn
from torch.nn.modules.transformer import (
TransformerDecoder,
TransformerDecoderLayer,
TransformerEncoder,
TransformerEncoderLayer,
)
from dataset import Batched, EncodedBatch
from vocab import BOS_ID, EOS_ID, PAD_ID
import helpers
class PositionalEncoding(nn.Module):
def __init__(self, dropout, dim, max_len=5000):
"""
initialization of required variables and functions
:param dropout: dropout probability
:param dim: hidden size
:param max_len: maximum length
"""
super(PositionalEncoding, self).__init__()
# positional encoding initialization
pe = torch.zeros(max_len, dim)
position = torch.arange(0, max_len).unsqueeze(1)
# term to divide
div_term = torch.exp(
(torch.arange(0, dim, 2, dtype=torch.float) * -(math.log(10000.0) / dim))
)
# sinusoidal positional encoding
pe[:, 0::2] = torch.sin(position.float() * div_term)
pe[:, 1::2] = torch.cos(position.float() * div_term)
pe = pe.unsqueeze(1)
self.register_buffer("pe", pe)
self.dropout = nn.Dropout(p=dropout)
self.dim = dim
def forward(self, emb):
"""
create positional encoding
:param emb: word embedding
:param step: step for decoding in inference
:return: positional encoding representation
"""
emb *= math.sqrt(self.dim)
emb = emb + self.pe[: emb.size(0)] # [len, batch, size]
emb = self.dropout(emb)
return emb
class Encoder(nn.Module):
@staticmethod
def from_args(args) -> "Encoder":
return Encoder(
args.text_vocab_size + args.cond_vocab_size,
args.max_seq_len,
args.d_model,
args.nhead,
args.num_encoder_layers,
args.dropout,
args.mode,
)
def __init__(
self,
vocab_size: int,
max_seq_len: int,
d_model: int,
nhead: int,
num_layers: int,
dropout: float,
mode: str,
):
super().__init__()
self.d_model = d_model
self.max_seq_len = max_seq_len
self.input_embedding = nn.Embedding(vocab_size, d_model)
self.pos_encoder = PositionalEncoding(dropout, d_model)
encoder_layer = TransformerEncoderLayer(
d_model, nhead, d_model * 4, dropout, norm_first=True
)
self.encoder = TransformerEncoder(
encoder_layer, num_layers, nn.LayerNorm(d_model)
)
self.mode = mode
def device(self):
return list(self.parameters())[0].device
def forward(self, batched: Batched) -> EncodedBatch:
src, src_key_padding_mask = Encoder._get_input(batched, self.mode)
src = self.input_embedding(src)
src = self.pos_encoder(src)
token_encodings = self.encoder.forward(
src=src, src_key_padding_mask=src_key_padding_mask
)
return EncodedBatch(
context_encodings=token_encodings,
context_encodings_mask=src_key_padding_mask,
)
@staticmethod
def _get_input(batched: Batched, mode: str) -> Tuple[torch.Tensor, torch.Tensor]:
return {
helpers.BASELINE: (batched.title_token_ids, batched.title_token_ids_mask),
helpers.KOBE_ATTRIBUTE: (
batched.cond_title_token_ids,
batched.cond_title_token_ids_mask,
),
helpers.KOBE_KNOWLEDGE: (
batched.title_fact_token_ids,
batched.title_fact_token_ids_mask,
),
helpers.KOBE_FULL: (
batched.cond_title_fact_token_ids,
batched.cond_title_fact_token_ids_mask,
),
}[mode]
class Decoder(nn.Module):
@staticmethod
def from_args(args) -> "Decoder":
return Decoder(
args.text_vocab_size,
args.max_seq_len,
args.d_model,
args.nhead,
args.num_encoder_layers,
args.dropout,
)
def __init__(
self,
vocab_size: int,
max_seq_len: int,
d_model: int,
nhead: int,
num_layers: int,
dropout: float,
):
super(Decoder, self).__init__()
self.max_seq_len = max_seq_len
self.embedding = nn.Embedding(vocab_size, d_model)
self.pos_encoder = PositionalEncoding(dropout, d_model)
decoder_layer = TransformerDecoderLayer(
d_model, nhead, 4 * d_model, dropout, norm_first=True
)
self.decoder = TransformerDecoder(
decoder_layer, num_layers, nn.LayerNorm(d_model)
)
self.output = nn.Linear(d_model, vocab_size)
def forward(self, batch: Batched, encoded_batch: EncodedBatch) -> torch.Tensor:
tgt = self.embedding(batch.description_token_ids[:-1])
tgt = self.pos_encoder(tgt)
tgt_mask = Decoder.generate_square_subsequent_mask(tgt.shape[0], tgt.device)
outputs = self.decoder(
tgt=tgt,
tgt_mask=tgt_mask,
tgt_key_padding_mask=batch.description_token_ids_mask[:, :-1],
memory=encoded_batch.context_encodings,
memory_key_padding_mask=encoded_batch.context_encodings_mask,
)
return self.output(outputs)
def predict(self, encoded_batch: EncodedBatch, decoding_strategy: str):
batch_size = encoded_batch.context_encodings.shape[1]
tgt = torch.tensor(
[BOS_ID] * batch_size, device=encoded_batch.context_encodings.device
).unsqueeze(dim=0)
tgt_mask = Decoder.generate_square_subsequent_mask(self.max_seq_len, tgt.device)
pred_all = []
for idx in range(self.max_seq_len):
tgt_emb = self.pos_encoder(self.embedding(tgt))
outputs = self.decoder(
tgt_emb,
tgt_mask=tgt_mask[: idx + 1, : idx + 1],
memory=encoded_batch.context_encodings,
memory_key_padding_mask=encoded_batch.context_encodings_mask,
)
logits = self.output(outputs[-1])
if decoding_strategy == "greedy":
pred_step = logits.argmax(dim=1).tolist()
elif decoding_strategy == "nucleus":
pred_step = [
helpers.top_k_top_p_sampling(logits[i], top_p=0.95)
for i in range(batch_size)
]
else:
raise NotImplementedError
for b in range(batch_size):
if pred_all and pred_all[-1][b].item() in [EOS_ID, PAD_ID]:
pred_step[b] = PAD_ID
if all([pred == PAD_ID for pred in pred_step]):
break
pred_step = torch.tensor(pred_step, device=tgt.device)
pred_all.append(pred_step)
if idx < self.max_seq_len - 1:
tgt_step = pred_step.unsqueeze(dim=0)
tgt = torch.cat([tgt, tgt_step], dim=0)
preds = torch.stack(pred_all)
return preds
@staticmethod
def generate_square_subsequent_mask(sz: int, device: torch.device) -> torch.Tensor:
r"""
Generate a square mask for the sequence. The masked positions are filled with
float('-inf').
Unmasked positions are filled with float(0.0).
"""
return torch.triu(
torch.full((sz, sz), float("-inf"), device=device), diagonal=1
)