arabic-RAG / app.py
derek-thomas's picture
derek-thomas HF staff
Making updates to be more Arabic
3cb8374
raw
history blame
No virus
6.03 kB
import logging
from functools import partial
from pathlib import Path
from time import perf_counter
import gradio as gr
from jinja2 import Environment, FileSystemLoader
from transformers import AutoTokenizer
from backend.query_llm import check_endpoint_status, generate
from backend.semantic_search import retriever
proj_dir = Path(__file__).parent
# Setting up the logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Set up the template environment with the templates directory
env = Environment(loader=FileSystemLoader(proj_dir / 'templates'))
# Load the templates directly from the environment
template = env.get_template('template.j2')
template_html = env.get_template('template_html.j2')
# Initialize tokenizer
tokenizer = AutoTokenizer.from_pretrained('derek-thomas/jais-13b-chat-hf')
# Examples
examples = ['ما هي عاصمة الصين؟',
'لم السماء زرقاء؟',
"من فاز بكأس العالم للرجال في عام 2014؟",]
def add_text(history, text):
history = [] if history is None else history
history = history + [(text, None)]
return history, gr.Textbox(value="", interactive=False)
def bot(history, hyde=False):
top_k = 5
query = history[-1][0]
logger.warning('Retrieving documents...')
# Retrieve documents relevant to query
document_start = perf_counter()
if hyde:
hyde_document = generate(f"Write a wikipedia article intro paragraph to answer this query: {query}")[-1]
logger.warning(hyde_document)
documents = retriever(hyde_document, top_k=top_k)
else:
documents = retriever(query, top_k=top_k)
document_time = perf_counter() - document_start
logger.warning(f'Finished Retrieving documents in {round(document_time, 2)} seconds...')
# Function to count tokens
def count_tokens(text):
return len(tokenizer.encode(text))
# Create Prompt
prompt = template.render(documents=documents, query=query)
# Check if the prompt is too long
token_count = count_tokens(prompt)
while token_count > 2048:
# Shorten your documents here. This is just a placeholder for the logic you'd use.
documents.pop() # Remove the last document
prompt = template.render(documents=documents, query=query) # Re-render the prompt
token_count = count_tokens(prompt) # Re-count tokens
prompt_html = template_html.render(documents=documents, query=query)
history[-1][1] = ""
for character in generate(prompt):
history[-1][1] = character
yield history, prompt_html
with gr.Blocks() as demo:
endpoint_status = gr.Textbox(check_endpoint_status, label="Endpoint Status (send chat to wake up)", every=1)
with gr.Tab("Arabic-RAG"):
chatbot = gr.Chatbot(
[],
elem_id="chatbot",
avatar_images=('https://aui.atlassian.com/aui/8.8/docs/images/avatar-person.svg',
'https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.svg'),
bubble_full_width=False,
show_copy_button=True,
show_share_button=True,
)
with gr.Row():
txt = gr.Textbox(
scale=3,
show_label=False,
placeholder="Enter text and press enter",
container=False,
)
txt_btn = gr.Button(value="Submit text", scale=1)
gr.Examples(examples, txt)
prompt_html = gr.HTML()
# Turn off interactivity while generating if you click
txt_msg = txt_btn.click(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
bot, chatbot, [chatbot, prompt_html])
# Turn it back on
txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
# Turn off interactivity while generating if you hit enter
txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
bot, chatbot, [chatbot, prompt_html])
# Turn it back on
txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
with gr.Tab("Arabic-RAG + HyDE"):
hyde_chatbot = gr.Chatbot(
[],
elem_id="chatbot",
avatar_images=('https://aui.atlassian.com/aui/8.8/docs/images/avatar-person.svg',
'https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.svg'),
bubble_full_width=False,
show_copy_button=True,
show_share_button=True,
)
with gr.Row():
hyde_txt = gr.Textbox(
scale=3,
show_label=False,
placeholder="Enter text and press enter",
container=False,
)
hyde_txt_btn = gr.Button(value="Submit text", scale=1)
gr.Examples(examples, hyde_txt)
hyde_prompt_html = gr.HTML()
# Turn off interactivity while generating if you click
hyde_txt_msg = hyde_txt_btn.click(add_text, [hyde_chatbot, hyde_txt], [hyde_chatbot, hyde_txt],
queue=False).then(
partial(bot, hyde=True), [hyde_chatbot], [hyde_chatbot, hyde_prompt_html])
# Turn it back on
hyde_txt_msg.then(lambda: gr.Textbox(interactive=True), None, [hyde_txt], queue=False)
# Turn off interactivity while generating if you hit enter
hyde_txt_msg = hyde_txt.submit(add_text, [hyde_chatbot, hyde_txt], [hyde_chatbot, hyde_txt], queue=False).then(
partial(bot, hyde=True), [hyde_chatbot], [hyde_chatbot, hyde_prompt_html])
# Turn it back on
hyde_txt_msg.then(lambda: gr.Textbox(interactive=True), None, [hyde_txt], queue=False)
# Examples
demo.queue()
demo.launch(debug=True)