z00mP's picture
change complition model interface
58cde81
raw
history blame
7.03 kB
"""
Credit to Derek Thomas, derek@huggingface.co
"""
import os
import logging
from pathlib import Path
from time import perf_counter
import gradio as gr
from jinja2 import Environment, FileSystemLoader
from backend.query_llm import generate_hf, generate_openai
from backend.semantic_search import retrieve
from backend.reranker import rerank_documents
TOP_K = int(os.getenv("TOP_K", 4))
proj_dir = Path(__file__).parent
# Setting up the logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Set up the template environment with the templates directory
env = Environment(loader=FileSystemLoader(proj_dir / 'templates'))
# Load the templates directly from the environment
template = env.get_template('template.j2')
template_html = env.get_template('template_html.j2')
def add_text(history, text):
history = [] if history is None else history
history = history + [(text, None)]
return history, gr.Textbox(value="", interactive=False)
def bot(history, chunk_table, embedding_model, llm_model, cross_encoder, top_k_param, rerank_topk ):
top_k_param = int(top_k_param)
query = history[-1][0]
logger.info("bot launched ...")
logger.info(f"embedding model: {embedding_model}")
logger.info(f"LLM model: {llm_model}")
logger.info(f"Cross encoder model: {cross_encoder}")
logger.info(f"TopK: {top_k_param}")
logger.info(f"ReRank TopK: {rerank_topk}")
if not query:
raise gr.Warning("Please submit a non-empty string as a prompt")
logger.info('Retrieving documents...')
# Retrieve documents relevant to query
document_start = perf_counter()
#documents = retrieve(query, TOP_K)
documents = retrieve(query, top_k_param, chunk_table, embedding_model)
logger.info('Retrived document count:', len(documents))
if cross_encoder != "None" and len(documents) > 1:
documents = rerank_documents(cross_encoder, documents, query, top_k_rerank=rerank_topk)
#"cross-encoder/ms-marco-MiniLM-L-6-v2"
logger.info('ReRank done, document count:', len(documents))
document_time = perf_counter() - document_start
logger.info(f'Finished Retrieving documents in {round(document_time, 2)} seconds...')
# Create Prompt
prompt = template.render(documents=documents, query=query)
prompt_html = template_html.render(documents=documents, query=query)
if llm_model == "mistralai/Mistral-7B-Instruct-v0.2":
generate_fn = generate_hf
if llm_model == "mistralai/Mistral-7B-v0.1":
generate_fn = generate_hf
if llm_model == "mistralai/Mixtral-8x7B-Instruct-v0.1":
generate_fn = generate_hf
if llm_model == "gpt-3.5-turbo":
generate_fn = generate_openai
if llm_model == "gpt-4-turbo-preview":
generate_fn = generate_openai
#if api_kind == "HuggingFace":
# generate_fn = generate_hf
#elif api_kind == "OpenAI":
# generate_fn = generate_openai
#else:
# raise gr.Error(f"API {api_kind} is not supported")
logger.info(f'Complition started. llm_model: {llm_model}, prompt: {prompt}')
history[-1][1] = ""
for character in generate_fn(prompt, history[:-1], llm_model):
history[-1][1] = character
yield history, prompt_html
with gr.Blocks() as demo:
chatbot = gr.Chatbot(
[],
elem_id="chatbot",
avatar_images=('https://aui.atlassian.com/aui/8.8/docs/images/avatar-person.svg',
'https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.svg'),
bubble_full_width=False,
show_copy_button=True,
show_share_button=True,
)
with gr.Row():
txt = gr.Textbox(
scale=3,
show_label=False,
placeholder="Enter text and press enter",
container=False,
)
txt_btn = gr.Button(value="Submit text", scale=1)
#api_kind = gr.Radio(choices=["HuggingFace",
# "OpenAI"], value="HuggingFace")
chunk_table = gr.Radio(choices=["BGE_CharacterTextSplitter",
"BGE_FixedSizeSplitter",
"BGE_RecursiveCharacterTextSplitter",
"MiniLM_CharacterTextSplitter",
"MiniLM_FixedSizeSplitter",
"MiniLM_RecursiveCharacterSplitter"
],
value="MiniLM_CharacterTextSplitter",
label="Chunk table")
embedding_model = gr.Radio(
choices=[
"BAAI/bge-large-en-v1.5",
"sentence-transformers/all-MiniLM-L6-v2",
],
value="sentence-transformers/all-MiniLM-L6-v2",
label='Embedding model'
)
llm_model = gr.Radio(
choices=[
"mistralai/Mistral-7B-Instruct-v0.2",
"gpt-3.5-turbo",
"gpt-4-turbo-preview",
"mistralai/Mistral-7B-v0.1",
"mistralai/Mixtral-8x7B-Instruct-v0.1"
],
value="gpt-3.5-turbo",
label='LLM'
)
cross_encoder = gr.Radio(
choices=[
"None",
"BAAI/bge-reranker-large",
"cross-encoder/ms-marco-MiniLM-L-6-v2",
],
value="None",
label='Cross-encoder model'
)
top_k_param = gr.Radio(
choices=[
"5",
"10",
"20",
"50",
],
value="5",
label='top-K'
)
rerank_topk = gr.Radio(
choices=[
"5",
"10",
"20",
"50",
],
value="5",
label='rerank-top-K'
)
prompt_html = gr.HTML()
# Turn off interactivity while generating if you click
txt_msg = txt_btn.click(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
bot, [chatbot, chunk_table, embedding_model, llm_model, cross_encoder, top_k_param, rerank_topk], [chatbot, prompt_html])
# Turn it back on
txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
# Turn off interactivity while generating if you hit enter
txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
bot, [chatbot, chunk_table, embedding_model, llm_model, cross_encoder, top_k_param, rerank_topk], [chatbot, prompt_html])
# Turn it back on
txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
demo.queue()
demo.launch(debug=True)