RVC / tools /utils /noisereduce.py
NeoPy's picture
EXP
b1cded8 verified
import os
import sys
import torch
from torch.nn.functional import conv1d, conv2d
sys.path.append(os.getcwd())
@torch.no_grad()
def temperature_sigmoid(x, x0, temp_coeff):
return ((x - x0) / temp_coeff).sigmoid()
@torch.no_grad()
def linspace(start, stop, num = 50, endpoint = True, **kwargs):
return (
torch.linspace(
start,
stop,
num,
**kwargs
)
) if endpoint else (
torch.linspace(
start,
stop,
num + 1,
**kwargs
)[:-1]
)
@torch.no_grad()
def amp_to_db(x, eps=torch.finfo(torch.float32).eps, top_db=40):
x_db = 20 * (x + eps).log10()
return x_db.max(
(x_db.max(-1).values - top_db).unsqueeze(-1)
)
class TorchGate(torch.nn.Module):
@torch.no_grad()
def __init__(
self,
sr,
nonstationary = False,
n_std_thresh_stationary = 1.5,
n_thresh_nonstationary = 1.3,
temp_coeff_nonstationary = 0.1,
n_movemean_nonstationary = 20,
prop_decrease = 1.0,
n_fft = 1024,
win_length = None,
hop_length = None,
freq_mask_smooth_hz = 500,
time_mask_smooth_ms = 50
):
super().__init__()
self.sr = sr
self.nonstationary = nonstationary
assert 0.0 <= prop_decrease <= 1.0
self.prop_decrease = prop_decrease
self.n_fft = n_fft
self.win_length = self.n_fft if win_length is None else win_length
self.hop_length = self.win_length // 4 if hop_length is None else hop_length
self.n_std_thresh_stationary = n_std_thresh_stationary
self.temp_coeff_nonstationary = temp_coeff_nonstationary
self.n_movemean_nonstationary = n_movemean_nonstationary
self.n_thresh_nonstationary = n_thresh_nonstationary
self.freq_mask_smooth_hz = freq_mask_smooth_hz
self.time_mask_smooth_ms = time_mask_smooth_ms
self.register_buffer("smoothing_filter", self._generate_mask_smoothing_filter())
@torch.no_grad()
def _generate_mask_smoothing_filter(self):
if self.freq_mask_smooth_hz is None and self.time_mask_smooth_ms is None: return None
n_grad_freq = (1 if self.freq_mask_smooth_hz is None else int(self.freq_mask_smooth_hz / (self.sr / (self.n_fft / 2))))
if n_grad_freq < 1: raise ValueError
n_grad_time = (1 if self.time_mask_smooth_ms is None else int(self.time_mask_smooth_ms / ((self.hop_length / self.sr) * 1000)))
if n_grad_time < 1: raise ValueError
if n_grad_time == 1 and n_grad_freq == 1: return None
smoothing_filter = torch.outer(
torch.cat([
linspace(0, 1, n_grad_freq + 1, endpoint=False),
linspace(1, 0, n_grad_freq + 2)
])[1:-1],
torch.cat([
linspace(0, 1, n_grad_time + 1, endpoint=False),
linspace(1, 0, n_grad_time + 2)
])[1:-1]
).unsqueeze(0).unsqueeze(0)
return smoothing_filter / smoothing_filter.sum()
@torch.no_grad()
def _stationary_mask(self, X_db):
std_freq_noise, mean_freq_noise = torch.std_mean(X_db, dim=-1)
return X_db > (mean_freq_noise + std_freq_noise * self.n_std_thresh_stationary).unsqueeze(2)
@torch.no_grad()
def _nonstationary_mask(self, X_abs):
X_smoothed = (
conv1d(
X_abs.reshape(-1, 1, X_abs.shape[-1]),
torch.ones(
self.n_movemean_nonstationary,
dtype=X_abs.dtype,
device=X_abs.device
).view(1, 1, -1),
padding="same"
).view(X_abs.shape) / self.n_movemean_nonstationary
)
return temperature_sigmoid(
((X_abs - X_smoothed) / X_smoothed),
self.n_thresh_nonstationary,
self.temp_coeff_nonstationary
)
def forward(self, x):
assert x.ndim == 2
if x.shape[-1] < self.win_length * 2: raise Exception
if str(x.device).startswith(("ocl", "privateuseone")):
if not hasattr(self, "stft"):
from main.library.backends.utils import STFT
self.stft = STFT(
filter_length=self.n_fft,
hop_length=self.hop_length,
win_length=self.win_length,
pad_mode="constant"
).to(x.device)
X, phase = self.stft.transform(
x,
eps=1e-9,
return_phase=True
)
else:
X = torch.stft(
x,
n_fft=self.n_fft,
hop_length=self.hop_length,
win_length=self.win_length,
return_complex=True,
pad_mode="constant",
center=True,
window=torch.hann_window(self.win_length).to(x.device)
)
sig_mask = self._nonstationary_mask(X.abs()) if self.nonstationary else self._stationary_mask(amp_to_db(X.abs()))
sig_mask = self.prop_decrease * (sig_mask.float() * 1.0 - 1.0) + 1.0
if self.smoothing_filter is not None:
sig_mask = conv2d(
sig_mask.unsqueeze(1),
self.smoothing_filter.to(sig_mask.dtype),
padding="same"
)
Y = X * sig_mask.squeeze(1)
return (
self.stft.inverse(
Y,
phase
)
) if hasattr(self, "stft") else (
torch.istft(
Y,
n_fft=self.n_fft,
hop_length=self.hop_length,
win_length=self.win_length,
center=True,
window=torch.hann_window(self.win_length).to(Y.device)
).to(dtype=x.dtype)
)