Rifky's picture
Added autoplay audio and fix responsiveness issue
bbac4a7
raw
history blame
9.32 kB
import gradio as gr
import base64
import requests
import secrets
import os
import argparse
from io import BytesIO
from pydub import AudioSegment
LOCAL_API_ENDPOINT = "http://localhost:5000"
PUBLIC_API_ENDPOINT = "http://121.176.153.117:5000"
API_ENDPOINT = PUBLIC_API_ENDPOINT
session_id = ""
chat_history = []
css = """
#audio_input {
margin-top: -30px; !important;
margin-left: -15px; !important;
width: 100% !important;
}
#audio_input button {
height:50px !important;
font-size: 0px !important;
width: 110% !important;
}
#audio_input button:after {
content: '🎤' !important;
font-size: 16px !important;
}
audio {
min-width: 200px !important;
}
@media (max-width : 480px) {
#audio_input {
width: 120% !important;
}
#audio_input button:after {
content: '' !important;
}
#txt_input_container {
flex-grow: 70% !important;
}
#audio_input_container {
flex-grow: 30% !important;
}
}
"""
js_audio_auto_play = """
() => {
// select last audio element
const audio = document.getElementsByTagName('audio');
const last_audio = audio[audio.length - 1];
// set autoplay attribute
last_audio.setAttribute('autoplay', true);
}
"""
def create_chat_session():
r = requests.post(API_ENDPOINT + "/create")
if (r.status_code != 201):
raise Exception("Failed to create chat session")
# create temp audio folder
session_id = r.json()["id"]
os.makedirs(f"./temp_audio/{session_id}")
return session_id
def create_new_or_change_session(history, id):
global session_id
global chat_history
if id == "":
session_id = create_chat_session()
history = []
else:
history, _ = change_session(history, id)
chat_history = history
return history, gr.update(value="", interactive=False)
def add_text(history, text):
history = history + [(text, None)]
return history, gr.update(value="", interactive=False)
def add_audio(history, audio):
audio_bytes = base64.b64decode(audio['data'].split(',')[-1].encode('utf-8'))
audio_file = BytesIO(audio_bytes)
AudioSegment.from_file(audio_file).export(audio_file, format="mp3")
# save audio file temporary to disk
audio_id = secrets.token_hex(8)
AudioSegment.from_file(audio_file).export(f"temp_audio/{session_id}/audio_input_{audio_id}.mp3", format="mp3")
history = history + [((f"temp_audio/{session_id}/audio_input_{audio_id}.mp3",), None)]
response = requests.post(
API_ENDPOINT + "/transcribe",
files={'audio': audio_file.getvalue()}
)
if (response.status_code != 200):
raise Exception(response.text)
text = response.json()['text']
history = history + [(text, None)]
return history, gr.update(value="", interactive=False)
def reset_chat_session(history):
global session_id
global chat_history
response = requests.post(
API_ENDPOINT + f"/reset/{session_id}"
)
if (response.status_code != 200):
raise Exception(response.text)
history = []
chat_history = []
return history
def bot(history):
if type(history[-1][0]) == str:
message = history[-1][0]
else:
message = history[-2][0]
response = requests.post(
API_ENDPOINT + f"/send/text/{session_id}",
headers={'Content-type': 'application/json'},
json={
'message': message,
'role': 'user'
}
)
if (response.status_code != 200):
raise Exception(f"Failed to send message, {response.text}")
response = response.json()
text, audio = response['text'], response['audio']
audio_bytes = base64.b64decode(audio.encode('utf-8'))
audio_file = BytesIO(audio_bytes)
audio_id = secrets.token_hex(8)
AudioSegment.from_file(audio_file).export(f"temp_audio/{session_id}/audio_input_{audio_id}.mp3", format="mp3")
history = history + [(None, (f"temp_audio/{session_id}/audio_input_{audio_id}.mp3",))]
history = history + [(None, text)]
global chat_history
chat_history = history.copy()
return history
def change_session(history, id):
global session_id
global chat_history
response = requests.get(
API_ENDPOINT + f"/{id}"
)
if (response.status_code != 200):
raise Exception(response.text)
response = response.json()
session_id = id
history = []
try:
for chat in response:
if chat['role'] == 'user':
if chat['audio'] != "":
audio_bytes = base64.b64decode(chat['audio'].encode('utf-8'))
audio_file = BytesIO(audio_bytes)
audio_id = secrets.token_hex(8)
AudioSegment.from_file(audio_file).export(f"temp_audio/{id}/audio_input_{audio_id}.mp3", format="mp3")
history = history + [((f"temp_audio/{id}/audio_input_{audio_id}.mp3",), None)]
history = history + [(chat['message'], None)]
elif chat['role'] == 'assistant':
audio_bytes = base64.b64decode(chat['audio'].encode('utf-8'))
audio_file = BytesIO(audio_bytes)
audio_id = secrets.token_hex(8)
AudioSegment.from_file(audio_file).export(f"temp_audio/{id}/audio_input_{audio_id}.mp3", format="mp3")
history = history + [(None, (f"temp_audio/{id}/audio_input_{audio_id}.mp3",))]
history = history + [(None, chat['message'])]
else:
raise Exception("Invalid chat role")
except Exception as e:
raise Exception(f"Response: {response}")
chat_history = history.copy()
print(f"len(chat_history): {len(chat_history)}\nlen(history): {len(history)}\nlen(response): {len(response)}")
return history, gr.update(value="", interactive=False)
def load_chat_history(history):
global chat_history
if len(chat_history) > len(history):
history = chat_history
return history
def main():
global session_id
global chat_history
session_id = create_chat_session()
chat_history = []
with gr.Blocks(css=css) as demo:
with gr.Row():
# change session id
change_session_txt = gr.Textbox(
show_label=False,
placeholder=session_id,
).style(container=False)
with gr.Row():
# button to create new or change session id
change_session_button = gr.Button(
"Create new or change session", type='success', size="sm"
).style(container=False)
chatbot = gr.Chatbot([], elem_id="chatbot").style(height=750)
demo.load(load_chat_history, [chatbot], [chatbot], queue=False)
with gr.Row():
with gr.Column(scale=0.85, min_width=0, elem_id="txt_input_container"):
txt = gr.Textbox(
show_label=False,
placeholder="Enter text and press enter, or record audio",
elem_id="txt_input"
).style(container=False)
with gr.Column(scale=0.15, min_width=0, elem_id="audio_input_container"):
audio = gr.Audio(
source="microphone", type="numpy", show_label=False, format="mp3", min_width=0, container=False, elem_id="audio_input"
)
with gr.Row():
reset_button = gr.Button(
"Reset Chat Session", type='stop', size="sm"
).style(container=False)
txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
bot, chatbot, chatbot
).then(
None, [], [], queue=False, _js=js_audio_auto_play
)
txt_msg.then(lambda: gr.update(interactive=True), None, [txt], queue=False)
audio_msg = audio.change(add_audio, [chatbot, audio], [chatbot, audio], queue=False, preprocess=False, postprocess=False).then(
bot, chatbot, chatbot
).then(
None, [], [], queue=False, _js=js_audio_auto_play
)
audio_msg.then(lambda: gr.update(interactive=True, value=None), None, [audio], queue=False)
reset_button.click(reset_chat_session, [chatbot], [chatbot], queue=False)
chgn_msg = change_session_txt.submit(change_session, [chatbot, change_session_txt], [chatbot, change_session_txt], queue=False)
chgn_msg.then(lambda: gr.update(interactive=True, placeholder=session_id), None, [change_session_txt], queue=False)
create_new_or_change_session_btn = change_session_button.click(create_new_or_change_session, [chatbot, change_session_txt], [chatbot, change_session_txt], queue=False)
create_new_or_change_session_btn.then(lambda: gr.update(interactive=True, placeholder=session_id), None, [change_session_txt], queue=False)
return demo
if __name__ == "__main__":
# arguments --local
parser = argparse.ArgumentParser()
parser.add_argument("--local", action="store_true", help="Use local API endpoint")
args = parser.parse_args()
if args.local:
API_ENDPOINT = LOCAL_API_ENDPOINT
demo = main()
demo.launch(show_error=True, server_name="0.0.0.0")