| | |
| | |
| | |
| | |
| | |
| |
|
| | """ |
| | Multi Band Diffusion models as described in |
| | "From Discrete Tokens to High-Fidelity Audio Using Multi-Band Diffusion" |
| | (paper link). |
| | """ |
| |
|
| | import typing as tp |
| |
|
| | import torch |
| | import julius |
| |
|
| | from .unet import DiffusionUnet |
| | from ..modules.diffusion_schedule import NoiseSchedule |
| | from .encodec import CompressionModel |
| | from ..solvers.compression import CompressionSolver |
| | from .loaders import load_compression_model, load_diffusion_models |
| |
|
| |
|
| | class DiffusionProcess: |
| | """Sampling for a diffusion Model. |
| | |
| | Args: |
| | model (DiffusionUnet): Diffusion U-Net model. |
| | noise_schedule (NoiseSchedule): Noise schedule for diffusion process. |
| | """ |
| | def __init__(self, model: DiffusionUnet, noise_schedule: NoiseSchedule) -> None: |
| | """ |
| | """ |
| | self.model = model |
| | self.schedule = noise_schedule |
| |
|
| | def generate(self, condition: torch.Tensor, initial_noise: torch.Tensor, |
| | step_list: tp.Optional[tp.List[int]] = None): |
| | """Perform one diffusion process to generate one of the bands. |
| | |
| | Args: |
| | condition (tensor): The embeddings form the compression model. |
| | initial_noise (tensor): The initial noise to start the process/ |
| | """ |
| | return self.schedule.generate_subsampled(model=self.model, initial=initial_noise, step_list=step_list, |
| | condition=condition) |
| |
|
| |
|
| | class MultiBandDiffusion: |
| | """Sample from multiple diffusion models. |
| | |
| | Args: |
| | DPs (list of DiffusionProcess): Diffusion processes. |
| | codec_model (CompressionModel): Underlying compression model used to obtain discrete tokens. |
| | """ |
| | def __init__(self, DPs: tp.List[DiffusionProcess], codec_model: CompressionModel) -> None: |
| | self.DPs = DPs |
| | self.codec_model = codec_model |
| | self.device = next(self.codec_model.parameters()).device |
| |
|
| | @property |
| | def sample_rate(self) -> int: |
| | return self.codec_model.sample_rate |
| |
|
| | @staticmethod |
| | def get_mbd_musicgen(device=None): |
| | """Load our diffusion models trained for MusicGen.""" |
| | if device is None: |
| | device = 'cuda' if torch.cuda.is_available() else 'cpu' |
| | path = 'facebook/multiband-diffusion' |
| | filename = 'mbd_musicgen_32khz.th' |
| | name = 'facebook/musicgen-small' |
| | codec_model = load_compression_model(name, device=device) |
| | models, processors, cfgs = load_diffusion_models(path, filename=filename, device=device) |
| | DPs = [] |
| | for i in range(len(models)): |
| | schedule = NoiseSchedule(**cfgs[i].schedule, sample_processor=processors[i], device=device) |
| | DPs.append(DiffusionProcess(model=models[i], noise_schedule=schedule)) |
| | return MultiBandDiffusion(DPs=DPs, codec_model=codec_model) |
| |
|
| | @staticmethod |
| | def get_mbd_24khz(bw: float = 3.0, pretrained: bool = True, |
| | device: tp.Optional[tp.Union[torch.device, str]] = None, |
| | n_q: tp.Optional[int] = None): |
| | """Get the pretrained Models for MultibandDiffusion. |
| | |
| | Args: |
| | bw (float): Bandwidth of the compression model. |
| | pretrained (bool): Whether to use / download if necessary the models. |
| | device (torch.device or str, optional): Device on which the models are loaded. |
| | n_q (int, optional): Number of quantizers to use within the compression model. |
| | """ |
| | if device is None: |
| | device = 'cuda' if torch.cuda.is_available() else 'cpu' |
| | assert bw in [1.5, 3.0, 6.0], f"bandwidth {bw} not available" |
| | if n_q is not None: |
| | assert n_q in [2, 4, 8] |
| | assert {1.5: 2, 3.0: 4, 6.0: 8}[bw] == n_q, \ |
| | f"bandwidth and number of codebooks missmatch to use n_q = {n_q} bw should be {n_q * (1.5 / 2)}" |
| | n_q = {1.5: 2, 3.0: 4, 6.0: 8}[bw] |
| | codec_model = CompressionSolver.model_from_checkpoint( |
| | '//pretrained/facebook/encodec_24khz', device=device) |
| | codec_model.set_num_codebooks(n_q) |
| | codec_model = codec_model.to(device) |
| | path = 'facebook/multiband-diffusion' |
| | filename = f'mbd_comp_{n_q}.pt' |
| | models, processors, cfgs = load_diffusion_models(path, filename=filename, device=device) |
| | DPs = [] |
| | for i in range(len(models)): |
| | schedule = NoiseSchedule(**cfgs[i].schedule, sample_processor=processors[i], device=device) |
| | DPs.append(DiffusionProcess(model=models[i], noise_schedule=schedule)) |
| | return MultiBandDiffusion(DPs=DPs, codec_model=codec_model) |
| |
|
| | return MultiBandDiffusion(DPs, codec_model) |
| |
|
| | @torch.no_grad() |
| | def get_condition(self, wav: torch.Tensor, sample_rate: int) -> torch.Tensor: |
| | """Get the conditioning (i.e. latent reprentatios of the compression model) from a waveform. |
| | Args: |
| | wav (torch.Tensor): The audio that we want to extract the conditioning from |
| | sample_rate (int): sample rate of the audio""" |
| | if sample_rate != self.sample_rate: |
| | wav = julius.resample_frac(wav, sample_rate, self.sample_rate) |
| | codes, scale = self.codec_model.encode(wav) |
| | assert scale is None, "Scaled compression models not supported." |
| | emb = self.get_emb(codes) |
| | return emb |
| |
|
| | @torch.no_grad() |
| | def get_emb(self, codes: torch.Tensor): |
| | """Get latent representation from the discrete codes |
| | Argrs: |
| | codes (torch.Tensor): discrete tokens""" |
| | emb = self.codec_model.decode_latent(codes) |
| | return emb |
| |
|
| | def generate(self, emb: torch.Tensor, size: tp.Optional[torch.Size] = None, |
| | step_list: tp.Optional[tp.List[int]] = None): |
| | """Generate Wavform audio from the latent embeddings of the compression model |
| | Args: |
| | emb (torch.Tensor): Conditioning embeddinds |
| | size (none torch.Size): size of the output |
| | if None this is computed from the typical upsampling of the model |
| | step_list (optional list[int]): list of Markov chain steps, defaults to 50 linearly spaced step. |
| | """ |
| | if size is None: |
| | upsampling = int(self.codec_model.sample_rate / self.codec_model.frame_rate) |
| | size = torch.Size([emb.size(0), self.codec_model.channels, emb.size(-1) * upsampling]) |
| | assert size[0] == emb.size(0) |
| | out = torch.zeros(size).to(self.device) |
| | for DP in self.DPs: |
| | out += DP.generate(condition=emb, step_list=step_list, initial_noise=torch.randn_like(out)) |
| | return out |
| |
|
| | def re_eq(self, wav: torch.Tensor, ref: torch.Tensor, n_bands: int = 32, strictness: float = 1): |
| | """match the eq to the encodec output by matching the standard deviation of some frequency bands |
| | Args: |
| | wav (torch.Tensor): audio to equalize |
| | ref (torch.Tensor):refenrence audio from which we match the spectrogram. |
| | n_bands (int): number of bands of the eq |
| | strictness (float): how strict the the matching. 0 is no matching, 1 is exact matching. |
| | """ |
| | split = julius.SplitBands(n_bands=n_bands, sample_rate=self.codec_model.sample_rate).to(wav.device) |
| | bands = split(wav) |
| | bands_ref = split(ref) |
| | out = torch.zeros_like(ref) |
| | for i in range(n_bands): |
| | out += bands[i] * (bands_ref[i].std() / bands[i].std()) ** strictness |
| | return out |
| |
|
| | def regenerate(self, wav: torch.Tensor, sample_rate: int): |
| | """Regenerate a wavform through compression and diffusion regeneration. |
| | Args: |
| | wav (torch.Tensor): Original 'ground truth' audio |
| | sample_rate (int): sample rate of the input (and output) wav |
| | """ |
| | if sample_rate != self.codec_model.sample_rate: |
| | wav = julius.resample_frac(wav, sample_rate, self.codec_model.sample_rate) |
| | emb = self.get_condition(wav, sample_rate=self.codec_model.sample_rate) |
| | size = wav.size() |
| | out = self.generate(emb, size=size) |
| | if sample_rate != self.codec_model.sample_rate: |
| | out = julius.resample_frac(out, self.codec_model.sample_rate, sample_rate) |
| | return out |
| |
|
| | def tokens_to_wav(self, tokens: torch.Tensor, n_bands: int = 32): |
| | """Generate Waveform audio with diffusion from the discrete codes. |
| | Args: |
| | tokens (torch.Tensor): discrete codes |
| | n_bands (int): bands for the eq matching. |
| | """ |
| | wav_encodec = self.codec_model.decode(tokens) |
| | condition = self.get_emb(tokens) |
| | wav_diffusion = self.generate(emb=condition, size=wav_encodec.size()) |
| | return self.re_eq(wav=wav_diffusion, ref=wav_encodec, n_bands=n_bands) |
| |
|