delphiclinic commited on
Commit
d34e8c0
β€’
1 Parent(s): 5b2f040

Upload app.py

Browse files

first demo of cdss

Files changed (1) hide show
  1. app.py +145 -0
app.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # LangChain ConversationalRetrievalChain app that streams output to gradio interface
2
+ from threading import Thread
3
+ import gradio as gr
4
+ from queue import SimpleQueue
5
+ from typing import Any, Dict, List, Union
6
+ from langchain.callbacks.base import BaseCallbackHandler
7
+ from langchain.schema import LLMResult
8
+ # from langchain_community.llms import HuggingFaceTextGenInference
9
+ from langchain_community.llms import HuggingFaceEndpoint
10
+ from langchain.chains import ConversationalRetrievalChain
11
+ from langchain_community.embeddings import HuggingFaceBgeEmbeddings
12
+ from langchain_community.vectorstores import FAISS
13
+ from langchain_community.document_loaders import PyPDFLoader
14
+ from dotenv import load_dotenv, find_dotenv
15
+ import pickle
16
+ import os
17
+
18
+
19
+ ## loading the .env file
20
+ load_dotenv(find_dotenv())
21
+
22
+ huggingfacehub_api_token = os.getenv("HUGGINGFACEHUB_API_TOKEN")
23
+
24
+
25
+ # loader = PyPDFLoader("data/stg.pdf")
26
+ # documents = loader.load_and_split()
27
+
28
+
29
+
30
+ # Define model and vector store
31
+
32
+ embeddings = "BAAI/bge-base-en"
33
+ encode_kwargs = {'normalize_embeddings': True}
34
+ model_norm = HuggingFaceBgeEmbeddings(
35
+ model_name=embeddings,
36
+ model_kwargs={'device': 'cpu'},
37
+ encode_kwargs=encode_kwargs
38
+ )
39
+ # vector_store = FAISS.from_documents(documents, model_norm)
40
+ # job_done = object() # signals the processing is done
41
+
42
+ ## saving the embeddings locally
43
+ # vector_store.save_local("cdssagent_database")
44
+
45
+ ##loading
46
+ vector_store = FAISS.load_local("cdssagent_database", model_norm, allow_dangerous_deserialization=True)
47
+ job_done = object()
48
+
49
+
50
+ # Lets set up our streaming
51
+ class StreamingGradioCallbackHandler(BaseCallbackHandler):
52
+ """Callback handler - works with LLMs that support streaming."""
53
+
54
+ def __init__(self, q: SimpleQueue):
55
+ self.q = q
56
+
57
+ def on_llm_start(
58
+ self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
59
+ ) -> None:
60
+ """Run when LLM starts running."""
61
+ while not self.q.empty():
62
+ try:
63
+ self.q.get(block=False)
64
+ except SimpleQueue.empty:
65
+ continue
66
+
67
+ def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
68
+ """Run on new LLM token. Only available when streaming is enabled."""
69
+ self.q.put(token)
70
+
71
+ def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
72
+ """Run when LLM ends running."""
73
+ self.q.put(job_done)
74
+
75
+ def on_llm_error(
76
+ self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
77
+ ) -> None:
78
+ """Run when LLM errors."""
79
+ self.q.put(job_done)
80
+
81
+
82
+ # Initializes the LLM
83
+ q = SimpleQueue()
84
+
85
+
86
+
87
+ # from langchain_core.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
88
+
89
+ callbacks = [StreamingGradioCallbackHandler(q)]
90
+ llm = HuggingFaceEndpoint(
91
+ endpoint_url="https://api-inference.huggingface.co/models/mistralai/Mixtral-8x7B-Instruct-v0.1",
92
+ max_new_tokens=512,
93
+ top_k=10,
94
+ top_p=0.95,
95
+ typical_p=0.95,
96
+ temperature=0.01,
97
+ repetition_penalty=1.03,
98
+ callbacks=callbacks,
99
+ streaming=True,
100
+ huggingfacehub_api_token=huggingfacehub_api_token
101
+ )
102
+
103
+ # Define prompts and initialize conversation chain
104
+ prompt = "Your are a senior clinician, you only answer questions you have been asked, and always limit your answers to the document content only. Never make up answers. If you do not have the answer, state that the data is not contained in your knowledge base and stop your response."
105
+ chain = ConversationalRetrievalChain.from_llm(llm=llm, chain_type='stuff',
106
+ retriever=vector_store.as_retriever(
107
+ search_kwargs={"k": 3}))
108
+
109
+ # Set up chat history and streaming for Gradio Display
110
+ def process_question(question):
111
+ chat_history = []
112
+ full_query = f"{prompt} {question}"
113
+ result = chain({"question": full_query, "chat_history": chat_history})
114
+ return result["answer"]
115
+
116
+
117
+ def add_text(history, text):
118
+ history = history + [(text, None)]
119
+ return history, ""
120
+
121
+
122
+ def streaming_chat(history):
123
+ user_input = history[-1][0]
124
+ thread = Thread(target=process_question, args=(user_input,))
125
+ thread.start()
126
+ history[-1][1] = ""
127
+ while True:
128
+ next_token = q.get(block=True) # Blocks until an input is available
129
+ if next_token is job_done:
130
+ break
131
+ history[-1][1] += next_token
132
+ yield history
133
+ thread.join()
134
+
135
+
136
+ # Creates A gradio Interface
137
+ with gr.Blocks() as demo:
138
+ Langchain = gr.Chatbot(label="Response", height=500)
139
+ Question = gr.Textbox(label="Question")
140
+ Question.submit(add_text, [Langchain, Question], [Langchain, Question]).then(
141
+ streaming_chat, Langchain, Langchain
142
+ )
143
+ demo.queue().launch(share=True, debug=True)
144
+
145
+