import os import time from threading import Thread from datetime import datetime from uuid import uuid4 import gradio as gr from time import sleep import pprint import torch from torch import cuda, bfloat16 import transformers from transformers import StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer from langchain.document_loaders.pdf import UnstructuredPDFLoader from langchain.text_splitter import CharacterTextSplitter from langchain.embeddings import HuggingFaceEmbeddings from langchain.vectorstores import Chroma from langchain.chains import RetrievalQA, ConversationalRetrievalChain from langchain.llms import HuggingFacePipeline # model_names = ["tiiuae/falcon-7b-instruct", "tiiuae/falcon-40b-instruct", "tiiuae/falcon-rw-1b"] model_names = ["tiiuae/falcon-7b-instruct"] models = {} embedding_function_name = "all-mpnet-base-v2" device = f'cuda:{cuda.current_device()}' if cuda.is_available() else 'cpu' max_new_tokens = 1024 repetition_penalty = 10.0 temperature = 0 chunk_size = 512 chunk_overlap = 32 def get_uuid(): return str(uuid4()) def create_embedding_function(embedding_function_name): return HuggingFaceEmbeddings(model_name=embedding_function_name, model_kwargs={"device": "cuda" if torch.cuda.is_available() else "cpu"}) def create_models(): for model_name in model_names: if model_name == "tiiuae/falcon-40b-instruct": bnb_config = transformers.BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type='nf4', bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=bfloat16 ) model = transformers.AutoModelForCausalLM.from_pretrained( model_name, trust_remote_code=True, quantization_config=bnb_config, device_map='auto' ) else: model = transformers.AutoModelForCausalLM.from_pretrained( model_name, trust_remote_code=True, torch_dtype=torch.bfloat16, device_map='auto' ) model.eval() print(f"Model loaded on {device}") models[model_name] = model create_models() embedding_function = create_embedding_function(embedding_function_name) def user(message, history): # Append the user's message to the conversation history if history is None: history = [] return "", history + [[message, None]] def bot(model_name, db_path, chat_mode, history): if not history or history[-1][0] == "": gr.Info("Please start the conversation by saying something.") return None chat_hist = history[:-1] if chat_hist: chat_hist = [tuple([y.replace("\n", ' ').strip(" ") for y in x]) for x in chat_hist] print("@" * 20) print(f"chat_hist:\n {chat_hist}") print("@" * 20) print('------------------------------------') print(model_name) print(db_path) print(chat_mode) print('------------------------------------') # Need to create langchain model from db for each session db = Chroma(persist_directory=db_path, embedding_function=embedding_function) tokenizer = transformers.AutoTokenizer.from_pretrained(model_name) stop_token_ids = [ tokenizer.convert_tokens_to_ids(x) for x in [ ['Question', ':'], ['Answer', ':'], ['User', ':'], ] ] class StopOnTokens(StoppingCriteria): def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: for stop_ids in stop_token_ids: if torch.eq(input_ids[0][-len(stop_ids):], stop_ids).all(): return True return False stop_token_ids = [torch.LongTensor(x).to(device) for x in stop_token_ids] stopping_criteria = StoppingCriteriaList([StopOnTokens()]) streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True) generate_text = transformers.pipeline( model=models[model_name], tokenizer=tokenizer, return_full_text=True, task='text-generation', stopping_criteria=stopping_criteria, temperature=temperature, max_new_tokens=max_new_tokens, repetition_penalty=repetition_penalty, streamer=streamer ) pipeline = HuggingFacePipeline(pipeline=generate_text) if chat_mode.lower() == 'basic': print("chat mode: basic") qa = RetrievalQA.from_llm( llm=pipeline, retriever=db.as_retriever(), return_source_documents=True ) def run_basic(history): a = qa({"query": history[-1][0]}) pprint.pprint(a['source_documents']) t = Thread(target=run_basic, args=(history,)) t.start() else: print("chat mode: conversational") qa = ConversationalRetrievalChain.from_llm( llm=pipeline, retriever=db.as_retriever(), return_source_documents=True ) def run_conv(history, chat_hist): a = qa({"question": history[-1][0], "chat_history": chat_hist}) pprint.pprint(a['source_documents']) t = Thread(target=run_conv, args=(history, chat_hist)) t.start() history[-1][1] = "" for new_text in streamer: history[-1][1] += new_text time.sleep(0.01) yield history def pdf_changes(pdf_doc): print("pdf changes, loading documents") # Persistently store the db next to the uploaded pdf db_path, file_ext = os.path.splitext(pdf_doc.name) timestamp = datetime.now() db_path += "_" + timestamp.strftime("%Y-%m-%d-%H-%S") loader = UnstructuredPDFLoader(pdf_doc.name) documents = loader.load() text_splitter = CharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap) texts = text_splitter.split_documents(documents) db = Chroma.from_documents(texts, embedding_function, persist_directory=db_path) db.persist() return db_path def init(): with gr.Blocks( theme=gr.themes.Soft(), css=".disclaimer {font-variant-caps: all-small-caps;}", ) as demo: gr.HTML( """

Chat With FalconPDF

""" ) pdf_doc = gr.File(label="Load a pdf", file_types=['.pdf'], type="file") model_id = gr.Radio(label="LLM", choices=model_names, value=model_names[0], interactive=True) db_path = gr.Textbox(label="DB_PATH", visible=False) chat_mode = gr.Radio(label="Chat mode", choices=['Basic', 'Conversational'], value='Basic', info="Basic: no coversational context. Conversational: uses conversational context.") chatbot = gr.Chatbot(height=500) with gr.Row(): with gr.Column(): msg = gr.Textbox( label="Chat Message Box", placeholder="Chat Message Box", show_label=False, container=False ) with gr.Column(): with gr.Row(): submit = gr.Button("Submit") stop = gr.Button("Stop") clear = gr.Button("Clear") gr.Examples(['What is the summary of the paper?', 'What is the motivation of the paper?'], inputs=msg) def clear_input(): sleep(1) return "" with gr.Row(): gr.HTML( """

It is based on Falcon 7B/40B. More information can be found here.

""" ) model_id.change(clear_input, inputs=[], outputs=[msg]) pdf_doc.upload(pdf_changes, inputs=[pdf_doc], outputs=[db_path]). \ then(clear_input, inputs=[], outputs=[msg]). \ then(lambda: None, None, chatbot) # enter key event submit_event = msg.submit( fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False, ).then( fn=bot, inputs=[ model_id, db_path, chat_mode, chatbot, ], outputs=chatbot, queue=True, ) # click submit button event submit_click_event = submit.click( fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False, ).then( fn=bot, inputs=[ model_id, db_path, chat_mode, chatbot, ], outputs=chatbot, queue=True, ) stop.click( fn=None, inputs=None, outputs=None, cancels=[submit_event, submit_click_event], queue=False, ) clear.click(lambda: None, None, chatbot, queue=False) demo.queue(max_size=32, concurrency_count=2) demo.launch(server_port=8266, inline=False, share=True) if __name__ == "__main__": init()