# Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # author: adefossez import math import time import torch as th from torch import nn from torch.nn import functional as F from .resample import downsample2, upsample2 from .utils import capture_init class BLSTM(nn.Module): def __init__(self, dim, layers=2, bi=True): super().__init__() klass = nn.LSTM self.lstm = klass( bidirectional=bi, num_layers=layers, hidden_size=dim, input_size=dim ) self.linear = None if bi: self.linear = nn.Linear(2 * dim, dim) def forward(self, x, hidden=None): x, hidden = self.lstm(x, hidden) if self.linear: x = self.linear(x) return x, hidden def rescale_conv(conv, reference): std = conv.weight.std().detach() scale = (std / reference)**0.5 conv.weight.data /= scale if conv.bias is not None: conv.bias.data /= scale def rescale_module(module, reference): for sub in module.modules(): if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d)): rescale_conv(sub, reference) class Demucs(nn.Module): """ Demucs speech enhancement model. Args: - chin (int): number of input channels. - chout (int): number of output channels. - hidden (int): number of initial hidden channels. - depth (int): number of layers. - kernel_size (int): kernel size for each layer. - stride (int): stride for each layer. - causal (bool): if false, uses BiLSTM instead of LSTM. - resample (int): amount of resampling to apply to the input/output. Can be one of 1, 2 or 4. - growth (float): number of channels is multiplied by this for every layer. - max_hidden (int): maximum number of channels. Can be useful to control the size/speed of the model. - normalize (bool): if true, normalize the input. - glu (bool): if true uses GLU instead of ReLU in 1x1 convolutions. - rescale (float): controls custom weight initialization. See https://arxiv.org/abs/1911.13254. - floor (float): stability flooring when normalizing. """ @capture_init def __init__(self, chin=1, chout=1, hidden=48, depth=5, kernel_size=8, stride=4, causal=True, resample=4, growth=2, max_hidden=10_000, normalize=True, glu=True, rescale=0.1, floor=1e-3): super().__init__() if resample not in [1, 2, 4]: raise ValueError("Resample should be 1, 2 or 4.") self.chin = chin self.chout = chout self.hidden = hidden self.depth = depth self.kernel_size = kernel_size self.stride = stride self.causal = causal self.floor = floor self.resample = resample self.normalize = normalize self.encoder = nn.ModuleList() self.decoder = nn.ModuleList() activation = nn.GLU(1) if glu else nn.ReLU() ch_scale = 2 if glu else 1 for index in range(depth): encode = [] encode += [ nn.Conv1d(chin, hidden, kernel_size, stride), nn.ReLU(), nn.Conv1d(hidden, hidden * ch_scale, 1), activation, ] self.encoder.append(nn.Sequential(*encode)) decode = [] decode += [ nn.Conv1d(hidden, ch_scale * hidden, 1), activation, nn.ConvTranspose1d(hidden, chout, kernel_size, stride), ] if index > 0: decode.append(nn.ReLU()) self.decoder.insert(0, nn.Sequential(*decode)) chout = hidden chin = hidden hidden = min(int(growth * hidden), max_hidden) self.lstm = BLSTM(chin, bi=not causal) if rescale: rescale_module(self, reference=rescale) def valid_length(self, length): """ Return the nearest valid length to use with the model so that there is no time steps left over in a convolutions, e.g. for all layers, size of the input - kernel_size % stride = 0. If the mixture has a valid length, the estimated sources will have exactly the same length. """ length = math.ceil(length * self.resample) for _ in range(self.depth): length = math.ceil((length - self.kernel_size) / self.stride) + 1 length = max(length, 1) for _ in range(self.depth): length = (length - 1) * self.stride + self.kernel_size length = int(math.ceil(length / self.resample)) return int(length) @property def total_stride(self): return self.stride ** self.depth // self.resample def forward(self, mix): if mix.dim() == 2: mix = mix.unsqueeze(1) if self.normalize: mono = mix.mean(dim=1, keepdim=True) std = mono.std(dim=-1, keepdim=True) mix = mix / (self.floor + std) else: std = 1 length = mix.shape[-1] x = mix x = F.pad(x, (0, self.valid_length(length) - length)) if self.resample == 2: x = upsample2(x) elif self.resample == 4: x = upsample2(x) x = upsample2(x) skips = [] for encode in self.encoder: x = encode(x) skips.append(x) x = x.permute(2, 0, 1) x, _ = self.lstm(x) x = x.permute(1, 2, 0) for decode in self.decoder: skip = skips.pop(-1) x = x + skip[..., :x.shape[-1]] x = decode(x) if self.resample == 2: x = downsample2(x) elif self.resample == 4: x = downsample2(x) x = downsample2(x) x = x[..., :length] return std * x def fast_conv(conv, x): """ Faster convolution evaluation if either kernel size is 1 or length of sequence is 1. """ batch, chin, length = x.shape chout, chin, kernel = conv.weight.shape assert batch == 1 if kernel == 1: x = x.view(chin, length) out = th.addmm(conv.bias.view(-1, 1), conv.weight.view(chout, chin), x) elif length == kernel: x = x.view(chin * kernel, 1) out = th.addmm(conv.bias.view(-1, 1), conv.weight.view(chout, chin * kernel), x) else: out = conv(x) return out.view(batch, chout, -1) class DemucsStreamer: """ Streaming implementation for Demucs. It supports being fed with any amount of audio at a time. You will get back as much audio as possible at that point. Args: - demucs (Demucs): Demucs model. - dry (float): amount of dry (e.g. input) signal to keep. 0 is maximum noise removal, 1 just returns the input signal. Small values > 0 allows to limit distortions. - num_frames (int): number of frames to process at once. Higher values will increase overall latency but improve the real time factor. - resample_lookahead (int): extra lookahead used for the resampling. - resample_buffer (int): size of the buffer of previous inputs/outputs kept for resampling. """ def __init__(self, demucs, dry=0, num_frames=1, resample_lookahead=64, resample_buffer=256): device = next(iter(demucs.parameters())).device self.demucs = demucs self.lstm_state = None self.conv_state = None self.dry = dry self.resample_lookahead = resample_lookahead resample_buffer = min(demucs.total_stride, resample_buffer) self.resample_buffer = resample_buffer self.frame_length = demucs.valid_length(1) + \ demucs.total_stride * (num_frames - 1) self.total_length = self.frame_length + self.resample_lookahead self.stride = demucs.total_stride * num_frames self.resample_in = th.zeros(demucs.chin, resample_buffer, device=device) self.resample_out = th.zeros( demucs.chin, resample_buffer, device=device ) self.frames = 0 self.total_time = 0 self.variance = 0 self.pending = th.zeros(demucs.chin, 0, device=device) bias = demucs.decoder[0][2].bias weight = demucs.decoder[0][2].weight chin, chout, kernel = weight.shape self._bias = bias.view(-1, 1).repeat(1, kernel).view(-1, 1) self._weight = weight.permute(1, 2, 0).contiguous() def reset_time_per_frame(self): self.total_time = 0 self.frames = 0 @property def time_per_frame(self): return self.total_time / self.frames def flush(self): """ Flush remaining audio by padding it with zero. Call this when you have no more input and want to get back the last chunk of audio. """ pending_length = self.pending.shape[1] padding = th.zeros( self.demucs.chin, self.total_length, device=self.pending.device ) out = self.feed(padding) return out[:, :pending_length] def feed(self, wav): """ Apply the model to mix using true real time evaluation. Normalization is done online as is the resampling. """ begin = time.time() demucs = self.demucs resample_buffer = self.resample_buffer stride = self.stride resample = demucs.resample if wav.dim() != 2: raise ValueError("input wav should be two dimensional.") chin, _ = wav.shape if chin != demucs.chin: raise ValueError(f"Expected {demucs.chin} channels, got {chin}") self.pending = th.cat([self.pending, wav], dim=1) outs = [] while self.pending.shape[1] >= self.total_length: self.frames += 1 frame = self.pending[:, :self.total_length] dry_signal = frame[:, :stride] if demucs.normalize: mono = frame.mean(0) variance = (mono**2).mean() self.variance = variance / self.frames + \ (1 - 1 / self.frames) * self.variance frame = frame / (demucs.floor + math.sqrt(self.variance)) frame = th.cat([self.resample_in, frame], dim=-1) self.resample_in[:] = frame[:, stride - resample_buffer:stride] if resample == 4: frame = upsample2(upsample2(frame)) elif resample == 2: frame = upsample2(frame) # remove pre sampling buffer frame = frame[:, resample * resample_buffer:] # remove extra samples after window frame = frame[:, :resample * self.frame_length] out, extra = self._separate_frame(frame) padded_out = th.cat([self.resample_out, out, extra], 1) self.resample_out[:] = out[:, -resample_buffer:] if resample == 4: out = downsample2(downsample2(padded_out)) elif resample == 2: out = downsample2(padded_out) else: out = padded_out out = out[:, resample_buffer // resample:] out = out[:, :stride] if demucs.normalize: out *= math.sqrt(self.variance) out = self.dry * dry_signal + (1 - self.dry) * out outs.append(out) self.pending = self.pending[:, stride:] self.total_time += time.time() - begin if outs: out = th.cat(outs, 1) else: out = th.zeros(chin, 0, device=wav.device) return out def _separate_frame(self, frame): demucs = self.demucs skips = [] next_state = [] first = self.conv_state is None stride = self.stride * demucs.resample x = frame[None] for idx, encode in enumerate(demucs.encoder): stride //= demucs.stride length = x.shape[2] if idx == demucs.depth - 1: # This is sligthly faster for the last conv x = fast_conv(encode[0], x) x = encode[1](x) x = fast_conv(encode[2], x) x = encode[3](x) else: if not first: prev = self.conv_state.pop(0) prev = prev[..., stride:] tgt = (length - demucs.kernel_size) // demucs.stride + 1 missing = tgt - prev.shape[-1] offset = length - demucs.kernel_size - \ demucs.stride * (missing - 1) x = x[..., offset:] x = encode[1](encode[0](x)) x = fast_conv(encode[2], x) x = encode[3](x) if not first: x = th.cat([prev, x], -1) next_state.append(x) skips.append(x) x = x.permute(2, 0, 1) x, self.lstm_state = demucs.lstm(x, self.lstm_state) x = x.permute(1, 2, 0) # In the following, x contains only correct samples, i.e. the one # for which each time position is covered by two window of the upper # layer. extra contains extra samples to the right, and is used only as # a better padding for the online resampling. extra = None for idx, decode in enumerate(demucs.decoder): skip = skips.pop(-1) x += skip[..., :x.shape[-1]] x = fast_conv(decode[0], x) x = decode[1](x) if extra is not None: skip = skip[..., x.shape[-1]:] extra += skip[..., :extra.shape[-1]] extra = decode[2](decode[1](decode[0](extra))) x = decode[2](x) next_state.append( x[..., -demucs.stride:] - decode[2].bias.view(-1, 1) ) if extra is None: extra = x[..., -demucs.stride:] else: extra[..., :demucs.stride] += next_state[-1] x = x[..., :-demucs.stride] if not first: prev = self.conv_state.pop(0) x[..., :demucs.stride] += prev if idx != demucs.depth - 1: x = decode[3](x) extra = decode[3](extra) self.conv_state = next_state return x[0], extra[0] def test(): import argparse parser = argparse.ArgumentParser( "denoiser.demucs", description="Benchmark the streaming Demucs implementation, as well as " "checking the delta with the offline implementation.") parser.add_argument("--depth", default=5, type=int) parser.add_argument("--resample", default=4, type=int) parser.add_argument("--hidden", default=48, type=int) parser.add_argument("--sample_rate", default=16000, type=float) parser.add_argument("--device", default="cpu") parser.add_argument("-t", "--num_threads", type=int) parser.add_argument("-f", "--num_frames", type=int, default=1) args = parser.parse_args() if args.num_threads: th.set_num_threads(args.num_threads) sr = args.sample_rate sr_ms = sr / 1000 demucs = Demucs( depth=args.depth, hidden=args.hidden, resample=args.resample ).to(args.device) x = th.randn(1, int(sr * 4)).to(args.device) out = demucs(x[None])[0] streamer = DemucsStreamer(demucs, num_frames=args.num_frames) out_rt = [] frame_size = streamer.total_length with th.no_grad(): while x.shape[1] > 0: out_rt.append(streamer.feed(x[:, :frame_size])) x = x[:, frame_size:] frame_size = streamer.demucs.total_stride out_rt.append(streamer.flush()) out_rt = th.cat(out_rt, 1) model_size = sum(p.numel() for p in demucs.parameters()) * 4 / 2**20 initial_lag = streamer.total_length / sr_ms tpf = 1000 * streamer.time_per_frame print(f"model size: {model_size:.1f}MB, ", end='') print(f"delta batch/streaming: {th.norm(out - out_rt) / th.norm(out):.2%}") print(f"initial lag: {initial_lag:.1f}ms, ", end='') print(f"stride: {streamer.stride * args.num_frames / sr_ms:.1f}ms") print(f"time per frame: {tpf:.1f}ms, ", end='') rtf = (1000 * streamer.time_per_frame) / (streamer.stride / sr_ms) print(f"RTF: {rtf:.2f}") print(f"Total lag with computation: {initial_lag + tpf:.1f}ms") if __name__ == "__main__": test()