Spaces:
Running
on
Zero
Running
on
Zero
import datetime | |
import logging | |
import time | |
from os.path import join | |
import pandas as pd | |
import torch | |
import torch.backends.cudnn as cudnn | |
import torch.distributed as dist | |
import wandb | |
from torch.utils.data import ConcatDataset | |
from dataset.serialize import local_broadcast_process_authkey | |
from dataset import MetaLoader_rs, create_dataset, create_loader, create_sampler, create_stateful_sampler | |
from models import * | |
from tasks_clip.retrieval_utils import evaluation_wrapper | |
from tasks_clip.shared_utils import get_media_types, setup_model | |
from utils.basic_utils import MetricLogger, SmoothedValue, setup_seed | |
from utils.config_utils import setup_main | |
from utils.distributed import get_rank, is_main_process | |
from utils.logger import log_dict_to_wandb, setup_wandb | |
logger = logging.getLogger(__name__) | |
def train( | |
model, | |
train_loaders, | |
optimizer, | |
tokenizer, | |
epoch, | |
global_step, | |
device, | |
scheduler, | |
scaler, | |
config, | |
data_type, | |
skip_num=0 | |
): | |
model.train() | |
metric_logger = MetricLogger(delimiter=" ") | |
metric_logger.add_meter("lr", SmoothedValue(window=100, fmt="{value:.6f}")) | |
metric_logger.add_meter("temperature", SmoothedValue(window=100, fmt="{value:.4f}")) | |
loss_names = ["loss_" + k for k, v in config.criterion.loss_weight.items() if v != 0] | |
media_types = get_media_types(train_loaders) | |
for name in loss_names: | |
for m in media_types: | |
metric_logger.add_meter( | |
f"{m}-{name}", SmoothedValue(window=100, fmt="{value:.4f}") | |
) | |
header = f"Train Epoch: [{epoch}]" | |
log_freq = config.log_freq | |
if config.distributed: | |
for d in train_loaders: | |
d.sampler.set_epoch(epoch) | |
train_loader = MetaLoader_rs(name2loader=dict(list(zip(media_types, train_loaders))), skip_num=skip_num) | |
model_without_ddp = model.module if config.distributed else model | |
iterator = metric_logger.log_every(train_loader, log_freq, header) | |
for i, (media_type, (image, text, idx)) in enumerate(iterator): | |
image = image.to(device, non_blocking=True) | |
idx = idx.to(device, non_blocking=True) | |
text_input = tokenizer(text).to(device) | |
with torch.cuda.amp.autocast(enabled=config.use_half_precision, dtype=data_type): | |
loss_dict = model(image, text_input, idx=idx) | |
loss = sum(loss_dict.values()) | |
if hasattr(config, "deepspeed") and config.deepspeed.enable: | |
model.backward(loss) | |
model.step() | |
else: | |
if not config.use_half_precision or config.get('use_bf16', True): | |
optimizer.zero_grad() | |
loss.backward() | |
if config.optimizer.max_grad_norm > 0: | |
torch.nn.utils.clip_grad_norm_(model.parameters(), config.optimizer.max_grad_norm) | |
optimizer.step() | |
scheduler.step() | |
else: | |
optimizer.zero_grad() | |
scaler.scale(loss).backward() | |
if config.optimizer.max_grad_norm > 0: | |
scaler.unscale_(optimizer) | |
torch.nn.utils.clip_grad_norm_(model.parameters(), config.optimizer.max_grad_norm) | |
scaler.step(optimizer) | |
scaler.update() | |
scheduler.step() | |
# logging | |
for name in loss_names: | |
value = loss_dict[name] | |
value = value if isinstance(value, float) else value.item() | |
metric_logger.update(**{f"{media_type}-{name}": value}) | |
metric_logger.update(lr=optimizer.param_groups[0]["lr"]) | |
metric_logger.update(temperature=model_without_ddp.temp.item()) | |
if is_main_process() and config.wandb.enable and global_step % log_freq == 0: | |
logs = metric_logger.get_global_avg_dict() | |
log_dict_to_wandb(logs, step=global_step, prefix="train/") | |
global_step += 1 | |
if config.debug and global_step % 20 == 0: | |
logger.info("debug mode, break training loop") | |
break | |
if config.debug and global_step % (2 * log_freq + 3) == 0: | |
logger.info("debug mode, break training loop") | |
break | |
if config.get('save_iter', 0) and global_step % config.save_iter == 0: | |
if hasattr(config, "deepspeed") and config.deepspeed.enable: | |
tag = f"ckpt_iter{global_step:02d}.pth" | |
model.save_checkpoint(config.output_dir, tag=tag, save_latest=False, exclude_frozen_parameters=True) | |
elif is_main_process(): | |
state_dict = model_without_ddp.state_dict() | |
param_grad_dict = { | |
k: v.requires_grad for (k, v) in model_without_ddp.named_parameters() | |
} | |
for k in list(state_dict.keys()): | |
if k in param_grad_dict.keys() and not param_grad_dict[k]: | |
# delete parameters that do not require gradient | |
logger.info(f"Not saving {k}") | |
del state_dict[k] | |
save_obj = { | |
"model": model_without_ddp.state_dict(), | |
"optimizer": optimizer.state_dict(), | |
"scheduler": scheduler.state_dict(), | |
"scaler": scaler.state_dict(), | |
"config": config, | |
"epoch": epoch, | |
"global_step": global_step, | |
} | |
torch.save(save_obj, join(config.output_dir, f"ckpt_iter{global_step:02d}.pth")) | |
# gather the stats from all processes | |
metric_logger.synchronize_between_processes() | |
logger.info(f"Averaged stats: {metric_logger.global_avg()}") | |
return global_step | |
def setup_dataloaders(config, mode="pt"): | |
# train datasets, create a list of data loaders | |
logger.info(f"Creating dataset for {mode}") | |
train_datasets = create_dataset(f"{mode}_train", config) | |
media_types = get_media_types(train_datasets) | |
if config.distributed: | |
batch_size = [config.inputs.batch_size[k] for k in media_types] # batch_size for each GPU | |
samplers = create_stateful_sampler(train_datasets, batch_size) | |
else: | |
raise NotImplementedError | |
train_loaders = create_loader( | |
train_datasets, | |
samplers, | |
batch_size=[config.inputs.batch_size[k] for k in media_types], | |
num_workers=[config.num_workers] * len(media_types), | |
is_trains=[True] * len(media_types), | |
collate_fns=[None] * len(media_types), | |
) | |
# test datasets, a mapping from dataset name to data loader | |
test_datasets, test_dataset_names = create_dataset(f"{mode}_eval", config) | |
test_loaders = create_loader( | |
test_datasets, | |
[None] * len(test_datasets), | |
batch_size=[config.inputs.batch_size_test[d.media_type] for d in test_datasets], | |
num_workers=[config.num_workers] * len(test_datasets), | |
is_trains=[False] * len(test_datasets), | |
collate_fns=[None] * len(test_datasets), | |
) | |
test_name2loaders = {k: v for k, v in zip(test_dataset_names, test_loaders)} | |
return train_loaders, test_name2loaders, media_types | |
def main(config): | |
if is_main_process() and config.wandb.enable: | |
run = setup_wandb(config) | |
is_pretrain = config.mode == "pt" | |
logger.info(f"train_file: {config.train_file}") | |
setup_seed(config.seed + get_rank()) | |
device = torch.device(config.device) | |
train_loaders, test_name2loaders, train_media_types = setup_dataloaders( | |
config, mode=config.mode | |
) | |
num_steps_per_epoch = sum(len(d) for d in train_loaders) | |
config.scheduler.num_training_steps = num_steps_per_epoch * config.scheduler.epochs | |
config.scheduler.num_warmup_steps = num_steps_per_epoch * config.scheduler.warmup_epochs | |
# set cudnn.benchmark=True only when input size is fixed | |
# https://discuss.pytorch.org/t/what-does-torch-backends-cudnn-benchmark-do/5936/3 | |
cudnn.benchmark = len(train_media_types) == 1 | |
model_cls = eval(config.model.get('model_cls', 'InternVideo2_CLIP')) | |
( | |
model, | |
model_without_ddp, | |
optimizer, | |
scheduler, | |
scaler, | |
tokenizer, | |
start_epoch, | |
global_step, | |
) = setup_model( | |
config, | |
model_cls=model_cls, | |
pretrain=is_pretrain, | |
find_unused_parameters=True, | |
num_steps_per_epoch=num_steps_per_epoch, | |
) | |
if is_main_process() and config.wandb.enable: | |
wandb.watch(model) | |
best = 0 | |
best_epoch = 0 | |
if config.get('use_bf16', True): | |
data_type = torch.bfloat16 | |
else: | |
data_type = torch.float16 | |
logger.info("Start training") | |
logger.info(f"Epoch: {start_epoch}") | |
start_time = time.time() | |
start_step = start_epoch * num_steps_per_epoch | |
for epoch in range(start_epoch, config.scheduler.epochs): | |
if not config.evaluate: | |
global_step = train( | |
model, | |
train_loaders, | |
optimizer, | |
tokenizer, | |
epoch, | |
global_step, | |
device, | |
scheduler, | |
scaler, | |
config, | |
data_type, | |
skip_num = global_step - start_step | |
) | |
# save checkpoint befor evaluation | |
# only save those with gradient | |
if hasattr(config, "deepspeed") and config.deepspeed.enable: | |
if config.get("save_latest", False): | |
tag = "ckpt_latest.pth" | |
else: | |
tag = f"ckpt_{epoch:02d}.pth" | |
model.save_checkpoint(config.output_dir, tag=tag, save_latest=False, exclude_frozen_parameters=True) | |
elif is_main_process(): | |
state_dict = model_without_ddp.state_dict() | |
param_grad_dict = { | |
k: v.requires_grad for (k, v) in model_without_ddp.named_parameters() | |
} | |
for k in list(state_dict.keys()): | |
if k in param_grad_dict.keys() and not param_grad_dict[k]: | |
# delete parameters that do not require gradient | |
logger.info(f"Not saving {k}") | |
del state_dict[k] | |
save_obj = { | |
"model": model_without_ddp.state_dict(), | |
"optimizer": optimizer.state_dict(), | |
"scheduler": scheduler.state_dict(), | |
"scaler": scaler.state_dict(), | |
"config": config, | |
"epoch": epoch, | |
"global_step": global_step, | |
} | |
if config.get("save_latest", False): | |
torch.save(save_obj, join(config.output_dir, "ckpt_latest.pth")) | |
else: | |
torch.save(save_obj, join(config.output_dir, f"ckpt_{epoch:02d}.pth")) | |
# evaluation | |
with torch.cuda.amp.autocast(enabled=config.use_half_precision, dtype=data_type): | |
eval_res = {} | |
for test_name, test_loader in test_name2loaders.items(): | |
if test_name not in config.test_types: | |
logger.info( | |
f"Skip eval {test_name} split. All test_types {config.test_types}" | |
) | |
continue | |
res = evaluation_wrapper( | |
model_without_ddp, test_loader, tokenizer, device, config, data_type=data_type, prefix=test_name | |
) | |
eval_res.update(res) | |
# save the best checkpoint | |
if is_main_process(): | |
# log to wandb | |
if config.wandb.enable: | |
for p, v in eval_res.items(): | |
log_dict_to_wandb(v, step=global_step, prefix=p) | |
if config.stop_key is not None and config.stop_key in eval_res: | |
cur_r_mean = eval_res[config.stop_key]["r_mean"] | |
else: # None | |
cur_r_mean = best + 1 # save the last as the best | |
eval_res = pd.DataFrame(eval_res) | |
logger.info(f"Epoch {epoch}") | |
logger.info(f"\n{eval_res.transpose().to_string(max_cols=30)}") | |
eval_res.to_json(join(config.output_dir, "eval_res_latest.json")) | |
if not config.evaluate and cur_r_mean > best: | |
if not hasattr(config, "deepspeed") or not config.deepspeed.enable: | |
torch.save(save_obj, join(config.output_dir, "ckpt_best.pth")) | |
eval_file = "eval_res_best.json" | |
eval_res.to_json(join(config.output_dir, eval_file)) | |
best = cur_r_mean | |
best_epoch = epoch | |
if hasattr(config, "deepspeed") and config.deepspeed.enable: | |
r_mean_best = torch.tensor([0.0, 0.0]).to(device) | |
if is_main_process(): | |
r_mean_best[0] = cur_r_mean | |
r_mean_best[1] = best | |
dist.broadcast(r_mean_best, 0) | |
cur_r_mean, best = r_mean_best[0].item(), r_mean_best[1].item() | |
if not config.evaluate and cur_r_mean > best: | |
model.save_checkpoint(config.output_dir, tag="ckpt_best.pth", save_latest=False, exclude_frozen_parameters=True) | |
if config.evaluate: | |
break | |
start_step = global_step | |
dist.barrier() | |
total_time = time.time() - start_time | |
total_time_str = str(datetime.timedelta(seconds=int(total_time))) | |
logger.info(f"Training time {total_time_str}") | |
logger.info(f"best epoch {best_epoch} [config.stop_key {config.stop_key}]") | |
logger.info(f"Checkpoints and Logs saved at {config.output_dir}") | |
if is_main_process() and config.wandb.enable: | |
run.finish() | |
if __name__ == "__main__": | |
cfg = setup_main() | |
local_broadcast_process_authkey() | |
main(cfg) | |