Spaces:
Sleeping
Sleeping
karan99300
commited on
Commit
·
b1a427a
1
Parent(s):
eb08f58
Upload 5 files
Browse files- app.py +30 -0
- inference.py +39 -0
- loader.py +100 -0
- model.py +74 -0
- train.py +64 -0
app.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
import requests
|
3 |
+
import gradio as gr
|
4 |
+
import torch
|
5 |
+
from loader import get_loader
|
6 |
+
import torchvision.transforms as transforms
|
7 |
+
|
8 |
+
transform = transforms.Compose([
|
9 |
+
transforms.Resize(256),
|
10 |
+
transforms.CenterCrop(224),
|
11 |
+
transforms.ToTensor(),
|
12 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
13 |
+
])
|
14 |
+
|
15 |
+
train_loader,dataset=get_loader(root_folder='FlickrDataset/Images',annotation_file='FlickrDataset/Captions/captions.txt',transform=transform,num_workers=2)
|
16 |
+
filepath="ImageCaptioningusingLSTM.pth"
|
17 |
+
from model import CNNtoRNN
|
18 |
+
model=CNNtoRNN(embed_size=256,hidden_size=256,vocab_size=len(dataset.vocab),num_layers=1)
|
19 |
+
model.load_state_dict(torch.load(filepath))
|
20 |
+
model.eval()
|
21 |
+
|
22 |
+
def launch(input):
|
23 |
+
im=Image.open(requests.get(input,stream=True).raw)
|
24 |
+
image=transform(im.convert('RGB')).unsqueeze(0)
|
25 |
+
|
26 |
+
return model.caption_image(image,dataset.vocab)
|
27 |
+
|
28 |
+
iface=gr.Interface(launch,inputs="text",outputs="text")
|
29 |
+
iface.launch()
|
30 |
+
|
inference.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import os
|
3 |
+
import torchvision.transforms as transforms
|
4 |
+
from PIL import Image
|
5 |
+
from model import CNNtoRNN
|
6 |
+
import pandas as pd
|
7 |
+
from loader import get_loader
|
8 |
+
|
9 |
+
def inference():
|
10 |
+
transform = transforms.Compose([
|
11 |
+
transforms.Resize(256),
|
12 |
+
transforms.CenterCrop(224),
|
13 |
+
transforms.ToTensor(),
|
14 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
15 |
+
])
|
16 |
+
|
17 |
+
image_index=100
|
18 |
+
|
19 |
+
train_loader,dataset=get_loader(root_folder='FlickrDataset/Images',annotation_file='FlickrDataset/Captions/captions.txt',transform=transform,num_workers=2)
|
20 |
+
df=pd.read_csv("FlickrDataset/Captions/captions.txt")
|
21 |
+
imagepath="FlickrDataset/Images/"
|
22 |
+
images=os.listdir(imagepath)
|
23 |
+
im=Image.open(os.path.join(imagepath,images[image_index]))
|
24 |
+
im.show()
|
25 |
+
|
26 |
+
device=torch.device('cuda' if torch.cuda.is_available() else "cpu")
|
27 |
+
|
28 |
+
filepath="ImageCaptioningusingLSTM.pth"
|
29 |
+
model=CNNtoRNN(embed_size=256,hidden_size=256,vocab_size=len(dataset.vocab),num_layers=1).to(device)
|
30 |
+
model.load_state_dict(torch.load(filepath))
|
31 |
+
model.eval()
|
32 |
+
|
33 |
+
image=transform(im.convert("RGB")).unsqueeze(0)
|
34 |
+
|
35 |
+
output=model.caption_image(image.to(device),dataset.vocab)
|
36 |
+
print("Output:"+" ".join(output[1:-1]))
|
37 |
+
|
38 |
+
if __name__=="__main__":
|
39 |
+
inference()
|
loader.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pandas as pd
|
3 |
+
import spacy
|
4 |
+
import torch
|
5 |
+
from torch.nn.utils.rnn import pad_sequence
|
6 |
+
from torch.utils.data import DataLoader,Dataset
|
7 |
+
from PIL import Image
|
8 |
+
import torchvision.transforms as transforms
|
9 |
+
|
10 |
+
spacy_eng=spacy.load("en_core_web_sm")
|
11 |
+
class Vocabulary:
|
12 |
+
def __init__(self,freq_threshold):
|
13 |
+
self.itos={0:"<PAD>",1:"<SOS>",2:"<EOS>",3:"<UNK>"}
|
14 |
+
self.stoi={"<PAD>":0,"<SOS>":1,"<EOS>":2,"<UNK>":3}
|
15 |
+
self.freq_threshold=freq_threshold
|
16 |
+
|
17 |
+
def __len__(self):
|
18 |
+
return len(self.itos)
|
19 |
+
|
20 |
+
def tokenizer_eng(self,text):
|
21 |
+
return [tok.text.lower() for tok in spacy_eng.tokenizer(text)]
|
22 |
+
|
23 |
+
def build_vocabulary(self,sentence_list):
|
24 |
+
frequencies={}
|
25 |
+
idx=4
|
26 |
+
|
27 |
+
for sentence in sentence_list:
|
28 |
+
for word in self.tokenizer_eng(sentence):
|
29 |
+
if word not in frequencies:
|
30 |
+
frequencies[word]=1
|
31 |
+
|
32 |
+
else:
|
33 |
+
frequencies[word]+=1
|
34 |
+
|
35 |
+
if frequencies[word]==self.freq_threshold:
|
36 |
+
self.stoi[word]=idx
|
37 |
+
self.itos[idx]=word
|
38 |
+
idx+=1
|
39 |
+
|
40 |
+
def numericalize(self,text):
|
41 |
+
tokenized_text=self.tokenizer_eng(text)
|
42 |
+
return [
|
43 |
+
self.stoi[token] if token in self.stoi else self.stoi["<UNK>"]
|
44 |
+
for token in tokenized_text
|
45 |
+
]
|
46 |
+
|
47 |
+
class FlickrDataset(Dataset):
|
48 |
+
def __init__(self,root_dir,captions_file,transform=None,freq_threshold=5):
|
49 |
+
self.root_dir=root_dir
|
50 |
+
self.df=pd.read_csv(captions_file)
|
51 |
+
self.transform=transform
|
52 |
+
|
53 |
+
self.imgs=self.df['image']
|
54 |
+
self.captions=self.df['caption']
|
55 |
+
|
56 |
+
self.vocab=Vocabulary(freq_threshold)
|
57 |
+
self.vocab.build_vocabulary(self.captions.tolist())
|
58 |
+
|
59 |
+
def __len__(self):
|
60 |
+
return len(self.df)
|
61 |
+
|
62 |
+
def __getitem__(self,index):
|
63 |
+
caption=self.captions[index]
|
64 |
+
img_id=self.imgs[index]
|
65 |
+
img=Image.open(os.path.join(self.root_dir,img_id)).convert("RGB")
|
66 |
+
|
67 |
+
if self.transform is not None:
|
68 |
+
img=self.transform(img)
|
69 |
+
|
70 |
+
numericalized_caption=[self.vocab.stoi["<SOS>"]]
|
71 |
+
numericalized_caption+=self.vocab.numericalize(caption)
|
72 |
+
numericalized_caption.append(self.vocab.stoi["<EOS>"])
|
73 |
+
|
74 |
+
return img,torch.tensor(numericalized_caption)
|
75 |
+
|
76 |
+
class MyCollate:
|
77 |
+
def __init__(self,pad_idx):
|
78 |
+
self.pad_idx=pad_idx
|
79 |
+
|
80 |
+
def __call__(self,batch):
|
81 |
+
imgs=[item[0].unsqueeze(0) for item in batch]
|
82 |
+
imgs=torch.cat(imgs,dim=0)
|
83 |
+
targets=[item[1] for item in batch]
|
84 |
+
targets=pad_sequence(targets,batch_first=False,padding_value=self.pad_idx)
|
85 |
+
|
86 |
+
return imgs,targets
|
87 |
+
|
88 |
+
def get_loader(root_folder,annotation_file,transform,batch_size=32,shuffle=True,pin_memory=True,num_workers=8):
|
89 |
+
dataset=FlickrDataset(root_folder,annotation_file,transform=transform)
|
90 |
+
pad_idx=dataset.vocab.stoi["<PAD>"]
|
91 |
+
loader=DataLoader(
|
92 |
+
dataset=dataset,
|
93 |
+
batch_size=batch_size,
|
94 |
+
num_workers=num_workers,
|
95 |
+
shuffle=shuffle,
|
96 |
+
pin_memory=pin_memory,
|
97 |
+
collate_fn=MyCollate(pad_idx=pad_idx)
|
98 |
+
)
|
99 |
+
|
100 |
+
return loader,dataset
|
model.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torchvision.models as models
|
4 |
+
|
5 |
+
class EncoderCNN(nn.Module):
|
6 |
+
def __init__(self,embed_size):
|
7 |
+
super(EncoderCNN, self).__init__()
|
8 |
+
resnet = models.resnet50(weights='ResNet50_Weights.DEFAULT')
|
9 |
+
for param in resnet.parameters():
|
10 |
+
param.requires_grad_(False)
|
11 |
+
|
12 |
+
modules = list(resnet.children())[:-1]
|
13 |
+
self.resnet = nn.Sequential(*modules)
|
14 |
+
self.embed = nn.Linear(resnet.fc.in_features, embed_size)
|
15 |
+
self.batch= nn.BatchNorm1d(embed_size,momentum = 0.01)
|
16 |
+
self.embed.weight.data.normal_(0., 0.02)
|
17 |
+
self.embed.bias.data.fill_(0)
|
18 |
+
|
19 |
+
def forward(self,images):
|
20 |
+
features = self.resnet(images)
|
21 |
+
features = features.view(features.size(0), -1)
|
22 |
+
features = self.batch(self.embed(features))
|
23 |
+
return features
|
24 |
+
|
25 |
+
|
26 |
+
class DecoderRNN(nn.Module):
|
27 |
+
def __init__(self,embed_size,hidden_size,vocab_size,num_layers):
|
28 |
+
super(DecoderRNN, self).__init__()
|
29 |
+
self.embed=nn.Embedding(vocab_size,embed_size)
|
30 |
+
self.lstm=nn.LSTM(embed_size,hidden_size,num_layers)
|
31 |
+
self.linear=nn.Linear(hidden_size,vocab_size)
|
32 |
+
self.dropout=nn.Dropout(0.5)
|
33 |
+
|
34 |
+
def forward(self,features,captions):
|
35 |
+
embeddings=self.dropout(self.embed(captions))
|
36 |
+
embeddings=torch.cat((features.unsqueeze(0),embeddings),dim=0)
|
37 |
+
hiddens,_=self.lstm(embeddings)
|
38 |
+
outputs=self.linear(hiddens)
|
39 |
+
|
40 |
+
return outputs
|
41 |
+
|
42 |
+
class CNNtoRNN(nn.Module):
|
43 |
+
def __init__(self,embed_size,hidden_size,vocab_size,num_layers):
|
44 |
+
super(CNNtoRNN,self).__init__()
|
45 |
+
self.encoderCNN=EncoderCNN(embed_size)
|
46 |
+
self.decoderRNN=DecoderRNN(embed_size,hidden_size,vocab_size,num_layers)
|
47 |
+
|
48 |
+
def forward(self,images,captions):
|
49 |
+
features=self.encoderCNN(images)
|
50 |
+
outputs=self.decoderRNN(features,captions)
|
51 |
+
return outputs
|
52 |
+
|
53 |
+
def caption_image(self,image,vocabulary,max_length=50):
|
54 |
+
result_caption=[]
|
55 |
+
with torch.no_grad():
|
56 |
+
X=self.encoderCNN(image).unsqueeze(0)
|
57 |
+
states=None
|
58 |
+
|
59 |
+
for _ in range(max_length):
|
60 |
+
hiddens,states=self.decoderRNN.lstm(X,states)
|
61 |
+
output=self.decoderRNN.linear(hiddens.squeeze(0))
|
62 |
+
predicted=output.argmax(1)
|
63 |
+
result_caption.append(predicted.item())
|
64 |
+
|
65 |
+
X=self.decoderRNN.embed(predicted).unsqueeze(0)
|
66 |
+
|
67 |
+
if vocabulary.itos[predicted.item()]=="<EOS>":
|
68 |
+
break
|
69 |
+
|
70 |
+
return [vocabulary.itos[idx] for idx in result_caption]
|
71 |
+
|
72 |
+
|
73 |
+
|
74 |
+
|
train.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.optim as optim
|
4 |
+
import torchvision.transforms as transforms
|
5 |
+
from loader import get_loader
|
6 |
+
from model import CNNtoRNN
|
7 |
+
from tqdm import tqdm
|
8 |
+
from tqdm import trange
|
9 |
+
|
10 |
+
def train():
|
11 |
+
transform = transforms.Compose([
|
12 |
+
transforms.Resize(256),
|
13 |
+
transforms.CenterCrop(224),
|
14 |
+
transforms.ToTensor(),
|
15 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
16 |
+
])
|
17 |
+
|
18 |
+
train_loader,dataset=get_loader(root_folder='FlickrDataset/Images',annotation_file='FlickrDataset/Captions/captions.txt',transform=transform,num_workers=2)
|
19 |
+
|
20 |
+
torch.backends.cudnn.benchmark=True
|
21 |
+
device=torch.device('cuda' if torch.cuda.is_available() else "cpu")
|
22 |
+
embed_size=256
|
23 |
+
hidden_size=256
|
24 |
+
vocab_size=len(dataset.vocab)
|
25 |
+
num_layers=1
|
26 |
+
learning_rate=3e-4
|
27 |
+
num_epochs=200
|
28 |
+
|
29 |
+
model=CNNtoRNN(embed_size,hidden_size,vocab_size,num_layers).to(device)
|
30 |
+
criterion=nn.CrossEntropyLoss(ignore_index=dataset.vocab.stoi["<PAD>"])
|
31 |
+
optimizer=optim.Adam(model.parameters(),lr=learning_rate)
|
32 |
+
train_iterator=trange(0,num_epochs)
|
33 |
+
for _ in train_iterator:
|
34 |
+
pbar=tqdm(train_loader)
|
35 |
+
for idx,(imgs,captions) in enumerate(pbar):
|
36 |
+
model.train()
|
37 |
+
imgs=imgs.to(device)
|
38 |
+
captions=captions.to(device)
|
39 |
+
|
40 |
+
outputs=model(imgs,captions[:-1])
|
41 |
+
loss=criterion(outputs.reshape(-1,outputs.shape[2]),captions.reshape(-1))
|
42 |
+
|
43 |
+
|
44 |
+
loss.backward()
|
45 |
+
optimizer.step()
|
46 |
+
optimizer.zero_grad()
|
47 |
+
|
48 |
+
pbar.set_postfix(loss=loss.item())
|
49 |
+
|
50 |
+
filepath="ImageCaptioningusingLSTM.pth"
|
51 |
+
torch.save(model.state_dict(),filepath)
|
52 |
+
|
53 |
+
|
54 |
+
|
55 |
+
if __name__=="__main__":
|
56 |
+
train()
|
57 |
+
|
58 |
+
|
59 |
+
|
60 |
+
|
61 |
+
|
62 |
+
|
63 |
+
|
64 |
+
|