|
import gradio as gr |
|
import random |
|
|
|
|
|
import json |
|
from pathlib import Path |
|
from pprint import pprint |
|
|
|
import uuid |
|
import chromadb |
|
from chromadb.utils import embedding_functions |
|
|
|
|
|
import torch |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer |
|
|
|
print(f"Is CUDA available: {torch.cuda.is_available()}") |
|
print( |
|
f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}") |
|
|
|
models = { |
|
"wizardLM-7B-HF": "TheBloke/wizardLM-7B-HF", |
|
"wizard-vicuna-13B-GPTQ": "TheBloke/wizard-vicuna-13B-GPTQ", |
|
"Wizard-Vicuna-13B-Uncensored": "ehartford/Wizard-Vicuna-13B-Uncensored", |
|
"WizardLM-13B": "TheBloke/WizardLM-13B-V1.0-Uncensored-GPTQ", |
|
"Llama-2-7B": "TheBloke/Llama-2-7b-Chat-GPTQ", |
|
"Vicuna-13B": "TheBloke/vicuna-13B-v1.5-GPTQ", |
|
"WizardLM-13B-V1.2": "TheBloke/WizardLM-13B-V1.2-GPTQ", |
|
"Mistral-7B": "TheBloke/Mistral-7B-Instruct-v0.1-GPTQ" |
|
} |
|
|
|
|
|
model_name = "Mistral-7B" |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(models[model_name]) |
|
|
|
tokenizer.chat_template = tokenizer.default_chat_template |
|
|
|
model = AutoModelForCausalLM.from_pretrained(models[model_name], |
|
torch_dtype=torch.float16, |
|
device_map="auto") |
|
|
|
|
|
file_path = './data/faq_dataset.json' |
|
data = json.loads(Path(file_path).read_text()) |
|
|
|
|
|
client = chromadb.Client() |
|
|
|
emb_fn = embedding_functions.SentenceTransformerEmbeddingFunction( |
|
model_name="BAAI/bge-small-en-v1.5") |
|
|
|
collection = client.create_collection( |
|
name="retrieval_qa", |
|
embedding_function=emb_fn, |
|
metadata={"hnsw:space": "cosine"} |
|
) |
|
|
|
|
|
documents = [json.dumps(q) for q in data['questions']] |
|
metadatas = data['questions'] |
|
ids = [str(uuid.uuid1()) for _ in documents] |
|
|
|
|
|
collection.add( |
|
documents=documents, |
|
metadatas=metadatas, |
|
ids=ids |
|
) |
|
|
|
samples = [ |
|
["How can I return a product?"], |
|
["What is the return policy?"], |
|
["How can I contact customer support?"], |
|
] |
|
|
|
|
|
def respond(query): |
|
global samples |
|
docs = collection.query(query_texts=[query], n_results=3) |
|
chat = [] |
|
related_questions = [] |
|
references = "## References\n" |
|
|
|
system_message = "You are a helpful, respectful and honest support executive. Always be as helpfully as possible, while being correct. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. Use the following piece of context to answer the questions. If the information is not present in the provided context, answer that you don't know. Please don't share false information." |
|
|
|
for d in docs['metadatas'][0]: |
|
|
|
system_message += f"\n Question: {d['question']} \n Answer: {d['answer']}" |
|
|
|
|
|
references += f"**{d['question']}**\n\n" |
|
references += f"> {d['answer']}\n\n" |
|
|
|
|
|
related_questions.append([d['question']]) |
|
|
|
chat.append({"role": "system", "content": system_message}) |
|
chat.append({"role": "user", "content": query}) |
|
|
|
encodeds = tokenizer.apply_chat_template(chat, return_tensors="pt") |
|
|
|
model_inputs = encodeds.to(model.device) |
|
streamer = TextStreamer(tokenizer) |
|
|
|
model.to(model.device) |
|
|
|
generated_ids = model.generate( |
|
model_inputs, streamer=streamer, temperature=0.01, max_new_tokens=100, do_sample=True) |
|
answer = tokenizer.batch_decode( |
|
generated_ids[:, model_inputs.shape[1]:])[0] |
|
answer = answer.replace('</s>', '') |
|
samples = related_questions |
|
|
|
related = gr.update(samples=related_questions) |
|
|
|
yield [answer, references, related] |
|
|
|
|
|
def load_example(example_id): |
|
global samples |
|
return samples[example_id][0] |
|
|
|
|
|
with gr.Blocks() as chatbot: |
|
with gr.Row(): |
|
with gr.Column(): |
|
answer_block = gr.Textbox(label="Answers", lines=2) |
|
question = gr.Textbox(label="Question") |
|
examples = gr.Dataset(samples=samples, components=[ |
|
question], label="Similar questions", type="index") |
|
generate = gr.Button(value="Ask") |
|
with gr.Column(): |
|
references_block = gr.Markdown( |
|
"## References\n", label="global variable") |
|
|
|
examples.click(load_example, inputs=[examples], outputs=[question]) |
|
generate.click(respond, inputs=question, outputs=[ |
|
answer_block, references_block, examples]) |
|
|
|
chatbot.queue() |
|
chatbot.launch() |
|
|