|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
All the functions to build the relevant solvers and used objects |
|
from the Hydra config. |
|
""" |
|
|
|
from enum import Enum |
|
import logging |
|
import typing as tp |
|
|
|
import dora |
|
import flashy |
|
import omegaconf |
|
import torch |
|
from torch import nn |
|
from torch.optim import Optimizer |
|
|
|
try: |
|
from torch.optim.lr_scheduler import LRScheduler |
|
except ImportError: |
|
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler |
|
|
|
from .base import StandardSolver |
|
from .. import adversarial, data, losses, metrics, optim |
|
from ..utils.utils import dict_from_config, get_loader |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class DatasetType(Enum): |
|
AUDIO = "audio" |
|
MUSIC = "music" |
|
SOUND = "sound" |
|
|
|
|
|
def get_solver(cfg: omegaconf.DictConfig) -> StandardSolver: |
|
"""Instantiate solver from config.""" |
|
from .audiogen import AudioGenSolver |
|
from .compression import CompressionSolver |
|
from .musicgen import MusicGenSolver |
|
from .diffusion import DiffusionSolver |
|
klass = { |
|
'compression': CompressionSolver, |
|
'musicgen': MusicGenSolver, |
|
'audiogen': AudioGenSolver, |
|
'lm': MusicGenSolver, |
|
'diffusion': DiffusionSolver, |
|
'sound_lm': AudioGenSolver, |
|
}[cfg.solver] |
|
return klass(cfg) |
|
|
|
|
|
def get_optim_parameter_groups(model: nn.Module): |
|
"""Create parameter groups for the model using the appropriate method |
|
if defined for each modules, to create the different groups. |
|
|
|
Args: |
|
model (nn.Module): torch model |
|
Returns: |
|
List of parameter groups |
|
""" |
|
seen_params: tp.Set[nn.parameter.Parameter] = set() |
|
other_params = [] |
|
groups = [] |
|
for name, module in model.named_modules(): |
|
if hasattr(module, 'make_optim_group'): |
|
group = module.make_optim_group() |
|
params = set(group['params']) |
|
assert params.isdisjoint(seen_params) |
|
seen_params |= set(params) |
|
groups.append(group) |
|
for param in model.parameters(): |
|
if param not in seen_params: |
|
other_params.append(param) |
|
groups.insert(0, {'params': other_params}) |
|
parameters = groups |
|
return parameters |
|
|
|
|
|
def get_optimizer(params: tp.Union[nn.Module, tp.Iterable[torch.Tensor]], cfg: omegaconf.DictConfig) -> Optimizer: |
|
"""Build torch optimizer from config and set of parameters. |
|
Supported optimizers: Adam, AdamW |
|
|
|
Args: |
|
params (nn.Module or iterable of torch.Tensor): Parameters to optimize. |
|
cfg (DictConfig): Optimization-related configuration. |
|
Returns: |
|
torch.optim.Optimizer. |
|
""" |
|
if 'optimizer' not in cfg: |
|
if getattr(cfg, 'optim', None) is not None: |
|
raise KeyError("Optimizer not found in config. Try instantiating optimizer from cfg.optim?") |
|
else: |
|
raise KeyError("Optimizer not found in config.") |
|
|
|
parameters = get_optim_parameter_groups(params) if isinstance(params, nn.Module) else params |
|
optimizer: torch.optim.Optimizer |
|
if cfg.optimizer == 'adam': |
|
optimizer = torch.optim.Adam(parameters, lr=cfg.lr, **cfg.adam) |
|
elif cfg.optimizer == 'adamw': |
|
optimizer = torch.optim.AdamW(parameters, lr=cfg.lr, **cfg.adam) |
|
elif cfg.optimizer == 'dadam': |
|
optimizer = optim.DAdaptAdam(parameters, lr=cfg.lr, **cfg.adam) |
|
else: |
|
raise ValueError(f"Unsupported LR Scheduler: {cfg.lr_scheduler}") |
|
return optimizer |
|
|
|
|
|
def get_lr_scheduler(optimizer: torch.optim.Optimizer, |
|
cfg: omegaconf.DictConfig, |
|
total_updates: int) -> tp.Optional[LRScheduler]: |
|
"""Build torch learning rate scheduler from config and associated optimizer. |
|
Supported learning rate schedulers: ExponentialLRScheduler, PlateauLRScheduler |
|
|
|
Args: |
|
optimizer (torch.optim.Optimizer): Optimizer. |
|
cfg (DictConfig): Schedule-related configuration. |
|
total_updates (int): Total number of updates. |
|
Returns: |
|
torch.optim.Optimizer. |
|
""" |
|
if 'lr_scheduler' not in cfg: |
|
raise KeyError("LR Scheduler not found in config") |
|
|
|
lr_sched: tp.Optional[LRScheduler] = None |
|
if cfg.lr_scheduler == 'step': |
|
lr_sched = torch.optim.lr_scheduler.StepLR(optimizer, **cfg.step) |
|
elif cfg.lr_scheduler == 'exponential': |
|
lr_sched = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=cfg.exponential) |
|
elif cfg.lr_scheduler == 'cosine': |
|
kwargs = dict_from_config(cfg.cosine) |
|
warmup_steps = kwargs.pop('warmup') |
|
lr_sched = optim.CosineLRScheduler( |
|
optimizer, warmup_steps=warmup_steps, total_steps=total_updates, **kwargs) |
|
elif cfg.lr_scheduler == 'polynomial_decay': |
|
kwargs = dict_from_config(cfg.polynomial_decay) |
|
warmup_steps = kwargs.pop('warmup') |
|
lr_sched = optim.PolynomialDecayLRScheduler( |
|
optimizer, warmup_steps=warmup_steps, total_steps=total_updates, **kwargs) |
|
elif cfg.lr_scheduler == 'inverse_sqrt': |
|
kwargs = dict_from_config(cfg.inverse_sqrt) |
|
warmup_steps = kwargs.pop('warmup') |
|
lr_sched = optim.InverseSquareRootLRScheduler(optimizer, warmup_steps=warmup_steps, **kwargs) |
|
elif cfg.lr_scheduler == 'linear_warmup': |
|
kwargs = dict_from_config(cfg.linear_warmup) |
|
warmup_steps = kwargs.pop('warmup') |
|
lr_sched = optim.LinearWarmupLRScheduler(optimizer, warmup_steps=warmup_steps, **kwargs) |
|
elif cfg.lr_scheduler is not None: |
|
raise ValueError(f"Unsupported LR Scheduler: {cfg.lr_scheduler}") |
|
return lr_sched |
|
|
|
|
|
def get_ema(module_dict: nn.ModuleDict, cfg: omegaconf.DictConfig) -> tp.Optional[optim.ModuleDictEMA]: |
|
"""Initialize Exponential Moving Average. |
|
|
|
Args: |
|
module_dict (nn.ModuleDict): ModuleDict for which to compute the EMA. |
|
cfg (omegaconf.DictConfig): Optim EMA configuration. |
|
Returns: |
|
optim.ModuleDictEMA: EMA version of the ModuleDict. |
|
""" |
|
kw: tp.Dict[str, tp.Any] = dict(cfg) |
|
use = kw.pop('use', False) |
|
decay = kw.pop('decay', None) |
|
device = kw.pop('device', None) |
|
if not use: |
|
return None |
|
if len(module_dict) == 0: |
|
raise ValueError("Trying to build EMA but an empty module_dict source is provided!") |
|
ema_module = optim.ModuleDictEMA(module_dict, decay=decay, device=device) |
|
return ema_module |
|
|
|
|
|
def get_loss(loss_name: str, cfg: omegaconf.DictConfig): |
|
"""Instantiate loss from configuration.""" |
|
klass = { |
|
'l1': torch.nn.L1Loss, |
|
'l2': torch.nn.MSELoss, |
|
'mel': losses.MelSpectrogramL1Loss, |
|
'mrstft': losses.MRSTFTLoss, |
|
'msspec': losses.MultiScaleMelSpectrogramLoss, |
|
'sisnr': losses.SISNR, |
|
}[loss_name] |
|
kwargs = dict(getattr(cfg, loss_name)) |
|
return klass(**kwargs) |
|
|
|
|
|
def get_balancer(loss_weights: tp.Dict[str, float], cfg: omegaconf.DictConfig) -> losses.Balancer: |
|
"""Instantiate loss balancer from configuration for the provided weights.""" |
|
kwargs: tp.Dict[str, tp.Any] = dict_from_config(cfg) |
|
return losses.Balancer(loss_weights, **kwargs) |
|
|
|
|
|
def get_adversary(name: str, cfg: omegaconf.DictConfig) -> nn.Module: |
|
"""Initialize adversary from config.""" |
|
klass = { |
|
'msd': adversarial.MultiScaleDiscriminator, |
|
'mpd': adversarial.MultiPeriodDiscriminator, |
|
'msstftd': adversarial.MultiScaleSTFTDiscriminator, |
|
}[name] |
|
adv_cfg: tp.Dict[str, tp.Any] = dict(getattr(cfg, name)) |
|
return klass(**adv_cfg) |
|
|
|
|
|
def get_adversarial_losses(cfg) -> nn.ModuleDict: |
|
"""Initialize dict of adversarial losses from config.""" |
|
device = cfg.device |
|
adv_cfg = getattr(cfg, 'adversarial') |
|
adversaries = adv_cfg.get('adversaries', []) |
|
adv_loss_name = adv_cfg['adv_loss'] |
|
feat_loss_name = adv_cfg.get('feat_loss') |
|
normalize = adv_cfg.get('normalize', True) |
|
feat_loss: tp.Optional[adversarial.FeatureMatchingLoss] = None |
|
if feat_loss_name: |
|
assert feat_loss_name in ['l1', 'l2'], f"Feature loss only support L1 or L2 but {feat_loss_name} found." |
|
loss = get_loss(feat_loss_name, cfg) |
|
feat_loss = adversarial.FeatureMatchingLoss(loss, normalize) |
|
loss = adversarial.get_adv_criterion(adv_loss_name) |
|
loss_real = adversarial.get_real_criterion(adv_loss_name) |
|
loss_fake = adversarial.get_fake_criterion(adv_loss_name) |
|
adv_losses = nn.ModuleDict() |
|
for adv_name in adversaries: |
|
adversary = get_adversary(adv_name, cfg).to(device) |
|
optimizer = get_optimizer(adversary.parameters(), cfg.optim) |
|
adv_loss = adversarial.AdversarialLoss( |
|
adversary, |
|
optimizer, |
|
loss=loss, |
|
loss_real=loss_real, |
|
loss_fake=loss_fake, |
|
loss_feat=feat_loss, |
|
normalize=normalize |
|
) |
|
adv_losses[adv_name] = adv_loss |
|
return adv_losses |
|
|
|
|
|
def get_visqol(cfg: omegaconf.DictConfig) -> metrics.ViSQOL: |
|
"""Instantiate ViSQOL metric from config.""" |
|
kwargs = dict_from_config(cfg) |
|
return metrics.ViSQOL(**kwargs) |
|
|
|
|
|
def get_fad(cfg: omegaconf.DictConfig) -> metrics.FrechetAudioDistanceMetric: |
|
"""Instantiate Frechet Audio Distance metric from config.""" |
|
kwargs = dict_from_config(cfg.tf) |
|
xp = dora.get_xp() |
|
kwargs['log_folder'] = xp.folder |
|
return metrics.FrechetAudioDistanceMetric(**kwargs) |
|
|
|
|
|
def get_kldiv(cfg: omegaconf.DictConfig) -> metrics.KLDivergenceMetric: |
|
"""Instantiate KL-Divergence metric from config.""" |
|
kld_metrics = { |
|
'passt': metrics.PasstKLDivergenceMetric, |
|
} |
|
klass = kld_metrics[cfg.model] |
|
kwargs = dict_from_config(cfg.get(cfg.model)) |
|
return klass(**kwargs) |
|
|
|
|
|
def get_text_consistency(cfg: omegaconf.DictConfig) -> metrics.TextConsistencyMetric: |
|
"""Instantiate Text Consistency metric from config.""" |
|
text_consistency_metrics = { |
|
'clap': metrics.CLAPTextConsistencyMetric |
|
} |
|
klass = text_consistency_metrics[cfg.model] |
|
kwargs = dict_from_config(cfg.get(cfg.model)) |
|
return klass(**kwargs) |
|
|
|
|
|
def get_chroma_cosine_similarity(cfg: omegaconf.DictConfig) -> metrics.ChromaCosineSimilarityMetric: |
|
"""Instantiate Chroma Cosine Similarity metric from config.""" |
|
assert cfg.model == 'chroma_base', "Only support 'chroma_base' method for chroma cosine similarity metric" |
|
kwargs = dict_from_config(cfg.get(cfg.model)) |
|
return metrics.ChromaCosineSimilarityMetric(**kwargs) |
|
|
|
|
|
def get_audio_datasets(cfg: omegaconf.DictConfig, |
|
dataset_type: DatasetType = DatasetType.AUDIO) -> tp.Dict[str, torch.utils.data.DataLoader]: |
|
"""Build AudioDataset from configuration. |
|
|
|
Args: |
|
cfg (omegaconf.DictConfig): Configuration. |
|
dataset_type: The type of dataset to create. |
|
Returns: |
|
dict[str, torch.utils.data.DataLoader]: Map of dataloader for each data split. |
|
""" |
|
dataloaders: dict = {} |
|
|
|
sample_rate = cfg.sample_rate |
|
channels = cfg.channels |
|
seed = cfg.seed |
|
max_sample_rate = cfg.datasource.max_sample_rate |
|
max_channels = cfg.datasource.max_channels |
|
|
|
assert cfg.dataset is not None, "Could not find dataset definition in config" |
|
|
|
dataset_cfg = dict_from_config(cfg.dataset) |
|
splits_cfg: dict = {} |
|
splits_cfg['train'] = dataset_cfg.pop('train') |
|
splits_cfg['valid'] = dataset_cfg.pop('valid') |
|
splits_cfg['evaluate'] = dataset_cfg.pop('evaluate') |
|
splits_cfg['generate'] = dataset_cfg.pop('generate') |
|
execute_only_stage = cfg.get('execute_only', None) |
|
|
|
for split, path in cfg.datasource.items(): |
|
if not isinstance(path, str): |
|
continue |
|
if execute_only_stage is not None and split != execute_only_stage: |
|
continue |
|
logger.info(f"Loading audio data split {split}: {str(path)}") |
|
assert ( |
|
cfg.sample_rate <= max_sample_rate |
|
), f"Expecting a max sample rate of {max_sample_rate} for datasource but {sample_rate} found." |
|
assert ( |
|
cfg.channels <= max_channels |
|
), f"Expecting a max number of channels of {max_channels} for datasource but {channels} found." |
|
|
|
split_cfg = splits_cfg[split] |
|
split_kwargs = {k: v for k, v in split_cfg.items()} |
|
kwargs = {**dataset_cfg, **split_kwargs} |
|
kwargs['sample_rate'] = sample_rate |
|
kwargs['channels'] = channels |
|
|
|
if kwargs.get('permutation_on_files') and cfg.optim.updates_per_epoch: |
|
kwargs['num_samples'] = ( |
|
flashy.distrib.world_size() * cfg.dataset.batch_size * cfg.optim.updates_per_epoch) |
|
|
|
num_samples = kwargs['num_samples'] |
|
shuffle = kwargs['shuffle'] |
|
|
|
return_info = kwargs.pop('return_info') |
|
batch_size = kwargs.pop('batch_size', None) |
|
num_workers = kwargs.pop('num_workers') |
|
|
|
if dataset_type == DatasetType.MUSIC: |
|
dataset = data.music_dataset.MusicDataset.from_meta(path, **kwargs) |
|
elif dataset_type == DatasetType.SOUND: |
|
dataset = data.sound_dataset.SoundDataset.from_meta(path, **kwargs) |
|
elif dataset_type == DatasetType.AUDIO: |
|
dataset = data.info_audio_dataset.InfoAudioDataset.from_meta(path, return_info=return_info, **kwargs) |
|
else: |
|
raise ValueError(f"Dataset type is unsupported: {dataset_type}") |
|
|
|
loader = get_loader( |
|
dataset, |
|
num_samples, |
|
batch_size=batch_size, |
|
num_workers=num_workers, |
|
seed=seed, |
|
collate_fn=dataset.collater if return_info else None, |
|
shuffle=shuffle, |
|
) |
|
dataloaders[split] = loader |
|
|
|
return dataloaders |
|
|