Siddhant commited on
Commit
94b0033
1 Parent(s): 5092f40

Add warmup before start

Browse files
Files changed (1) hide show
  1. app.py +39 -6
app.py CHANGED
@@ -22,6 +22,7 @@ from espnet2.sds.eval.LLM_Metrics import perplexity, vert, bert_score, DialoGPT_
22
  from espnet2.sds.utils.chat import Chat
23
  from espnet2.sds.end_to_end.mini_omni_e2e import MiniOmniE2EModel
24
  import argparse
 
25
 
26
  access_token = os.environ.get("HF_TOKEN")
27
  ASR_name="pyf98/owsm_ctc_v3.1_1B"
@@ -188,12 +189,44 @@ def handle_E2E_selection():
188
  client = MiniOmniE2EModel()
189
  client.warmup()
190
 
191
- for _ in handle_selection(TTS_name):
192
- continue
193
- for _ in handle_ASR_selection(ASR_name):
194
- continue
195
- for _ in handle_LLM_selection(LLM_name):
196
- continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
  vad_model=WebrtcVADModel()
198
 
199
  callback = gr.CSVLogger()
 
22
  from espnet2.sds.utils.chat import Chat
23
  from espnet2.sds.end_to_end.mini_omni_e2e import MiniOmniE2EModel
24
  import argparse
25
+ import torch
26
 
27
  access_token = os.environ.get("HF_TOKEN")
28
  ASR_name="pyf98/owsm_ctc_v3.1_1B"
 
189
  client = MiniOmniE2EModel()
190
  client.warmup()
191
 
192
+ def start_warmup():
193
+ global client
194
+ for opt in ASR_options:
195
+ if opt==ASR_name:
196
+ continue
197
+ print(opt)
198
+ for _ in handle_ASR_selection(opt):
199
+ continue
200
+ for opt in LLM_options:
201
+ if opt==LLM_name:
202
+ continue
203
+ print(opt)
204
+ for _ in handle_LLM_selection(opt):
205
+ continue
206
+ for opt in TTS_options:
207
+ if opt==TTS_name:
208
+ continue
209
+ print(opt)
210
+ for _ in handle_selection(opt):
211
+ continue
212
+ handle_E2E_selection()
213
+ client=None
214
+ for _ in handle_selection(TTS_name):
215
+ continue
216
+ for _ in handle_ASR_selection(ASR_name):
217
+ continue
218
+ for _ in handle_LLM_selection(LLM_name):
219
+ continue
220
+ dummy_input = torch.randn(
221
+ (3000),
222
+ dtype=getattr(torch, "float16"),
223
+ device="cpu",
224
+ ).cpu().numpy()
225
+ dummy_text="This is dummy text"
226
+ for opt in Eval_options:
227
+ handle_eval_selection(opt, dummy_input, dummy_text, dummy_input, dummy_text)
228
+
229
+ start_warmup()
230
  vad_model=WebrtcVADModel()
231
 
232
  callback = gr.CSVLogger()