MusicGen / audiocraft /solvers /diffusion.py
reach-vb's picture
reach-vb HF staff
Stereo demo update (#60)
5325fcc
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import typing as tp
import flashy
import julius
import omegaconf
import torch
import torch.nn.functional as F
from . import builders
from . import base
from .. import models
from ..modules.diffusion_schedule import NoiseSchedule
from ..metrics import RelativeVolumeMel
from ..models.builders import get_processor
from ..utils.samples.manager import SampleManager
from ..solvers.compression import CompressionSolver
class PerStageMetrics:
"""Handle prompting the metrics per stage.
It outputs the metrics per range of diffusion states.
e.g. avg loss when t in [250, 500]
"""
def __init__(self, num_steps: int, num_stages: int = 4):
self.num_steps = num_steps
self.num_stages = num_stages
def __call__(self, losses: dict, step: tp.Union[int, torch.Tensor]):
if type(step) is int:
stage = int((step / self.num_steps) * self.num_stages)
return {f"{name}_{stage}": loss for name, loss in losses.items()}
elif type(step) is torch.Tensor:
stage_tensor = ((step / self.num_steps) * self.num_stages).long()
out: tp.Dict[str, float] = {}
for stage_idx in range(self.num_stages):
mask = (stage_tensor == stage_idx)
N = mask.sum()
stage_out = {}
if N > 0: # pass if no elements in the stage
for name, loss in losses.items():
stage_loss = (mask * loss).sum() / N
stage_out[f"{name}_{stage_idx}"] = stage_loss
out = {**out, **stage_out}
return out
class DataProcess:
"""Apply filtering or resampling.
Args:
initial_sr (int): Initial sample rate.
target_sr (int): Target sample rate.
use_resampling: Whether to use resampling or not.
use_filter (bool):
n_bands (int): Number of bands to consider.
idx_band (int):
device (torch.device or str):
cutoffs ():
boost (bool):
"""
def __init__(self, initial_sr: int = 24000, target_sr: int = 16000, use_resampling: bool = False,
use_filter: bool = False, n_bands: int = 4,
idx_band: int = 0, device: torch.device = torch.device('cpu'), cutoffs=None, boost=False):
"""Apply filtering or resampling
Args:
initial_sr (int): sample rate of the dataset
target_sr (int): sample rate after resampling
use_resampling (bool): whether or not performs resampling
use_filter (bool): when True filter the data to keep only one frequency band
n_bands (int): Number of bands used
cuts (none or list): The cutoff frequencies of the band filtering
if None then we use mel scale bands.
idx_band (int): index of the frequency band. 0 are lows ... (n_bands - 1) highs
boost (bool): make the data scale match our music dataset.
"""
assert idx_band < n_bands
self.idx_band = idx_band
if use_filter:
if cutoffs is not None:
self.filter = julius.SplitBands(sample_rate=initial_sr, cutoffs=cutoffs).to(device)
else:
self.filter = julius.SplitBands(sample_rate=initial_sr, n_bands=n_bands).to(device)
self.use_filter = use_filter
self.use_resampling = use_resampling
self.target_sr = target_sr
self.initial_sr = initial_sr
self.boost = boost
def process_data(self, x, metric=False):
if x is None:
return None
if self.boost:
x /= torch.clamp(x.std(dim=(1, 2), keepdim=True), min=1e-4)
x * 0.22
if self.use_filter and not metric:
x = self.filter(x)[self.idx_band]
if self.use_resampling:
x = julius.resample_frac(x, old_sr=self.initial_sr, new_sr=self.target_sr)
return x
def inverse_process(self, x):
"""Upsampling only."""
if self.use_resampling:
x = julius.resample_frac(x, old_sr=self.target_sr, new_sr=self.target_sr)
return x
class DiffusionSolver(base.StandardSolver):
"""Solver for compression task.
The diffusion task allows for MultiBand diffusion model training.
Args:
cfg (DictConfig): Configuration.
"""
def __init__(self, cfg: omegaconf.DictConfig):
super().__init__(cfg)
self.cfg = cfg
self.device = cfg.device
self.sample_rate: int = self.cfg.sample_rate
self.codec_model = CompressionSolver.model_from_checkpoint(
cfg.compression_model_checkpoint, device=self.device)
self.codec_model.set_num_codebooks(cfg.n_q)
assert self.codec_model.sample_rate == self.cfg.sample_rate, (
f"Codec model sample rate is {self.codec_model.sample_rate} but "
f"Solver sample rate is {self.cfg.sample_rate}."
)
assert self.codec_model.sample_rate == self.sample_rate, \
f"Sample rate of solver {self.sample_rate} and codec {self.codec_model.sample_rate} " \
"don't match."
self.sample_processor = get_processor(cfg.processor, sample_rate=self.sample_rate)
self.register_stateful('sample_processor')
self.sample_processor.to(self.device)
self.schedule = NoiseSchedule(
**cfg.schedule, device=self.device, sample_processor=self.sample_processor)
self.eval_metric: tp.Optional[torch.nn.Module] = None
self.rvm = RelativeVolumeMel()
self.data_processor = DataProcess(initial_sr=self.sample_rate, target_sr=cfg.resampling.target_sr,
use_resampling=cfg.resampling.use, cutoffs=cfg.filter.cutoffs,
use_filter=cfg.filter.use, n_bands=cfg.filter.n_bands,
idx_band=cfg.filter.idx_band, device=self.device)
@property
def best_metric_name(self) -> tp.Optional[str]:
if self._current_stage == "evaluate":
return 'rvm'
else:
return 'loss'
@torch.no_grad()
def get_condition(self, wav: torch.Tensor) -> torch.Tensor:
codes, scale = self.codec_model.encode(wav)
assert scale is None, "Scaled compression models not supported."
emb = self.codec_model.decode_latent(codes)
return emb
def build_model(self):
"""Build model and optimizer as well as optional Exponential Moving Average of the model.
"""
# Model and optimizer
self.model = models.builders.get_diffusion_model(self.cfg).to(self.device)
self.optimizer = builders.get_optimizer(self.model.parameters(), self.cfg.optim)
self.register_stateful('model', 'optimizer')
self.register_best_state('model')
self.register_ema('model')
def build_dataloaders(self):
"""Build audio dataloaders for each stage."""
self.dataloaders = builders.get_audio_datasets(self.cfg)
def show(self):
# TODO
raise NotImplementedError()
def run_step(self, idx: int, batch: torch.Tensor, metrics: dict):
"""Perform one training or valid step on a given batch."""
x = batch.to(self.device)
loss_fun = F.mse_loss if self.cfg.loss.kind == 'mse' else F.l1_loss
condition = self.get_condition(x) # [bs, 128, T/hop, n_emb]
sample = self.data_processor.process_data(x)
input_, target, step = self.schedule.get_training_item(sample,
tensor_step=self.cfg.schedule.variable_step_batch)
out = self.model(input_, step, condition=condition).sample
base_loss = loss_fun(out, target, reduction='none').mean(dim=(1, 2))
reference_loss = loss_fun(input_, target, reduction='none').mean(dim=(1, 2))
loss = base_loss / reference_loss ** self.cfg.loss.norm_power
if self.is_training:
loss.mean().backward()
flashy.distrib.sync_model(self.model)
self.optimizer.step()
self.optimizer.zero_grad()
metrics = {
'loss': loss.mean(), 'normed_loss': (base_loss / reference_loss).mean(),
}
metrics.update(self.per_stage({'loss': loss, 'normed_loss': base_loss / reference_loss}, step))
metrics.update({
'std_in': input_.std(), 'std_out': out.std()})
return metrics
def run_epoch(self):
# reset random seed at the beginning of the epoch
self.rng = torch.Generator()
self.rng.manual_seed(1234 + self.epoch)
self.per_stage = PerStageMetrics(self.schedule.num_steps, self.cfg.metrics.num_stage)
# run epoch
super().run_epoch()
def evaluate(self):
"""Evaluate stage.
Runs audio reconstruction evaluation.
"""
self.model.eval()
evaluate_stage_name = f'{self.current_stage}'
loader = self.dataloaders['evaluate']
updates = len(loader)
lp = self.log_progress(f'{evaluate_stage_name} estimate', loader, total=updates, updates=self.log_updates)
metrics = {}
n = 1
for idx, batch in enumerate(lp):
x = batch.to(self.device)
with torch.no_grad():
y_pred = self.regenerate(x)
y_pred = y_pred.cpu()
y = batch.cpu() # should already be on CPU but just in case
rvm = self.rvm(y_pred, y)
lp.update(**rvm)
if len(metrics) == 0:
metrics = rvm
else:
for key in rvm.keys():
metrics[key] = (metrics[key] * n + rvm[key]) / (n + 1)
metrics = flashy.distrib.average_metrics(metrics)
return metrics
@torch.no_grad()
def regenerate(self, wav: torch.Tensor, step_list: tp.Optional[list] = None):
"""Regenerate the given waveform."""
condition = self.get_condition(wav)
initial = self.schedule.get_initial_noise(self.data_processor.process_data(wav)) # sampling rate changes.
result = self.schedule.generate_subsampled(self.model, initial=initial, condition=condition,
step_list=step_list)
result = self.data_processor.inverse_process(result)
return result
def generate(self):
"""Generate stage."""
sample_manager = SampleManager(self.xp)
self.model.eval()
generate_stage_name = f'{self.current_stage}'
loader = self.dataloaders['generate']
updates = len(loader)
lp = self.log_progress(generate_stage_name, loader, total=updates, updates=self.log_updates)
for batch in lp:
reference, _ = batch
reference = reference.to(self.device)
estimate = self.regenerate(reference)
reference = reference.cpu()
estimate = estimate.cpu()
sample_manager.add_samples(estimate, self.epoch, ground_truth_wavs=reference)
flashy.distrib.barrier()