Blaxzter commited on
Commit
8d61b73
1 Parent(s): e202d39

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +13 -51
handler.py CHANGED
@@ -4,27 +4,31 @@ import os
4
  from io import StringIO
5
  from typing import Dict, Any
6
 
 
7
  from transformers import pipeline
8
 
9
 
10
  class EndpointHandler:
11
 
12
  def __init__(self, asr_model_path: str = "./whisper-large-v2"):
 
 
 
13
  # Create an ASR pipeline using the model located in the specified directory
14
  self.asr_pipeline = pipeline(
15
  "automatic-speech-recognition",
16
  model = asr_model_path,
 
17
  )
18
 
19
  def __call__(self, data: Dict[str, Any]) -> str:
20
 
21
- json_data = json.loads(data)
22
- if "audio_data" not in json_data.keys():
23
  raise Exception("Request must contain a top-level key named 'audio_data'")
24
 
25
  # Get the audio data from the input
26
- audio_data = json_data["audio_data"]
27
- language = json_data["language"]
28
 
29
  # Decode the binary audio data if it's provided as a base64 string
30
  if isinstance(audio_data, str):
@@ -33,12 +37,11 @@ class EndpointHandler:
33
  # Process the audio data with the ASR pipeline
34
  transcription = self.asr_pipeline(
35
  audio_data,
36
- return_timestamps=False,
37
- chunk_length_s=30,
38
- batch_size=8,
39
- max_length=10000,
40
- max_new_tokens=10000,
41
- generate_kwargs={"task": "transcribe", "language": "<|language|>"}
42
  )
43
 
44
  # Convert the transcription to JSON
@@ -46,44 +49,3 @@ class EndpointHandler:
46
  json.dump(transcription, result)
47
 
48
  return result.getvalue()
49
-
50
- def init():
51
- global asr_pipeline
52
- # Set the path to the directory where the model is stored
53
- model_path = os.getenv("AZUREML_MODEL_DIR", "./whisper-large-v2")
54
-
55
- # Create an ASR pipeline using the model located in the specified directory
56
- asr_pipeline = pipeline(
57
- "automatic-speech-recognition",
58
- model = model_path,
59
- )
60
-
61
-
62
- def run(raw_data):
63
- json_data = json.loads(raw_data)
64
- if "audio_data" not in json_data.keys():
65
- raise Exception("Request must contain a top level key named 'audio_data'")
66
-
67
- # Get the audio data from the input
68
- audio_data = json_data["audio_data"]
69
-
70
- # Decode the binary audio data if it's provided as a base64 string
71
- if isinstance(audio_data, str):
72
- import base64
73
- audio_data = base64.b64decode(audio_data)
74
-
75
- # Process the audio data with the ASR pipeline
76
- transcription = asr_pipeline(
77
- audio_data,
78
- return_timestamps = False,
79
- chunk_length_s = 30,
80
- batch_size = 8,
81
- max_new_tokens = 1000,
82
- generate_kwargs = {"task": "transcribe", "language": "<|de|>"}
83
- )
84
-
85
- # Convert the transcription to JSON
86
- result = StringIO()
87
- json.dump(transcription, result)
88
-
89
- return result.getvalue()
 
4
  from io import StringIO
5
  from typing import Dict, Any
6
 
7
+ import torch
8
  from transformers import pipeline
9
 
10
 
11
  class EndpointHandler:
12
 
13
  def __init__(self, asr_model_path: str = "./whisper-large-v2"):
14
+ device = 0 if torch.cuda.is_available() else -1
15
+ device = -1
16
+ print("Using device:", device)
17
  # Create an ASR pipeline using the model located in the specified directory
18
  self.asr_pipeline = pipeline(
19
  "automatic-speech-recognition",
20
  model = asr_model_path,
21
+ device = device
22
  )
23
 
24
  def __call__(self, data: Dict[str, Any]) -> str:
25
 
26
+ if "audio_data" not in data.keys():
 
27
  raise Exception("Request must contain a top-level key named 'audio_data'")
28
 
29
  # Get the audio data from the input
30
+ audio_data = data["audio_data"]
31
+ options = data["options"]
32
 
33
  # Decode the binary audio data if it's provided as a base64 string
34
  if isinstance(audio_data, str):
 
37
  # Process the audio data with the ASR pipeline
38
  transcription = self.asr_pipeline(
39
  audio_data,
40
+ return_timestamps = True,
41
+ chunk_length_s = 30,
42
+ batch_size = 8,
43
+ max_new_tokens = 10000,
44
+ generate_kwargs = options
 
45
  )
46
 
47
  # Convert the transcription to JSON
 
49
  json.dump(transcription, result)
50
 
51
  return result.getvalue()