Spaces:
Sleeping
Sleeping
import pandas as pd | |
import numpy as np | |
import streamlit as st | |
import torchvision | |
from torchvision import transforms | |
import cv2 | |
import math | |
from collections import Counter | |
from PIL import Image | |
import PIL | |
import zipfile | |
import io | |
import torch | |
import torch.nn as nn | |
from torch.autograd import Variable | |
from torch.utils.data import Dataset, DataLoader | |
import torch.nn.functional as fun | |
from tqdm.notebook import tqdm | |
import matplotlib.pyplot as plt | |
import nltk | |
nltk.download('punkt') | |
from nltk.tokenize import word_tokenize | |
import random | |
data = pd.read_csv("captions.txt", sep=',') | |
#Removes Single Char | |
def remove_single_char(caption_list): | |
list = [] | |
for word in caption_list: | |
if len(word)>1: | |
list.append(word) | |
return list | |
#Make an array of words out of caption and then remove useless single char words | |
data['caption'] = data['caption'].apply(lambda caption :word_tokenize(caption)) | |
data['caption'] = data['caption'].apply(lambda word : remove_single_char(word)) | |
#We need to make sure size of all the captions arrays is same so we add <cell> to cover up | |
lengths = [] | |
lengths = data['caption'].apply(lambda caption : len(caption)) | |
max_length = lengths.max() | |
data['caption'] = data['caption'].apply(lambda caption : ['<start>'] + caption + ['<cell>']*(max_length-len(caption)) + ['<end>']) | |
#For non truncated dataframe to appear | |
pd.set_option('display.max_colwidth', None) | |
#Extracting words | |
words = data['caption'].apply(lambda word : " ".join(word)).str.cat(sep = ' ').split(' ') | |
#Arranging the words in order of their frequency | |
word_dict = sorted(Counter(words), key=Counter(words).get, reverse=True) | |
dict_size = len(word_dict) | |
vocab_threshold = 5 | |
#Encoding the words with index in dictionary made above | |
data['sequence'] = data['caption'].apply(lambda caption : [word_dict.index(word) for word in caption]) | |
data = data.sort_values(by = 'image') | |
class PositionalEncoding(nn.Module): | |
def __init__(self, d_model, dropout=0.1, max_len=(max_length+2)): | |
super(PositionalEncoding, self).__init__() | |
self.dropout = nn.Dropout(p=dropout) | |
pe = torch.zeros(max_len, d_model) | |
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) | |
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) | |
pe[:, 0::2] = torch.sin(position * div_term) | |
pe[:, 1::2] = torch.cos(position * div_term) | |
pe = pe.unsqueeze(0) | |
self.register_buffer('pe', pe) | |
def forward(self, x): | |
if self.pe.size(0) < x.size(0): | |
self.pe = self.pe.repeat(x.size(0), 1, 1) | |
self.pe = self.pe[:x.size(0), : , : ] | |
x = x + self.pe | |
return self.dropout(x) | |
class ImageCaptionModel(nn.Module): | |
def __init__(self, n_head, n_decoder_layer, vocab_size, embedding_size): | |
super(ImageCaptionModel, self).__init__() | |
self.pos_encoder = PositionalEncoding(embedding_size, 0.1) | |
self.TransformerDecoderLayer = nn.TransformerDecoderLayer(d_model = embedding_size, nhead = n_head) | |
self.TransformerDecoder = nn.TransformerDecoder(decoder_layer = self.TransformerDecoderLayer, num_layers = n_decoder_layer) | |
self.embedding_size = embedding_size | |
self.embedding = nn.Embedding(vocab_size , embedding_size) | |
self.last_linear_layer = nn.Linear(embedding_size, vocab_size) | |
self.init_weights() | |
self.n_head = n_head | |
def init_weights(self): | |
initrange = 0.1 | |
self.embedding.weight.data.uniform_(-initrange, initrange) | |
self.last_linear_layer.bias.data.zero_() | |
self.last_linear_layer.weight.data.uniform_(-initrange, initrange) | |
def generate_Mask(self, size, decoder_inp): | |
decoder_input_mask = (torch.triu(torch.ones(size, size)) == 1).transpose(0, 1) | |
decoder_input_mask = decoder_input_mask.float().masked_fill(decoder_input_mask == 0, float('-inf')).masked_fill(decoder_input_mask == 1, float(0.0)) | |
decoder_input_pad_mask = decoder_inp.float().masked_fill(decoder_inp == 0, float(0.0)).masked_fill(decoder_inp > 0, float(1.0)) | |
decoder_input_pad_mask_bool = decoder_inp == 0 | |
return decoder_input_mask, decoder_input_pad_mask, decoder_input_pad_mask_bool | |
def forward(self, encoded_image, decoder_inp): | |
# display(decoder_inp) | |
encoded_image = encoded_image.permute(1,0,2) | |
decoder_inp = torch.clamp(decoder_inp, 0, self.embedding.num_embeddings - 1) | |
decoder_inp_embed = self.embedding(decoder_inp)* math.sqrt(self.embedding_size) | |
decoder_inp_embed = self.embedding(decoder_inp) | |
decoder_inp_embed = self.pos_encoder(decoder_inp_embed) | |
decoder_inp_embed = decoder_inp_embed.permute(1,0,2) | |
decoder_input_mask, decoder_input_pad_mask, decoder_input_pad_mask_bool = self.generate_Mask(decoder_inp.size(1), decoder_inp) | |
decoder_input_mask = decoder_input_mask | |
decoder_input_pad_mask = decoder_input_pad_mask | |
decoder_input_pad_mask_bool = decoder_input_pad_mask_bool | |
decoder_output = self.TransformerDecoder(tgt = decoder_inp_embed, memory = encoded_image, tgt_mask = decoder_input_mask, tgt_key_padding_mask = decoder_input_pad_mask_bool) | |
final_output = self.last_linear_layer(decoder_output) | |
return final_output, decoder_input_pad_mask | |
model = pd.read_pickle('ImageCaptioning_Model.pkl') | |
model.eval() | |
start_token = 2 | |
end_token = 3 | |
cell_token = 1 | |
max_seq_len = 34 | |
validation = pd.read_pickle('Image_Features_Embed_ResNet_Valid.pkl') | |
def process_image_from_zip(zip_path, image_name): | |
with zipfile.ZipFile(zip_path, 'r') as zip_ref: | |
with zip_ref.open(image_name) as file: | |
# Use BytesIO to keep the file-like object open | |
image_data = io.BytesIO(file.read()) | |
image = Image.open(image_data) | |
image = image.convert("RGB") # Convert to RGB if needed | |
return image | |
def generate_caption(K, image_path): | |
model.eval() | |
image__path = 'Images/' + image_path | |
# image = Image.open('Images/' + image_path).convert("RGB") | |
image = process_image_from_zip('Images.zip', image__path) | |
valid_img_df = validation[validation['image']==image_path] | |
actual_caption_list = valid_img_df['caption'].tolist() | |
filtered_caption_list = [word for word in actual_caption_list[0] if word not in ['<start>', '<end>', '<cell>']] | |
actual_caption = " ".join(filtered_caption_list) | |
valid_img_embed = validation[validation['image'] == image_path] | |
img_embed = valid_img_embed['embedded'].tolist() | |
img_embed = torch.tensor(img_embed) | |
input_seq = [cell_token]*max_seq_len | |
input_seq[0] = start_token | |
input_seq = torch.tensor(input_seq).unsqueeze(0) | |
predicted_sentence = [] | |
with torch.no_grad(): | |
for eval_iter in range(0, max_seq_len): | |
img_embed_dense = img_embed.to_dense() | |
output, padding_mask = model.forward(img_embed, input_seq) | |
output = output[eval_iter, 0, :] | |
values = torch.topk(output, K).values.tolist() | |
indices = torch.topk(output, K).indices.tolist() | |
next_word_index = random.choices(indices, values, k = 1)[0] | |
index_to_word = {index: word for index, word in enumerate(word_dict)} | |
next_word = index_to_word[next_word_index] | |
if eval_iter + 1 < max_seq_len: | |
input_seq[:, eval_iter + 1] = next_word_index | |
if next_word == '<end>' : | |
break | |
predicted_sentence.append(next_word) | |
print("\n") | |
filtered_caption_list = [word for word in predicted_sentence if word not in ['<start>', '<end>', '<cell>']] | |
fig, ax = plt.subplots() | |
ax.imshow(image) | |
ax.axis('off') | |
st.pyplot(fig) | |
st.write("Actual Caption: ", actual_caption) | |
st.write("Predicted Caption: ", " ".join(filtered_caption_list)) | |
st.title('Image Captioning') | |
st.write('Generate Caption for Random Image') | |
generate_caption_button = st.button('Generate Caption') | |
if generate_caption_button: | |
try: | |
random_row = validation.sample() | |
random_image = random_row.iloc[0]['image'] | |
generate_caption(1, random_image) | |
except RuntimeError as e: | |
random_row = validation.sample() | |
random_image = random_row.iloc[0]['image'] | |
generate_caption(1, random_image) | |