Spaces:
Runtime error
Runtime error
| # coding: utf-8 | |
| __author__ = "Roman Solovyev (ZFTurbo): https://github.com/ZFTurbo/" | |
| __version__ = "1.0.5" | |
| import argparse | |
| import sys | |
| import warnings | |
| from typing import Callable, List, Union | |
| import numpy as np | |
| import torch | |
| import torch.distributed as dist | |
| import torch.nn as nn | |
| import wandb | |
| from ml_collections import ConfigDict | |
| from tqdm.auto import tqdm | |
| from utils.model_utils import ( | |
| initialize_model_and_device, | |
| normalize_batch, | |
| save_last_weights, | |
| save_weights, | |
| ) | |
| from utils.settings import ( | |
| get_model_from_config, | |
| get_scheduler, | |
| initialize_environment, | |
| initialize_environment_ddp, | |
| parse_args_train, | |
| wandb_init, | |
| ) | |
| from valid import valid, valid_multi_gpu | |
| warnings.filterwarnings("ignore") | |
| def forward_step( | |
| x, y, active_stem_ids, get_internal_loss, model, multi_loss, device_ids | |
| ): | |
| if get_internal_loss: | |
| loss = model(x, y, active_stem_ids=active_stem_ids) | |
| if isinstance(device_ids, (list, tuple)): | |
| loss = loss.mean() | |
| return loss | |
| else: | |
| y_ = model(x) | |
| return multi_loss(y_, y, x) | |
| def train_one_epoch( | |
| model: torch.nn.Module, | |
| config: ConfigDict, | |
| args: argparse.Namespace, | |
| optimizer: torch.optim.Optimizer, | |
| device: torch.device, | |
| device_ids: List[int], | |
| epoch: int, | |
| use_amp: bool, | |
| scaler: torch.cuda.amp.GradScaler, | |
| scheduler, | |
| gradient_accumulation_steps: int, | |
| train_loader: torch.utils.data.DataLoader, | |
| multi_loss: Callable[ | |
| [ | |
| torch.Tensor, | |
| torch.Tensor, | |
| torch.Tensor, | |
| ], | |
| torch.Tensor, | |
| ], | |
| all_losses=None, | |
| world_size=None, | |
| ema_model=None, | |
| safe_mode=None, | |
| ) -> None: | |
| """ | |
| Train the model for one epoch. | |
| Args: | |
| world_size: | |
| scheduler: | |
| model: The model to train. | |
| config: Configuration object containing training parameters. | |
| args: Command-line arguments with specific settings (e.g., model type). | |
| optimizer: Optimizer used for training. | |
| device: Device to run the model on (CPU or GPU). | |
| device_ids: List of GPU device IDs if using multiple GPUs. | |
| epoch: The current epoch number. | |
| use_amp: Whether to use automatic mixed precision (AMP) for training. | |
| scaler: Scaler for AMP to manage gradient scaling. | |
| gradient_accumulation_steps: Number of gradient accumulation steps before updating the optimizer. | |
| train_loader: DataLoader for the training dataset. | |
| multi_loss: The loss function to use during training. | |
| Returns: | |
| None | |
| """ | |
| ddp = True if world_size else False | |
| should_print = not dist.is_initialized() or dist.get_rank() == 0 | |
| model.train() | |
| if not ddp: | |
| model.to(device) | |
| if should_print: | |
| print(f"Train epoch: {epoch} Learning rate: {optimizer.param_groups[0]['lr']}") | |
| sys.stdout.flush() | |
| loss_val = 0.0 | |
| total = 0 | |
| all_losses[f"epoch_{epoch}"] = [] | |
| normalize = getattr(config.training, "normalize", False) | |
| get_internal_loss = ( | |
| args.model_type | |
| in ( | |
| "mel_band_roformer", | |
| "bs_roformer", | |
| "bs_mamba2", | |
| "mel_band_conformer", | |
| "bs_conformer", | |
| ) | |
| and not args.use_standard_loss | |
| ) | |
| if ddp: | |
| pbar = ( | |
| tqdm(train_loader, dynamic_ncols=True) | |
| if dist.get_rank() == 0 | |
| else train_loader | |
| ) | |
| else: | |
| pbar = tqdm(train_loader) | |
| for i, data in enumerate(pbar): | |
| if len(data) == 3: | |
| batch, mixes, active_stem_ids = data | |
| elif len(data) == 2: | |
| batch, mixes = data | |
| active_stem_ids = None | |
| else: | |
| raise ValueError(f"len data is {len(data)}") | |
| x = mixes.to(device) | |
| y = batch.to(device) | |
| if normalize: | |
| x, y = normalize_batch(x, y) | |
| if safe_mode: | |
| try: | |
| with torch.cuda.amp.autocast(enabled=use_amp): | |
| loss = forward_step( | |
| x, | |
| y, | |
| active_stem_ids, | |
| get_internal_loss, | |
| model, | |
| multi_loss, | |
| device_ids, | |
| ) | |
| except Exception as e: | |
| print(f"Error: {e}") | |
| continue | |
| else: | |
| with torch.cuda.amp.autocast(enabled=use_amp): | |
| loss = forward_step( | |
| x, | |
| y, | |
| active_stem_ids, | |
| get_internal_loss, | |
| model, | |
| multi_loss, | |
| device_ids, | |
| ) | |
| loss /= gradient_accumulation_steps | |
| scaler.scale(loss).backward() | |
| if ((i + 1) % gradient_accumulation_steps == 0) or (i == len(train_loader) - 1): | |
| scaler.unscale_(optimizer) | |
| if config.training.grad_clip: | |
| nn.utils.clip_grad_norm_(model.parameters(), config.training.grad_clip) | |
| scaler.step(optimizer) | |
| scaler.update() | |
| if ema_model is not None: | |
| if ddp: | |
| ema_model.update_parameters(model.module) | |
| else: | |
| ema_model.update_parameters(model) | |
| if scheduler.name in ["linear_scheduler"]: | |
| scheduler.step() | |
| optimizer.zero_grad(set_to_none=True) | |
| if ddp: | |
| with torch.no_grad(): | |
| loss_copy = loss.detach().clone() | |
| dist.all_reduce(loss_copy, op=dist.ReduceOp.SUM) | |
| loss_copy /= dist.get_world_size() | |
| if dist.get_rank() == 0: | |
| li = loss_copy.item() * gradient_accumulation_steps | |
| all_losses[f"epoch_{epoch}"].append(li) | |
| loss_val += li | |
| total += 1 | |
| pbar.set_postfix( | |
| {"loss": 100 * li, "avg_loss": 100 * loss_val / (i + 1)} | |
| ) | |
| sys.stdout.flush() | |
| wandb.log( | |
| {"loss": 100 * li, "avg_loss": 100 * loss_val / (i + 1), "i": i} | |
| ) | |
| else: | |
| li = loss.item() * gradient_accumulation_steps | |
| all_losses[f"epoch_{epoch}"].append(li) | |
| loss_val += li | |
| total += 1 | |
| pbar.set_postfix({"loss": 100 * li, "avg_loss": 100 * loss_val / (i + 1)}) | |
| wandb.log({"loss": 100 * li, "avg_loss": 100 * loss_val / (i + 1), "i": i}) | |
| loss.detach() | |
| if should_print: | |
| print(f"Training loss: {loss_val / total}") | |
| wandb.log( | |
| { | |
| "train_loss": loss_val / total, | |
| "epoch": epoch, | |
| "learning_rate": optimizer.param_groups[0]["lr"], | |
| } | |
| ) | |
| def compute_epoch_metrics( | |
| model: torch.nn.Module, | |
| args: argparse.Namespace, | |
| config: ConfigDict, | |
| device: torch.device, | |
| device_ids: List[int], | |
| best_metric: float, | |
| epoch: int, | |
| scheduler: torch.optim.lr_scheduler, | |
| optimizer, | |
| all_time_all_metrics, | |
| all_losses, | |
| world_size=None, | |
| metrics_avg=None, | |
| all_metrics=None, | |
| ) -> float: | |
| """ | |
| Compute and log the metrics for the current epoch, and save model weights if the metric improves. | |
| Args: | |
| all_losses: | |
| all_metrics: | |
| metrics_avg: | |
| world_size: | |
| model: The model to evaluate. | |
| args: Command-line arguments containing configuration paths and other settings. | |
| config: Configuration dictionary containing training settings. | |
| device: The device (CPU or GPU) used for evaluation. | |
| device_ids: List of GPU device IDs when using multiple GPUs. | |
| best_metric: The best metric value seen so far. | |
| epoch: The current epoch number. | |
| scheduler: The learning rate scheduler to adjust the learning rate. | |
| optimizer: | |
| all_time_all_metrics: | |
| Returns: | |
| The updated best_metric. | |
| """ | |
| ddp = True if world_size else False | |
| should_print = not dist.is_initialized() or dist.get_rank() == 0 | |
| if not ddp: | |
| if torch.cuda.is_available() and len(device_ids) > 1: | |
| metrics_avg, all_metrics = valid_multi_gpu( | |
| model, args, config, args.device_ids, verbose=False | |
| ) | |
| else: | |
| metrics_avg, all_metrics = valid(model, args, config, device, verbose=False) | |
| all_time_all_metrics[f"epoch_{epoch}"] = all_metrics | |
| metric_avg = metrics_avg[args.metric_for_scheduler] | |
| if metric_avg > best_metric: | |
| if args.each_metrics_in_name: | |
| stem_parts = [] | |
| for stem_name, values in all_metrics[args.metric_for_scheduler].items(): | |
| stem_values = np.array(values) | |
| mean_val = stem_values.mean() | |
| std_val = stem_values.std() | |
| stem_parts.append( | |
| f"{stem_name}_{args.metric_for_scheduler}_{mean_val:.4f}_std_{std_val:.4f}" | |
| ) | |
| stem_info = "__".join(stem_parts) | |
| store_path = f"{args.results_path}/model_{args.model_type}_ep_{epoch}_{stem_info}.ckpt" | |
| else: | |
| store_path = f"{args.results_path}/model_{args.model_type}_ep_{epoch}_{args.metric_for_scheduler}_{metric_avg:.4f}.ckpt" | |
| if should_print: | |
| print(f"Store weights: {store_path}") | |
| save_weights( | |
| store_path=store_path, | |
| model=model, | |
| device_ids=device_ids, | |
| optimizer=optimizer, | |
| epoch=epoch, | |
| all_time_all_metrics=all_time_all_metrics, | |
| all_losses=all_losses, | |
| best_metric=best_metric, | |
| args=args, | |
| scheduler=scheduler, | |
| ) | |
| best_metric = metric_avg | |
| if args.save_weights_every_epoch: | |
| metric_string = "" | |
| for m in metrics_avg: | |
| metric_string += "_{}_{:.4f}".format(m, metrics_avg[m]) | |
| store_path = f"{args.results_path}/model_{args.model_type}_ep_{epoch}{metric_string}.ckpt" | |
| save_weights( | |
| store_path=store_path, | |
| model=model, | |
| device_ids=device_ids, | |
| optimizer=optimizer, | |
| epoch=epoch, | |
| all_time_all_metrics=all_time_all_metrics, | |
| all_losses=all_losses, | |
| best_metric=best_metric, | |
| args=args, | |
| scheduler=scheduler, | |
| ) | |
| if scheduler.name in ["ReduceLROnPlateau"]: | |
| scheduler.step(metric_avg) | |
| if should_print: | |
| wandb.log({"metric_main": metric_avg, "best_metric": best_metric}) | |
| for metric_name in metrics_avg: | |
| wandb.log({f"metric_{metric_name}": metrics_avg[metric_name]}) | |
| return best_metric | |
| def train_model( | |
| args: Union[argparse.Namespace, None], rank=None, world_size=None | |
| ) -> None: | |
| """ | |
| Trains the model based on the provided arguments, including data preparation, optimizer setup, | |
| and loss calculation. The model is trained for multiple epochs with logging via wandb. | |
| Args: | |
| world_size: | |
| rank: | |
| args: Command-line arguments containing configuration paths, hyperparameters, and other settings. | |
| Returns: | |
| None | |
| """ | |
| from torch.cuda.amp.grad_scaler import GradScaler | |
| from utils.dataset import prepare_data | |
| from utils.losses import choice_loss | |
| from utils.model_utils import ( | |
| get_lora, | |
| get_optimizer, | |
| load_start_checkpoint, | |
| log_model_info, | |
| ) | |
| args = parse_args_train(args) | |
| ddp = True if world_size else False | |
| if ddp: | |
| initialize_environment_ddp(rank, world_size, args.seed, args.results_path) | |
| else: | |
| initialize_environment(args.seed, args.results_path) | |
| model, config = get_model_from_config(args.model_type, args.config_path) | |
| if "model_type" in config.training: | |
| args.model_type = config.training.model_type | |
| use_amp = getattr(config.training, "use_amp", True) | |
| device_ids = args.device_ids | |
| if ddp: | |
| batch_size = config.training.batch_size | |
| else: | |
| batch_size = config.training.batch_size * len(device_ids) | |
| if not dist.is_initialized() or dist.get_rank() == 0: | |
| wandb_init(args, config, batch_size) | |
| train_loader = prepare_data(config, args, batch_size) | |
| if args.start_check_point: | |
| checkpoint = torch.load( | |
| args.start_check_point, weights_only=False, map_location="cpu" | |
| ) | |
| load_start_checkpoint(args, model, checkpoint, type_="train") | |
| model = get_lora(args, config, model) | |
| if args.freeze_layers is not None: | |
| freeze_layers = [] | |
| train_layers = [] | |
| for name, param in model.named_parameters(): | |
| if any(name.startswith(prefix) for prefix in args.freeze_layers): | |
| freeze_layers.append(name) | |
| print("Freezing layer:", name) | |
| param.requires_grad = False | |
| else: | |
| train_layers.append(name) | |
| print("Trainable layers: {}".format(len(train_layers))) | |
| print("Frozen layers: {}".format(len(freeze_layers))) | |
| if ddp: | |
| device = torch.device(f"cuda:{rank}") | |
| model.to(device) | |
| model = torch.nn.parallel.DistributedDataParallel( | |
| model, device_ids=[rank], find_unused_parameters=True | |
| ) | |
| model_module = model.module | |
| else: | |
| device, model = initialize_model_and_device(model, args.device_ids) | |
| # If model is DataParallel, get underlying module | |
| model_module = model.module if hasattr(model, "module") else model | |
| ema_model = None | |
| if hasattr(config.training, "ema_momentum") and config.training.ema_momentum > 0: | |
| from torch.optim.swa_utils import AveragedModel, get_ema_multi_avg_fn | |
| if not dist.is_initialized() or dist.get_rank() == 0: | |
| print(f"Initializing EMA with decay: {config.training.ema_momentum}") | |
| ema_model = AveragedModel( | |
| model_module, | |
| multi_avg_fn=get_ema_multi_avg_fn(config.training.ema_momentum), | |
| ) | |
| if args.pre_valid: | |
| model_to_valid = ema_model if ema_model is not None else model | |
| if ddp: | |
| valid_multi_gpu( | |
| model_to_valid, args, config, args.device_ids, verbose=False | |
| ) | |
| else: | |
| if torch.cuda.is_available() and len(args.device_ids) > 1: | |
| valid_multi_gpu( | |
| model_to_valid, args, config, args.device_ids, verbose=True | |
| ) | |
| else: | |
| valid(model_to_valid, args, config, device, verbose=True) | |
| gradient_accumulation_steps = int( | |
| getattr(config.training, "gradient_accumulation_steps", 1) | |
| ) | |
| # load optimizer | |
| optimizer = get_optimizer(config, model) | |
| scheduler = get_scheduler(config, optimizer) | |
| if ( | |
| args.start_check_point | |
| and "optimizer_state_dict" in checkpoint | |
| and args.load_optimizer | |
| ): | |
| optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) | |
| if ( | |
| args.start_check_point | |
| and "scheduler_state_dict" in checkpoint | |
| and args.load_scheduler | |
| ): | |
| scheduler.load_state_dict(checkpoint["scheduler_state_dict"]) | |
| # load num epoch | |
| if args.start_check_point and "epoch" in checkpoint and args.load_epoch: | |
| start_epoch = checkpoint["epoch"] + 1 | |
| else: | |
| start_epoch = 0 | |
| if args.start_check_point and "best_metric" in checkpoint and args.load_best_metric: | |
| best_metric = checkpoint["best_metric"] | |
| else: | |
| best_metric = float("-inf") | |
| if args.start_check_point and "all_metrics" in checkpoint and args.load_all_metrics: | |
| all_time_all_metrics = checkpoint["all_metrics"] | |
| else: | |
| all_time_all_metrics = {} | |
| if args.start_check_point and "all_losses" in checkpoint and args.load_all_losses: | |
| all_losses = checkpoint["all_losses"] | |
| else: | |
| all_losses = {} | |
| multi_loss = choice_loss(args, config) | |
| scaler = GradScaler() | |
| if args.set_per_process_memory_fraction: | |
| torch.cuda.set_per_process_memory_fraction(1.0) | |
| torch.cuda.empty_cache() | |
| safe_mode = args.safe_mode | |
| should_print = not dist.is_initialized() or dist.get_rank() == 0 | |
| if should_print: | |
| if world_size: | |
| batch_size = config.training.batch_size | |
| ef_batch_size = batch_size * gradient_accumulation_steps * world_size | |
| num_gpu = world_size | |
| else: | |
| device_ids = args.device_ids | |
| batch_size = config.training.batch_size * len(device_ids) | |
| ef_batch_size = batch_size * gradient_accumulation_steps | |
| num_gpu = len(device_ids) | |
| print( | |
| f"Instruments: {config.training.instruments}\n" | |
| f"Metrics for training: {args.metrics}. Metric for scheduler: {args.metric_for_scheduler}\n" | |
| f"Patience: {config.training.patience} " | |
| f"Reduce factor: {config.training.reduce_factor}\n" | |
| f"Batch size: {batch_size} " | |
| f"Grad accum steps: {gradient_accumulation_steps} " | |
| f"Num gpus: {num_gpu} " | |
| f"Effective batch size: {ef_batch_size}\n" | |
| f"Dataset type: {args.dataset_type}\n" | |
| f"Optimizer: {config.training.optimizer}" | |
| ) | |
| print(f"Train for: {config.training.num_epochs} epochs") | |
| log_model_info(model, args.results_path) | |
| for epoch in range(start_epoch, config.training.num_epochs): | |
| if ddp: | |
| train_loader.sampler.set_epoch(epoch) | |
| train_one_epoch( | |
| model, | |
| config, | |
| args, | |
| optimizer, | |
| device, | |
| device_ids, | |
| epoch, | |
| use_amp, | |
| scaler, | |
| scheduler, | |
| gradient_accumulation_steps, | |
| train_loader, | |
| multi_loss, | |
| all_losses, | |
| world_size, | |
| ema_model=ema_model, | |
| safe_mode=safe_mode, | |
| ) | |
| model_to_valid = ema_model if ema_model is not None else model | |
| if should_print: | |
| save_last_weights( | |
| args, | |
| model, | |
| device_ids, | |
| optimizer, | |
| epoch, | |
| all_time_all_metrics, | |
| best_metric, | |
| scheduler, | |
| ) | |
| if ddp: | |
| metrics_avg, all_metrics = valid_multi_gpu( | |
| model, args, config, args.device_ids, verbose=False | |
| ) | |
| if rank == 0: | |
| all_time_all_metrics[f"epoch_{epoch}"] = all_metrics | |
| best_metric = compute_epoch_metrics( | |
| model=model, | |
| args=args, | |
| config=config, | |
| device=device, | |
| device_ids=device_ids, | |
| best_metric=best_metric, | |
| epoch=epoch, | |
| scheduler=scheduler, | |
| optimizer=optimizer, | |
| all_time_all_metrics=all_time_all_metrics, | |
| all_losses=all_losses, | |
| world_size=world_size, | |
| metrics_avg=metrics_avg, | |
| all_metrics=all_metrics, | |
| ) | |
| else: | |
| best_metric = compute_epoch_metrics( | |
| model=model, | |
| args=args, | |
| config=config, | |
| device=device, | |
| device_ids=device_ids, | |
| best_metric=best_metric, | |
| epoch=epoch, | |
| scheduler=scheduler, | |
| optimizer=optimizer, | |
| all_time_all_metrics=all_time_all_metrics, | |
| all_losses=all_losses, | |
| ) | |
| if __name__ == "__main__": | |
| train_model(None) | |