Hemant0000 commited on
Commit
c0c9ba8
·
verified ·
1 Parent(s): 2af09cc

Create api.py

Browse files
Files changed (1) hide show
  1. src/f5_tts/api.py +166 -0
src/f5_tts/api.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import sys
3
+ from importlib.resources import files
4
+
5
+ import soundfile as sf
6
+ import tqdm
7
+ from cached_path import cached_path
8
+
9
+ from f5_tts.infer.utils_infer import (
10
+ hop_length,
11
+ infer_process,
12
+ load_model,
13
+ load_vocoder,
14
+ preprocess_ref_audio_text,
15
+ remove_silence_for_generated_wav,
16
+ save_spectrogram,
17
+ transcribe,
18
+ target_sample_rate,
19
+ )
20
+ from f5_tts.model import DiT, UNetT
21
+ from f5_tts.model.utils import seed_everything
22
+
23
+
24
+ class F5TTS:
25
+ def __init__(
26
+ self,
27
+ model_type="F5-TTS",
28
+ ckpt_file="",
29
+ vocab_file="",
30
+ ode_method="euler",
31
+ use_ema=True,
32
+ vocoder_name="vocos",
33
+ local_path=None,
34
+ device=None,
35
+ hf_cache_dir=None,
36
+ ):
37
+ # Initialize parameters
38
+ self.final_wave = None
39
+ self.target_sample_rate = target_sample_rate
40
+ self.hop_length = hop_length
41
+ self.seed = -1
42
+ self.mel_spec_type = vocoder_name
43
+
44
+ # Set device
45
+ if device is not None:
46
+ self.device = device
47
+ else:
48
+ import torch
49
+
50
+ self.device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
51
+
52
+ # Load models
53
+ self.load_vocoder_model(vocoder_name, local_path=local_path, hf_cache_dir=hf_cache_dir)
54
+ self.load_ema_model(
55
+ model_type, ckpt_file, vocoder_name, vocab_file, ode_method, use_ema, hf_cache_dir=hf_cache_dir
56
+ )
57
+
58
+ def load_vocoder_model(self, vocoder_name, local_path=None, hf_cache_dir=None):
59
+ self.vocoder = load_vocoder(vocoder_name, local_path is not None, local_path, self.device, hf_cache_dir)
60
+
61
+ def load_ema_model(self, model_type, ckpt_file, mel_spec_type, vocab_file, ode_method, use_ema, hf_cache_dir=None):
62
+ if model_type == "F5-TTS":
63
+ if not ckpt_file:
64
+ if mel_spec_type == "vocos":
65
+ ckpt_file = str(
66
+ cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors", cache_dir=hf_cache_dir)
67
+ )
68
+ elif mel_spec_type == "bigvgan":
69
+ ckpt_file = str(
70
+ cached_path("hf://SWivid/F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt", cache_dir=hf_cache_dir)
71
+ )
72
+ model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
73
+ model_cls = DiT
74
+ elif model_type == "E2-TTS":
75
+ if not ckpt_file:
76
+ ckpt_file = str(
77
+ cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors", cache_dir=hf_cache_dir)
78
+ )
79
+ model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
80
+ model_cls = UNetT
81
+ else:
82
+ raise ValueError(f"Unknown model type: {model_type}")
83
+
84
+ self.ema_model = load_model(
85
+ model_cls, model_cfg, ckpt_file, mel_spec_type, vocab_file, ode_method, use_ema, self.device
86
+ )
87
+
88
+ def transcribe(self, ref_audio, language=None):
89
+ return transcribe(ref_audio, language)
90
+
91
+ def export_wav(self, wav, file_wave, remove_silence=False):
92
+ sf.write(file_wave, wav, self.target_sample_rate)
93
+
94
+ if remove_silence:
95
+ remove_silence_for_generated_wav(file_wave)
96
+
97
+ def export_spectrogram(self, spect, file_spect):
98
+ save_spectrogram(spect, file_spect)
99
+
100
+ def infer(
101
+ self,
102
+ ref_file,
103
+ ref_text,
104
+ gen_text,
105
+ show_info=print,
106
+ progress=tqdm,
107
+ target_rms=0.1,
108
+ cross_fade_duration=0.15,
109
+ sway_sampling_coef=-1,
110
+ cfg_strength=2,
111
+ nfe_step=32,
112
+ speed=1.0,
113
+ fix_duration=None,
114
+ remove_silence=False,
115
+ file_wave=None,
116
+ file_spect=None,
117
+ seed=-1,
118
+ ):
119
+ if seed == -1:
120
+ seed = random.randint(0, sys.maxsize)
121
+ seed_everything(seed)
122
+ self.seed = seed
123
+
124
+ ref_file, ref_text = preprocess_ref_audio_text(ref_file, ref_text, device=self.device)
125
+
126
+ wav, sr, spect = infer_process(
127
+ ref_file,
128
+ ref_text,
129
+ gen_text,
130
+ self.ema_model,
131
+ self.vocoder,
132
+ self.mel_spec_type,
133
+ show_info=show_info,
134
+ progress=progress,
135
+ target_rms=target_rms,
136
+ cross_fade_duration=cross_fade_duration,
137
+ nfe_step=nfe_step,
138
+ cfg_strength=cfg_strength,
139
+ sway_sampling_coef=sway_sampling_coef,
140
+ speed=speed,
141
+ fix_duration=fix_duration,
142
+ device=self.device,
143
+ )
144
+
145
+ if file_wave is not None:
146
+ self.export_wav(wav, file_wave, remove_silence)
147
+
148
+ if file_spect is not None:
149
+ self.export_spectrogram(spect, file_spect)
150
+
151
+ return wav, sr, spect
152
+
153
+
154
+ if __name__ == "__main__":
155
+ f5tts = F5TTS()
156
+
157
+ wav, sr, spect = f5tts.infer(
158
+ ref_file=str(files("f5_tts").joinpath("infer/examples/basic/basic_ref_en.wav")),
159
+ ref_text="some call me nature, others call me mother nature.",
160
+ gen_text="""I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences.""",
161
+ file_wave=str(files("f5_tts").joinpath("../../tests/api_out.wav")),
162
+ file_spect=str(files("f5_tts").joinpath("../../tests/api_out.png")),
163
+ seed=-1, # random seed = -1
164
+ )
165
+
166
+ print("seed :", f5tts.seed)