xjsc0's picture
1
64ec292
# coding: utf-8
__author__ = "Roman Solovyev (ZFTurbo): https://github.com/ZFTurbo/"
__version__ = "1.0.5"
import argparse
import sys
import warnings
from typing import Callable, List, Union
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
import wandb
from ml_collections import ConfigDict
from tqdm.auto import tqdm
from utils.model_utils import (
initialize_model_and_device,
normalize_batch,
save_last_weights,
save_weights,
)
from utils.settings import (
get_model_from_config,
get_scheduler,
initialize_environment,
initialize_environment_ddp,
parse_args_train,
wandb_init,
)
from valid import valid, valid_multi_gpu
warnings.filterwarnings("ignore")
def forward_step(
x, y, active_stem_ids, get_internal_loss, model, multi_loss, device_ids
):
if get_internal_loss:
loss = model(x, y, active_stem_ids=active_stem_ids)
if isinstance(device_ids, (list, tuple)):
loss = loss.mean()
return loss
else:
y_ = model(x)
return multi_loss(y_, y, x)
def train_one_epoch(
model: torch.nn.Module,
config: ConfigDict,
args: argparse.Namespace,
optimizer: torch.optim.Optimizer,
device: torch.device,
device_ids: List[int],
epoch: int,
use_amp: bool,
scaler: torch.cuda.amp.GradScaler,
scheduler,
gradient_accumulation_steps: int,
train_loader: torch.utils.data.DataLoader,
multi_loss: Callable[
[
torch.Tensor,
torch.Tensor,
torch.Tensor,
],
torch.Tensor,
],
all_losses=None,
world_size=None,
ema_model=None,
safe_mode=None,
) -> None:
"""
Train the model for one epoch.
Args:
world_size:
scheduler:
model: The model to train.
config: Configuration object containing training parameters.
args: Command-line arguments with specific settings (e.g., model type).
optimizer: Optimizer used for training.
device: Device to run the model on (CPU or GPU).
device_ids: List of GPU device IDs if using multiple GPUs.
epoch: The current epoch number.
use_amp: Whether to use automatic mixed precision (AMP) for training.
scaler: Scaler for AMP to manage gradient scaling.
gradient_accumulation_steps: Number of gradient accumulation steps before updating the optimizer.
train_loader: DataLoader for the training dataset.
multi_loss: The loss function to use during training.
Returns:
None
"""
ddp = True if world_size else False
should_print = not dist.is_initialized() or dist.get_rank() == 0
model.train()
if not ddp:
model.to(device)
if should_print:
print(f"Train epoch: {epoch} Learning rate: {optimizer.param_groups[0]['lr']}")
sys.stdout.flush()
loss_val = 0.0
total = 0
all_losses[f"epoch_{epoch}"] = []
normalize = getattr(config.training, "normalize", False)
get_internal_loss = (
args.model_type
in (
"mel_band_roformer",
"bs_roformer",
"bs_mamba2",
"mel_band_conformer",
"bs_conformer",
)
and not args.use_standard_loss
)
if ddp:
pbar = (
tqdm(train_loader, dynamic_ncols=True)
if dist.get_rank() == 0
else train_loader
)
else:
pbar = tqdm(train_loader)
for i, data in enumerate(pbar):
if len(data) == 3:
batch, mixes, active_stem_ids = data
elif len(data) == 2:
batch, mixes = data
active_stem_ids = None
else:
raise ValueError(f"len data is {len(data)}")
x = mixes.to(device)
y = batch.to(device)
if normalize:
x, y = normalize_batch(x, y)
if safe_mode:
try:
with torch.cuda.amp.autocast(enabled=use_amp):
loss = forward_step(
x,
y,
active_stem_ids,
get_internal_loss,
model,
multi_loss,
device_ids,
)
except Exception as e:
print(f"Error: {e}")
continue
else:
with torch.cuda.amp.autocast(enabled=use_amp):
loss = forward_step(
x,
y,
active_stem_ids,
get_internal_loss,
model,
multi_loss,
device_ids,
)
loss /= gradient_accumulation_steps
scaler.scale(loss).backward()
if ((i + 1) % gradient_accumulation_steps == 0) or (i == len(train_loader) - 1):
scaler.unscale_(optimizer)
if config.training.grad_clip:
nn.utils.clip_grad_norm_(model.parameters(), config.training.grad_clip)
scaler.step(optimizer)
scaler.update()
if ema_model is not None:
if ddp:
ema_model.update_parameters(model.module)
else:
ema_model.update_parameters(model)
if scheduler.name in ["linear_scheduler"]:
scheduler.step()
optimizer.zero_grad(set_to_none=True)
if ddp:
with torch.no_grad():
loss_copy = loss.detach().clone()
dist.all_reduce(loss_copy, op=dist.ReduceOp.SUM)
loss_copy /= dist.get_world_size()
if dist.get_rank() == 0:
li = loss_copy.item() * gradient_accumulation_steps
all_losses[f"epoch_{epoch}"].append(li)
loss_val += li
total += 1
pbar.set_postfix(
{"loss": 100 * li, "avg_loss": 100 * loss_val / (i + 1)}
)
sys.stdout.flush()
wandb.log(
{"loss": 100 * li, "avg_loss": 100 * loss_val / (i + 1), "i": i}
)
else:
li = loss.item() * gradient_accumulation_steps
all_losses[f"epoch_{epoch}"].append(li)
loss_val += li
total += 1
pbar.set_postfix({"loss": 100 * li, "avg_loss": 100 * loss_val / (i + 1)})
wandb.log({"loss": 100 * li, "avg_loss": 100 * loss_val / (i + 1), "i": i})
loss.detach()
if should_print:
print(f"Training loss: {loss_val / total}")
wandb.log(
{
"train_loss": loss_val / total,
"epoch": epoch,
"learning_rate": optimizer.param_groups[0]["lr"],
}
)
def compute_epoch_metrics(
model: torch.nn.Module,
args: argparse.Namespace,
config: ConfigDict,
device: torch.device,
device_ids: List[int],
best_metric: float,
epoch: int,
scheduler: torch.optim.lr_scheduler,
optimizer,
all_time_all_metrics,
all_losses,
world_size=None,
metrics_avg=None,
all_metrics=None,
) -> float:
"""
Compute and log the metrics for the current epoch, and save model weights if the metric improves.
Args:
all_losses:
all_metrics:
metrics_avg:
world_size:
model: The model to evaluate.
args: Command-line arguments containing configuration paths and other settings.
config: Configuration dictionary containing training settings.
device: The device (CPU or GPU) used for evaluation.
device_ids: List of GPU device IDs when using multiple GPUs.
best_metric: The best metric value seen so far.
epoch: The current epoch number.
scheduler: The learning rate scheduler to adjust the learning rate.
optimizer:
all_time_all_metrics:
Returns:
The updated best_metric.
"""
ddp = True if world_size else False
should_print = not dist.is_initialized() or dist.get_rank() == 0
if not ddp:
if torch.cuda.is_available() and len(device_ids) > 1:
metrics_avg, all_metrics = valid_multi_gpu(
model, args, config, args.device_ids, verbose=False
)
else:
metrics_avg, all_metrics = valid(model, args, config, device, verbose=False)
all_time_all_metrics[f"epoch_{epoch}"] = all_metrics
metric_avg = metrics_avg[args.metric_for_scheduler]
if metric_avg > best_metric:
if args.each_metrics_in_name:
stem_parts = []
for stem_name, values in all_metrics[args.metric_for_scheduler].items():
stem_values = np.array(values)
mean_val = stem_values.mean()
std_val = stem_values.std()
stem_parts.append(
f"{stem_name}_{args.metric_for_scheduler}_{mean_val:.4f}_std_{std_val:.4f}"
)
stem_info = "__".join(stem_parts)
store_path = f"{args.results_path}/model_{args.model_type}_ep_{epoch}_{stem_info}.ckpt"
else:
store_path = f"{args.results_path}/model_{args.model_type}_ep_{epoch}_{args.metric_for_scheduler}_{metric_avg:.4f}.ckpt"
if should_print:
print(f"Store weights: {store_path}")
save_weights(
store_path=store_path,
model=model,
device_ids=device_ids,
optimizer=optimizer,
epoch=epoch,
all_time_all_metrics=all_time_all_metrics,
all_losses=all_losses,
best_metric=best_metric,
args=args,
scheduler=scheduler,
)
best_metric = metric_avg
if args.save_weights_every_epoch:
metric_string = ""
for m in metrics_avg:
metric_string += "_{}_{:.4f}".format(m, metrics_avg[m])
store_path = f"{args.results_path}/model_{args.model_type}_ep_{epoch}{metric_string}.ckpt"
save_weights(
store_path=store_path,
model=model,
device_ids=device_ids,
optimizer=optimizer,
epoch=epoch,
all_time_all_metrics=all_time_all_metrics,
all_losses=all_losses,
best_metric=best_metric,
args=args,
scheduler=scheduler,
)
if scheduler.name in ["ReduceLROnPlateau"]:
scheduler.step(metric_avg)
if should_print:
wandb.log({"metric_main": metric_avg, "best_metric": best_metric})
for metric_name in metrics_avg:
wandb.log({f"metric_{metric_name}": metrics_avg[metric_name]})
return best_metric
def train_model(
args: Union[argparse.Namespace, None], rank=None, world_size=None
) -> None:
"""
Trains the model based on the provided arguments, including data preparation, optimizer setup,
and loss calculation. The model is trained for multiple epochs with logging via wandb.
Args:
world_size:
rank:
args: Command-line arguments containing configuration paths, hyperparameters, and other settings.
Returns:
None
"""
from torch.cuda.amp.grad_scaler import GradScaler
from utils.dataset import prepare_data
from utils.losses import choice_loss
from utils.model_utils import (
get_lora,
get_optimizer,
load_start_checkpoint,
log_model_info,
)
args = parse_args_train(args)
ddp = True if world_size else False
if ddp:
initialize_environment_ddp(rank, world_size, args.seed, args.results_path)
else:
initialize_environment(args.seed, args.results_path)
model, config = get_model_from_config(args.model_type, args.config_path)
if "model_type" in config.training:
args.model_type = config.training.model_type
use_amp = getattr(config.training, "use_amp", True)
device_ids = args.device_ids
if ddp:
batch_size = config.training.batch_size
else:
batch_size = config.training.batch_size * len(device_ids)
if not dist.is_initialized() or dist.get_rank() == 0:
wandb_init(args, config, batch_size)
train_loader = prepare_data(config, args, batch_size)
if args.start_check_point:
checkpoint = torch.load(
args.start_check_point, weights_only=False, map_location="cpu"
)
load_start_checkpoint(args, model, checkpoint, type_="train")
model = get_lora(args, config, model)
if args.freeze_layers is not None:
freeze_layers = []
train_layers = []
for name, param in model.named_parameters():
if any(name.startswith(prefix) for prefix in args.freeze_layers):
freeze_layers.append(name)
print("Freezing layer:", name)
param.requires_grad = False
else:
train_layers.append(name)
print("Trainable layers: {}".format(len(train_layers)))
print("Frozen layers: {}".format(len(freeze_layers)))
if ddp:
device = torch.device(f"cuda:{rank}")
model.to(device)
model = torch.nn.parallel.DistributedDataParallel(
model, device_ids=[rank], find_unused_parameters=True
)
model_module = model.module
else:
device, model = initialize_model_and_device(model, args.device_ids)
# If model is DataParallel, get underlying module
model_module = model.module if hasattr(model, "module") else model
ema_model = None
if hasattr(config.training, "ema_momentum") and config.training.ema_momentum > 0:
from torch.optim.swa_utils import AveragedModel, get_ema_multi_avg_fn
if not dist.is_initialized() or dist.get_rank() == 0:
print(f"Initializing EMA with decay: {config.training.ema_momentum}")
ema_model = AveragedModel(
model_module,
multi_avg_fn=get_ema_multi_avg_fn(config.training.ema_momentum),
)
if args.pre_valid:
model_to_valid = ema_model if ema_model is not None else model
if ddp:
valid_multi_gpu(
model_to_valid, args, config, args.device_ids, verbose=False
)
else:
if torch.cuda.is_available() and len(args.device_ids) > 1:
valid_multi_gpu(
model_to_valid, args, config, args.device_ids, verbose=True
)
else:
valid(model_to_valid, args, config, device, verbose=True)
gradient_accumulation_steps = int(
getattr(config.training, "gradient_accumulation_steps", 1)
)
# load optimizer
optimizer = get_optimizer(config, model)
scheduler = get_scheduler(config, optimizer)
if (
args.start_check_point
and "optimizer_state_dict" in checkpoint
and args.load_optimizer
):
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
if (
args.start_check_point
and "scheduler_state_dict" in checkpoint
and args.load_scheduler
):
scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
# load num epoch
if args.start_check_point and "epoch" in checkpoint and args.load_epoch:
start_epoch = checkpoint["epoch"] + 1
else:
start_epoch = 0
if args.start_check_point and "best_metric" in checkpoint and args.load_best_metric:
best_metric = checkpoint["best_metric"]
else:
best_metric = float("-inf")
if args.start_check_point and "all_metrics" in checkpoint and args.load_all_metrics:
all_time_all_metrics = checkpoint["all_metrics"]
else:
all_time_all_metrics = {}
if args.start_check_point and "all_losses" in checkpoint and args.load_all_losses:
all_losses = checkpoint["all_losses"]
else:
all_losses = {}
multi_loss = choice_loss(args, config)
scaler = GradScaler()
if args.set_per_process_memory_fraction:
torch.cuda.set_per_process_memory_fraction(1.0)
torch.cuda.empty_cache()
safe_mode = args.safe_mode
should_print = not dist.is_initialized() or dist.get_rank() == 0
if should_print:
if world_size:
batch_size = config.training.batch_size
ef_batch_size = batch_size * gradient_accumulation_steps * world_size
num_gpu = world_size
else:
device_ids = args.device_ids
batch_size = config.training.batch_size * len(device_ids)
ef_batch_size = batch_size * gradient_accumulation_steps
num_gpu = len(device_ids)
print(
f"Instruments: {config.training.instruments}\n"
f"Metrics for training: {args.metrics}. Metric for scheduler: {args.metric_for_scheduler}\n"
f"Patience: {config.training.patience} "
f"Reduce factor: {config.training.reduce_factor}\n"
f"Batch size: {batch_size} "
f"Grad accum steps: {gradient_accumulation_steps} "
f"Num gpus: {num_gpu} "
f"Effective batch size: {ef_batch_size}\n"
f"Dataset type: {args.dataset_type}\n"
f"Optimizer: {config.training.optimizer}"
)
print(f"Train for: {config.training.num_epochs} epochs")
log_model_info(model, args.results_path)
for epoch in range(start_epoch, config.training.num_epochs):
if ddp:
train_loader.sampler.set_epoch(epoch)
train_one_epoch(
model,
config,
args,
optimizer,
device,
device_ids,
epoch,
use_amp,
scaler,
scheduler,
gradient_accumulation_steps,
train_loader,
multi_loss,
all_losses,
world_size,
ema_model=ema_model,
safe_mode=safe_mode,
)
model_to_valid = ema_model if ema_model is not None else model
if should_print:
save_last_weights(
args,
model,
device_ids,
optimizer,
epoch,
all_time_all_metrics,
best_metric,
scheduler,
)
if ddp:
metrics_avg, all_metrics = valid_multi_gpu(
model, args, config, args.device_ids, verbose=False
)
if rank == 0:
all_time_all_metrics[f"epoch_{epoch}"] = all_metrics
best_metric = compute_epoch_metrics(
model=model,
args=args,
config=config,
device=device,
device_ids=device_ids,
best_metric=best_metric,
epoch=epoch,
scheduler=scheduler,
optimizer=optimizer,
all_time_all_metrics=all_time_all_metrics,
all_losses=all_losses,
world_size=world_size,
metrics_avg=metrics_avg,
all_metrics=all_metrics,
)
else:
best_metric = compute_epoch_metrics(
model=model,
args=args,
config=config,
device=device,
device_ids=device_ids,
best_metric=best_metric,
epoch=epoch,
scheduler=scheduler,
optimizer=optimizer,
all_time_all_metrics=all_time_all_metrics,
all_losses=all_losses,
)
if __name__ == "__main__":
train_model(None)