ashishkblink commited on
Commit
e0794df
·
verified ·
1 Parent(s): c652221

Upload f5_tts/api.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. f5_tts/api.py +9 -1
f5_tts/api.py CHANGED
@@ -3,6 +3,7 @@ import sys
3
  from importlib.resources import files
4
 
5
  import soundfile as sf
 
6
  import tqdm
7
  from cached_path import cached_path
8
 
@@ -10,6 +11,7 @@ from f5_tts.infer.utils_infer import (
10
  hop_length,
11
  infer_process,
12
  load_model,
 
13
  load_vocoder,
14
  preprocess_ref_audio_text,
15
  remove_silence_for_generated_wav,
@@ -81,9 +83,15 @@ class F5TTS:
81
  else:
82
  raise ValueError(f"Unknown model type: {model_type}")
83
 
 
84
  self.ema_model = load_model(
85
- model_cls, model_cfg, ckpt_file, mel_spec_type, vocab_file, ode_method, use_ema, self.device
86
  )
 
 
 
 
 
87
 
88
  def transcribe(self, ref_audio, language=None):
89
  return transcribe(ref_audio, language)
 
3
  from importlib.resources import files
4
 
5
  import soundfile as sf
6
+ import torch
7
  import tqdm
8
  from cached_path import cached_path
9
 
 
11
  hop_length,
12
  infer_process,
13
  load_model,
14
+ load_checkpoint,
15
  load_vocoder,
16
  preprocess_ref_audio_text,
17
  remove_silence_for_generated_wav,
 
83
  else:
84
  raise ValueError(f"Unknown model type: {model_type}")
85
 
86
+ # Load model architecture
87
  self.ema_model = load_model(
88
+ model_cls, model_cfg, mel_spec_type, vocab_file, ode_method, use_ema, self.device
89
  )
90
+
91
+ # Load checkpoint weights if provided
92
+ if ckpt_file:
93
+ dtype = torch.float32 if mel_spec_type == "bigvgan" else None
94
+ self.ema_model = load_checkpoint(self.ema_model, ckpt_file, self.device, dtype=dtype, use_ema=use_ema)
95
 
96
  def transcribe(self, ref_audio, language=None):
97
  return transcribe(ref_audio, language)