image_to_text / app.py
ovi054
requirements updated
abf0d14
raw
history blame
No virus
3.07 kB
from PIL import Image
import numpy as np
from torchvision import transforms
import torch
from data_loader import get_loader
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Watch for any changes in model.py, and re-load it automatically.
# %load_ext autoreload
# %autoreload 2
import os
import torch
from model import EncoderCNN, DecoderRNN
# TODO #2: Specify the saved models to load.
encoder_file = 'encoder-3.pkl'
decoder_file = 'decoder-3.pkl'
# TODO #3: Select appropriate values for the Python variables below.
embed_size = 256
hidden_size = 512
# The size of the vocabulary.
vocab_size = 8855
# Initialize the encoder and decoder, and set each to inference mode.
encoder = EncoderCNN(embed_size)
encoder.eval()
decoder = DecoderRNN(embed_size, hidden_size, vocab_size)
decoder.eval()
# Load the trained weights.
encoder.load_state_dict(torch.load(os.path.join('/models', encoder_file), map_location=torch.device('cpu')))
decoder.load_state_dict(torch.load(os.path.join('/models', decoder_file), map_location=torch.device('cpu')))
# Move models to GPU if CUDA is available.
encoder.to(device)
decoder.to(device)
def process_image(image):
''' Scales, crops, and normalizes a PIL image for a PyTorch model
'''
#img = Image.open(image)
transformation = transforms.Compose([
transforms.Resize(256), # smaller edge of image resized to 256
transforms.RandomCrop(224), # get 224x224 crop from random location
transforms.ToTensor(), # convert the PIL Image to a tensor
transforms.Normalize((0.485, 0.456, 0.406), # normalize image for pre-trained model
(0.229, 0.224, 0.225))])
return transformation(image)
def function(img_np):
PIL_image = Image.fromarray(img_np).convert('RGB')
orig_image = np.array(PIL_image)
image = process_image(PIL_image)
# return original image and pre-processed image tensor
return orig_image, image
def clean_sentence(output):
sentense = ''
for i in output:
word = data_loader.dataset.vocab.idx2word[i]
if i == 0:
continue
if i == 1:
break
if i == 18:
sentense = sentense + word
else:
sentense = sentense + ' ' + word
return sentense.strip()
data_loader = get_loader(transform=transforms, mode='test')
def get_caption(image):
orig_image, image = function('image')
image =image.unsqueeze(0)
plt.imshow(np.squeeze(orig_image))
plt.title('Sample Image')
plt.show()
image = image.to(device)
features = encoder(image).unsqueeze(1)
output = decoder.sample(features)
sentence = clean_sentence(output)
return sentence
import gradio as gr
demo = gr.Interface(fn=get_caption, inputs= "image", outputs="image")
demo.launch()