|
|
import torch |
|
|
from torch import Tensor |
|
|
import torch.nn as nn |
|
|
from safetensors.torch import load_model |
|
|
from jaxtyping import Bool, Int, Float |
|
|
from huggingface_hub import hf_hub_download |
|
|
from embedding import InputEmbeddings, PositionalEncoding |
|
|
from modules import Encoder, Decoder |
|
|
import config |
|
|
|
|
|
|
|
|
class Generator(nn.Module): |
|
|
""" |
|
|
Implements the final Linear (Projection) layer and Softmax. |
|
|
|
|
|
This module takes the final output of the Decoder stack (B, T, D) |
|
|
and projects it onto the vocabulary space (B, T, vocab_size) |
|
|
to produce the logits. |
|
|
|
|
|
(This layer's weights can be tied with the |
|
|
target embedding layer, which we will handle in the main |
|
|
'Transformer' model class). |
|
|
""" |
|
|
|
|
|
def __init__(self, d_model: int, vocab_size: int) -> None: |
|
|
""" |
|
|
Initializes the Generator (Output Projection) layer. |
|
|
|
|
|
Args: |
|
|
d_model (int): The dimension of the model (D). |
|
|
vocab_size (int): The size of the target vocabulary. |
|
|
""" |
|
|
super().__init__() |
|
|
|
|
|
self.proj: nn.Linear = nn.Linear(d_model, vocab_size, bias=False) |
|
|
|
|
|
def forward( |
|
|
self, x: Float[Tensor, "B T_tgt D"] |
|
|
) -> Float[Tensor, "B T_tgt vocab_size"]: |
|
|
""" |
|
|
Forward pass for the Generator. |
|
|
|
|
|
Args: |
|
|
x (Tensor): The final output tensor from the Decoder stack. |
|
|
|
|
|
Returns: |
|
|
Tensor: The output logits over the vocabulary. |
|
|
""" |
|
|
|
|
|
logits = self.proj(x) |
|
|
return logits |
|
|
|
|
|
|
|
|
class Transformer(nn.Module): |
|
|
""" |
|
|
The main Transformer model architecture, combining the Encoder |
|
|
and Decoder stacks, as described in "Attention Is All You Need". |
|
|
|
|
|
This implementation follows modern best practices (Pre-LN) and |
|
|
is designed for a sequence-to-sequence task (e.g., translation). |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
src_vocab_size: int, |
|
|
tgt_vocab_size: int, |
|
|
d_model: int, |
|
|
n_heads: int, |
|
|
n_layers: int, |
|
|
d_ff: int, |
|
|
dropout: float = 0.1, |
|
|
max_seq_len: int = 512, |
|
|
) -> None: |
|
|
""" |
|
|
Initializes the full Transformer model. |
|
|
|
|
|
Args: |
|
|
src_vocab_size (int): Vocabulary size for the source language. |
|
|
tgt_vocab_size (int): Vocabulary size for the target language. |
|
|
d_model (int): The dimension of the model (D). |
|
|
n_heads (int): The number of attention heads (H). |
|
|
n_layers (int): The number of Encoder/Decoder layers (N). |
|
|
d_ff (int): The inner dimension of the Feed-Forward Network (D_FF). |
|
|
dropout (float): The dropout rate. |
|
|
max_seq_len (int): The maximum sequence length for positional encoding. |
|
|
""" |
|
|
super().__init__() |
|
|
|
|
|
self.d_model = d_model |
|
|
|
|
|
|
|
|
|
|
|
self.src_embed: InputEmbeddings = InputEmbeddings(d_model, src_vocab_size) |
|
|
|
|
|
|
|
|
self.tgt_embed: InputEmbeddings = InputEmbeddings(d_model, tgt_vocab_size) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.pos_enc: PositionalEncoding = PositionalEncoding( |
|
|
d_model, max_seq_len, dropout |
|
|
) |
|
|
|
|
|
|
|
|
self.encoder: Encoder = Encoder(d_model, n_heads, d_ff, n_layers, dropout) |
|
|
|
|
|
|
|
|
self.decoder: Decoder = Decoder(d_model, n_heads, d_ff, n_layers, dropout) |
|
|
|
|
|
|
|
|
self.generator: Generator = Generator(d_model, tgt_vocab_size) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.generator.proj.weight = self.tgt_embed.token_emb.weight |
|
|
|
|
|
|
|
|
|
|
|
self.apply(self._init_weights) |
|
|
|
|
|
def _init_weights(self, module: nn.Module): |
|
|
""" |
|
|
Applies Xavier/Glorot uniform initialization to linear layers. |
|
|
This is a common and effective initialization strategy. |
|
|
""" |
|
|
if isinstance(module, nn.Linear): |
|
|
nn.init.xavier_uniform_(module.weight) |
|
|
|
|
|
if module.bias is not None: |
|
|
nn.init.constant_(module.bias, 0) |
|
|
|
|
|
elif isinstance(module, nn.Embedding): |
|
|
|
|
|
nn.init.normal_(module.weight, mean=0, std=self.d_model**-0.5) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
src: Int[Tensor, "B T_src"], |
|
|
tgt: Int[Tensor, "B T_tgt"], |
|
|
src_mask: Bool[Tensor, "B 1 1 T_src"], |
|
|
tgt_mask: Bool[Tensor, "B 1 T_tgt T_tgt"], |
|
|
) -> Float[Tensor, "B T_tgt vocab_size"]: |
|
|
""" |
|
|
Defines the main forward pass of the Transformer model. |
|
|
|
|
|
Args: |
|
|
src (Tensor): Source sequence token IDs. |
|
|
tgt (Tensor): Target sequence token IDs (shifted right). |
|
|
src_mask (Tensor): Padding mask for the source sequence. |
|
|
tgt_mask (Tensor): Combined padding and look-ahead mask |
|
|
for the target sequence. |
|
|
|
|
|
Returns: |
|
|
Tensor: The output logits from the model (B, T_tgt, vocab_size). |
|
|
""" |
|
|
|
|
|
|
|
|
src_embeded = self.src_embed(src) |
|
|
src_with_pos = self.pos_enc(src_embeded) |
|
|
|
|
|
|
|
|
|
|
|
enc_output: Tensor = self.encoder(src_with_pos, src_mask) |
|
|
|
|
|
|
|
|
|
|
|
tgt_embeded = self.tgt_embed(tgt) |
|
|
tgt_with_pos = self.pos_enc(tgt_embeded) |
|
|
|
|
|
|
|
|
dec_output: Tensor = self.decoder(tgt_with_pos, enc_output, src_mask, tgt_mask) |
|
|
|
|
|
|
|
|
|
|
|
logits: Tensor = self.generator(dec_output) |
|
|
|
|
|
return logits |
|
|
|
|
|
|
|
|
def load_trained_model( |
|
|
config_obj, checkpoint_path, device: torch.device |
|
|
) -> Transformer: |
|
|
print("Downloading safetensors from Hub...") |
|
|
model_path = hf_hub_download(repo_id=config.REPO_ID, filename=config.FILENAME) |
|
|
|
|
|
print("Instantiating the Transformer model...") |
|
|
model = Transformer( |
|
|
src_vocab_size=config_obj.VOCAB_SIZE, |
|
|
tgt_vocab_size=config_obj.VOCAB_SIZE, |
|
|
d_model=config_obj.D_MODEL, |
|
|
n_heads=config_obj.N_HEADS, |
|
|
n_layers=config_obj.N_LAYERS, |
|
|
d_ff=config_obj.D_FF, |
|
|
dropout=config_obj.DROPOUT, |
|
|
max_seq_len=config_obj.MAX_SEQ_LEN, |
|
|
).to(device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print(f"Loading model from: {model_path}") |
|
|
load_model(model, filename=model_path) |
|
|
|
|
|
print(f"Successfully loaded trained weights from {model_path}") |
|
|
return model |
|
|
|