Katsuya Oda
fix: use config.yaml
50e7175
from typing import Dict
from pyannote.audio import Pipeline
from io import BytesIO
import torch
import torchaudio
class EndpointHandler:
def __init__(self, path=""):
# load the model
self.pipeline = Pipeline.from_pretrained("config.yaml")
def __call__(self, data: Dict[str, bytes]) -> Dict[str, str]:
"""
Args:
data (:obj:):
includes the deserialized audio file as bytes
Return:
A :obj:`dict`:. base64 encoded image
"""
# process input
inputs = data.pop("inputs", data)
parameters = data.pop("parameters", None) # min_speakers=2, max_speakers=5
waveform, sample_rate = torchaudio.load(BytesIO(inputs))
pyannote_input = {"waveform": waveform, "sample_rate": sample_rate}
# apply pretrained pipeline
# pass inputs with all kwargs in data
if parameters is not None:
diarization = self.pipeline(pyannote_input, **parameters)
else:
diarization = self.pipeline(pyannote_input)
# postprocess the prediction
processed_diarization = [
{"label": str(label), "start": str(segment.start), "stop": str(segment.end)}
for segment, _, label in diarization.itertracks(yield_label=True)
]
return {"diarization": processed_diarization}