# app.py import torch import torchvision.transforms as transforms import torch.nn as nn import torchvision.models as models from PIL import Image import os import nltk import argparse from collections import Counter # Needed for Vocabulary unpickling from torch.serialization import safe_globals # For secure loading import gradio as gr # Import Gradio # --- 1. Define Classes EXACTLY as during training --- # Paste the final versions of Vocabulary, EncoderCNN, DecoderRNN here. # This is CRUCIAL for loading the model correctly. class Vocabulary: # --- Paste your final Vocabulary class definition here --- def __init__(self, freq_threshold=5): self.freq_threshold = freq_threshold self.word2idx = {"": 0, "": 1, "": 2, "": 3} self.idx2word = {0: "", 1: "", 2: "", 3: ""} self.idx = 4 def build_vocabulary(self, sentence_list): # Needs to be present for unpickling frequencies = Counter() for sentence in sentence_list: tokens = nltk.tokenize.word_tokenize(sentence.lower()); frequencies.update(tokens) filtered_freq = {word: freq for word, freq in frequencies.items() if freq >= self.freq_threshold} for word in filtered_freq: if word not in self.word2idx: self.word2idx[word] = self.idx; self.idx2word[self.idx] = word; self.idx += 1 def numericalize(self, text): tokens = nltk.tokenize.word_tokenize(text.lower()) return [self.word2idx.get(token, self.word2idx[""]) for token in tokens] def __len__(self): return self.idx class EncoderCNN(nn.Module): # --- Paste your final EncoderCNN class definition here --- def __init__(self, embed_size, dropout_p=0.5, fine_tune=True): super(EncoderCNN, self).__init__() try: # Handle potential torchvision version differences resnet = models.resnet101(weights=models.ResNet101_Weights.IMAGENET1K_V1) except TypeError: resnet = models.resnet101(pretrained=True) for param in resnet.parameters(): param.requires_grad = False # Fine-tune status doesn't matter for eval, but architecture must match self.resnet = nn.Sequential(*list(resnet.children())[:-1]) self.fc = nn.Linear(resnet.fc.in_features, embed_size) self.bn = nn.BatchNorm1d(embed_size, momentum=0.01) self.dropout = nn.Dropout(dropout_p) def forward(self, images): with torch.no_grad(): features = self.resnet(images) features = features.squeeze(3).squeeze(2) features = self.fc(features) features = self.bn(features) return features class DecoderRNN(nn.Module): # --- Paste your final DecoderRNN class definition here --- # --- including forward_step and init_hidden_state --- def __init__(self, embed_size, hidden_size, vocab_size, num_layers=1, dropout_p=0.5): super().__init__() self.embed = nn.Embedding(vocab_size, embed_size) self.embed_dropout = nn.Dropout(dropout_p) lstm_dropout = dropout_p if num_layers > 1 else 0 self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True, dropout=lstm_dropout) self.dropout = nn.Dropout(dropout_p) self.linear = nn.Linear(hidden_size, vocab_size) self.init_h = nn.Linear(embed_size, hidden_size) self.init_c = nn.Linear(embed_size, hidden_size) self.num_layers = num_layers def init_hidden_state(self, features): h0 = self.init_h(features).unsqueeze(0) c0 = self.init_c(features).unsqueeze(0) if self.num_layers > 1: h0 = h0.repeat(self.num_layers, 1, 1) c0 = c0.repeat(self.num_layers, 1, 1) return (h0, c0) def forward_step(self, embedded_input, hidden_state): lstm_out, hidden_state = self.lstm(embedded_input, hidden_state) outputs = self.linear(lstm_out.squeeze(1)) return outputs, hidden_state # --- End Class Definitions --- # --- Configuration --- CHECKPOINT_PATH = 'best_model_improved.pth' DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Use CPU for typical Spaces hardware MAX_LEN = 25 # --- Global variables for loaded model (load ONCE) --- encoder_global = None decoder_global = None vocab_global = None transform_global = None # --- Model Loading Function --- def load_model_and_vocab(): global encoder_global, decoder_global, vocab_global, transform_global if encoder_global is not None: # Already loaded print("Model already loaded.") return print(f"Loading checkpoint: {CHECKPOINT_PATH} onto device: {DEVICE}") if not os.path.exists(CHECKPOINT_PATH): raise FileNotFoundError(f"Error: Checkpoint file not found at {CHECKPOINT_PATH}") try: with safe_globals([Vocabulary, Counter]): # Allowlist custom classes checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE) except Exception as e: print(f"Error loading checkpoint with safe_globals: {e}. Trying weights_only=False...") try: checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE, weights_only=False) except Exception as e2: raise RuntimeError(f"Failed to load checkpoint: {e2}") # Load vocabulary and hyperparameters vocab_global = checkpoint['vocab'] embed_size = checkpoint.get('embed_size', 256) hidden_size = checkpoint.get('hidden_size', 512) num_layers = checkpoint.get('num_layers', 1) dropout_prob = checkpoint.get('dropout_prob', 0.5) fine_tune_encoder = checkpoint.get('fine_tune_encoder', True) # Match saved config vocab_size = len(vocab_global) print(f"Vocabulary loaded (size: {vocab_size}). Hyperparameters extracted.") # Initialize models encoder_global = EncoderCNN(embed_size, dropout_p=dropout_prob, fine_tune=fine_tune_encoder).to(DEVICE) decoder_global = DecoderRNN(embed_size, hidden_size, vocab_size, num_layers, dropout_p=dropout_prob).to(DEVICE) encoder_global.load_state_dict(checkpoint['encoder_state_dict']) decoder_global.load_state_dict(checkpoint['decoder_state_dict']) # Set to evaluation mode encoder_global.eval() decoder_global.eval() print("Models initialized, weights loaded, and set to eval mode.") # Define image transformation (same as validation/inference) transform_global = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) print("Transforms defined.") # --- Helper: Tokens to Sentence --- def tokens_to_sentence(tokens, vocab): words = [vocab.idx2word.get(token, "") for token in tokens] words = [word for word in words if word not in ["", "", ""]] return " ".join(words) # --- Inference Function for Gradio --- def predict(input_image): """Generates caption for a PIL image input from Gradio.""" if encoder_global is None or decoder_global is None or vocab_global is None or transform_global is None: print("Error: Model not loaded.") # Optionally try loading here, but it's better to load upfront # load_model_and_vocab() # if encoder_global is None: # Check again return "Error: Model components not loaded. Check logs." # 1. Preprocess Image try: image_tensor = transform_global(input_image) image_tensor = image_tensor.unsqueeze(0).to(DEVICE) # Add batch dim except Exception as e: print(f"Error transforming image: {e}") return f"Error processing image: {e}" # 2. Generate Caption (Greedy Search) generated_indices = [] with torch.no_grad(): try: features = encoder_global(image_tensor) hidden_state = decoder_global.init_hidden_state(features) start_token_idx = vocab_global.word2idx[""] inputs = torch.tensor([[start_token_idx]], dtype=torch.long).to(DEVICE) for _ in range(MAX_LEN): embedded = decoder_global.embed(inputs) outputs, hidden_state = decoder_global.forward_step(embedded, hidden_state) predicted_idx = outputs.argmax(1) predicted_word_idx = predicted_idx.item() if predicted_word_idx == vocab_global.word2idx[""]: break # Stop if is predicted generated_indices.append(predicted_word_idx) inputs = predicted_idx.unsqueeze(1) # Prepare for next step except Exception as e: print(f"Error during caption generation: {e}") return f"Error during generation: {e}" # 3. Convert to Sentence caption = tokens_to_sentence(generated_indices, vocab_global) return caption # --- Load Model when script starts --- # Ensure NLTK data is available if needed by tokenizer within Vocab class try: nltk.data.find('tokenizers/punkt') except LookupError: print("NLTK 'punkt' tokenizer data not found. Downloading...") nltk.download('punkt', quiet=True) load_model_and_vocab() # Load model into global variables # --- Create Gradio Interface --- title = "Image Captioning Demo" description = "Upload an image and this model (ResNet101 Encoder + LSTM Decoder) will generate a caption. Trained on COCO." # Optional: Define example images (paths relative to the app.py file) example_list = [["images/example1.jpg"], ["images/example2.jpg"]] if os.path.exists("images") else None iface = gr.Interface( fn=predict, # The function to call for inference inputs=gr.Image(type="pil", label="Upload Image"), # Input: Image upload, provide PIL image to fn outputs=gr.Textbox(label="Generated Caption"), # Output: Textbox title=title, description=description, examples=example_list, # Optional: Provide examples allow_flagging="never" # Optional: Disable flagging ) # --- Launch the Gradio app --- if __name__ == "__main__": iface.launch() # Share=True is not needed for Spaces, it's handled automatically