|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import math |
|
import time |
|
|
|
import torch as th |
|
from torch import nn |
|
from torch.nn import functional as F |
|
|
|
from .resample import downsample2, upsample2 |
|
from .utils import capture_init |
|
|
|
|
|
class BLSTM(nn.Module): |
|
def __init__(self, dim, layers=2, bi=True): |
|
super().__init__() |
|
klass = nn.LSTM |
|
self.lstm = klass(bidirectional=bi, num_layers=layers, hidden_size=dim, input_size=dim) |
|
self.linear = None |
|
if bi: |
|
self.linear = nn.Linear(2 * dim, dim) |
|
|
|
def forward(self, x, hidden=None): |
|
x, hidden = self.lstm(x, hidden) |
|
if self.linear: |
|
x = self.linear(x) |
|
return x, hidden |
|
|
|
|
|
def rescale_conv(conv, reference): |
|
std = conv.weight.std().detach() |
|
scale = (std / reference)**0.5 |
|
conv.weight.data /= scale |
|
if conv.bias is not None: |
|
conv.bias.data /= scale |
|
|
|
|
|
def rescale_module(module, reference): |
|
for sub in module.modules(): |
|
if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d)): |
|
rescale_conv(sub, reference) |
|
|
|
|
|
class Demucs(nn.Module): |
|
""" |
|
Demucs speech enhancement model. |
|
Args: |
|
- chin (int): number of input channels. |
|
- chout (int): number of output channels. |
|
- hidden (int): number of initial hidden channels. |
|
- depth (int): number of layers. |
|
- kernel_size (int): kernel size for each layer. |
|
- stride (int): stride for each layer. |
|
- causal (bool): if false, uses BiLSTM instead of LSTM. |
|
- resample (int): amount of resampling to apply to the input/output. |
|
Can be one of 1, 2 or 4. |
|
- growth (float): number of channels is multiplied by this for every layer. |
|
- max_hidden (int): maximum number of channels. Can be useful to |
|
control the size/speed of the model. |
|
- normalize (bool): if true, normalize the input. |
|
- glu (bool): if true uses GLU instead of ReLU in 1x1 convolutions. |
|
- rescale (float): controls custom weight initialization. |
|
See https://arxiv.org/abs/1911.13254. |
|
- floor (float): stability flooring when normalizing. |
|
|
|
""" |
|
@capture_init |
|
def __init__(self, |
|
chin=1, |
|
chout=1, |
|
hidden=48, |
|
depth=5, |
|
kernel_size=8, |
|
stride=4, |
|
causal=True, |
|
resample=4, |
|
growth=2, |
|
max_hidden=10_000, |
|
normalize=True, |
|
glu=True, |
|
rescale=0.1, |
|
floor=1e-3): |
|
|
|
super().__init__() |
|
if resample not in [1, 2, 4]: |
|
raise ValueError("Resample should be 1, 2 or 4.") |
|
|
|
self.chin = chin |
|
self.chout = chout |
|
self.hidden = hidden |
|
self.depth = depth |
|
self.kernel_size = kernel_size |
|
self.stride = stride |
|
self.causal = causal |
|
self.floor = floor |
|
self.resample = resample |
|
self.normalize = normalize |
|
|
|
self.encoder = nn.ModuleList() |
|
self.decoder = nn.ModuleList() |
|
activation = nn.GLU(1) if glu else nn.ReLU() |
|
ch_scale = 2 if glu else 1 |
|
|
|
for index in range(depth): |
|
encode = [] |
|
encode += [ |
|
nn.Conv1d(chin, hidden, kernel_size, stride), |
|
nn.ReLU(), |
|
nn.Conv1d(hidden, hidden * ch_scale, 1), activation, |
|
] |
|
self.encoder.append(nn.Sequential(*encode)) |
|
|
|
decode = [] |
|
decode += [ |
|
nn.Conv1d(hidden, ch_scale * hidden, 1), activation, |
|
nn.ConvTranspose1d(hidden, chout, kernel_size, stride), |
|
] |
|
if index > 0: |
|
decode.append(nn.ReLU()) |
|
self.decoder.insert(0, nn.Sequential(*decode)) |
|
chout = hidden |
|
chin = hidden |
|
hidden = min(int(growth * hidden), max_hidden) |
|
|
|
self.lstm = BLSTM(chin, bi=not causal) |
|
if rescale: |
|
rescale_module(self, reference=rescale) |
|
|
|
def valid_length(self, length): |
|
""" |
|
Return the nearest valid length to use with the model so that |
|
there is no time steps left over in a convolutions, e.g. for all |
|
layers, size of the input - kernel_size % stride = 0. |
|
|
|
If the mixture has a valid length, the estimated sources |
|
will have exactly the same length. |
|
""" |
|
length = math.ceil(length * self.resample) |
|
for idx in range(self.depth): |
|
length = math.ceil((length - self.kernel_size) / self.stride) + 1 |
|
length = max(length, 1) |
|
for idx in range(self.depth): |
|
length = (length - 1) * self.stride + self.kernel_size |
|
length = int(math.ceil(length / self.resample)) |
|
return int(length) |
|
|
|
@property |
|
def total_stride(self): |
|
return self.stride ** self.depth // self.resample |
|
|
|
def forward(self, mix): |
|
if mix.dim() == 2: |
|
mix = mix.unsqueeze(1) |
|
|
|
if self.normalize: |
|
mono = mix.mean(dim=1, keepdim=True) |
|
std = mono.std(dim=-1, keepdim=True) |
|
mix = mix / (self.floor + std) |
|
else: |
|
std = 1 |
|
length = mix.shape[-1] |
|
x = mix |
|
x = F.pad(x, (0, self.valid_length(length) - length)) |
|
if self.resample == 2: |
|
x = upsample2(x) |
|
elif self.resample == 4: |
|
x = upsample2(x) |
|
x = upsample2(x) |
|
skips = [] |
|
for encode in self.encoder: |
|
x = encode(x) |
|
skips.append(x) |
|
x = x.permute(2, 0, 1) |
|
x, _ = self.lstm(x) |
|
x = x.permute(1, 2, 0) |
|
for decode in self.decoder: |
|
skip = skips.pop(-1) |
|
x = x + skip[..., :x.shape[-1]] |
|
x = decode(x) |
|
if self.resample == 2: |
|
x = downsample2(x) |
|
elif self.resample == 4: |
|
x = downsample2(x) |
|
x = downsample2(x) |
|
|
|
x = x[..., :length] |
|
return std * x |
|
|
|
|
|
def fast_conv(conv, x): |
|
""" |
|
Faster convolution evaluation if either kernel size is 1 |
|
or length of sequence is 1. |
|
""" |
|
batch, chin, length = x.shape |
|
chout, chin, kernel = conv.weight.shape |
|
assert batch == 1 |
|
if kernel == 1: |
|
x = x.view(chin, length) |
|
out = th.addmm(conv.bias.view(-1, 1), |
|
conv.weight.view(chout, chin), x) |
|
elif length == kernel: |
|
x = x.view(chin * kernel, 1) |
|
out = th.addmm(conv.bias.view(-1, 1), |
|
conv.weight.view(chout, chin * kernel), x) |
|
else: |
|
out = conv(x) |
|
return out.view(batch, chout, -1) |
|
|
|
|
|
class DemucsStreamer: |
|
""" |
|
Streaming implementation for Demucs. It supports being fed with any amount |
|
of audio at a time. You will get back as much audio as possible at that |
|
point. |
|
|
|
Args: |
|
- demucs (Demucs): Demucs model. |
|
- dry (float): amount of dry (e.g. input) signal to keep. 0 is maximum |
|
noise removal, 1 just returns the input signal. Small values > 0 |
|
allows to limit distortions. |
|
- num_frames (int): number of frames to process at once. Higher values |
|
will increase overall latency but improve the real time factor. |
|
- resample_lookahead (int): extra lookahead used for the resampling. |
|
- resample_buffer (int): size of the buffer of previous inputs/outputs |
|
kept for resampling. |
|
""" |
|
def __init__(self, demucs, |
|
dry=0, |
|
num_frames=1, |
|
resample_lookahead=64, |
|
resample_buffer=256): |
|
device = next(iter(demucs.parameters())).device |
|
self.demucs = demucs |
|
self.lstm_state = None |
|
self.conv_state = None |
|
self.dry = dry |
|
self.resample_lookahead = resample_lookahead |
|
self.resample_buffer = resample_buffer |
|
self.frame_length = demucs.valid_length(1) + demucs.total_stride * (num_frames - 1) |
|
self.total_length = self.frame_length + self.resample_lookahead |
|
self.stride = demucs.total_stride * num_frames |
|
self.resample_in = th.zeros(demucs.chin, resample_buffer, device=device) |
|
self.resample_out = th.zeros(demucs.chin, resample_buffer, device=device) |
|
|
|
self.frames = 0 |
|
self.total_time = 0 |
|
self.variance = 0 |
|
self.pending = th.zeros(demucs.chin, 0, device=device) |
|
|
|
bias = demucs.decoder[0][2].bias |
|
weight = demucs.decoder[0][2].weight |
|
chin, chout, kernel = weight.shape |
|
self._bias = bias.view(-1, 1).repeat(1, kernel).view(-1, 1) |
|
self._weight = weight.permute(1, 2, 0).contiguous() |
|
|
|
def reset_time_per_frame(self): |
|
self.total_time = 0 |
|
self.frames = 0 |
|
|
|
@property |
|
def time_per_frame(self): |
|
return self.total_time / self.frames |
|
|
|
def flush(self): |
|
""" |
|
Flush remaining audio by padding it with zero. Call this |
|
when you have no more input and want to get back the last chunk of audio. |
|
""" |
|
pending_length = self.pending.shape[1] |
|
padding = th.zeros(self.demucs.chin, self.total_length, device=self.pending.device) |
|
out = self.feed(padding) |
|
return out[:, :pending_length] |
|
|
|
def feed(self, wav): |
|
""" |
|
Apply the model to mix using true real time evaluation. |
|
Normalization is done online as is the resampling. |
|
""" |
|
begin = time.time() |
|
demucs = self.demucs |
|
resample_buffer = self.resample_buffer |
|
stride = self.stride |
|
resample = demucs.resample |
|
|
|
if wav.dim() != 2: |
|
raise ValueError("input wav should be two dimensional.") |
|
chin, _ = wav.shape |
|
if chin != demucs.chin: |
|
raise ValueError(f"Expected {demucs.chin} channels, got {chin}") |
|
|
|
self.pending = th.cat([self.pending, wav], dim=1) |
|
outs = [] |
|
while self.pending.shape[1] >= self.total_length: |
|
self.frames += 1 |
|
frame = self.pending[:, :self.total_length] |
|
dry_signal = frame[:, :stride] |
|
if demucs.normalize: |
|
mono = frame.mean(0) |
|
variance = (mono**2).mean() |
|
self.variance = variance / self.frames + (1 - 1 / self.frames) * self.variance |
|
frame = frame / (demucs.floor + math.sqrt(self.variance)) |
|
frame = th.cat([self.resample_in, frame], dim=-1) |
|
self.resample_in[:] = frame[:, stride - resample_buffer:stride] |
|
|
|
if resample == 4: |
|
frame = upsample2(upsample2(frame)) |
|
elif resample == 2: |
|
frame = upsample2(frame) |
|
frame = frame[:, resample * resample_buffer:] |
|
frame = frame[:, :resample * self.frame_length] |
|
|
|
out, extra = self._separate_frame(frame) |
|
padded_out = th.cat([self.resample_out, out, extra], 1) |
|
self.resample_out[:] = out[:, -resample_buffer:] |
|
if resample == 4: |
|
out = downsample2(downsample2(padded_out)) |
|
elif resample == 2: |
|
out = downsample2(padded_out) |
|
else: |
|
out = padded_out |
|
|
|
out = out[:, resample_buffer // resample:] |
|
out = out[:, :stride] |
|
|
|
if demucs.normalize: |
|
out *= math.sqrt(self.variance) |
|
out = self.dry * dry_signal + (1 - self.dry) * out |
|
outs.append(out) |
|
self.pending = self.pending[:, stride:] |
|
|
|
self.total_time += time.time() - begin |
|
if outs: |
|
out = th.cat(outs, 1) |
|
else: |
|
out = th.zeros(chin, 0, device=wav.device) |
|
return out |
|
|
|
def _separate_frame(self, frame): |
|
demucs = self.demucs |
|
skips = [] |
|
next_state = [] |
|
first = self.conv_state is None |
|
stride = self.stride * demucs.resample |
|
x = frame[None] |
|
for idx, encode in enumerate(demucs.encoder): |
|
stride //= demucs.stride |
|
length = x.shape[2] |
|
if idx == demucs.depth - 1: |
|
|
|
x = fast_conv(encode[0], x) |
|
x = encode[1](x) |
|
x = fast_conv(encode[2], x) |
|
x = encode[3](x) |
|
else: |
|
if not first: |
|
prev = self.conv_state.pop(0) |
|
prev = prev[..., stride:] |
|
tgt = (length - demucs.kernel_size) // demucs.stride + 1 |
|
missing = tgt - prev.shape[-1] |
|
offset = length - demucs.kernel_size - demucs.stride * (missing - 1) |
|
x = x[..., offset:] |
|
x = encode[1](encode[0](x)) |
|
x = fast_conv(encode[2], x) |
|
x = encode[3](x) |
|
if not first: |
|
x = th.cat([prev, x], -1) |
|
next_state.append(x) |
|
skips.append(x) |
|
|
|
x = x.permute(2, 0, 1) |
|
x, self.lstm_state = demucs.lstm(x, self.lstm_state) |
|
x = x.permute(1, 2, 0) |
|
|
|
|
|
|
|
|
|
extra = None |
|
for idx, decode in enumerate(demucs.decoder): |
|
skip = skips.pop(-1) |
|
x += skip[..., :x.shape[-1]] |
|
x = fast_conv(decode[0], x) |
|
x = decode[1](x) |
|
|
|
if extra is not None: |
|
skip = skip[..., x.shape[-1]:] |
|
extra += skip[..., :extra.shape[-1]] |
|
extra = decode[2](decode[1](decode[0](extra))) |
|
x = decode[2](x) |
|
next_state.append(x[..., -demucs.stride:] - decode[2].bias.view(-1, 1)) |
|
if extra is None: |
|
extra = x[..., -demucs.stride:] |
|
else: |
|
extra[..., :demucs.stride] += next_state[-1] |
|
x = x[..., :-demucs.stride] |
|
|
|
if not first: |
|
prev = self.conv_state.pop(0) |
|
x[..., :demucs.stride] += prev |
|
if idx != demucs.depth - 1: |
|
x = decode[3](x) |
|
extra = decode[3](extra) |
|
self.conv_state = next_state |
|
return x[0], extra[0] |
|
|
|
|
|
def test(): |
|
import argparse |
|
parser = argparse.ArgumentParser( |
|
"denoiser.demucs", |
|
description="Benchmark the streaming Demucs implementation, " |
|
"as well as checking the delta with the offline implementation.") |
|
parser.add_argument("--resample", default=4, type=int) |
|
parser.add_argument("--hidden", default=48, type=int) |
|
parser.add_argument("--device", default="cpu") |
|
parser.add_argument("-t", "--num_threads", type=int) |
|
parser.add_argument("-f", "--num_frames", type=int, default=1) |
|
args = parser.parse_args() |
|
if args.num_threads: |
|
th.set_num_threads(args.num_threads) |
|
sr = 16_000 |
|
sr_ms = sr / 1000 |
|
demucs = Demucs(hidden=args.hidden, resample=args.resample).to(args.device) |
|
x = th.randn(1, sr * 4).to(args.device) |
|
out = demucs(x[None])[0] |
|
streamer = DemucsStreamer(demucs, num_frames=args.num_frames) |
|
out_rt = [] |
|
frame_size = streamer.total_length |
|
with th.no_grad(): |
|
while x.shape[1] > 0: |
|
out_rt.append(streamer.feed(x[:, :frame_size])) |
|
x = x[:, frame_size:] |
|
frame_size = streamer.demucs.total_stride |
|
out_rt.append(streamer.flush()) |
|
out_rt = th.cat(out_rt, 1) |
|
print(f"total lag: {streamer.total_length / sr_ms:.1f}ms, ", end='') |
|
print(f"stride: {streamer.stride / sr_ms:.1f}ms, ", end='') |
|
print(f"time per frame: {1000 * streamer.time_per_frame:.1f}ms, ", end='') |
|
print(f"delta: {th.norm(out - out_rt) / th.norm(out):.2%}, ", end='') |
|
print(f"RTF: {((1000 * streamer.time_per_frame) / (streamer.stride / sr_ms)):.1f}") |
|
|
|
|
|
if __name__ == "__main__": |
|
test() |
|
|