srd4 commited on
Commit
2f2c6bc
1 Parent(s): 4f4670a

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +14 -17
handler.py CHANGED
@@ -1,34 +1,31 @@
1
  from typing import Dict
2
  from faster_whisper import WhisperModel
3
  import io
4
- import re
5
 
6
  class EndpointHandler:
7
  def __init__(self, model_dir=None):
8
- # The compute_type is set to "float16" for efficient GPU computation
9
- # For "int8" computation on CPU, the compute_type would be "int8"
10
- compute_type = "float16"
11
-
12
- # Initialize WhisperModel with large-v2 model size and specified compute_type
13
- model_size = "large-v2" if model_dir is None else model_dir
14
- self.model = WhisperModel(model_size, device="cuda", compute_type=compute_type)
15
 
16
  def __call__(self, data: Dict) -> Dict[str, str]:
 
17
  audio_bytes = data["inputs"]
 
 
18
  audio_file = io.BytesIO(audio_bytes)
19
-
20
- # Transcribe audio file with a smaller beam size for faster inference
21
- # Note: Adjust beam_size based on desired accuracy vs speed trade-off
22
- beam_size = 1
23
- segments, info = self.model.transcribe(audio_file, beam_size=beam_size)
24
-
25
- # Aggregate transcribed text and remove any extra spaces
26
  text = " ".join(segment.text.strip() for segment in segments)
27
- text = re.sub(' +', ' ', text)
28
-
29
  language_code = info.language
30
  language_prob = info.language_probability
31
 
 
32
  result = {
33
  "text": text,
34
  "language": language_code,
 
1
  from typing import Dict
2
  from faster_whisper import WhisperModel
3
  import io
 
4
 
5
  class EndpointHandler:
6
  def __init__(self, model_dir=None):
7
+ # Set model size, assuming installation has been done with appropriate model files and setup
8
+ model_size = "medium" if model_dir is None else model_dir
9
+ # Change to 'cuda' to use the GPU, and set compute_type for faster computation
10
+ self.model = WhisperModel(model_size, device="cpu", compute_type="int8")
 
 
 
11
 
12
  def __call__(self, data: Dict) -> Dict[str, str]:
13
+ # Process the input data expected to be in 'inputs' key containing audio file bytes
14
  audio_bytes = data["inputs"]
15
+
16
+ # Convert bytes to a file-like object
17
  audio_file = io.BytesIO(audio_bytes)
18
+
19
+ # Perform transcription using the model
20
+ segments, info = self.model.transcribe(audio_file)
21
+
22
+ # Compile the results into a text string and extract language information
23
+ # Strip whitespace from each segment before joining them
 
24
  text = " ".join(segment.text.strip() for segment in segments)
 
 
25
  language_code = info.language
26
  language_prob = info.language_probability
27
 
28
+ # Compile the response dictionary
29
  result = {
30
  "text": text,
31
  "language": language_code,