| """
|
| 10+2 Tied Transformer for English → Malay Translation
|
| =======================================================
|
| An asymmetric encoder-decoder Transformer built on ``torch.nn.Transformer``.
|
|
|
| Architecture (redesigned for efficient T4 GPU training & inference):
|
| d_model = 512 (embedding dimension, head_dim = 64)
|
| n_head = 8 (attention heads)
|
| encoder layers = 10 (deep encoder for source understanding)
|
| decoder layers = 2 (shallow decoder for fast generation)
|
| d_ff = 2048 (feed-forward inner dimension)
|
| dropout = 0.1
|
| norm_first = True (pre-norm for training stability)
|
| shared embeddings = True (single vocab, en+ms share Latin script)
|
| tied output proj. = True (output reuses embedding weights)
|
|
|
| Key design choices (see architecture_report.md for full rationale):
|
| • **Asymmetric depth (Kasai et al., 2021):** Encoder depth drives
|
| translation quality; decoder depth can be aggressively reduced
|
| with minimal quality loss and ~3× faster inference.
|
| • **Shared vocabulary:** English and Malay both use Latin script with
|
| significant lexical overlap (loanwords, numbers, proper nouns).
|
| A joint BPE naturally captures cross-lingual subword patterns.
|
| • **Tied output projection (Press & Wolf, 2017):** The decoder's output
|
| linear layer reuses the shared embedding matrix, saving ~26M params
|
| and acting as a regulariser.
|
| • **Pre-layer normalisation (Xiong et al., 2020):** Essential for stable
|
| training of a 10-layer encoder. Places LayerNorm before each sublayer.
|
| • Uses PyTorch's native ``nn.Transformer`` to keep FlashAttention /
|
| SDPA fused kernels active (PyTorch 2.0+).
|
| """
|
|
|
| from __future__ import annotations
|
|
|
| import math
|
| from typing import Optional
|
|
|
| import torch
|
| import torch.nn as nn
|
|
|
|
|
|
|
|
|
|
|
| class PositionalEncoding(nn.Module):
|
| """
|
| Inject positional information via fixed sinusoidal signals.
|
|
|
| PE(pos, 2i) = sin(pos / 10000^{2i / d_model})
|
| PE(pos, 2i+1) = cos(pos / 10000^{2i / d_model})
|
| """
|
|
|
| def __init__(self, d_model: int, max_len: int = 5000, dropout: float = 0.1):
|
| super().__init__()
|
| self.dropout = nn.Dropout(p=dropout)
|
|
|
| pe = torch.zeros(max_len, d_model)
|
| position = torch.arange(0, max_len).unsqueeze(1).float()
|
| div_term = torch.exp(
|
| torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
|
| )
|
|
|
| pe[:, 0::2] = torch.sin(position * div_term)
|
| pe[:, 1::2] = torch.cos(position * div_term)
|
| pe = pe.unsqueeze(0)
|
| self.register_buffer("pe", pe)
|
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| """
|
| Args:
|
| x: (batch, seq_len, d_model)
|
| Returns:
|
| (batch, seq_len, d_model) with positional encoding added.
|
| """
|
| x = x + self.pe[:, : x.size(1)]
|
| return self.dropout(x)
|
|
|
|
|
|
|
|
|
|
|
| class TransformerTranslator(nn.Module):
|
| """
|
| Asymmetric encoder-decoder Transformer with shared/tied embeddings.
|
|
|
| Parameters
|
| ----------
|
| vocab_size : int
|
| Size of the shared source+target vocabulary.
|
| d_model : int
|
| Embedding / hidden dimension.
|
| n_head : int
|
| Number of attention heads.
|
| num_encoder_layers : int
|
| Number of encoder blocks (default 10).
|
| num_decoder_layers : int
|
| Number of decoder blocks (default 2).
|
| d_ff : int
|
| Feed-forward inner dimension.
|
| dropout : float
|
| Dropout rate.
|
| max_len : int
|
| Maximum sequence length for positional encoding.
|
| pad_idx : int
|
| Padding token ID (used to create padding masks).
|
| """
|
|
|
| def __init__(
|
| self,
|
| vocab_size: int,
|
| d_model: int = 512,
|
| n_head: int = 8,
|
| num_encoder_layers: int = 10,
|
| num_decoder_layers: int = 2,
|
| d_ff: int = 2048,
|
| dropout: float = 0.1,
|
| max_len: int = 512,
|
| pad_idx: int = 0,
|
| ):
|
| super().__init__()
|
| self.pad_idx = pad_idx
|
| self.d_model = d_model
|
|
|
|
|
| self.shared_embedding = nn.Embedding(vocab_size, d_model, padding_idx=pad_idx)
|
| self.pos_encoding = PositionalEncoding(d_model, max_len, dropout)
|
| self.embed_scale = math.sqrt(d_model)
|
|
|
|
|
| self.transformer = nn.Transformer(
|
| d_model=d_model,
|
| nhead=n_head,
|
| num_encoder_layers=num_encoder_layers,
|
| num_decoder_layers=num_decoder_layers,
|
| dim_feedforward=d_ff,
|
| dropout=dropout,
|
| batch_first=True,
|
| norm_first=True,
|
| )
|
|
|
|
|
|
|
| self.output_bias = nn.Parameter(torch.zeros(vocab_size))
|
|
|
|
|
| self._init_weights()
|
|
|
| def _embed(self, tokens: torch.Tensor) -> torch.Tensor:
|
| """Shared embedding + scale + positional encoding."""
|
| return self.pos_encoding(self.shared_embedding(tokens) * self.embed_scale)
|
|
|
| def _init_weights(self):
|
| """Xavier-uniform initialization for embeddings."""
|
| nn.init.normal_(self.shared_embedding.weight, mean=0, std=self.d_model ** -0.5)
|
|
|
| with torch.no_grad():
|
| self.shared_embedding.weight[self.pad_idx].zero_()
|
|
|
|
|
|
|
|
|
| @staticmethod
|
| def generate_square_subsequent_mask(sz: int, device: torch.device) -> torch.Tensor:
|
| """
|
| Causal mask for the decoder: prevents attending to future positions.
|
| Returns a (sz, sz) boolean mask where True = blocked.
|
| """
|
| return torch.triu(torch.ones(sz, sz, device=device, dtype=torch.bool), diagonal=1)
|
|
|
| def _make_pad_mask(self, x: torch.Tensor) -> torch.Tensor:
|
| """
|
| Create a padding mask: True where token == pad_idx.
|
| Shape: (batch, seq_len)
|
| """
|
| return x == self.pad_idx
|
|
|
|
|
|
|
|
|
| def forward(
|
| self,
|
| src: torch.Tensor,
|
| tgt: torch.Tensor,
|
| src_key_padding_mask: Optional[torch.Tensor] = None,
|
| tgt_key_padding_mask: Optional[torch.Tensor] = None,
|
| ) -> torch.Tensor:
|
| """
|
| Args:
|
| src: (batch, src_len) source token IDs.
|
| tgt: (batch, tgt_len) target token IDs (teacher-forced).
|
|
|
| Returns:
|
| logits: (batch, tgt_len, vocab_size)
|
| """
|
|
|
| if src_key_padding_mask is None:
|
| src_key_padding_mask = self._make_pad_mask(src)
|
| if tgt_key_padding_mask is None:
|
| tgt_key_padding_mask = self._make_pad_mask(tgt)
|
|
|
|
|
| tgt_len = tgt.size(1)
|
| tgt_mask = self.generate_square_subsequent_mask(tgt_len, tgt.device)
|
|
|
|
|
| src_emb = self._embed(src)
|
| tgt_emb = self._embed(tgt)
|
|
|
|
|
| out = self.transformer(
|
| src=src_emb,
|
| tgt=tgt_emb,
|
| tgt_mask=tgt_mask,
|
| src_key_padding_mask=src_key_padding_mask,
|
| tgt_key_padding_mask=tgt_key_padding_mask,
|
| memory_key_padding_mask=src_key_padding_mask,
|
| )
|
|
|
|
|
| logits = torch.nn.functional.linear(out, self.shared_embedding.weight, self.output_bias)
|
| return logits
|
|
|
|
|
|
|
|
|
| def encode(self, src: torch.Tensor, src_key_padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| """Run only the encoder. Returns memory: (batch, src_len, d_model)."""
|
| if src_key_padding_mask is None:
|
| src_key_padding_mask = self._make_pad_mask(src)
|
| src_emb = self._embed(src)
|
| return self.transformer.encoder(src_emb, src_key_padding_mask=src_key_padding_mask)
|
|
|
| def decode(
|
| self,
|
| tgt: torch.Tensor,
|
| memory: torch.Tensor,
|
| tgt_key_padding_mask: Optional[torch.Tensor] = None,
|
| memory_key_padding_mask: Optional[torch.Tensor] = None,
|
| ) -> torch.Tensor:
|
| """Run only the decoder given encoder memory. Returns logits."""
|
| if tgt_key_padding_mask is None:
|
| tgt_key_padding_mask = self._make_pad_mask(tgt)
|
| tgt_len = tgt.size(1)
|
| tgt_mask = self.generate_square_subsequent_mask(tgt_len, tgt.device)
|
| tgt_emb = self._embed(tgt)
|
| out = self.transformer.decoder(
|
| tgt_emb,
|
| memory,
|
| tgt_mask=tgt_mask,
|
| tgt_key_padding_mask=tgt_key_padding_mask,
|
| memory_key_padding_mask=memory_key_padding_mask,
|
| )
|
| return torch.nn.functional.linear(out, self.shared_embedding.weight, self.output_bias)
|
|
|
|
|
|
|
|
|
|
|
| def count_parameters(model: nn.Module) -> int:
|
| """Return the number of trainable parameters."""
|
| return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
|
|
|
|
|
|
|
|
|
| def build_model(
|
| vocab_size: int,
|
| pad_idx: int = 0,
|
| device: Optional[torch.device] = None,
|
| **kwargs,
|
| ) -> TransformerTranslator:
|
| """
|
| Build and return a TransformerTranslator with default hyperparameters.
|
|
|
| Any kwarg (d_model, n_head, etc.) overrides the default.
|
| """
|
| if device is None:
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
| model = TransformerTranslator(
|
| vocab_size=vocab_size,
|
| pad_idx=pad_idx,
|
| **kwargs,
|
| ).to(device)
|
|
|
| return model
|
|
|