rohan13's picture
added press enter to label
151932e
raw
history blame
4.33 kB
import time
import uuid
import gradio as gr
from gtts import gTTS
from transformers import pipeline
from main import index, run
p = pipeline("automatic-speech-recognition", model="openai/whisper-base")
"""Use text to call chat method from main.py"""
models = ["GPT-3.5", "Flan UL2", "Flan T5"]
with gr.Blocks(theme='snehilsanyal/scikit-learn') as demo:
state = gr.State([])
def create_session_id():
return str(uuid.uuid4())
def add_text(history, text, model):
print("Question asked: " + text)
response = run_model(text, model)
history = history + [(text, response)]
print(history)
return history, ""
def run_model(text, model):
start_time = time.time()
print("start time:" + str(start_time))
response = run(text, model, state.session_id)
end_time = time.time()
# If response contains string `SOURCES:`, then add a \n before `SOURCES`
if "SOURCES:" in response:
response = response.replace("SOURCES:", "\nSOURCES:")
# response = response + "\n\n" + "Time taken: " + str(end_time - start_time)
print(response)
print("Time taken: " + str(end_time - start_time))
return response
def get_output(history, audio, model):
txt = p(audio)["text"]
# history.append(( (audio, ) , txt))
audio_path = 'response.wav'
response = run_model(txt, model)
# Remove all text from SOURCES: to the end of the string
trimmed_response = response.split("SOURCES:")[0]
myobj = gTTS(text=trimmed_response, lang='en', slow=False)
myobj.save(audio_path)
# split audio by / and keep the last element
# audio = audio.split("/")[-1]
# audio = audio + ".wav"
history.append(((audio,), (audio_path,)))
print(history)
return history
def set_model(history, model):
print("Model selected: " + model)
history = get_first_message(history)
index(model, state.session_id)
return history
def get_first_message(history):
history = [(None,
'Learn about the course and get answers with referred sources.\nWarning! Use the bot wisely. It might give incorrect answers.')]
return history
def bot(history):
return history
state.session_id = create_session_id()
print("Session ID: " + state.session_id)
# Title on top in middle of the page
# gr.HTML("<h1 style='text-align: center;'>Course Assistant - 3D Printing Revolution</h1>")
chatbot = gr.Chatbot(get_first_message([]), elem_id="chatbot", label='3D Printing Revolution').style(height=300,
container=False)
# with gr.Row():
# Create radio button to select model
radio = gr.Radio(models, label="Choose a model", value="GPT-3.5", type="value", visible=False)
with gr.Row():
# with gr.Column(scale=0.75):
txt = gr.Textbox(
label="Ask your question here and press enter",
placeholder="Enter text and press enter", lines=1
).style(container=False)
# with gr.Column(scale=0.25):
audio = gr.Audio(source="microphone", type="filepath", visible=False)
with gr.Row():
gr.Examples(
examples=['What is 3D printing?', 'Who are the instructors of the course?', 'What is the course about?',
'Which software can be used to create a design file for 3D printing?',
'What are the key takeaways from the course?', 'How to create a 3D printing design file?'],
inputs=[txt],
label="Examples")
txt.submit(add_text, [chatbot, txt, radio], [chatbot, txt], postprocess=False).then(
bot, chatbot, chatbot
)
audio.change(fn=get_output, inputs=[chatbot, audio, radio], outputs=[chatbot], show_progress=True).then(
bot, chatbot, chatbot
)
radio.change(fn=set_model, inputs=[chatbot, radio], outputs=[chatbot]).then(bot, chatbot, chatbot)
audio.change(lambda: None, None, audio)
set_model(chatbot, radio.value)
if __name__ == "__main__":
demo.queue()
demo.queue(concurrency_count=5)
demo.launch(debug=True)