| import torch
|
| import torch.nn as nn
|
| import math
|
|
|
| class PositionalEncoding(nn.Module):
|
| """Positional encoding module."""
|
|
|
| def __init__(self, d_model, max_len=5000):
|
| super().__init__()
|
|
|
|
|
| pe = torch.zeros(max_len, d_model)
|
| position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
| div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
|
|
|
| pe[:, 0::2] = torch.sin(position * div_term)
|
| pe[:, 1::2] = torch.cos(position * div_term)
|
| pe = pe.unsqueeze(0)
|
|
|
| self.register_buffer('pe', pe)
|
|
|
| def forward(self, x):
|
| """
|
| Args:
|
| x: Tensor of shape (batch_size, seq_len, d_model)
|
| """
|
| return x + self.pe[:, :x.size(1), :]
|
|
|
|
|
| class DecoderBlock(nn.Module):
|
| def __init__(self, d_model, num_heads, dim_ff, dropout=0.2):
|
| super().__init__()
|
|
|
| self.self_attn = nn.MultiheadAttention(
|
| d_model, num_heads, dropout=dropout, batch_first=True
|
| )
|
|
|
| self.cross_attn = nn.MultiheadAttention(
|
| d_model, num_heads, dropout=dropout, batch_first=True
|
| )
|
|
|
| self.ffn = nn.Sequential(
|
| nn.Linear(d_model, dim_ff),
|
| nn.GELU(),
|
| nn.Dropout(dropout),
|
| nn.Linear(dim_ff, d_model),
|
| )
|
|
|
| self.norm1 = nn.LayerNorm(d_model)
|
| self.norm2 = nn.LayerNorm(d_model)
|
| self.norm3 = nn.LayerNorm(d_model)
|
|
|
| self.dropout = nn.Dropout(dropout)
|
|
|
| def forward(self, x, memory, tgt_mask,tgt_key_padding_mask):
|
|
|
|
|
|
|
|
|
| attn_out, _ = self.self_attn(
|
| x, x, x, attn_mask=tgt_mask,
|
| key_padding_mask=tgt_key_padding_mask
|
| )
|
| x = self.norm1(x + self.dropout(attn_out))
|
|
|
|
|
| attn_out, _ = self.cross_attn(
|
| x, memory, memory
|
| )
|
| x = self.norm2(x + self.dropout(attn_out))
|
|
|
|
|
| ffn_out = self.ffn(x)
|
| x = self.norm3(x + self.dropout(ffn_out))
|
|
|
| return x
|
|
|
| class TransformerDecoder(nn.Module):
|
| def __init__(
|
| self,
|
| vocab_size,
|
| pad_id,
|
| d_model=512,
|
| num_layers=6,
|
| num_heads=8,
|
| dim_ff=2048,
|
| max_len=25,
|
| dropout=0.1
|
| ):
|
| super().__init__()
|
| self.pad_id = pad_id
|
| self.d_model = d_model
|
| self.max_len = max_len
|
|
|
|
|
| self.embedding = nn.Embedding(vocab_size, d_model)
|
| self.pos_encoder = PositionalEncoding(d_model, max_len=self.max_len)
|
|
|
| self.layers = nn.ModuleList([
|
| DecoderBlock(d_model, num_heads, dim_ff, dropout)
|
| for _ in range(num_layers)
|
| ])
|
|
|
| self.fc_out = nn.Linear(d_model, vocab_size)
|
| self.dropout = nn.Dropout(dropout)
|
|
|
| self._init_weights()
|
|
|
| def _init_weights(self):
|
| """Initialize weights."""
|
| initrange = 0.1
|
| self.embedding.weight.data.uniform_(-initrange, initrange)
|
| self.fc_out.bias.data.zero_()
|
| self.fc_out.weight.data.uniform_(-initrange, initrange)
|
|
|
| def generate_square_subsequent_mask(self, sz):
|
| """Generate causal mask for decoder."""
|
| return torch.triu(torch.ones(sz, sz), diagonal=1).bool()
|
|
|
|
|
| def forward(self, captions, img_features, tgt_mask=None, tgt_padding_mask=None):
|
| """
|
| captions: (B, L)
|
| memory: (B, N, D)
|
| """
|
|
|
| B, L = captions.shape
|
| device = captions.device
|
|
|
| src = img_features
|
|
|
|
|
| tgt = self.dropout(self.pos_encoder(self.embedding(captions) * math.sqrt(self.d_model)))
|
|
|
|
|
| if tgt_mask is None:
|
| tgt_mask = self.generate_square_subsequent_mask(tgt.size(1)).to(tgt.device)
|
|
|
| tgt_key_padding_mask = (captions == self.pad_id)
|
|
|
| for layer in self.layers:
|
| tgt = layer(tgt, src, tgt_mask, tgt_key_padding_mask)
|
|
|
| logits = self.fc_out(tgt)
|
| return logits
|
|
|