File size: 4,915 Bytes
bc1ada8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
from typing import Tuple
import torch.nn as nn
from torch import Tensor
from modules.transformer_embedding import TransformerEmbedding
from modules.positional_encoding import PositionalEncoding
from model.encoder import Encoder
from model.decoder import Decoder
from layers.projection_layer import ProjectionLayer
class Transformer(nn.Module):
"""
Transformer.
Args:
- src_vocab_size (int): source vocabulary size
- tgt_vocab_size (int): target vocabulary size
- src_max_seq_len (int): source max sequence length
- tgt_max_seq_len (int): target max sequence length
- d_model (int): dimension of model
- num_heads (int): number of heads
- d_ff (int): dimension of hidden feed forward layer
- dropout_p (float): probability of dropout
- num_encoder_layers (int): number of encoder layers
- num_decoder_layers (int): number of decoder layers
"""
def __init__(
self,
src_vocab_size: int,
tgt_vocab_size: int,
src_max_seq_len: int,
tgt_max_seq_len: int,
d_model: int = 512,
num_heads: int = 8,
d_ff: int = 2048,
dropout_p: float = 0.1,
num_encoder_layers: int = 6,
num_decoder_layers: int = 6,
) -> None:
super(Transformer, self).__init__()
# Embedding layers
self.src_embedding = TransformerEmbedding(
d_model=d_model,
num_embeddings=src_vocab_size
)
self.tgt_embedding = TransformerEmbedding(
d_model=d_model,
num_embeddings=tgt_vocab_size
)
# Positional Encoding layers
self.src_positional_encoding = PositionalEncoding(
d_model=d_model,
dropout_p=dropout_p,
max_length=src_max_seq_len
)
self.tgt_positional_encoding = PositionalEncoding(
d_model=d_model,
dropout_p=dropout_p,
max_length=tgt_max_seq_len
)
# Encoder
self.encoder = Encoder(
d_model=d_model,
num_heads=num_heads,
d_ff=d_ff,
dropout_p=dropout_p,
num_layers=num_encoder_layers
)
# Decoder
self.decoder = Decoder(
d_model=d_model,
num_heads=num_heads,
d_ff=d_ff,
dropout_p=dropout_p,
num_layers=num_decoder_layers
)
# projecting decoder's output to the target language.
self.projection_layer = ProjectionLayer(
d_model=d_model,
vocab_size=tgt_vocab_size
)
def encode(
self,
src: Tensor,
src_mask: Tensor
) -> Tensor:
"""
Get encoder outputs.
"""
src = self.src_embedding(src)
src = self.src_positional_encoding(src)
return self.encoder(src, src_mask)
def decode(
self,
encoder_output: Tensor,
src_mask: Tensor,
tgt: Tensor,
tgt_mask: Tensor
) -> Tuple[Tensor, Tensor]:
"""
Get decoder outputs for a set of target inputs.
"""
tgt = self.tgt_embedding(tgt)
tgt = self.tgt_positional_encoding(tgt)
return self.decoder(
x=tgt,
encoder_output=encoder_output,
src_mask=src_mask,
tgt_mask=tgt_mask
)
def project(self, decoder_output: Tensor) -> Tensor:
"""
Project decoder outputs to target vocabulary.
"""
return self.projection_layer(decoder_output)
def forward(
self,
src: Tensor,
src_mask: Tensor,
tgt: Tensor,
tgt_mask: Tensor
) -> Tuple[Tensor, Tensor]:
# src_mask = self.make_src_mask(src)
# tgt_mask = self.make_tgt_mask(tgt)
encoder_output = self.encode(src, src_mask)
decoder_output, attn = self.decode(
encoder_output, src_mask, tgt, tgt_mask
)
output = self.project(decoder_output)
return output, attn
def count_parameters(self):
return sum(p.numel() for p in self.parameters() if p.requires_grad)
def get_model(config, src_vocab_size: int, tgt_vocab_size: int) -> Transformer:
"""
returns a `Transformer` model for a given config.
"""
return Transformer(
src_vocab_size=src_vocab_size,
tgt_vocab_size=tgt_vocab_size,
src_max_seq_len=config['dataset']['src_max_seq_len'],
tgt_max_seq_len=config['dataset']['tgt_max_seq_len'],
d_model=config['model']['d_model'],
num_heads=config['model']['num_heads'],
d_ff=config['model']['d_ff'],
dropout_p=config['model']['dropout_p'],
num_encoder_layers=config['model']['num_encoder_layers'],
num_decoder_layers=config['model']['num_decoder_layers'],
) |