cdss-app / app.py
delphiclinic's picture
Update app.py
ccafaa2 verified
# LangChain ConversationalRetrievalChain app that streams output to gradio interface
from threading import Thread
import gradio as gr
from queue import SimpleQueue
from typing import Any, Dict, List, Union
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import LLMResult
# from langchain_community.llms import HuggingFaceTextGenInference
from langchain_community.llms import HuggingFaceEndpoint
from langchain.chains import ConversationalRetrievalChain
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_community.document_loaders import PyPDFLoader
# from dotenv import load_dotenv, find_dotenv
import pickle
import os
from langchain_community.document_loaders import PyPDFDirectoryLoader
## loading the .env file
# load_dotenv(find_dotenv())
huggingfacehub_api_token = os.getenv("HUGGINGFACEHUB_API_TOKEN")
# Define model and vector store
embeddings = "BAAI/bge-base-en"
encode_kwargs = {'normalize_embeddings': True}
model_norm = HuggingFaceBgeEmbeddings(
model_name=embeddings,
model_kwargs={'device': 'cpu'},
encode_kwargs=encode_kwargs
)
try:
##ltoading
vector_sore = FAISS.load_local("wolo_database", model_norm, allow_dangerous_deserialization=True)
job_done = object()
except Exception as e:
loader = PyPDFLoader("stg.pdf")
documents = loader.load_and_split()
# loader = PyPDFDirectoryLoader("wolo/")
# documents = loader.load_and_split()
vector_store = FAISS.from_documents(documents, model_norm)
# job_done = object() # signals the processing is done
## saving the embeddings locally
vector_store.save_local("wolo_database")
##loading
# vector_store = FAISS.load_local("cdssagent_database", model_norm, allow_dangerous_deserialization=True)
# job_done = object()
# Lets set up our streaming
class StreamingGradioCallbackHandler(BaseCallbackHandler):
"""Callback handler - works with LLMs that support streaming."""
def __init__(self, q: SimpleQueue):
self.q = q
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Run when LLM starts running."""
while not self.q.empty():
try:
self.q.get(block=False)
except SimpleQueue.empty:
continue
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Run on new LLM token. Only available when streaming is enabled."""
self.q.put(token)
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Run when LLM ends running."""
self.q.put(job_done)
def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Run when LLM errors."""
self.q.put(job_done)
# Initializes the LLM
q = SimpleQueue()
# from langchain_core.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
callbacks = [StreamingGradioCallbackHandler(q)]
llm = HuggingFaceEndpoint(
endpoint_url="https://api-inference.huggingface.co/models/mistralai/Mixtral-8x7B-Instruct-v0.1",
max_new_tokens=512,
top_k=10,
top_p=0.95,
typical_p=0.95,
temperature=0.01,
repetition_penalty=1.03,
callbacks=callbacks,
streaming=True,
huggingfacehub_api_token=huggingfacehub_api_token
)
# Define prompts and initialize conversation chain
prompt = "Your are a senior clinician, you only answer questions you have been asked, and always limit your answers to the document content only. Never make up answers. If you do not have the answer, state that the data is not contained in your knowledge base and stop your response."
chain = ConversationalRetrievalChain.from_llm(llm=llm, chain_type='stuff',
retriever=vector_store.as_retriever(
search_kwargs={"k": 3}))
# Set up chat history and streaming for Gradio Display
def process_question(question):
chat_history = []
full_query = f"{prompt} {question}"
result = chain({"question": full_query, "chat_history": chat_history})
return result["answer"]
def add_text(history, text):
history = history + [(text, None)]
return history, ""
def streaming_chat(history):
user_input = history[-1][0]
thread = Thread(target=process_question, args=(user_input,))
thread.start()
history[-1][1] = ""
while True:
next_token = q.get(block=True) # Blocks until an input is available
if next_token is job_done:
break
history[-1][1] += next_token
yield history
thread.join()
# Creates A gradio Interface
with gr.Blocks(title="Clinical Decision Support System") as demo:
Langchain = gr.Chatbot(label="Response", height=500)
Question = gr.Textbox(label="Question",placeholder='What are the symptoms of endocarditis?')
Question.submit(add_text, [Langchain, Question], [Langchain, Question]).then(
streaming_chat, Langchain, Langchain
)
demo.queue().launch(share=True, debug=True, favicon_path ='thumbnail.jpg')