import torch import os import torchvision.transforms as transforms from PIL import Image from model import CNNtoRNN import pandas as pd from loader import get_loader def inference(): transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) image_index=100 train_loader,dataset=get_loader(root_folder='FlickrDataset/Images',annotation_file='FlickrDataset/Captions/captions.txt',transform=transform,num_workers=2) df=pd.read_csv("FlickrDataset/Captions/captions.txt") imagepath="FlickrDataset/Images/" images=os.listdir(imagepath) im=Image.open(os.path.join(imagepath,images[image_index])) im.show() device=torch.device('cuda' if torch.cuda.is_available() else "cpu") filepath="ImageCaptioningusingLSTM.pth" model=CNNtoRNN(embed_size=256,hidden_size=256,vocab_size=len(dataset.vocab),num_layers=1).to(device) model.load_state_dict(torch.load(filepath)) model.eval() image=transform(im.convert("RGB")).unsqueeze(0) output=model.caption_image(image.to(device),dataset.vocab) print("Output:"+" ".join(output[1:-1])) if __name__=="__main__": inference()