from PIL import Image import numpy as np from torchvision import transforms import torch from data_loader import get_loader device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Watch for any changes in model.py, and re-load it automatically. %load_ext autoreload %autoreload 2 import os import torch from model import EncoderCNN, DecoderRNN # TODO #2: Specify the saved models to load. encoder_file = 'encoder-3.pkl' decoder_file = 'decoder-3.pkl' # TODO #3: Select appropriate values for the Python variables below. embed_size = 256 hidden_size = 512 # The size of the vocabulary. vocab_size = 8855 # Initialize the encoder and decoder, and set each to inference mode. encoder = EncoderCNN(embed_size) encoder.eval() decoder = DecoderRNN(embed_size, hidden_size, vocab_size) decoder.eval() # Load the trained weights. encoder.load_state_dict(torch.load(os.path.join('/models', encoder_file), map_location=torch.device('cpu'))) decoder.load_state_dict(torch.load(os.path.join('/models', decoder_file), map_location=torch.device('cpu'))) # Move models to GPU if CUDA is available. encoder.to(device) decoder.to(device) def process_image(image): ''' Scales, crops, and normalizes a PIL image for a PyTorch model ''' #img = Image.open(image) transformation = transforms.Compose([ transforms.Resize(256), # smaller edge of image resized to 256 transforms.RandomCrop(224), # get 224x224 crop from random location transforms.ToTensor(), # convert the PIL Image to a tensor transforms.Normalize((0.485, 0.456, 0.406), # normalize image for pre-trained model (0.229, 0.224, 0.225))]) return transformation(image) def function(img_np): PIL_image = Image.fromarray(img_np).convert('RGB') orig_image = np.array(PIL_image) image = process_image(PIL_image) # return original image and pre-processed image tensor return orig_image, image def clean_sentence(output): sentense = '' for i in output: word = data_loader.dataset.vocab.idx2word[i] if i == 0: continue if i == 1: break if i == 18: sentense = sentense + word else: sentense = sentense + ' ' + word return sentense.strip() data_loader = get_loader(transform=transforms, mode='test') def get_caption(image): orig_image, image = function('image') image =image.unsqueeze(0) plt.imshow(np.squeeze(orig_image)) plt.title('Sample Image') plt.show() image = image.to(device) features = encoder(image).unsqueeze(1) output = decoder.sample(features) sentence = clean_sentence(output) return sentence import gradio as gr demo = gr.Interface(fn=get_caption, inputs= "image", outputs="image") demo.launch()