KIFF commited on
Commit
5dbbf5e
·
verified ·
1 Parent(s): 7f3cec9

Delete create_handler.ipynb

Browse files
Files changed (1) hide show
  1. create_handler.ipynb +0 -280
create_handler.ipynb DELETED
@@ -1,280 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "markdown",
5
- "metadata": {},
6
- "source": [
7
- "## 1. Setup & Installation"
8
- ]
9
- },
10
- {
11
- "cell_type": "code",
12
- "execution_count": 1,
13
- "metadata": {},
14
- "outputs": [
15
- {
16
- "name": "stdout",
17
- "output_type": "stream",
18
- "text": [
19
- "Overwriting requirements.txt\n"
20
- ]
21
- }
22
- ],
23
- "source": [
24
- "%%writefile requirements.txt\n",
25
- "torchaudio==0.11.*\n",
26
- "git+https://github.com/philschmid/pyannote-audio.git"
27
- ]
28
- },
29
- {
30
- "cell_type": "code",
31
- "execution_count": null,
32
- "metadata": {},
33
- "outputs": [],
34
- "source": [
35
- "!pip install -r requirements.txt --upgrade"
36
- ]
37
- },
38
- {
39
- "cell_type": "markdown",
40
- "metadata": {},
41
- "source": [
42
- "## 2. Create Custom Handler for Inference Endpoints\n"
43
- ]
44
- },
45
- {
46
- "cell_type": "code",
47
- "execution_count": 2,
48
- "metadata": {},
49
- "outputs": [
50
- {
51
- "name": "stdout",
52
- "output_type": "stream",
53
- "text": [
54
- "Overwriting handler.py\n"
55
- ]
56
- }
57
- ],
58
- "source": [
59
- "%%writefile handler.py\n",
60
- "from typing import Dict\n",
61
- "from pyannote.audio import Pipeline\n",
62
- "from transformers.pipelines.audio_utils import ffmpeg_read\n",
63
- "import torch \n",
64
- "\n",
65
- "SAMPLE_RATE = 16000\n",
66
- "\n",
67
- "\n",
68
- "\n",
69
- "class EndpointHandler():\n",
70
- " def __init__(self, path=\"\"):\n",
71
- " # load the model\n",
72
- " self.pipeline = Pipeline.from_pretrained(\"pyannote/speaker-diarization\")\n",
73
- "\n",
74
- "\n",
75
- " def __call__(self, data: Dict[str, bytes]) -> Dict[str, str]:\n",
76
- " \"\"\"\n",
77
- " Args:\n",
78
- " data (:obj:):\n",
79
- " includes the deserialized audio file as bytes\n",
80
- " Return:\n",
81
- " A :obj:`dict`:. base64 encoded image\n",
82
- " \"\"\"\n",
83
- " # process input\n",
84
- " inputs = data.pop(\"inputs\", data)\n",
85
- " parameters = data.pop(\"parameters\", None) # min_speakers=2, max_speakers=5\n",
86
- "\n",
87
- " \n",
88
- " # prepare pynannote input\n",
89
- " audio_nparray = ffmpeg_read(inputs, SAMPLE_RATE)\n",
90
- " audio_tensor= torch.from_numpy(audio_nparray).unsqueeze(0)\n",
91
- " pyannote_input = {\"waveform\": audio_tensor, \"sample_rate\": SAMPLE_RATE}\n",
92
- " \n",
93
- " # apply pretrained pipeline\n",
94
- " # pass inputs with all kwargs in data\n",
95
- " if parameters is not None:\n",
96
- " diarization = self.pipeline(pyannote_input, **parameters)\n",
97
- " else:\n",
98
- " diarization = self.pipeline(pyannote_input)\n",
99
- "\n",
100
- " # postprocess the prediction\n",
101
- " processed_diarization = [\n",
102
- " {\"label\": str(label), \"start\": str(segment.start), \"stop\": str(segment.end)}\n",
103
- " for segment, _, label in diarization.itertracks(yield_label=True)\n",
104
- " ]\n",
105
- " \n",
106
- " return {\"diarization\": processed_diarization}"
107
- ]
108
- },
109
- {
110
- "cell_type": "markdown",
111
- "metadata": {},
112
- "source": [
113
- "test custom pipeline"
114
- ]
115
- },
116
- {
117
- "cell_type": "code",
118
- "execution_count": 1,
119
- "metadata": {},
120
- "outputs": [],
121
- "source": [
122
- "from handler import EndpointHandler\n",
123
- "\n",
124
- "# init handler\n",
125
- "my_handler = EndpointHandler(path=\".\")"
126
- ]
127
- },
128
- {
129
- "cell_type": "code",
130
- "execution_count": 2,
131
- "metadata": {},
132
- "outputs": [],
133
- "source": [
134
- "import base64\n",
135
- "from PIL import Image\n",
136
- "from io import BytesIO\n",
137
- "import json\n",
138
- "\n",
139
- "# file reader\n",
140
- "with open(\"sample.wav\", \"rb\") as f:\n",
141
- " request = {\"inputs\": f.read()}\n",
142
- "\n",
143
- "# test the handler\n",
144
- "pred = my_handler(request)"
145
- ]
146
- },
147
- {
148
- "cell_type": "code",
149
- "execution_count": 3,
150
- "metadata": {},
151
- "outputs": [
152
- {
153
- "data": {
154
- "text/plain": [
155
- "{'diarization': [{'label': 'SPEAKER_01',\n",
156
- " 'start': '0.4978125',\n",
157
- " 'stop': '1.3921875'},\n",
158
- " {'label': 'SPEAKER_01', 'start': '1.8984375', 'stop': '2.7590624999999998'},\n",
159
- " {'label': 'SPEAKER_02', 'start': '2.9953125', 'stop': '3.5015625000000004'},\n",
160
- " {'label': 'SPEAKER_01',\n",
161
- " 'start': '3.5690625000000002',\n",
162
- " 'stop': '4.311562500000001'},\n",
163
- " {'label': 'SPEAKER_02', 'start': '4.6153125', 'stop': '6.7753125'},\n",
164
- " {'label': 'SPEAKER_00', 'start': '7.1128125', 'stop': '7.551562500000001'},\n",
165
- " {'label': 'SPEAKER_02',\n",
166
- " 'start': '7.551562500000001',\n",
167
- " 'stop': '9.475312500000001'},\n",
168
- " {'label': 'SPEAKER_02',\n",
169
- " 'start': '9.812812500000003',\n",
170
- " 'stop': '10.555312500000003'},\n",
171
- " {'label': 'SPEAKER_00',\n",
172
- " 'start': '9.863437500000003',\n",
173
- " 'stop': '10.420312500000001'},\n",
174
- " {'label': 'SPEAKER_03', 'start': '12.411562500000002', 'stop': '15.5503125'},\n",
175
- " {'label': 'SPEAKER_00', 'start': '15.786562500000002', 'stop': '16.1409375'},\n",
176
- " {'label': 'SPEAKER_01', 'start': '16.1409375', 'stop': '16.1578125'},\n",
177
- " {'label': 'SPEAKER_00', 'start': '17.1534375', 'stop': '17.4234375'},\n",
178
- " {'label': 'SPEAKER_01', 'start': '17.7440625', 'stop': '20.3596875'},\n",
179
- " {'label': 'SPEAKER_01', 'start': '20.6128125', 'stop': '20.6634375'},\n",
180
- " {'label': 'SPEAKER_00', 'start': '20.6634375', 'stop': '20.8490625'},\n",
181
- " {'label': 'SPEAKER_01', 'start': '20.8490625', 'stop': '20.8828125'},\n",
182
- " {'label': 'SPEAKER_01', 'start': '21.1021875', 'stop': '22.1315625'},\n",
183
- " {'label': 'SPEAKER_02', 'start': '22.4521875', 'stop': '22.7053125'},\n",
184
- " {'label': 'SPEAKER_02', 'start': '23.2115625', 'stop': '23.4815625'},\n",
185
- " {'label': 'SPEAKER_01', 'start': '23.4815625', 'stop': '24.0215625'},\n",
186
- " {'label': 'SPEAKER_02', 'start': '24.3253125', 'stop': '25.5065625'},\n",
187
- " {'label': 'SPEAKER_01', 'start': '25.8440625', 'stop': '27.3121875'},\n",
188
- " {'label': 'SPEAKER_02', 'start': '27.3121875', 'stop': '27.4978125'},\n",
189
- " {'label': 'SPEAKER_01', 'start': '29.7253125', 'stop': '29.9615625'}]}"
190
- ]
191
- },
192
- "execution_count": 3,
193
- "metadata": {},
194
- "output_type": "execute_result"
195
- }
196
- ],
197
- "source": [
198
- "pred"
199
- ]
200
- }
201
- ],
202
- "metadata": {
203
- "kernelspec": {
204
- "display_name": "Python 3.9.13 ('dev': conda)",
205
- "language": "python",
206
- "name": "python3"
207
- },
208
- "language_info": {
209
- "codemirror_mode": {
210
- "name": "ipython",
211
- "version": 3
212
- },
213
- "file_extension": ".py",
214
- "mimetype": "text/x-python",
215
- "name": "python",
216
- "nbconvert_exporter": "python",
217
- "pygments_lexer": "ipython3",
218
- "version": "3.9.13"
219
- },
220
- "orig_nbformat": 4,
221
- "vscode": {
222
- "interpreter": {
223
- "hash": "f6dd96c16031089903d5a31ec148b80aeb0d39c32affb1a1080393235fbfa2fc"
224
- }
225
- }
226
- },
227
- "nbformat": 4,
228
- "nbformat_minor": 2
229
- }
230
-
231
-
232
- handler.py
233
-
234
- from typing import Dict
235
- from pyannote.audio import Pipeline
236
- import torch
237
- import base64
238
- import numpy as np
239
-
240
- SAMPLE_RATE = 16000
241
-
242
- class EndpointHandler():
243
- def __init__(self, path=""):
244
- # load the model
245
- self.pipeline = Pipeline.from_pretrained("KIFF/pyannote-speaker-diarization-endpoint")
246
-
247
- def __call__(self, data: Dict[str, bytes]) -> Dict[str, str]:
248
- """
249
- Args:
250
- data (:obj:):
251
- includes the deserialized audio file as bytes
252
- Return:
253
- A :obj:`dict`:. base64 encoded image
254
- """
255
- # process input
256
- inputs = data.pop("inputs", data)
257
- parameters = data.pop("parameters", None) # min_speakers=2, max_speakers=5
258
-
259
- # decode the base64 audio data
260
- audio_data = base64.b64decode(inputs)
261
- audio_nparray = np.frombuffer(audio_data, dtype=np.int16)
262
-
263
- # prepare pynannote input
264
- audio_tensor= torch.from_numpy(audio_nparray).float().unsqueeze(0)
265
- pyannote_input = {"waveform": audio_tensor, "sample_rate": SAMPLE_RATE}
266
-
267
- # apply pretrained pipeline
268
- # pass inputs with all kwargs in data
269
- if parameters is not None:
270
- diarization = self.pipeline(pyannote_input, **parameters)
271
- else:
272
- diarization = self.pipeline(pyannote_input)
273
-
274
- # postprocess the prediction
275
- processed_diarization = [
276
- {"label": str(label), "start": str(segment.start), "stop": str(segment.end)}
277
- for segment, _, label in diarization.itertracks(yield_label=True)
278
- ]
279
-
280
- return {"diarization": processed_diarization}