Spaces:
Runtime error
Runtime error
File size: 4,847 Bytes
d34e8c0 ed818ca d34e8c0 ed818ca d34e8c0 7f752df d34e8c0 ddba56a d34e8c0 a9eae4d ddba56a d34e8c0 ddba56a d34e8c0 7dfed3b 998bf4f d34e8c0 7dfed3b d34e8c0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
# LangChain ConversationalRetrievalChain app that streams output to gradio interface
from threading import Thread
import gradio as gr
from queue import SimpleQueue
from typing import Any, Dict, List, Union
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import LLMResult
# from langchain_community.llms import HuggingFaceTextGenInference
from langchain_community.llms import HuggingFaceEndpoint
from langchain.chains import ConversationalRetrievalChain
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_community.document_loaders import PyPDFLoader
# from dotenv import load_dotenv, find_dotenv
import pickle
import os
## loading the .env file
# load_dotenv(find_dotenv())
huggingfacehub_api_token = os.getenv("HUGGINGFACEHUB_API_TOKEN")
loader = PyPDFLoader("stg.pdf")
documents = loader.load_and_split()
# Define model and vector store
embeddings = "BAAI/bge-base-en"
encode_kwargs = {'normalize_embeddings': True}
model_norm = HuggingFaceBgeEmbeddings(
model_name=embeddings,
model_kwargs={'device': 'cpu'},
encode_kwargs=encode_kwargs
)
vector_store = FAISS.from_documents(documents, model_norm)
# job_done = object() # signals the processing is done
## saving the embeddings locally
vector_store.save_local("cdssagent_database")
##loading
vector_store = FAISS.load_local("cdssagent_database", model_norm, allow_dangerous_deserialization=True)
job_done = object()
# Lets set up our streaming
class StreamingGradioCallbackHandler(BaseCallbackHandler):
"""Callback handler - works with LLMs that support streaming."""
def __init__(self, q: SimpleQueue):
self.q = q
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Run when LLM starts running."""
while not self.q.empty():
try:
self.q.get(block=False)
except SimpleQueue.empty:
continue
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Run on new LLM token. Only available when streaming is enabled."""
self.q.put(token)
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Run when LLM ends running."""
self.q.put(job_done)
def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Run when LLM errors."""
self.q.put(job_done)
# Initializes the LLM
q = SimpleQueue()
# from langchain_core.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
callbacks = [StreamingGradioCallbackHandler(q)]
llm = HuggingFaceEndpoint(
endpoint_url="https://api-inference.huggingface.co/models/mistralai/Mixtral-8x7B-Instruct-v0.1",
max_new_tokens=512,
top_k=10,
top_p=0.95,
typical_p=0.95,
temperature=0.01,
repetition_penalty=1.03,
callbacks=callbacks,
streaming=True,
huggingfacehub_api_token=huggingfacehub_api_token
)
# Define prompts and initialize conversation chain
prompt = "Your are a senior clinician, you only answer questions you have been asked, and always limit your answers to the document content only. Never make up answers. If you do not have the answer, state that the data is not contained in your knowledge base and stop your response."
chain = ConversationalRetrievalChain.from_llm(llm=llm, chain_type='stuff',
retriever=vector_store.as_retriever(
search_kwargs={"k": 3}))
# Set up chat history and streaming for Gradio Display
def process_question(question):
chat_history = []
full_query = f"{prompt} {question}"
result = chain({"question": full_query, "chat_history": chat_history})
return result["answer"]
def add_text(history, text):
history = history + [(text, None)]
return history, ""
def streaming_chat(history):
user_input = history[-1][0]
thread = Thread(target=process_question, args=(user_input,))
thread.start()
history[-1][1] = ""
while True:
next_token = q.get(block=True) # Blocks until an input is available
if next_token is job_done:
break
history[-1][1] += next_token
yield history
thread.join()
# Creates A gradio Interface
with gr.Blocks(title="Clinical Decision Support System", theme=gr.themes.Base() ) as demo:
Langchain = gr.Chatbot(label="Response", height=500)
Question = gr.Textbox(label="Question", placeholder="Type your question here")
Question.submit(add_text, [Langchain, Question], [Langchain, Question]).then(
streaming_chat, Langchain, Langchain
)
demo.queue().launch(share=True,debug=True, favicon_path ='algorithm.png')
|