Spaces:
Running
Running
# Load in packages | |
import os | |
from typing import Type | |
from langchain_community.embeddings import HuggingFaceEmbeddings#, HuggingFaceInstructEmbeddings | |
from langchain_community.vectorstores import FAISS | |
import gradio as gr | |
import pandas as pd | |
from transformers import AutoTokenizer | |
from ctransformers import AutoModelForCausalLM | |
PandasDataFrame = Type[pd.DataFrame] | |
# Disable cuda devices if necessary | |
#os.environ['CUDA_VISIBLE_DEVICES'] = '-1' | |
#from chatfuncs.chatfuncs import * | |
import chatfuncs.ingest as ing | |
## Load preset embeddings, vectorstore, and model | |
embeddings_name = "BAAI/bge-base-en-v1.5" | |
def load_embeddings(embeddings_name = embeddings_name): | |
embeddings_func = HuggingFaceEmbeddings(model_name=embeddings_name) | |
global embeddings | |
embeddings = embeddings_func | |
return embeddings | |
def get_faiss_store(faiss_vstore_folder,embeddings): | |
import zipfile | |
with zipfile.ZipFile(faiss_vstore_folder + '/' + faiss_vstore_folder + '.zip', 'r') as zip_ref: | |
zip_ref.extractall(faiss_vstore_folder) | |
faiss_vstore = FAISS.load_local(folder_path=faiss_vstore_folder, embeddings=embeddings) | |
os.remove(faiss_vstore_folder + "/index.faiss") | |
os.remove(faiss_vstore_folder + "/index.pkl") | |
global vectorstore | |
vectorstore = faiss_vstore | |
return vectorstore | |
import chatfuncs.chatfuncs as chatf | |
chatf.embeddings = load_embeddings(embeddings_name) | |
chatf.vectorstore = get_faiss_store(faiss_vstore_folder="faiss_embedding",embeddings=globals()["embeddings"]) | |
def load_model(model_type, gpu_layers, gpu_config=None, cpu_config=None, torch_device=None): | |
print("Loading model") | |
# Default values inside the function | |
if gpu_config is None: | |
gpu_config = chatf.gpu_config | |
if cpu_config is None: | |
cpu_config = chatf.cpu_config | |
if torch_device is None: | |
torch_device = chatf.torch_device | |
if model_type == "Mistral Open Orca (larger, slow)": | |
if torch_device == "cuda": | |
gpu_config.update_gpu(gpu_layers) | |
else: | |
gpu_config.update_gpu(gpu_layers) | |
cpu_config.update_gpu(gpu_layers) | |
print("Loading with", cpu_config.gpu_layers, "model layers sent to GPU.") | |
print(vars(gpu_config)) | |
print(vars(cpu_config)) | |
try: | |
#model = AutoModelForCausalLM.from_pretrained('Aryanne/Orca-Mini-3B-gguf', model_type='llama', model_file='q5_0-orca-mini-3b.gguf', **vars(gpu_config)) # **asdict(CtransRunConfig_cpu()) | |
#model = AutoModelForCausalLM.from_pretrained('Aryanne/Wizard-Orca-3B-gguf', model_type='llama', model_file='q4_1-wizard-orca-3b.gguf', **vars(gpu_config)) # **asdict(CtransRunConfig_cpu()) | |
model = AutoModelForCausalLM.from_pretrained('TheBloke/Mistral-7B-OpenOrca-GGUF', model_type='mistral', model_file='mistral-7b-openorca.Q4_K_M.gguf', **vars(gpu_config)) # **asdict(CtransRunConfig_cpu()) | |
#model = AutoModelForCausalLM.from_pretrained('TheBloke/MistralLite-7B-GGUF', model_type='mistral', model_file='mistrallite.Q4_K_M.gguf', **vars(gpu_config)) # **asdict(CtransRunConfig_cpu()) | |
except: | |
#model = AutoModelForCausalLM.from_pretrained('Aryanne/Orca-Mini-3B-gguf', model_type='llama', model_file='q5_0-orca-mini-3b.gguf', **vars(cpu_config)) #**asdict(CtransRunConfig_gpu()) | |
#model = AutoModelForCausalLM.from_pretrained('Aryanne/Wizard-Orca-3B-gguf', model_type='llama', model_file='q4_1-wizard-orca-3b.gguf', **vars(cpu_config)) # **asdict(CtransRunConfig_cpu()) | |
model = AutoModelForCausalLM.from_pretrained('TheBloke/Mistral-7B-OpenOrca-GGUF', model_type='mistral', model_file='mistral-7b-openorca.Q4_K_M.gguf', **vars(cpu_config)) # **asdict(CtransRunConfig_cpu()) | |
#model = AutoModelForCausalLM.from_pretrained('TheBloke/MistralLite-7B-GGUF', model_type='mistral', model_file='mistrallite.Q4_K_M.gguf', **vars(cpu_config)) # **asdict(CtransRunConfig_cpu()) | |
tokenizer = [] | |
if model_type == "Flan Alpaca (small, fast)": | |
# Huggingface chat model | |
hf_checkpoint = 'declare-lab/flan-alpaca-large'#'declare-lab/flan-alpaca-base' # # # | |
def create_hf_model(model_name): | |
from transformers import AutoModelForSeq2SeqLM, AutoModelForCausalLM | |
if torch_device == "cuda": | |
if "flan" in model_name: | |
model = AutoModelForSeq2SeqLM.from_pretrained(model_name, device_map="auto") | |
else: | |
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto") | |
else: | |
if "flan" in model_name: | |
model = AutoModelForSeq2SeqLM.from_pretrained(model_name) | |
else: | |
model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True) | |
tokenizer = AutoTokenizer.from_pretrained(model_name, model_max_length = chatf.context_length) | |
return model, tokenizer, model_type | |
model, tokenizer, model_type = create_hf_model(model_name = hf_checkpoint) | |
chatf.model = model | |
chatf.tokenizer = tokenizer | |
chatf.model_type = model_type | |
load_confirmation = "Finished loading model: " + model_type | |
print(load_confirmation) | |
return model_type, load_confirmation, model_type | |
# Both models are loaded on app initialisation so that users don't have to wait for the models to be downloaded | |
model_type = "Mistral Open Orca (larger, slow)" | |
load_model(model_type, chatf.gpu_layers, chatf.gpu_config, chatf.cpu_config, chatf.torch_device) | |
model_type = "Flan Alpaca (small, fast)" | |
load_model(model_type, 0, chatf.gpu_config, chatf.cpu_config, chatf.torch_device) | |
def docs_to_faiss_save(docs_out:PandasDataFrame, embeddings=embeddings): | |
print(f"> Total split documents: {len(docs_out)}") | |
print(docs_out) | |
vectorstore_func = FAISS.from_documents(documents=docs_out, embedding=embeddings) | |
chatf.vectorstore = vectorstore_func | |
out_message = "Document processing complete" | |
return out_message, vectorstore_func | |
# Gradio chat | |
block = gr.Blocks(theme = gr.themes.Base())#css=".gradio-container {background-color: black}") | |
with block: | |
ingest_text = gr.State() | |
ingest_metadata = gr.State() | |
ingest_docs = gr.State() | |
model_type_state = gr.State(model_type) | |
embeddings_state = gr.State(chatf.embeddings)#globals()["embeddings"]) | |
vectorstore_state = gr.State(chatf.vectorstore)#globals()["vectorstore"]) | |
model_state = gr.State() # chatf.model (gives error) | |
tokenizer_state = gr.State() # chatf.tokenizer (gives error) | |
chat_history_state = gr.State() | |
instruction_prompt_out = gr.State() | |
gr.Markdown("<h1><center>Lightweight PDF / web page QA bot</center></h1>") | |
gr.Markdown("Chat with PDF, web page or (new) csv/Excel documents.") | |
with gr.Row(): | |
current_source = gr.Textbox(label="Current data source(s)", value="Lambeth_2030-Our_Future_Our_Lambeth.pdf", scale = 10) | |
current_model = gr.Textbox(label="Current model", value=model_type, scale = 3) | |
with gr.Tab("Chatbot"): | |
with gr.Row(): | |
#chat_height = 500 | |
chatbot = gr.Chatbot(avatar_images=('user.jfif', 'bot.jpg'),bubble_full_width = False, scale = 1) # , height=chat_height | |
with gr.Accordion("Open this tab to see the source paragraphs used to generate the answer", open = False): | |
sources = gr.HTML(value = "Source paragraphs with the most relevant text will appear here", scale = 1) # , height=chat_height | |
with gr.Row(): | |
message = gr.Textbox( | |
label="Enter your question here", | |
lines=1, | |
) | |
with gr.Row(): | |
submit = gr.Button(value="Send message", variant="secondary", scale = 1) | |
clear = gr.Button(value="Clear chat", variant="secondary", scale=0) | |
stop = gr.Button(value="Stop generating", variant="secondary", scale=0) | |
examples_set = gr.Radio(label="Examples for the Lambeth Borough Plan", | |
#value = "What were the five pillars of the previous borough plan?", | |
choices=["What were the five pillars of the previous borough plan?", | |
"What is the vision statement for Lambeth?", | |
"What are the commitments for Lambeth?", | |
"What are the 2030 outcomes for Lambeth?"]) | |
current_topic = gr.Textbox(label="Feature currently disabled - Keywords related to current conversation topic.", placeholder="Keywords related to the conversation topic will appear here") | |
with gr.Tab("Load in a different file to chat with"): | |
with gr.Accordion("PDF file", open = False): | |
in_pdf = gr.File(label="Upload pdf", file_count="multiple", file_types=['.pdf']) | |
load_pdf = gr.Button(value="Load in file", variant="secondary", scale=0) | |
with gr.Accordion("Web page", open = False): | |
with gr.Row(): | |
in_web = gr.Textbox(label="Enter web page url") | |
in_div = gr.Textbox(label="(Advanced) Web page div for text extraction", value="p", placeholder="p") | |
load_web = gr.Button(value="Load in webpage", variant="secondary", scale=0) | |
with gr.Accordion("CSV/Excel file", open = False): | |
in_csv = gr.File(label="Upload CSV/Excel file", file_count="multiple", file_types=['.csv', '.xlsx']) | |
in_text_column = gr.Textbox(label="Enter column name where text is stored") | |
load_csv = gr.Button(value="Load in CSV/Excel file", variant="secondary", scale=0) | |
ingest_embed_out = gr.Textbox(label="File/web page preparation progress") | |
with gr.Tab("Advanced features"): | |
out_passages = gr.Slider(minimum=1, value = 2, maximum=10, step=1, label="Choose number of passages to retrieve from the document. Numbers greater than 2 may lead to increased hallucinations or input text being truncated.") | |
temp_slide = gr.Slider(minimum=0.1, value = 0.1, maximum=1, step=0.1, label="Choose temperature setting for response generation.") | |
with gr.Row(): | |
model_choice = gr.Radio(label="Choose a chat model", value="Flan Alpaca (small, fast)", choices = ["Flan Alpaca (small, fast)", "Mistral Open Orca (larger, slow)"]) | |
change_model_button = gr.Button(value="Load model", scale=0) | |
with gr.Accordion("Choose number of model layers to send to GPU (WARNING: please don't modify unless you are sure you have a GPU).", open = False): | |
gpu_layer_choice = gr.Slider(label="Choose number of model layers to send to GPU.", value=0, minimum=0, maximum=100, step = 1, visible=True) | |
load_text = gr.Text(label="Load status") | |
gr.HTML( | |
"<center>This app is based on the model LLama-2 finetuned in our dataset. It powered by Gradio, Transformers, Ctransformers, and Langchain.</a></center>" | |
) | |
examples_set.change(fn=chatf.update_message, inputs=[examples_set], outputs=[message]) | |
change_model_button.click(fn=chatf.turn_off_interactivity, inputs=[message, chatbot], outputs=[message, chatbot], queue=False).\ | |
then(fn=load_model, inputs=[model_choice, gpu_layer_choice], outputs = [model_type_state, load_text, current_model]).\ | |
then(lambda: chatf.restore_interactivity(), None, [message], queue=False).\ | |
then(chatf.clear_chat, inputs=[chat_history_state, sources, message, current_topic], outputs=[chat_history_state, sources, message, current_topic]).\ | |
then(lambda: None, None, chatbot, queue=False) | |
# Load in a pdf | |
load_pdf_click = load_pdf.click(ing.parse_file, inputs=[in_pdf], outputs=[ingest_text, current_source]).\ | |
then(ing.text_to_docs, inputs=[ingest_text], outputs=[ingest_docs]).\ | |
then(docs_to_faiss_save, inputs=[ingest_docs], outputs=[ingest_embed_out, vectorstore_state]).\ | |
then(chatf.hide_block, outputs = [examples_set]) | |
# Load in a webpage | |
load_web_click = load_web.click(ing.parse_html, inputs=[in_web, in_div], outputs=[ingest_text, ingest_metadata, current_source]).\ | |
then(ing.html_text_to_docs, inputs=[ingest_text, ingest_metadata], outputs=[ingest_docs]).\ | |
then(docs_to_faiss_save, inputs=[ingest_docs], outputs=[ingest_embed_out, vectorstore_state]).\ | |
then(chatf.hide_block, outputs = [examples_set]) | |
# Load in a csv/excel file | |
load_csv_click = load_csv.click(ing.parse_csv_or_excel, inputs=[in_csv, in_text_column], outputs=[ingest_text, current_source]).\ | |
then(ing.csv_excel_text_to_docs, inputs=[ingest_text, in_text_column], outputs=[ingest_docs]).\ | |
then(docs_to_faiss_save, inputs=[ingest_docs], outputs=[ingest_embed_out, vectorstore_state]).\ | |
then(chatf.hide_block, outputs = [examples_set]) | |
# Load in a webpage | |
# Click/enter to send message action | |
response_click = submit.click(chatf.create_full_prompt, inputs=[message, chat_history_state, current_topic, vectorstore_state, embeddings_state, model_type_state, out_passages], outputs=[chat_history_state, sources, instruction_prompt_out], queue=False, api_name="retrieval").\ | |
then(chatf.turn_off_interactivity, inputs=[message, chatbot], outputs=[message, chatbot], queue=False).\ | |
then(chatf.produce_streaming_answer_chatbot, inputs=[chatbot, instruction_prompt_out, model_type_state, temp_slide], outputs=chatbot) | |
response_click.then(chatf.highlight_found_text, [chatbot, sources], [sources]).\ | |
then(chatf.add_inputs_answer_to_history,[message, chatbot, current_topic], [chat_history_state, current_topic]).\ | |
then(lambda: chatf.restore_interactivity(), None, [message], queue=False) | |
response_enter = message.submit(chatf.create_full_prompt, inputs=[message, chat_history_state, current_topic, vectorstore_state, embeddings_state, model_type_state, out_passages], outputs=[chat_history_state, sources, instruction_prompt_out], queue=False).\ | |
then(chatf.turn_off_interactivity, inputs=[message, chatbot], outputs=[message, chatbot], queue=False).\ | |
then(chatf.produce_streaming_answer_chatbot, [chatbot, instruction_prompt_out, model_type_state, temp_slide], chatbot) | |
response_enter.then(chatf.highlight_found_text, [chatbot, sources], [sources]).\ | |
then(chatf.add_inputs_answer_to_history,[message, chatbot, current_topic], [chat_history_state, current_topic]).\ | |
then(lambda: chatf.restore_interactivity(), None, [message], queue=False) | |
# Stop box | |
stop.click(fn=None, inputs=None, outputs=None, cancels=[response_click, response_enter]) | |
# Clear box | |
clear.click(chatf.clear_chat, inputs=[chat_history_state, sources, message, current_topic], outputs=[chat_history_state, sources, message, current_topic]) | |
clear.click(lambda: None, None, chatbot, queue=False) | |
# Thumbs up or thumbs down voting function | |
chatbot.like(chatf.vote, [chat_history_state, instruction_prompt_out, model_type_state], None) | |
block.queue(concurrency_count=1).launch(debug=True) | |
# - | |