Spaces:
Sleeping
Sleeping
import torch | |
from tqdm import tqdm | |
import torch.nn as nn | |
import torch.optim as optim | |
import torchvision.transforms as transforms | |
from torch.utils.tensorboard import SummaryWriter # For TensorBoard | |
from utils import save_checkpoint, load_checkpoint, print_examples | |
from dataset import get_loader | |
from model import SeqToSeq | |
from tabulate import tabulate # To tabulate loss and epoch | |
import argparse | |
import json | |
def main(args): | |
transform = transforms.Compose( | |
[ | |
transforms.Resize((356, 356)), | |
transforms.RandomCrop((299, 299)), | |
transforms.ToTensor(), | |
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) | |
] | |
) | |
train_loader, _ = get_loader( | |
root_folder = args.root_dir, | |
annotation_file = args.csv_file, | |
transform=transform, | |
batch_size = 64, | |
num_workers=2, | |
) | |
vocab = json.load(open('vocab.json')) | |
torch.backends.cudnn.benchmark = True | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
load_model = False | |
save_model = True | |
train_CNN = False | |
# Hyperparameters | |
embed_size = args.embed_size | |
hidden_size = args.hidden_size | |
vocab_size = len(vocab['stoi']) | |
num_layers = args.num_layers | |
learning_rate = args.lr | |
num_epochs = args.num_epochs | |
# for tensorboard | |
writer = SummaryWriter(args.log_dir) | |
step = 0 | |
model_params = {'embed_size': embed_size, 'hidden_size': hidden_size, 'vocab_size':vocab_size, 'num_layers':num_layers} | |
# initialize model, loss etc | |
model = SeqToSeq(**model_params, device = device).to(device) | |
criterion = nn.CrossEntropyLoss(ignore_index = vocab['stoi']["<PAD>"]) | |
optimizer = optim.Adam(model.parameters(), lr=learning_rate) | |
# Only finetune the CNN | |
for name, param in model.encoder.inception.named_parameters(): | |
if "fc.weight" in name or "fc.bias" in name: | |
param.requires_grad = True | |
else: | |
param.requires_grad = train_CNN | |
#load from a save checkpoint | |
if load_model: | |
step = load_checkpoint(torch.load(args.save_path), model, optimizer) | |
model.train() | |
best_loss, best_epoch = 10, 0 | |
for epoch in range(num_epochs): | |
print_examples(model, device, vocab['itos']) | |
for idx, (imgs, captions) in tqdm( | |
enumerate(train_loader), total=len(train_loader), leave=False): | |
imgs = imgs.to(device) | |
captions = captions.to(device) | |
outputs = model(imgs, captions[:-1]) | |
loss = criterion( | |
outputs.reshape(-1, outputs.shape[2]), captions.reshape(-1) | |
) | |
writer.add_scalar("Training loss", loss.item(), global_step=step) | |
step += 1 | |
optimizer.zero_grad() | |
loss.backward(loss) | |
optimizer.step() | |
train_loss = loss.item() | |
if train_loss < best_loss: | |
best_loss = train_loss | |
best_epoch = epoch + 1 | |
if save_model: | |
checkpoint = { | |
"model_params": model_params, | |
"state_dict": model.state_dict(), | |
"optimizer": optimizer.state_dict(), | |
"step": step | |
} | |
save_checkpoint(checkpoint, args.save_path) | |
table = [["Loss:", train_loss], | |
["Step:", step], | |
["Epoch:", epoch + 1], | |
["Best Loss:", best_loss], | |
["Best Epoch:", best_epoch]] | |
print(tabulate(table)) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--root_dir', type = str, default = './flickr30k/flickr30k_images', help = 'path to images folder') | |
parser.add_argument('--csv_file', type = str, default = './flickr30k/results.csv', help = 'path to captions csv file') | |
parser.add_argument('--log_dir', type = str, default = './drive/MyDrive/TensorBoard/', help = 'path to save tensorboard logs') | |
parser.add_argument('--save_path', type = str, default = './drive/MyDrive/checkpoints/Seq2Seq.pt', help = 'path to save checkpoint') | |
# Model Params | |
parser.add_argument('--batch_size', type = int, default = 64) | |
parser.add_argument('--num_epochs', type = int, default = 100) | |
parser.add_argument('--embed_size', type = int, default=256) | |
parser.add_argument('--hidden_size', type = int, default=512) | |
parser.add_argument('--lr', type = float, default= 0.001) | |
parser.add_argument('--num_layers', type = int, default = 3, help = 'number of lstm layers') | |
args = parser.parse_args() | |
main(args) |