ylacombe HF staff commited on
Commit
2df3fbd
1 Parent(s): 0300002

Upload 7 files

Browse files
dataspeech/__init__.py CHANGED
@@ -1,2 +1,2 @@
1
  from .cpu_enrichments import rate_apply
2
- from .gpu_enrichments import pitch_apply, snr_apply
 
1
  from .cpu_enrichments import rate_apply
2
+ from .gpu_enrichments import pitch_apply, snr_apply, squim_apply
dataspeech/gpu_enrichments/__init__.py CHANGED
@@ -1,2 +1,3 @@
1
  from .pitch import pitch_apply
2
- from .snr_and_reverb import snr_apply
 
 
1
  from .pitch import pitch_apply
2
+ from .snr_and_reverb import snr_apply
3
+ from .squim import squim_apply
dataspeech/gpu_enrichments/squim.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torchaudio.pipelines import SQUIM_OBJECTIVE
2
+ import torch
3
+ import torchaudio
4
+
5
+ model = None
6
+
7
+ def squim_apply(batch, rank=None, audio_column_name="audio"):
8
+ global model
9
+ if model is None:
10
+ model = SQUIM_OBJECTIVE.get_model()
11
+ if rank is not None:
12
+ # move the model to the right GPU if not there already
13
+ device = f"cuda:{(rank or 0)% torch.cuda.device_count()}"
14
+ # move to device and create pipeline here because the pipeline moves to the first GPU it finds anyway
15
+ model.to(device)
16
+ else:
17
+ device = "cpu"
18
+
19
+ if isinstance(batch[audio_column_name], list):
20
+ sdr = []
21
+ pesq = []
22
+ stoi = []
23
+ for sample in batch[audio_column_name]:
24
+ waveform = torchaudio.functional.resample(torch.tensor(sample["array"][None, :]).to(device).float(), sample["sampling_rate"], SQUIM_OBJECTIVE.sample_rate)
25
+ with torch.no_grad():
26
+ stoi_sample, pesq_sample, sdr_sample = model(waveform)
27
+ sdr.append(sdr_sample.cpu())
28
+ pesq.append(pesq_sample.cpu())
29
+ stoi.append(stoi_sample.cpu())
30
+
31
+ batch["sdr"] = sdr
32
+ batch["pesq"] = pesq
33
+ batch["stoi"] = stoi
34
+ else:
35
+
36
+ waveform = torchaudio.functional.resample(torch.tensor(batch[audio_column_name]["array"][None, :]).to(device).float(), batch[audio_column_name]["sampling_rate"], SQUIM_OBJECTIVE.sample_rate)
37
+ with torch.no_grad():
38
+ stoi_sample, pesq_sample, sdr_sample = model(waveform)
39
+ batch["sdr"] = sdr_sample
40
+ batch["pesq"] = pesq_sample
41
+ batch["stoi"] = stoi_sample
42
+ # TODO
43
+ return batch
44
+