|
from abc import ABC, abstractmethod |
|
from math import floor |
|
from typing import Any, Callable, Optional, Sequence, Tuple, Union |
|
|
|
import torch |
|
from einops import pack, rearrange, unpack |
|
from torch import Generator, Tensor, nn |
|
|
|
from .components import AppendChannelsPlugin, MelSpectrogram |
|
from .diffusion import ARVDiffusion, ARVSampler, VDiffusion, VSampler |
|
from .utils import ( |
|
closest_power_2, |
|
default, |
|
downsample, |
|
exists, |
|
groupby, |
|
randn_like, |
|
upsample, |
|
) |
|
|
|
|
|
class DiffusionModel(nn.Module): |
|
def __init__( |
|
self, |
|
net_t: Callable, |
|
diffusion_t: Callable = VDiffusion, |
|
sampler_t: Callable = VSampler, |
|
loss_fn: Callable = torch.nn.functional.mse_loss, |
|
dim: int = 1, |
|
**kwargs, |
|
): |
|
super().__init__() |
|
diffusion_kwargs, kwargs = groupby("diffusion_", kwargs) |
|
sampler_kwargs, kwargs = groupby("sampler_", kwargs) |
|
|
|
self.net = net_t(dim=dim, **kwargs) |
|
self.diffusion = diffusion_t(net=self.net, loss_fn=loss_fn, **diffusion_kwargs) |
|
self.sampler = sampler_t(net=self.net, **sampler_kwargs) |
|
|
|
def forward(self, *args, **kwargs) -> Tensor: |
|
return self.diffusion(*args, **kwargs) |
|
|
|
@torch.no_grad() |
|
def sample(self, *args, **kwargs) -> Tensor: |
|
return self.sampler(*args, **kwargs) |
|
|
|
|
|
class EncoderBase(nn.Module, ABC): |
|
"""Abstract class for DiffusionAE encoder""" |
|
|
|
@abstractmethod |
|
def __init__(self): |
|
super().__init__() |
|
self.out_channels = None |
|
self.downsample_factor = None |
|
|
|
|
|
class AdapterBase(nn.Module, ABC): |
|
"""Abstract class for DiffusionAE encoder""" |
|
|
|
@abstractmethod |
|
def encode(self, x: Tensor) -> Tensor: |
|
pass |
|
|
|
@abstractmethod |
|
def decode(self, x: Tensor) -> Tensor: |
|
pass |
|
|
|
|
|
class DiffusionAE(DiffusionModel): |
|
"""Diffusion Auto Encoder""" |
|
|
|
def __init__( |
|
self, |
|
in_channels: int, |
|
channels: Sequence[int], |
|
encoder: EncoderBase, |
|
inject_depth: int, |
|
latent_factor: Optional[int] = None, |
|
adapter: Optional[AdapterBase] = None, |
|
**kwargs, |
|
): |
|
context_channels = [0] * len(channels) |
|
context_channels[inject_depth] = encoder.out_channels |
|
super().__init__( |
|
in_channels=in_channels, |
|
channels=channels, |
|
context_channels=context_channels, |
|
**kwargs, |
|
) |
|
self.in_channels = in_channels |
|
self.encoder = encoder |
|
self.inject_depth = inject_depth |
|
|
|
self.latent_factor = default(latent_factor, self.encoder.downsample_factor) |
|
self.adapter = adapter.requires_grad_(False) if exists(adapter) else None |
|
|
|
def forward( |
|
self, x: Tensor, with_info: bool = False, **kwargs |
|
) -> Union[Tensor, Tuple[Tensor, Any]]: |
|
|
|
latent, info = self.encode(x, with_info=True) |
|
channels = [None] * self.inject_depth + [latent] |
|
|
|
x = self.adapter.encode(x) if exists(self.adapter) else x |
|
|
|
loss = super().forward(x, channels=channels, **kwargs) |
|
return (loss, info) if with_info else loss |
|
|
|
def encode(self, *args, **kwargs): |
|
return self.encoder(*args, **kwargs) |
|
|
|
@torch.no_grad() |
|
def decode( |
|
self, latent: Tensor, generator: Optional[Generator] = None, **kwargs |
|
) -> Tensor: |
|
b = latent.shape[0] |
|
noise_length = closest_power_2(latent.shape[2] * self.latent_factor) |
|
|
|
noise = torch.randn( |
|
(b, self.in_channels, noise_length), |
|
device=latent.device, |
|
dtype=latent.dtype, |
|
generator=generator, |
|
) |
|
|
|
channels = [None] * self.inject_depth + [latent] |
|
|
|
out = super().sample(noise, channels=channels, **kwargs) |
|
|
|
return self.adapter.decode(out) if exists(self.adapter) else out |
|
|
|
|
|
class DiffusionUpsampler(DiffusionModel): |
|
def __init__( |
|
self, |
|
in_channels: int, |
|
upsample_factor: int, |
|
net_t: Callable, |
|
**kwargs, |
|
): |
|
self.upsample_factor = upsample_factor |
|
super().__init__( |
|
net_t=AppendChannelsPlugin(net_t, channels=in_channels), |
|
in_channels=in_channels, |
|
**kwargs, |
|
) |
|
|
|
def reupsample(self, x: Tensor) -> Tensor: |
|
x = x.clone() |
|
x = downsample(x, factor=self.upsample_factor) |
|
x = upsample(x, factor=self.upsample_factor) |
|
return x |
|
|
|
def forward(self, x: Tensor, *args, **kwargs) -> Tensor: |
|
reupsampled = self.reupsample(x) |
|
return super().forward(x, *args, append_channels=reupsampled, **kwargs) |
|
|
|
@torch.no_grad() |
|
def sample( |
|
self, downsampled: Tensor, generator: Optional[Generator] = None, **kwargs |
|
) -> Tensor: |
|
reupsampled = upsample(downsampled, factor=self.upsample_factor) |
|
noise = randn_like(reupsampled, generator=generator) |
|
return super().sample(noise, append_channels=reupsampled, **kwargs) |
|
|
|
|
|
class DiffusionVocoder(DiffusionModel): |
|
def __init__( |
|
self, |
|
net_t: Callable, |
|
mel_channels: int, |
|
mel_n_fft: int, |
|
mel_hop_length: Optional[int] = None, |
|
mel_win_length: Optional[int] = None, |
|
in_channels: int = 1, |
|
**kwargs, |
|
): |
|
mel_hop_length = default(mel_hop_length, floor(mel_n_fft) // 4) |
|
mel_win_length = default(mel_win_length, mel_n_fft) |
|
mel_kwargs, kwargs = groupby("mel_", kwargs) |
|
super().__init__( |
|
net_t=AppendChannelsPlugin(net_t, channels=1), |
|
in_channels=1, |
|
**kwargs, |
|
) |
|
self.to_spectrogram = MelSpectrogram( |
|
n_fft=mel_n_fft, |
|
hop_length=mel_hop_length, |
|
win_length=mel_win_length, |
|
n_mel_channels=mel_channels, |
|
**mel_kwargs, |
|
) |
|
self.to_flat = nn.ConvTranspose1d( |
|
in_channels=mel_channels, |
|
out_channels=1, |
|
kernel_size=mel_win_length, |
|
stride=mel_hop_length, |
|
padding=(mel_win_length - mel_hop_length) // 2, |
|
bias=False, |
|
) |
|
|
|
def forward(self, x: Tensor, *args, **kwargs) -> Tensor: |
|
|
|
spectrogram = rearrange(self.to_spectrogram(x), "b c f l -> (b c) f l") |
|
spectrogram_flat = self.to_flat(spectrogram) |
|
|
|
x = rearrange(x, "b c t -> (b c) 1 t") |
|
return super().forward(x, *args, append_channels=spectrogram_flat, **kwargs) |
|
|
|
@torch.no_grad() |
|
def sample( |
|
self, spectrogram: Tensor, generator: Optional[Generator] = None, **kwargs |
|
) -> Tensor: |
|
|
|
spectrogram, ps = pack([spectrogram], "* f l") |
|
spectrogram_flat = self.to_flat(spectrogram) |
|
|
|
noise = randn_like(spectrogram_flat, generator=generator) |
|
waveform = super().sample(noise, append_channels=spectrogram_flat, **kwargs) |
|
|
|
waveform = rearrange(waveform, "... 1 t -> ... t") |
|
waveform = unpack(waveform, ps, "* t")[0] |
|
return waveform |
|
|
|
|
|
class DiffusionAR(DiffusionModel): |
|
def __init__( |
|
self, |
|
in_channels: int, |
|
length: int, |
|
num_splits: int, |
|
diffusion_t: Callable = ARVDiffusion, |
|
sampler_t: Callable = ARVSampler, |
|
**kwargs, |
|
): |
|
super().__init__( |
|
in_channels=in_channels + 1, |
|
out_channels=in_channels, |
|
diffusion_t=diffusion_t, |
|
diffusion_length=length, |
|
diffusion_num_splits=num_splits, |
|
sampler_t=sampler_t, |
|
sampler_in_channels=in_channels, |
|
sampler_length=length, |
|
sampler_num_splits=num_splits, |
|
use_time_conditioning=False, |
|
use_modulation=False, |
|
**kwargs, |
|
) |
|
|