Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from cnn_encoder import CNNEncoder | |
| from vit_encoder import ViTEncoder | |
| from transformer_encoder import TransformerEncoder | |
| from transformer_decoder import TransformerDecoder | |
| class ImageCaptioningModel(nn.Module): | |
| def __init__( | |
| self, | |
| vocab_size, | |
| pad_id, | |
| d_model=512, | |
| num_encoder_layers=6, | |
| num_decoder_layers=6, | |
| num_heads=8, | |
| dim_ff=2048, | |
| max_seq_len=50, | |
| dropout=0.1, | |
| freeze_backbone=True, | |
| use_vit=False | |
| ): | |
| super().__init__() | |
| self.use_vit = use_vit | |
| if self.use_vit: | |
| self.encoder = ViTEncoder(d_model=d_model, freeze_backbone=freeze_backbone) | |
| else: | |
| self.encoder = CNNEncoder(d_model=d_model, freeze_backbone=freeze_backbone) | |
| self.transformer_encoder = TransformerEncoder( | |
| d_model=d_model, | |
| num_layers=num_encoder_layers, | |
| num_heads=num_heads, | |
| dim_ff=dim_ff, | |
| max_len=200, | |
| dropout=dropout, | |
| use_vit=self.use_vit | |
| ) | |
| self.decoder = TransformerDecoder( | |
| vocab_size=vocab_size, | |
| pad_id=pad_id, | |
| d_model=d_model, | |
| num_layers=num_decoder_layers, | |
| num_heads=num_heads, | |
| dim_ff=dim_ff, | |
| max_len=max_seq_len, | |
| dropout=dropout, | |
| ) | |
| self.d_model = d_model | |
| def generate_square_subsequent_mask(self, sz): | |
| return self.decoder.generate_square_subsequent_mask(sz) | |
| def unfreeze_encoder(self, unfreeze=True): | |
| self.encoder.unfreeze_backbone(unfreeze) | |
| def encode_image(self, images): | |
| img_features = self.encoder(images) | |
| return self.transformer_encoder(img_features) | |
| def forward(self, images, captions, tgt_mask=None, tgt_padding_mask=None): | |
| img_features = self.encode_image(images) | |
| return self.decoder( | |
| captions=captions, | |
| img_features=img_features, | |
| tgt_mask=tgt_mask, | |
| tgt_padding_mask=tgt_padding_mask, | |
| ) | |
| def predict_caption_beam(self, image, vocab, beam_width=5, max_len=50, alpha=0.7, device="cpu"): | |
| """ | |
| Generates a caption using beam search decoding. | |
| Args: | |
| image: Preprocessed image tensor of shape (1, 3, H, W). | |
| vocab: Vocabulary object with word2idx and idx2word mappings. | |
| beam_width: Number of candidate sequences to keep at each step. | |
| max_len: Maximum caption length. | |
| alpha: Length normalization penalty. Higher values favor longer captions. | |
| device: Device to run inference on. | |
| Returns: | |
| The highest-scoring caption as a string. | |
| """ | |
| self.eval() | |
| with torch.no_grad(): | |
| img_features = self.encode_image(image) | |
| bos_idx = vocab.word2idx["<bos>"] | |
| eos_idx = vocab.word2idx["<eos>"] | |
| # Each beam: (log_probability, token_indices_list) | |
| beams = [(0.0, [bos_idx])] | |
| completed = [] | |
| for _ in range(max_len): | |
| candidates = [] | |
| for score, seq in beams: | |
| # If this beam already ended, don't expand it | |
| if seq[-1] == eos_idx: | |
| completed.append((score, seq)) | |
| continue | |
| tgt_tensor = torch.tensor(seq).unsqueeze(0).to(device) | |
| tgt_mask = self.generate_square_subsequent_mask(len(seq)).to(device) | |
| logits = self.decoder( | |
| captions=tgt_tensor, | |
| img_features=img_features, | |
| tgt_mask=tgt_mask, | |
| tgt_padding_mask=None, | |
| ) | |
| # Get log-probabilities for the last token | |
| log_probs = F.log_softmax(logits[:, -1, :], dim=-1).squeeze(0) | |
| # Select top-k tokens | |
| topk_log_probs, topk_indices = log_probs.topk(beam_width) | |
| for log_p, idx in zip(topk_log_probs.tolist(), topk_indices.tolist()): | |
| new_seq = seq + [idx] | |
| new_score = score + log_p | |
| candidates.append((new_score, new_seq)) | |
| # Keep top beam_width candidates (sorted by score) | |
| candidates.sort(key=lambda x: x[0], reverse=True) | |
| beams = candidates[:beam_width] | |
| # Early stop: all beams have ended | |
| if not beams: | |
| break | |
| # Add any remaining incomplete beams to completed | |
| completed.extend(beams) | |
| # Length-normalized scoring: score / (length ^ alpha) | |
| def normalize_score(score, length): | |
| return score / (length ** alpha) | |
| completed.sort( | |
| key=lambda x: normalize_score(x[0], len(x[1])), | |
| reverse=True | |
| ) | |
| best_seq = completed[0][1] | |
| # Convert indices to words, skipping special tokens | |
| tokens = [] | |
| for idx in best_seq: | |
| word = vocab.idx2word.get(idx, "<unk>") | |
| if word not in ["<bos>", "<eos>", "<pad>"]: | |
| tokens.append(word) | |
| return " ".join(tokens) | |
| def predict_caption(self, image, vocab, max_len=50, device="cpu"): | |
| ''' | |
| Generates a caption using greedy decoding. | |
| Args: | |
| image: Preprocessed image tensor of shape (1, 3, H, W). | |
| vocab: Vocabulary object with word2idx and idx2word mappings. | |
| max_len: Maximum caption length. | |
| device: Device to run inference on. | |
| Returns: | |
| The generated caption as a string. | |
| ''' | |
| self.eval() | |
| with torch.no_grad(): | |
| img_features = self.encode_image(image) | |
| start_token_idx = vocab.word2idx["<bos>"] | |
| end_token_idx = vocab.word2idx["<eos>"] | |
| tgt_indices = [start_token_idx] | |
| for _ in range(max_len): | |
| tgt_tensor = torch.tensor(tgt_indices).unsqueeze(0).to(device) | |
| tgt_mask = self.generate_square_subsequent_mask(len(tgt_indices)).to(device) | |
| logits = self.decoder( | |
| captions=tgt_tensor, | |
| img_features=img_features, | |
| tgt_mask=tgt_mask, | |
| tgt_padding_mask=None, | |
| ) | |
| last_token_logits = logits[:, -1, :] | |
| predicted_id = last_token_logits.argmax(dim=-1).item() | |
| if predicted_id == end_token_idx: | |
| break | |
| tgt_indices.append(predicted_id) | |
| tokens = [] | |
| for idx in tgt_indices: | |
| word = vocab.idx2word.get(idx, "<unk>") | |
| if word not in ["<bos>", "<eos>", "<pad>"]: | |
| tokens.append(word) | |
| return " ".join(tokens) | |