Hemant0000 commited on
Commit
fc8d5c8
·
verified ·
1 Parent(s): c0c9ba8

Delete src/f5_tts/src_f5_tts_api.py

Browse files
Files changed (1) hide show
  1. src/f5_tts/src_f5_tts_api.py +0 -151
src/f5_tts/src_f5_tts_api.py DELETED
@@ -1,151 +0,0 @@
1
- import random
2
- import sys
3
- from importlib.resources import files
4
-
5
- import soundfile as sf
6
- import torch
7
- import tqdm
8
- from cached_path import cached_path
9
-
10
- from f5_tts.infer.utils_infer import (
11
- hop_length,
12
- infer_process,
13
- load_model,
14
- load_vocoder,
15
- preprocess_ref_audio_text,
16
- remove_silence_for_generated_wav,
17
- save_spectrogram,
18
- target_sample_rate,
19
- )
20
- from f5_tts.model import DiT, UNetT
21
- from f5_tts.model.utils import seed_everything
22
-
23
-
24
- class F5TTS:
25
- def __init__(
26
- self,
27
- model_type="F5-TTS",
28
- ckpt_file="",
29
- vocab_file="",
30
- ode_method="euler",
31
- use_ema=True,
32
- vocoder_name="vocos",
33
- local_path=None,
34
- device=None,
35
- ):
36
- # Initialize parameters
37
- self.final_wave = None
38
- self.target_sample_rate = target_sample_rate
39
- self.hop_length = hop_length
40
- self.seed = -1
41
- self.mel_spec_type = vocoder_name
42
-
43
- # Set device
44
- self.device = device or (
45
- "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
46
- )
47
-
48
- # Load models
49
- self.load_vocoder_model(vocoder_name, local_path)
50
- self.load_ema_model(model_type, ckpt_file, vocoder_name, vocab_file, ode_method, use_ema)
51
-
52
- def load_vocoder_model(self, vocoder_name, local_path):
53
- self.vocoder = load_vocoder(vocoder_name, local_path is not None, local_path, self.device)
54
-
55
- def load_ema_model(self, model_type, ckpt_file, mel_spec_type, vocab_file, ode_method, use_ema):
56
- if model_type == "F5-TTS":
57
- if not ckpt_file:
58
- if mel_spec_type == "vocos":
59
- ckpt_file = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors"))
60
- elif mel_spec_type == "bigvgan":
61
- ckpt_file = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt"))
62
- model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
63
- model_cls = DiT
64
- elif model_type == "E2-TTS":
65
- if not ckpt_file:
66
- ckpt_file = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors"))
67
- model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
68
- model_cls = UNetT
69
- else:
70
- raise ValueError(f"Unknown model type: {model_type}")
71
-
72
- self.ema_model = load_model(
73
- model_cls, model_cfg, ckpt_file, mel_spec_type, vocab_file, ode_method, use_ema, self.device
74
- )
75
-
76
- def export_wav(self, wav, file_wave, remove_silence=False):
77
- sf.write(file_wave, wav, self.target_sample_rate)
78
-
79
- if remove_silence:
80
- remove_silence_for_generated_wav(file_wave)
81
-
82
- def export_spectrogram(self, spect, file_spect):
83
- save_spectrogram(spect, file_spect)
84
-
85
- def infer(
86
- self,
87
- ref_file,
88
- ref_text,
89
- gen_text,
90
- show_info=print,
91
- progress=tqdm,
92
- target_rms=0.1,
93
- cross_fade_duration=0.15,
94
- sway_sampling_coef=-1,
95
- cfg_strength=2,
96
- nfe_step=32,
97
- speed=1.0,
98
- fix_duration=None,
99
- remove_silence=False,
100
- file_wave=None,
101
- file_spect=None,
102
- seed=-1,
103
- ):
104
- if seed == -1:
105
- seed = random.randint(0, sys.maxsize)
106
- seed_everything(seed)
107
- self.seed = seed
108
-
109
- ref_file, ref_text = preprocess_ref_audio_text(ref_file, ref_text, device=self.device)
110
-
111
- wav, sr, spect = infer_process(
112
- ref_file,
113
- ref_text,
114
- gen_text,
115
- self.ema_model,
116
- self.vocoder,
117
- self.mel_spec_type,
118
- show_info=show_info,
119
- progress=progress,
120
- target_rms=target_rms,
121
- cross_fade_duration=cross_fade_duration,
122
- nfe_step=nfe_step,
123
- cfg_strength=cfg_strength,
124
- sway_sampling_coef=sway_sampling_coef,
125
- speed=speed,
126
- fix_duration=fix_duration,
127
- device=self.device,
128
- )
129
-
130
- if file_wave is not None:
131
- self.export_wav(wav, file_wave, remove_silence)
132
-
133
- if file_spect is not None:
134
- self.export_spectrogram(spect, file_spect)
135
-
136
- return wav, sr, spect
137
-
138
-
139
- if __name__ == "__main__":
140
- f5tts = F5TTS()
141
-
142
- wav, sr, spect = f5tts.infer(
143
- ref_file=str(files("f5_tts").joinpath("infer/examples/basic/basic_ref_en.wav")),
144
- ref_text="some call me nature, others call me mother nature.",
145
- gen_text="""I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences.""",
146
- file_wave=str(files("f5_tts").joinpath("../../tests/api_out.wav")),
147
- file_spect=str(files("f5_tts").joinpath("../../tests/api_out.png")),
148
- seed=-1, # random seed = -1
149
- )
150
-
151
- print("seed :", f5tts.seed)