Spaces:
Running
Running
#!/usr/bin/python3 | |
# -*- coding: utf-8 -*- | |
""" | |
https://github.com/LXP-Never/TCNN | |
https://github.com/LXP-Never/TCNN/blob/main/TCNN_model.py | |
https://github.com/HardeyPandya/Temporal-Convolutional-Neural-Network-Single-Channel-Speech-Enhancement | |
https://ieeexplore.ieee.org/abstract/document/8683634 | |
参考来源: | |
https://github.com/WenzheLiu-Speech/awesome-speech-enhancement | |
""" | |
from typing import Union | |
import torch | |
import torch.nn as nn | |
from torch.nn import functional as F | |
from torch.nn.common_types import _size_1_t, _size_2_t, _size_3_t | |
class Chomp1d(nn.Module): | |
def __init__(self, chomp_size: int): | |
super(Chomp1d, self).__init__() | |
self.chomp_size = chomp_size | |
def forward(self, x: torch.Tensor): | |
return x[:, :, :-self.chomp_size].contiguous() | |
class DepthwiseSeparableConv(nn.Module): | |
def __init__(self, | |
in_channels: int, | |
out_channels: int, | |
kernel_size: _size_1_t, | |
stride: _size_1_t = 1, | |
padding: Union[str, _size_1_t] = 0, | |
dilation: _size_1_t = 1, | |
causal: bool = False, | |
): | |
super(DepthwiseSeparableConv, self).__init__() | |
# Use `groups` option to implement depthwise convolution | |
self.depthwise_conv = nn.Conv1d( | |
in_channels=in_channels, out_channels=in_channels, | |
kernel_size=kernel_size, stride=stride, | |
padding=padding, dilation=dilation, | |
groups=in_channels, | |
bias=False, | |
) | |
self.chomp1d = Chomp1d(padding) if causal else nn.Identity() | |
self.prelu = nn.PReLU() | |
self.norm = nn.BatchNorm1d(in_channels) | |
self.pointwise_conv = nn.Conv1d( | |
in_channels=in_channels, out_channels=out_channels, | |
kernel_size=1, | |
bias=False, | |
) | |
def forward(self, x: torch.Tensor): | |
# x shape: [b, c, t] | |
x = self.depthwise_conv.forward(x) | |
# x shape: [b, c, t_pad] | |
x = self.chomp1d(x) | |
# x shape: [b, c, t] | |
x = self.prelu(x) | |
x = self.norm(x) | |
x = self.pointwise_conv.forward(x) | |
return x | |
class ResBlock(nn.Module): | |
def __init__(self, | |
in_channels: int, | |
hidden_channels: int, | |
kernel_size: _size_1_t, | |
dilation: _size_1_t = 1, | |
): | |
super(ResBlock, self).__init__() | |
self.conv1d = nn.Conv1d(in_channels=in_channels, out_channels=hidden_channels, kernel_size=1) | |
self.prelu = nn.PReLU(num_parameters=1) | |
self.norm = nn.BatchNorm1d(num_features=hidden_channels) | |
self.sconv = DepthwiseSeparableConv( | |
in_channels=hidden_channels, | |
out_channels=in_channels, | |
kernel_size=kernel_size, | |
stride=1, | |
padding=(kernel_size - 1) * dilation, | |
dilation=dilation, | |
causal=True, | |
) | |
def forward(self, inputs: torch.Tensor): | |
x = inputs | |
# x shape: [b, in_channels, t] | |
x = self.conv1d.forward(x) | |
# x shape: [b, out_channels, t] | |
x = self.prelu(x) | |
x = self.norm(x) | |
# x shape: [b, out_channels, t] | |
x = self.sconv.forward(x) | |
# x shape: [b, in_channels, t] | |
result = x + inputs | |
return result | |
class TCNNBlock(nn.Module): | |
def __init__(self, | |
in_channels: int, | |
hidden_channels: int, | |
kernel_size: int = 3, | |
init_dilation: int = 2, | |
num_layers: int = 6 | |
): | |
super(TCNNBlock, self).__init__() | |
self.layers = nn.ModuleList(modules=[]) | |
for i in range(num_layers): | |
dilation_size = init_dilation ** i | |
# in_channels = in_channels if i == 0 else out_channels | |
self.layers.append( | |
ResBlock( | |
in_channels, | |
hidden_channels, | |
kernel_size, | |
dilation=dilation_size, | |
) | |
) | |
def forward(self, x: torch.Tensor): | |
for layer in self.layers: | |
# x shape: [b, c, t] | |
x = layer.forward(x) | |
# x shape: [b, c, t] | |
return x | |
class TCNN(nn.Module): | |
def __init__(self): | |
super(TCNN, self).__init__() | |
self.win_size = 320 | |
self.hop_size = 160 | |
self.conv2d_1 = nn.Sequential( | |
nn.Conv2d(in_channels=1, out_channels=16, kernel_size=(3, 5), stride=(1, 1), padding=(1, 2)), | |
nn.BatchNorm2d(num_features=16), | |
nn.PReLU() | |
) | |
self.conv2d_2 = nn.Sequential( | |
nn.Conv2d(in_channels=16, out_channels=16, kernel_size=(3, 5), stride=(1, 2), padding=(1, 2)), | |
nn.BatchNorm2d(num_features=16), | |
nn.PReLU() | |
) | |
self.conv2d_3 = nn.Sequential( | |
nn.Conv2d(in_channels=16, out_channels=16, kernel_size=(3, 5), stride=(1, 2), padding=(1, 1)), | |
nn.BatchNorm2d(num_features=16), | |
nn.PReLU() | |
) | |
self.conv2d_4 = nn.Sequential( | |
nn.Conv2d(in_channels=16, out_channels=32, kernel_size=(3, 5), stride=(1, 2), padding=(1, 1)), | |
nn.BatchNorm2d(num_features=32), | |
nn.PReLU() | |
) | |
self.conv2d_5 = nn.Sequential( | |
nn.Conv2d(in_channels=32, out_channels=32, kernel_size=(3, 5), stride=(1, 2), padding=(1, 1)), | |
nn.BatchNorm2d(num_features=32), | |
nn.PReLU() | |
) | |
self.conv2d_6 = nn.Sequential( | |
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(3, 5), stride=(1, 2), padding=(1, 1)), | |
nn.BatchNorm2d(num_features=64), | |
nn.PReLU() | |
) | |
self.conv2d_7 = nn.Sequential( | |
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(3, 5), stride=(1, 2), padding=(1, 1)), | |
nn.BatchNorm2d(num_features=64), | |
nn.PReLU() | |
) | |
# 256 = 64 * 4 | |
self.tcnn_block_1 = TCNNBlock(in_channels=256, hidden_channels=512, kernel_size=3, init_dilation=2, num_layers=6) | |
self.tcnn_block_2 = TCNNBlock(in_channels=256, hidden_channels=512, kernel_size=3, init_dilation=2, num_layers=6) | |
self.tcnn_block_3 = TCNNBlock(in_channels=256, hidden_channels=512, kernel_size=3, init_dilation=2, num_layers=6) | |
self.dconv2d_7 = nn.Sequential( | |
nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=(3, 5), stride=(1, 2), padding=(1, 1), | |
output_padding=(0, 0)), | |
nn.BatchNorm2d(num_features=64), | |
nn.PReLU() | |
) | |
self.dconv2d_6 = nn.Sequential( | |
nn.ConvTranspose2d(in_channels=128, out_channels=32, kernel_size=(3, 5), stride=(1, 2), padding=(1, 1), | |
output_padding=(0, 0)), | |
nn.BatchNorm2d(num_features=32), | |
nn.PReLU() | |
) | |
self.dconv2d_5 = nn.Sequential( | |
nn.ConvTranspose2d(in_channels=64, out_channels=32, kernel_size=(3, 5), stride=(1, 2), padding=(1, 1), | |
output_padding=(0, 0)), | |
nn.BatchNorm2d(num_features=32), | |
nn.PReLU() | |
) | |
self.dconv2d_4 = nn.Sequential( | |
nn.ConvTranspose2d(in_channels=64, out_channels=16, kernel_size=(3, 5), stride=(1, 2), padding=(1, 1), | |
output_padding=(0, 0)), | |
nn.BatchNorm2d(num_features=16), | |
nn.PReLU() | |
) | |
self.dconv2d_3 = nn.Sequential( | |
nn.ConvTranspose2d(in_channels=32, out_channels=16, kernel_size=(3, 5), stride=(1, 2), padding=(1, 1), | |
output_padding=(0, 1)), | |
nn.BatchNorm2d(num_features=16), | |
nn.PReLU() | |
) | |
self.dconv2d_2 = nn.Sequential( | |
nn.ConvTranspose2d(in_channels=32, out_channels=16, kernel_size=(3, 5), stride=(1, 2), padding=(1, 2), | |
output_padding=(0, 1)), | |
nn.BatchNorm2d(num_features=16), | |
nn.PReLU() | |
) | |
self.dconv2d_1 = nn.Sequential( | |
nn.ConvTranspose2d(in_channels=32, out_channels=1, kernel_size=(3, 5), stride=(1, 1), padding=(1, 2), | |
output_padding=(0, 0)), | |
nn.BatchNorm2d(num_features=1), | |
nn.PReLU() | |
) | |
def signal_prepare(self, signal: torch.Tensor) -> torch.Tensor: | |
if signal.dim() == 2: | |
signal = torch.unsqueeze(signal, dim=1) | |
_, _, n_samples = signal.shape | |
remainder = (n_samples - self.win_size) % self.hop_size | |
if remainder > 0: | |
n_samples_pad = self.hop_size - remainder | |
signal = F.pad(signal, pad=(0, n_samples_pad), mode="constant", value=0) | |
return signal, n_samples | |
def forward(self, | |
noisy: torch.Tensor, | |
): | |
noisy, num_samples = self.signal_prepare(noisy) | |
batch_size, _, num_samples_pad = noisy.shape | |
# n_frame = (num_samples_pad - self.win_size) / self.hop_size + 1 | |
# unfold | |
# noisy shape: [b, 1, num_samples_pad] | |
noisy = noisy.unsqueeze(1) | |
# noisy shape: [b, 1, 1, num_samples_pad] | |
noisy_frame = torch.nn.functional.unfold( | |
input=noisy, | |
kernel_size=(1, self.win_size), | |
padding=(0, 0), | |
stride=(1, self.hop_size), | |
) | |
# noisy_frame shape: [b, win_size, n_frame] | |
noisy_frame = noisy_frame.unsqueeze(1) | |
# noisy_frame shape: [b, 1, win_size, n_frame] | |
noisy_frame = noisy_frame.permute(0, 1, 3, 2) | |
# noisy_frame shape: [b, 1, n_frame, win_size] | |
denoise_frame = self.forward_chunk(noisy_frame) | |
# denoise_frame shape: [b, c, n_frame, win_size] | |
denoise_frame = denoise_frame.squeeze(1) | |
# denoise_frame shape: [b, n_frame, win_size] | |
denoise = self.denoise_frame_to_denoise(denoise_frame, batch_size, num_samples_pad) | |
# denoise shape: [b, num_samples_pad] | |
denoise = denoise[:, :num_samples] | |
# denoise shape: [b, num_samples] | |
return denoise | |
def forward_chunk(self, inputs: torch.Tensor): | |
# inputs shape: [b, c, t, segment_length] | |
conv2d_1 = self.conv2d_1(inputs) | |
conv2d_2 = self.conv2d_2(conv2d_1) | |
conv2d_3 = self.conv2d_3(conv2d_2) | |
conv2d_4 = self.conv2d_4(conv2d_3) | |
conv2d_5 = self.conv2d_5(conv2d_4) | |
conv2d_6 = self.conv2d_6(conv2d_5) | |
conv2d_7 = self.conv2d_7(conv2d_6) | |
# shape: [b, c, t, 4] | |
reshape_1 = conv2d_7.permute(0, 1, 3, 2) | |
# shape: [b, c, 4, t] | |
batch_size, C, frame_len, frame_num = reshape_1.shape | |
reshape_1 = reshape_1.reshape(batch_size, C * frame_len, frame_num) | |
# shape: [b, c*4, t] | |
tcnn_block_1 = self.tcnn_block_1.forward(reshape_1) | |
tcnn_block_2 = self.tcnn_block_2.forward(tcnn_block_1) | |
tcnn_block_3 = self.tcnn_block_3.forward(tcnn_block_2) | |
# shape: [b, c*4, t] | |
reshape_2 = tcnn_block_3.reshape(batch_size, C, frame_len, frame_num) | |
reshape_2 = reshape_2.permute(0, 1, 3, 2) | |
# shape: [b, c, t, 4] | |
dconv2d_7 = self.dconv2d_7(torch.cat((conv2d_7, reshape_2), dim=1)) | |
dconv2d_6 = self.dconv2d_6(torch.cat((conv2d_6, dconv2d_7), dim=1)) | |
dconv2d_5 = self.dconv2d_5(torch.cat((conv2d_5, dconv2d_6), dim=1)) | |
dconv2d_4 = self.dconv2d_4(torch.cat((conv2d_4, dconv2d_5), dim=1)) | |
dconv2d_3 = self.dconv2d_3(torch.cat((conv2d_3, dconv2d_4), dim=1)) | |
dconv2d_2 = self.dconv2d_2(torch.cat((conv2d_2, dconv2d_3), dim=1)) | |
dconv2d_1 = self.dconv2d_1(torch.cat((conv2d_1, dconv2d_2), dim=1)) | |
return dconv2d_1 | |
def denoise_frame_to_denoise(self, denoise_frame: torch.Tensor, batch_size: int, num_samples: int): | |
# overlap and add | |
# https://github.com/HardeyPandya/Temporal-Convolutional-Neural-Network-Single-Channel-Speech-Enhancement/blob/main/TCNN/util/utils.py#L40 | |
b, t, f = denoise_frame.shape | |
if f != self.win_size: | |
raise AssertionError | |
denoise = torch.zeros(size=(b, num_samples), dtype=denoise_frame.dtype) | |
count = torch.zeros(size=(b, num_samples), dtype=torch.float32) | |
start = 0 | |
end = start + self.win_size | |
for i in range(t): | |
denoise[..., start:end] += denoise_frame[:, i, :] | |
count[..., start:end] += 1. | |
start += self.hop_size | |
end = start + self.win_size | |
denoise = denoise / count | |
return denoise | |
def main(): | |
model = TCNN() | |
x = torch.randn(64, 1, 5, 320) | |
# x = torch.randn(64, 1, 5, 160) | |
y = model.forward_chunk(x) | |
print("output", y.shape) | |
noisy = torch.randn(size=(2, 16000), dtype=torch.float32) | |
denoise = model.forward(noisy) | |
print(f"denoise.shape: {denoise.shape}") | |
return | |
if __name__ == "__main__": | |
main() | |