JacobLinCool commited on
Commit
ecd3224
1 Parent(s): bd8dcd1

perf: model lazy load

Browse files
Files changed (1) hide show
  1. app.py +49 -31
app.py CHANGED
@@ -27,38 +27,48 @@ from huggingface_hub import HfApi
27
 
28
  # will use api to restart space on a unrecoverable error
29
  api = HfApi(token=HF_TOKEN)
30
- repo_id = "coqui/xtts"
31
 
32
- print("loading model")
 
33
 
34
- model_name = "tts_models/multilingual/multi-dataset/xtts_v2"
35
- model_path = os.path.join(get_user_data_dir("tts"), model_name.replace("/", "--"))
36
 
37
- config = XttsConfig()
38
- config.load_json(os.path.join(model_path, "config.json"))
 
39
 
40
- model = Xtts.init_from_config(config)
41
- model.load_checkpoint(
42
- config,
43
- checkpoint_path=os.path.join(model_path, "model.pth"),
44
- vocab_path=os.path.join(model_path, "vocab.json"),
45
- eval=True,
46
- use_deepspeed=False,
47
- )
48
 
49
- if torch.cuda.is_available():
50
- model.cuda()
51
- else:
52
- model.cpu()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
- print("Model loaded")
55
 
56
  # This is for debugging purposes only
57
  DEVICE_ASSERT_DETECTED = 0
58
  DEVICE_ASSERT_PROMPT = None
59
  DEVICE_ASSERT_LANG = None
60
 
61
- supported_languages = config.languages
62
 
63
  def predict(
64
  prompt,
@@ -68,6 +78,9 @@ def predict(
68
  no_lang_auto_detect,
69
  agree,
70
  ):
 
 
 
71
  if agree == True:
72
  if language not in supported_languages:
73
  gr.Warning(
@@ -184,7 +197,7 @@ def predict(
184
 
185
  # HF Space specific.. This error is unrecoverable need to restart space
186
  space = api.get_space_runtime(repo_id=repo_id)
187
- if space.stage!="BUILDING":
188
  api.restart_space(repo_id=repo_id)
189
  else:
190
  print("TRIED TO RESTART but space is building")
@@ -198,7 +211,9 @@ def predict(
198
  (
199
  gpt_cond_latent,
200
  speaker_embedding,
201
- ) = model.get_conditioning_latents(audio_path=speaker_wav, gpt_cond_len=30, max_ref_length=60)
 
 
202
  except Exception as e:
203
  print("Speaker encoding error", str(e))
204
  gr.Warning(
@@ -215,7 +230,7 @@ def predict(
215
  # metrics_text=f"Embedding calculation time: {latent_calculation_time:.2f} seconds\n"
216
 
217
  # temporary comma fix
218
- prompt= re.sub("([^\x00-\x7F]|\w)(\.|\。|\?)",r"\1 \2\2",prompt)
219
 
220
  wav_chunks = []
221
  ## Direct mode
@@ -260,9 +275,9 @@ def predict(
260
  print(
261
  f"I: Time to generate audio: {round(inference_time*1000)} milliseconds"
262
  )
263
- #metrics_text += (
264
  # f"Time to generate audio: {round(inference_time*1000)} milliseconds\n"
265
- #)
266
 
267
  wav = torch.cat(wav_chunks, dim=0)
268
  print(wav.shape)
@@ -330,11 +345,11 @@ def predict(
330
 
331
  # HF Space specific.. This error is unrecoverable need to restart space
332
  space = api.get_space_runtime(repo_id=repo_id)
333
- if space.stage!="BUILDING":
334
  api.restart_space(repo_id=repo_id)
335
  else:
336
  print("TRIED TO RESTART but space is building")
337
-
338
  else:
339
  if "Failed to decode" in str(e):
340
  print("Speaker encoding error", str(e))
@@ -459,7 +474,7 @@ with gr.Blocks(analytics_enabled=False) as demo:
459
  "zh-cn",
460
  "ja",
461
  "ko",
462
- "hu"
463
  ],
464
  value="en",
465
  )
@@ -487,14 +502,17 @@ with gr.Blocks(analytics_enabled=False) as demo:
487
 
488
  tts_button = gr.Button("Send", elem_id="send-btn", visible=True)
489
 
490
-
491
  with gr.Column():
492
  video_gr = gr.Video(label="Waveform Visual")
493
  audio_gr = gr.Audio(label="Synthesised Audio", autoplay=True)
494
  out_text_gr = gr.Text(label="Metrics")
495
  ref_audio_gr = gr.Audio(label="Reference Audio Used")
496
 
497
- tts_button.click(predict, [input_text_gr, language_gr, ref_gr, clean_ref_gr, auto_det_lang_gr, tos_gr], outputs=[video_gr, audio_gr, out_text_gr, ref_audio_gr])
498
-
 
 
 
 
499
  print("Starting server")
500
  demo.queue().launch(debug=True, show_api=True)
 
27
 
28
  # will use api to restart space on a unrecoverable error
29
  api = HfApi(token=HF_TOKEN)
30
+ repo_id = "JacobLinCool/xtts-v2"
31
 
32
+ model = None
33
+ supported_languages = None
34
 
 
 
35
 
36
+ def load_model():
37
+ global model
38
+ global supported_languages
39
 
40
+ print("loading model")
 
 
 
 
 
 
 
41
 
42
+ model_name = "tts_models/multilingual/multi-dataset/xtts_v2"
43
+ model_path = os.path.join(get_user_data_dir("tts"), model_name.replace("/", "--"))
44
+
45
+ config = XttsConfig()
46
+ config.load_json(os.path.join(model_path, "config.json"))
47
+
48
+ model = Xtts.init_from_config(config)
49
+ model.load_checkpoint(
50
+ config,
51
+ checkpoint_path=os.path.join(model_path, "model.pth"),
52
+ vocab_path=os.path.join(model_path, "vocab.json"),
53
+ eval=True,
54
+ use_deepspeed=False,
55
+ )
56
+
57
+ if torch.cuda.is_available():
58
+ model.cuda()
59
+ else:
60
+ model.cpu()
61
+
62
+ supported_languages = config.languages
63
+
64
+ print("Model loaded")
65
 
 
66
 
67
  # This is for debugging purposes only
68
  DEVICE_ASSERT_DETECTED = 0
69
  DEVICE_ASSERT_PROMPT = None
70
  DEVICE_ASSERT_LANG = None
71
 
 
72
 
73
  def predict(
74
  prompt,
 
78
  no_lang_auto_detect,
79
  agree,
80
  ):
81
+ if model is None:
82
+ load_model()
83
+
84
  if agree == True:
85
  if language not in supported_languages:
86
  gr.Warning(
 
197
 
198
  # HF Space specific.. This error is unrecoverable need to restart space
199
  space = api.get_space_runtime(repo_id=repo_id)
200
+ if space.stage != "BUILDING":
201
  api.restart_space(repo_id=repo_id)
202
  else:
203
  print("TRIED TO RESTART but space is building")
 
211
  (
212
  gpt_cond_latent,
213
  speaker_embedding,
214
+ ) = model.get_conditioning_latents(
215
+ audio_path=speaker_wav, gpt_cond_len=30, max_ref_length=60
216
+ )
217
  except Exception as e:
218
  print("Speaker encoding error", str(e))
219
  gr.Warning(
 
230
  # metrics_text=f"Embedding calculation time: {latent_calculation_time:.2f} seconds\n"
231
 
232
  # temporary comma fix
233
+ prompt = re.sub("([^\x00-\x7F]|\w)(\.|\。|\?)", r"\1 \2\2", prompt)
234
 
235
  wav_chunks = []
236
  ## Direct mode
 
275
  print(
276
  f"I: Time to generate audio: {round(inference_time*1000)} milliseconds"
277
  )
278
+ # metrics_text += (
279
  # f"Time to generate audio: {round(inference_time*1000)} milliseconds\n"
280
+ # )
281
 
282
  wav = torch.cat(wav_chunks, dim=0)
283
  print(wav.shape)
 
345
 
346
  # HF Space specific.. This error is unrecoverable need to restart space
347
  space = api.get_space_runtime(repo_id=repo_id)
348
+ if space.stage != "BUILDING":
349
  api.restart_space(repo_id=repo_id)
350
  else:
351
  print("TRIED TO RESTART but space is building")
352
+
353
  else:
354
  if "Failed to decode" in str(e):
355
  print("Speaker encoding error", str(e))
 
474
  "zh-cn",
475
  "ja",
476
  "ko",
477
+ "hu",
478
  ],
479
  value="en",
480
  )
 
502
 
503
  tts_button = gr.Button("Send", elem_id="send-btn", visible=True)
504
 
 
505
  with gr.Column():
506
  video_gr = gr.Video(label="Waveform Visual")
507
  audio_gr = gr.Audio(label="Synthesised Audio", autoplay=True)
508
  out_text_gr = gr.Text(label="Metrics")
509
  ref_audio_gr = gr.Audio(label="Reference Audio Used")
510
 
511
+ tts_button.click(
512
+ predict,
513
+ [input_text_gr, language_gr, ref_gr, clean_ref_gr, auto_det_lang_gr, tos_gr],
514
+ outputs=[video_gr, audio_gr, out_text_gr, ref_audio_gr],
515
+ )
516
+
517
  print("Starting server")
518
  demo.queue().launch(debug=True, show_api=True)