Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
# decoder.py | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import math | |
class PositionalEncoding(nn.Module): | |
def __init__(self, d_model, max_len=5000): | |
super(PositionalEncoding, self).__init__() | |
pe = torch.zeros(max_len, d_model) # [max_len, d_model] | |
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) # [max_len, 1] | |
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) | |
pe[:, 0::2] = torch.sin(position * div_term) # dim 2i | |
pe[:, 1::2] = torch.cos(position * div_term) # dim 2i+1 | |
pe = pe.unsqueeze(1) # [max_len, 1, d_model] | |
self.register_buffer('pe', pe) | |
def forward(self, x): | |
# x: [seq_len, batch_size, d_model] | |
x = x + self.pe[:x.size(0)] | |
return x | |
def generate_square_subsequent_mask(sz): | |
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) | |
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) | |
return mask | |
class TransformerDecoder(nn.Module): | |
def __init__(self, vocab_size, hidden_dim=512, encoder_dim=768, num_layers=2): | |
super(TransformerDecoder, self).__init__() | |
self.vocab_size = vocab_size | |
self.embedding = nn.Embedding(vocab_size, hidden_dim) | |
self.positional_encoding = PositionalEncoding(hidden_dim) | |
decoder_layer = nn.TransformerDecoderLayer(d_model=hidden_dim, nhead=8) | |
self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers) | |
self.fc_out = nn.Linear(hidden_dim, vocab_size) | |
# Project ViT encoder output to decoder hidden_dim if needed | |
self.encoder_projection = nn.Linear(encoder_dim, hidden_dim) | |
def forward(self, input_ids, encoder_outputs, tgt_attention_mask=None): | |
embedded = self.embedding(input_ids).permute(1, 0, 2) | |
embedded = self.positional_encoding(embedded) | |
memory = self.encoder_projection(encoder_outputs).unsqueeze(0) | |
tgt_mask = generate_square_subsequent_mask(embedded.size(0)).to(embedded.device) | |
if tgt_attention_mask is not None: | |
tgt_key_padding_mask = ~tgt_attention_mask.bool() | |
else: | |
tgt_key_padding_mask = None | |
output = self.transformer_decoder( | |
tgt=embedded, | |
memory=memory, | |
tgt_mask=tgt_mask, | |
tgt_key_padding_mask=tgt_key_padding_mask | |
) | |
output = self.fc_out(output).permute(1, 0, 2) | |
return output | |
def generate( | |
self, | |
encoder_outputs, | |
start_token_id=101, # [CLS] token for BERT | |
eos_token_id=102, | |
max_length=50, | |
mode="greedy", # "greedy", "beam", "topk", "topp" | |
num_beams=3, | |
top_k=50, | |
top_p=0.95, | |
length_penalty=1.0 | |
): | |
device = encoder_outputs.device | |
""" | |
Generate caption using specified decoding mode. | |
""" | |
batch_size = encoder_outputs.size(0) | |
input_ids = torch.full( | |
(batch_size, 1), | |
start_token_id, | |
dtype=torch.long, | |
device=device | |
) | |
if mode == "beam": | |
return self._generate_beam_search( | |
encoder_outputs, | |
input_ids, | |
max_length, | |
eos_token_id, | |
num_beams, | |
length_penalty | |
) | |
# Greedy or sampling | |
generated = input_ids | |
for _ in range(max_length): | |
logits = self.forward(generated, encoder_outputs) # (batch, seq_len, vocab) | |
next_token_logits = logits[:, -1, :] # (batch, vocab) | |
if mode == "greedy": | |
next_token = next_token_logits.argmax(dim=-1, keepdim=True) | |
elif mode == "topk": | |
probs = F.softmax(next_token_logits, dim=-1) | |
topk_probs, topk_indices = torch.topk(probs, top_k) | |
next_token = topk_indices[ | |
torch.arange(probs.size(0)), | |
torch.multinomial(topk_probs, num_samples=1).squeeze(-1) | |
].unsqueeze(-1) | |
elif mode == "topp": | |
probs = F.softmax(next_token_logits, dim=-1) | |
sorted_probs, sorted_indices = torch.sort(probs, descending=True) | |
cumulative_probs = torch.cumsum(sorted_probs, dim=-1) | |
# Remove tokens with cumulative probs above threshold | |
sorted_mask = cumulative_probs <= top_p | |
sorted_mask[..., 0] = 1 # Always keep at least 1 token | |
filtered_probs = sorted_probs * sorted_mask | |
filtered_probs /= filtered_probs.sum(dim=-1, keepdim=True) | |
next_token = sorted_indices[ | |
torch.arange(probs.size(0)), | |
torch.multinomial(filtered_probs, num_samples=1).squeeze(-1) | |
].unsqueeze(-1) | |
else: | |
raise ValueError(f"Unknown mode: {mode}") | |
generated = torch.cat((generated, next_token), dim=1) | |
if eos_token_id is not None: | |
if (next_token == eos_token_id).all(): | |
break | |
return generated[:, 1:] # Remove BOS if needed | |
def _generate_beam_search( | |
self, | |
encoder_outputs, | |
input_ids, | |
max_length=50, | |
eos_token_id=102, | |
num_beams=3, | |
length_penalty=1.0 | |
): | |
""" | |
Custom beam search decoder for batch_size = 1. | |
""" | |
device = encoder_outputs.device | |
batch_size = encoder_outputs.size(0) | |
vocab_size = self.vocab_size | |
# Assume batch_size = 1 for simplicity | |
assert batch_size == 1, "Basic beam search only supports batch size 1 here." | |
# Initialize beams | |
beam_sequences = [input_ids] * num_beams | |
beam_scores = torch.zeros(num_beams, device=device) | |
finished_sequences = [] | |
finished_scores = [] | |
for step in range(max_length): | |
all_candidates = [] | |
for beam_idx in range(num_beams): | |
seq = beam_sequences[beam_idx] | |
score = beam_scores[beam_idx] | |
logits = self.forward(seq, encoder_outputs) # (1, seq_len, vocab) | |
next_token_logits = logits[:, -1, :] # (1, vocab) | |
log_probs = F.log_softmax(next_token_logits, dim=-1).squeeze(0) # (vocab,) | |
for token_id in range(vocab_size): | |
new_seq = torch.cat([seq, torch.tensor([[token_id]], device=device)], dim=1) | |
new_score = score + log_probs[token_id] | |
all_candidates.append((new_seq, new_score)) | |
# Get top beams | |
all_candidates.sort(key=lambda x: x[1], reverse=True) | |
beam_sequences = [] | |
beam_scores = [] | |
for seq, score in all_candidates[:num_beams]: | |
if eos_token_id is not None and seq[0, -1].item() == eos_token_id: | |
finished_sequences.append(seq) | |
finished_scores.append(score) | |
else: | |
beam_sequences.append(seq) | |
beam_scores.append(score) | |
beam_scores = torch.stack(beam_scores) if beam_scores else torch.tensor([], device=device) | |
# Early stopping if all beams ended | |
if len(beam_sequences) == 0: | |
break | |
# Add unfinished beams to finished | |
if not finished_sequences: | |
finished_sequences = beam_sequences | |
finished_scores = beam_scores | |
# Length penalty | |
finished_scores = [s / (len(seq[0]) ** length_penalty) for seq, s in zip(finished_sequences, finished_scores)] | |
# Pick best | |
best_idx = torch.tensor(finished_scores).argmax().item() | |
best_seq = finished_sequences[best_idx] | |
return best_seq[:, 1:] # remove BOS if needed | |