Spaces:
Sleeping
Sleeping
| import torch | |
| import pickle | |
| from PIL import Image | |
| import torchvision.transforms as T | |
| import os | |
| import sys | |
| # # Ensure project root is on sys.path | |
| # PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__))) | |
| # if PROJECT_ROOT not in sys.path: | |
| # sys.path.insert(0, PROJECT_ROOT) | |
| # # Backward compatibility: vocab.pkl was saved when Vocabulary lived at | |
| # # 'data_processing_vocabulary'. Register an alias so pickle can find it | |
| # # at its new location 'src.data_processing_vocabulary'. | |
| # from src import data_processing_vocabulary | |
| # sys.modules["data_processing_vocabulary"] = data_processing_vocabulary | |
| from image_captioning_model import ImageCaptioningModel | |
| DEVICE = "cpu" | |
| class CaptionGenerator: | |
| def __init__( | |
| self, | |
| model_path="best_model.pth", | |
| vocab_path="vocab.pkl", | |
| use_vit=True | |
| ): | |
| print("Loading vocab...") | |
| print("Building model...") | |
| print("Loading state dict...") | |
| # Load vocab | |
| with open(vocab_path, "rb") as f: | |
| self.vocab = pickle.load(f) | |
| print(f"Vocab loaded with {len(self.vocab)} words.") | |
| # Build model | |
| self.model = ImageCaptioningModel( | |
| vocab_size=len(self.vocab), | |
| pad_id=self.vocab.word2idx["<pad>"], | |
| use_vit=use_vit | |
| ) | |
| print("Model built.") | |
| # Load weights | |
| state_dict = torch.load(model_path, map_location=DEVICE) | |
| self.model.load_state_dict(state_dict) | |
| print(f"Model weights loaded from {model_path}.") | |
| # Eval mode | |
| self.model.eval() | |
| self.model.to(DEVICE) | |
| # Preprocess | |
| self.transform = T.Compose([ | |
| T.Resize((224, 224)), | |
| T.ToTensor(), | |
| T.Normalize( | |
| mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225] | |
| ) | |
| ]) | |
| def preprocess(self, image): | |
| if not isinstance(image, Image.Image): | |
| image = Image.open(image).convert("RGB") | |
| image = self.transform(image) | |
| return image.unsqueeze(0) | |
| def generate(self, image, max_len=30, decoding="beam", beam_width=5): | |
| """ | |
| Generate a caption for the given image. | |
| Args: | |
| image: PIL Image or file path. | |
| max_len: Maximum caption length. | |
| decoding: 'greedy' or 'beam'. | |
| beam_width: Number of beams (only used when decoding='beam'). | |
| Returns: | |
| Generated caption string. | |
| """ | |
| image = self.preprocess(image).to(DEVICE) | |
| if decoding == "beam": | |
| caption = self.model.predict_caption_beam( | |
| image=image, | |
| vocab=self.vocab, | |
| beam_width=beam_width, | |
| max_len=max_len, | |
| device=DEVICE | |
| ) | |
| else: | |
| caption = self.model.predict_caption( | |
| image=image, | |
| vocab=self.vocab, | |
| max_len=max_len, | |
| device=DEVICE | |
| ) | |
| return caption | |