srd4 commited on
Commit
3fee4b7
·
verified ·
1 Parent(s): 85a4262

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +9 -5
handler.py CHANGED
@@ -1,11 +1,12 @@
1
  from typing import Dict
2
- from faster_whisper import WhisperModel
3
  import io
 
4
 
5
  class EndpointHandler:
6
  def __init__(self, model_dir=None):
7
  # Set model size, assuming installation has been done with appropriate model files and setup
8
- model_size = "medium" if model_dir is None else model_dir
9
  # Change to 'cuda' to use the GPU, and set compute_type for faster computation
10
  self.model = WhisperModel(model_size, device="cuda", compute_type="float16")
11
 
@@ -16,11 +17,14 @@ class EndpointHandler:
16
  # Convert bytes to a file-like object
17
  audio_file = io.BytesIO(audio_bytes)
18
 
19
- # Perform transcription using the model
20
- segments, info = self.model.transcribe(audio_file)
 
21
 
22
  # Compile the results into a text string and extract language information
23
- text = " ".join(segment.text for segment in segments)
 
 
24
  language_code = info.language
25
  language_prob = info.language_probability
26
 
 
1
  from typing import Dict
2
+ from faster_whisper import WhisperModel, Streaming
3
  import io
4
+ import re
5
 
6
  class EndpointHandler:
7
  def __init__(self, model_dir=None):
8
  # Set model size, assuming installation has been done with appropriate model files and setup
9
+ model_size = "large-v2" if model_dir is None else model_dir
10
  # Change to 'cuda' to use the GPU, and set compute_type for faster computation
11
  self.model = WhisperModel(model_size, device="cuda", compute_type="float16")
12
 
 
17
  # Convert bytes to a file-like object
18
  audio_file = io.BytesIO(audio_bytes)
19
 
20
+ # Enable VAD and perform transcription using the model with a reduced beam size
21
+ streaming = Streaming(device="cuda", compute_type="float16", vad=True)
22
+ segments, info = streaming.transcribe(audio_file, beam_size=1)
23
 
24
  # Compile the results into a text string and extract language information
25
+ # Strip leading and trailing whitespace and replace multiple spaces with a single space
26
+ text = " ".join(segment.text.strip() for segment in segments)
27
+ text = re.sub(' +', ' ', text)
28
  language_code = info.language
29
  language_prob = info.language_probability
30