|
import io |
|
import re |
|
import wave |
|
|
|
import gradio as gr |
|
|
|
from fish_speech.utils.schema import ServeMessage, ServeTextPart, ServeVQPart |
|
|
|
from .fish_e2e import FishE2EAgent, FishE2EEventType |
|
|
|
|
|
def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1): |
|
buffer = io.BytesIO() |
|
|
|
with wave.open(buffer, "wb") as wav_file: |
|
wav_file.setnchannels(channels) |
|
wav_file.setsampwidth(bit_depth // 8) |
|
wav_file.setframerate(sample_rate) |
|
|
|
wav_header_bytes = buffer.getvalue() |
|
buffer.close() |
|
return wav_header_bytes |
|
|
|
|
|
class ChatState: |
|
def __init__(self): |
|
self.conversation = [] |
|
self.added_systext = False |
|
self.added_sysaudio = False |
|
|
|
def get_history(self): |
|
results = [] |
|
for msg in self.conversation: |
|
results.append({"role": msg.role, "content": self.repr_message(msg)}) |
|
|
|
|
|
for i, msg in enumerate(results): |
|
if msg["role"] == "assistant": |
|
match = re.search(r"Question: (.*?)\n\nResponse:", msg["content"]) |
|
if match and i > 0 and results[i - 1]["role"] == "user": |
|
|
|
results[i - 1]["content"] += "\n" + match.group(1) |
|
|
|
msg["content"] = msg["content"].split("\n\nResponse: ", 1)[1] |
|
return results |
|
|
|
def repr_message(self, msg: ServeMessage): |
|
response = "" |
|
for part in msg.parts: |
|
if isinstance(part, ServeTextPart): |
|
response += part.text |
|
elif isinstance(part, ServeVQPart): |
|
response += f"<audio {len(part.codes[0]) / 21:.2f}s>" |
|
return response |
|
|
|
|
|
def clear_fn(): |
|
return [], ChatState(), None, None, None |
|
|
|
|
|
async def process_audio_input( |
|
sys_audio_input, sys_text_input, audio_input, state: ChatState, text_input: str |
|
): |
|
if audio_input is None and not text_input: |
|
raise gr.Error("No input provided") |
|
|
|
agent = FishE2EAgent() |
|
|
|
|
|
if isinstance(audio_input, tuple): |
|
sr, audio_data = audio_input |
|
elif text_input: |
|
sr = 44100 |
|
audio_data = None |
|
else: |
|
raise gr.Error("Invalid audio format") |
|
|
|
if isinstance(sys_audio_input, tuple): |
|
sr, sys_audio_data = sys_audio_input |
|
else: |
|
sr = 44100 |
|
sys_audio_data = None |
|
|
|
def append_to_chat_ctx( |
|
part: ServeTextPart | ServeVQPart, role: str = "assistant" |
|
) -> None: |
|
if not state.conversation or state.conversation[-1].role != role: |
|
state.conversation.append(ServeMessage(role=role, parts=[part])) |
|
else: |
|
state.conversation[-1].parts.append(part) |
|
|
|
if state.added_systext is False and sys_text_input: |
|
state.added_systext = True |
|
append_to_chat_ctx(ServeTextPart(text=sys_text_input), role="system") |
|
if text_input: |
|
append_to_chat_ctx(ServeTextPart(text=text_input), role="user") |
|
audio_data = None |
|
|
|
result_audio = b"" |
|
async for event in agent.stream( |
|
sys_audio_data, |
|
audio_data, |
|
sr, |
|
1, |
|
chat_ctx={ |
|
"messages": state.conversation, |
|
"added_sysaudio": state.added_sysaudio, |
|
}, |
|
): |
|
if event.type == FishE2EEventType.USER_CODES: |
|
append_to_chat_ctx(ServeVQPart(codes=event.vq_codes), role="user") |
|
elif event.type == FishE2EEventType.SPEECH_SEGMENT: |
|
append_to_chat_ctx(ServeVQPart(codes=event.vq_codes)) |
|
yield state.get_history(), wav_chunk_header() + event.frame.data, None, None |
|
elif event.type == FishE2EEventType.TEXT_SEGMENT: |
|
append_to_chat_ctx(ServeTextPart(text=event.text)) |
|
yield state.get_history(), None, None, None |
|
|
|
yield state.get_history(), None, None, None |
|
|
|
|
|
async def process_text_input( |
|
sys_audio_input, sys_text_input, state: ChatState, text_input: str |
|
): |
|
async for event in process_audio_input( |
|
sys_audio_input, sys_text_input, None, state, text_input |
|
): |
|
yield event |
|
|
|
|
|
def create_demo(): |
|
with gr.Blocks() as demo: |
|
state = gr.State(ChatState()) |
|
|
|
with gr.Row(): |
|
|
|
with gr.Column(scale=7): |
|
chatbot = gr.Chatbot( |
|
[], |
|
elem_id="chatbot", |
|
bubble_full_width=False, |
|
height=600, |
|
type="messages", |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
notes = gr.Markdown( |
|
""" |
|
# Fish Agent |
|
1. This demo is Fish Audio's self-researh end-to-end language model, Fish Agent version 3B. |
|
2. You can find the code and weights in our official repo in [gitub](https://github.com/fishaudio/fish-speech) and [hugging face](https://huggingface.co/fishaudio/fish-agent-v0.1-3b), but the content is released under a CC BY-NC-SA 4.0 licence. |
|
3. The demo is an early alpha test version, the inference speed needs to be optimised. |
|
# Features |
|
1. The model automatically integrates ASR and TTS parts, no need to plug-in other models, i.e., true end-to-end, not three-stage (ASR+LLM+TTS). |
|
2. The model can use reference audio to control the speech timbre. |
|
3. The model can generate speech with strong emotion. |
|
""" |
|
) |
|
|
|
|
|
with gr.Column(scale=3): |
|
sys_audio_input = gr.Audio( |
|
sources=["upload"], |
|
type="numpy", |
|
label="Give a timbre for your assistant", |
|
) |
|
sys_text_input = gr.Textbox( |
|
label="What is your assistant's role?", |
|
value="You are a voice assistant created by Fish Audio, offering end-to-end voice interaction for a seamless user experience. You are required to first transcribe the user's speech, then answer it in the following format: 'Question: [USER_SPEECH]\n\nAnswer: [YOUR_RESPONSE]\n'. You are required to use the following voice in this conversation.", |
|
type="text", |
|
) |
|
audio_input = gr.Audio( |
|
sources=["microphone"], type="numpy", label="Speak your message" |
|
) |
|
|
|
text_input = gr.Textbox(label="Or type your message", type="text") |
|
|
|
output_audio = gr.Audio( |
|
label="Assistant's Voice", |
|
streaming=True, |
|
autoplay=True, |
|
interactive=False, |
|
) |
|
|
|
send_button = gr.Button("Send", variant="primary") |
|
clear_button = gr.Button("Clear") |
|
|
|
|
|
audio_input.stop_recording( |
|
process_audio_input, |
|
inputs=[sys_audio_input, sys_text_input, audio_input, state, text_input], |
|
outputs=[chatbot, output_audio, audio_input, text_input], |
|
show_progress=True, |
|
) |
|
|
|
send_button.click( |
|
process_text_input, |
|
inputs=[sys_audio_input, sys_text_input, state, text_input], |
|
outputs=[chatbot, output_audio, audio_input, text_input], |
|
show_progress=True, |
|
) |
|
|
|
text_input.submit( |
|
process_text_input, |
|
inputs=[sys_audio_input, sys_text_input, state, text_input], |
|
outputs=[chatbot, output_audio, audio_input, text_input], |
|
show_progress=True, |
|
) |
|
|
|
clear_button.click( |
|
clear_fn, |
|
inputs=[], |
|
outputs=[chatbot, state, audio_input, output_audio, text_input], |
|
) |
|
|
|
return demo |
|
|
|
|
|
if __name__ == "__main__": |
|
demo = create_demo() |
|
demo.launch(server_name="127.0.0.1", server_port=7860, share=True) |
|
|