Vageesh1 commited on
Commit
59907c6
1 Parent(s): 2dbfc9c

Update helper.py

Browse files
Files changed (1) hide show
  1. helper.py +13 -0
helper.py CHANGED
@@ -1,4 +1,5 @@
1
  import torch
 
2
  import matplotlib.pyplot as plt
3
  from pesq import pesq
4
  from pystoi import stoi
@@ -41,7 +42,19 @@ def si_snr(estimate, reference, epsilon=1e-8):
41
  si_snr = 10 * torch.log10(reference_pow) - 10 * torch.log10(error_pow)
42
  return si_snr.item()
43
 
 
 
 
 
 
 
 
44
  def generate_mixture(waveform_clean, waveform_noise, target_snr):
 
 
 
 
 
45
  power_clean_signal = waveform_clean.pow(2).mean()
46
  power_noise_signal = waveform_noise.pow(2).mean()
47
  current_snr = 10 * torch.log10(power_clean_signal / power_noise_signal)
 
1
  import torch
2
+ import torchaudio.functional as F
3
  import matplotlib.pyplot as plt
4
  from pesq import pesq
5
  from pystoi import stoi
 
42
  si_snr = 10 * torch.log10(reference_pow) - 10 * torch.log10(error_pow)
43
  return si_snr.item()
44
 
45
+ # def generate_mixture(waveform_clean, waveform_noise, target_snr):
46
+ # power_clean_signal = waveform_clean.pow(2).mean()
47
+ # power_noise_signal = waveform_noise.pow(2).mean()
48
+ # current_snr = 10 * torch.log10(power_clean_signal / power_noise_signal)
49
+ # waveform_noise *= 10 ** (-(target_snr - current_snr) / 20)
50
+ # return waveform_clean + waveform_noise
51
+
52
  def generate_mixture(waveform_clean, waveform_noise, target_snr):
53
+ if waveform_clean.size(1) > waveform_noise.size(1):
54
+ waveform_noise = F.pad(waveform_noise, (0, waveform_clean.size(1) - waveform_noise.size(1)))
55
+ elif waveform_noise.size(1) > waveform_clean.size(1):
56
+ waveform_clean = F.pad(waveform_clean, (0, waveform_noise.size(1) - waveform_clean.size(1)))
57
+
58
  power_clean_signal = waveform_clean.pow(2).mean()
59
  power_noise_signal = waveform_noise.pow(2).mean()
60
  current_snr = 10 * torch.log10(power_clean_signal / power_noise_signal)