fish-audio-t / tools /webui /inference.py
kiylu's picture
add project files
b128c76
import html
from functools import partial
from typing import Any, Callable
from fish_speech.i18n import i18n
from tools.schema import ServeReferenceAudio, ServeTTSRequest
def inference_wrapper(
text,
normalize,
reference_id,
reference_audio,
reference_text,
max_new_tokens,
chunk_length,
top_p,
repetition_penalty,
temperature,
seed,
use_memory_cache,
engine,
):
"""
Wrapper for the inference function.
Used in the Gradio interface.
"""
if reference_audio:
references = get_reference_audio(reference_audio, reference_text)
else:
references = []
req = ServeTTSRequest(
text=text,
normalize=normalize,
reference_id=reference_id if reference_id else None,
references=references,
max_new_tokens=max_new_tokens,
chunk_length=chunk_length,
top_p=top_p,
repetition_penalty=repetition_penalty,
temperature=temperature,
seed=int(seed) if seed else None,
use_memory_cache=use_memory_cache,
)
for result in engine.inference(req):
match result.code:
case "final":
return result.audio, None
case "error":
return None, build_html_error_message(i18n(result.error))
case _:
pass
return None, i18n("No audio generated")
def get_reference_audio(reference_audio: str, reference_text: str) -> list:
"""
Get the reference audio bytes.
"""
with open(reference_audio, "rb") as audio_file:
audio_bytes = audio_file.read()
return [ServeReferenceAudio(audio=audio_bytes, text=reference_text)]
def build_html_error_message(error: Any) -> str:
error = error if isinstance(error, Exception) else Exception("Unknown error")
return f"""
<div style="color: red;
font-weight: bold;">
{html.escape(str(error))}
</div>
"""
def get_inference_wrapper(engine) -> Callable:
"""
Get the inference function with the immutable arguments.
"""
return partial(
inference_wrapper,
engine=engine,
)