reimagine-it / retrieval /train_pl.py
Alberto Carmona
Track error cloning the repo
ebd4e51
raw
history blame
23.1 kB
from ast import parse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import time
import os
from collections import defaultdict
# import captioning.utils.opts as opts
# import captioning.models as models
# from captioning.data.pth_loader import CaptionDataset
# import captioning.utils.eval_utils as eval_utils
# import captioning.utils.misc as utils
# from captioning.utils.rewards import init_scorer, get_self_critical_reward
# from captioning.modules.loss_wrapper import LossWrapper
from clip_model import CLIPScore
from caption_data import COCORetrievalDataset
import pytorch_lightning as pl
import detectron2.utils.comm as d2comm
from detectron2.utils.env import seed_all_rng
seed_all_rng(1234)
class LitModel(pl.LightningModule):
def __init__(self, opt):
super().__init__()
self.opt = opt
self.args = args
# Intilaize dataset
# self.dataset = CaptionDataset(opt)
# self.dataset =
# opt.vocab_size = self.dataset.vocab_size
# opt.seq_length = self.dataset.seq_length
# self.batch_size = opt.batch_size
# Build model
# opt.vocab = self.dataset.get_vocab()
# model = models.setup(opt)
# print(model)
# del opt.vocab
# wrapper with loss in it.
# lw_model = LossWrapper(model, opt)
self.model = CLIPScore(use_grammar=opt.use_grammar, joint_out=opt.joint_out)
# self.lw_model = lw_model
for p in self.model.clip_model.vision_model.parameters():
p.requires_grad = False
for p in self.model.clip_model.visual_projection.parameters():
p.requires_grad = False
# self.struc_flag = None
# self.sc_flag = None
def forward(self, *args, **kwargs):
"""
I hate this design. Never pretend it as a nn.Module
"""
raise NotImplementedError
def train_dataloader(self):
# train_dataset = torch.utils.data.Subset(
# self.dataset,
# self.dataset.split_ix['train']
# )
# train_loader = torch.utils.data.DataLoader(
# dataset=train_dataset,
# batch_size=self.batch_size,
# shuffle=True,
# num_workers=4,
# collate_fn=self.dataset.collate_func
# )
train_dataset = COCORetrievalDataset(
split='karpathy_train', mode='train',
args=opt,
verbose=verbose
)
train_loader = torch.utils.data.DataLoader(
dataset=train_dataset,
batch_size=opt.batch_size,
shuffle=True,
num_workers=4,
collate_fn=train_dataset.collate_fn
)
return train_loader
def val_dataloader(self, split='karpathy_val'):
# val_dataset = torch.utils.data.Subset(
# self.dataset,
# self.dataset.split_ix[split]
# )
# val_loader = torch.utils.data.DataLoader(
# val_dataset,
# batch_size=self.batch_size,
# shuffle=False,
# num_workers=4,
# drop_last=False,
# collate_fn=self.dataset.collate_func
# )
val_dataset = COCORetrievalDataset(
split=split, mode='val',
args=opt,
verbose=verbose
)
val_loader = torch.utils.data.DataLoader(
dataset=val_dataset,
batch_size=opt.valid_batch_size,
shuffle=False,
num_workers=4,
drop_last=False,
collate_fn=val_dataset.collate_fn
)
return val_loader
def test_dataloader(self):
return self.val_dataloader('karpathy_test')
def training_step(self, data, batch_idx):
batch = data
self.model.train()
model_out = self.model.train_step(
img_feat=batch['img_feats'],
text=batch['text'],
neg_text=batch['neg_text'],
)
clip_loss = model_out['clip_loss']
if self.opt.joint_out:
loss = clip_loss
else:
grammar_loss = model_out['grammar_loss']
loss = clip_loss + grammar_loss
data_time = self.trainer.profiler.recorded_durations["get_train_batch"][-1]
data_time = torch.tensor(data_time)
# print('batch_idx', batch_idx)
# print('loss:', loss)
# logger_logs = model_out.copy()
logger_logs = {}
logger_logs['loss'] = loss.detach()
logger_logs['clip_loss'] = clip_loss.detach()
if not self.opt.joint_out:
logger_logs['grammar_loss'] = grammar_loss.detach()
logger_logs['data_time'] = data_time.detach()
# UserWarning: The {progress_bar:dict keyword} was deprecated in 0.9.1 and will be removed in 1.0.0
# Please use self.log(...) inside the lightningModule instead.
# # log on a step or aggregate epoch metric to the logger and/or progress bar
# # (inside LightningModule)
# self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
# warnings.warn(*args, **kwargs)
# UserWarning: The {log:dict keyword} was deprecated in 0.9.1 and will be removed in 1.0.0
# Please use self.log(...) inside the lightningModule instead.
# output = {
# 'loss': loss,
# 'log': logger_logs,
# 'progress_bar': {'data_time': data_time}
# }
for k, v in logger_logs.items():
if k in ['data_time', 'clip_loss', 'grammar_loss']:
self.log('train/'+k, v, prog_bar=True)
else:
self.log('train/'+k, v)
# print('training step logged')
return loss
def validation_step(self, data, batch_idx):
batch = data
self.model.eval()
with torch.no_grad():
model_out = self.model.train_step(
img_feat=batch['img_feats'],
text=batch['text'],
neg_text=batch['neg_text'],
)
if self.opt.joint_out:
clip_loss = model_out['clip_loss']
loss = clip_loss
output = {
# 'val_loss': loss,
'loss': loss.detach(),
'clip_loss': clip_loss.detach(),
# 'grammar_loss': grammar_loss.detach(),
'img_feat': model_out['img_feat'].detach(),
'text_feat': model_out['text_feat'].detach(),
# 'neg_text_feat': model_out['neg_text_feat'].detach(),
# 'grammar_pos_pred': model_out['grammar_pos_pred'].detach(),
# 'grammar_neg_pred': model_out['grammar_neg_pred'].detach(),
# 'predictions': predictions,
# 'n_predictions': n_predictions,
}
else:
clip_loss = model_out['clip_loss']
grammar_loss = model_out['grammar_loss']
loss = clip_loss + grammar_loss
output = {
# 'val_loss': loss,
'loss': loss.detach(),
'clip_loss': clip_loss.detach(),
'grammar_loss': grammar_loss.detach(),
'img_feat': model_out['img_feat'].detach(),
'text_feat': model_out['text_feat'].detach(),
# 'neg_text_feat': model_out['neg_text_feat'].detach(),
'grammar_pos_pred': model_out['grammar_pos_pred'].detach(),
'grammar_neg_pred': model_out['grammar_neg_pred'].detach(),
# 'predictions': predictions,
# 'n_predictions': n_predictions,
}
return output
def test_step(self, *args, **kwargs):
return self.validation_step(*args, **kwargs)
def validation_epoch_end(self, outputs, split='val'):
outputs = d2comm.gather(outputs)
# master node
if d2comm.is_main_process():
assert self.trainer.node_rank == 0 and self.trainer.local_rank == 0
outputs = sum(outputs, [])
out = {}
val_loss_mean = sum([_['loss'].cpu() for _ in outputs]) / len(outputs)
val_clip_loss_mean = sum([_['clip_loss'].cpu() for _ in outputs]) / len(outputs)
if not self.opt.joint_out:
val_grammar_loss_mean = sum([_['grammar_loss'].cpu() for _ in outputs]) / len(outputs)
print('loss', val_loss_mean.item())
print('clip_loss', val_clip_loss_mean.item())
if not self.opt.joint_out:
print('grammar_loss', val_grammar_loss_mean.item())
logit_scale = self.model.clip_model.logit_scale.exp().cpu()
text_feats = torch.cat([_['text_feat'].cpu() for _ in outputs], dim=0)
img_feats = torch.cat([_['img_feat'].cpu() for _ in outputs], dim=0)
assert text_feats.size() == (5000, 512), text_feats.size()
assert img_feats.size() == (5000, 512), img_feats.size()
logits_per_text = torch.matmul(text_feats, img_feats.t()) * logit_scale
logits_per_image = logits_per_text.T
# text-to-image retrieval
print('Text-to-Image retrieval')
for k in [1, 5, 10]:
text_to_image_topk = logits_per_text.topk(k, dim=1).indices
n_text = len(text_to_image_topk)
labels = torch.arange(0, n_text).view(-1, 1)
n_retrieved = ((text_to_image_topk == labels).sum(dim=1) > 0).sum()
recall_k = n_retrieved / n_text * 100
out[f'text_to_image_recall_{k}'] = recall_k.item()
print(f'R@{k}: {recall_k.item():.2f}%')
# image-to-text retrieval
print('Image-to-Text retrieval')
for k in [1, 5, 10]:
image_to_text_topk = logits_per_image.topk(k, dim=1).indices
n_image = len(image_to_text_topk)
labels = torch.arange(0, n_image).view(-1, 1)
n_retrieved = ((image_to_text_topk == labels).sum(dim=1) > 0).sum()
recall_k = n_retrieved / n_image * 100
out[f'image_to_text_recall_{k}'] = recall_k.item()
print(f'R@{k}: {recall_k.item():.2f}%')
out.update({
'loss': val_loss_mean.item(),
'clip_loss': val_clip_loss_mean.item()
})
if not self.opt.joint_out:
# grammar scoring
grammar_pos_pred = torch.cat([_['grammar_pos_pred'].cpu() for _ in outputs], dim=0)
grammar_neg_pred = torch.cat([_['grammar_neg_pred'].cpu() for _ in outputs], dim=0)
TP = (grammar_pos_pred == 1).sum().item()
FP = (grammar_pos_pred == 0).sum().item()
FN = (grammar_neg_pred == 1).sum().item()
TN = (grammar_neg_pred == 0).sum().item()
print('Grammar check')
print(f'TP: {TP} FP: {FP} FN: {FN} TN: {TN}')
precision = TP / (TP + FP) * 100
recall = TP / (TP + FN) * 100
accuracy = (TP + TN) / (TP + FP + FN + TN) * 100
f1 = 2 * precision * recall / (precision + recall)
print(f'Precision: {precision:.2f}%')
print(f'Recall: {recall:.2f}%')
print(f'Accuracy: {accuracy:.2f}%')
print(f'F1: {f1:.2f}%')
print('Total: {}'.format(len(grammar_pos_pred)))
out.update({
'grammar_loss': val_grammar_loss_mean,
'grammar_precision': precision,
'grammar_recall': recall,
'grammar_accuracy': accuracy,
'grammar_f1': f1,
})
else:
out = {}
out = d2comm.all_gather(out)[0] # Only the one from master node
assert len(out) > 0 # make sure the head has index 0
# must all be tensors
out = {k: torch.tensor(v) if not torch.is_tensor(
v) else v for k, v in out.items()}
for k, v in out.items():
self.log(f'{split}/{k}', v)
def test_epoch_end(self, outputs):
self.validation_epoch_end(outputs, 'test')
def configure_optimizers(self):
# opt = self.opt
# model = self.model
# parameters = [p for p in model.parameters() if p.requires_grad]
# if opt.noamopt:
# # assert opt.caption_model in ['transformer', 'bert', 'm2transformer'], 'noamopt can only work with transformer'
# optimizer = utils.get_std_opt(
# model, optim_func=opt.optim, factor=opt.noamopt_factor, warmup=opt.noamopt_warmup)
# elif opt.reduce_on_plateau:
# # optimizer = utils.build_optimizer(model.parameters(), opt)
# optimizer = utils.build_optimizer(parameters, opt)
# optimizer = utils.ReduceLROnPlateau(optimizer,
# factor=opt.reduce_on_plateau_factor,
# patience=opt.reduce_on_plateau_patience)
# else:
# # optimizer = utils.build_optimizer(model.parameters(), opt)
# optimizer = utils.build_optimizer(parameters, opt)
# from transformers.optimization import AdamW, get_linear_schedule_with_warmup
# batch_per_epoch = len(self.train_loader)
# t_total = batch_per_epoch // self.args.gradient_accumulation_steps * self.args.epochs
# warmup_ratio = self.args.warmup_ratio
# warmup_iters = int(t_total * warmup_ratio)
# if self.verbose:
# print("Batch per epoch: %d" % batch_per_epoch)
# print("Total Iters: %d" % t_total)
# print('Warmup ratio:', warmup_ratio)
# print("Warm up Iters: %d" % warmup_iters)
if self.args.optim == 'adamw':
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
"weight_decay": self.args.weight_decay,
},
{
"params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],
"weight_decay": 0.0,
},
]
for group in optimizer_grouped_parameters:
group['params'] = [p for p in group['params'] if p.requires_grad]
from transformers.optimization import AdamW
optim = AdamW(optimizer_grouped_parameters,
lr=self.args.lr, eps=self.args.adam_eps)
# lr_scheduler = get_linear_schedule_with_warmup(
# optim, warmup_iters, t_total)
# optimizers = []
optimizers = [optim]
lr_schedulers = []
return optimizers, lr_schedulers
def optimizer_step(self, epoch, batch_idx, optimizer,
optimizer_idx, *args, **kwargs):
# # warm up lr
# opt = self.opt
# iteration = self.trainer.global_step
# if opt.use_warmup and (iteration < opt.noamopt_warmup):
# opt.current_lr = opt.learning_rate * \
# (iteration+1) / opt.noamopt_warmup
# utils.set_lr(optimizer, opt.current_lr)
super().optimizer_step(epoch, batch_idx, optimizer,
optimizer_idx, *args, **kwargs)
# print('optimizer step')
def state_dict(self):
"""
Save the model state dict as well as opt and vocab
"""
state_dict = self.model.state_dict()
device = next(iter(state_dict.values())).device
assert '_vocab' not in state_dict and '_opt' not in state_dict, 'Just in case'
# state_dict.update({
# '_vocab': utils.serialize_to_tensor(self.model.vocab).to(device),
# '_opt': utils.serialize_to_tensor(self.opt).to(device)
# })
return state_dict
def load_state_dict(self, state_dict=None, strict=True):
# if '_vocab' in state_dict:
# self.model.vocab = utils.deserialize(state_dict['_vocab'])
# del state_dict['_vocab']
# elif strict:
# raise KeyError
# if '_opt' in state_dict:
# saved_model_opt = utils.deserialize(state_dict['_opt'])
# del state_dict['_opt']
# opt = self.opt
# # Make sure the saved opt is compatible with the curren topt
# need_be_same = ["caption_model",
# "rnn_type", "rnn_size", "num_layers"]
# for checkme in need_be_same:
# if getattr(saved_model_opt, checkme) in ['updown', 'topdown'] and \
# getattr(opt, checkme) in ['updown', 'topdown']:
# continue
# assert getattr(saved_model_opt, checkme) == getattr(
# opt, checkme), "Command line argument and saved model disagree on '%s' " % checkme
# elif strict:
# raise KeyError
self.model.load_state_dict(state_dict, strict)
class OnEpochStartCallback(pl.Callback):
def on_epoch_start(self, trainer, pl_module):
# Update lr/training stage/scheduled sampling prob etc.
opt = pl_module.opt
model = pl_module.model
epoch = trainer.current_epoch
optimizer = trainer.optimizers[0]
# if not opt.noamopt and not opt.reduce_on_plateau:
# # Assign the learning rate
# if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
# frac = (
# epoch - opt.learning_rate_decay_start) // opt.learning_rate_decay_every
# decay_factor = opt.learning_rate_decay_rate ** frac
# opt.current_lr = opt.learning_rate * decay_factor
# else:
# opt.current_lr = opt.learning_rate
# utils.set_lr(optimizer, opt.current_lr) # set the decayed rate
# # Assign the scheduled sampling prob
# if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
# frac = (
# epoch - opt.scheduled_sampling_start) // opt.scheduled_sampling_increase_every
# opt.ss_prob = min(opt.scheduled_sampling_increase_prob *
# frac, opt.scheduled_sampling_max_prob)
# model.ss_prob = opt.ss_prob
# # If start self critical training
# if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
# sc_flag = True
# init_scorer(opt.cached_tokens)
# else:
# sc_flag = False
# # If start structure loss training
# if opt.structure_after != -1 and epoch >= opt.structure_after:
# struc_flag = True
# init_scorer(opt.cached_tokens)
# else:
# struc_flag = False
# pl_module.struc_flag = struc_flag
# pl_module.sc_flag = sc_flag
class ModelCheckpoint(pl.callbacks.ModelCheckpoint):
def on_keyboard_interrupt(self, trainer, pl_module):
# Save model when keyboard interrupt
filepath = os.path.join(self.dirpath, self.prefix + 'interrupt.ckpt')
self._save_model(filepath)
from param import parse_args
# opt = opts.parse_opt()
args = parse_args()
opt = args
checkpoint_callback = ModelCheckpoint(
filepath=opt.checkpoint_dir + '{epoch:02d}',
# dirpath=opt.checkpoint_path,
save_last=True,
save_top_k=1,
verbose=True,
# monitor='to_monitor',
# monitor='val/to_monitor',
# monitor='val/CIDEr',
monitor='val/loss',
mode='min',
# prefix=opt.id+'_',
prefix=opt.id,
# filename=f'{opt.id}_',
)
verbose = True
# import torch
# if torch.cuda.current_device() in [0, -1]:
if 'LOCAL_RANK' in os.environ and os.environ['LOCAL_RANK'] != '0':
verbose = False
# if verbose:
# print(opt)
# print("""
# val_image_use,
# save_checkpoint_very
# save_every_epoch,
# save_history-ckpt will be ignored.
# """)
# Lightning defines batch size as batch size per gpu
assert opt.batch_size % torch.cuda.device_count() == 0
opt.batch_size = opt.batch_size // torch.cuda.device_count()
opt.valid_batch_size = opt.valid_batch_size // torch.cuda.device_count()
# If resume from last checkpoint
# if opt.start_from is not None and os.path.isfile(os.path.join(opt.start_from, f'{opt.id}_last.ckpt')):
# resume_from = os.path.join(opt.start_from, f'{opt.id}_last.ckpt')
if opt.start_from is not None and os.path.isfile(os.path.join(opt.start_from, f'{opt.id}-last.ckpt')):
resume_from = os.path.join(opt.start_from, f'{opt.id}-last.ckpt')
if verbose:
print('resume from', resume_from)
else:
resume_from = None
from pytorch_lightning.loggers import WandbLogger
wandb_logger = WandbLogger(
# project='CLIP-ViL-COCOCaption',
project='CLIP-Finetune-COCO',
name=opt.id,
)
if verbose:
wandb_logger.experiment.config.update(opt)
from pathlib import Path
import glob
import wandb
# src_dir = Path(__file__).resolve().parent.parent
glob_str = "*.py"
base_path = './'
wandb.save(glob_str=glob_str, base_path=base_path)
glob_str = "**/*.yaml"
base_path = './'
wandb.save(glob_str=glob_str, base_path=base_path)
# code = wandb.Artifact('project-source', type='code')
# for path in glob.glob('**/*.py', recursive=True):
# code.add_file(path, name='source/'+path)
# print(path)
# wandb.run.use_artifact(code)
lit = LitModel(opt)
# warning grad_clip_mode is ignored.
trainer = pl.Trainer(
callbacks=[
OnEpochStartCallback(),
# pl.callbacks.lr_logger.LearningRateLogger()
pl.callbacks.LearningRateMonitor()
],
default_root_dir=opt.checkpoint_dir,
resume_from_checkpoint=resume_from,
distributed_backend='ddp',
gpus=torch.cuda.device_count(),
# gpus=1,
check_val_every_n_epoch=1,
# max_epochs=opt.max_epochs,
max_epochs=opt.epochs,
# gradient_clip_val=opt.grad_clip_value,
gradient_clip_val=opt.clip_grad_norm,
checkpoint_callback=checkpoint_callback,
log_gpu_memory='min_max',
# log_save_interval=opt.losses_log_every,
log_every_n_steps=opt.losses_log_every,
profiler=True,
# profiler='simple',
# row_log_interval=10, # what is it?
flush_logs_every_n_steps=10,
num_sanity_val_steps=0,
# val_check_interval=0.01,
# limit_train_batches=500,
# progress_bar_refresh_rate=0,
# fast_dev_run=True,
precision=opt.precision,
logger=wandb_logger
)
if os.getenv('EVALUATE', '0') == '1':
trainer.test(lit)
else:
trainer.fit(lit)