Spaces:
Running
Running
| import html | |
| from functools import partial | |
| from typing import Any, Callable | |
| from fish_speech.i18n import i18n | |
| from tools.schema import ServeReferenceAudio, ServeTTSRequest | |
| def inference_wrapper( | |
| text, | |
| normalize, | |
| reference_id, | |
| reference_audio, | |
| reference_text, | |
| max_new_tokens, | |
| chunk_length, | |
| top_p, | |
| repetition_penalty, | |
| temperature, | |
| seed, | |
| use_memory_cache, | |
| engine, | |
| ): | |
| """ | |
| Wrapper for the inference function. | |
| Used in the Gradio interface. | |
| """ | |
| if reference_audio: | |
| references = get_reference_audio(reference_audio, reference_text) | |
| else: | |
| references = [] | |
| req = ServeTTSRequest( | |
| text=text, | |
| normalize=normalize, | |
| reference_id=reference_id if reference_id else None, | |
| references=references, | |
| max_new_tokens=max_new_tokens, | |
| chunk_length=chunk_length, | |
| top_p=top_p, | |
| repetition_penalty=repetition_penalty, | |
| temperature=temperature, | |
| seed=int(seed) if seed else None, | |
| use_memory_cache=use_memory_cache, | |
| ) | |
| for result in engine.inference(req): | |
| match result.code: | |
| case "final": | |
| return result.audio, None | |
| case "error": | |
| return None, build_html_error_message(i18n(result.error)) | |
| case _: | |
| pass | |
| return None, i18n("No audio generated") | |
| def get_reference_audio(reference_audio: str, reference_text: str) -> list: | |
| """ | |
| Get the reference audio bytes. | |
| """ | |
| with open(reference_audio, "rb") as audio_file: | |
| audio_bytes = audio_file.read() | |
| return [ServeReferenceAudio(audio=audio_bytes, text=reference_text)] | |
| def build_html_error_message(error: Any) -> str: | |
| error = error if isinstance(error, Exception) else Exception("Unknown error") | |
| return f""" | |
| <div style="color: red; | |
| font-weight: bold;"> | |
| {html.escape(str(error))} | |
| </div> | |
| """ | |
| def get_inference_wrapper(engine) -> Callable: | |
| """ | |
| Get the inference function with the immutable arguments. | |
| """ | |
| return partial( | |
| inference_wrapper, | |
| engine=engine, | |
| ) | |