File size: 4,914 Bytes
a054510
 
 
d5191ee
a054510
 
 
 
 
 
 
 
 
 
f34410c
a054510
 
 
 
 
 
 
 
 
 
f34410c
 
 
 
 
 
 
 
 
 
 
a054510
f34410c
a054510
 
 
 
 
 
 
d5191ee
a054510
 
 
 
 
 
f34410c
a054510
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d5191ee
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
from langchain_community.document_loaders import UnstructuredPowerPointLoader
from langchain_community.vectorstores import FAISS
from threading import Thread
import gradio as gr
from queue import SimpleQueue
from typing import Any, Dict, List, Union
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import LLMResult
# from langchain_community.llms import HuggingFaceTextGenInference
from langchain_community.llms import HuggingFaceEndpoint
from langchain.chains import ConversationalRetrievalChain
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_community.document_loaders import PyPDFLoader
# from dotenv import load_dotenv, find_dotenv
import pickle
import os
from langchain_community.document_loaders import PyPDFDirectoryLoader



huggingfacehub_api_token = os.getenv("HUGGINGFACEHUB_API_TOKEN")



try:
    ##loading
    vector_store = FAISS.load_local("wolo_database", model_norm, allow_dangerous_deserialization=True)
    job_done = object()
except:
     
    loader = PyPDFDirectoryLoader("wolo/")
    
    documents = loader.load_and_split()
    
  

    
# Define model and vector store
embeddings = "BAAI/bge-base-en"
encode_kwargs = {'normalize_embeddings': True}
model_norm = HuggingFaceBgeEmbeddings(
    model_name=embeddings,
    model_kwargs={'device': 'cpu'},
    encode_kwargs=encode_kwargs
)
# vector_store = FAISS.from_documents(documents, model_norm)
# job_done = object()  # signals the processing is done

# # saving the embeddings locally
# vector_store.save_local("wolo_database")





# Lets set up our streaming
class StreamingGradioCallbackHandler(BaseCallbackHandler):
    """Callback handler - works with LLMs that support streaming."""

    def __init__(self, q: SimpleQueue):
        self.q = q

    def on_llm_start(
        self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
    ) -> None:
        """Run when LLM starts running."""
        while not self.q.empty():
            try:
                self.q.get(block=False)
            except SimpleQueue.empty:
                continue

    def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
        """Run on new LLM token. Only available when streaming is enabled."""
        self.q.put(token)

    def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
        """Run when LLM ends running."""
        self.q.put(job_done)

    def on_llm_error(
        self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
    ) -> None:
        """Run when LLM errors."""
        self.q.put(job_done)


# Initializes the LLM
q = SimpleQueue()



# from langchain_core.callbacks.streaming_stdout import StreamingStdOutCallbackHandler

callbacks = [StreamingGradioCallbackHandler(q)]
llm = HuggingFaceEndpoint(
    endpoint_url="https://api-inference.huggingface.co/models/mistralai/Mixtral-8x7B-Instruct-v0.1",
    max_new_tokens=512,
    top_k=10,
    top_p=0.95,
    typical_p=0.95,
    temperature=0.01,
    repetition_penalty=1.03,
    callbacks=callbacks,
    streaming=True,
    huggingfacehub_api_token=huggingfacehub_api_token
)

# Define prompts and initialize conversation chain
prompt = "Your are a senior clinician, you only answer questions you have been asked, and always limit your answers to the document content only. Never make up answers. If you do not have the answer, state that the data is not contained in your knowledge base and stop your response."
chain = ConversationalRetrievalChain.from_llm(llm=llm, chain_type='stuff',
                                              retriever=vector_store.as_retriever(
                                                  search_kwargs={"k": 2}))

# Set up chat history and streaming for Gradio Display
def process_question(question):
    chat_history = []
    full_query = f"{prompt} {question}"
    result = chain({"question": full_query, "chat_history": chat_history})
    return result["answer"]


def add_text(history, text):
    history = history + [(text, None)]
    return history, ""


def streaming_chat(history):
    user_input = history[-1][0]
    thread = Thread(target=process_question, args=(user_input,))
    thread.start()
    history[-1][1] = ""
    while True:
        next_token = q.get(block=True)  # Blocks until an input is available
        if next_token is job_done:
            break
        history[-1][1] += next_token
        yield history
    thread.join()


# Creates A gradio Interface
with gr.Blocks(title="Clinical Decision Support System", head="intro.html") as demo:
    Langchain = gr.Chatbot(label="Response", height=500)
    Question = gr.Textbox(label="Question")
    Question.submit(add_text, [Langchain, Question], [Langchain, Question]).then(
        streaming_chat, Langchain, Langchain
    )
demo.queue().launch(share=True, debug=True, favicon_path ='thumbnail.jpg')