KIFF commited on
Commit
7f3cec9
·
verified ·
1 Parent(s): 1440acf

Create create_handler.ipynb

Browse files
Files changed (1) hide show
  1. create_handler.ipynb +280 -0
create_handler.ipynb ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "## 1. Setup & Installation"
8
+ ]
9
+ },
10
+ {
11
+ "cell_type": "code",
12
+ "execution_count": 1,
13
+ "metadata": {},
14
+ "outputs": [
15
+ {
16
+ "name": "stdout",
17
+ "output_type": "stream",
18
+ "text": [
19
+ "Overwriting requirements.txt\n"
20
+ ]
21
+ }
22
+ ],
23
+ "source": [
24
+ "%%writefile requirements.txt\n",
25
+ "torchaudio==0.11.*\n",
26
+ "git+https://github.com/philschmid/pyannote-audio.git"
27
+ ]
28
+ },
29
+ {
30
+ "cell_type": "code",
31
+ "execution_count": null,
32
+ "metadata": {},
33
+ "outputs": [],
34
+ "source": [
35
+ "!pip install -r requirements.txt --upgrade"
36
+ ]
37
+ },
38
+ {
39
+ "cell_type": "markdown",
40
+ "metadata": {},
41
+ "source": [
42
+ "## 2. Create Custom Handler for Inference Endpoints\n"
43
+ ]
44
+ },
45
+ {
46
+ "cell_type": "code",
47
+ "execution_count": 2,
48
+ "metadata": {},
49
+ "outputs": [
50
+ {
51
+ "name": "stdout",
52
+ "output_type": "stream",
53
+ "text": [
54
+ "Overwriting handler.py\n"
55
+ ]
56
+ }
57
+ ],
58
+ "source": [
59
+ "%%writefile handler.py\n",
60
+ "from typing import Dict\n",
61
+ "from pyannote.audio import Pipeline\n",
62
+ "from transformers.pipelines.audio_utils import ffmpeg_read\n",
63
+ "import torch \n",
64
+ "\n",
65
+ "SAMPLE_RATE = 16000\n",
66
+ "\n",
67
+ "\n",
68
+ "\n",
69
+ "class EndpointHandler():\n",
70
+ " def __init__(self, path=\"\"):\n",
71
+ " # load the model\n",
72
+ " self.pipeline = Pipeline.from_pretrained(\"pyannote/speaker-diarization\")\n",
73
+ "\n",
74
+ "\n",
75
+ " def __call__(self, data: Dict[str, bytes]) -> Dict[str, str]:\n",
76
+ " \"\"\"\n",
77
+ " Args:\n",
78
+ " data (:obj:):\n",
79
+ " includes the deserialized audio file as bytes\n",
80
+ " Return:\n",
81
+ " A :obj:`dict`:. base64 encoded image\n",
82
+ " \"\"\"\n",
83
+ " # process input\n",
84
+ " inputs = data.pop(\"inputs\", data)\n",
85
+ " parameters = data.pop(\"parameters\", None) # min_speakers=2, max_speakers=5\n",
86
+ "\n",
87
+ " \n",
88
+ " # prepare pynannote input\n",
89
+ " audio_nparray = ffmpeg_read(inputs, SAMPLE_RATE)\n",
90
+ " audio_tensor= torch.from_numpy(audio_nparray).unsqueeze(0)\n",
91
+ " pyannote_input = {\"waveform\": audio_tensor, \"sample_rate\": SAMPLE_RATE}\n",
92
+ " \n",
93
+ " # apply pretrained pipeline\n",
94
+ " # pass inputs with all kwargs in data\n",
95
+ " if parameters is not None:\n",
96
+ " diarization = self.pipeline(pyannote_input, **parameters)\n",
97
+ " else:\n",
98
+ " diarization = self.pipeline(pyannote_input)\n",
99
+ "\n",
100
+ " # postprocess the prediction\n",
101
+ " processed_diarization = [\n",
102
+ " {\"label\": str(label), \"start\": str(segment.start), \"stop\": str(segment.end)}\n",
103
+ " for segment, _, label in diarization.itertracks(yield_label=True)\n",
104
+ " ]\n",
105
+ " \n",
106
+ " return {\"diarization\": processed_diarization}"
107
+ ]
108
+ },
109
+ {
110
+ "cell_type": "markdown",
111
+ "metadata": {},
112
+ "source": [
113
+ "test custom pipeline"
114
+ ]
115
+ },
116
+ {
117
+ "cell_type": "code",
118
+ "execution_count": 1,
119
+ "metadata": {},
120
+ "outputs": [],
121
+ "source": [
122
+ "from handler import EndpointHandler\n",
123
+ "\n",
124
+ "# init handler\n",
125
+ "my_handler = EndpointHandler(path=\".\")"
126
+ ]
127
+ },
128
+ {
129
+ "cell_type": "code",
130
+ "execution_count": 2,
131
+ "metadata": {},
132
+ "outputs": [],
133
+ "source": [
134
+ "import base64\n",
135
+ "from PIL import Image\n",
136
+ "from io import BytesIO\n",
137
+ "import json\n",
138
+ "\n",
139
+ "# file reader\n",
140
+ "with open(\"sample.wav\", \"rb\") as f:\n",
141
+ " request = {\"inputs\": f.read()}\n",
142
+ "\n",
143
+ "# test the handler\n",
144
+ "pred = my_handler(request)"
145
+ ]
146
+ },
147
+ {
148
+ "cell_type": "code",
149
+ "execution_count": 3,
150
+ "metadata": {},
151
+ "outputs": [
152
+ {
153
+ "data": {
154
+ "text/plain": [
155
+ "{'diarization': [{'label': 'SPEAKER_01',\n",
156
+ " 'start': '0.4978125',\n",
157
+ " 'stop': '1.3921875'},\n",
158
+ " {'label': 'SPEAKER_01', 'start': '1.8984375', 'stop': '2.7590624999999998'},\n",
159
+ " {'label': 'SPEAKER_02', 'start': '2.9953125', 'stop': '3.5015625000000004'},\n",
160
+ " {'label': 'SPEAKER_01',\n",
161
+ " 'start': '3.5690625000000002',\n",
162
+ " 'stop': '4.311562500000001'},\n",
163
+ " {'label': 'SPEAKER_02', 'start': '4.6153125', 'stop': '6.7753125'},\n",
164
+ " {'label': 'SPEAKER_00', 'start': '7.1128125', 'stop': '7.551562500000001'},\n",
165
+ " {'label': 'SPEAKER_02',\n",
166
+ " 'start': '7.551562500000001',\n",
167
+ " 'stop': '9.475312500000001'},\n",
168
+ " {'label': 'SPEAKER_02',\n",
169
+ " 'start': '9.812812500000003',\n",
170
+ " 'stop': '10.555312500000003'},\n",
171
+ " {'label': 'SPEAKER_00',\n",
172
+ " 'start': '9.863437500000003',\n",
173
+ " 'stop': '10.420312500000001'},\n",
174
+ " {'label': 'SPEAKER_03', 'start': '12.411562500000002', 'stop': '15.5503125'},\n",
175
+ " {'label': 'SPEAKER_00', 'start': '15.786562500000002', 'stop': '16.1409375'},\n",
176
+ " {'label': 'SPEAKER_01', 'start': '16.1409375', 'stop': '16.1578125'},\n",
177
+ " {'label': 'SPEAKER_00', 'start': '17.1534375', 'stop': '17.4234375'},\n",
178
+ " {'label': 'SPEAKER_01', 'start': '17.7440625', 'stop': '20.3596875'},\n",
179
+ " {'label': 'SPEAKER_01', 'start': '20.6128125', 'stop': '20.6634375'},\n",
180
+ " {'label': 'SPEAKER_00', 'start': '20.6634375', 'stop': '20.8490625'},\n",
181
+ " {'label': 'SPEAKER_01', 'start': '20.8490625', 'stop': '20.8828125'},\n",
182
+ " {'label': 'SPEAKER_01', 'start': '21.1021875', 'stop': '22.1315625'},\n",
183
+ " {'label': 'SPEAKER_02', 'start': '22.4521875', 'stop': '22.7053125'},\n",
184
+ " {'label': 'SPEAKER_02', 'start': '23.2115625', 'stop': '23.4815625'},\n",
185
+ " {'label': 'SPEAKER_01', 'start': '23.4815625', 'stop': '24.0215625'},\n",
186
+ " {'label': 'SPEAKER_02', 'start': '24.3253125', 'stop': '25.5065625'},\n",
187
+ " {'label': 'SPEAKER_01', 'start': '25.8440625', 'stop': '27.3121875'},\n",
188
+ " {'label': 'SPEAKER_02', 'start': '27.3121875', 'stop': '27.4978125'},\n",
189
+ " {'label': 'SPEAKER_01', 'start': '29.7253125', 'stop': '29.9615625'}]}"
190
+ ]
191
+ },
192
+ "execution_count": 3,
193
+ "metadata": {},
194
+ "output_type": "execute_result"
195
+ }
196
+ ],
197
+ "source": [
198
+ "pred"
199
+ ]
200
+ }
201
+ ],
202
+ "metadata": {
203
+ "kernelspec": {
204
+ "display_name": "Python 3.9.13 ('dev': conda)",
205
+ "language": "python",
206
+ "name": "python3"
207
+ },
208
+ "language_info": {
209
+ "codemirror_mode": {
210
+ "name": "ipython",
211
+ "version": 3
212
+ },
213
+ "file_extension": ".py",
214
+ "mimetype": "text/x-python",
215
+ "name": "python",
216
+ "nbconvert_exporter": "python",
217
+ "pygments_lexer": "ipython3",
218
+ "version": "3.9.13"
219
+ },
220
+ "orig_nbformat": 4,
221
+ "vscode": {
222
+ "interpreter": {
223
+ "hash": "f6dd96c16031089903d5a31ec148b80aeb0d39c32affb1a1080393235fbfa2fc"
224
+ }
225
+ }
226
+ },
227
+ "nbformat": 4,
228
+ "nbformat_minor": 2
229
+ }
230
+
231
+
232
+ handler.py
233
+
234
+ from typing import Dict
235
+ from pyannote.audio import Pipeline
236
+ import torch
237
+ import base64
238
+ import numpy as np
239
+
240
+ SAMPLE_RATE = 16000
241
+
242
+ class EndpointHandler():
243
+ def __init__(self, path=""):
244
+ # load the model
245
+ self.pipeline = Pipeline.from_pretrained("KIFF/pyannote-speaker-diarization-endpoint")
246
+
247
+ def __call__(self, data: Dict[str, bytes]) -> Dict[str, str]:
248
+ """
249
+ Args:
250
+ data (:obj:):
251
+ includes the deserialized audio file as bytes
252
+ Return:
253
+ A :obj:`dict`:. base64 encoded image
254
+ """
255
+ # process input
256
+ inputs = data.pop("inputs", data)
257
+ parameters = data.pop("parameters", None) # min_speakers=2, max_speakers=5
258
+
259
+ # decode the base64 audio data
260
+ audio_data = base64.b64decode(inputs)
261
+ audio_nparray = np.frombuffer(audio_data, dtype=np.int16)
262
+
263
+ # prepare pynannote input
264
+ audio_tensor= torch.from_numpy(audio_nparray).float().unsqueeze(0)
265
+ pyannote_input = {"waveform": audio_tensor, "sample_rate": SAMPLE_RATE}
266
+
267
+ # apply pretrained pipeline
268
+ # pass inputs with all kwargs in data
269
+ if parameters is not None:
270
+ diarization = self.pipeline(pyannote_input, **parameters)
271
+ else:
272
+ diarization = self.pipeline(pyannote_input)
273
+
274
+ # postprocess the prediction
275
+ processed_diarization = [
276
+ {"label": str(label), "start": str(segment.start), "stop": str(segment.end)}
277
+ for segment, _, label in diarization.itertracks(yield_label=True)
278
+ ]
279
+
280
+ return {"diarization": processed_diarization}