# Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import math import typing as tp import julius import torch from torch import nn from torch.nn import functional as F from .states import capture_init from .utils import center_trim, unfold class BLSTM(nn.Module): """ BiLSTM with same hidden units as input dim. If `max_steps` is not None, input will be splitting in overlapping chunks and the LSTM applied separately on each chunk. """ def __init__(self, dim, layers=1, max_steps=None, skip=False): super().__init__() assert max_steps is None or max_steps % 4 == 0 self.max_steps = max_steps self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim) self.linear = nn.Linear(2 * dim, dim) self.skip = skip def forward(self, x): B, C, T = x.shape y = x framed = False if self.max_steps is not None and T > self.max_steps: width = self.max_steps stride = width // 2 frames = unfold(x, width, stride) nframes = frames.shape[2] framed = True x = frames.permute(0, 2, 1, 3).reshape(-1, C, width) x = x.permute(2, 0, 1) x = self.lstm(x)[0] x = self.linear(x) x = x.permute(1, 2, 0) if framed: out = [] frames = x.reshape(B, -1, C, width) limit = stride // 2 for k in range(nframes): if k == 0: out.append(frames[:, k, :, :-limit]) elif k == nframes - 1: out.append(frames[:, k, :, limit:]) else: out.append(frames[:, k, :, limit:-limit]) out = torch.cat(out, -1) out = out[..., :T] x = out if self.skip: x = x + y return x def rescale_conv(conv, reference): """Rescale initial weight scale. It is unclear why it helps but it certainly does. """ std = conv.weight.std().detach() scale = (std / reference)**0.5 conv.weight.data /= scale if conv.bias is not None: conv.bias.data /= scale def rescale_module(module, reference): for sub in module.modules(): if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d, nn.Conv2d, nn.ConvTranspose2d)): rescale_conv(sub, reference) class LayerScale(nn.Module): """Layer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf). This rescales diagonaly residual outputs close to 0 initially, then learnt. """ def __init__(self, channels: int, init: float = 0): super().__init__() self.scale = nn.Parameter(torch.zeros(channels, requires_grad=True)) self.scale.data[:] = init def forward(self, x): return self.scale[:, None] * x class DConv(nn.Module): """ New residual branches in each encoder layer. This alternates dilated convolutions, potentially with LSTMs and attention. Also before entering each residual branch, dimension is projected on a smaller subspace, e.g. of dim `channels // compress`. """ def __init__(self, channels: int, compress: float = 4, depth: int = 2, init: float = 1e-4, norm=True, attn=False, heads=4, ndecay=4, lstm=False, gelu=True, kernel=3, dilate=True): """ Args: channels: input/output channels for residual branch. compress: amount of channel compression inside the branch. depth: number of layers in the residual branch. Each layer has its own projection, and potentially LSTM and attention. init: initial scale for LayerNorm. norm: use GroupNorm. attn: use LocalAttention. heads: number of heads for the LocalAttention. ndecay: number of decay controls in the LocalAttention. lstm: use LSTM. gelu: Use GELU activation. kernel: kernel size for the (dilated) convolutions. dilate: if true, use dilation, increasing with the depth. """ super().__init__() assert kernel % 2 == 1 self.channels = channels self.compress = compress self.depth = abs(depth) dilate = depth > 0 norm_fn: tp.Callable[[int], nn.Module] norm_fn = lambda d: nn.Identity() # noqa if norm: norm_fn = lambda d: nn.GroupNorm(1, d) # noqa hidden = int(channels / compress) act: tp.Type[nn.Module] if gelu: act = nn.GELU else: act = nn.ReLU self.layers = nn.ModuleList([]) for d in range(self.depth): dilation = 2 ** d if dilate else 1 padding = dilation * (kernel // 2) mods = [ nn.Conv1d(channels, hidden, kernel, dilation=dilation, padding=padding), norm_fn(hidden), act(), nn.Conv1d(hidden, 2 * channels, 1), norm_fn(2 * channels), nn.GLU(1), LayerScale(channels, init), ] if attn: mods.insert(3, LocalState(hidden, heads=heads, ndecay=ndecay)) if lstm: mods.insert(3, BLSTM(hidden, layers=2, max_steps=200, skip=True)) layer = nn.Sequential(*mods) self.layers.append(layer) def forward(self, x): for layer in self.layers: x = x + layer(x) return x class LocalState(nn.Module): """Local state allows to have attention based only on data (no positional embedding), but while setting a constraint on the time window (e.g. decaying penalty term). Also a failed experiments with trying to provide some frequency based attention. """ def __init__(self, channels: int, heads: int = 4, nfreqs: int = 0, ndecay: int = 4): super().__init__() assert channels % heads == 0, (channels, heads) self.heads = heads self.nfreqs = nfreqs self.ndecay = ndecay self.content = nn.Conv1d(channels, channels, 1) self.query = nn.Conv1d(channels, channels, 1) self.key = nn.Conv1d(channels, channels, 1) if nfreqs: self.query_freqs = nn.Conv1d(channels, heads * nfreqs, 1) if ndecay: self.query_decay = nn.Conv1d(channels, heads * ndecay, 1) # Initialize decay close to zero (there is a sigmoid), for maximum initial window. self.query_decay.weight.data *= 0.01 assert self.query_decay.bias is not None # stupid type checker self.query_decay.bias.data[:] = -2 self.proj = nn.Conv1d(channels + heads * nfreqs, channels, 1) def forward(self, x): B, C, T = x.shape heads = self.heads indexes = torch.arange(T, device=x.device, dtype=x.dtype) # left index are keys, right index are queries delta = indexes[:, None] - indexes[None, :] queries = self.query(x).view(B, heads, -1, T) keys = self.key(x).view(B, heads, -1, T) # t are keys, s are queries dots = torch.einsum("bhct,bhcs->bhts", keys, queries) dots /= keys.shape[2]**0.5 if self.nfreqs: periods = torch.arange(1, self.nfreqs + 1, device=x.device, dtype=x.dtype) freq_kernel = torch.cos(2 * math.pi * delta / periods.view(-1, 1, 1)) freq_q = self.query_freqs(x).view(B, heads, -1, T) / self.nfreqs ** 0.5 dots += torch.einsum("fts,bhfs->bhts", freq_kernel, freq_q) if self.ndecay: decays = torch.arange(1, self.ndecay + 1, device=x.device, dtype=x.dtype) decay_q = self.query_decay(x).view(B, heads, -1, T) decay_q = torch.sigmoid(decay_q) / 2 decay_kernel = - decays.view(-1, 1, 1) * delta.abs() / self.ndecay**0.5 dots += torch.einsum("fts,bhfs->bhts", decay_kernel, decay_q) # Kill self reference. dots.masked_fill_(torch.eye(T, device=dots.device, dtype=torch.bool), -100) weights = torch.softmax(dots, dim=2) content = self.content(x).view(B, heads, -1, T) result = torch.einsum("bhts,bhct->bhcs", weights, content) if self.nfreqs: time_sig = torch.einsum("bhts,fts->bhfs", weights, freq_kernel) result = torch.cat([result, time_sig], 2) result = result.reshape(B, -1, T) return x + self.proj(result) class Demucs(nn.Module): @capture_init def __init__(self, sources, # Channels audio_channels=2, channels=64, growth=2., # Main structure depth=6, rewrite=True, lstm_layers=0, # Convolutions kernel_size=8, stride=4, context=1, # Activations gelu=True, glu=True, # Normalization norm_starts=4, norm_groups=4, # DConv residual branch dconv_mode=1, dconv_depth=2, dconv_comp=4, dconv_attn=4, dconv_lstm=4, dconv_init=1e-4, # Pre/post processing normalize=True, resample=True, # Weight init rescale=0.1, # Metadata samplerate=44100, segment=4 * 10): """ Args: sources (list[str]): list of source names audio_channels (int): stereo or mono channels (int): first convolution channels depth (int): number of encoder/decoder layers growth (float): multiply (resp divide) number of channels by that for each layer of the encoder (resp decoder) depth (int): number of layers in the encoder and in the decoder. rewrite (bool): add 1x1 convolution to each layer. lstm_layers (int): number of lstm layers, 0 = no lstm. Deactivated by default, as this is now replaced by the smaller and faster small LSTMs in the DConv branches. kernel_size (int): kernel size for convolutions stride (int): stride for convolutions context (int): kernel size of the convolution in the decoder before the transposed convolution. If > 1, will provide some context from neighboring time steps. gelu: use GELU activation function. glu (bool): use glu instead of ReLU for the 1x1 rewrite conv. norm_starts: layer at which group norm starts being used. decoder layers are numbered in reverse order. norm_groups: number of groups for group norm. dconv_mode: if 1: dconv in encoder only, 2: decoder only, 3: both. dconv_depth: depth of residual DConv branch. dconv_comp: compression of DConv branch. dconv_attn: adds attention layers in DConv branch starting at this layer. dconv_lstm: adds a LSTM layer in DConv branch starting at this layer. dconv_init: initial scale for the DConv branch LayerScale. normalize (bool): normalizes the input audio on the fly, and scales back the output by the same amount. resample (bool): upsample x2 the input and downsample /2 the output. rescale (int): rescale initial weights of convolutions to get their standard deviation closer to `rescale`. samplerate (int): stored as meta information for easing future evaluations of the model. segment (float): duration of the chunks of audio to ideally evaluate the model on. This is used by `demucs.apply.apply_model`. """ super().__init__() self.audio_channels = audio_channels self.sources = sources self.kernel_size = kernel_size self.context = context self.stride = stride self.depth = depth self.resample = resample self.channels = channels self.normalize = normalize self.samplerate = samplerate self.segment = segment self.encoder = nn.ModuleList() self.decoder = nn.ModuleList() self.skip_scales = nn.ModuleList() if glu: activation = nn.GLU(dim=1) ch_scale = 2 else: activation = nn.ReLU() ch_scale = 1 if gelu: act2 = nn.GELU else: act2 = nn.ReLU in_channels = audio_channels padding = 0 for index in range(depth): norm_fn = lambda d: nn.Identity() # noqa if index >= norm_starts: norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa encode = [] encode += [ nn.Conv1d(in_channels, channels, kernel_size, stride), norm_fn(channels), act2(), ] attn = index >= dconv_attn lstm = index >= dconv_lstm if dconv_mode & 1: encode += [DConv(channels, depth=dconv_depth, init=dconv_init, compress=dconv_comp, attn=attn, lstm=lstm)] if rewrite: encode += [ nn.Conv1d(channels, ch_scale * channels, 1), norm_fn(ch_scale * channels), activation] self.encoder.append(nn.Sequential(*encode)) decode = [] if index > 0: out_channels = in_channels else: out_channels = len(self.sources) * audio_channels if rewrite: decode += [ nn.Conv1d(channels, ch_scale * channels, 2 * context + 1, padding=context), norm_fn(ch_scale * channels), activation] if dconv_mode & 2: decode += [DConv(channels, depth=dconv_depth, init=dconv_init, compress=dconv_comp, attn=attn, lstm=lstm)] decode += [nn.ConvTranspose1d(channels, out_channels, kernel_size, stride, padding=padding)] if index > 0: decode += [norm_fn(out_channels), act2()] self.decoder.insert(0, nn.Sequential(*decode)) in_channels = channels channels = int(growth * channels) channels = in_channels if lstm_layers: self.lstm = BLSTM(channels, lstm_layers) else: self.lstm = None if rescale: rescale_module(self, reference=rescale) def valid_length(self, length): """ Return the nearest valid length to use with the model so that there is no time steps left over in a convolution, e.g. for all layers, size of the input - kernel_size % stride = 0. Note that input are automatically padded if necessary to ensure that the output has the same length as the input. """ if self.resample: length *= 2 for _ in range(self.depth): length = math.ceil((length - self.kernel_size) / self.stride) + 1 length = max(1, length) for idx in range(self.depth): length = (length - 1) * self.stride + self.kernel_size if self.resample: length = math.ceil(length / 2) return int(length) def forward(self, mix): x = mix length = x.shape[-1] if self.normalize: mono = mix.mean(dim=1, keepdim=True) mean = mono.mean(dim=-1, keepdim=True) std = mono.std(dim=-1, keepdim=True) x = (x - mean) / (1e-5 + std) else: mean = 0 std = 1 delta = self.valid_length(length) - length x = F.pad(x, (delta // 2, delta - delta // 2)) if self.resample: x = julius.resample_frac(x, 1, 2) saved = [] for encode in self.encoder: x = encode(x) saved.append(x) if self.lstm: x = self.lstm(x) for decode in self.decoder: skip = saved.pop(-1) skip = center_trim(skip, x) x = decode(x + skip) if self.resample: x = julius.resample_frac(x, 2, 1) x = x * std + mean x = center_trim(x, length) x = x.view(x.size(0), len(self.sources), self.audio_channels, x.size(-1)) return x def load_state_dict(self, state, strict=True): # fix a mismatch with previous generation Demucs models. for idx in range(self.depth): for a in ['encoder', 'decoder']: for b in ['bias', 'weight']: new = f'{a}.{idx}.3.{b}' old = f'{a}.{idx}.2.{b}' if old in state and new not in state: state[new] = state.pop(old) super().load_state_dict(state, strict=strict)