Spaces:
Sleeping
Sleeping
Siddhant
commited on
Commit
•
94b0033
1
Parent(s):
5092f40
Add warmup before start
Browse files
app.py
CHANGED
@@ -22,6 +22,7 @@ from espnet2.sds.eval.LLM_Metrics import perplexity, vert, bert_score, DialoGPT_
|
|
22 |
from espnet2.sds.utils.chat import Chat
|
23 |
from espnet2.sds.end_to_end.mini_omni_e2e import MiniOmniE2EModel
|
24 |
import argparse
|
|
|
25 |
|
26 |
access_token = os.environ.get("HF_TOKEN")
|
27 |
ASR_name="pyf98/owsm_ctc_v3.1_1B"
|
@@ -188,12 +189,44 @@ def handle_E2E_selection():
|
|
188 |
client = MiniOmniE2EModel()
|
189 |
client.warmup()
|
190 |
|
191 |
-
|
192 |
-
|
193 |
-
for
|
194 |
-
|
195 |
-
|
196 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
197 |
vad_model=WebrtcVADModel()
|
198 |
|
199 |
callback = gr.CSVLogger()
|
|
|
22 |
from espnet2.sds.utils.chat import Chat
|
23 |
from espnet2.sds.end_to_end.mini_omni_e2e import MiniOmniE2EModel
|
24 |
import argparse
|
25 |
+
import torch
|
26 |
|
27 |
access_token = os.environ.get("HF_TOKEN")
|
28 |
ASR_name="pyf98/owsm_ctc_v3.1_1B"
|
|
|
189 |
client = MiniOmniE2EModel()
|
190 |
client.warmup()
|
191 |
|
192 |
+
def start_warmup():
|
193 |
+
global client
|
194 |
+
for opt in ASR_options:
|
195 |
+
if opt==ASR_name:
|
196 |
+
continue
|
197 |
+
print(opt)
|
198 |
+
for _ in handle_ASR_selection(opt):
|
199 |
+
continue
|
200 |
+
for opt in LLM_options:
|
201 |
+
if opt==LLM_name:
|
202 |
+
continue
|
203 |
+
print(opt)
|
204 |
+
for _ in handle_LLM_selection(opt):
|
205 |
+
continue
|
206 |
+
for opt in TTS_options:
|
207 |
+
if opt==TTS_name:
|
208 |
+
continue
|
209 |
+
print(opt)
|
210 |
+
for _ in handle_selection(opt):
|
211 |
+
continue
|
212 |
+
handle_E2E_selection()
|
213 |
+
client=None
|
214 |
+
for _ in handle_selection(TTS_name):
|
215 |
+
continue
|
216 |
+
for _ in handle_ASR_selection(ASR_name):
|
217 |
+
continue
|
218 |
+
for _ in handle_LLM_selection(LLM_name):
|
219 |
+
continue
|
220 |
+
dummy_input = torch.randn(
|
221 |
+
(3000),
|
222 |
+
dtype=getattr(torch, "float16"),
|
223 |
+
device="cpu",
|
224 |
+
).cpu().numpy()
|
225 |
+
dummy_text="This is dummy text"
|
226 |
+
for opt in Eval_options:
|
227 |
+
handle_eval_selection(opt, dummy_input, dummy_text, dummy_input, dummy_text)
|
228 |
+
|
229 |
+
start_warmup()
|
230 |
vad_model=WebrtcVADModel()
|
231 |
|
232 |
callback = gr.CSVLogger()
|