|
import os |
|
from argparse import ArgumentParser |
|
from pathlib import Path |
|
|
|
import pyrootutils |
|
import torch |
|
from loguru import logger |
|
|
|
pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) |
|
|
|
from fish_speech.inference_engine import TTSInferenceEngine |
|
from fish_speech.models.text2semantic.inference import launch_thread_safe_queue |
|
from fish_speech.models.vqgan.inference import load_model as load_decoder_model |
|
from fish_speech.utils.schema import ServeTTSRequest |
|
from tools.webui import build_app |
|
from tools.webui.inference import get_inference_wrapper |
|
|
|
|
|
|
|
os.environ["EINX_FILTER_TRACEBACK"] = "false" |
|
os.system("huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5") |
|
|
|
def parse_args(): |
|
parser = ArgumentParser() |
|
parser.add_argument( |
|
"--llama-checkpoint-path", |
|
type=Path, |
|
default="checkpoints/fish-speech-1.5", |
|
) |
|
parser.add_argument( |
|
"--decoder-checkpoint-path", |
|
type=Path, |
|
default="checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth", |
|
) |
|
parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq") |
|
parser.add_argument("--device", type=str, default="cuda") |
|
parser.add_argument("--half", action="store_true") |
|
parser.add_argument("--compile", action="store_true") |
|
parser.add_argument("--max-gradio-length", type=int, default=0) |
|
parser.add_argument("--theme", type=str, default="light") |
|
|
|
return parser.parse_args() |
|
|
|
|
|
if __name__ == "__main__": |
|
args = parse_args() |
|
args.precision = torch.half if args.half else torch.bfloat16 |
|
|
|
|
|
if torch.backends.mps.is_available(): |
|
args.device = "mps" |
|
logger.info("mps is available, running on mps.") |
|
elif not torch.cuda.is_available(): |
|
logger.info("CUDA is not available, running on CPU.") |
|
args.device = "cpu" |
|
|
|
logger.info("Loading Llama model...") |
|
llama_queue = launch_thread_safe_queue( |
|
checkpoint_path=args.llama_checkpoint_path, |
|
device=args.device, |
|
precision=args.precision, |
|
compile=args.compile, |
|
) |
|
|
|
logger.info("Loading VQ-GAN model...") |
|
decoder_model = load_decoder_model( |
|
config_name=args.decoder_config_name, |
|
checkpoint_path=args.decoder_checkpoint_path, |
|
device=args.device, |
|
) |
|
|
|
logger.info("Decoder model loaded, warming up...") |
|
|
|
|
|
inference_engine = TTSInferenceEngine( |
|
llama_queue=llama_queue, |
|
decoder_model=decoder_model, |
|
compile=args.compile, |
|
precision=args.precision, |
|
) |
|
|
|
|
|
list( |
|
inference_engine.inference( |
|
ServeTTSRequest( |
|
text="Hello world.", |
|
references=[], |
|
reference_id=None, |
|
max_new_tokens=1024, |
|
chunk_length=200, |
|
top_p=0.7, |
|
repetition_penalty=1.5, |
|
temperature=0.7, |
|
format="wav", |
|
) |
|
) |
|
) |
|
|
|
logger.info("Warming up done, launching the web UI...") |
|
|
|
|
|
inference_fct = get_inference_wrapper(inference_engine) |
|
|
|
app = build_app(inference_fct, args.theme) |
|
app.launch(show_api=True) |
|
|