|
import datetime |
|
import json |
|
import logging |
|
import os.path as osp |
|
import time |
|
|
|
import pandas as pd |
|
import torch |
|
import torch.backends.cudnn as cudnn |
|
import torch.distributed as dist |
|
import wandb |
|
|
|
from dataset import MetaLoader, create_dataset, create_loader, create_sampler |
|
from models.vindlu import VindLU |
|
from tasks.retrieval_utils import evaluation_wrapper |
|
from tasks.shared_utils import get_media_types, setup_model |
|
from utils.basic_utils import (MetricLogger, SmoothedValue, |
|
remove_files_if_exist, setup_seed) |
|
from utils.config_utils import setup_main |
|
from utils.distributed import get_rank, get_world_size, is_main_process |
|
from utils.logger import log_dict_to_wandb, setup_wandb |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class PretrainTrainer(object): |
|
"""trainer for pretraining.""" |
|
|
|
def __init__(self, config): |
|
super(PretrainTrainer, self).__init__() |
|
self.config = config |
|
|
|
self.is_pretrain = config.mode == "pt" |
|
self.setup() |
|
|
|
self.has_decoder = False |
|
if config.mode in ["ret", "pt"]: |
|
self.evaluation_fn = evaluation_wrapper |
|
self.model_cls = VindLU |
|
elif config.mode == "vqa": |
|
raise NotImplementedError("not implemented") |
|
else: |
|
raise NotImplementedError("not implemented") |
|
|
|
self.build_dataloaders() |
|
self.build_model() |
|
|
|
def setup(self): |
|
"""setup for train.""" |
|
config = self.config |
|
if is_main_process() and config.wandb.enable: |
|
self.wandb_run = setup_wandb(config) |
|
else: |
|
self.wandb_run = None |
|
setup_seed(config.seed + get_rank()) |
|
self.device = torch.device(config.device) |
|
|
|
@torch.no_grad() |
|
def evaluate(self, epoch=0): |
|
"""evaluate the model. |
|
Args: |
|
model (nn.Module): The model to evaluate. |
|
loader (DataLoader): dataloader. |
|
tokenizer (None): tokenizer. |
|
prefix (str): The str prepended to the keys of return dict. |
|
|
|
Returns: dict. The value is the corresponding evaluation results for the key. |
|
""" |
|
eval_res = {} |
|
for test_name, test_loader in self.test_name2loaders.items(): |
|
if test_name not in self.config.test_types: |
|
logger.info( |
|
f"Skip eval {test_name} split. All test_types {self.config.test_types}" |
|
) |
|
continue |
|
with torch.cuda.amp.autocast(enabled=self.config.fp16): |
|
res = self.evaluation_fn( |
|
self.model_without_ddp, |
|
test_loader, |
|
self.tokenizer, |
|
self.device, |
|
self.config, |
|
test_name, |
|
) |
|
eval_res.update(res) |
|
|
|
df = pd.DataFrame(eval_res) |
|
logger.info(f"Epoch {epoch}") |
|
logger.info(f"\n{df.transpose().to_string(max_cols=30)}") |
|
return eval_res |
|
|
|
def build_model(self): |
|
"""TODO: Docstring for build_model. |
|
Returns: TODO |
|
|
|
""" |
|
( |
|
self.model, |
|
self.model_without_ddp, |
|
self.optimizer, |
|
self.scheduler, |
|
self.scaler, |
|
self.tokenizer, |
|
self.start_epoch, |
|
self.global_step, |
|
) = setup_model( |
|
self.config, |
|
model_cls=self.model_cls, |
|
has_decoder=self.has_decoder, |
|
pretrain=self.is_pretrain, |
|
find_unused_parameters=True, |
|
) |
|
|
|
def build_dataloaders(self): |
|
config = self.config |
|
mode = config.mode |
|
|
|
logger.info(f"Creating dataset for {mode}") |
|
train_datasets = create_dataset(f"{mode}_train", config) |
|
media_types = get_media_types(train_datasets) |
|
|
|
if config.distributed: |
|
num_tasks = get_world_size() |
|
global_rank = get_rank() |
|
samplers = create_sampler( |
|
train_datasets, [True] * len(media_types), num_tasks, global_rank |
|
) |
|
else: |
|
samplers = [None] * len(media_types) |
|
|
|
train_loaders = create_loader( |
|
train_datasets, |
|
samplers, |
|
batch_size=[config.inputs.batch_size[k] for k in media_types], |
|
num_workers=[config.num_workers] * len(media_types), |
|
is_trains=[True] * len(media_types), |
|
collate_fns=[None] * len(media_types), |
|
) |
|
|
|
|
|
test_datasets, test_dataset_names = create_dataset(f"{mode}_eval", config) |
|
test_loaders = create_loader( |
|
test_datasets, |
|
[None] * len(test_datasets), |
|
batch_size=[config.inputs.batch_size_test[d.media_type] for d in test_datasets], |
|
num_workers=[config.num_workers] * len(test_datasets), |
|
is_trains=[False] * len(test_datasets), |
|
collate_fns=[None] * len(test_datasets), |
|
) |
|
test_name2loaders = {k: v for k, v in zip(test_dataset_names, test_loaders)} |
|
|
|
self.train_loaders = train_loaders |
|
self.test_name2loaders = test_name2loaders |
|
self.media_types = media_types |
|
|
|
num_steps_per_epoch = sum(len(d) for d in self.train_loaders) |
|
|
|
config.scheduler.num_training_steps = num_steps_per_epoch * config.scheduler.epochs |
|
config.scheduler.num_warmup_steps = ( |
|
num_steps_per_epoch * config.scheduler.warmup_epochs |
|
) |
|
self.config = config |
|
|
|
def train(self): |
|
"""train the model.""" |
|
config = self.config |
|
|
|
|
|
|
|
cudnn.benchmark = len(self.media_types) == 1 |
|
|
|
if is_main_process() and config.wandb.enable: |
|
wandb.watch(self.model) |
|
|
|
best = 0 |
|
best_epoch = 0 |
|
|
|
logger.info("Start training") |
|
start_time = time.time() |
|
global_step = self.global_step |
|
for epoch in range(self.start_epoch, config.scheduler.epochs): |
|
|
|
global_step = self.train_one_epoch(epoch, global_step) |
|
|
|
|
|
eval_res = self.evaluate(epoch) |
|
|
|
if is_main_process(): |
|
|
|
|
|
if config.wandb.enable: |
|
for p, v in eval_res.items(): |
|
log_dict_to_wandb(v, step=global_step, prefix=p) |
|
|
|
if config.stop_key is not None and config.stop_key in eval_res: |
|
if config.model.multimodal.enable: |
|
cur_r_mean = eval_res[config.stop_key]["r_mean"] |
|
else: |
|
cur_r_mean = eval_res[config.stop_key.replace("/", "_emb/")]["r_mean"] |
|
else: |
|
cur_r_mean = best + 1 |
|
|
|
with open(osp.join(config.output_dir, "eval_res_latest.json"), "w") as f: |
|
json.dump(eval_res, f) |
|
|
|
|
|
save_obj = { |
|
"model": self.model_without_ddp.state_dict(), |
|
"optimizer": self.optimizer.state_dict(), |
|
"scheduler": self.scheduler.state_dict(), |
|
"scaler": self.scaler.state_dict(), |
|
"config": config, |
|
"epoch": epoch, |
|
"global_step": global_step, |
|
} |
|
torch.save(save_obj, osp.join(config.output_dir, f"ckpt_{epoch:02d}.pth")) |
|
|
|
if cur_r_mean > best: |
|
torch.save(save_obj, osp.join(config.output_dir, "ckpt_best.pth")) |
|
eval_file = "eval_res_best.json" |
|
|
|
with open(osp.join(config.output_dir, eval_file), "w") as f: |
|
json.dump(eval_res, f) |
|
best = cur_r_mean |
|
best_epoch = epoch |
|
|
|
dist.barrier() |
|
|
|
total_time = time.time() - start_time |
|
total_time_str = str(datetime.timedelta(seconds=int(total_time))) |
|
logger.info(f"Training time {total_time_str}") |
|
logger.info(f"best epoch {best_epoch} [config.stop_key {config.stop_key}]") |
|
logger.info(f"Checkpoints and Logs saved at {config.output_dir}") |
|
|
|
if is_main_process() and config.wandb.enable: |
|
self.wandb_run.finish() |
|
|
|
def train_one_epoch(self, epoch, global_step): |
|
config = self.config |
|
self.model.train() |
|
|
|
metric_logger = MetricLogger(delimiter=" ") |
|
metric_logger.add_meter("lr", SmoothedValue(window=100, fmt="{value:.6f}")) |
|
metric_logger.add_meter("temperature", SmoothedValue(window=100, fmt="{value:.4f}")) |
|
loss_names = ["loss_" + k for k, v in config.criterion.loss_weight.items() if v != 0] |
|
|
|
media_types = get_media_types(self.train_loaders) |
|
|
|
for name in loss_names: |
|
for m in media_types: |
|
metric_logger.add_meter( |
|
f"{m}-{name}", SmoothedValue(window=100, fmt="{value:.4f}") |
|
) |
|
|
|
header = f"Train Epoch: [{epoch}]" |
|
log_freq = config.log_freq |
|
|
|
if config.distributed: |
|
for d in self.train_loaders: |
|
d.sampler.set_epoch(epoch) |
|
train_loader = MetaLoader(name2loader=dict(list(zip(media_types, self.train_loaders)))) |
|
|
|
model_without_ddp = self.model.module if config.distributed else self.model |
|
iterator = metric_logger.log_every(train_loader, log_freq, header) |
|
for i, (media_type, (image, text, idx)) in enumerate(iterator): |
|
image = image.to(self.device, non_blocking=True) |
|
idx = idx.to(self.device, non_blocking=True) |
|
text_input = self.tokenizer( |
|
text, |
|
padding="max_length", |
|
truncation=True, |
|
max_length=config.inputs.max_txt_l[media_type], |
|
return_tensors="pt", |
|
).to(self.device) |
|
|
|
with torch.cuda.amp.autocast(enabled=config.fp16, dtype=torch.bfloat16): |
|
loss_dict = self.model(image, text_input, idx=idx) |
|
loss = sum(loss_dict.values()) |
|
|
|
self.optimizer.zero_grad() |
|
self.scaler.scale(loss).backward() |
|
if config.optimizer.max_grad_norm > 0: |
|
self.scaler.unscale_(self.optimizer) |
|
torch.nn.utils.clip_grad_norm_( |
|
self.model.parameters(), config.optimizer.max_grad_norm |
|
) |
|
self.scaler.step(self.optimizer) |
|
self.scaler.update() |
|
self.scheduler.step() |
|
|
|
|
|
for name in loss_names: |
|
value = loss_dict[name] |
|
value = value if isinstance(value, float) else value.item() |
|
metric_logger.update(**{f"{media_type}-{name}": value}) |
|
metric_logger.update(lr=self.optimizer.param_groups[0]["lr"]) |
|
metric_logger.update(temperature=model_without_ddp.temp.item()) |
|
|
|
if is_main_process() and config.wandb.enable and global_step % log_freq == 0: |
|
logs = metric_logger.get_global_avg_dict() |
|
log_dict_to_wandb(logs, step=global_step, prefix="train/") |
|
|
|
global_step += 1 |
|
|
|
if config.debug and global_step % 2 == 0: |
|
logger.info("debug mode, break training loop") |
|
break |
|
|
|
|
|
metric_logger.synchronize_between_processes() |
|
logger.info(f"Averaged stats: {metric_logger.global_avg()}") |
|
return global_step |
|
|
|
|
|
if __name__ == "__main__": |
|
cfg = setup_main() |
|
trainer = PretrainTrainer(cfg) |
|
if cfg.evaluate: |
|
trainer.evaluate() |
|
else: |
|
trainer.train() |
|
|