|
import os |
|
import torch |
|
import numpy as np |
|
import scipy.stats |
|
from scipy.signal import butter, sosfilt |
|
|
|
from pesq import pesq |
|
from pystoi import stoi |
|
|
|
|
|
def si_sdr_components(s_hat, s, n): |
|
|
|
alpha_s = np.dot(s_hat, s) / np.linalg.norm(s)**2 |
|
s_target = alpha_s * s |
|
|
|
|
|
alpha_n = np.dot(s_hat, n) / np.linalg.norm(n)**2 |
|
e_noise = alpha_n * n |
|
|
|
|
|
e_art = s_hat - s_target - e_noise |
|
|
|
return s_target, e_noise, e_art |
|
|
|
def energy_ratios(s_hat, s, n): |
|
s_target, e_noise, e_art = si_sdr_components(s_hat, s, n) |
|
|
|
si_sdr = 10*np.log10(np.linalg.norm(s_target)**2 / np.linalg.norm(e_noise + e_art)**2) |
|
si_sir = 10*np.log10(np.linalg.norm(s_target)**2 / np.linalg.norm(e_noise)**2) |
|
si_sar = 10*np.log10(np.linalg.norm(s_target)**2 / np.linalg.norm(e_art)**2) |
|
|
|
return si_sdr, si_sir, si_sar |
|
|
|
def mean_conf_int(data, confidence=0.95): |
|
a = 1.0 * np.array(data) |
|
n = len(a) |
|
m, se = np.mean(a), scipy.stats.sem(a) |
|
h = se * scipy.stats.t.ppf((1 + confidence) / 2., n-1) |
|
return m, h |
|
|
|
class Method(): |
|
def __init__(self, name, base_dir, metrics): |
|
self.name = name |
|
self.base_dir = base_dir |
|
self.metrics = {} |
|
|
|
for i in range(len(metrics)): |
|
metric = metrics[i] |
|
value = [] |
|
self.metrics[metric] = value |
|
|
|
def append(self, matric, value): |
|
self.metrics[matric].append(value) |
|
|
|
def get_mean_ci(self, metric): |
|
return mean_conf_int(np.array(self.metrics[metric])) |
|
|
|
def hp_filter(signal, cut_off=80, order=10, sr=16000): |
|
factor = cut_off /sr * 2 |
|
sos = butter(order, factor, 'hp', output='sos') |
|
filtered = sosfilt(sos, signal) |
|
return filtered |
|
|
|
def si_sdr(s, s_hat): |
|
alpha = np.dot(s_hat, s)/np.linalg.norm(s)**2 |
|
sdr = 10*np.log10(np.linalg.norm(alpha*s)**2/np.linalg.norm( |
|
alpha*s - s_hat)**2) |
|
return sdr |
|
|
|
def snr_dB(s,n): |
|
s_power = 1/len(s)*np.sum(s**2) |
|
n_power = 1/len(n)*np.sum(n**2) |
|
snr_dB = 10*np.log10(s_power/n_power) |
|
return snr_dB |
|
|
|
def pad_spec(Y, mode="zero_pad"): |
|
T = Y.size(3) |
|
if T%64 !=0: |
|
num_pad = 64-T%64 |
|
else: |
|
num_pad = 0 |
|
if mode == "zero_pad": |
|
pad2d = torch.nn.ZeroPad2d((0, num_pad, 0,0)) |
|
elif mode == "reflection": |
|
pad2d = torch.nn.ReflectionPad2d((0, num_pad, 0,0)) |
|
elif mode == "replication": |
|
pad2d = torch.nn.ReplicationPad2d((0, num_pad, 0,0)) |
|
else: |
|
raise NotImplementedError("This function hasn't been implemented yet.") |
|
return pad2d(Y) |
|
|
|
def ensure_dir(file_path): |
|
directory = file_path |
|
if not os.path.exists(directory): |
|
os.makedirs(directory) |
|
|
|
|
|
def print_metrics(x, y, x_hat_list, labels, sr=16000): |
|
_si_sdr_mix = si_sdr(x, y) |
|
_pesq_mix = pesq(sr, x, y, 'wb') |
|
_estoi_mix = stoi(x, y, sr, extended=True) |
|
print(f'Mixture: PESQ: {_pesq_mix:.2f}, ESTOI: {_estoi_mix:.2f}, SI-SDR: {_si_sdr_mix:.2f}') |
|
for i, x_hat in enumerate(x_hat_list): |
|
_si_sdr = si_sdr(x, x_hat) |
|
_pesq = pesq(sr, x, x_hat, 'wb') |
|
_estoi = stoi(x, x_hat, sr, extended=True) |
|
print(f'{labels[i]}: {_pesq:.2f}, ESTOI: {_estoi:.2f}, SI-SDR: {_si_sdr:.2f}') |
|
|
|
def mean_std(data): |
|
data = data[~np.isnan(data)] |
|
mean = np.mean(data) |
|
std = np.std(data) |
|
return mean, std |
|
|
|
def print_mean_std(data, decimal=2): |
|
data = np.array(data) |
|
data = data[~np.isnan(data)] |
|
mean = np.mean(data) |
|
std = np.std(data) |
|
if decimal == 2: |
|
string = f'{mean:.2f} ± {std:.2f}' |
|
elif decimal == 1: |
|
string = f'{mean:.1f} ± {std:.1f}' |
|
return string |
|
|
|
def set_torch_cuda_arch_list(): |
|
if not torch.cuda.is_available(): |
|
print("CUDA is not available. No GPUs found.") |
|
return |
|
|
|
num_gpus = torch.cuda.device_count() |
|
compute_capabilities = [] |
|
|
|
for i in range(num_gpus): |
|
cc_major, cc_minor = torch.cuda.get_device_capability(i) |
|
cc = f"{cc_major}.{cc_minor}" |
|
compute_capabilities.append(cc) |
|
|
|
cc_string = ";".join(compute_capabilities) |
|
os.environ['TORCH_CUDA_ARCH_LIST'] = cc_string |
|
print(f"Set TORCH_CUDA_ARCH_LIST to: {cc_string}") |