Spaces:
Runtime error
Runtime error
# LangChain ConversationalRetrievalChain app that streams output to gradio interface | |
from threading import Thread | |
import gradio as gr | |
from queue import SimpleQueue | |
from typing import Any, Dict, List, Union | |
from langchain.callbacks.base import BaseCallbackHandler | |
from langchain.schema import LLMResult | |
# from langchain_community.llms import HuggingFaceTextGenInference | |
from langchain_community.llms import HuggingFaceEndpoint | |
from langchain.chains import ConversationalRetrievalChain | |
from langchain_community.embeddings import HuggingFaceBgeEmbeddings | |
from langchain_community.vectorstores import FAISS | |
from langchain_community.document_loaders import PyPDFLoader | |
# from dotenv import load_dotenv, find_dotenv | |
import pickle | |
import os | |
## loading the .env file | |
# load_dotenv(find_dotenv()) | |
huggingfacehub_api_token = os.getenv("HUGGINGFACEHUB_API_TOKEN") | |
loader = PyPDFLoader("stg.pdf") | |
documents = loader.load_and_split() | |
# Define model and vector store | |
embeddings = "BAAI/bge-base-en" | |
encode_kwargs = {'normalize_embeddings': True} | |
model_norm = HuggingFaceBgeEmbeddings( | |
model_name=embeddings, | |
model_kwargs={'device': 'cpu'}, | |
encode_kwargs=encode_kwargs | |
) | |
vector_store = FAISS.from_documents(documents, model_norm) | |
# job_done = object() # signals the processing is done | |
## saving the embeddings locally | |
vector_store.save_local("cdssagent_database") | |
##loading | |
vector_store = FAISS.load_local("cdssagent_database", model_norm, allow_dangerous_deserialization=True) | |
job_done = object() | |
# Lets set up our streaming | |
class StreamingGradioCallbackHandler(BaseCallbackHandler): | |
"""Callback handler - works with LLMs that support streaming.""" | |
def __init__(self, q: SimpleQueue): | |
self.q = q | |
def on_llm_start( | |
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any | |
) -> None: | |
"""Run when LLM starts running.""" | |
while not self.q.empty(): | |
try: | |
self.q.get(block=False) | |
except SimpleQueue.empty: | |
continue | |
def on_llm_new_token(self, token: str, **kwargs: Any) -> None: | |
"""Run on new LLM token. Only available when streaming is enabled.""" | |
self.q.put(token) | |
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: | |
"""Run when LLM ends running.""" | |
self.q.put(job_done) | |
def on_llm_error( | |
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any | |
) -> None: | |
"""Run when LLM errors.""" | |
self.q.put(job_done) | |
# Initializes the LLM | |
q = SimpleQueue() | |
# from langchain_core.callbacks.streaming_stdout import StreamingStdOutCallbackHandler | |
callbacks = [StreamingGradioCallbackHandler(q)] | |
llm = HuggingFaceEndpoint( | |
endpoint_url="https://api-inference.huggingface.co/models/mistralai/Mixtral-8x7B-Instruct-v0.1", | |
max_new_tokens=512, | |
top_k=10, | |
top_p=0.95, | |
typical_p=0.95, | |
temperature=0.01, | |
repetition_penalty=1.03, | |
callbacks=callbacks, | |
streaming=True, | |
huggingfacehub_api_token=huggingfacehub_api_token | |
) | |
# Define prompts and initialize conversation chain | |
prompt = "Your are a senior clinician, you only answer questions you have been asked, and always limit your answers to the document content only. Never make up answers. If you do not have the answer, state that the data is not contained in your knowledge base and stop your response." | |
chain = ConversationalRetrievalChain.from_llm(llm=llm, chain_type='stuff', | |
retriever=vector_store.as_retriever( | |
search_kwargs={"k": 3})) | |
# Set up chat history and streaming for Gradio Display | |
def process_question(question): | |
chat_history = [] | |
full_query = f"{prompt} {question}" | |
result = chain({"question": full_query, "chat_history": chat_history}) | |
return result["answer"] | |
def add_text(history, text): | |
history = history + [(text, None)] | |
return history, "" | |
def streaming_chat(history): | |
user_input = history[-1][0] | |
thread = Thread(target=process_question, args=(user_input,)) | |
thread.start() | |
history[-1][1] = "" | |
while True: | |
next_token = q.get(block=True) # Blocks until an input is available | |
if next_token is job_done: | |
break | |
history[-1][1] += next_token | |
yield history | |
thread.join() | |
# Creates A gradio Interface | |
with gr.Blocks() as demo: | |
Langchain = gr.Chatbot(label="Response", height=500) | |
Question = gr.Textbox(label="Question") | |
Question.submit(add_text, [Langchain, Question], [Langchain, Question]).then( | |
streaming_chat, Langchain, Langchain | |
) | |
demo.queue().launch(share=True,debug=True) | |