maolin.liu
commited on
Commit
·
705afb7
1
Parent(s):
e335f35
[bugfix]Revise response body field value.
Browse files
server.py
CHANGED
@@ -21,6 +21,8 @@ async def register_init(app: FastAPI):
|
|
21 |
|
22 |
:return:
|
23 |
"""
|
|
|
|
|
24 |
|
25 |
yield
|
26 |
|
@@ -51,7 +53,16 @@ app = create_app()
|
|
51 |
|
52 |
model_size = os.environ.get('WHISPER-MODEL-SIZE', 'large-v3')
|
53 |
# Run on GPU with FP16
|
54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
|
56 |
|
57 |
class TranscribeRequestParams(BaseModel):
|
@@ -68,7 +79,12 @@ async def transcribe_api(
|
|
68 |
try:
|
69 |
audio_file = io.BytesIO(base64.b64decode(obj.audio_file))
|
70 |
|
71 |
-
|
|
|
|
|
|
|
|
|
|
|
72 |
except Exception as exc:
|
73 |
logging.exception(exc)
|
74 |
response_body = {
|
@@ -93,7 +109,12 @@ async def transcribe_file_api(
|
|
93 |
language: typing.Literal['en', 'zh']
|
94 |
):
|
95 |
try:
|
96 |
-
|
|
|
|
|
|
|
|
|
|
|
97 |
except Exception as exc:
|
98 |
logging.exception(exc)
|
99 |
response_body = {
|
@@ -134,7 +155,12 @@ async def transcribe_ws_api(
|
|
134 |
try:
|
135 |
audio_file = io.BytesIO(base64.b64decode(form.audio_file))
|
136 |
|
137 |
-
|
|
|
|
|
|
|
|
|
|
|
138 |
except Exception as exc:
|
139 |
logging.exception(exc)
|
140 |
response_body = {
|
|
|
21 |
|
22 |
:return:
|
23 |
"""
|
24 |
+
print('Loading ASR model...')
|
25 |
+
setup_asr_model()
|
26 |
|
27 |
yield
|
28 |
|
|
|
53 |
|
54 |
model_size = os.environ.get('WHISPER-MODEL-SIZE', 'large-v3')
|
55 |
# Run on GPU with FP16
|
56 |
+
asr_model: typing.Optional[WhisperModel] = None
|
57 |
+
|
58 |
+
|
59 |
+
def setup_asr_model():
|
60 |
+
global asr_model
|
61 |
+
if asr_model is None:
|
62 |
+
logging.info('Loading ASR model...')
|
63 |
+
asr_model = WhisperModel(model_size, device='cuda', compute_type='float16')
|
64 |
+
logging.info('Load ASR model finished.')
|
65 |
+
return asr_model
|
66 |
|
67 |
|
68 |
class TranscribeRequestParams(BaseModel):
|
|
|
79 |
try:
|
80 |
audio_file = io.BytesIO(base64.b64decode(obj.audio_file))
|
81 |
|
82 |
+
segments, _ = asr_model.transcribe(audio_file, language=obj.language)
|
83 |
+
|
84 |
+
transcribed_text = ''
|
85 |
+
for segment in segments:
|
86 |
+
transcribed_text = segment.text
|
87 |
+
break
|
88 |
except Exception as exc:
|
89 |
logging.exception(exc)
|
90 |
response_body = {
|
|
|
109 |
language: typing.Literal['en', 'zh']
|
110 |
):
|
111 |
try:
|
112 |
+
segments, _ = asr_model.transcribe(audio_file.file, language=language)
|
113 |
+
|
114 |
+
transcribed_text = ''
|
115 |
+
for segment in segments:
|
116 |
+
transcribed_text = segment.text
|
117 |
+
break
|
118 |
except Exception as exc:
|
119 |
logging.exception(exc)
|
120 |
response_body = {
|
|
|
155 |
try:
|
156 |
audio_file = io.BytesIO(base64.b64decode(form.audio_file))
|
157 |
|
158 |
+
segments, _ = asr_model.transcribe(audio_file, language=form.language)
|
159 |
+
|
160 |
+
transcribed_text = ''
|
161 |
+
for segment in segments:
|
162 |
+
transcribed_text = segment.text
|
163 |
+
break
|
164 |
except Exception as exc:
|
165 |
logging.exception(exc)
|
166 |
response_body = {
|