File size: 4,715 Bytes
d34e8c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a9eae4d
d34e8c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b7913ee
d34e8c0
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
# LangChain ConversationalRetrievalChain app that streams output to gradio interface
from threading import Thread
import gradio as gr
from queue import SimpleQueue
from typing import Any, Dict, List, Union
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import LLMResult
# from langchain_community.llms import HuggingFaceTextGenInference
from langchain_community.llms import HuggingFaceEndpoint
from langchain.chains import ConversationalRetrievalChain
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_community.document_loaders import PyPDFLoader
from dotenv import load_dotenv, find_dotenv
import pickle
import os


## loading the .env file
load_dotenv(find_dotenv())

huggingfacehub_api_token = os.getenv("HUGGINGFACEHUB_API_TOKEN")


# loader = PyPDFLoader("data/stg.pdf")
# documents = loader.load_and_split()



# Define model and vector store

embeddings = "BAAI/bge-base-en"
encode_kwargs = {'normalize_embeddings': True}
model_norm = HuggingFaceBgeEmbeddings(
    model_name=embeddings,
    model_kwargs={'device': 'cpu'},
    encode_kwargs=encode_kwargs
)  
# vector_store = FAISS.from_documents(documents, model_norm)
# job_done = object()  # signals the processing is done

## saving the embeddings locally
# vector_store.save_local("cdssagent_database")

##loading
vector_store = FAISS.load_local("cdssagent_database", model_norm, allow_dangerous_deserialization=True)
job_done = object()


# Lets set up our streaming
class StreamingGradioCallbackHandler(BaseCallbackHandler):
    """Callback handler - works with LLMs that support streaming."""

    def __init__(self, q: SimpleQueue):
        self.q = q

    def on_llm_start(
        self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
    ) -> None:
        """Run when LLM starts running."""
        while not self.q.empty():
            try:
                self.q.get(block=False)
            except SimpleQueue.empty:
                continue

    def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
        """Run on new LLM token. Only available when streaming is enabled."""
        self.q.put(token)

    def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
        """Run when LLM ends running."""
        self.q.put(job_done)

    def on_llm_error(
        self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
    ) -> None:
        """Run when LLM errors."""
        self.q.put(job_done)


# Initializes the LLM
q = SimpleQueue()



# from langchain_core.callbacks.streaming_stdout import StreamingStdOutCallbackHandler

callbacks = [StreamingGradioCallbackHandler(q)]
llm = HuggingFaceEndpoint(
    endpoint_url="https://api-inference.huggingface.co/models/mistralai/Mixtral-8x7B-Instruct-v0.1",
    max_new_tokens=512,
    top_k=10,
    top_p=0.95,
    typical_p=0.95,
    temperature=0.01,
    repetition_penalty=1.03,
    callbacks=callbacks,
    streaming=True,
    huggingfacehub_api_token=huggingfacehub_api_token
)

# Define prompts and initialize conversation chain
prompt = "Your are a senior clinician, you only answer questions you have been asked, and always limit your answers to the document content only. Never make up answers. If you do not have the answer, state that the data is not contained in your knowledge base and stop your response."
chain = ConversationalRetrievalChain.from_llm(llm=llm, chain_type='stuff',
                                              retriever=vector_store.as_retriever(
                                                  search_kwargs={"k": 3}))

# Set up chat history and streaming for Gradio Display
def process_question(question):
    chat_history = []
    full_query = f"{prompt} {question}"
    result = chain({"question": full_query, "chat_history": chat_history})
    return result["answer"]


def add_text(history, text):
    history = history + [(text, None)]
    return history, ""


def streaming_chat(history):
    user_input = history[-1][0]
    thread = Thread(target=process_question, args=(user_input,))
    thread.start()
    history[-1][1] = ""
    while True:
        next_token = q.get(block=True)  # Blocks until an input is available
        if next_token is job_done:
            break
        history[-1][1] += next_token
        yield history
    thread.join()


# Creates A gradio Interface
with gr.Blocks() as demo:
    Langchain = gr.Chatbot(label="Response", height=500)
    Question = gr.Textbox(label="Question")
    Question.submit(add_text, [Langchain, Question], [Langchain, Question]).then(
        streaming_chat, Langchain, Langchain
    )
demo.queue().launch(share=True,debug=True)