In [1]:
from transformers import AutoModel, AutoTokenizer
from sentence_transformers import SentenceTransformer
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
import torch.nn.functional as F
from tqdm import tqdm

In [2]:
sentence_model = SentenceTransformer('all-MiniLM-L6-v2')

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

cuda


In [3]:
class ImageDataset(Dataset):
 def __init__(self, csv_file, transform=None):
 self.annotations = csv_file
 self.transform=transform
 
 def __len__(self):
 return len(self.annotations)
 
 def __getitem__(self,index):
 img_desc = self.annotations.iloc[index, 2]

 label=torch.tensor(int(self.annotations.iloc[index, 3]))
 
 if self.transform:
 img_desc = self.transform(img_desc)
 
 return (img_desc, label)

In [4]:
df = pd.read_csv('image_desc.csv')
dataset = ImageDataset(df)
train_size = int(0.85 * len(dataset))
test_size = len(dataset) - train_size
train_set, test_set = torch.utils.data.random_split(dataset, [train_size, test_size])
print(len(dataset))

81890


In [12]:
batch_size=16
train_loader=DataLoader(train_set, batch_size=batch_size, shuffle=True)
test_loader=DataLoader(test_set, batch_size=batch_size, shuffle=True)

In [54]:
class MyModel(nn.Module):
 def __init__(self, sentence_model, hidden_dim, output_dim):
 super(MyModel, self).__init__()
 self.sentence_model = sentence_model
 self.fc1 = nn.Linear(384, hidden_dim)
 self.fc2 = nn.Linear(hidden_dim, output_dim)
 self.sig = nn.Sigmoid()

 def forward(self, x):
 sentence_embeddings = self.sentence_model.encode(x, convert_to_tensor=True)
 sentence_embeddings = sentence_embeddings.to(device)
 hidden = self.fc1(sentence_embeddings)
 hidden = F.relu(hidden)
 logits = self.fc2(hidden)
# logits = torch.clamp(logits, min=1e-5)
 logits = self.sig(logits)
 return logits

output_dim = 102
hidden_dim = 256

model = MyModel(sentence_model, hidden_dim, output_dim).to(device)

In [37]:
min = torch.tensor(1).to(device)
similarity = nn.CosineSimilarity(dim = 0)
for sample_batch, sample_label in tqdm(train_loader):
 i = sample_batch[0]
 j = sample_batch[1]
 output_i = model(i)
 output_j = model(j)
 sim_i_j = similarity(output_i, output_j)
 if sim_i_j < min:
 min = sim_i_j
 
print(min)

100%|██████████████████████████████████████████████████████████████████████████████| 4351/4351 [01:42<00:00, 42.36it/s]

tensor(1.0000, device='cuda:0', grad_fn=)





In [55]:
criterion = nn.CrossEntropyLoss()
# criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)

In [56]:
num_epochs = 4
for epoch in range(num_epochs):
 model.train()
 losses = []

 for i, (sentences_batch, labels_batch) in enumerate(tqdm(train_loader)):
 labels_batch = labels_batch.to(device)
 labels_batch = F.one_hot(labels_batch, num_classes = 102).float()
 optimizer.zero_grad()
 # Forward pass
 logits = model(sentences_batch).float()
 loss = criterion(logits, labels_batch)
 
 # Backward pass and optimization
 loss.backward()
 optimizer.step()
 curr_loss = loss.item()
 losses.append(curr_loss)
 
 running_loss = sum(losses)
 
 # Print the average loss for every epoch
 epoch_loss = running_loss / batch_size
 print(f"Epoch: {epoch+1}/{num_epochs}, Loss: {epoch_loss}")


100%|██████████████████████████████████████████████████████████████████████████████| 4351/4351 [01:06<00:00, 65.25it/s]


Epoch: 1/4, Loss: 1116.4719812870026


100%|██████████████████████████████████████████████████████████████████████████████| 4351/4351 [01:06<00:00, 65.90it/s]


Epoch: 2/4, Loss: 1087.523635149002


100%|██████████████████████████████████████████████████████████████████████████████| 4351/4351 [01:06<00:00, 65.20it/s]


Epoch: 3/4, Loss: 1079.509438186884


100%|██████████████████████████████████████████████████████████████████████████████| 4351/4351 [01:07<00:00, 64.31it/s]

Epoch: 4/4, Loss: 1074.7653084248304





In [47]:
for sample_batch, sample_label in tqdm(train_loader):
 i = sample_batch[0]
 j = sample_batch[1]
 output_i = model(i)
 output_j = model(j)
 sim_i_j = similarity(output_i, output_j)
 if sim_i_j < min:
 min = sim_i_j
 
print(min)

100%|██████████████████████████████████████████████████████████████████████████████| 4351/4351 [01:33<00:00, 46.66it/s]

tensor(0., device='cuda:0', grad_fn=)





In [57]:
model.eval()
total_correct = 0
total_samples = 0

with torch.no_grad():
 for i, (sentences_batch, labels_batch) in enumerate(tqdm(train_loader)):
 labels_batch = labels_batch.to(device)
# labels_batch = F.one_hot(labels_batch, num_classes = 102).float()
 logits = model(sentences_batch).float()
 predicted = torch.argmax(logits, dim = 1)
 total_samples += labels_batch.size(0)
 total_correct += (predicted == labels_batch).sum().item()

accuracy = total_correct / total_samples
print("Accuracy:", accuracy)

100%|██████████████████████████████████████████████████████████████████████████████| 4351/4351 [01:04<00:00, 67.76it/s]

Accuracy: 0.2659109846852283





In [58]:
sentence = input("ENTER DESCRIPTION ")
output = model(sentence)
predicted = torch.argmax(output)
print(output)
print(predicted)

ENTER DESCRIPTION pink
tensor([9.6708e-15, 1.0179e-04, 4.2242e-08, 1.3063e-15, 8.8056e-03, 0.0000e+00,
 1.6553e-14, 8.9271e-33, 9.0644e-27, 5.9910e-19, 2.2721e-24, 7.7432e-03,
 3.9587e-36, 7.1618e-07, 2.7430e-08, 0.0000e+00, 0.0000e+00, 1.4562e-03,
 9.8114e-06, 9.2844e-24, 7.8520e-33, 2.9296e-22, 3.5067e-13, 1.3316e-05,
 7.7768e-11, 9.2201e-39, 5.0639e-22, 1.6904e-19, 3.2689e-35, 1.0034e-14,
 9.8686e-01, 4.1330e-05, 6.3048e-01, 9.5960e-23, 1.2662e-14, 2.4540e-22,
 1.4413e-08, 9.9928e-01, 2.8299e-02, 4.9763e-10, 2.7364e-04, 9.9878e-01,
 0.0000e+00, 9.9998e-01, 6.7328e-02, 2.9939e-13, 1.9145e-17, 0.0000e+00,
 0.0000e+00, 0.0000e+00, 0.0000e+00, 9.9998e-01, 1.1818e-30, 2.2513e-22,
 0.0000e+00, 1.0346e-32, 8.8656e-21, 9.9353e-01, 4.3037e-03, 8.6023e-39,
 3.6964e-10, 3.3164e-21, 1.9611e-15, 0.0000e+00, 3.7135e-38, 1.3163e-34,
 1.8906e-07, 7.0084e-30, 1.0882e-20, 2.6501e-33, 8.9597e-39, 5.0791e-37,
 1.0000e+00, 5.7929e-03, 1.3252e-03, 1.4498e-23, 1.3656e-02, 2.0226e-07,
 8.3005e-01, 8.4326e-

In [59]:
torch.save(model.state_dict(), "sentence_embedding.pth")