mrfakename commited on
Commit
daaf1ba
·
verified ·
1 Parent(s): c940f75

Sync from GitHub repo

Browse files

This Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there

Files changed (2) hide show
  1. pyproject.toml +1 -1
  2. src/f5_tts/socket_server.py +48 -17
pyproject.toml CHANGED
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
 
5
  [project]
6
  name = "f5-tts"
7
- version = "0.3.1"
8
  description = "F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching"
9
  readme = "README.md"
10
  license = {text = "MIT License"}
 
4
 
5
  [project]
6
  name = "f5-tts"
7
+ version = "0.3.2"
8
  description = "F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching"
9
  readme = "README.md"
10
  license = {text = "MIT License"}
src/f5_tts/socket_server.py CHANGED
@@ -1,13 +1,14 @@
 
 
1
  import socket
2
  import struct
3
  import torch
4
  import torchaudio
5
- from threading import Thread
6
-
7
-
8
- import gc
9
  import traceback
 
 
10
 
 
11
 
12
  from infer.utils_infer import infer_batch_process, preprocess_ref_audio_text, load_vocoder, load_model
13
  from model.backbones.dit import DiT
@@ -15,7 +16,9 @@ from model.backbones.dit import DiT
15
 
16
  class TTSStreamingProcessor:
17
  def __init__(self, ckpt_file, vocab_file, ref_audio, ref_text, device=None, dtype=torch.float32):
18
- self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
 
 
19
 
20
  # Load the model using the provided checkpoint and vocab files
21
  self.model = load_model(
@@ -137,23 +140,51 @@ def start_server(host, port, processor):
137
 
138
 
139
  if __name__ == "__main__":
140
- try:
141
- # Load the model and vocoder using the provided files
142
- ckpt_file = "" # pointing your checkpoint "ckpts/model/model_1096.pt"
143
- vocab_file = "" # Add vocab file path if needed
144
- ref_audio = "" # add ref audio"./tests/ref_audio/reference.wav"
145
- ref_text = ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
 
 
147
  # Initialize the processor with the model and vocoder
148
  processor = TTSStreamingProcessor(
149
- ckpt_file=ckpt_file,
150
- vocab_file=vocab_file,
151
- ref_audio=ref_audio,
152
- ref_text=ref_text,
153
- dtype=torch.float32,
 
154
  )
155
 
156
  # Start the server
157
- start_server("0.0.0.0", 9998, processor)
 
158
  except KeyboardInterrupt:
159
  gc.collect()
 
1
+ import argparse
2
+ import gc
3
  import socket
4
  import struct
5
  import torch
6
  import torchaudio
 
 
 
 
7
  import traceback
8
+ from importlib.resources import files
9
+ from threading import Thread
10
 
11
+ from cached_path import cached_path
12
 
13
  from infer.utils_infer import infer_batch_process, preprocess_ref_audio_text, load_vocoder, load_model
14
  from model.backbones.dit import DiT
 
16
 
17
  class TTSStreamingProcessor:
18
  def __init__(self, ckpt_file, vocab_file, ref_audio, ref_text, device=None, dtype=torch.float32):
19
+ self.device = device or (
20
+ "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
21
+ )
22
 
23
  # Load the model using the provided checkpoint and vocab files
24
  self.model = load_model(
 
140
 
141
 
142
  if __name__ == "__main__":
143
+ parser = argparse.ArgumentParser()
144
+
145
+ parser.add_argument("--host", default="0.0.0.0")
146
+ parser.add_argument("--port", default=9998)
147
+
148
+ parser.add_argument(
149
+ "--ckpt_file",
150
+ default=str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors")),
151
+ help="Path to the model checkpoint file",
152
+ )
153
+ parser.add_argument(
154
+ "--vocab_file",
155
+ default="",
156
+ help="Path to the vocab file if customized",
157
+ )
158
+
159
+ parser.add_argument(
160
+ "--ref_audio",
161
+ default=str(files("f5_tts").joinpath("infer/examples/basic/basic_ref_en.wav")),
162
+ help="Reference audio to provide model with speaker characteristics",
163
+ )
164
+ parser.add_argument(
165
+ "--ref_text",
166
+ default="",
167
+ help="Reference audio subtitle, leave empty to auto-transcribe",
168
+ )
169
+
170
+ parser.add_argument("--device", default=None, help="Device to run the model on")
171
+ parser.add_argument("--dtype", default=torch.float32, help="Data type to use for model inference")
172
+
173
+ args = parser.parse_args()
174
 
175
+ try:
176
  # Initialize the processor with the model and vocoder
177
  processor = TTSStreamingProcessor(
178
+ ckpt_file=args.ckpt_file,
179
+ vocab_file=args.vocab_file,
180
+ ref_audio=args.ref_audio,
181
+ ref_text=args.ref_text,
182
+ device=args.device,
183
+ dtype=args.dtype,
184
  )
185
 
186
  # Start the server
187
+ start_server(args.host, args.port, processor)
188
+
189
  except KeyboardInterrupt:
190
  gc.collect()