Respair's picture
Upload folder using huggingface_hub
bcdb559 verified
raw
history blame
7.06 kB
from typing import Callable, Optional, Sequence
import torch
import torch.nn.functional as F
from a_unet import (
ClassifierFreeGuidancePlugin,
Conv,
Module,
TextConditioningPlugin,
TimeConditioningPlugin,
default,
exists,
)
from a_unet.apex import (
AttentionItem,
CrossAttentionItem,
InjectChannelsItem,
ModulationItem,
ResnetItem,
SkipCat,
SkipModulate,
XBlock,
XUNet,
)
from einops import pack, unpack
from torch import Tensor, nn
from torchaudio import transforms
"""
UNets (built with a-unet: https://github.com/archinetai/a-unet)
"""
def UNetV0(
dim: int,
in_channels: int,
channels: Sequence[int],
factors: Sequence[int],
items: Sequence[int],
attentions: Optional[Sequence[int]] = None,
cross_attentions: Optional[Sequence[int]] = None,
context_channels: Optional[Sequence[int]] = None,
attention_features: Optional[int] = None,
attention_heads: Optional[int] = None,
embedding_features: Optional[int] = None,
resnet_groups: int = 8,
use_modulation: bool = True,
modulation_features: int = 1024,
embedding_max_length: Optional[int] = None,
use_time_conditioning: bool = True,
use_embedding_cfg: bool = False,
use_text_conditioning: bool = False,
out_channels: Optional[int] = None,
):
# Set defaults and check lengths
num_layers = len(channels)
attentions = default(attentions, [0] * num_layers)
cross_attentions = default(cross_attentions, [0] * num_layers)
context_channels = default(context_channels, [0] * num_layers)
xs = (channels, factors, items, attentions, cross_attentions, context_channels)
assert all(len(x) == num_layers for x in xs) # type: ignore
# Define UNet type
UNetV0 = XUNet
if use_embedding_cfg:
msg = "use_embedding_cfg requires embedding_max_length"
assert exists(embedding_max_length), msg
UNetV0 = ClassifierFreeGuidancePlugin(UNetV0, embedding_max_length)
if use_text_conditioning:
UNetV0 = TextConditioningPlugin(UNetV0)
if use_time_conditioning:
assert use_modulation, "use_time_conditioning requires use_modulation=True"
UNetV0 = TimeConditioningPlugin(UNetV0)
# Build
return UNetV0(
dim=dim,
in_channels=in_channels,
out_channels=out_channels,
blocks=[
XBlock(
channels=channels,
factor=factor,
context_channels=ctx_channels,
items=(
[ResnetItem]
+ [ModulationItem] * use_modulation
+ [InjectChannelsItem] * (ctx_channels > 0)
+ [AttentionItem] * att
+ [CrossAttentionItem] * cross
)
* items,
)
for channels, factor, items, att, cross, ctx_channels in zip(*xs) # type: ignore # noqa
],
skip_t=SkipModulate if use_modulation else SkipCat,
attention_features=attention_features,
attention_heads=attention_heads,
embedding_features=embedding_features,
modulation_features=modulation_features,
resnet_groups=resnet_groups,
)
"""
Plugins
"""
def LTPlugin(
net_t: Callable, num_filters: int, window_length: int, stride: int
) -> Callable[..., nn.Module]:
"""Learned Transform Plugin"""
def Net(
dim: int, in_channels: int, out_channels: Optional[int] = None, **kwargs
) -> nn.Module:
out_channels = default(out_channels, in_channels)
in_channel_transform = in_channels * num_filters
out_channel_transform = out_channels * num_filters # type: ignore
padding = window_length // 2 - stride // 2
encode = Conv(
dim=dim,
in_channels=in_channels,
out_channels=in_channel_transform,
kernel_size=window_length,
stride=stride,
padding=padding,
padding_mode="reflect",
bias=False,
)
decode = nn.ConvTranspose1d(
in_channels=out_channel_transform,
out_channels=out_channels, # type: ignore
kernel_size=window_length,
stride=stride,
padding=padding,
bias=False,
)
net = net_t( # type: ignore
dim=dim,
in_channels=in_channel_transform,
out_channels=out_channel_transform,
**kwargs
)
def forward(x: Tensor, *args, **kwargs):
x = encode(x)
x = net(x, *args, **kwargs)
x = decode(x)
return x
return Module([encode, decode, net], forward)
return Net
def AppendChannelsPlugin(
net_t: Callable,
channels: int,
):
def Net(
in_channels: int, out_channels: Optional[int] = None, **kwargs
) -> nn.Module:
out_channels = default(out_channels, in_channels)
net = net_t( # type: ignore
in_channels=in_channels + channels, out_channels=out_channels, **kwargs
)
def forward(x: Tensor, *args, append_channels: Tensor, **kwargs):
x = torch.cat([x, append_channels], dim=1)
return net(x, *args, **kwargs)
return Module([net], forward)
return Net
"""
Other
"""
class MelSpectrogram(nn.Module):
def __init__(
self,
n_fft: int,
hop_length: int,
win_length: int,
sample_rate: int,
n_mel_channels: int,
center: bool = False,
normalize: bool = False,
normalize_log: bool = False,
):
super().__init__()
self.padding = (n_fft - hop_length) // 2
self.normalize = normalize
self.normalize_log = normalize_log
self.hop_length = hop_length
self.to_spectrogram = transforms.Spectrogram(
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
center=center,
power=None,
)
self.to_mel_scale = transforms.MelScale(
n_mels=n_mel_channels, n_stft=n_fft // 2 + 1, sample_rate=sample_rate
)
def forward(self, waveform: Tensor) -> Tensor:
# Pack non-time dimension
waveform, ps = pack([waveform], "* t")
# Pad waveform
waveform = F.pad(waveform, [self.padding] * 2, mode="reflect")
# Compute STFT
spectrogram = self.to_spectrogram(waveform)
# Compute magnitude
spectrogram = torch.abs(spectrogram)
# Convert to mel scale
mel_spectrogram = self.to_mel_scale(spectrogram)
# Normalize
if self.normalize:
mel_spectrogram = mel_spectrogram / torch.max(mel_spectrogram)
mel_spectrogram = 2 * torch.pow(mel_spectrogram, 0.25) - 1
if self.normalize_log:
mel_spectrogram = torch.log(torch.clamp(mel_spectrogram, min=1e-5))
# Unpack non-spectrogram dimension
return unpack(mel_spectrogram, ps, "* f l")[0]