Spaces:
Build error
Build error
from PIL import Image | |
import numpy as np | |
from torchvision import transforms | |
import torch | |
from data_loader import get_loader | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# Watch for any changes in model.py, and re-load it automatically. | |
%load_ext autoreload | |
%autoreload 2 | |
import os | |
import torch | |
from model import EncoderCNN, DecoderRNN | |
# TODO #2: Specify the saved models to load. | |
encoder_file = 'encoder-3.pkl' | |
decoder_file = 'decoder-3.pkl' | |
# TODO #3: Select appropriate values for the Python variables below. | |
embed_size = 256 | |
hidden_size = 512 | |
# The size of the vocabulary. | |
vocab_size = 8855 | |
# Initialize the encoder and decoder, and set each to inference mode. | |
encoder = EncoderCNN(embed_size) | |
encoder.eval() | |
decoder = DecoderRNN(embed_size, hidden_size, vocab_size) | |
decoder.eval() | |
# Load the trained weights. | |
encoder.load_state_dict(torch.load(os.path.join('/models', encoder_file), map_location=torch.device('cpu'))) | |
decoder.load_state_dict(torch.load(os.path.join('/models', decoder_file), map_location=torch.device('cpu'))) | |
# Move models to GPU if CUDA is available. | |
encoder.to(device) | |
decoder.to(device) | |
def process_image(image): | |
''' Scales, crops, and normalizes a PIL image for a PyTorch model | |
''' | |
#img = Image.open(image) | |
transformation = transforms.Compose([ | |
transforms.Resize(256), # smaller edge of image resized to 256 | |
transforms.RandomCrop(224), # get 224x224 crop from random location | |
transforms.ToTensor(), # convert the PIL Image to a tensor | |
transforms.Normalize((0.485, 0.456, 0.406), # normalize image for pre-trained model | |
(0.229, 0.224, 0.225))]) | |
return transformation(image) | |
def function(img_np): | |
PIL_image = Image.fromarray(img_np).convert('RGB') | |
orig_image = np.array(PIL_image) | |
image = process_image(PIL_image) | |
# return original image and pre-processed image tensor | |
return orig_image, image | |
def clean_sentence(output): | |
sentense = '' | |
for i in output: | |
word = data_loader.dataset.vocab.idx2word[i] | |
if i == 0: | |
continue | |
if i == 1: | |
break | |
if i == 18: | |
sentense = sentense + word | |
else: | |
sentense = sentense + ' ' + word | |
return sentense.strip() | |
data_loader = get_loader(transform=transforms, mode='test') | |
def get_caption(image): | |
orig_image, image = function('image') | |
image =image.unsqueeze(0) | |
plt.imshow(np.squeeze(orig_image)) | |
plt.title('Sample Image') | |
plt.show() | |
image = image.to(device) | |
features = encoder(image).unsqueeze(1) | |
output = decoder.sample(features) | |
sentence = clean_sentence(output) | |
return sentence | |
import gradio as gr | |
demo = gr.Interface(fn=get_caption, inputs= "image", outputs="image") | |
demo.launch() |