BeamDiffusion / models /CoSeD /sequence_predictor.py
Gui28F's picture
uploaded all project files
173ea2b verified
import math
import pickle
import sys
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
import torch.nn.functional as F
import lightning as L
from lightning.pytorch.tuner import Tuner
from lightning.pytorch.callbacks import LearningRateMonitor
import wandb
from pytorch_lightning.loggers import WandbLogger
from transformers import CLIPTokenizer, CLIPTextModelWithProjection
class SoftAttention(L.LightningModule):
def __init__(self, learning_rate=0.001, batch_size=10, unfreeze=0, random_text=False, random_everything=False,
fixed_text=False, random_images=False):
super(SoftAttention, self).__init__()
self.my_optimizer = None
self.my_scheduler = None
self.save_hyperparameters()
self.learning_rate = learning_rate
self.batch_size = batch_size
self.frozen = False
self.unfreeze_epoch = unfreeze
self.loss_method = torch.nn.CrossEntropyLoss()
self.train_sum_precision = 0
self.train_sum_accuracy = 0
self.train_sum_recall = 0
self.train_sum_runs = 0
self.val_sum_precision = 0
self.val_sum_accuracy = 0
self.val_sum_recall = 0
self.val_sum_runs = 0
# NETWORK
# Linear layers to reduce dimensionality
self.text_reduction = torch.nn.Linear(512, 256)
self.image_reduction = torch.nn.Linear(512, 256)
# Soft attention weights
self.W_query_text_half_dim = torch.nn.Linear(256, 256)
self.W_query_image_half_dim = torch.nn.Linear(256, 256)
self.W_query_text_full_dim = torch.nn.Linear(512, 512)
self.W_query_image_full_dim = torch.nn.Linear(512, 512)
self.W_key_text_half_dim = torch.nn.Linear(256, 256)
self.W_key_image_half_dim = torch.nn.Linear(256, 256)
self.W_key_image_full_dim = torch.nn.Linear(512, 512)
self.W_key_text_full_dim = torch.nn.Linear(512, 512)
# TO TEST THE MODEL WITH SAME TEXT
self.fixed_text = torch.tensor([2.2875e-01, 2.3762e-02, 1.3448e-01, 6.5997e-02, 2.5605e-01,
-1.6183e-01, 7.1169e-03, -1.6895e+00, 1.8110e-01, 1.7249e-01,
7.0582e-02, -6.3566e-02, -1.5862e-01, -2.3586e-01, 6.9382e-02,
9.4649e-02, 6.3127e-01, -4.1287e-02, -4.9883e-02, -2.1821e-01,
5.8677e-01, -2.5353e-01, 1.4792e-01, 2.2195e-02, -6.8436e-02,
-1.5512e-01, -9.8894e-02, 6.3377e-02, -2.3078e-01, 9.3588e-02,
5.2875e-02, -5.1388e-01, -7.0461e-02, 2.4253e-02, -7.8069e-02,
7.6921e-02, -1.1610e-01, -1.3345e-01, 7.8038e-03, -2.0226e-01,
1.1381e-01, -9.6335e-02, -2.2195e-02, -6.5028e-02, 1.4025e-01,
2.6969e-01, -1.0758e-01, 3.6736e-02, 3.2893e-01, -1.9067e-01,
4.9070e-02, 8.0207e-02, 7.2942e-02, 7.7496e-03, 2.0883e-01,
1.7339e-01, 1.0072e-01, -1.7874e-01, -4.6898e-02, -6.2682e-02,
5.9596e-02, 5.2925e-02, 2.4633e-01, -7.2811e-02, -1.4157e-01,
8.8013e-03, -4.6815e-02, -7.4260e-02, 8.6530e-03, -1.8174e-01,
1.6101e-01, -4.8832e-02, -5.8030e-02, -3.2518e-02, -6.2896e-02,
-2.3472e-01, -8.0996e-02, 1.1261e-01, -2.1039e-01, -2.3837e-01,
-2.6827e-02, -2.3075e-01, -2.2087e-02, 5.4009e-01, 3.7671e-02,
3.3140e-01, -4.2569e-02, -1.6946e-01, 1.7165e-01, 3.0887e-01,
4.9847e-02, 1.2438e-02, -2.0701e+00, 2.7104e-01, 1.9001e-01,
3.1907e-01, -9.1116e-02, -8.3141e-02, 4.5765e-03, -2.5675e-01,
-2.2119e-02, 3.4949e-02, 2.8192e-01, 7.9688e-02, -2.1810e-01,
8.1565e-02, 3.3208e-01, -9.1857e-02, -2.1145e-01, -1.6843e-01,
6.7942e-02, 5.1067e-01, -1.6835e-01, 2.2090e-02, 1.8061e-02,
-2.1313e-01, 2.6867e-02, -2.2734e-01, 8.4164e-02, -4.7868e-02,
2.0980e-02, -2.1424e-01, -2.2919e-02, 1.7554e-01, 5.2253e-02,
-2.2049e-01, 6.9408e-02, 7.0811e-02, -1.1892e-02, -4.7958e-02,
7.9476e-02, 1.8851e-01, 2.2516e-02, 8.6119e+00, -7.8583e-02,
1.0218e-01, 1.6675e-01, -4.0961e-01, 4.5291e-02, 7.9783e-02,
-1.1764e-01, -2.3162e-01, -2.7717e-02, 1.2963e-01, -3.0165e-01,
-2.1588e-02, -1.2324e-01, 1.9732e-02, -1.9312e-01, -7.1229e-02,
2.5102e-01, -4.1674e-01, -1.5610e-01, -6.1321e-03, -4.5332e-02,
6.1500e-02, -1.5942e-01, 3.5142e-01, -2.1119e-01, 4.5057e-02,
-5.6277e-02, -3.4298e-01, -1.6499e-01, -2.9384e-02, -2.7163e-01,
6.5339e-03, 2.7674e-02, -1.1302e-01, -2.6373e-02, -1.4370e-01,
2.1936e-01, 1.3103e-01, 2.5538e-01, 1.9502e-01, -1.5278e-01,
1.4978e-01, -2.5552e-01, 2.2397e-01, -1.0369e-01, -1.0491e-01,
5.1112e-01, 2.4879e-01, 7.0940e-02, 1.7351e-01, -3.6831e-02,
1.5027e-01, -1.9452e-01, 2.0322e-01, 8.5931e-02, -2.8588e-03,
3.1146e-02, -3.3307e-01, 1.1595e-01, 1.9435e-01, -3.4536e-02,
2.5245e-01, 4.5388e-02, 2.1197e-02, 4.2232e-02, 4.2436e-02,
4.9622e-02, -2.0907e-01, 1.2264e-01, -7.3529e-02, -2.1788e-01,
-1.2429e-01, -8.1422e-02, 1.6572e-01, -6.0989e-02, 8.0322e-02,
3.3477e-01, -7.2207e-02, -8.8658e-02, -2.4944e-01, 9.9211e-02,
8.6244e-02, 8.8807e-02, -1.9676e-01, -4.5365e-03, -3.7754e-01,
-1.7204e-01, -1.3001e-01, 6.4961e-02, -5.8192e-03, 2.4670e-01,
-8.3591e-02, -3.0810e-01, -3.4549e-02, -1.4452e-01, -5.5416e-02,
1.0527e-02, 3.1159e-01, -1.3857e-01, -2.2676e-01, 1.4768e-01,
3.2650e-01, 2.3971e-01, 6.8196e-02, -2.6235e-02, -2.9741e-01,
4.7721e-02, -1.2859e-02, 2.0340e-01, 1.7823e-02, -1.1337e-01,
4.4077e-02, -1.3949e-01, 2.9229e-01, 1.7425e-01, -5.0722e-03,
-6.3722e-02, 1.0181e-01, 2.3344e-02, 2.2200e-01, 3.5022e-02,
1.5361e-01, -1.0702e-03, 2.9319e-02, 1.8938e-01, -7.2263e-02,
2.2192e-02, 9.5394e-02, -4.4459e-03, 7.6698e-02, -1.7830e-01,
1.0213e-01, -8.8493e-02, -1.6439e-01, -1.1085e-01, 1.2938e-01,
2.3929e-01, -4.9047e-02, -1.2814e-01, -2.1075e-01, 2.4423e-01,
-4.4565e-02, -5.1225e-02, -4.0214e-02, -1.4033e-01, 6.3284e-02,
4.7094e-01, -2.6821e-02, 2.1138e-02, 1.1590e-01, -2.0023e-02,
1.7200e-01, 3.8215e-01, -2.4871e-01, -1.5359e-01, 2.4691e-01,
1.4904e-01, -1.0636e-01, 2.4185e-01, 1.7119e-03, 1.4618e-01,
-1.6813e-01, -4.4372e-01, -1.7475e-01, -6.9891e-02, -4.5553e-02,
9.3102e-02, 1.7686e-02, -1.1781e-01, 6.9423e-02, 1.0211e-02,
3.2742e-01, 7.5272e-02, 8.5080e-02, -1.7731e-01, 1.4030e-01,
2.7764e-01, -6.5041e-02, 8.5968e+00, 2.5900e-01, -2.0825e-01,
9.6241e-02, -1.5257e-01, -3.4269e-01, -1.1251e-01, 3.0549e-01,
3.1628e-01, 6.1856e-01, 1.5791e-03, 6.5656e-02, 1.8862e-02,
-7.1927e-02, 1.3239e-01, -1.1126e-01, 1.1135e-02, -3.2411e+00,
-4.7349e-02, 1.4775e-01, -9.7712e-02, 4.5727e-02, -1.3868e-01,
2.1260e-01, 1.5465e-01, 1.1308e-01, -8.0110e-02, -1.3123e-01,
1.8527e-01, -8.6424e-02, -1.9778e-01, -1.3295e-01, -1.5880e-01,
2.0800e-01, -3.6513e-02, 2.6472e-02, 2.7275e-01, 1.8995e-01,
-7.7340e-02, 1.2059e-02, 3.5163e-02, 1.5442e-02, 5.1417e-02,
5.0993e-01, 1.2994e-01, 2.3873e-01, -7.2816e-02, 1.5850e-01,
-2.0404e-01, -2.2941e-01, 2.3660e-01, 2.0418e-01, 6.7775e-02,
-3.9195e-01, 3.6655e-01, 1.6498e-01, 6.4065e-02, 4.9579e-02,
2.8265e-01, -5.9919e-03, 4.0163e-02, 8.9072e-02, 1.5125e-01,
9.0711e-02, -1.2608e-01, -1.0413e-01, -2.1931e-01, 5.0183e-02,
-3.4841e-02, -8.1449e-02, -1.1225e-01, -4.5787e-02, -7.8871e-02,
3.8858e-02, 9.2660e-02, 1.5991e-01, -6.7528e-02, -6.3166e-02,
-4.7824e-03, -1.3528e-01, 1.4845e-01, 2.0460e-01, -9.3238e-02,
1.4902e-03, 1.1896e-01, -3.1337e-01, 2.1637e-02, 1.4990e-01,
-2.1179e-03, -8.1374e-02, -1.0241e-01, -8.0754e-02, -1.4449e-01,
-1.3549e-01, -7.5588e-02, -8.0083e-02, -1.4114e-01, 2.9467e-03,
3.5340e-01, -4.3351e-02, 9.6934e-02, 1.3625e-01, 1.3339e-01,
-1.2059e-02, -1.4325e-01, -2.1202e-01, 3.8758e-02, 2.5965e-01,
-7.8454e-02, 1.5983e-01, 1.0115e-02, 2.2192e-01, -1.4043e-01,
6.7966e-02, -1.4672e-01, -1.8846e-01, 1.9488e-01, 1.2942e-01,
-1.3165e-02, -1.6099e-01, -9.6146e-02, 1.3439e-01, -5.0560e-02,
8.2779e-02, -2.4827e-01, -7.8047e-02, -3.1163e-01, -1.7481e-01,
2.1450e-01, -7.6112e-02, -1.9967e-02, 5.7099e-02, 7.7664e-02,
-7.9647e-02, 3.3941e-02, 2.9551e-02, 1.4231e-01, 2.3480e-02,
1.5209e-01, -2.0011e-01, 1.1153e-01, 1.2694e-01, 8.7853e-02,
2.6997e-01, 1.3525e-01, 1.9541e-01, 3.4429e-03, -9.6446e-02,
7.6708e-02, -3.0698e-02, -1.8507e-01, 2.5645e-01, 2.8182e-01,
-1.2282e-01, -1.1017e-01, 2.2249e-01, 2.1966e-01, 3.5795e-01,
1.6279e-01, 1.7276e-01, 2.1410e-01, -3.2499e-01, 5.0327e-02,
7.9813e-02, -1.5915e-01, -3.6175e-02, 1.4376e-01, 2.9565e-01,
6.9097e-02, -8.0661e-01, 4.9966e-02, 6.2506e-02, 1.8852e-02,
-8.6921e-02, 6.0899e-02, 2.2442e-01, -1.4272e-01, -4.0656e-04,
-1.2531e-01, 1.5240e-01, -6.8841e-02, 4.2114e-01, -4.4379e-02,
-3.5105e-02, 1.4931e-01, -8.3358e-02, -1.0498e-01, 1.4575e-01,
-1.6491e-01, 4.7820e-02, 2.5958e-01, 1.1974e-01, 1.8271e-01,
1.7439e-02, -1.5855e-01, -9.0135e-02, -2.6199e-01, -2.5709e-01,
6.3203e-03, 7.5823e-02])
self.random_text_flag = random_text
self.random_everything_flag = random_everything
self.fixed_text_flag = fixed_text
self.random_image_flag = random_images
# Weight Stacks
self.W_query = {
"multimodal": [self.text_reduction, self.image_reduction, self.W_query_text_half_dim,
self.W_query_image_half_dim],
"image": [self.W_query_image_full_dim],
}
self.W_key = {
"multimodal": [self.text_reduction, self.image_reduction, self.W_key_text_half_dim,
self.W_key_image_half_dim],
"image": [self.W_key_image_full_dim]
}
def weight_pass(self, query_text, query_image, key_text, key_image):
inference_functions = [
(True, True, True, True), # Input: text and image Context: text and image
(False, True, False, True), # Input: image Context: image
(False, True, True, True) # Input: image Context: text and image
]
if None in (query_image, key_image):
raise ValueError("Query and Key image cannot be None")
if (query_text is not None, query_image is not None, key_text is not None,
key_image is not None) in inference_functions:
query = self._queries_inference(query_text, query_image)
key = self._keys_inference(key_text, key_image)
return query, key
else:
raise ValueError("Invalid input")
def _queries_inference(self, query_text, query_image):
if query_text is None:
output = self.W_query_image_full_dim(query_image)
elif query_image is None:
raise ValueError("Query image cannot be None")
else:
text_reduction = self.text_reduction(query_text)
image_reduction = self.image_reduction(query_image)
query_text_half_dim = self.W_query_text_half_dim(text_reduction)
query_image_half_dim = self.W_query_image_half_dim(image_reduction)
output = torch.cat((query_text_half_dim, query_image_half_dim), dim=-1)
return output
def _keys_inference(self, key_text, key_image):
if key_text is None:
output = self.W_key_image_full_dim(key_image)
elif key_image is None:
raise ValueError("Key image cannot be None")
else:
text_reduction = self.text_reduction(key_text)
image_reduction = self.image_reduction(key_image)
key_text_half_dim = self.W_key_text_half_dim(text_reduction)
key_image_half_dim = self.W_key_image_half_dim(image_reduction)
output = torch.cat((key_text_half_dim, key_image_half_dim), dim=-1)
return output
def forward(self, query_text, query_image, key_text, key_image):
query_text = query_text.to(self.device)
query_image = query_image.to(self.device)
key_text = key_text.to(self.device)
key_image = key_image.to(self.device)
query, key = self.weight_pass(query_text, query_image, key_text, key_image)
d_k = key.size()[-1] # Get the size of the key
key_transposed = key.transpose(1, 2)
logits = torch.matmul(query, key_transposed) / math.sqrt(d_k)
logits = logits.squeeze()
if len(logits.shape) <= 2:
softmax = F.softmax(logits, dim=0)
else:
softmax = F.softmax(logits, dim=1)
return softmax, logits
def training_step(self, train_batch, batch_idx):
if self.current_epoch == 0 and not self.frozen and self.unfreeze_epoch != 0:
print("Freezing....................................................")
for param in self.image_reduction.parameters():
param.requires_grad = False
self.frozen = True
if self.current_epoch == self.unfreeze_epoch and self.frozen:
print("Unfreezing....................................................")
for param in self.image_reduction.parameters():
param.requires_grad = True
self.frozen = False
# Unpack the batch data
queries = train_batch['queries']
keys = train_batch['keys']
real_labels = train_batch['real_index']
keys_text = []
keys_image = []
for batch in keys:
temp_key_text = []
temp_key_image = []
for key_text, key_image in batch:
temp_key_text.append(key_text)
temp_key_image.append(key_image)
keys_text.append(torch.stack(temp_key_text))
keys_image.append(torch.stack(temp_key_image))
queries_text = []
queries_image = []
for batch in queries:
temp_query_text = []
temp_query_image = []
for query_text, query_image in batch:
temp_query_text.append(query_text)
temp_query_image.append(query_image)
queries_text.append(torch.stack(temp_query_text))
queries_image.append(torch.stack(temp_query_image))
queries_text = torch.stack(queries_text)
queries_image = torch.stack(queries_image)
keys_text = torch.stack(keys_text)
keys_image = torch.stack(keys_image)
if self.fixed_text_flag:
print("Fixed text flag")
queries_text_shape = queries_text.shape
keys_text_shape = keys_text.shape
queries_text = self.fixed_text.expand(*queries_text_shape).to(queries_text.device)
keys_text = self.fixed_text.expand(*keys_text_shape).to(keys_text.device)
if self.random_text_flag:
print("Random text flag")
old_queries_text = queries_text.clone()
old_keys_text = keys_text.clone()
queries_text = torch.randn(queries_text.shape).to(queries_text.device)
keys_text = torch.randn(keys_text.shape).to(keys_text.device)
if torch.equal(queries_text, old_queries_text):
print("Queries text are equal")
if torch.equal(keys_text, old_keys_text):
print("Keys text are equal")
if self.random_image_flag:
print("Random image flag")
old_queries_image = queries_image.clone()
old_keys_image = keys_image.clone()
queries_image = torch.randn(queries_image.shape).to(queries_image.device)
keys_image = torch.randn(keys_image.shape).to(keys_image.device)
if torch.equal(queries_image, old_queries_image):
print("Queries image are equal")
if torch.equal(keys_image, old_keys_image):
print("Keys image are equal")
if self.random_everything_flag:
print("Random everything flag")
old_queries_text = queries_text.clone()
old_keys_text = keys_text.clone()
old_queries_image = queries_image.clone()
old_keys_image = keys_image.clone()
queries_text = torch.randn(queries_text.shape).to(queries_text.device)
keys_text = torch.randn(keys_text.shape).to(keys_text.device)
queries_image = torch.randn(queries_image.shape).to(queries_image.device)
keys_image = torch.randn(keys_image.shape).to(keys_image.device)
if torch.equal(queries_text, old_queries_text):
print("Queries text are equal")
if torch.equal(keys_text, old_keys_text):
print("Keys text are equal")
if torch.equal(queries_image, old_queries_image):
print("Queries image are equal")
if torch.equal(keys_image, old_keys_image):
print("Keys image are equal")
# Forward pass
softmax, logits = self.forward(queries_text, queries_image, keys_text, keys_image)
softmax = softmax.squeeze()
real_labels = real_labels.squeeze()
logits = logits.squeeze()
real_labels = real_labels.float()
if real_labels.dim() < 3:
real_labels = real_labels.unsqueeze(0)
softmax = softmax.unsqueeze(0)
logits = logits.unsqueeze(0)
temp_real_labels = []
temp_logits = []
global_padding = 0
for batch_l, batch_r in zip(logits, real_labels):
padding = torch.nonzero(batch_r[0] == -100)
if padding.nelement() == 0:
temp_real_labels.append(batch_r)
temp_logits.append(batch_l)
continue
global_padding = global_padding + padding.nelement()
padding_index = padding[0]
temp_r = batch_r.clone()
temp_r[:, padding_index:] = 0
temp_l = batch_l.clone()
temp_l[:, padding_index:] = -100
temp_real_labels.append(temp_r)
temp_logits.append(temp_l)
for_loss_real_labels = torch.stack(temp_real_labels).float()
for_loss_logits = torch.stack(temp_logits)
loss = self.loss_method(for_loss_logits.mT, for_loss_real_labels.mT)
batched_precision = []
batched_accuracy = []
batched_recall = []
for batch_s, batch_r in zip(softmax, real_labels):
padding = torch.nonzero(batch_r[0] == -100)
if padding.nelement() > 0:
padding_index = padding[0]
batch_r = batch_r[:, :padding_index]
batch_s = batch_s[:, :padding_index]
max_indices = batch_s.argmax(dim=0)
# print("Max indices: ", max_indices)
target_index = batch_r.argmax(dim=0)
# print("Target index: ", target_index)
subtraction = max_indices - target_index
# print("Subtraction: ", subtraction)
different_values = torch.count_nonzero(subtraction)
# print("Different values: ", different_values)
# print("Sample size: ", target_index.shape)
# print("Len target index: ", len(target_index))
samples = batch_s.shape[1] * batch_s.shape[0]
TP = len(target_index) - different_values
FP = different_values
FN = different_values
TN = samples - TP - FP - FN
precision = TP / (TP + FP)
accuracy = (TP + TN) / samples
recall = TP / (TP + FN)
batched_precision.append(precision.item())
batched_accuracy.append(accuracy.item())
batched_recall.append(recall.item())
precision = sum(batched_precision) / len(batched_precision)
accuracy = sum(batched_accuracy) / len(batched_accuracy)
recall = sum(batched_recall) / len(batched_recall)
self.train_sum_precision += precision
self.train_sum_accuracy += accuracy
self.train_sum_recall += recall
self.train_sum_runs += 1
self.log("train_loss", loss, on_epoch=True, on_step=False, prog_bar=True, logger=True)
self.log("train_precision", precision, on_epoch=True, on_step=False, prog_bar=True, logger=True)
self.log("train_accuracy", accuracy, on_epoch=True, on_step=False, prog_bar=True, logger=True)
self.log("train_recall", recall, on_epoch=True, on_step=False, prog_bar=True, logger=True)
return loss
def on_train_epoch_end(self) -> None:
self.log("train_precision_epoch", self.train_sum_precision / self.train_sum_runs)
self.log("train_accuracy_epoch", self.train_sum_accuracy / self.train_sum_runs)
self.log("train_recall_epoch", self.train_sum_recall / self.train_sum_runs)
self.train_sum_precision = 0
self.train_sum_accuracy = 0
self.train_sum_recall = 0
self.train_sum_runs = 0
def configure_optimizers(self):
self.my_optimizer = torch.optim.Adam(params=self.parameters(), lr=self.learning_rate)
optimizer = self.my_optimizer
"""self.my_scheduler = torch.optim.lr_scheduler.CyclicLR(self.my_optimizer, base_lr=0.01, max_lr=0.05,step_size_up=100,cycle_momentum=False)
scheduler = {
'scheduler': self.my_scheduler,
'interval': 'step',
'frequency': 1,
'name': 'learning_rate'
}"""
return [optimizer]
def validation_step(self, val_batch, batch_idx):
# Unpack the batch data
queries = val_batch['queries']
keys = val_batch['keys']
real_labels = val_batch['real_index']
keys_text = []
keys_image = []
for batch in keys:
temp_key_text = []
temp_key_image = []
for key_text, key_image in batch:
temp_key_text.append(key_text)
temp_key_image.append(key_image)
keys_text.append(torch.stack(temp_key_text))
keys_image.append(torch.stack(temp_key_image))
queries_text = []
queries_image = []
for batch in queries:
temp_query_text = []
temp_query_image = []
for query_text, query_image in batch:
temp_query_text.append(query_text)
temp_query_image.append(query_image)
queries_text.append(torch.stack(temp_query_text))
queries_image.append(torch.stack(temp_query_image))
queries_text = torch.stack(queries_text)
queries_image = torch.stack(queries_image)
keys_text = torch.stack(keys_text)
keys_image = torch.stack(keys_image)
# Forward pass
softmax, logits = self.forward(queries_text, queries_image, keys_text, keys_image)
softmax = softmax.squeeze()
real_labels = real_labels.squeeze()
if real_labels.dim() < 3:
real_labels = real_labels.unsqueeze(0)
softmax = softmax.unsqueeze(0)
logits = logits.unsqueeze(0)
temp_real_labels = []
temp_logits = []
for batch_l, batch_r in zip(logits, real_labels):
padding = torch.nonzero(batch_r[0] == -100)
if padding.nelement() == 0:
continue
padding_index = padding[0]
temp_r = batch_r.clone()
temp_r[:, padding_index:] = 0
temp_l = batch_l.clone()
temp_l[:, padding_index:] = -100
temp_real_labels.append(temp_r)
temp_logits.append(temp_l)
if padding.nelement() > 0:
for_loss_real_labels = torch.stack(temp_real_labels).float()
for_loss_logits = torch.stack(temp_logits)
loss = self.loss_method(for_loss_logits.mT, for_loss_real_labels.mT)
else:
loss = self.loss_method(logits.mT, real_labels.mT)
if loss < 0:
print("Padding: ", padding.nelement())
print("Loss: ", loss)
print("Logits: ", logits)
print("Real labels: ", real_labels)
exit()
batched_precision = []
batched_accuracy = []
batched_recall = []
for batch_s, batch_r in zip(softmax, real_labels):
padding = torch.nonzero(batch_r[0] == -100)
if padding.nelement() > 0:
padding_index = padding[0]
batch_r = batch_r[:, :padding_index]
batch_s = batch_s[:, :padding_index]
max_indices = batch_s.argmax(dim=0)
# print("Max indices: ", max_indices)
target_index = batch_r.argmax(dim=0)
# print("Target index: ", target_index)
subtraction = max_indices - target_index
# print("Subtraction: ", subtraction)
different_values = torch.count_nonzero(subtraction)
# print("Different values: ", different_values)
# print("Sample size: ", target_index.shape)
# print("Len target index: ", len(target_index))
samples = batch_s.shape[1] * batch_s.shape[0]
TP = len(target_index) - different_values
FP = different_values
FN = different_values
TN = samples - TP - FP - FN
precision = TP / (TP + FP)
accuracy = (TP + TN) / samples
recall = TP / (TP + FN)
batched_precision.append(precision.item())
batched_accuracy.append(accuracy.item())
batched_recall.append(recall.item())
precision = sum(batched_precision) / len(batched_precision)
accuracy = sum(batched_accuracy) / len(batched_accuracy)
recall = sum(batched_recall) / len(batched_recall)
self.val_sum_precision += precision
self.val_sum_accuracy += accuracy
self.val_sum_recall += recall
self.val_sum_runs += 1
self.log("val_loss", loss, on_epoch=True, on_step=False, prog_bar=True, logger=True)
self.log("val_precision", precision, on_epoch=True, on_step=False, prog_bar=True, logger=True)
self.log("val_accuracy", accuracy, on_epoch=True, on_step=False, prog_bar=True, logger=True)
self.log("val_recall", recall, on_epoch=True, on_step=False, prog_bar=True, logger=True)
def on_validation_epoch_end(self) -> None:
self.log("val_precision_epoch", self.val_sum_precision / self.val_sum_runs)
self.log("val_accuracy_epoch", self.val_sum_accuracy / self.val_sum_runs)
self.log("val_recall_epoch", self.val_sum_recall / self.val_sum_runs)
self.val_sum_precision = 0
self.val_sum_accuracy = 0
self.val_sum_recall = 0
self.val_sum_runs = 0
if __name__ == '__main__':
if len(sys.argv) > 1:
print("Using arguments")
batch_size = int(sys.argv[1])
learning_rate = float(sys.argv[2])
epochs = int(sys.argv[3])
if sys.argv[4] == "True":
wandb_flag = True
else:
wandb_flag = False
if sys.argv[5] == "True":
find_lr = True
else:
find_lr = False
unfreeze = int(sys.argv[6])
else:
print("Using default values")
batch_size = 500
learning_rate = 0.01
epochs = 50
wandb_flag = True
find_lr = False
unfreeze = 10
random_text = False
random_everything = False
random_images = False
fixed_text = False
print("Batch size: ", batch_size)
print("Learning rate: ", learning_rate)
print("Epochs: ", epochs)
print("Wandb flag: ", wandb_flag)
print("Find lr: ", find_lr)
print("Unfreeze: ", unfreeze)
train_path = "./recipe_dataset_3500_real_1.pkl"
val_path = "./recipe_dataset_3500_real_2.pkl"
train = pickle.load(open(train_path, "rb"))
val = pickle.load(open(val_path, "rb"))
if "wrong" in train_path and "wrong" in val_path:
print("Using dataset with false positives")
string_wrong = "WRONG_"
elif "wrong" in train_path or "wrong" in val_path:
raise ValueError("One of the datasets is wrong")
else:
print("Using normal dataset")
string_wrong = ""
if random_text:
string_wrong += "RANDOM_TEXT_"
elif random_everything:
string_wrong += "RANDOM_EVERYTHING_"
elif random_images:
string_wrong += "RANDOM_IMAGES_"
elif fixed_text:
string_wrong += "FIXED_TEXT_"
# remove fields that are not needed
for batch in train:
batch.pop('ids_queries')
batch.pop('ids_keys')
for batch in val:
batch.pop('ids_queries')
batch.pop('ids_keys')
train_dataset = DataLoader(train, num_workers=0, shuffle=False, batch_size=batch_size)
print("Train dataset size:", len(train_dataset))
val_dataset = DataLoader(val, num_workers=0, shuffle=False, batch_size=batch_size)
print("Val dataset size:", len(val_dataset))
model = SoftAttention(learning_rate=learning_rate, batch_size=batch_size, unfreeze=unfreeze,
random_text=random_text, random_everything=random_everything, fixed_text=fixed_text,
random_images=random_images)
lr_monitor = LearningRateMonitor(logging_interval='step')
if wandb_flag:
run_name = f"{string_wrong}MORE_RECIPES_{len(train_dataset)}_batch_{batch_size}_lr_{learning_rate}_epochs_{epochs}_unfreeze_{unfreeze}"
wandb_logger = WandbLogger(project='reference_training', name=run_name, log_model="all")
wandb_logger.experiment.config["batch_size"] = batch_size
wandb_logger.experiment.config["max_epochs"] = epochs
wandb_logger.experiment.config["learning_rate"] = learning_rate
trainer = L.Trainer(max_epochs=epochs, detect_anomaly=False, logger=wandb_logger, callbacks=[lr_monitor])
else:
trainer = L.Trainer(max_epochs=epochs, default_root_dir="./", callbacks=[lr_monitor])
if find_lr:
tuner = Tuner(trainer)
lr_finder = tuner.lr_find(model, train_dataloaders=train_dataset, val_dataloaders=val_dataset)
print(lr_finder.suggestion())
else:
trainer.fit(model, train_dataloaders=train_dataset, val_dataloaders=val_dataset)
# trainer.fit(model, train_dataloaders=train_dataset)
if wandb_flag:
wandb.finish()