change handler to use whisper model and include segments
Browse files- handler.py +25 -12
handler.py
CHANGED
@@ -2,15 +2,18 @@ from typing import Dict
|
|
2 |
from transformers.pipelines.audio_utils import ffmpeg_read
|
3 |
import whisper
|
4 |
import torch
|
5 |
-
|
6 |
-
SAMPLE_RATE = 16000
|
7 |
-
|
8 |
|
9 |
|
10 |
class EndpointHandler():
|
11 |
def __init__(self, path=""):
|
12 |
# load the model
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
|
16 |
def __call__(self, data: Dict[str, bytes]) -> Dict[str, str]:
|
@@ -19,15 +22,25 @@ class EndpointHandler():
|
|
19 |
data (:obj:):
|
20 |
includes the deserialized audio file as bytes
|
21 |
Return:
|
22 |
-
A :obj:`dict`:.
|
23 |
"""
|
24 |
# process input
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
# postprocess the prediction
|
33 |
-
return {"
|
|
|
2 |
from transformers.pipelines.audio_utils import ffmpeg_read
|
3 |
import whisper
|
4 |
import torch
|
5 |
+
import pytube
|
|
|
|
|
6 |
|
7 |
|
8 |
class EndpointHandler():
|
9 |
def __init__(self, path=""):
|
10 |
# load the model
|
11 |
+
MODEL_NAME = "tiny.en"
|
12 |
+
|
13 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
14 |
+
print(f'whisper will use: {device}')
|
15 |
+
|
16 |
+
whisper_model = whisper.load_model(MODEL_NAME).to(device)
|
17 |
|
18 |
|
19 |
def __call__(self, data: Dict[str, bytes]) -> Dict[str, str]:
|
|
|
22 |
data (:obj:):
|
23 |
includes the deserialized audio file as bytes
|
24 |
Return:
|
25 |
+
A :obj:`dict`:. URL to image for download
|
26 |
"""
|
27 |
# process input
|
28 |
+
print('data', data)
|
29 |
+
video_url = data.pop("inputs", data)
|
30 |
+
decode_options = {
|
31 |
+
# Set language to None to support multilingual,
|
32 |
+
# but it will take longer to process while it detects the language.
|
33 |
+
# Realized this by running in verbose mode and seeing how much time
|
34 |
+
# was spent on the decoding language step
|
35 |
+
"language":"en",
|
36 |
+
verbose: True
|
37 |
+
}
|
38 |
+
yt = pt.YouTube(video_url)
|
39 |
+
stream = yt.streams.filter(only_audio=True)[0]
|
40 |
+
path_to_audio = f"{yt.video_id}.mp3"
|
41 |
+
stream.download(filename=path_to_audio)
|
42 |
+
|
43 |
+
transcript = self.model.transcribe(path_to_audio, **decode_options)
|
44 |
|
45 |
# postprocess the prediction
|
46 |
+
return {"transcript": transcript}
|