|
from inspect import getargs |
|
import logging |
|
import os |
|
import random |
|
from datetime import datetime |
|
import bisect |
|
import copy |
|
import numpy as np |
|
import torch |
|
import torch.backends.cudnn as cudnn |
|
from torch import optim |
|
from torch.cuda.amp import GradScaler |
|
import faulthandler |
|
import pathlib |
|
|
|
try: |
|
import wandb |
|
except ImportError: |
|
wandb = None |
|
|
|
try: |
|
import torch.utils.tensorboard as tensorboard |
|
except ImportError: |
|
tensorboard = None |
|
|
|
try: |
|
import horovod.torch as hvd |
|
except ImportError: |
|
hvd = None |
|
|
|
from open_clip import create_model_and_transforms, trace_model, create_model |
|
from training.data import get_data |
|
from training.distributed import is_master, init_distributed_device, world_info_from_env |
|
from training.logger import setup_logging |
|
from training.params import parse_args |
|
from training.scheduler import cosine_lr |
|
from training.train import train_one_epoch, evaluate |
|
from open_clip.utils import dataset_split, get_optimizer |
|
|
|
|
|
def maintain_ckpts(args, startidx, all_idx_len): |
|
for i in reversed(range(startidx, all_idx_len)): |
|
if os.path.exists(os.path.join(args.checkpoint_path, f"epoch_top_{i}.pt")): |
|
os.rename( |
|
os.path.join(args.checkpoint_path, f"epoch_top_{i}.pt"), |
|
os.path.join(args.checkpoint_path, f"epoch_top_{i+1}.pt"), |
|
) |
|
if os.path.exists( |
|
os.path.join(args.checkpoint_path, f"epoch_top_{all_idx_len}.pt") |
|
): |
|
os.remove(os.path.join(args.checkpoint_path, f"epoch_top_{all_idx_len}.pt")) |
|
return |
|
|
|
|
|
def update_top_k_performance( |
|
new_metrics_inputs, current_top_k_ckpt_metrics, args, ckpt, bignumbetter=True |
|
): |
|
""" |
|
Record the top-k performance of the current epoch. |
|
current_top_k_metrics is a dictionary of the form: {1: top_1_ckpt_measure, 2: top_2_ckpt_measure, ...} |
|
""" |
|
if isinstance(new_metrics_inputs, (list, tuple)): |
|
new_metrics_inputs = np.mean(new_metrics_inputs) |
|
return update_top_k_performance( |
|
new_metrics_inputs, |
|
current_top_k_ckpt_metrics, |
|
args=args, |
|
ckpt=ckpt, |
|
bignumbetter=bignumbetter, |
|
) |
|
elif isinstance(new_metrics_inputs, dict): |
|
new_metrics_inputs = np.mean(list(new_metrics_inputs.values())) |
|
return update_top_k_performance( |
|
new_metrics_inputs, |
|
current_top_k_ckpt_metrics, |
|
args=args, |
|
ckpt=ckpt, |
|
bignumbetter=bignumbetter, |
|
) |
|
elif isinstance(new_metrics_inputs, (float, int)): |
|
update_flag = {k: False for k in current_top_k_ckpt_metrics.keys()} |
|
sorted_keys = sorted(current_top_k_ckpt_metrics.keys()) |
|
sorted_values = sorted( |
|
current_top_k_ckpt_metrics.values(), reverse=bignumbetter |
|
) |
|
sorted_values_ = copy.deepcopy(sorted_values) |
|
sorted_values.append(new_metrics_inputs) |
|
sorted_values = sorted(sorted_values, reverse=bignumbetter) |
|
sorted_values = sorted_values[:-1] |
|
|
|
if sorted_values == sorted_values_: |
|
return current_top_k_ckpt_metrics, new_metrics_inputs |
|
else: |
|
for i in range(len(sorted_keys)): |
|
if current_top_k_ckpt_metrics[sorted_keys[i]] != sorted_values[i]: |
|
current_top_k_ckpt_metrics[sorted_keys[i]] = sorted_values[i] |
|
update_flag[sorted_keys[i]] = True |
|
for i in range(len(update_flag)): |
|
if update_flag[i]: |
|
maintain_ckpts(args, i, len(sorted_keys)) |
|
torch.save( |
|
ckpt, |
|
os.path.join(args.checkpoint_path, f"epoch_top_{i}.pt"), |
|
) |
|
break |
|
return current_top_k_ckpt_metrics, new_metrics_inputs |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def is_pretrained_params(n): |
|
return ( |
|
n.startswith("transformer") |
|
or n in ["positional_embedding", "text_projection"] |
|
or n.startswith("token_embedding") |
|
or n.startswith("ln_final") |
|
or n.startswith("logit_scale_t") |
|
) |
|
|
|
|
|
def random_seed(seed=42, rank=0): |
|
torch.manual_seed(seed + rank) |
|
np.random.seed(seed + rank) |
|
random.seed(seed + rank) |
|
|
|
|
|
def main(): |
|
args = parse_args() |
|
|
|
args.amodel = args.amodel.replace("/", "-") |
|
|
|
|
|
|
|
|
|
|
|
|
|
random.seed(args.seed) |
|
torch.manual_seed(args.seed) |
|
torch.cuda.manual_seed(args.seed) |
|
torch.cuda.manual_seed_all(args.seed) |
|
np.random.seed(args.seed) |
|
if args.tmodel == "bert" or args.tmodel == "roberta" or args.tmodel == "bart": |
|
assert ( |
|
args.pretrained == "" or args.pretrained is None |
|
), "bert/roberta/bart text encoder does not support pretrained models." |
|
|
|
|
|
if args.name is None: |
|
args.name = "-".join( |
|
[ |
|
datetime.now().strftime("%Y_%m_%d-%H_%M_%S"), |
|
f"model_{args.amodel}", |
|
f"lr_{args.lr}", |
|
f"b_{args.batch_size}", |
|
f"j_{args.workers}", |
|
f"p_{args.precision}", |
|
] |
|
) |
|
|
|
|
|
args.distributed = False |
|
args.local_rank, args.rank, args.world_size = world_info_from_env() |
|
|
|
if args.remotedata and is_master(args): |
|
for dataset_name in args.datasetnames: |
|
for split in dataset_split[dataset_name]: |
|
if not os.path.exists(f"./json_files/{dataset_name}/{split}"): |
|
os.makedirs(f"./json_files/{dataset_name}/{split}") |
|
os.system( |
|
f"aws s3 cp s3://s-laion-audio/webdataset_tar/{dataset_name}/{split}/sizes.json ./json_files/{dataset_name}/{split}/sizes.json" |
|
) |
|
|
|
args.log_path = None |
|
if is_master(args, local=args.log_local): |
|
log_base_path = os.path.join(args.logs, args.name) |
|
os.makedirs(log_base_path, exist_ok=True) |
|
log_filename = f"out-{args.rank}" if args.log_local else "out.log" |
|
args.log_path = os.path.join(log_base_path, log_filename) |
|
if os.path.exists(args.log_path): |
|
print( |
|
"Error. Experiment already exists. Use --name {} to specify a new experiment." |
|
) |
|
return -1 |
|
|
|
|
|
args.log_level = logging.DEBUG if args.debug else logging.INFO |
|
setup_logging(args.log_path, args.log_level) |
|
|
|
|
|
device = init_distributed_device(args) |
|
|
|
args.wandb = "wandb" in args.report_to or "all" in args.report_to |
|
args.tensorboard = "tensorboard" in args.report_to or "all" in args.report_to |
|
if is_master(args): |
|
args.tensorboard_path = ( |
|
os.path.join(args.logs, args.name, "tensorboard") |
|
if args.tensorboard |
|
else "" |
|
) |
|
args.checkpoint_path = os.path.join(args.logs, args.name, "checkpoints") |
|
for dirname in [args.tensorboard_path, args.checkpoint_path]: |
|
if dirname: |
|
os.makedirs(dirname, exist_ok=True) |
|
else: |
|
args.tensorboard_path = "" |
|
args.checkpoint_path = "" |
|
|
|
if args.copy_codebase: |
|
copy_codebase(args) |
|
|
|
assert args.precision in ["amp", "fp16", "fp32"] |
|
if args.precision == "fp16": |
|
logging.warning( |
|
"It is recommended to use AMP mixed-precision instead of FP16. " |
|
"FP16 support needs further verification and tuning, especially for train." |
|
) |
|
|
|
if args.horovod: |
|
logging.info( |
|
f"Running in horovod mode with multiple processes / nodes. Device: {args.device}." |
|
f"Process (global: {args.rank}, local {args.local_rank}), total {args.world_size}." |
|
) |
|
elif args.distributed: |
|
logging.info( |
|
f"Running in distributed mode with multiple processes. Device: {args.device}." |
|
f"Process (global: {args.rank}, local {args.local_rank}), total {args.world_size}." |
|
) |
|
else: |
|
logging.info(f"Running with a single process. Device {args.device}.") |
|
|
|
logging.info(f"openai cache dir: {os.path.expanduser(args.openai_model_cache_dir)}") |
|
|
|
model, model_cfg = create_model( |
|
args.amodel, |
|
args.tmodel, |
|
args.pretrained, |
|
precision=args.precision, |
|
device=device, |
|
jit=args.torchscript, |
|
force_quick_gelu=args.force_quick_gelu, |
|
openai_model_cache_dir=os.path.expanduser(args.openai_model_cache_dir), |
|
skip_params=True, |
|
pretrained_audio=args.pretrained_audio, |
|
pretrained_text=args.pretrained_text, |
|
enable_fusion=args.enable_fusion, |
|
fusion_type=args.fusion_type, |
|
) |
|
|
|
if args.horovod: |
|
with torch.no_grad(): |
|
for param in model.parameters(): |
|
param.set_(param.contiguous()) |
|
|
|
if args.trace: |
|
model = trace_model(model, batch_size=args.batch_size, device=device) |
|
|
|
if is_master(args): |
|
logging.info("Model:") |
|
logging.info(f"{str(model)}") |
|
logging.info("Params:") |
|
params_file = os.path.join(args.logs, args.name, "params.txt") |
|
with open(params_file, "w") as f: |
|
for name in sorted(vars(args)): |
|
val = getattr(args, name) |
|
logging.info(f" {name}: {val}") |
|
f.write(f"{name}: {val}\n") |
|
|
|
if args.distributed and not args.horovod: |
|
if args.use_bn_sync: |
|
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) |
|
ddp_args = {} |
|
if args.ddp_static_graph: |
|
|
|
ddp_args["static_graph"] = True |
|
model = torch.nn.parallel.DistributedDataParallel( |
|
model, device_ids=[device], find_unused_parameters=True, **ddp_args |
|
) |
|
|
|
data = get_data(args, model_cfg) |
|
assert len(data), "At least one train or eval dataset must be specified." |
|
if args.trace: |
|
assert "train" not in data, "Cannot train with traced model" |
|
|
|
exclude = ( |
|
lambda n, p: p.ndim < 2 |
|
or "bn" in n |
|
or "ln" in n |
|
or "bias" in n |
|
or "logit_scale" in n |
|
) |
|
include = lambda n, p: not exclude(n, p) |
|
|
|
named_parameters = list(model.named_parameters()) |
|
|
|
|
|
text_freeze_parameters = [p for n, p in named_parameters if "text_branch" in n] |
|
|
|
if args.freeze_text: |
|
print("Freeze Text!!!!") |
|
for k in text_freeze_parameters: |
|
k.requires_grad = False |
|
|
|
gain_or_bias_params = [ |
|
p for n, p in named_parameters if exclude(n, p) and p.requires_grad |
|
] |
|
rest_params = [p for n, p in named_parameters if include(n, p) and p.requires_grad] |
|
|
|
|
|
if args.optimizer == "adam": |
|
args.wd = 0 |
|
args.wd_pretrained = 0 |
|
args.wd_new = 0 |
|
|
|
if args.train_data is None: |
|
optimizer = None |
|
scheduler = None |
|
else: |
|
total_steps = data["train"].dataloader.num_batches * args.epochs |
|
|
|
if args.split_opt: |
|
for x in ["lr", "beta1", "beta2", "eps", "wd"]: |
|
for y in ["_new", "_pretrained"]: |
|
if getattr(args, x + y) is None: |
|
setattr(args, x + y, getattr(args, x)) |
|
|
|
gain_or_bias_pretrained_params = [ |
|
p |
|
for n, p in named_parameters |
|
if (exclude(n, p) and p.requires_grad) and is_pretrained_params(n) |
|
] |
|
rest_pretrained_params = [ |
|
p |
|
for n, p in named_parameters |
|
if (include(n, p) and p.requires_grad) and is_pretrained_params(n) |
|
] |
|
gain_or_bias_new_params = [ |
|
p |
|
for n, p in named_parameters |
|
if (exclude(n, p) and p.requires_grad) and (not is_pretrained_params(n)) |
|
] |
|
rest_new_params = [ |
|
p |
|
for n, p in named_parameters |
|
if (include(n, p) and p.requires_grad) and (not is_pretrained_params(n)) |
|
] |
|
pretrained_params_optimizer = get_optimizer( |
|
[ |
|
{"params": gain_or_bias_pretrained_params, "weight_decay": 0.0}, |
|
{ |
|
"params": rest_pretrained_params, |
|
"weight_decay": args.wd_pretrained, |
|
}, |
|
], |
|
lr=args.lr_pretrained, |
|
betas=(args.beta1_pretrained, args.beta2_pretrained), |
|
eps=args.eps_pretrained, |
|
momentum=args.momentum_pretrained, |
|
optimizer_name=args.optimizer, |
|
) |
|
pretrained_params_scheduler = cosine_lr( |
|
pretrained_params_optimizer, |
|
args.lr_pretrained, |
|
args.warmup, |
|
total_steps, |
|
) |
|
new_params_optimizer = get_optimizer( |
|
[ |
|
{"params": gain_or_bias_new_params, "weight_decay": 0.0}, |
|
{"params": rest_new_params, "weight_decay": args.wd_new}, |
|
], |
|
lr=args.lr_new, |
|
betas=(args.beta1_new, args.beta2_new), |
|
eps=args.eps_new, |
|
momentum=args.momentum_new, |
|
optimizer_name=args.optimizer, |
|
) |
|
|
|
new_params_scheduler = cosine_lr( |
|
new_params_optimizer, args.lr_new, args.warmup, total_steps |
|
) |
|
|
|
optimizer = { |
|
"pretrained": pretrained_params_optimizer, |
|
"new": new_params_optimizer, |
|
} |
|
scheduler = { |
|
"pretrained": pretrained_params_scheduler, |
|
"new": new_params_scheduler, |
|
} |
|
|
|
if args.horovod: |
|
pretrained_params_optimizer = hvd.DistributedOptimizer( |
|
pretrained_params_optimizer, |
|
named_parameters=model.named_parameters(), |
|
) |
|
new_params_optimizer = hvd.DistributedOptimizer( |
|
new_params_optimizer, named_parameters=model.named_parameters() |
|
) |
|
hvd.broadcast_parameters(model.state_dict(), root_rank=0) |
|
hvd.broadcast_optimizer_state(pretrained_params_optimizer, root_rank=0) |
|
hvd.broadcast_optimizer_state(new_params_optimizer, root_rank=0) |
|
else: |
|
optimizer = get_optimizer( |
|
[ |
|
{"params": gain_or_bias_params, "weight_decay": 0.0}, |
|
{"params": rest_params, "weight_decay": args.wd}, |
|
], |
|
lr=args.lr, |
|
betas=(args.beta1, args.beta2), |
|
eps=args.eps, |
|
momentum=args.momentum, |
|
optimizer_name=args.optimizer, |
|
) |
|
|
|
scheduler = cosine_lr(optimizer, args.lr, args.warmup, total_steps) |
|
|
|
if args.horovod: |
|
optimizer = hvd.DistributedOptimizer( |
|
optimizer, named_parameters=model.named_parameters() |
|
) |
|
hvd.broadcast_parameters(model.state_dict(), root_rank=0) |
|
hvd.broadcast_optimizer_state(optimizer, root_rank=0) |
|
|
|
scaler = GradScaler() if args.precision == "amp" else None |
|
|
|
|
|
start_epoch = 0 |
|
if args.resume is not None: |
|
if os.path.isfile(args.resume): |
|
checkpoint = torch.load(args.resume, map_location=device) |
|
if "epoch" in checkpoint: |
|
|
|
start_epoch = checkpoint["epoch"] |
|
sd = checkpoint["state_dict"] |
|
if not args.distributed and next(iter(sd.items()))[0].startswith( |
|
"module" |
|
): |
|
sd = {k[len("module.") :]: v for k, v in sd.items()} |
|
model.load_state_dict(sd) |
|
if args.split_opt: |
|
if optimizer is not None: |
|
for k, o_ in optimizer.items(): |
|
o_.load_state_dict(checkpoint[k + "_" + "optimizer"]) |
|
if optimizer is not None: |
|
optimizer.load_state_dict(checkpoint["optimizer"]) |
|
if scaler is not None and "scaler" in checkpoint: |
|
scaler.load_state_dict(checkpoint["scaler"]) |
|
logging.info( |
|
f"=> resuming checkpoint '{args.resume}' (epoch {start_epoch})" |
|
) |
|
else: |
|
|
|
model.load_state_dict(checkpoint) |
|
logging.info( |
|
f"=> loaded checkpoint '{args.resume}' (epoch {start_epoch})" |
|
) |
|
if args.freeze_text: |
|
print("Freeze Text!!!!") |
|
for k in text_freeze_parameters: |
|
k.requires_grad = False |
|
else: |
|
logging.info("=> no checkpoint found at '{}'".format(args.resume)) |
|
|
|
cudnn.benchmark = True |
|
cudnn.deterministic = False |
|
|
|
|
|
args.save_logs = args.logs and args.logs.lower() != "none" and is_master(args) |
|
writer = None |
|
if args.save_logs and args.tensorboard: |
|
assert tensorboard is not None, "Please install tensorboard." |
|
writer = tensorboard.SummaryWriter(args.tensorboard_path) |
|
|
|
if args.wandb and is_master(args): |
|
assert wandb is not None, "Please install wandb." |
|
logging.debug("Starting wandb.") |
|
args.train_sz = data["train"].dataloader.num_samples |
|
if args.val_data is not None: |
|
args.val_sz = data["val"].dataloader.num_samples |
|
|
|
wandb.init( |
|
project="clap", |
|
notes=args.wandb_notes, |
|
name=args.wandb_notes, |
|
tags=[], |
|
config=vars(args), |
|
) |
|
if args.debug: |
|
wandb.watch(model, log="all") |
|
wandb.save(params_file) |
|
logging.debug("Finished loading wandb.") |
|
|
|
if "train" not in data: |
|
evaluate(model, data, start_epoch, args, writer) |
|
return |
|
elif start_epoch == 0 and "val" in data and not args.no_eval: |
|
evaluate(model, data, 0, args, writer) |
|
|
|
if args.save_top_performance: |
|
current_top_k_ckpt_metrics = { |
|
i: 0 for i in range(args.save_top_performance) |
|
} |
|
|
|
|
|
for epoch in range(start_epoch, args.epochs): |
|
|
|
if epoch == args.freeze_text_after: |
|
print("Text pretrained parameters are freezed since this epoch.") |
|
for k in text_freeze_parameters: |
|
k.requires_grad = False |
|
if is_master(args): |
|
logging.info(f"Start epoch {epoch}") |
|
|
|
train_one_epoch(model, data, epoch, optimizer, scaler, scheduler, args, writer) |
|
completed_epoch = epoch + 1 |
|
|
|
if ( |
|
any(v in data for v in ("val", "imagenet-val", "imagenet-v2")) |
|
and not args.no_eval |
|
): |
|
metrics = evaluate(model, data, completed_epoch, args, writer) |
|
if args.save_top_performance: |
|
top_k_dataset = args.top_k_checkpoint_select_dataset |
|
top_k_metric = args.top_k_checkpoint_select_metric |
|
filtered_metrics = [ |
|
v |
|
for k, v in metrics.items() |
|
if top_k_metric in k and top_k_dataset in k |
|
] |
|
|
|
if args.save_logs: |
|
if args.split_opt: |
|
opt_dict = { |
|
k + "_" + "optimizer": v.state_dict() for k, v in optimizer.items() |
|
} |
|
else: |
|
opt_dict = {"optimizer": optimizer.state_dict()} |
|
checkpoint_dict = { |
|
"epoch": completed_epoch, |
|
"name": args.name, |
|
"state_dict": model.state_dict(), |
|
} |
|
checkpoint_dict.update(opt_dict) |
|
if scaler is not None: |
|
checkpoint_dict["scaler"] = scaler.state_dict() |
|
|
|
if completed_epoch == args.epochs or ( |
|
args.save_frequency > 0 and (completed_epoch % args.save_frequency) == 0 |
|
): |
|
torch.save( |
|
checkpoint_dict, |
|
os.path.join(args.checkpoint_path, f"epoch_{completed_epoch}.pt"), |
|
) |
|
if args.save_most_recent: |
|
torch.save( |
|
checkpoint_dict, |
|
os.path.join(args.checkpoint_path, f"epoch_latest.pt"), |
|
) |
|
if args.save_top_performance and not args.no_eval: |
|
update_top_k_performance( |
|
filtered_metrics, |
|
current_top_k_ckpt_metrics, |
|
args, |
|
checkpoint_dict, |
|
bignumbetter=True, |
|
) |
|
|
|
if args.wandb and is_master(args): |
|
wandb.finish() |
|
|
|
|
|
def copy_codebase(args): |
|
from shutil import copytree, ignore_patterns |
|
|
|
new_code_path = os.path.join(args.logs, args.name, "code") |
|
if os.path.exists(new_code_path): |
|
print( |
|
f"Error. Experiment already exists at {new_code_path}. Use --name to specify a new experiment." |
|
) |
|
return -1 |
|
print(f"Copying codebase to {new_code_path}") |
|
current_code_path = os.path.realpath(__file__) |
|
for _ in range(3): |
|
current_code_path = os.path.dirname(current_code_path) |
|
copytree( |
|
current_code_path, new_code_path, ignore=ignore_patterns("log", "logs", "wandb") |
|
) |
|
print("Done copying code.") |
|
return 1 |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|