{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. Setup & Installation" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Overwriting requirements.txt\n" ] } ], "source": [ "%%writefile requirements.txt\n", "torchaudio==0.11.*\n", "git+https://github.com/philschmid/pyannote-audio.git" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!pip install -r requirements.txt --upgrade" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. Create Custom Handler for Inference Endpoints\n" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Overwriting handler.py\n" ] } ], "source": [ "%%writefile handler.py\n", "from typing import Dict\n", "from pyannote.audio import Pipeline\n", "from transformers.pipelines.audio_utils import ffmpeg_read\n", "import torch \n", "\n", "SAMPLE_RATE = 16000\n", "\n", "\n", "\n", "class EndpointHandler():\n", " def __init__(self, path=\"\"):\n", " # load the model\n", " self.pipeline = Pipeline.from_pretrained(\"pyannote/speaker-diarization\")\n", "\n", "\n", " def __call__(self, data: Dict[str, bytes]) -> Dict[str, str]:\n", " \"\"\"\n", " Args:\n", " data (:obj:):\n", " includes the deserialized audio file as bytes\n", " Return:\n", " A :obj:`dict`:. base64 encoded image\n", " \"\"\"\n", " # process input\n", " inputs = data.pop(\"inputs\", data)\n", " parameters = data.pop(\"parameters\", None) # min_speakers=2, max_speakers=5\n", "\n", " \n", " # prepare pynannote input\n", " audio_nparray = ffmpeg_read(inputs, SAMPLE_RATE)\n", " audio_tensor= torch.from_numpy(audio_nparray).unsqueeze(0)\n", " pyannote_input = {\"waveform\": audio_tensor, \"sample_rate\": SAMPLE_RATE}\n", " \n", " # apply pretrained pipeline\n", " # pass inputs with all kwargs in data\n", " if parameters is not None:\n", " diarization = self.pipeline(pyannote_input, **parameters)\n", " else:\n", " diarization = self.pipeline(pyannote_input)\n", "\n", " # postprocess the prediction\n", " processed_diarization = [\n", " {\"label\": str(label), \"start\": str(segment.start), \"stop\": str(segment.end)}\n", " for segment, _, label in diarization.itertracks(yield_label=True)\n", " ]\n", " \n", " return {\"diarization\": processed_diarization}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "test custom pipeline" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from handler import EndpointHandler\n", "\n", "# init handler\n", "my_handler = EndpointHandler(path=\".\")" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import base64\n", "from PIL import Image\n", "from io import BytesIO\n", "import json\n", "\n", "# file reader\n", "with open(\"sample.wav\", \"rb\") as f:\n", " request = {\"inputs\": f.read()}\n", "\n", "# test the handler\n", "pred = my_handler(request)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'diarization': [{'label': 'SPEAKER_01',\n", " 'start': '0.4978125',\n", " 'stop': '1.3921875'},\n", " {'label': 'SPEAKER_01', 'start': '1.8984375', 'stop': '2.7590624999999998'},\n", " {'label': 'SPEAKER_02', 'start': '2.9953125', 'stop': '3.5015625000000004'},\n", " {'label': 'SPEAKER_01',\n", " 'start': '3.5690625000000002',\n", " 'stop': '4.311562500000001'},\n", " {'label': 'SPEAKER_02', 'start': '4.6153125', 'stop': '6.7753125'},\n", " {'label': 'SPEAKER_00', 'start': '7.1128125', 'stop': '7.551562500000001'},\n", " {'label': 'SPEAKER_02',\n", " 'start': '7.551562500000001',\n", " 'stop': '9.475312500000001'},\n", " {'label': 'SPEAKER_02',\n", " 'start': '9.812812500000003',\n", " 'stop': '10.555312500000003'},\n", " {'label': 'SPEAKER_00',\n", " 'start': '9.863437500000003',\n", " 'stop': '10.420312500000001'},\n", " {'label': 'SPEAKER_03', 'start': '12.411562500000002', 'stop': '15.5503125'},\n", " {'label': 'SPEAKER_00', 'start': '15.786562500000002', 'stop': '16.1409375'},\n", " {'label': 'SPEAKER_01', 'start': '16.1409375', 'stop': '16.1578125'},\n", " {'label': 'SPEAKER_00', 'start': '17.1534375', 'stop': '17.4234375'},\n", " {'label': 'SPEAKER_01', 'start': '17.7440625', 'stop': '20.3596875'},\n", " {'label': 'SPEAKER_01', 'start': '20.6128125', 'stop': '20.6634375'},\n", " {'label': 'SPEAKER_00', 'start': '20.6634375', 'stop': '20.8490625'},\n", " {'label': 'SPEAKER_01', 'start': '20.8490625', 'stop': '20.8828125'},\n", " {'label': 'SPEAKER_01', 'start': '21.1021875', 'stop': '22.1315625'},\n", " {'label': 'SPEAKER_02', 'start': '22.4521875', 'stop': '22.7053125'},\n", " {'label': 'SPEAKER_02', 'start': '23.2115625', 'stop': '23.4815625'},\n", " {'label': 'SPEAKER_01', 'start': '23.4815625', 'stop': '24.0215625'},\n", " {'label': 'SPEAKER_02', 'start': '24.3253125', 'stop': '25.5065625'},\n", " {'label': 'SPEAKER_01', 'start': '25.8440625', 'stop': '27.3121875'},\n", " {'label': 'SPEAKER_02', 'start': '27.3121875', 'stop': '27.4978125'},\n", " {'label': 'SPEAKER_01', 'start': '29.7253125', 'stop': '29.9615625'}]}" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pred" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3.9.13 ('dev': conda)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.13" }, "orig_nbformat": 4, "vscode": { "interpreter": { "hash": "f6dd96c16031089903d5a31ec148b80aeb0d39c32affb1a1080393235fbfa2fc" } } }, "nbformat": 4, "nbformat_minor": 2 }