S-Fry commited on
Commit
e4b911e
1 Parent(s): d9964f9

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +23 -28
handler.py CHANGED
@@ -1,35 +1,30 @@
 
1
  from typing import Dict
2
- from transformers import WhisperProcessor, WhisperForConditionalGeneration
3
- from transformers.pipelines.audio_utils import ffmpeg_read
4
- #import Torch
5
- #from datasets import load_dataset
6
-
7
-
8
- SAMPLE_RATE = 16000
9
 
 
10
  class EndpointHandler():
11
  def __init__(self, path=""):
12
- # load the model
13
- self.processor = WhisperProcessor.from_pretrained("openai/whisper-large-v2")
14
- self.model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large-v2")
15
- self.classifier = AudioClassificationPipeline(model=self.model, processor=self.processor, device=0)
16
- self.forced_decoder_ids = self.processor.get_decoder_prompt_ids(language="Danish", task="transcribe")
17
-
 
 
18
  def __call__(self, data: Dict[str, bytes]) -> Dict[str, str]:
19
- """
20
- Args:
21
- data (:obj:):
22
- includes the deserialized audio file as bytes
23
- Return:
24
- A :obj:`dict`:. base64 encoded image
25
- """
26
- # process input
27
  inputs = data.pop("inputs", data)
28
  audio_nparray = ffmpeg_read(inputs, sample_rate=SAMPLE_RATE)
29
- audio_tensor= torch.from_numpy(audio_nparray)
30
-
31
- # run inference pipeline
32
- result = self.classifier(audio_nparray)
33
-
34
- # postprocess the prediction
35
- return {"txt": result[0]["transcription"]}
 
 
 
1
+ import torch
2
  from typing import Dict
3
+ from transformers import pipeline
4
+ from datasets import load_dataset
 
 
 
 
 
5
 
6
+ SAMPLE_RATE=16000
7
  class EndpointHandler():
8
  def __init__(self, path=""):
9
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
10
+ pipe = pipeline(
11
+ "automatic-speech-recognition",
12
+ model="openai/whisper-large",
13
+ chunk_length_s=30,
14
+ device=device,
15
+ )
16
+
17
  def __call__(self, data: Dict[str, bytes]) -> Dict[str, str]:
18
+ #ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
19
+ #sample = ds[0]["audio"]
 
 
 
 
 
 
20
  inputs = data.pop("inputs", data)
21
  audio_nparray = ffmpeg_read(inputs, sample_rate=SAMPLE_RATE)
22
+ audio_tensor = torch.from_numpy(audio_nparray)
23
+
24
+ prediction = pipe(audio_nparray, return_timestamps=True)
25
+ return {"text": prediction[0]["transcription"]}
26
+
27
+ # we can also return timestamps for the predictions
28
+ #prediction = pipe(sample, return_timestamps=True)["chunks"]
29
+ #[{'text': ' Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.',
30
+ # 'timestamp': (0.0, 5.44)}]