clip_gpt2 / neuralnet /train.py
Vageesh1's picture
Upload 4 files
4e527a6
raw
history blame
4.58 kB
import torch
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter # For TensorBoard
from utils import save_checkpoint, load_checkpoint, print_examples
from dataset import get_loader
from model import SeqToSeq
from tabulate import tabulate # To tabulate loss and epoch
import argparse
import json
def main(args):
transform = transforms.Compose(
[
transforms.Resize((356, 356)),
transforms.RandomCrop((299, 299)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
]
)
train_loader, _ = get_loader(
root_folder = args.root_dir,
annotation_file = args.csv_file,
transform=transform,
batch_size = 64,
num_workers=2,
)
vocab = json.load(open('vocab.json'))
torch.backends.cudnn.benchmark = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
load_model = False
save_model = True
train_CNN = False
# Hyperparameters
embed_size = args.embed_size
hidden_size = args.hidden_size
vocab_size = len(vocab['stoi'])
num_layers = args.num_layers
learning_rate = args.lr
num_epochs = args.num_epochs
# for tensorboard
writer = SummaryWriter(args.log_dir)
step = 0
model_params = {'embed_size': embed_size, 'hidden_size': hidden_size, 'vocab_size':vocab_size, 'num_layers':num_layers}
# initialize model, loss etc
model = SeqToSeq(**model_params, device = device).to(device)
criterion = nn.CrossEntropyLoss(ignore_index = vocab['stoi']["<PAD>"])
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
# Only finetune the CNN
for name, param in model.encoder.inception.named_parameters():
if "fc.weight" in name or "fc.bias" in name:
param.requires_grad = True
else:
param.requires_grad = train_CNN
#load from a save checkpoint
if load_model:
step = load_checkpoint(torch.load(args.save_path), model, optimizer)
model.train()
best_loss, best_epoch = 10, 0
for epoch in range(num_epochs):
print_examples(model, device, vocab['itos'])
for idx, (imgs, captions) in tqdm(
enumerate(train_loader), total=len(train_loader), leave=False):
imgs = imgs.to(device)
captions = captions.to(device)
outputs = model(imgs, captions[:-1])
loss = criterion(
outputs.reshape(-1, outputs.shape[2]), captions.reshape(-1)
)
writer.add_scalar("Training loss", loss.item(), global_step=step)
step += 1
optimizer.zero_grad()
loss.backward(loss)
optimizer.step()
train_loss = loss.item()
if train_loss < best_loss:
best_loss = train_loss
best_epoch = epoch + 1
if save_model:
checkpoint = {
"model_params": model_params,
"state_dict": model.state_dict(),
"optimizer": optimizer.state_dict(),
"step": step
}
save_checkpoint(checkpoint, args.save_path)
table = [["Loss:", train_loss],
["Step:", step],
["Epoch:", epoch + 1],
["Best Loss:", best_loss],
["Best Epoch:", best_epoch]]
print(tabulate(table))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--root_dir', type = str, default = './flickr30k/flickr30k_images', help = 'path to images folder')
parser.add_argument('--csv_file', type = str, default = './flickr30k/results.csv', help = 'path to captions csv file')
parser.add_argument('--log_dir', type = str, default = './drive/MyDrive/TensorBoard/', help = 'path to save tensorboard logs')
parser.add_argument('--save_path', type = str, default = './drive/MyDrive/checkpoints/Seq2Seq.pt', help = 'path to save checkpoint')
# Model Params
parser.add_argument('--batch_size', type = int, default = 64)
parser.add_argument('--num_epochs', type = int, default = 100)
parser.add_argument('--embed_size', type = int, default=256)
parser.add_argument('--hidden_size', type = int, default=512)
parser.add_argument('--lr', type = float, default= 0.001)
parser.add_argument('--num_layers', type = int, default = 3, help = 'number of lstm layers')
args = parser.parse_args()
main(args)