mrfakename commited on
Commit
e2287e3
·
verified ·
1 Parent(s): cf0b618

Sync from GitHub repo

Browse files

This Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there

src/f5_tts/api.py CHANGED
@@ -49,10 +49,10 @@ class F5TTS:
49
  self.load_vocoder_model(vocoder_name, local_path=local_path)
50
  self.load_ema_model(model_type, ckpt_file, vocoder_name, vocab_file, ode_method, use_ema, local_path=local_path)
51
 
52
- def load_vocoder_model(self, vocoder_name, local_path):
53
  self.vocoder = load_vocoder(vocoder_name, local_path is not None, local_path, self.device)
54
 
55
- def load_ema_model(self, model_type, ckpt_file, mel_spec_type, vocab_file, ode_method, use_ema, local_path):
56
  if model_type == "F5-TTS":
57
  if not ckpt_file:
58
  if mel_spec_type == "vocos":
 
49
  self.load_vocoder_model(vocoder_name, local_path=local_path)
50
  self.load_ema_model(model_type, ckpt_file, vocoder_name, vocab_file, ode_method, use_ema, local_path=local_path)
51
 
52
+ def load_vocoder_model(self, vocoder_name, local_path=None):
53
  self.vocoder = load_vocoder(vocoder_name, local_path is not None, local_path, self.device)
54
 
55
+ def load_ema_model(self, model_type, ckpt_file, mel_spec_type, vocab_file, ode_method, use_ema, local_path=None):
56
  if model_type == "F5-TTS":
57
  if not ckpt_file:
58
  if mel_spec_type == "vocos":
src/f5_tts/infer/SHARED.md CHANGED
@@ -18,6 +18,8 @@
18
  - [Multilingual](#multilingual)
19
  - [F5-TTS Base @ pretrain @ zh \& en](#f5-tts-base--pretrain--zh--en)
20
  - [Mandarin](#mandarin)
 
 
21
  - [English](#english)
22
  - [French](#french)
23
  - [French LibriVox @ finetune @ fr](#french-librivox--finetune--fr)
@@ -67,6 +69,6 @@ MODEL_CKPT: hf://RASPIAUDIO/F5-French-MixedSpeakers-reduced/model_last_reduced.p
67
  VOCAB_FILE: hf://RASPIAUDIO/F5-French-MixedSpeakers-reduced/vocab.txt
68
  ```
69
 
70
- - Online Inference with [Hugging Face Space](https://huggingface.co/spaces/RASPIAUDIO/f5-tts_french).
71
  - [Tutorial video to train a new language model](https://www.youtube.com/watch?v=UO4usaOojys).
72
  - [Discussion about this training can be found here](https://github.com/SWivid/F5-TTS/issues/434).
 
18
  - [Multilingual](#multilingual)
19
  - [F5-TTS Base @ pretrain @ zh \& en](#f5-tts-base--pretrain--zh--en)
20
  - [Mandarin](#mandarin)
21
+ - [Japanese](#japanese)
22
+ - [F5-TTS Base @ pretrain/finetune @ ja](#f5-tts-base--pretrainfinetune--ja)
23
  - [English](#english)
24
  - [French](#french)
25
  - [French LibriVox @ finetune @ fr](#french-librivox--finetune--fr)
 
69
  VOCAB_FILE: hf://RASPIAUDIO/F5-French-MixedSpeakers-reduced/vocab.txt
70
  ```
71
 
72
+ - [Online Inference with Hugging Face Space](https://huggingface.co/spaces/RASPIAUDIO/f5-tts_french).
73
  - [Tutorial video to train a new language model](https://www.youtube.com/watch?v=UO4usaOojys).
74
  - [Discussion about this training can be found here](https://github.com/SWivid/F5-TTS/issues/434).
src/f5_tts/infer/utils_infer.py CHANGED
@@ -90,36 +90,41 @@ def chunk_text(text, max_chars=135):
90
 
91
 
92
  # load vocoder
93
- def load_vocoder(vocoder_name="vocos", is_local=False, local_path="", device=device):
94
  if vocoder_name == "vocos":
95
- if is_local:
 
96
  print(f"Load vocos from local path {local_path}")
97
- repo_id = "charactr/vocos-mel-24khz"
98
- revision = None
99
- config_path = hf_hub_download(
100
- repo_id=repo_id, cache_dir=local_path, filename="config.yaml", revision=revision
101
- )
102
- model_path = hf_hub_download(
103
- repo_id=repo_id, cache_dir=local_path, filename="pytorch_model.bin", revision=revision
104
- )
105
- vocoder = Vocos.from_hparams(config_path=config_path)
106
- state_dict = torch.load(model_path, map_location="cpu")
107
- vocoder.load_state_dict(state_dict)
108
- vocoder = vocoder.eval().to(device)
109
  else:
110
  print("Download Vocos from huggingface charactr/vocos-mel-24khz")
111
- vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz").to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  elif vocoder_name == "bigvgan":
113
  try:
114
  from third_party.BigVGAN import bigvgan
115
  except ImportError:
116
  print("You need to follow the README to init submodule and change the BigVGAN source code.")
117
- if is_local:
118
  """download from https://huggingface.co/nvidia/bigvgan_v2_24khz_100band_256x/tree/main"""
119
- local_path = snapshot_download(repo_id="nvidia/bigvgan_v2_24khz_100band_256x", cache_dir=local_path)
120
  vocoder = bigvgan.BigVGAN.from_pretrained(local_path, use_cuda_kernel=False)
121
  else:
122
- vocoder = bigvgan.BigVGAN.from_pretrained("nvidia/bigvgan_v2_24khz_100band_256x", use_cuda_kernel=False)
 
123
 
124
  vocoder.remove_weight_norm()
125
  vocoder = vocoder.eval().to(device)
 
90
 
91
 
92
  # load vocoder
93
+ def load_vocoder(vocoder_name="vocos", is_local=False, local_path=None, device=device):
94
  if vocoder_name == "vocos":
95
+ # vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz").to(device)
96
+ if is_local and local_path is not None:
97
  print(f"Load vocos from local path {local_path}")
98
+ config_path = f"{local_path}/config.yaml"
99
+ model_path = f"{local_path}/pytorch_model.bin"
 
 
 
 
 
 
 
 
 
 
100
  else:
101
  print("Download Vocos from huggingface charactr/vocos-mel-24khz")
102
+ repo_id = "charactr/vocos-mel-24khz"
103
+ config_path = hf_hub_download(repo_id=repo_id, cache_dir=local_path, filename="config.yaml")
104
+ model_path = hf_hub_download(repo_id=repo_id, cache_dir=local_path, filename="pytorch_model.bin")
105
+ vocoder = Vocos.from_hparams(config_path)
106
+ state_dict = torch.load(model_path, map_location="cpu", weights_only=True)
107
+ from vocos.feature_extractors import EncodecFeatures
108
+
109
+ if isinstance(vocoder.feature_extractor, EncodecFeatures):
110
+ encodec_parameters = {
111
+ "feature_extractor.encodec." + key: value
112
+ for key, value in vocoder.feature_extractor.encodec.state_dict().items()
113
+ }
114
+ state_dict.update(encodec_parameters)
115
+ vocoder.load_state_dict(state_dict)
116
+ vocoder = vocoder.eval().to(device)
117
  elif vocoder_name == "bigvgan":
118
  try:
119
  from third_party.BigVGAN import bigvgan
120
  except ImportError:
121
  print("You need to follow the README to init submodule and change the BigVGAN source code.")
122
+ if is_local and local_path is not None:
123
  """download from https://huggingface.co/nvidia/bigvgan_v2_24khz_100band_256x/tree/main"""
 
124
  vocoder = bigvgan.BigVGAN.from_pretrained(local_path, use_cuda_kernel=False)
125
  else:
126
+ local_path = snapshot_download(repo_id="nvidia/bigvgan_v2_24khz_100band_256x", cache_dir=local_path)
127
+ vocoder = bigvgan.BigVGAN.from_pretrained(local_path, use_cuda_kernel=False)
128
 
129
  vocoder.remove_weight_norm()
130
  vocoder = vocoder.eval().to(device)