rag_test_task / app.py
alexandraroze's picture
fixed docs
ff6e87f
import logging
from pathlib import Path
import gradio as gr
import os
from jinja2 import Environment, FileSystemLoader
from src.chat import Chat
from src.rag import FaissDB, AICompletion, define_query
from src.prompts import *
chat_model = AICompletion()
chat = Chat(system_prompt=SYSTEM_PROMPT)
faiss_index = FaissDB(emb_model=os.environ["OPENAI_EMBEDDINGS_MODEL"])
faiss_index.load_index(os.environ["PATH_TO_INDEX"])
proj_dir = Path(__file__).parent
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
env = Environment(loader=FileSystemLoader(proj_dir / 'templates'))
template_html = env.get_template('template_html.j2')
def add_text(history, text):
history = [] if history is None else history
history = history + [(text, None)]
return history, gr.Textbox(value="", interactive=False)
def bot(history):
user_query = history[-1][0]
if not user_query:
raise gr.Warning("Please submit a non-empty string")
logger.info('Retrieving documents...')
retrieve_query = define_query(user_query, chat_model)
documents = faiss_index.similarity_search(retrieve_query) if retrieve_query else ''
user_prompt = USER_PROMPT(user_query, '\n'.join(documents))
prompt_html = template_html.render(documents=documents, query=retrieve_query if retrieve_query else 'No query')
stream = chat.stream(user_prompt)
history[-1][1] = ""
for character in stream:
history[-1][1] = character
yield history, prompt_html
with gr.Blocks() as demo:
chatbot = gr.Chatbot(
[],
elem_id="chatbot",
avatar_images=('https://aui.atlassian.com/aui/8.8/docs/images/avatar-person.svg',
'https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.svg'),
bubble_full_width=False,
show_copy_button=True,
show_share_button=True,
)
with gr.Row():
txt = gr.Textbox(
scale=3,
show_label=False,
placeholder="Enter text and press enter",
container=False,
)
txt_btn = gr.Button(value="Submit text", scale=1)
prompt_html = gr.HTML()
txt_msg = txt_btn.click(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
bot, [chatbot], [chatbot, prompt_html])
txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
bot, [chatbot], [chatbot, prompt_html])
txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
demo.queue()
demo.launch(debug=True)