# LangChain ConversationalRetrievalChain app that streams output to gradio interface from threading import Thread import gradio as gr from queue import SimpleQueue from typing import Any, Dict, List, Union from langchain.callbacks.base import BaseCallbackHandler from langchain.schema import LLMResult # from langchain_community.llms import HuggingFaceTextGenInference from langchain_community.llms import HuggingFaceEndpoint from langchain.chains import ConversationalRetrievalChain from langchain_community.embeddings import HuggingFaceBgeEmbeddings from langchain_community.vectorstores import FAISS from langchain_community.document_loaders import PyPDFLoader # from dotenv import load_dotenv, find_dotenv import pickle import os ## loading the .env file # load_dotenv(find_dotenv()) huggingfacehub_api_token = os.getenv("HUGGINGFACEHUB_API_TOKEN") loader = PyPDFLoader("stg.pdf") documents = loader.load_and_split() # Define model and vector store embeddings = "BAAI/bge-base-en" encode_kwargs = {'normalize_embeddings': True} model_norm = HuggingFaceBgeEmbeddings( model_name=embeddings, model_kwargs={'device': 'cpu'}, encode_kwargs=encode_kwargs ) vector_store = FAISS.from_documents(documents, model_norm) # job_done = object() # signals the processing is done ## saving the embeddings locally vector_store.save_local("cdssagent_database") ##loading vector_store = FAISS.load_local("cdssagent_database", model_norm, allow_dangerous_deserialization=True) job_done = object() # Lets set up our streaming class StreamingGradioCallbackHandler(BaseCallbackHandler): """Callback handler - works with LLMs that support streaming.""" def __init__(self, q: SimpleQueue): self.q = q def on_llm_start( self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any ) -> None: """Run when LLM starts running.""" while not self.q.empty(): try: self.q.get(block=False) except SimpleQueue.empty: continue def on_llm_new_token(self, token: str, **kwargs: Any) -> None: """Run on new LLM token. Only available when streaming is enabled.""" self.q.put(token) def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: """Run when LLM ends running.""" self.q.put(job_done) def on_llm_error( self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any ) -> None: """Run when LLM errors.""" self.q.put(job_done) # Initializes the LLM q = SimpleQueue() # from langchain_core.callbacks.streaming_stdout import StreamingStdOutCallbackHandler callbacks = [StreamingGradioCallbackHandler(q)] llm = HuggingFaceEndpoint( endpoint_url="https://api-inference.huggingface.co/models/mistralai/Mixtral-8x7B-Instruct-v0.1", max_new_tokens=512, top_k=10, top_p=0.95, typical_p=0.95, temperature=0.01, repetition_penalty=1.03, callbacks=callbacks, streaming=True, huggingfacehub_api_token=huggingfacehub_api_token ) # Define prompts and initialize conversation chain prompt = "Your are a senior clinician, you only answer questions you have been asked, and always limit your answers to the document content only. Never make up answers. If you do not have the answer, state that the data is not contained in your knowledge base and stop your response." chain = ConversationalRetrievalChain.from_llm(llm=llm, chain_type='stuff', retriever=vector_store.as_retriever( search_kwargs={"k": 3})) # Set up chat history and streaming for Gradio Display def process_question(question): chat_history = [] full_query = f"{prompt} {question}" result = chain({"question": full_query, "chat_history": chat_history}) return result["answer"] def add_text(history, text): history = history + [(text, None)] return history, "" def streaming_chat(history): user_input = history[-1][0] thread = Thread(target=process_question, args=(user_input,)) thread.start() history[-1][1] = "" while True: next_token = q.get(block=True) # Blocks until an input is available if next_token is job_done: break history[-1][1] += next_token yield history thread.join() # Creates A gradio Interface with gr.Blocks(title="Clinical Decision Support System", theme=gr.themes.Base() ) as demo: Langchain = gr.Chatbot(label="Response", height=500) Question = gr.Textbox(label="Question", placeholder="Type your question here") Question.submit(add_text, [Langchain, Question], [Langchain, Question]).then( streaming_chat, Langchain, Langchain ) demo.queue().launch(share=True,debug=True, favicon_path ='algorithm.png')