| import os |
| import queue |
| from huggingface_hub import snapshot_download |
| import numpy as np |
| import wave |
| import io |
| import gc |
| from typing import Callable |
|
|
| |
| os.makedirs("checkpoints", exist_ok=True) |
| snapshot_download(repo_id="fishaudio/s1-mini", local_dir="./checkpoints/openaudio-s1-mini") |
|
|
| print("All checkpoints downloaded") |
|
|
| import html |
| import os |
| from argparse import ArgumentParser |
| from pathlib import Path |
|
|
| import gradio as gr |
| import torch |
| import torchaudio |
|
|
| torchaudio.set_audio_backend("soundfile") |
|
|
| from loguru import logger |
| from fish_speech.i18n import i18n |
| from fish_speech.inference_engine import TTSInferenceEngine |
| from fish_speech.models.dac.inference import load_model as load_decoder_model |
| from fish_speech.models.text2semantic.inference import launch_thread_safe_queue |
| from tools.webui.inference import get_inference_wrapper |
| from fish_speech.utils.schema import ServeTTSRequest |
|
|
| |
| os.environ["EINX_FILTER_TRACEBACK"] = "false" |
|
|
|
|
| HEADER_MD = """# Fish Audio S1 |
| |
| ## The demo in this space is Fish Audio S1, Please check [Fish Audio](https://fish.audio) for the best model. |
| ## 该 Demo 为 Fish Audio S1 版本, 请在 [Fish Audio](https://fish.audio) 体验最新 DEMO. |
| |
| A text-to-speech model based on DAC & Qwen3 developed by [Fish Audio](https://fish.audio). |
| 由 [Fish Audio](https://fish.audio) 研发的 DAC & Qwen3 多语种语音合成. |
| |
| You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/s1-mini). |
| 你可以在 [这里](https://github.com/fishaudio/fish-speech) 找到源代码和 [这里](https://huggingface.co/fishaudio/s1-mini) 找到模型. |
| |
| Related code and weights are released under CC BY-NC-SA 4.0 License. |
| 相关代码,权重使用 CC BY-NC-SA 4.0 许可证发布. |
| |
| We are not responsible for any misuse of the model, please consider your local laws and regulations before using it. |
| 我们不对模型的任何滥用负责,请在使用之前考虑您当地的法律法规. |
| |
| The model running in this WebUI is Fish Audio S1 Mini. |
| 在此 WebUI 中运行的模型是 Fish Audio S1 Mini. |
| """ |
|
|
| TEXTBOX_PLACEHOLDER = """Put your text here. 在此处输入文本.""" |
|
|
| try: |
| import spaces |
|
|
| GPU_DECORATOR = spaces.GPU |
| except ImportError: |
|
|
| def GPU_DECORATOR(func): |
| def wrapper(*args, **kwargs): |
| return func(*args, **kwargs) |
|
|
| return wrapper |
|
|
| def build_html_error_message(error): |
| return f""" |
| <div style="color: red; |
| font-weight: bold;"> |
| {html.escape(str(error))} |
| </div> |
| """ |
|
|
| def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1): |
| buffer = io.BytesIO() |
|
|
| with wave.open(buffer, "wb") as wav_file: |
| wav_file.setnchannels(channels) |
| wav_file.setsampwidth(bit_depth // 8) |
| wav_file.setframerate(sample_rate) |
|
|
| wav_header_bytes = buffer.getvalue() |
| buffer.close() |
| return wav_header_bytes |
|
|
|
|
| def build_app(inference_fct: Callable, theme: str = "light") -> gr.Blocks: |
| with gr.Blocks(theme=gr.themes.Base()) as app: |
| gr.Markdown(HEADER_MD) |
|
|
| |
| app.load( |
| None, |
| None, |
| js="() => {const params = new URLSearchParams(window.location.search);if (!params.has('__theme')) {params.set('__theme', '%s');window.location.search = params.toString();}}" |
| % theme, |
| ) |
|
|
| |
| with gr.Row(): |
| with gr.Column(scale=3): |
| text = gr.Textbox( |
| label=i18n("Input Text"), placeholder=TEXTBOX_PLACEHOLDER, lines=10 |
| ) |
|
|
| with gr.Row(): |
| with gr.Column(): |
| with gr.Tab(label=i18n("Advanced Config")): |
| with gr.Row(): |
| chunk_length = gr.Slider( |
| label=i18n("Iterative Prompt Length, 0 means off"), |
| minimum=0, |
| maximum=500, |
| value=0, |
| step=8, |
| ) |
|
|
| max_new_tokens = gr.Slider( |
| label=i18n( |
| "Maximum tokens per batch, 0 means no limit" |
| ), |
| minimum=0, |
| maximum=2048, |
| value=0, |
| step=8, |
| ) |
|
|
| with gr.Row(): |
| top_p = gr.Slider( |
| label="Top-P", |
| minimum=0.7, |
| maximum=0.95, |
| value=0.9, |
| step=0.01, |
| ) |
|
|
| repetition_penalty = gr.Slider( |
| label=i18n("Repetition Penalty"), |
| minimum=1, |
| maximum=1.2, |
| value=1.1, |
| step=0.01, |
| ) |
|
|
| with gr.Row(): |
| temperature = gr.Slider( |
| label="Temperature", |
| minimum=0.7, |
| maximum=1.0, |
| value=0.9, |
| step=0.01, |
| ) |
| seed = gr.Number( |
| label="Seed", |
| info="0 means randomized inference, otherwise deterministic", |
| value=0, |
| ) |
|
|
| with gr.Tab(label=i18n("Reference Audio")): |
| with gr.Row(): |
| gr.Markdown( |
| i18n( |
| "5 to 10 seconds of reference audio, useful for specifying speaker." |
| ) |
| ) |
| with gr.Row(): |
| reference_id = gr.Textbox( |
| label=i18n("Reference ID"), |
| placeholder="Leave empty to use uploaded references", |
| ) |
|
|
| with gr.Row(): |
| use_memory_cache = gr.Radio( |
| label=i18n("Use Memory Cache"), |
| choices=["on", "off"], |
| value="on", |
| ) |
|
|
| with gr.Row(): |
| reference_audio = gr.Audio( |
| label=i18n("Reference Audio"), |
| type="filepath", |
| ) |
| with gr.Row(): |
| reference_text = gr.Textbox( |
| label=i18n("Reference Text"), |
| lines=1, |
| placeholder="在一无所知中,梦里的一天结束了,一个新的「轮回」便会开始。", |
| value="", |
| ) |
|
|
| with gr.Column(scale=3): |
| with gr.Row(): |
| error = gr.HTML( |
| label=i18n("Error Message"), |
| visible=True, |
| ) |
| with gr.Row(): |
| audio = gr.Audio( |
| label=i18n("Generated Audio"), |
| type="numpy", |
| interactive=False, |
| visible=True, |
| ) |
|
|
| with gr.Row(): |
| with gr.Column(scale=3): |
| generate = gr.Button( |
| value="\U0001f3a7 " + i18n("Generate"), |
| variant="primary", |
| ) |
|
|
| |
| generate.click( |
| inference_fct, |
| [ |
| text, |
| reference_id, |
| reference_audio, |
| reference_text, |
| max_new_tokens, |
| chunk_length, |
| top_p, |
| repetition_penalty, |
| temperature, |
| seed, |
| use_memory_cache, |
| ], |
| [audio, error], |
| concurrency_limit=1, |
| ) |
|
|
| return app |
|
|
| def parse_args(): |
| parser = ArgumentParser() |
| parser.add_argument( |
| "--llama-checkpoint-path", |
| type=Path, |
| default="checkpoints/openaudio-s1-mini", |
| ) |
| parser.add_argument( |
| "--decoder-checkpoint-path", |
| type=Path, |
| default="checkpoints/openaudio-s1-mini/codec.pth", |
| ) |
| parser.add_argument("--decoder-config-name", type=str, default="modded_dac_vq") |
| parser.add_argument("--device", type=str, default="cuda") |
| parser.add_argument("--half", action="store_true") |
| parser.add_argument("--compile", action="store_true",default=True) |
| parser.add_argument("--max-gradio-length", type=int, default=0) |
| parser.add_argument("--theme", type=str, default="dark") |
|
|
| return parser.parse_args() |
|
|
|
|
| if __name__ == "__main__": |
| args = parse_args() |
| args.precision = torch.half if args.half else torch.bfloat16 |
|
|
| logger.info("Loading Llama model...") |
| llama_queue = launch_thread_safe_queue( |
| checkpoint_path=args.llama_checkpoint_path, |
| device=args.device, |
| precision=args.precision, |
| compile=args.compile, |
| ) |
| logger.info("Llama model loaded, loading VQ-GAN model...") |
|
|
| decoder_model = load_decoder_model( |
| config_name=args.decoder_config_name, |
| checkpoint_path=args.decoder_checkpoint_path, |
| device=args.device, |
| ) |
|
|
| logger.info("Decoder model loaded, warming up...") |
|
|
| |
| inference_engine = TTSInferenceEngine( |
| llama_queue=llama_queue, |
| decoder_model=decoder_model, |
| compile=args.compile, |
| precision=args.precision, |
| ) |
|
|
| |
| list( |
| inference_engine.inference( |
| ServeTTSRequest( |
| text="Hello world.", |
| references=[], |
| reference_id=None, |
| max_new_tokens=1024, |
| chunk_length=200, |
| top_p=0.7, |
| repetition_penalty=1.5, |
| temperature=0.7, |
| format="wav", |
| ) |
| ) |
| ) |
|
|
| logger.info("Warming up done, launching the web UI...") |
|
|
| inference_fct = get_inference_wrapper(inference_engine) |
|
|
| app = build_app(inference_fct, args.theme) |
| app.queue(api_open=True).launch(show_error=True, show_api=True) |
|
|