|
from typing import Callable, Optional, Sequence |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from a_unet import ( |
|
ClassifierFreeGuidancePlugin, |
|
Conv, |
|
Module, |
|
TextConditioningPlugin, |
|
TimeConditioningPlugin, |
|
default, |
|
exists, |
|
) |
|
from a_unet.apex import ( |
|
AttentionItem, |
|
CrossAttentionItem, |
|
InjectChannelsItem, |
|
ModulationItem, |
|
ResnetItem, |
|
SkipCat, |
|
SkipModulate, |
|
XBlock, |
|
XUNet, |
|
) |
|
from einops import pack, unpack |
|
from torch import Tensor, nn |
|
from torchaudio import transforms |
|
|
|
""" |
|
UNets (built with a-unet: https://github.com/archinetai/a-unet) |
|
""" |
|
|
|
|
|
def UNetV0( |
|
dim: int, |
|
in_channels: int, |
|
channels: Sequence[int], |
|
factors: Sequence[int], |
|
items: Sequence[int], |
|
attentions: Optional[Sequence[int]] = None, |
|
cross_attentions: Optional[Sequence[int]] = None, |
|
context_channels: Optional[Sequence[int]] = None, |
|
attention_features: Optional[int] = None, |
|
attention_heads: Optional[int] = None, |
|
embedding_features: Optional[int] = None, |
|
resnet_groups: int = 8, |
|
use_modulation: bool = True, |
|
modulation_features: int = 1024, |
|
embedding_max_length: Optional[int] = None, |
|
use_time_conditioning: bool = True, |
|
use_embedding_cfg: bool = False, |
|
use_text_conditioning: bool = False, |
|
out_channels: Optional[int] = None, |
|
): |
|
|
|
num_layers = len(channels) |
|
attentions = default(attentions, [0] * num_layers) |
|
cross_attentions = default(cross_attentions, [0] * num_layers) |
|
context_channels = default(context_channels, [0] * num_layers) |
|
xs = (channels, factors, items, attentions, cross_attentions, context_channels) |
|
assert all(len(x) == num_layers for x in xs) |
|
|
|
|
|
UNetV0 = XUNet |
|
|
|
if use_embedding_cfg: |
|
msg = "use_embedding_cfg requires embedding_max_length" |
|
assert exists(embedding_max_length), msg |
|
UNetV0 = ClassifierFreeGuidancePlugin(UNetV0, embedding_max_length) |
|
|
|
if use_text_conditioning: |
|
UNetV0 = TextConditioningPlugin(UNetV0) |
|
|
|
if use_time_conditioning: |
|
assert use_modulation, "use_time_conditioning requires use_modulation=True" |
|
UNetV0 = TimeConditioningPlugin(UNetV0) |
|
|
|
|
|
return UNetV0( |
|
dim=dim, |
|
in_channels=in_channels, |
|
out_channels=out_channels, |
|
blocks=[ |
|
XBlock( |
|
channels=channels, |
|
factor=factor, |
|
context_channels=ctx_channels, |
|
items=( |
|
[ResnetItem] |
|
+ [ModulationItem] * use_modulation |
|
+ [InjectChannelsItem] * (ctx_channels > 0) |
|
+ [AttentionItem] * att |
|
+ [CrossAttentionItem] * cross |
|
) |
|
* items, |
|
) |
|
for channels, factor, items, att, cross, ctx_channels in zip(*xs) |
|
], |
|
skip_t=SkipModulate if use_modulation else SkipCat, |
|
attention_features=attention_features, |
|
attention_heads=attention_heads, |
|
embedding_features=embedding_features, |
|
modulation_features=modulation_features, |
|
resnet_groups=resnet_groups, |
|
) |
|
|
|
|
|
""" |
|
Plugins |
|
""" |
|
|
|
|
|
def LTPlugin( |
|
net_t: Callable, num_filters: int, window_length: int, stride: int |
|
) -> Callable[..., nn.Module]: |
|
"""Learned Transform Plugin""" |
|
|
|
def Net( |
|
dim: int, in_channels: int, out_channels: Optional[int] = None, **kwargs |
|
) -> nn.Module: |
|
out_channels = default(out_channels, in_channels) |
|
in_channel_transform = in_channels * num_filters |
|
out_channel_transform = out_channels * num_filters |
|
|
|
padding = window_length // 2 - stride // 2 |
|
encode = Conv( |
|
dim=dim, |
|
in_channels=in_channels, |
|
out_channels=in_channel_transform, |
|
kernel_size=window_length, |
|
stride=stride, |
|
padding=padding, |
|
padding_mode="reflect", |
|
bias=False, |
|
) |
|
decode = nn.ConvTranspose1d( |
|
in_channels=out_channel_transform, |
|
out_channels=out_channels, |
|
kernel_size=window_length, |
|
stride=stride, |
|
padding=padding, |
|
bias=False, |
|
) |
|
net = net_t( |
|
dim=dim, |
|
in_channels=in_channel_transform, |
|
out_channels=out_channel_transform, |
|
**kwargs |
|
) |
|
|
|
def forward(x: Tensor, *args, **kwargs): |
|
x = encode(x) |
|
x = net(x, *args, **kwargs) |
|
x = decode(x) |
|
return x |
|
|
|
return Module([encode, decode, net], forward) |
|
|
|
return Net |
|
|
|
|
|
def AppendChannelsPlugin( |
|
net_t: Callable, |
|
channels: int, |
|
): |
|
def Net( |
|
in_channels: int, out_channels: Optional[int] = None, **kwargs |
|
) -> nn.Module: |
|
out_channels = default(out_channels, in_channels) |
|
net = net_t( |
|
in_channels=in_channels + channels, out_channels=out_channels, **kwargs |
|
) |
|
|
|
def forward(x: Tensor, *args, append_channels: Tensor, **kwargs): |
|
x = torch.cat([x, append_channels], dim=1) |
|
return net(x, *args, **kwargs) |
|
|
|
return Module([net], forward) |
|
|
|
return Net |
|
|
|
|
|
""" |
|
Other |
|
""" |
|
|
|
|
|
class MelSpectrogram(nn.Module): |
|
def __init__( |
|
self, |
|
n_fft: int, |
|
hop_length: int, |
|
win_length: int, |
|
sample_rate: int, |
|
n_mel_channels: int, |
|
center: bool = False, |
|
normalize: bool = False, |
|
normalize_log: bool = False, |
|
): |
|
super().__init__() |
|
self.padding = (n_fft - hop_length) // 2 |
|
self.normalize = normalize |
|
self.normalize_log = normalize_log |
|
self.hop_length = hop_length |
|
|
|
self.to_spectrogram = transforms.Spectrogram( |
|
n_fft=n_fft, |
|
hop_length=hop_length, |
|
win_length=win_length, |
|
center=center, |
|
power=None, |
|
) |
|
|
|
self.to_mel_scale = transforms.MelScale( |
|
n_mels=n_mel_channels, n_stft=n_fft // 2 + 1, sample_rate=sample_rate |
|
) |
|
|
|
def forward(self, waveform: Tensor) -> Tensor: |
|
|
|
waveform, ps = pack([waveform], "* t") |
|
|
|
waveform = F.pad(waveform, [self.padding] * 2, mode="reflect") |
|
|
|
spectrogram = self.to_spectrogram(waveform) |
|
|
|
spectrogram = torch.abs(spectrogram) |
|
|
|
mel_spectrogram = self.to_mel_scale(spectrogram) |
|
|
|
if self.normalize: |
|
mel_spectrogram = mel_spectrogram / torch.max(mel_spectrogram) |
|
mel_spectrogram = 2 * torch.pow(mel_spectrogram, 0.25) - 1 |
|
if self.normalize_log: |
|
mel_spectrogram = torch.log(torch.clamp(mel_spectrogram, min=1e-5)) |
|
|
|
return unpack(mel_spectrogram, ps, "* f l")[0] |
|
|