# Copied and modified from https://github.com/archinetai/audio-diffusion-pytorch/blob/v0.0.94/audio_diffusion_pytorch/modules.py under MIT License # License can be found in LICENSES/LICENSE_ADP.txt import math from inspect import isfunction from math import ceil, floor, log, pi, log2 from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union from packaging import version import torch import torch.nn as nn from einops import rearrange, reduce, repeat from einops.layers.torch import Rearrange from einops_exts import rearrange_many from torch import Tensor, einsum from torch.backends.cuda import sdp_kernel from torch.nn import functional as F from dac.nn.layers import Snake1d import pdb """ Utils """ class ConditionedSequential(nn.Module): def __init__(self, *modules): super().__init__() self.module_list = nn.ModuleList(*modules) def forward(self, x: Tensor, mapping: Optional[Tensor] = None): for module in self.module_list: x = module(x, mapping) return x T = TypeVar("T") def default(val: Optional[T], d: Union[Callable[..., T], T]) -> T: if exists(val): return val return d() if isfunction(d) else d def exists(val: Optional[T]) -> T: return val is not None def closest_power_2(x: float) -> int: exponent = log2(x) distance_fn = lambda z: abs(x - 2 ** z) # noqa exponent_closest = min((floor(exponent), ceil(exponent)), key=distance_fn) return 2 ** int(exponent_closest) def group_dict_by_prefix(prefix: str, d: Dict) -> Tuple[Dict, Dict]: return_dicts: Tuple[Dict, Dict] = ({}, {}) for key in d.keys(): no_prefix = int(not key.startswith(prefix)) return_dicts[no_prefix][key] = d[key] return return_dicts def groupby(prefix: str, d: Dict, keep_prefix: bool = False) -> Tuple[Dict, Dict]: kwargs_with_prefix, kwargs = group_dict_by_prefix(prefix, d) if keep_prefix: return kwargs_with_prefix, kwargs kwargs_no_prefix = {k[len(prefix) :]: v for k, v in kwargs_with_prefix.items()} return kwargs_no_prefix, kwargs """ Convolutional Blocks """ import typing as tp # Copied from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/modules/conv.py under MIT License # License available in LICENSES/LICENSE_META.txt def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0) -> int: """See `pad_for_conv1d`.""" length = x.shape[-1] n_frames = (length - kernel_size + padding_total) / stride + 1 ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total) return ideal_length - length def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0): """Pad for a convolution to make sure that the last window is full. Extra padding is added at the end. This is required to ensure that we can rebuild an output of the same length, as otherwise, even with padding, some time steps might get removed. For instance, with total padding = 4, kernel size = 4, stride = 2: 0 0 1 2 3 4 5 0 0 # (0s are padding) 1 2 3 # (output frames of a convolution, last 0 is never used) 0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding) 1 2 3 4 # once you removed padding, we are missing one time step ! """ extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total) return F.pad(x, (0, extra_padding)) def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'constant', value: float = 0.): """Tiny wrapper around F.pad, just to allow for reflect padding on small input. If this is the case, we insert extra 0 padding to the right before the reflection happen. """ length = x.shape[-1] padding_left, padding_right = paddings assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) if mode == 'reflect': max_pad = max(padding_left, padding_right) extra_pad = 0 if length <= max_pad: extra_pad = max_pad - length + 1 x = F.pad(x, (0, extra_pad)) padded = F.pad(x, paddings, mode, value) end = padded.shape[-1] - extra_pad return padded[..., :end] else: return F.pad(x, paddings, mode, value) def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]): """Remove padding from x, handling properly zero padding. Only for 1d!""" padding_left, padding_right = paddings assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) assert (padding_left + padding_right) <= x.shape[-1] end = x.shape[-1] - padding_right return x[..., padding_left: end] class Conv1d(nn.Conv1d): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def forward(self, x: Tensor, causal=False) -> Tensor: kernel_size = self.kernel_size[0] stride = self.stride[0] dilation = self.dilation[0] kernel_size = (kernel_size - 1) * dilation + 1 # effective kernel size with dilations padding_total = kernel_size - stride extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total) if causal: # Left padding for causal x = pad1d(x, (padding_total, extra_padding)) else: # Asymmetric padding required for odd strides padding_right = padding_total // 2 padding_left = padding_total - padding_right x = pad1d(x, (padding_left, padding_right + extra_padding)) return super().forward(x) class ConvTranspose1d(nn.ConvTranspose1d): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def forward(self, x: Tensor, causal=False) -> Tensor: kernel_size = self.kernel_size[0] stride = self.stride[0] padding_total = kernel_size - stride y = super().forward(x) # We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be # removed at the very end, when keeping only the right length for the output, # as removing it here would require also passing the length at the matching layer # in the encoder. if causal: padding_right = ceil(padding_total) padding_left = padding_total - padding_right y = unpad1d(y, (padding_left, padding_right)) else: # Asymmetric padding required for odd strides padding_right = padding_total // 2 padding_left = padding_total - padding_right y = unpad1d(y, (padding_left, padding_right)) return y def Downsample1d( in_channels: int, out_channels: int, factor: int, kernel_multiplier: int = 2 ) -> nn.Module: assert kernel_multiplier % 2 == 0, "Kernel multiplier must be even" return Conv1d( in_channels=in_channels, out_channels=out_channels, kernel_size=factor * kernel_multiplier + 1, stride=factor ) def Upsample1d( in_channels: int, out_channels: int, factor: int, use_nearest: bool = False ) -> nn.Module: if factor == 1: return Conv1d( in_channels=in_channels, out_channels=out_channels, kernel_size=3 ) if use_nearest: return nn.Sequential( nn.Upsample(scale_factor=factor, mode="nearest"), Conv1d( in_channels=in_channels, out_channels=out_channels, kernel_size=3 ), ) else: return ConvTranspose1d( in_channels=in_channels, out_channels=out_channels, kernel_size=factor * 2, stride=factor ) class ConvBlock1d(nn.Module): def __init__( self, in_channels: int, out_channels: int, *, kernel_size: int = 3, stride: int = 1, dilation: int = 1, num_groups: int = 8, use_norm: bool = True, use_snake: bool = False ) -> None: super().__init__() self.groupnorm = ( nn.GroupNorm(num_groups=num_groups, num_channels=in_channels) if use_norm else nn.Identity() ) if use_snake: self.activation = Snake1d(in_channels) else: self.activation = nn.SiLU() self.project = Conv1d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, dilation=dilation, ) def forward( self, x: Tensor, scale_shift: Optional[Tuple[Tensor, Tensor]] = None, causal=False ) -> Tensor: x = self.groupnorm(x) if exists(scale_shift): scale, shift = scale_shift x = x * (scale + 1) + shift x = self.activation(x) return self.project(x, causal=causal) class MappingToScaleShift(nn.Module): def __init__( self, features: int, channels: int, ): super().__init__() self.to_scale_shift = nn.Sequential( nn.SiLU(), nn.Linear(in_features=features, out_features=channels * 2), ) def forward(self, mapping: Tensor) -> Tuple[Tensor, Tensor]: scale_shift = self.to_scale_shift(mapping) scale_shift = rearrange(scale_shift, "b c -> b c 1") scale, shift = scale_shift.chunk(2, dim=1) return scale, shift class ResnetBlock1d(nn.Module): def __init__( self, in_channels: int, out_channels: int, *, kernel_size: int = 3, stride: int = 1, dilation: int = 1, use_norm: bool = True, use_snake: bool = False, num_groups: int = 8, context_mapping_features: Optional[int] = None, ) -> None: super().__init__() self.use_mapping = exists(context_mapping_features) self.block1 = ConvBlock1d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, dilation=dilation, use_norm=use_norm, num_groups=num_groups, use_snake=use_snake ) if self.use_mapping: assert exists(context_mapping_features) self.to_scale_shift = MappingToScaleShift( features=context_mapping_features, channels=out_channels ) self.block2 = ConvBlock1d( in_channels=out_channels, out_channels=out_channels, use_norm=use_norm, num_groups=num_groups, use_snake=use_snake ) self.to_out = ( Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=1) if in_channels != out_channels else nn.Identity() ) def forward(self, x: Tensor, mapping: Optional[Tensor] = None, causal=False) -> Tensor: assert_message = "context mapping required if context_mapping_features > 0" assert not (self.use_mapping ^ exists(mapping)), assert_message h = self.block1(x, causal=causal) scale_shift = None if self.use_mapping: scale_shift = self.to_scale_shift(mapping) h = self.block2(h, scale_shift=scale_shift, causal=causal) return h + self.to_out(x) class Patcher(nn.Module): def __init__( self, in_channels: int, out_channels: int, patch_size: int, context_mapping_features: Optional[int] = None, use_snake: bool = False, ): super().__init__() assert_message = f"out_channels must be divisible by patch_size ({patch_size})" assert out_channels % patch_size == 0, assert_message self.patch_size = patch_size self.block = ResnetBlock1d( in_channels=in_channels, out_channels=out_channels // patch_size, num_groups=1, context_mapping_features=context_mapping_features, use_snake=use_snake ) def forward(self, x: Tensor, mapping: Optional[Tensor] = None, causal=False) -> Tensor: x = self.block(x, mapping, causal=causal) x = rearrange(x, "b c (l p) -> b (c p) l", p=self.patch_size) return x class Unpatcher(nn.Module): def __init__( self, in_channels: int, out_channels: int, patch_size: int, context_mapping_features: Optional[int] = None, use_snake: bool = False ): super().__init__() assert_message = f"in_channels must be divisible by patch_size ({patch_size})" assert in_channels % patch_size == 0, assert_message self.patch_size = patch_size self.block = ResnetBlock1d( in_channels=in_channels // patch_size, out_channels=out_channels, num_groups=1, context_mapping_features=context_mapping_features, use_snake=use_snake ) def forward(self, x: Tensor, mapping: Optional[Tensor] = None, causal=False) -> Tensor: x = rearrange(x, " b (c p) l -> b c (l p) ", p=self.patch_size) x = self.block(x, mapping, causal=causal) return x """ Attention Components """ def FeedForward(features: int, multiplier: int) -> nn.Module: mid_features = features * multiplier return nn.Sequential( nn.Linear(in_features=features, out_features=mid_features), nn.GELU(), nn.Linear(in_features=mid_features, out_features=features), ) def add_mask(sim: Tensor, mask: Tensor) -> Tensor: b, ndim = sim.shape[0], mask.ndim if ndim == 3: mask = rearrange(mask, "b n m -> b 1 n m") if ndim == 2: mask = repeat(mask, "n m -> b 1 n m", b=b) max_neg_value = -torch.finfo(sim.dtype).max sim = sim.masked_fill(~mask, max_neg_value) return sim def causal_mask(q: Tensor, k: Tensor) -> Tensor: b, i, j, device = q.shape[0], q.shape[-2], k.shape[-2], q.device mask = ~torch.ones((i, j), dtype=torch.bool, device=device).triu(j - i + 1) mask = repeat(mask, "n m -> b n m", b=b) return mask class AttentionBase(nn.Module): def __init__( self, features: int, *, head_features: int, num_heads: int, out_features: Optional[int] = None, ): super().__init__() self.scale = head_features**-0.5 self.num_heads = num_heads mid_features = head_features * num_heads out_features = default(out_features, features) self.to_out = nn.Linear( in_features=mid_features, out_features=out_features ) self.use_flash = torch.cuda.is_available() and version.parse(torch.__version__) >= version.parse('2.0.0') if not self.use_flash: return device_properties = torch.cuda.get_device_properties(torch.device('cuda')) if device_properties.major == 8 and device_properties.minor == 0: # Use flash attention for A100 GPUs self.sdp_kernel_config = (True, False, False) else: # Don't use flash attention for other GPUs self.sdp_kernel_config = (False, True, True) def forward( self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None, is_causal: bool = False ) -> Tensor: # Split heads q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=self.num_heads) if not self.use_flash: if is_causal and not mask: # Mask out future tokens for causal attention mask = causal_mask(q, k) # Compute similarity matrix and add eventual mask sim = einsum("... n d, ... m d -> ... n m", q, k) * self.scale sim = add_mask(sim, mask) if exists(mask) else sim # Get attention matrix with softmax attn = sim.softmax(dim=-1, dtype=torch.float32) # Compute values out = einsum("... n m, ... m d -> ... n d", attn, v) else: with sdp_kernel(*self.sdp_kernel_config): out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, is_causal=is_causal) out = rearrange(out, "b h n d -> b n (h d)") return self.to_out(out) class Attention(nn.Module): def __init__( self, features: int, *, head_features: int, num_heads: int, out_features: Optional[int] = None, context_features: Optional[int] = None, causal: bool = False, ): super().__init__() self.context_features = context_features self.causal = causal mid_features = head_features * num_heads context_features = default(context_features, features) self.norm = nn.LayerNorm(features) self.norm_context = nn.LayerNorm(context_features) self.to_q = nn.Linear( in_features=features, out_features=mid_features, bias=False ) self.to_kv = nn.Linear( in_features=context_features, out_features=mid_features * 2, bias=False ) self.attention = AttentionBase( features, num_heads=num_heads, head_features=head_features, out_features=out_features, ) def forward( self, x: Tensor, # [b, n, c] context: Optional[Tensor] = None, # [b, m, d] context_mask: Optional[Tensor] = None, # [b, m], false is masked, causal: Optional[bool] = False, ) -> Tensor: assert_message = "You must provide a context when using context_features" assert not self.context_features or exists(context), assert_message # Use context if provided context = default(context, x) # Normalize then compute q from input and k,v from context x, context = self.norm(x), self.norm_context(context) q, k, v = (self.to_q(x), *torch.chunk(self.to_kv(context), chunks=2, dim=-1)) if exists(context_mask): # Mask out cross-attention for padding tokens mask = repeat(context_mask, "b m -> b m d", d=v.shape[-1]) k, v = k * mask, v * mask # Compute and return attention return self.attention(q, k, v, is_causal=self.causal or causal) def FeedForward(features: int, multiplier: int) -> nn.Module: mid_features = features * multiplier return nn.Sequential( nn.Linear(in_features=features, out_features=mid_features), nn.GELU(), nn.Linear(in_features=mid_features, out_features=features), ) """ Transformer Blocks """ class TransformerBlock(nn.Module): def __init__( self, features: int, num_heads: int, head_features: int, multiplier: int, context_features: Optional[int] = None, ): super().__init__() self.use_cross_attention = exists(context_features) and context_features > 0 self.attention = Attention( features=features, num_heads=num_heads, head_features=head_features ) if self.use_cross_attention: self.cross_attention = Attention( features=features, num_heads=num_heads, head_features=head_features, context_features=context_features ) self.feed_forward = FeedForward(features=features, multiplier=multiplier) def forward(self, x: Tensor, *, context: Optional[Tensor] = None, context_mask: Optional[Tensor] = None, causal: Optional[bool] = False) -> Tensor: x = self.attention(x, causal=causal) + x if self.use_cross_attention: x = self.cross_attention(x, context=context, context_mask=context_mask) + x x = self.feed_forward(x) + x return x """ Transformers """ class Transformer1d(nn.Module): def __init__( self, num_layers: int, channels: int, num_heads: int, head_features: int, multiplier: int, context_features: Optional[int] = None, ): super().__init__() self.to_in = nn.Sequential( nn.GroupNorm(num_groups=32, num_channels=channels, eps=1e-6, affine=True), Conv1d( in_channels=channels, out_channels=channels, kernel_size=1, ), Rearrange("b c t -> b t c"), ) self.blocks = nn.ModuleList( [ TransformerBlock( features=channels, head_features=head_features, num_heads=num_heads, multiplier=multiplier, context_features=context_features, ) for i in range(num_layers) ] ) self.to_out = nn.Sequential( Rearrange("b t c -> b c t"), Conv1d( in_channels=channels, out_channels=channels, kernel_size=1, ), ) def forward(self, x: Tensor, *, context: Optional[Tensor] = None, context_mask: Optional[Tensor] = None, causal=False) -> Tensor: x = self.to_in(x) for block in self.blocks: x = block(x, context=context, context_mask=context_mask, causal=causal) x = self.to_out(x) return x """ Time Embeddings """ class SinusoidalEmbedding(nn.Module): def __init__(self, dim: int): super().__init__() self.dim = dim def forward(self, x: Tensor) -> Tensor: device, half_dim = x.device, self.dim // 2 emb = torch.tensor(log(10000) / (half_dim - 1), device=device) emb = torch.exp(torch.arange(half_dim, device=device) * -emb) emb = rearrange(x, "i -> i 1") * rearrange(emb, "j -> 1 j") return torch.cat((emb.sin(), emb.cos()), dim=-1) class LearnedPositionalEmbedding(nn.Module): """Used for continuous time""" def __init__(self, dim: int): super().__init__() assert (dim % 2) == 0 half_dim = dim // 2 self.weights = nn.Parameter(torch.randn(half_dim)) def forward(self, x: Tensor) -> Tensor: x = rearrange(x, "b -> b 1") freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * pi fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1) fouriered = torch.cat((x, fouriered), dim=-1) return fouriered def TimePositionalEmbedding(dim: int, out_features: int) -> nn.Module: return nn.Sequential( LearnedPositionalEmbedding(dim), nn.Linear(in_features=dim + 1, out_features=out_features), ) """ Encoder/Decoder Components """ class DownsampleBlock1d(nn.Module): def __init__( self, in_channels: int, out_channels: int, *, factor: int, num_groups: int, num_layers: int, kernel_multiplier: int = 2, use_pre_downsample: bool = True, use_skip: bool = False, use_snake: bool = False, extract_channels: int = 0, context_channels: int = 0, num_transformer_blocks: int = 0, attention_heads: Optional[int] = None, attention_features: Optional[int] = None, attention_multiplier: Optional[int] = None, context_mapping_features: Optional[int] = None, context_embedding_features: Optional[int] = None, ): super().__init__() self.use_pre_downsample = use_pre_downsample self.use_skip = use_skip self.use_transformer = num_transformer_blocks > 0 self.use_extract = extract_channels > 0 self.use_context = context_channels > 0 channels = out_channels if use_pre_downsample else in_channels self.downsample = Downsample1d( in_channels=in_channels, out_channels=out_channels, factor=factor, kernel_multiplier=kernel_multiplier, ) self.blocks = nn.ModuleList( [ ResnetBlock1d( in_channels=channels + context_channels if i == 0 else channels, out_channels=channels, num_groups=num_groups, context_mapping_features=context_mapping_features, use_snake=use_snake ) for i in range(num_layers) ] ) if self.use_transformer: assert ( (exists(attention_heads) or exists(attention_features)) and exists(attention_multiplier) ) if attention_features is None and attention_heads is not None: attention_features = channels // attention_heads if attention_heads is None and attention_features is not None: attention_heads = channels // attention_features self.transformer = Transformer1d( num_layers=num_transformer_blocks, channels=channels, num_heads=attention_heads, head_features=attention_features, multiplier=attention_multiplier, context_features=context_embedding_features ) if self.use_extract: num_extract_groups = min(num_groups, extract_channels) self.to_extracted = ResnetBlock1d( in_channels=out_channels, out_channels=extract_channels, num_groups=num_extract_groups, use_snake=use_snake ) def forward( self, x: Tensor, *, mapping: Optional[Tensor] = None, channels: Optional[Tensor] = None, embedding: Optional[Tensor] = None, embedding_mask: Optional[Tensor] = None, causal: Optional[bool] = False ) -> Union[Tuple[Tensor, List[Tensor]], Tensor]: if self.use_pre_downsample: x = self.downsample(x) if self.use_context and exists(channels): x = torch.cat([x, channels], dim=1) skips = [] for block in self.blocks: x = block(x, mapping=mapping, causal=causal) skips += [x] if self.use_skip else [] if self.use_transformer: x = self.transformer(x, context=embedding, context_mask=embedding_mask, causal=causal) skips += [x] if self.use_skip else [] if not self.use_pre_downsample: x = self.downsample(x) if self.use_extract: extracted = self.to_extracted(x) return x, extracted return (x, skips) if self.use_skip else x class UpsampleBlock1d(nn.Module): def __init__( self, in_channels: int, out_channels: int, *, factor: int, num_layers: int, num_groups: int, use_nearest: bool = False, use_pre_upsample: bool = False, use_skip: bool = False, use_snake: bool = False, skip_channels: int = 0, use_skip_scale: bool = False, extract_channels: int = 0, num_transformer_blocks: int = 0, attention_heads: Optional[int] = None, attention_features: Optional[int] = None, attention_multiplier: Optional[int] = None, context_mapping_features: Optional[int] = None, context_embedding_features: Optional[int] = None, ): super().__init__() self.use_extract = extract_channels > 0 self.use_pre_upsample = use_pre_upsample self.use_transformer = num_transformer_blocks > 0 self.use_skip = use_skip self.skip_scale = 2 ** -0.5 if use_skip_scale else 1.0 channels = out_channels if use_pre_upsample else in_channels self.blocks = nn.ModuleList( [ ResnetBlock1d( in_channels=channels + skip_channels, out_channels=channels, num_groups=num_groups, context_mapping_features=context_mapping_features, use_snake=use_snake ) for _ in range(num_layers) ] ) if self.use_transformer: assert ( (exists(attention_heads) or exists(attention_features)) and exists(attention_multiplier) ) if attention_features is None and attention_heads is not None: attention_features = channels // attention_heads if attention_heads is None and attention_features is not None: attention_heads = channels // attention_features self.transformer = Transformer1d( num_layers=num_transformer_blocks, channels=channels, num_heads=attention_heads, head_features=attention_features, multiplier=attention_multiplier, context_features=context_embedding_features, ) self.upsample = Upsample1d( in_channels=in_channels, out_channels=out_channels, factor=factor, use_nearest=use_nearest, ) if self.use_extract: num_extract_groups = min(num_groups, extract_channels) self.to_extracted = ResnetBlock1d( in_channels=out_channels, out_channels=extract_channels, num_groups=num_extract_groups, use_snake=use_snake ) def add_skip(self, x: Tensor, skip: Tensor) -> Tensor: return torch.cat([x, skip * self.skip_scale], dim=1) def forward( self, x: Tensor, *, skips: Optional[List[Tensor]] = None, mapping: Optional[Tensor] = None, embedding: Optional[Tensor] = None, embedding_mask: Optional[Tensor] = None, causal: Optional[bool] = False ) -> Union[Tuple[Tensor, Tensor], Tensor]: if self.use_pre_upsample: x = self.upsample(x) for block in self.blocks: x = self.add_skip(x, skip=skips.pop()) if exists(skips) else x x = block(x, mapping=mapping, causal=causal) if self.use_transformer: x = self.transformer(x, context=embedding, context_mask=embedding_mask, causal=causal) if not self.use_pre_upsample: x = self.upsample(x) if self.use_extract: extracted = self.to_extracted(x) return x, extracted return x class BottleneckBlock1d(nn.Module): def __init__( self, channels: int, *, num_groups: int, num_transformer_blocks: int = 0, attention_heads: Optional[int] = None, attention_features: Optional[int] = None, attention_multiplier: Optional[int] = None, context_mapping_features: Optional[int] = None, context_embedding_features: Optional[int] = None, use_snake: bool = False, ): super().__init__() self.use_transformer = num_transformer_blocks > 0 self.pre_block = ResnetBlock1d( in_channels=channels, out_channels=channels, num_groups=num_groups, context_mapping_features=context_mapping_features, use_snake=use_snake ) if self.use_transformer: assert ( (exists(attention_heads) or exists(attention_features)) and exists(attention_multiplier) ) if attention_features is None and attention_heads is not None: attention_features = channels // attention_heads if attention_heads is None and attention_features is not None: attention_heads = channels // attention_features self.transformer = Transformer1d( num_layers=num_transformer_blocks, channels=channels, num_heads=attention_heads, head_features=attention_features, multiplier=attention_multiplier, context_features=context_embedding_features, ) self.post_block = ResnetBlock1d( in_channels=channels, out_channels=channels, num_groups=num_groups, context_mapping_features=context_mapping_features, use_snake=use_snake ) def forward( self, x: Tensor, *, mapping: Optional[Tensor] = None, embedding: Optional[Tensor] = None, embedding_mask: Optional[Tensor] = None, causal: Optional[bool] = False ) -> Tensor: x = self.pre_block(x, mapping=mapping, causal=causal) if self.use_transformer: x = self.transformer(x, context=embedding, context_mask=embedding_mask, causal=causal) x = self.post_block(x, mapping=mapping, causal=causal) return x """ UNet """ class UNet1d(nn.Module): def __init__( self, in_channels: int, channels: int, multipliers: Sequence[int], factors: Sequence[int], num_blocks: Sequence[int], attentions: Sequence[int], patch_size: int = 1, resnet_groups: int = 8, use_context_time: bool = True, kernel_multiplier_downsample: int = 2, use_nearest_upsample: bool = False, use_skip_scale: bool = True, use_snake: bool = False, use_stft: bool = False, use_stft_context: bool = False, out_channels: Optional[int] = None, context_features: Optional[int] = None, context_features_multiplier: int = 4, context_channels: Optional[Sequence[int]] = None, context_embedding_features: Optional[int] = None, **kwargs, ): super().__init__() out_channels = default(out_channels, in_channels) context_channels = list(default(context_channels, [])) num_layers = len(multipliers) - 1 use_context_features = exists(context_features) use_context_channels = len(context_channels) > 0 context_mapping_features = None attention_kwargs, kwargs = groupby("attention_", kwargs, keep_prefix=True) self.num_layers = num_layers self.use_context_time = use_context_time self.use_context_features = use_context_features self.use_context_channels = use_context_channels self.use_stft = use_stft self.use_stft_context = use_stft_context self.context_features = context_features context_channels_pad_length = num_layers + 1 - len(context_channels) context_channels = context_channels + [0] * context_channels_pad_length self.context_channels = context_channels self.context_embedding_features = context_embedding_features if use_context_channels: has_context = [c > 0 for c in context_channels] self.has_context = has_context self.channels_ids = [sum(has_context[:i]) for i in range(len(has_context))] assert ( len(factors) == num_layers and len(attentions) >= num_layers and len(num_blocks) == num_layers ) if use_context_time or use_context_features: context_mapping_features = channels * context_features_multiplier self.to_mapping = nn.Sequential( nn.Linear(context_mapping_features, context_mapping_features), nn.GELU(), nn.Linear(context_mapping_features, context_mapping_features), nn.GELU(), ) if use_context_time: assert exists(context_mapping_features) self.to_time = nn.Sequential( TimePositionalEmbedding( dim=channels, out_features=context_mapping_features ), nn.GELU(), ) if use_context_features: assert exists(context_features) and exists(context_mapping_features) self.to_features = nn.Sequential( nn.Linear( in_features=context_features, out_features=context_mapping_features ), nn.GELU(), ) if use_stft: stft_kwargs, kwargs = groupby("stft_", kwargs) assert "num_fft" in stft_kwargs, "stft_num_fft required if use_stft=True" stft_channels = (stft_kwargs["num_fft"] // 2 + 1) * 2 in_channels *= stft_channels out_channels *= stft_channels context_channels[0] *= stft_channels if use_stft_context else 1 assert exists(in_channels) and exists(out_channels) self.stft = STFT(**stft_kwargs) assert not kwargs, f"Unknown arguments: {', '.join(list(kwargs.keys()))}" self.to_in = Patcher( in_channels=in_channels + context_channels[0], out_channels=channels * multipliers[0], patch_size=patch_size, context_mapping_features=context_mapping_features, use_snake=use_snake ) self.downsamples = nn.ModuleList( [ DownsampleBlock1d( in_channels=channels * multipliers[i], out_channels=channels * multipliers[i + 1], context_mapping_features=context_mapping_features, context_channels=context_channels[i + 1], context_embedding_features=context_embedding_features, num_layers=num_blocks[i], factor=factors[i], kernel_multiplier=kernel_multiplier_downsample, num_groups=resnet_groups, use_pre_downsample=True, use_skip=True, use_snake=use_snake, num_transformer_blocks=attentions[i], **attention_kwargs, ) for i in range(num_layers) ] ) self.bottleneck = BottleneckBlock1d( channels=channels * multipliers[-1], context_mapping_features=context_mapping_features, context_embedding_features=context_embedding_features, num_groups=resnet_groups, num_transformer_blocks=attentions[-1], use_snake=use_snake, **attention_kwargs, ) self.upsamples = nn.ModuleList( [ UpsampleBlock1d( in_channels=channels * multipliers[i + 1], out_channels=channels * multipliers[i], context_mapping_features=context_mapping_features, context_embedding_features=context_embedding_features, num_layers=num_blocks[i] + (1 if attentions[i] else 0), factor=factors[i], use_nearest=use_nearest_upsample, num_groups=resnet_groups, use_skip_scale=use_skip_scale, use_pre_upsample=False, use_skip=True, use_snake=use_snake, skip_channels=channels * multipliers[i + 1], num_transformer_blocks=attentions[i], **attention_kwargs, ) for i in reversed(range(num_layers)) ] ) self.to_out = Unpatcher( in_channels=channels * multipliers[0], out_channels=out_channels, patch_size=patch_size, context_mapping_features=context_mapping_features, use_snake=use_snake ) def get_channels( self, channels_list: Optional[Sequence[Tensor]] = None, layer: int = 0 ) -> Optional[Tensor]: """Gets context channels at `layer` and checks that shape is correct""" use_context_channels = self.use_context_channels and self.has_context[layer] if not use_context_channels: return None assert exists(channels_list), "Missing context" # Get channels index (skipping zero channel contexts) channels_id = self.channels_ids[layer] # Get channels channels = channels_list[channels_id] message = f"Missing context for layer {layer} at index {channels_id}" assert exists(channels), message # Check channels num_channels = self.context_channels[layer] message = f"Expected context with {num_channels} channels at idx {channels_id}" assert channels.shape[1] == num_channels, message # STFT channels if requested channels = self.stft.encode1d(channels) if self.use_stft_context else channels # type: ignore # noqa return channels def get_mapping( self, time: Optional[Tensor] = None, features: Optional[Tensor] = None ) -> Optional[Tensor]: """Combines context time features and features into mapping""" items, mapping = [], None # Compute time features if self.use_context_time: assert_message = "use_context_time=True but no time features provided" assert exists(time), assert_message items += [self.to_time(time)] # Compute features if self.use_context_features: assert_message = "context_features exists but no features provided" assert exists(features), assert_message items += [self.to_features(features)] # Compute joint mapping if self.use_context_time or self.use_context_features: mapping = reduce(torch.stack(items), "n b m -> b m", "sum") mapping = self.to_mapping(mapping) return mapping def forward( self, x: Tensor, time: Optional[Tensor] = None, *, features: Optional[Tensor] = None, channels_list: Optional[Sequence[Tensor]] = None, embedding: Optional[Tensor] = None, embedding_mask: Optional[Tensor] = None, causal: Optional[bool] = False, ) -> Tensor: channels = self.get_channels(channels_list, layer=0) # Apply stft if required print(x.shape) x = self.stft.encode1d(x) if self.use_stft else x # type: ignore print(x.shape) # Concat context channels at layer 0 if provided x = torch.cat([x, channels], dim=1) if exists(channels) else x print(x.shape) # Compute mapping from time and features mapping = self.get_mapping(time, features) x = self.to_in(x, mapping, causal=causal) print(x.shape) skips_list = [x] for i, downsample in enumerate(self.downsamples): channels = self.get_channels(channels_list, layer=i + 1) x, skips = downsample( x, mapping=mapping, channels=channels, embedding=embedding, embedding_mask=embedding_mask, causal=causal ) skips_list += [skips] x = self.bottleneck(x, mapping=mapping, embedding=embedding, embedding_mask=embedding_mask, causal=causal) for i, upsample in enumerate(self.upsamples): skips = skips_list.pop() x = upsample(x, skips=skips, mapping=mapping, embedding=embedding, embedding_mask=embedding_mask, causal=causal) x += skips_list.pop() x = self.to_out(x, mapping, causal=causal) x = self.stft.decode1d(x) if self.use_stft else x return x """ Conditioning Modules """ class FixedEmbedding(nn.Module): def __init__(self, max_length: int, features: int): super().__init__() self.max_length = max_length self.embedding = nn.Embedding(max_length, features) def forward(self, x: Tensor) -> Tensor: batch_size, length, device = *x.shape[0:2], x.device assert_message = "Input sequence length must be <= max_length" assert length <= self.max_length, assert_message position = torch.arange(length, device=device) fixed_embedding = self.embedding(position) fixed_embedding = repeat(fixed_embedding, "n d -> b n d", b=batch_size) return fixed_embedding def rand_bool(shape: Any, proba: float, device: Any = None) -> Tensor: if proba == 1: return torch.ones(shape, device=device, dtype=torch.bool) elif proba == 0: return torch.zeros(shape, device=device, dtype=torch.bool) else: return torch.bernoulli(torch.full(shape, proba, device=device)).to(torch.bool) class UNetCFG1d(UNet1d): """UNet1d with Classifier-Free Guidance""" def __init__( self, context_embedding_max_length: int, context_embedding_features: int, use_xattn_time: bool = False, **kwargs, ): super().__init__( context_embedding_features=context_embedding_features, **kwargs ) self.use_xattn_time = use_xattn_time if use_xattn_time: assert exists(context_embedding_features) self.to_time_embedding = nn.Sequential( TimePositionalEmbedding( dim=kwargs["channels"], out_features=context_embedding_features ), nn.GELU(), ) context_embedding_max_length += 1 # Add one for time embedding self.fixed_embedding = FixedEmbedding( max_length=context_embedding_max_length, features=context_embedding_features ) def forward( # type: ignore self, x: Tensor, time: Tensor, *, embedding: Tensor, embedding_mask: Optional[Tensor] = None, embedding_scale: float = 1.0, embedding_mask_proba: float = 0.0, batch_cfg: bool = False, rescale_cfg: bool = False, scale_phi: float = 0.4, negative_embedding: Optional[Tensor] = None, negative_embedding_mask: Optional[Tensor] = None, **kwargs, ) -> Tensor: b, device = embedding.shape[0], embedding.device if self.use_xattn_time: embedding = torch.cat([embedding, self.to_time_embedding(time).unsqueeze(1)], dim=1) if embedding_mask is not None: embedding_mask = torch.cat([embedding_mask, torch.ones((b, 1), device=device)], dim=1) fixed_embedding = self.fixed_embedding(embedding) if embedding_mask_proba > 0.0: # Randomly mask embedding batch_mask = rand_bool( shape=(b, 1, 1), proba=embedding_mask_proba, device=device ) embedding = torch.where(batch_mask, fixed_embedding, embedding) if embedding_scale != 1.0: if batch_cfg: batch_x = torch.cat([x, x], dim=0) batch_time = torch.cat([time, time], dim=0) if negative_embedding is not None: if negative_embedding_mask is not None: negative_embedding_mask = negative_embedding_mask.to(torch.bool).unsqueeze(2) negative_embedding = torch.where(negative_embedding_mask, negative_embedding, fixed_embedding) batch_embed = torch.cat([embedding, negative_embedding], dim=0) else: batch_embed = torch.cat([embedding, fixed_embedding], dim=0) batch_mask = None if embedding_mask is not None: batch_mask = torch.cat([embedding_mask, embedding_mask], dim=0) batch_features = None features = kwargs.pop("features", None) if self.use_context_features: batch_features = torch.cat([features, features], dim=0) batch_channels = None channels_list = kwargs.pop("channels_list", None) if self.use_context_channels: batch_channels = [] for channels in channels_list: batch_channels += [torch.cat([channels, channels], dim=0)] # Compute both normal and fixed embedding outputs batch_out = super().forward(batch_x, batch_time, embedding=batch_embed, embedding_mask=batch_mask, features=batch_features, channels_list=batch_channels, **kwargs) out, out_masked = batch_out.chunk(2, dim=0) else: # Compute both normal and fixed embedding outputs out = super().forward(x, time, embedding=embedding, embedding_mask=embedding_mask, **kwargs) out_masked = super().forward(x, time, embedding=fixed_embedding, embedding_mask=embedding_mask, **kwargs) out_cfg = out_masked + (out - out_masked) * embedding_scale if rescale_cfg: out_std = out.std(dim=1, keepdim=True) out_cfg_std = out_cfg.std(dim=1, keepdim=True) return scale_phi * (out_cfg * (out_std/out_cfg_std)) + (1-scale_phi) * out_cfg else: return out_cfg else: return super().forward(x, time, embedding=embedding, embedding_mask=embedding_mask, **kwargs) class UNetNCCA1d(UNet1d): """UNet1d with Noise Channel Conditioning Augmentation""" def __init__(self, context_features: int, **kwargs): super().__init__(context_features=context_features, **kwargs) self.embedder = NumberEmbedder(features=context_features) def expand(self, x: Any, shape: Tuple[int, ...]) -> Tensor: x = x if torch.is_tensor(x) else torch.tensor(x) return x.expand(shape) def forward( # type: ignore self, x: Tensor, time: Tensor, *, channels_list: Sequence[Tensor], channels_augmentation: Union[ bool, Sequence[bool], Sequence[Sequence[bool]], Tensor ] = False, channels_scale: Union[ float, Sequence[float], Sequence[Sequence[float]], Tensor ] = 0, **kwargs, ) -> Tensor: b, n = x.shape[0], len(channels_list) channels_augmentation = self.expand(channels_augmentation, shape=(b, n)).to(x) channels_scale = self.expand(channels_scale, shape=(b, n)).to(x) # Augmentation (for each channel list item) for i in range(n): scale = channels_scale[:, i] * channels_augmentation[:, i] scale = rearrange(scale, "b -> b 1 1") item = channels_list[i] channels_list[i] = torch.randn_like(item) * scale + item * (1 - scale) # type: ignore # noqa # Scale embedding (sum reduction if more than one channel list item) channels_scale_emb = self.embedder(channels_scale) channels_scale_emb = reduce(channels_scale_emb, "b n d -> b d", "sum") return super().forward( x=x, time=time, channels_list=channels_list, features=channels_scale_emb, **kwargs, ) class UNetAll1d(UNetCFG1d, UNetNCCA1d): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def forward(self, *args, **kwargs): # type: ignore return UNetCFG1d.forward(self, *args, **kwargs) def XUNet1d(type: str = "base", **kwargs) -> UNet1d: if type == "base": return UNet1d(**kwargs) elif type == "all": return UNetAll1d(**kwargs) elif type == "cfg": return UNetCFG1d(**kwargs) elif type == "ncca": return UNetNCCA1d(**kwargs) else: raise ValueError(f"Unknown XUNet1d type: {type}") class NumberEmbedder(nn.Module): def __init__( self, features: int, dim: int = 256, ): super().__init__() self.features = features self.embedding = TimePositionalEmbedding(dim=dim, out_features=features) def forward(self, x: Union[List[float], Tensor]) -> Tensor: if not torch.is_tensor(x): device = next(self.embedding.parameters()).device x = torch.tensor(x, device=device) assert isinstance(x, Tensor) shape = x.shape x = rearrange(x, "... -> (...)") embedding = self.embedding(x) x = embedding.view(*shape, self.features) return x # type: ignore """ Audio Transforms """ class STFT(nn.Module): """Helper for torch stft and istft""" def __init__( self, num_fft: int = 1023, hop_length: int = 256, window_length: Optional[int] = None, length: Optional[int] = None, use_complex: bool = False, ): super().__init__() self.num_fft = num_fft self.hop_length = default(hop_length, floor(num_fft // 4)) self.window_length = default(window_length, num_fft) self.length = length self.register_buffer("window", torch.hann_window(self.window_length)) self.use_complex = use_complex def encode(self, wave: Tensor) -> Tuple[Tensor, Tensor]: b = wave.shape[0] wave = rearrange(wave, "b c t -> (b c) t") stft = torch.stft( wave, n_fft=self.num_fft, hop_length=self.hop_length, win_length=self.window_length, window=self.window, # type: ignore return_complex=True, normalized=True, ) if self.use_complex: # Returns real and imaginary stft_a, stft_b = stft.real, stft.imag else: # Returns magnitude and phase matrices magnitude, phase = torch.abs(stft), torch.angle(stft) stft_a, stft_b = magnitude, phase return rearrange_many((stft_a, stft_b), "(b c) f l -> b c f l", b=b) def decode(self, stft_a: Tensor, stft_b: Tensor) -> Tensor: b, l = stft_a.shape[0], stft_a.shape[-1] # noqa length = closest_power_2(l * self.hop_length) stft_a, stft_b = rearrange_many((stft_a, stft_b), "b c f l -> (b c) f l") if self.use_complex: real, imag = stft_a, stft_b else: magnitude, phase = stft_a, stft_b real, imag = magnitude * torch.cos(phase), magnitude * torch.sin(phase) stft = torch.stack([real, imag], dim=-1) wave = torch.istft( stft, n_fft=self.num_fft, hop_length=self.hop_length, win_length=self.window_length, window=self.window, # type: ignore length=default(self.length, length), normalized=True, ) return rearrange(wave, "(b c) t -> b c t", b=b) def encode1d( self, wave: Tensor, stacked: bool = True ) -> Union[Tensor, Tuple[Tensor, Tensor]]: stft_a, stft_b = self.encode(wave) stft_a, stft_b = rearrange_many((stft_a, stft_b), "b c f l -> b (c f) l") return torch.cat((stft_a, stft_b), dim=1) if stacked else (stft_a, stft_b) def decode1d(self, stft_pair: Tensor) -> Tensor: f = self.num_fft // 2 + 1 stft_a, stft_b = stft_pair.chunk(chunks=2, dim=1) stft_a, stft_b = rearrange_many((stft_a, stft_b), "b (c f) l -> b c f l", f=f) return self.decode(stft_a, stft_b)