| import base64 |
| import io |
| from typing import Any |
|
|
| import numpy as np |
| import soundfile as sf |
| import torch |
| import torchaudio |
|
|
|
|
| |
| if not hasattr(torchaudio, "list_audio_backends"): |
| torchaudio.list_audio_backends = lambda: ["soundfile"] |
|
|
| from speechbrain.inference.separation import SepformerSeparation |
|
|
|
|
| TARGET_SAMPLE_RATE = 8000 |
|
|
|
|
| class EndpointHandler: |
| def __init__(self, path: str = ""): |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| self.model = SepformerSeparation.from_hparams( |
| source=path or ".", |
| savedir=path or ".", |
| run_opts={"device": device}, |
| ) |
|
|
| def __call__(self, data: Any) -> dict: |
| audio_bytes = self._extract_audio_bytes(data) |
| waveform, sample_rate = self._load_audio(audio_bytes) |
|
|
| with torch.no_grad(): |
| est_sources = self.model.separate_batch(waveform.unsqueeze(0)) |
|
|
| est_sources = est_sources.squeeze(0).detach().cpu() |
| if est_sources.ndim == 1: |
| est_sources = est_sources.unsqueeze(-1) |
|
|
| outputs = [] |
| for idx in range(est_sources.shape[-1]): |
| source = est_sources[:, idx].numpy() |
| buffer = io.BytesIO() |
| sf.write(buffer, source, TARGET_SAMPLE_RATE, format="WAV") |
| outputs.append( |
| { |
| "speaker": idx, |
| "audio_base64": base64.b64encode(buffer.getvalue()).decode("utf-8"), |
| "sample_rate": TARGET_SAMPLE_RATE, |
| "mime_type": "audio/wav", |
| } |
| ) |
|
|
| return { |
| "num_speakers": len(outputs), |
| "sources": outputs, |
| } |
|
|
| def _extract_audio_bytes(self, data: Any) -> bytes: |
| if isinstance(data, (bytes, bytearray)): |
| return bytes(data) |
|
|
| if isinstance(data, dict): |
| payload = data.get("inputs", data) |
|
|
| if isinstance(payload, (bytes, bytearray)): |
| return bytes(payload) |
|
|
| if isinstance(payload, str): |
| return self._decode_base64_audio(payload) |
|
|
| if isinstance(payload, dict): |
| for key in ("audio", "audio_base64", "data"): |
| value = payload.get(key) |
| if isinstance(value, str): |
| return self._decode_base64_audio(value) |
|
|
| raise ValueError("Unsupported request format. Send raw audio bytes or a JSON body with base64 audio.") |
|
|
| def _decode_base64_audio(self, value: str) -> bytes: |
| if "," in value and value.startswith("data:"): |
| value = value.split(",", 1)[1] |
| return base64.b64decode(value) |
|
|
| def _load_audio(self, audio_bytes: bytes) -> tuple[torch.Tensor, int]: |
| waveform, sample_rate = sf.read(io.BytesIO(audio_bytes), dtype="float32", always_2d=True) |
| waveform = torch.from_numpy(waveform.T) |
|
|
| if waveform.shape[0] > 1: |
| waveform = waveform.mean(dim=0, keepdim=True) |
|
|
| if sample_rate != TARGET_SAMPLE_RATE: |
| resampler = torchaudio.transforms.Resample(sample_rate, TARGET_SAMPLE_RATE) |
| waveform = resampler(waveform) |
|
|
| return waveform.squeeze(0), TARGET_SAMPLE_RATE |
|
|