|
from typing import Callable |
|
|
|
import gradio as gr |
|
|
|
from fish_speech.i18n import i18n |
|
from fish_speech.inference_engine.utils import normalize_text |
|
from tools.webui.variables import HEADER_MD, TEXTBOX_PLACEHOLDER |
|
|
|
|
|
def build_app(inference_fct: Callable, theme: str = "light") -> gr.Blocks: |
|
with gr.Blocks(theme=gr.themes.Base()) as app: |
|
gr.Markdown(HEADER_MD) |
|
|
|
|
|
app.load( |
|
None, |
|
None, |
|
js="() => {const params = new URLSearchParams(window.location.search);if (!params.has('__theme')) {params.set('__theme', '%s');window.location.search = params.toString();}}" |
|
% theme, |
|
) |
|
|
|
|
|
with gr.Row(): |
|
with gr.Column(scale=3): |
|
text = gr.Textbox( |
|
label=i18n("Input Text"), placeholder=TEXTBOX_PLACEHOLDER, lines=10 |
|
) |
|
refined_text = gr.Textbox( |
|
label=i18n("Realtime Transform Text"), |
|
placeholder=i18n( |
|
"Normalization Result Preview (Currently Only Chinese)" |
|
), |
|
lines=5, |
|
interactive=False, |
|
) |
|
|
|
with gr.Row(): |
|
normalize = gr.Checkbox( |
|
label=i18n("Text Normalization"), |
|
value=False, |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
with gr.Tab(label=i18n("Advanced Config")): |
|
with gr.Row(): |
|
chunk_length = gr.Slider( |
|
label=i18n("Iterative Prompt Length, 0 means off"), |
|
minimum=0, |
|
maximum=300, |
|
value=200, |
|
step=8, |
|
) |
|
|
|
max_new_tokens = gr.Slider( |
|
label=i18n( |
|
"Maximum tokens per batch, 0 means no limit" |
|
), |
|
minimum=0, |
|
maximum=2048, |
|
value=0, |
|
step=8, |
|
) |
|
|
|
with gr.Row(): |
|
top_p = gr.Slider( |
|
label="Top-P", |
|
minimum=0.6, |
|
maximum=0.9, |
|
value=0.7, |
|
step=0.01, |
|
) |
|
|
|
repetition_penalty = gr.Slider( |
|
label=i18n("Repetition Penalty"), |
|
minimum=1, |
|
maximum=1.5, |
|
value=1.2, |
|
step=0.01, |
|
) |
|
|
|
with gr.Row(): |
|
temperature = gr.Slider( |
|
label="Temperature", |
|
minimum=0.6, |
|
maximum=0.9, |
|
value=0.7, |
|
step=0.01, |
|
) |
|
seed = gr.Number( |
|
label="Seed", |
|
info="0 means randomized inference, otherwise deterministic", |
|
value=0, |
|
) |
|
|
|
with gr.Tab(label=i18n("Reference Audio")): |
|
with gr.Row(): |
|
gr.Markdown( |
|
i18n( |
|
"5 to 10 seconds of reference audio, useful for specifying speaker." |
|
) |
|
) |
|
with gr.Row(): |
|
reference_id = gr.Textbox( |
|
label=i18n("Reference ID"), |
|
placeholder="Leave empty to use uploaded references", |
|
) |
|
|
|
with gr.Row(): |
|
use_memory_cache = gr.Radio( |
|
label=i18n("Use Memory Cache"), |
|
choices=["on", "off"], |
|
value="on", |
|
) |
|
|
|
with gr.Row(): |
|
reference_audio = gr.Audio( |
|
label=i18n("Reference Audio"), |
|
type="filepath", |
|
) |
|
with gr.Row(): |
|
reference_text = gr.Textbox( |
|
label=i18n("Reference Text"), |
|
lines=1, |
|
placeholder="在一无所知中,梦里的一天结束了,一个新的「轮回」便会开始。", |
|
value="", |
|
) |
|
|
|
with gr.Column(scale=3): |
|
with gr.Row(): |
|
error = gr.HTML( |
|
label=i18n("Error Message"), |
|
visible=True, |
|
) |
|
with gr.Row(): |
|
audio = gr.Audio( |
|
label=i18n("Generated Audio"), |
|
type="numpy", |
|
interactive=False, |
|
visible=True, |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=3): |
|
generate = gr.Button( |
|
value="\U0001F3A7 " + i18n("Generate"), |
|
variant="primary", |
|
) |
|
|
|
text.input(fn=normalize_text, inputs=[text, normalize], outputs=[refined_text]) |
|
|
|
|
|
generate.click( |
|
inference_fct, |
|
[ |
|
refined_text, |
|
normalize, |
|
reference_id, |
|
reference_audio, |
|
reference_text, |
|
max_new_tokens, |
|
chunk_length, |
|
top_p, |
|
repetition_penalty, |
|
temperature, |
|
seed, |
|
use_memory_cache, |
|
], |
|
[audio, error], |
|
concurrency_limit=1, |
|
) |
|
|
|
return app |
|
|