ylacombe's picture
ylacombe HF staff
Upload 17 files
db36668 verified
raw
history blame
No virus
1.71 kB
from pyannote.audio import Model
from pathlib import Path
from brouhaha.pipeline import RegressiveActivityDetectionPipeline
import torch
from huggingface_hub import hf_hub_download
model = None
def snr_apply(batch, rank=None, audio_column_name="audio"):
global model
if model is None:
model = Model.from_pretrained(
Path(hf_hub_download(repo_id="ylacombe/brouhaha-best", filename="best.ckpt")),
strict=False,
)
if rank is not None:
# move the model to the right GPU if not there already
device = f"cuda:{(rank or 0)% torch.cuda.device_count()}"
# move to device and create pipeline here because the pipeline moves to the first GPU it finds anyway
model.to(device)
pipeline = RegressiveActivityDetectionPipeline(segmentation=model)
if rank:
pipeline.to(torch.device(device))
device = pipeline._models["segmentation"].device
if isinstance(batch[audio_column_name], list):
snr = []
c50 = []
for sample in batch[audio_column_name]:
res = pipeline({"sample_rate": sample["sampling_rate"],
"waveform": torch.tensor(sample["array"][None, :]).to(device).float()})
snr.append(res["snr"].mean())
c50.append(res["c50"].mean())
batch["snr"] = snr
batch["c50"] = c50
else:
res = pipeline({"sample_rate": batch[audio_column_name]["sampling_rate"],
"waveform": torch.tensor(batch[audio_column_name]["array"][None, :]).to(device).float()})
batch["snr"] = res["snr"].mean()
batch["c50"] = res["c50"].mean()
return batch