|
import torch |
|
import torchaudio |
|
import torch.nn.functional as F |
|
from pesq import pesq |
|
from pystoi import stoi |
|
|
|
from .other import si_sdr, pad_spec |
|
|
|
|
|
sr = 8000 |
|
snr = 0.5 |
|
N = 30 |
|
corrector_steps = 1 |
|
|
|
|
|
def evaluate_model(model, num_eval_files): |
|
|
|
clean_files = model.data_module.valid_set.clean_files |
|
noisy_files = model.data_module.valid_set.noisy_files |
|
mixture_files = model.data_module.valid_set.mixture_files |
|
|
|
|
|
total_num_files = len(clean_files) |
|
indices = torch.linspace(0, total_num_files-1, num_eval_files, dtype=torch.int) |
|
clean_files = list(clean_files[i] for i in indices) |
|
noisy_files = list(noisy_files[i] for i in indices) |
|
mixture_files = list(mixture_files[i] for i in indices) |
|
|
|
_pesq = 0 |
|
_si_sdr = 0 |
|
_estoi = 0 |
|
|
|
for (clean_file, noisy_file, mixture_file) in zip(clean_files, noisy_files, mixture_files): |
|
|
|
x, sr_ = torchaudio.load(clean_file) |
|
if sr_ != sr: |
|
x = torchaudio.transforms.Resample(sr_, sr)(x) |
|
y, sr_ = torchaudio.load(noisy_file) |
|
if sr_ != sr: |
|
y = torchaudio.transforms.Resample(sr_, sr)(y) |
|
m, sr_ = torchaudio.load(mixture_file) |
|
if sr_ != sr: |
|
m = torchaudio.transforms.Resample(sr_, sr)(m) |
|
|
|
min_leng = min(x.shape[-1],y.shape[-1],m.shape[-1]) |
|
x = x[...,:min_leng] |
|
y = y[...,:min_leng] |
|
m = m[...,:min_leng] |
|
|
|
T_orig = x.size(1) |
|
|
|
|
|
norm_factor = y.abs().max() |
|
y = y / norm_factor |
|
m = m / norm_factor |
|
|
|
|
|
Y = torch.unsqueeze(model._forward_transform(model._stft(y.cuda())), 0) |
|
Y = pad_spec(Y) |
|
|
|
M = torch.unsqueeze(model._forward_transform(model._stft(m.cuda())), 0) |
|
M = pad_spec(M) |
|
|
|
y = y * norm_factor |
|
|
|
|
|
|
|
sampler = model.get_pc_sampler( |
|
'reverse_diffusion', 'ald', Y.cuda(), M.cuda(), N=N, |
|
corrector_steps=corrector_steps, snr=snr) |
|
sample, _ = sampler() |
|
|
|
sample = sample.squeeze() |
|
|
|
|
|
x_hat = model.to_audio(sample.squeeze(), T_orig) |
|
x_hat = x_hat * norm_factor |
|
|
|
x_hat = x_hat.squeeze().cpu().numpy() |
|
x = x.squeeze().cpu().numpy() |
|
y = y.squeeze().cpu().numpy() |
|
|
|
_si_sdr += si_sdr(x, x_hat) |
|
_pesq += pesq(sr, x, x_hat, 'nb') |
|
_estoi += stoi(x, x_hat, sr, extended=True) |
|
|
|
return _pesq/num_eval_files, _si_sdr/num_eval_files, _estoi/num_eval_files |
|
|
|
|
|
def evaluate_model2(model, num_eval_files, inference_N, inference_start=0.5): |
|
|
|
|
|
N = inference_N |
|
reverse_start_time = inference_start |
|
|
|
clean_files = model.data_module.valid_set.clean_files |
|
noisy_files = model.data_module.valid_set.noisy_files |
|
mixture_files = model.data_module.valid_set.mixture_files |
|
|
|
|
|
total_num_files = len(clean_files) |
|
indices = torch.linspace(0, total_num_files-1, num_eval_files, dtype=torch.int) |
|
clean_files = list(clean_files[i] for i in indices) |
|
noisy_files = list(noisy_files[i] for i in indices) |
|
mixture_files = list(mixture_files[i] for i in indices) |
|
|
|
|
|
|
|
_pesq = 0 |
|
_si_sdr = 0 |
|
_estoi = 0 |
|
|
|
for (clean_file, noisy_file, mixture_file) in zip(clean_files, noisy_files, mixture_files): |
|
|
|
x, sr_ = torchaudio.load(clean_file) |
|
if sr_ != sr: |
|
x = torchaudio.transforms.Resample(sr_, sr)(x) |
|
y, sr_ = torchaudio.load(noisy_file) |
|
if sr_ != sr: |
|
y = torchaudio.transforms.Resample(sr_, sr)(y) |
|
m, sr_ = torchaudio.load(mixture_file) |
|
if sr_ != sr: |
|
m = torchaudio.transforms.Resample(sr_, sr)(m) |
|
|
|
|
|
min_leng = min(x.shape[-1],y.shape[-1],m.shape[-1]) |
|
x = x[...,:min_leng] |
|
y = y[...,:min_leng] |
|
m = m[...,:min_leng] |
|
|
|
T_orig = x.size(1) |
|
|
|
|
|
norm_factor = y.abs().max() |
|
y = y / norm_factor |
|
x = x / norm_factor |
|
m = m / norm_factor |
|
|
|
|
|
Y = torch.unsqueeze(model._forward_transform(model._stft(y.cuda())), 0) |
|
Y = pad_spec(Y) |
|
|
|
X = torch.unsqueeze(model._forward_transform(model._stft(x.cuda())), 0) |
|
X = pad_spec(X) |
|
|
|
M = torch.unsqueeze(model._forward_transform(model._stft(m.cuda())), 0) |
|
M = pad_spec(M) |
|
|
|
|
|
y = y * norm_factor |
|
x = x * norm_factor |
|
|
|
x = x.squeeze().cpu().numpy() |
|
y = y.squeeze().cpu().numpy() |
|
|
|
total_loss = 0 |
|
timesteps = torch.linspace(reverse_start_time, 0.03, N, device=Y.device) |
|
|
|
std = model.sde._std(reverse_start_time*torch.ones((Y.shape[0],), device=Y.device)) |
|
z = torch.randn_like(Y) |
|
X_t = Y + z * std[:, None, None, None] |
|
|
|
|
|
for i in range(len(timesteps)): |
|
t = timesteps[i] |
|
if i != len(timesteps) - 1: |
|
dt = t - timesteps[i+1] |
|
else: |
|
dt = timesteps[-1] |
|
with torch.no_grad(): |
|
|
|
f, g = model.sde.sde(X_t, t, Y) |
|
vec_t = torch.ones(Y.shape[0], device=Y.device) * t |
|
score = model.forward(X_t, vec_t, Y, M, vec_t[:,None,None,None]) |
|
mean_x_tm1 = X_t - (f - g**2*score)*dt |
|
if i == len(timesteps) - 1: |
|
X_t = mean_x_tm1 |
|
break |
|
z = torch.randn_like(X) |
|
|
|
X_t = mean_x_tm1 + z*g*torch.sqrt(dt) |
|
|
|
sample = X_t |
|
sample = sample.squeeze() |
|
x_hat = model.to_audio(sample.squeeze(), T_orig) |
|
x_hat = x_hat * norm_factor |
|
x_hat = x_hat.squeeze().cpu().numpy() |
|
_si_sdr += si_sdr(x, x_hat) |
|
_pesq += pesq(sr, x, x_hat, 'nb') |
|
_estoi += stoi(x, x_hat, sr, extended=True) |
|
|
|
return _pesq/num_eval_files, _si_sdr/num_eval_files, _estoi/num_eval_files, total_loss/num_eval_files |
|
|
|
|
|
def convert_to_audio(X, deemp, T_orig, model, norm_factor): |
|
|
|
sample = X |
|
|
|
sample = sample.squeeze() |
|
if len(sample.shape)==4: |
|
sample = sample*deemp[None, None, :, None].to(device=sample.device) |
|
elif len(sample.shape)==3: |
|
sample = sample*deemp[None, :, None].to(device=sample.device) |
|
else: |
|
sample = sample*deemp[:, None].to(device=sample.device) |
|
|
|
x_hat = model.to_audio(sample.squeeze(), T_orig) |
|
x_hat = x_hat * norm_factor |
|
|
|
x_hat = x_hat.squeeze().cpu().numpy() |
|
return x_hat |