import torch import requests import gradio as gr from io import BytesIO from PIL import Image from model import CaptionModel from torchvision import transforms from preprocess import Tokenizer, return_user_agent tokenizer = Tokenizer('./') tokenizer.load_tokenizer('./checkpoints/vocab-v1.pkl') weights = torch.load('./checkpoints/caption_model.pt', map_location=torch.device('cpu')) model = CaptionModel(tokenizer) model.load_state_dict(weights['state_dict']) val_tfms = transforms.Compose([ # smaller edge of image resized to 256 transforms.Resize(256), transforms.ToTensor(), # normalize image for pre-trained model transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) ]) def decode_caption(idxs, tokenizer): temp = [] for i in idxs: temp.append(tokenizer.idx2val[i]) return ' '.join(temp).replace('', '') def predict_fn(image, link): if link != '': try: resp = requests.get(link, headers=return_user_agent()) image = Image.open(BytesIO(resp.content)) except: error_image = Image.open('./error.jpg') error_text = 'Image from given link could not be downloaded, please try again with valid link' return error_image, error_text display_image = transforms.Resize(100)(image) image = val_tfms(image).unsqueeze(0) model.eval() out = model.predict(image, torch.device('cpu')) caption = decode_caption(out[0], tokenizer) return display_image, caption demo = gr.Interface( fn=predict_fn, inputs=[ gr.Image(label="Input Image", type='pil'), gr.Textbox(label='Enter Image Link', placeholder='Enter or Paste any Image link from Internet') ], outputs=[ gr.Image(label="Display Image for link as input", type='pil'), gr.Textbox(label="Generated Caption"), ], title="Image Captioning System", description="Image Captioning Model trained on Flick8k Dataset ", ) demo.launch(debug=True)