| |
| """Neural Storyteller β Gradio App for Hugging Face Spaces (Attention model).""" |
|
|
| import os, json, pickle |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torchvision import models, transforms |
| from PIL import Image |
| import gradio as gr |
|
|
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| |
| with open("config.json", "r") as f: |
| cfg = json.load(f) |
|
|
| EMBED_SIZE = cfg["embed_size"] |
| HIDDEN_SIZE = cfg["hidden_size"] |
| NUM_REGIONS = cfg["num_regions"] |
| VOCAB_SIZE = cfg["vocab_size"] |
| MAX_LEN = cfg["max_len"] |
| DROPOUT = cfg["dropout"] |
| BEAM_WIDTH = cfg["beam_width"] |
| LENGTH_PEN = cfg.get("length_penalty", 0.7) |
| REP_PEN = cfg.get("repetition_penalty", 1.2) |
|
|
| |
| class Vocabulary: |
| PAD, START, END, UNK = '<pad>', '<start>', '<end>', '<unk>' |
|
|
| def __init__(self, freq_threshold=5): |
| self.freq_threshold = freq_threshold |
| self.word2idx = {} |
| self.idx2word = {} |
| self._idx = 0 |
|
|
| def __len__(self): |
| return len(self.word2idx) |
|
|
| |
| with open("vocab.pkl", "rb") as f: |
| vocab = pickle.load(f) |
|
|
|
|
| |
|
|
| class Encoder(nn.Module): |
| def __init__(self, feature_dim=2048, hidden_size=HIDDEN_SIZE, |
| num_regions=NUM_REGIONS, dropout=DROPOUT): |
| super().__init__() |
| self.num_regions = num_regions |
| self.hidden_size = hidden_size |
| self.project = nn.Linear(feature_dim, hidden_size * num_regions) |
| self.bn = nn.BatchNorm1d(hidden_size * num_regions) |
| self.dropout = nn.Dropout(dropout) |
| self.init_h = nn.Linear(feature_dim, hidden_size) |
| self.init_c = nn.Linear(feature_dim, hidden_size) |
|
|
| def forward(self, features): |
| proj = self.dropout(F.relu(self.bn(self.project(features)))) |
| regions = proj.view(-1, self.num_regions, self.hidden_size) |
| h0 = torch.tanh(self.init_h(features)) |
| c0 = torch.tanh(self.init_c(features)) |
| return regions, h0, c0 |
|
|
|
|
| class BahdanauAttention(nn.Module): |
| def __init__(self, hidden_size): |
| super().__init__() |
| self.W_enc = nn.Linear(hidden_size, hidden_size) |
| self.W_dec = nn.Linear(hidden_size, hidden_size) |
| self.V = nn.Linear(hidden_size, 1) |
|
|
| def forward(self, encoder_out, decoder_hidden): |
| energy = self.V(torch.tanh( |
| self.W_enc(encoder_out) + self.W_dec(decoder_hidden).unsqueeze(1) |
| )) |
| weights = F.softmax(energy.squeeze(2), dim=1) |
| context = (weights.unsqueeze(2) * encoder_out).sum(1) |
| return context, weights |
|
|
|
|
| class AttentionDecoder(nn.Module): |
| def __init__(self, vocab_size, embed_size=EMBED_SIZE, |
| hidden_size=HIDDEN_SIZE, dropout=DROPOUT): |
| super().__init__() |
| self.embed = nn.Embedding(vocab_size, embed_size, padding_idx=0) |
| self.attention = BahdanauAttention(hidden_size) |
| self.lstm_cell = nn.LSTMCell(embed_size + hidden_size, hidden_size) |
| self.fc_out = nn.Linear(hidden_size + hidden_size, vocab_size) |
| self.dropout = nn.Dropout(dropout) |
|
|
| def forward_step(self, word_idx, h, c, encoder_out): |
| embed = self.dropout(self.embed(word_idx)) |
| context, attn_w = self.attention(encoder_out, h) |
| lstm_in = torch.cat([embed, context], dim=1) |
| h, c = self.lstm_cell(lstm_in, (h, c)) |
| logits = self.fc_out(self.dropout(torch.cat([h, context], dim=1))) |
| return logits, h, c, attn_w |
|
|
|
|
| class Seq2SeqCaptioner(nn.Module): |
| def __init__(self, vocab_size, embed_size=EMBED_SIZE, |
| hidden_size=HIDDEN_SIZE, dropout=DROPOUT, |
| num_regions=NUM_REGIONS): |
| super().__init__() |
| self.encoder = Encoder(2048, hidden_size, num_regions, dropout) |
| self.decoder = AttentionDecoder(vocab_size, embed_size, hidden_size, dropout) |
| self.hidden_size = hidden_size |
|
|
| def forward(self, features, captions, teacher_forcing_ratio=1.0): |
| import random |
| B = features.size(0) |
| T = captions.size(1) - 1 |
| V = self.decoder.fc_out.out_features |
| encoder_out, h, c = self.encoder(features) |
| outputs = torch.zeros(B, T, V, device=features.device) |
| inp = captions[:, 0] |
| for t in range(T): |
| logits, h, c, _ = self.decoder.forward_step(inp, h, c, encoder_out) |
| outputs[:, t] = logits |
| if random.random() < teacher_forcing_ratio: |
| inp = captions[:, t + 1] |
| else: |
| inp = logits.argmax(dim=-1) |
| return outputs |
|
|
|
|
| |
| caption_model = Seq2SeqCaptioner(VOCAB_SIZE, EMBED_SIZE, HIDDEN_SIZE, DROPOUT, NUM_REGIONS).to(device) |
| caption_model.load_state_dict(torch.load("best_model.pth", map_location=device)) |
| caption_model.eval() |
|
|
| |
| resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT) |
| resnet = nn.Sequential(*list(resnet.children())[:-1]) |
| resnet = resnet.to(device) |
| resnet.eval() |
|
|
| img_transform = transforms.Compose([ |
| transforms.Resize((224, 224)), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], |
| std=[0.229, 0.224, 0.225]), |
| ]) |
|
|
|
|
| |
| @torch.no_grad() |
| def greedy_search_inference(feature): |
| feature = feature.unsqueeze(0).to(device) |
| encoder_out, h, c = caption_model.encoder(feature) |
|
|
| start_idx = vocab.word2idx[vocab.START] |
| end_idx = vocab.word2idx[vocab.END] |
| |
| sequence = [start_idx] |
| inp = torch.tensor([start_idx], device=device) |
|
|
| for _ in range(MAX_LEN): |
| logits, h, c, _ = caption_model.decoder.forward_step(inp, h, c, encoder_out) |
| predicted = logits.argmax(dim=-1).item() |
| |
| if predicted == end_idx: |
| break |
| |
| sequence.append(predicted) |
| inp = torch.tensor([predicted], device=device) |
|
|
| words = [vocab.idx2word[i] for i in sequence |
| if vocab.idx2word[i] not in (vocab.START, vocab.END, vocab.PAD)] |
| return " ".join(words) |
|
|
|
|
| |
| @torch.no_grad() |
| def beam_search_inference(feature, beam_width=BEAM_WIDTH, |
| length_penalty=LENGTH_PEN, |
| repetition_penalty=REP_PEN): |
| feature = feature.unsqueeze(0).to(device) |
| encoder_out, h0, c0 = caption_model.encoder(feature) |
|
|
| start_idx = vocab.word2idx[vocab.START] |
| end_idx = vocab.word2idx[vocab.END] |
| pad_idx = vocab.word2idx[vocab.PAD] |
|
|
| beams = [(0.0, [start_idx], h0, c0)] |
| completed = [] |
|
|
| for _ in range(MAX_LEN): |
| new_beams = [] |
| for log_prob, seq, h, c in beams: |
| inp = torch.tensor([seq[-1]], device=device) |
| logits, h_new, c_new, _ = caption_model.decoder.forward_step( |
| inp, h, c, encoder_out) |
| logits = logits.squeeze(0) |
|
|
| for prev_tok in set(seq): |
| if prev_tok not in (start_idx, end_idx, pad_idx): |
| logits[prev_tok] /= repetition_penalty |
|
|
| log_probs = F.log_softmax(logits, dim=-1) |
| topk_lp, topk_idx = log_probs.topk(beam_width) |
|
|
| for k in range(beam_width): |
| token = topk_idx[k].item() |
| new_lp = log_prob + topk_lp[k].item() |
| new_seq = seq + [token] |
| if token == end_idx: |
| score = new_lp / (len(new_seq) ** length_penalty) |
| completed.append((score, new_seq)) |
| else: |
| new_beams.append((new_lp, new_seq, h_new, c_new)) |
|
|
| new_beams.sort(key=lambda x: x[0], reverse=True) |
| beams = new_beams[:beam_width] |
| if not beams or len(completed) >= beam_width: |
| break |
|
|
| if not completed and beams: |
| for lp, seq, _, _ in beams: |
| completed.append((lp / (len(seq) ** length_penalty), seq)) |
|
|
| completed.sort(key=lambda x: x[0], reverse=True) |
| best_seq = completed[0][1] if completed else [start_idx] |
|
|
| words = [vocab.idx2word[i] for i in best_seq |
| if vocab.idx2word[i] not in (vocab.START, vocab.END, vocab.PAD)] |
| return " ".join(words) |
|
|
|
|
| |
| def predict(image, search_method, beam_width, length_penalty, repetition_penalty): |
| """Take a PIL image -> return generated caption string.""" |
| if image is None: |
| return """ |
| <div style="background: linear-gradient(135deg, #f093fb 0%, #f5576c 100%); padding: 30px; border-radius: 15px; text-align: center;"> |
| <p style="color: white; font-size: 20px; margin: 0;">β οΈ Please upload an image first</p> |
| </div> |
| """ |
| |
| image = image.convert("RGB") |
| img_tensor = img_transform(image).unsqueeze(0).to(device) |
|
|
| with torch.no_grad(): |
| feature = resnet(img_tensor).view(1, -1).squeeze(0) |
|
|
| if search_method == "Greedy Search (Fast)": |
| caption = greedy_search_inference(feature) |
| method_info = "π Generated using Greedy Search" |
| else: |
| caption = beam_search_inference( |
| feature, |
| beam_width=int(beam_width), |
| length_penalty=length_penalty, |
| repetition_penalty=repetition_penalty |
| ) |
| method_info = f"π Generated using Beam Search (width={int(beam_width)})" |
| |
| |
| return f""" |
| <div style="background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); padding: 40px; border-radius: 15px; box-shadow: 0 8px 32px rgba(0,0,0,0.1);"> |
| <p style="color: white; font-size: 28px; font-weight: 600; text-align: center; line-height: 1.6; margin: 0; text-shadow: 2px 2px 4px rgba(0,0,0,0.2);"> |
| "{caption}" |
| </p> |
| <p style="color: rgba(255,255,255,0.9); font-size: 14px; text-align: center; margin-top: 20px; font-style: italic;"> |
| {method_info} |
| </p> |
| </div> |
| """ |
|
|
|
|
| |
| with gr.Blocks(theme=gr.themes.Soft(), title="Neural Storyteller", css=""" |
| .caption-box { |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); |
| padding: 30px; |
| border-radius: 15px; |
| box-shadow: 0 8px 32px rgba(0,0,0,0.1); |
| margin: 20px 0; |
| } |
| .caption-text { |
| color: white; |
| font-size: 24px; |
| font-weight: 600; |
| text-align: center; |
| line-height: 1.6; |
| text-shadow: 2px 2px 4px rgba(0,0,0,0.2); |
| } |
| .method-info { |
| color: rgba(255,255,255,0.9); |
| font-size: 14px; |
| text-align: center; |
| margin-top: 15px; |
| font-style: italic; |
| } |
| """) as demo: |
| gr.Markdown(""" |
| # π§ Neural Storyteller β AI Image Captioning |
| |
| Upload any image and let the AI generate a natural language description using a **Seq2Seq model** |
| with ResNet50 encoder and Attention-based LSTM decoder, trained on Flickr30k dataset. |
| """) |
| |
| with gr.Row(): |
| with gr.Column(scale=1): |
| image_input = gr.Image(type="pil", label="πΈ Upload Your Image", height=400) |
| |
| with gr.Column(scale=1): |
| gr.Markdown("### βοΈ Generation Settings") |
| |
| search_method = gr.Radio( |
| choices=["Greedy Search (Fast)", "Beam Search (Better Quality)"], |
| value="Beam Search (Better Quality)", |
| label="π― Decoding Method", |
| info="Greedy is faster, Beam produces better results" |
| ) |
| |
| with gr.Accordion("π§ Advanced Options (Beam Search Only)", open=False): |
| beam_width = gr.Slider( |
| minimum=1, maximum=10, value=5, step=1, |
| label="Beam Width", |
| info="Number of candidates to explore (higher = better quality but slower)" |
| ) |
| |
| length_penalty = gr.Slider( |
| minimum=0.0, maximum=2.0, value=0.7, step=0.1, |
| label="Length Penalty", |
| info="Controls caption length (lower = shorter, higher = longer)" |
| ) |
| |
| repetition_penalty = gr.Slider( |
| minimum=1.0, maximum=2.0, value=1.2, step=0.1, |
| label="Repetition Penalty", |
| info="Reduces word repetition (higher = less repetition)" |
| ) |
| |
| generate_btn = gr.Button("β¨ Generate Caption", variant="primary", size="lg", scale=1) |
| |
| |
| gr.Markdown("## π Generated Caption") |
| output_text = gr.HTML(label="") |
| |
| with gr.Accordion("π‘ Tips & Model Details", open=False): |
| gr.Markdown(""" |
| ### Tips: |
| - Try both **Greedy** and **Beam** search to compare results |
| - Increase **Beam Width** for more diverse captions |
| - Adjust **Length Penalty** if captions are too short/long |
| - Use **Repetition Penalty** to avoid repeated words |
| |
| ### Model Details: |
| - **Encoder**: ResNet50 (pretrained on ImageNet) |
| - **Decoder**: Attention-based LSTM |
| - **Training Data**: Flickr30k dataset |
| - **Vocabulary**: ~8000 words |
| """) |
| |
| generate_btn.click( |
| fn=predict, |
| inputs=[image_input, search_method, beam_width, length_penalty, repetition_penalty], |
| outputs=output_text |
| ) |
| |
| gr.Markdown(""" |
| --- |
| <p style="text-align: center; color: #666;"> |
| Built with PyTorch, Gradio, and β€οΈ | Model trained on Flickr30k |
| </p> |
| """) |
|
|
| if __name__ == "__main__": |
| demo.launch() |
|
|