Spaces:
Build error
Build error
Upload 7 files
Browse files
dataspeech/__init__.py
CHANGED
@@ -1,2 +1,2 @@
|
|
1 |
from .cpu_enrichments import rate_apply
|
2 |
-
from .gpu_enrichments import pitch_apply, snr_apply
|
|
|
1 |
from .cpu_enrichments import rate_apply
|
2 |
+
from .gpu_enrichments import pitch_apply, snr_apply, squim_apply
|
dataspeech/gpu_enrichments/__init__.py
CHANGED
@@ -1,2 +1,3 @@
|
|
1 |
from .pitch import pitch_apply
|
2 |
-
from .snr_and_reverb import snr_apply
|
|
|
|
1 |
from .pitch import pitch_apply
|
2 |
+
from .snr_and_reverb import snr_apply
|
3 |
+
from .squim import squim_apply
|
dataspeech/gpu_enrichments/squim.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torchaudio.pipelines import SQUIM_OBJECTIVE
|
2 |
+
import torch
|
3 |
+
import torchaudio
|
4 |
+
|
5 |
+
model = None
|
6 |
+
|
7 |
+
def squim_apply(batch, rank=None, audio_column_name="audio"):
|
8 |
+
global model
|
9 |
+
if model is None:
|
10 |
+
model = SQUIM_OBJECTIVE.get_model()
|
11 |
+
if rank is not None:
|
12 |
+
# move the model to the right GPU if not there already
|
13 |
+
device = f"cuda:{(rank or 0)% torch.cuda.device_count()}"
|
14 |
+
# move to device and create pipeline here because the pipeline moves to the first GPU it finds anyway
|
15 |
+
model.to(device)
|
16 |
+
else:
|
17 |
+
device = "cpu"
|
18 |
+
|
19 |
+
if isinstance(batch[audio_column_name], list):
|
20 |
+
sdr = []
|
21 |
+
pesq = []
|
22 |
+
stoi = []
|
23 |
+
for sample in batch[audio_column_name]:
|
24 |
+
waveform = torchaudio.functional.resample(torch.tensor(sample["array"][None, :]).to(device).float(), sample["sampling_rate"], SQUIM_OBJECTIVE.sample_rate)
|
25 |
+
with torch.no_grad():
|
26 |
+
stoi_sample, pesq_sample, sdr_sample = model(waveform)
|
27 |
+
sdr.append(sdr_sample.cpu())
|
28 |
+
pesq.append(pesq_sample.cpu())
|
29 |
+
stoi.append(stoi_sample.cpu())
|
30 |
+
|
31 |
+
batch["sdr"] = sdr
|
32 |
+
batch["pesq"] = pesq
|
33 |
+
batch["stoi"] = stoi
|
34 |
+
else:
|
35 |
+
|
36 |
+
waveform = torchaudio.functional.resample(torch.tensor(batch[audio_column_name]["array"][None, :]).to(device).float(), batch[audio_column_name]["sampling_rate"], SQUIM_OBJECTIVE.sample_rate)
|
37 |
+
with torch.no_grad():
|
38 |
+
stoi_sample, pesq_sample, sdr_sample = model(waveform)
|
39 |
+
batch["sdr"] = sdr_sample
|
40 |
+
batch["pesq"] = pesq_sample
|
41 |
+
batch["stoi"] = stoi_sample
|
42 |
+
# TODO
|
43 |
+
return batch
|
44 |
+
|