karan99300 commited on
Commit
b1a427a
·
1 Parent(s): eb08f58

Upload 5 files

Browse files
Files changed (5) hide show
  1. app.py +30 -0
  2. inference.py +39 -0
  3. loader.py +100 -0
  4. model.py +74 -0
  5. train.py +64 -0
app.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import requests
3
+ import gradio as gr
4
+ import torch
5
+ from loader import get_loader
6
+ import torchvision.transforms as transforms
7
+
8
+ transform = transforms.Compose([
9
+ transforms.Resize(256),
10
+ transforms.CenterCrop(224),
11
+ transforms.ToTensor(),
12
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
13
+ ])
14
+
15
+ train_loader,dataset=get_loader(root_folder='FlickrDataset/Images',annotation_file='FlickrDataset/Captions/captions.txt',transform=transform,num_workers=2)
16
+ filepath="ImageCaptioningusingLSTM.pth"
17
+ from model import CNNtoRNN
18
+ model=CNNtoRNN(embed_size=256,hidden_size=256,vocab_size=len(dataset.vocab),num_layers=1)
19
+ model.load_state_dict(torch.load(filepath))
20
+ model.eval()
21
+
22
+ def launch(input):
23
+ im=Image.open(requests.get(input,stream=True).raw)
24
+ image=transform(im.convert('RGB')).unsqueeze(0)
25
+
26
+ return model.caption_image(image,dataset.vocab)
27
+
28
+ iface=gr.Interface(launch,inputs="text",outputs="text")
29
+ iface.launch()
30
+
inference.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ import torchvision.transforms as transforms
4
+ from PIL import Image
5
+ from model import CNNtoRNN
6
+ import pandas as pd
7
+ from loader import get_loader
8
+
9
+ def inference():
10
+ transform = transforms.Compose([
11
+ transforms.Resize(256),
12
+ transforms.CenterCrop(224),
13
+ transforms.ToTensor(),
14
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
15
+ ])
16
+
17
+ image_index=100
18
+
19
+ train_loader,dataset=get_loader(root_folder='FlickrDataset/Images',annotation_file='FlickrDataset/Captions/captions.txt',transform=transform,num_workers=2)
20
+ df=pd.read_csv("FlickrDataset/Captions/captions.txt")
21
+ imagepath="FlickrDataset/Images/"
22
+ images=os.listdir(imagepath)
23
+ im=Image.open(os.path.join(imagepath,images[image_index]))
24
+ im.show()
25
+
26
+ device=torch.device('cuda' if torch.cuda.is_available() else "cpu")
27
+
28
+ filepath="ImageCaptioningusingLSTM.pth"
29
+ model=CNNtoRNN(embed_size=256,hidden_size=256,vocab_size=len(dataset.vocab),num_layers=1).to(device)
30
+ model.load_state_dict(torch.load(filepath))
31
+ model.eval()
32
+
33
+ image=transform(im.convert("RGB")).unsqueeze(0)
34
+
35
+ output=model.caption_image(image.to(device),dataset.vocab)
36
+ print("Output:"+" ".join(output[1:-1]))
37
+
38
+ if __name__=="__main__":
39
+ inference()
loader.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+ import spacy
4
+ import torch
5
+ from torch.nn.utils.rnn import pad_sequence
6
+ from torch.utils.data import DataLoader,Dataset
7
+ from PIL import Image
8
+ import torchvision.transforms as transforms
9
+
10
+ spacy_eng=spacy.load("en_core_web_sm")
11
+ class Vocabulary:
12
+ def __init__(self,freq_threshold):
13
+ self.itos={0:"<PAD>",1:"<SOS>",2:"<EOS>",3:"<UNK>"}
14
+ self.stoi={"<PAD>":0,"<SOS>":1,"<EOS>":2,"<UNK>":3}
15
+ self.freq_threshold=freq_threshold
16
+
17
+ def __len__(self):
18
+ return len(self.itos)
19
+
20
+ def tokenizer_eng(self,text):
21
+ return [tok.text.lower() for tok in spacy_eng.tokenizer(text)]
22
+
23
+ def build_vocabulary(self,sentence_list):
24
+ frequencies={}
25
+ idx=4
26
+
27
+ for sentence in sentence_list:
28
+ for word in self.tokenizer_eng(sentence):
29
+ if word not in frequencies:
30
+ frequencies[word]=1
31
+
32
+ else:
33
+ frequencies[word]+=1
34
+
35
+ if frequencies[word]==self.freq_threshold:
36
+ self.stoi[word]=idx
37
+ self.itos[idx]=word
38
+ idx+=1
39
+
40
+ def numericalize(self,text):
41
+ tokenized_text=self.tokenizer_eng(text)
42
+ return [
43
+ self.stoi[token] if token in self.stoi else self.stoi["<UNK>"]
44
+ for token in tokenized_text
45
+ ]
46
+
47
+ class FlickrDataset(Dataset):
48
+ def __init__(self,root_dir,captions_file,transform=None,freq_threshold=5):
49
+ self.root_dir=root_dir
50
+ self.df=pd.read_csv(captions_file)
51
+ self.transform=transform
52
+
53
+ self.imgs=self.df['image']
54
+ self.captions=self.df['caption']
55
+
56
+ self.vocab=Vocabulary(freq_threshold)
57
+ self.vocab.build_vocabulary(self.captions.tolist())
58
+
59
+ def __len__(self):
60
+ return len(self.df)
61
+
62
+ def __getitem__(self,index):
63
+ caption=self.captions[index]
64
+ img_id=self.imgs[index]
65
+ img=Image.open(os.path.join(self.root_dir,img_id)).convert("RGB")
66
+
67
+ if self.transform is not None:
68
+ img=self.transform(img)
69
+
70
+ numericalized_caption=[self.vocab.stoi["<SOS>"]]
71
+ numericalized_caption+=self.vocab.numericalize(caption)
72
+ numericalized_caption.append(self.vocab.stoi["<EOS>"])
73
+
74
+ return img,torch.tensor(numericalized_caption)
75
+
76
+ class MyCollate:
77
+ def __init__(self,pad_idx):
78
+ self.pad_idx=pad_idx
79
+
80
+ def __call__(self,batch):
81
+ imgs=[item[0].unsqueeze(0) for item in batch]
82
+ imgs=torch.cat(imgs,dim=0)
83
+ targets=[item[1] for item in batch]
84
+ targets=pad_sequence(targets,batch_first=False,padding_value=self.pad_idx)
85
+
86
+ return imgs,targets
87
+
88
+ def get_loader(root_folder,annotation_file,transform,batch_size=32,shuffle=True,pin_memory=True,num_workers=8):
89
+ dataset=FlickrDataset(root_folder,annotation_file,transform=transform)
90
+ pad_idx=dataset.vocab.stoi["<PAD>"]
91
+ loader=DataLoader(
92
+ dataset=dataset,
93
+ batch_size=batch_size,
94
+ num_workers=num_workers,
95
+ shuffle=shuffle,
96
+ pin_memory=pin_memory,
97
+ collate_fn=MyCollate(pad_idx=pad_idx)
98
+ )
99
+
100
+ return loader,dataset
model.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torchvision.models as models
4
+
5
+ class EncoderCNN(nn.Module):
6
+ def __init__(self,embed_size):
7
+ super(EncoderCNN, self).__init__()
8
+ resnet = models.resnet50(weights='ResNet50_Weights.DEFAULT')
9
+ for param in resnet.parameters():
10
+ param.requires_grad_(False)
11
+
12
+ modules = list(resnet.children())[:-1]
13
+ self.resnet = nn.Sequential(*modules)
14
+ self.embed = nn.Linear(resnet.fc.in_features, embed_size)
15
+ self.batch= nn.BatchNorm1d(embed_size,momentum = 0.01)
16
+ self.embed.weight.data.normal_(0., 0.02)
17
+ self.embed.bias.data.fill_(0)
18
+
19
+ def forward(self,images):
20
+ features = self.resnet(images)
21
+ features = features.view(features.size(0), -1)
22
+ features = self.batch(self.embed(features))
23
+ return features
24
+
25
+
26
+ class DecoderRNN(nn.Module):
27
+ def __init__(self,embed_size,hidden_size,vocab_size,num_layers):
28
+ super(DecoderRNN, self).__init__()
29
+ self.embed=nn.Embedding(vocab_size,embed_size)
30
+ self.lstm=nn.LSTM(embed_size,hidden_size,num_layers)
31
+ self.linear=nn.Linear(hidden_size,vocab_size)
32
+ self.dropout=nn.Dropout(0.5)
33
+
34
+ def forward(self,features,captions):
35
+ embeddings=self.dropout(self.embed(captions))
36
+ embeddings=torch.cat((features.unsqueeze(0),embeddings),dim=0)
37
+ hiddens,_=self.lstm(embeddings)
38
+ outputs=self.linear(hiddens)
39
+
40
+ return outputs
41
+
42
+ class CNNtoRNN(nn.Module):
43
+ def __init__(self,embed_size,hidden_size,vocab_size,num_layers):
44
+ super(CNNtoRNN,self).__init__()
45
+ self.encoderCNN=EncoderCNN(embed_size)
46
+ self.decoderRNN=DecoderRNN(embed_size,hidden_size,vocab_size,num_layers)
47
+
48
+ def forward(self,images,captions):
49
+ features=self.encoderCNN(images)
50
+ outputs=self.decoderRNN(features,captions)
51
+ return outputs
52
+
53
+ def caption_image(self,image,vocabulary,max_length=50):
54
+ result_caption=[]
55
+ with torch.no_grad():
56
+ X=self.encoderCNN(image).unsqueeze(0)
57
+ states=None
58
+
59
+ for _ in range(max_length):
60
+ hiddens,states=self.decoderRNN.lstm(X,states)
61
+ output=self.decoderRNN.linear(hiddens.squeeze(0))
62
+ predicted=output.argmax(1)
63
+ result_caption.append(predicted.item())
64
+
65
+ X=self.decoderRNN.embed(predicted).unsqueeze(0)
66
+
67
+ if vocabulary.itos[predicted.item()]=="<EOS>":
68
+ break
69
+
70
+ return [vocabulary.itos[idx] for idx in result_caption]
71
+
72
+
73
+
74
+
train.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ import torchvision.transforms as transforms
5
+ from loader import get_loader
6
+ from model import CNNtoRNN
7
+ from tqdm import tqdm
8
+ from tqdm import trange
9
+
10
+ def train():
11
+ transform = transforms.Compose([
12
+ transforms.Resize(256),
13
+ transforms.CenterCrop(224),
14
+ transforms.ToTensor(),
15
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
16
+ ])
17
+
18
+ train_loader,dataset=get_loader(root_folder='FlickrDataset/Images',annotation_file='FlickrDataset/Captions/captions.txt',transform=transform,num_workers=2)
19
+
20
+ torch.backends.cudnn.benchmark=True
21
+ device=torch.device('cuda' if torch.cuda.is_available() else "cpu")
22
+ embed_size=256
23
+ hidden_size=256
24
+ vocab_size=len(dataset.vocab)
25
+ num_layers=1
26
+ learning_rate=3e-4
27
+ num_epochs=200
28
+
29
+ model=CNNtoRNN(embed_size,hidden_size,vocab_size,num_layers).to(device)
30
+ criterion=nn.CrossEntropyLoss(ignore_index=dataset.vocab.stoi["<PAD>"])
31
+ optimizer=optim.Adam(model.parameters(),lr=learning_rate)
32
+ train_iterator=trange(0,num_epochs)
33
+ for _ in train_iterator:
34
+ pbar=tqdm(train_loader)
35
+ for idx,(imgs,captions) in enumerate(pbar):
36
+ model.train()
37
+ imgs=imgs.to(device)
38
+ captions=captions.to(device)
39
+
40
+ outputs=model(imgs,captions[:-1])
41
+ loss=criterion(outputs.reshape(-1,outputs.shape[2]),captions.reshape(-1))
42
+
43
+
44
+ loss.backward()
45
+ optimizer.step()
46
+ optimizer.zero_grad()
47
+
48
+ pbar.set_postfix(loss=loss.item())
49
+
50
+ filepath="ImageCaptioningusingLSTM.pth"
51
+ torch.save(model.state_dict(),filepath)
52
+
53
+
54
+
55
+ if __name__=="__main__":
56
+ train()
57
+
58
+
59
+
60
+
61
+
62
+
63
+
64
+