| | """
|
| | TRANSFORMER ENCODER & DECODER (FIXED)
|
| | Xây dựng hoàn chỉnh Encoder và Decoder layers
|
| | """
|
| |
|
| | import torch
|
| | import torch.nn as nn
|
| | from .transformer_components import (
|
| | MultiHeadAttention,
|
| | PositionwiseFeedForward,
|
| | ResidualConnection,
|
| | LayerNorm
|
| | )
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | class EncoderLayer(nn.Module):
|
| | """
|
| | Một layer của Transformer Encoder
|
| |
|
| | Gồm:
|
| | 1. Multi-Head Self-Attention
|
| | 2. Add & Norm
|
| | 3. Feed-Forward Network
|
| | 4. Add & Norm
|
| |
|
| | Args:
|
| | d_model: Dimension của model
|
| | n_heads: Số lượng attention heads
|
| | d_ff: Dimension của feed-forward network
|
| | dropout: Dropout rate
|
| | """
|
| | def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
|
| | super().__init__()
|
| |
|
| |
|
| | self.self_attention = MultiHeadAttention(d_model, n_heads, dropout)
|
| |
|
| |
|
| | self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)
|
| |
|
| |
|
| | self.residual1 = ResidualConnection(d_model, dropout)
|
| | self.residual2 = ResidualConnection(d_model, dropout)
|
| |
|
| | def forward(self, x, mask=None):
|
| | """
|
| | Args:
|
| | x: Input [batch_size, seq_len, d_model]
|
| | mask: Mask tensor [batch_size, 1, 1, seq_len] (để mask padding)
|
| |
|
| | Returns:
|
| | output: [batch_size, seq_len, d_model]
|
| | """
|
| |
|
| | x = self.residual1(x, lambda x: self.self_attention(x, x, x, mask)[0])
|
| |
|
| |
|
| | x = self.residual2(x, self.feed_forward)
|
| |
|
| | return x
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | class Encoder(nn.Module):
|
| | """
|
| | Transformer Encoder - Stack của N encoder layers
|
| |
|
| | Args:
|
| | vocab_size: Kích thước vocabulary
|
| | d_model: Dimension của model
|
| | n_layers: Số lượng encoder layers
|
| | n_heads: Số lượng attention heads
|
| | d_ff: Dimension của feed-forward network
|
| | dropout: Dropout rate
|
| | max_len: Maximum sequence length
|
| | """
|
| | def __init__(self, vocab_size, d_model, n_layers, n_heads, d_ff, dropout=0.1, max_len=5000):
|
| | super().__init__()
|
| |
|
| | from .transformer_components import Embedding, PositionalEncoding
|
| |
|
| |
|
| | self.embedding = Embedding(vocab_size, d_model)
|
| |
|
| |
|
| | self.pos_encoding = PositionalEncoding(d_model, max_len, dropout)
|
| |
|
| |
|
| | self.layers = nn.ModuleList([
|
| | EncoderLayer(d_model, n_heads, d_ff, dropout)
|
| | for _ in range(n_layers)
|
| | ])
|
| |
|
| |
|
| | self.norm = LayerNorm(d_model)
|
| |
|
| | def forward(self, src, src_mask=None):
|
| | """
|
| | Args:
|
| | src: Source sequence [batch_size, src_len]
|
| | src_mask: Source mask [batch_size, 1, 1, src_len]
|
| |
|
| | Returns:
|
| | output: [batch_size, src_len, d_model]
|
| | """
|
| |
|
| | x = self.embedding(src)
|
| | x = self.pos_encoding(x)
|
| |
|
| |
|
| | for layer in self.layers:
|
| | x = layer(x, src_mask)
|
| |
|
| |
|
| | x = self.norm(x)
|
| |
|
| | return x
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | class DecoderLayer(nn.Module):
|
| | """
|
| | Một layer của Transformer Decoder
|
| |
|
| | Gồm:
|
| | 1. Masked Multi-Head Self-Attention
|
| | 2. Add & Norm
|
| | 3. Multi-Head Cross-Attention (với Encoder output)
|
| | 4. Add & Norm
|
| | 5. Feed-Forward Network
|
| | 6. Add & Norm
|
| |
|
| | Args:
|
| | d_model: Dimension của model
|
| | n_heads: Số lượng attention heads
|
| | d_ff: Dimension của feed-forward network
|
| | dropout: Dropout rate
|
| | """
|
| | def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
|
| | super().__init__()
|
| |
|
| |
|
| | self.self_attention = MultiHeadAttention(d_model, n_heads, dropout)
|
| |
|
| |
|
| | self.cross_attention = MultiHeadAttention(d_model, n_heads, dropout)
|
| |
|
| |
|
| | self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)
|
| |
|
| |
|
| | self.residual1 = ResidualConnection(d_model, dropout)
|
| | self.residual2 = ResidualConnection(d_model, dropout)
|
| | self.residual3 = ResidualConnection(d_model, dropout)
|
| |
|
| | def forward(self, x, encoder_output, src_mask=None, tgt_mask=None):
|
| | """
|
| | Args:
|
| | x: Target input [batch_size, tgt_len, d_model]
|
| | encoder_output: Encoder output [batch_size, src_len, d_model]
|
| | src_mask: Source mask [batch_size, 1, 1, src_len]
|
| | tgt_mask: Target mask [batch_size, 1, tgt_len, tgt_len] (causal mask)
|
| |
|
| | Returns:
|
| | output: [batch_size, tgt_len, d_model]
|
| | """
|
| |
|
| | x = self.residual1(x, lambda x: self.self_attention(x, x, x, tgt_mask)[0])
|
| |
|
| |
|
| |
|
| | x = self.residual2(x, lambda x: self.cross_attention(x, encoder_output, encoder_output, src_mask)[0])
|
| |
|
| |
|
| | x = self.residual3(x, self.feed_forward)
|
| |
|
| | return x
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | class Decoder(nn.Module):
|
| | """
|
| | Transformer Decoder - Stack của N decoder layers
|
| |
|
| | Args:
|
| | vocab_size: Kích thước vocabulary
|
| | d_model: Dimension của model
|
| | n_layers: Số lượng decoder layers
|
| | n_heads: Số lượng attention heads
|
| | d_ff: Dimension của feed-forward network
|
| | dropout: Dropout rate
|
| | max_len: Maximum sequence length
|
| | """
|
| | def __init__(self, vocab_size, d_model, n_layers, n_heads, d_ff, dropout=0.1, max_len=5000):
|
| | super().__init__()
|
| |
|
| | from .transformer_components import Embedding, PositionalEncoding
|
| |
|
| |
|
| | self.embedding = Embedding(vocab_size, d_model)
|
| |
|
| |
|
| | self.pos_encoding = PositionalEncoding(d_model, max_len, dropout)
|
| |
|
| |
|
| | self.layers = nn.ModuleList([
|
| | DecoderLayer(d_model, n_heads, d_ff, dropout)
|
| | for _ in range(n_layers)
|
| | ])
|
| |
|
| |
|
| | self.norm = LayerNorm(d_model)
|
| |
|
| |
|
| | self.fc_out = nn.Linear(d_model, vocab_size)
|
| |
|
| | def forward(self, tgt, encoder_output, src_mask=None, tgt_mask=None):
|
| | """
|
| | Args:
|
| | tgt: Target sequence [batch_size, tgt_len]
|
| | encoder_output: Encoder output [batch_size, src_len, d_model]
|
| | src_mask: Source mask [batch_size, 1, 1, src_len]
|
| | tgt_mask: Target mask [batch_size, 1, tgt_len, tgt_len]
|
| |
|
| | Returns:
|
| | output: [batch_size, tgt_len, vocab_size]
|
| | """
|
| |
|
| | x = self.embedding(tgt)
|
| | x = self.pos_encoding(x)
|
| |
|
| |
|
| | for layer in self.layers:
|
| | x = layer(x, encoder_output, src_mask, tgt_mask)
|
| |
|
| |
|
| | x = self.norm(x)
|
| |
|
| |
|
| | output = self.fc_out(x)
|
| |
|
| | return output
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | def create_padding_mask(seq, pad_idx=0):
|
| | """
|
| | Tạo mask cho padding tokens
|
| |
|
| | Args:
|
| | seq: Sequence [batch_size, seq_len]
|
| | pad_idx: Index của padding token
|
| |
|
| | Returns:
|
| | mask: [batch_size, 1, 1, seq_len] (bool type)
|
| | """
|
| |
|
| | mask = (seq != pad_idx).unsqueeze(1).unsqueeze(2)
|
| | return mask
|
| |
|
| | def create_causal_mask(seq_len, device):
|
| | """
|
| | Tạo causal mask (look-ahead mask) cho decoder
|
| | Ngăn decoder nhìn thấy future tokens
|
| |
|
| | Args:
|
| | seq_len: Length của sequence
|
| | device: Device (cuda hoặc cpu)
|
| |
|
| | Returns:
|
| | mask: [1, 1, seq_len, seq_len] (bool type)
|
| | """
|
| |
|
| | mask = torch.tril(torch.ones(seq_len, seq_len, device=device))
|
| | mask = mask.bool()
|
| | mask = mask.unsqueeze(0).unsqueeze(1)
|
| | return mask
|
| |
|
| | def create_target_mask(tgt, pad_idx=0):
|
| | """
|
| | Tạo mask kết hợp cho target sequence (padding + causal)
|
| |
|
| | Args:
|
| | tgt: Target sequence [batch_size, tgt_len]
|
| | pad_idx: Index của padding token
|
| |
|
| | Returns:
|
| | mask: [batch_size, 1, tgt_len, tgt_len] (bool type)
|
| | """
|
| | batch_size, tgt_len = tgt.size()
|
| | device = tgt.device
|
| |
|
| |
|
| | padding_mask = (tgt != pad_idx).unsqueeze(1).unsqueeze(2)
|
| |
|
| |
|
| | causal_mask = create_causal_mask(tgt_len, device)
|
| |
|
| |
|
| | mask = padding_mask & causal_mask
|
| |
|
| | return mask
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | if __name__ == "__main__":
|
| | print("="*70)
|
| | print("KIỂM TRA ENCODER & DECODER")
|
| | print("="*70)
|
| |
|
| |
|
| | batch_size = 2
|
| | src_len = 10
|
| | tgt_len = 12
|
| | src_vocab_size = 10000
|
| | tgt_vocab_size = 8000
|
| | d_model = 512
|
| | n_layers = 6
|
| | n_heads = 8
|
| | d_ff = 2048
|
| | dropout = 0.1
|
| | pad_idx = 0
|
| |
|
| | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| | print(f"\nDevice: {device}")
|
| |
|
| |
|
| | src = torch.randint(1, src_vocab_size, (batch_size, src_len)).to(device)
|
| | tgt = torch.randint(1, tgt_vocab_size, (batch_size, tgt_len)).to(device)
|
| |
|
| |
|
| | src_mask = create_padding_mask(src, pad_idx).to(device)
|
| | tgt_mask = create_target_mask(tgt, pad_idx).to(device)
|
| |
|
| | print(f"\nInput shapes:")
|
| | print(f" Source: {src.shape}")
|
| | print(f" Target: {tgt.shape}")
|
| | print(f" Source mask: {src_mask.shape}, dtype: {src_mask.dtype}")
|
| | print(f" Target mask: {tgt_mask.shape}, dtype: {tgt_mask.dtype}")
|
| |
|
| |
|
| | print("\n" + "="*70)
|
| | print("Test Encoder")
|
| | print("="*70)
|
| |
|
| | encoder = Encoder(
|
| | vocab_size=src_vocab_size,
|
| | d_model=d_model,
|
| | n_layers=n_layers,
|
| | n_heads=n_heads,
|
| | d_ff=d_ff,
|
| | dropout=dropout
|
| | ).to(device)
|
| |
|
| | encoder_output = encoder(src, src_mask)
|
| | print(f"Encoder output shape: {encoder_output.shape}")
|
| | print(f"Expected: [{batch_size}, {src_len}, {d_model}]")
|
| |
|
| |
|
| | print("\n" + "="*70)
|
| | print("Test Decoder")
|
| | print("="*70)
|
| |
|
| | decoder = Decoder(
|
| | vocab_size=tgt_vocab_size,
|
| | d_model=d_model,
|
| | n_layers=n_layers,
|
| | n_heads=n_heads,
|
| | d_ff=d_ff,
|
| | dropout=dropout
|
| | ).to(device)
|
| |
|
| | decoder_output = decoder(tgt, encoder_output, src_mask, tgt_mask)
|
| | print(f"Decoder output shape: {decoder_output.shape}")
|
| | print(f"Expected: [{batch_size}, {tgt_len}, {tgt_vocab_size}]")
|
| |
|
| |
|
| | encoder_params = sum(p.numel() for p in encoder.parameters())
|
| | decoder_params = sum(p.numel() for p in decoder.parameters())
|
| |
|
| | print("\n" + "="*70)
|
| | print("THỐNG KÊ MÔ HÌNH")
|
| | print("="*70)
|
| | print(f"Encoder parameters: {encoder_params:,}")
|
| | print(f"Decoder parameters: {decoder_params:,}")
|
| | print(f"Total parameters: {encoder_params + decoder_params:,}")
|
| |
|
| | print("\n" + "="*70)
|
| | print("✓ ENCODER & DECODER HOẠT ĐỘNG ĐÚNG!")
|
| | print("="*70) |