Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -10,6 +10,7 @@ from langchain_community.llms import HuggingFacePipeline
|
|
10 |
from langchain.chains import ConversationChain
|
11 |
from langchain.memory import ConversationBufferMemory
|
12 |
from langchain_community.llms import HuggingFaceEndpoint
|
|
|
13 |
|
14 |
from pathlib import Path
|
15 |
import chromadb
|
@@ -166,12 +167,22 @@ def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, pr
|
|
166 |
# retriever=vector_db.as_retriever(search_type="similarity", search_kwargs={'k': 3})
|
167 |
retriever=vector_db.as_retriever()
|
168 |
progress(0.8, desc="Defining retrieval chain...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
169 |
qa_chain = ConversationalRetrievalChain.from_llm(
|
170 |
llm,
|
171 |
retriever=retriever,
|
172 |
chain_type="stuff",
|
173 |
memory=memory,
|
174 |
-
combine_docs_chain_kwargs={"prompt":
|
175 |
return_source_documents=True,
|
176 |
#return_generated_question=False,
|
177 |
verbose=False,
|
@@ -226,11 +237,11 @@ def initialize_database(list_file_obj, chunk_size, chunk_overlap, progress=gr.Pr
|
|
226 |
return vector_db, collection_name, "Complete!"
|
227 |
|
228 |
|
229 |
-
def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
|
230 |
# print("llm_option",llm_option)
|
231 |
llm_name = list_llm[llm_option]
|
232 |
print("llm_name: ",llm_name)
|
233 |
-
qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db, progress)
|
234 |
return qa_chain, "Complete!"
|
235 |
|
236 |
|
|
|
10 |
from langchain.chains import ConversationChain
|
11 |
from langchain.memory import ConversationBufferMemory
|
12 |
from langchain_community.llms import HuggingFaceEndpoint
|
13 |
+
from langchain import PromptTemplate
|
14 |
|
15 |
from pathlib import Path
|
16 |
import chromadb
|
|
|
167 |
# retriever=vector_db.as_retriever(search_type="similarity", search_kwargs={'k': 3})
|
168 |
retriever=vector_db.as_retriever()
|
169 |
progress(0.8, desc="Defining retrieval chain...")
|
170 |
+
|
171 |
+
prompt_template_string = f'''Your task is as follows: 1. Determine if the input {query} is compliant with the provided {context}. 2. If the requirement is compliant, report "This requirement is compliant." 3. If the requirement is not compliant report "This requirement is not compliant." 4. If the requirement is not compliant, give the reason for non compliance and return the specific rule or guideline the requirement violates. 5. If the requirement is not compliant, report a refined version of the requirement delimited in quotes that is compliant with the provided {context}.'''
|
172 |
+
|
173 |
+
# Create a PromptTemplate object
|
174 |
+
prompt_template = PromptTemplate(
|
175 |
+
template=prompt_template_string,
|
176 |
+
input_variables=["user_question"]
|
177 |
+
)
|
178 |
+
|
179 |
+
|
180 |
qa_chain = ConversationalRetrievalChain.from_llm(
|
181 |
llm,
|
182 |
retriever=retriever,
|
183 |
chain_type="stuff",
|
184 |
memory=memory,
|
185 |
+
combine_docs_chain_kwargs={"prompt": prompt_template},
|
186 |
return_source_documents=True,
|
187 |
#return_generated_question=False,
|
188 |
verbose=False,
|
|
|
237 |
return vector_db, collection_name, "Complete!"
|
238 |
|
239 |
|
240 |
+
def initialize_LLM(llm_option, prompt_template, llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
|
241 |
# print("llm_option",llm_option)
|
242 |
llm_name = list_llm[llm_option]
|
243 |
print("llm_name: ",llm_name)
|
244 |
+
qa_chain = initialize_llmchain(llm_name, prompt = prompt_template, llm_temperature, max_tokens, top_k, vector_db, progress)
|
245 |
return qa_chain, "Complete!"
|
246 |
|
247 |
|