Spaces:
Running
Running
import gradio as gr | |
import os | |
import time | |
from haystack.document_stores import InMemoryDocumentStore | |
from haystack.nodes import EmbeddingRetriever | |
import pandas as pd | |
def load_qa_model(): | |
document_store = InMemoryDocumentStore() | |
retriever = EmbeddingRetriever( | |
document_store=document_store, | |
embedding_model="sentence-transformers/all-MiniLM-L6-v2", | |
use_gpu=False, | |
scale_score=False, | |
) | |
# Get dataframe with columns "question", "answer" and some custom metadata | |
df = pd.read_csv('/content/social-faq.csv', on_bad_lines='skip', delimiter=';') | |
# Minimal cleaning | |
df.fillna(value="", inplace=True) | |
df["question"] = df["question"].apply(lambda x: x.strip()) | |
questions = list(df["question"].values) | |
df["embedding"] = retriever.embed_queries(queries=questions).tolist() | |
df = df.rename(columns={"question": "content"}) | |
# Convert Dataframe to list of dicts and index them in our DocumentStore | |
docs_to_index = df.to_dict(orient="records") | |
document_store.write_documents(docs_to_index) | |
return retriever | |
def add_text(history, text): | |
history = history + [(text, None)] | |
return history, gr.Textbox(value="", interactive=False) | |
def add_file(history, file): | |
history = history + [((file.name,), None)] | |
return history | |
def bot(history): | |
print(history) | |
# response = "**That's cool!**" | |
history[-1][1] = "" | |
global retriever | |
response = get_answers(retriever, history[0][0]) | |
for character in response: | |
history[-1][1] += character | |
time.sleep(0.01) | |
yield history | |
def get_answers(retriever, query): | |
from haystack.pipelines import FAQPipeline | |
pipe = FAQPipeline(retriever=retriever) | |
from haystack.utils import print_answers | |
# Run any question and change top_k to see more or less answers | |
prediction = pipe.run(query=query, params={"Retriever": {"top_k": 1}}) | |
answers = prediction['answers'] | |
if answers: | |
return answers[0].answer | |
else: | |
return "I don't have an answer to that question" | |
retriever = load_qa_model() | |
with gr.Blocks() as demo: | |
chatbot = gr.Chatbot( | |
[], | |
elem_id="chatbot", | |
bubble_full_width=False, | |
# avatar_images=(None, "/content/avatar.png"), | |
) | |
with gr.Row(): | |
txt = gr.Textbox( | |
scale=4, | |
show_label=False, | |
placeholder="Enter text and press enter", | |
container=False, | |
) | |
inputRecord = gr.Audio(label="Record a question", source="microphone", type="filepath") | |
audioOutput = gr.Audio(label="Listen the answer", interactive=False) | |
txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then( | |
bot, chatbot, chatbot | |
) | |
txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False) | |
demo.queue() | |
demo.launch() | |