Tan Gezerman commited on
Commit
ff92b60
·
verified ·
1 Parent(s): 3edb6a4

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +151 -0
app.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+ import chainlit as cl
3
+ from langchain_chroma import Chroma
4
+ from langchain_core.prompts import PromptTemplate
5
+ from langchain_core.callbacks import CallbackManager, AsyncCallbackManagerForLLMRun
6
+ from langchain_community.llms import LlamaCpp
7
+ from chainlit.types import ThreadDict
8
+ from langchain.chains import RetrievalQA, ConversationChain
9
+ from langchain_community.embeddings.huggingface import HuggingFaceEmbeddings
10
+ from langchain.chains.conversation.memory import ConversationBufferMemory
11
+
12
+
13
+
14
+ # ctransformers is no longer used
15
+ """ from langchain_community.llms import CTransformers
16
+
17
+ # Initialize the language model
18
+ llm = CTransformers(model='Model/llama-2-7b-chat.ggmlv3.q2_K.bin', # 2 bit quantized model
19
+ model_type='llama',
20
+ config={'max_new_tokens': 256, # max tokens in reply
21
+ 'temperature': 0.01, } # randomness of the reply
22
+ )
23
+ """
24
+
25
+ # Initialize the language model with LlamaCpp
26
+ llm = LlamaCpp(model_path="Model/llama-2-7b-chat.Q4_K_M.gguf", # token streaming to terminal
27
+ device="cpu",verbose = True, max_tokens = 4096, #offloads ALL layers to GPU, uses around 6 GB of Vram
28
+ config={ # max tokens in reply
29
+ 'temperature': 0.75} # randomness of the reply
30
+ )
31
+
32
+ DATA_PATH = 'Data/'
33
+
34
+ DB_CHROMA_PATH = 'vectorstore/db_chroma'
35
+
36
+ embedding_function = HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2', model_kwargs={'device': 'cpu'})
37
+
38
+ db = Chroma(persist_directory=DB_CHROMA_PATH, embedding_function=embedding_function)
39
+
40
+
41
+ rag_pipeline = RetrievalQA.from_chain_type(
42
+ llm=llm, chain_type='stuff',
43
+ retriever=db.as_retriever(),
44
+ return_source_documents=True
45
+ )
46
+
47
+
48
+ template = """
49
+ You are an AI specialized in the medical domain.
50
+ Your purpose is to provide accurate, clear, and helpful responses to medical-related inquiries.
51
+ You must avoid misinformation at all costs. Do not respond to questions outside of the medical domain.
52
+ If you are unsure or lack information about a query, you must clearly state that you do not know the answer.
53
+
54
+ Question: {query}
55
+
56
+ Answer:
57
+
58
+ """
59
+
60
+ prompt_template = PromptTemplate(input_variables=["query"],template=template)
61
+
62
+
63
+
64
+ conversation_buf = ConversationChain(
65
+ llm=llm,
66
+ memory=ConversationBufferMemory(),
67
+ )
68
+
69
+
70
+
71
+
72
+ @cl.on_chat_start
73
+ async def on_chat_start():
74
+ pass
75
+
76
+
77
+
78
+
79
+ @cl.step(type="llm")
80
+ def get_response(query):
81
+ """
82
+ Generates a response from the language model based on the user's input. If the input includes
83
+ '-rag', it uses a retrieval-augmented generation pipeline, otherwise, it directly invokes
84
+ the language model.
85
+
86
+ Args:
87
+ question (str): The user's input text.
88
+
89
+ Returns:
90
+ str: The language model's response, potentially including source documents if '-rag' was used.
91
+ """
92
+
93
+
94
+ if "-rag" in query.lower():
95
+ response = rag_pipeline(prompt_template.format(query=query))
96
+ result = response["result"]
97
+ source = response["source_documents"]
98
+ if source:
99
+ source_details = "\n\nSources:"
100
+ for source in source:
101
+ page_content = source.page_content
102
+ page_number = source.metadata['page']
103
+ source_book = source.metadata['source']
104
+ source_details += f"\n- Page {page_number} from {source_book}: \"{page_content}\""
105
+
106
+ result += source_details
107
+ return result
108
+
109
+
110
+ return llm.invoke(prompt_template.format(query=query))
111
+
112
+
113
+
114
+ @cl.on_message
115
+ async def on_message(message: cl.Message):
116
+ """
117
+ Fetches the response from the language model and shows it in the web ui.
118
+ """
119
+ try:
120
+ response = get_response(message.content)
121
+ msg = cl.Message(content=response)
122
+ except Exception as e:
123
+ msg = cl.Message(content=str(e))
124
+
125
+ await msg.send()
126
+
127
+
128
+
129
+
130
+ @cl.on_chat_resume
131
+ async def on_chat_resume(thread: ThreadDict):
132
+ pass # TODO user history gets fed to LLM
133
+
134
+
135
+
136
+ @cl.on_chat_end
137
+ def on_chat_end():
138
+ pass
139
+
140
+
141
+
142
+ @cl.password_auth_callback
143
+ def auth_callback(username: str, password: str):
144
+ # Fetch the user matching username from your database
145
+ # and compare the hashed password with the value stored in the database
146
+ if (username, password) == ("karcan", "karcan123"):
147
+ return cl.User(
148
+ identifier="admin", metadata={"role": "admin", "provider": "credentials"}
149
+ )
150
+ else:
151
+ return None