import html from functools import partial from typing import Any, Callable from fish_speech.i18n import i18n from tools.schema import ServeReferenceAudio, ServeTTSRequest def inference_wrapper( text, normalize, reference_id, reference_audio, reference_text, max_new_tokens, chunk_length, top_p, repetition_penalty, temperature, seed, use_memory_cache, engine, ): """ Wrapper for the inference function. Used in the Gradio interface. """ if reference_audio: references = get_reference_audio(reference_audio, reference_text) else: references = [] req = ServeTTSRequest( text=text, normalize=normalize, reference_id=reference_id if reference_id else None, references=references, max_new_tokens=max_new_tokens, chunk_length=chunk_length, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature, seed=int(seed) if seed else None, use_memory_cache=use_memory_cache, ) for result in engine.inference(req): match result.code: case "final": return result.audio, None case "error": return None, build_html_error_message(i18n(result.error)) case _: pass return None, i18n("No audio generated") def get_reference_audio(reference_audio: str, reference_text: str) -> list: """ Get the reference audio bytes. """ with open(reference_audio, "rb") as audio_file: audio_bytes = audio_file.read() return [ServeReferenceAudio(audio=audio_bytes, text=reference_text)] def build_html_error_message(error: Any) -> str: error = error if isinstance(error, Exception) else Exception("Unknown error") return f"""