ryan0303 commited on
Commit
9091ab2
1 Parent(s): 055a2fa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -3
app.py CHANGED
@@ -10,6 +10,7 @@ from langchain_community.llms import HuggingFacePipeline
10
  from langchain.chains import ConversationChain
11
  from langchain.memory import ConversationBufferMemory
12
  from langchain_community.llms import HuggingFaceEndpoint
 
13
 
14
  from pathlib import Path
15
  import chromadb
@@ -166,12 +167,22 @@ def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, pr
166
  # retriever=vector_db.as_retriever(search_type="similarity", search_kwargs={'k': 3})
167
  retriever=vector_db.as_retriever()
168
  progress(0.8, desc="Defining retrieval chain...")
 
 
 
 
 
 
 
 
 
 
169
  qa_chain = ConversationalRetrievalChain.from_llm(
170
  llm,
171
  retriever=retriever,
172
  chain_type="stuff",
173
  memory=memory,
174
- combine_docs_chain_kwargs={"prompt": """Your task is as follows: 1. Determine if the input {query} is compliant with the provided {context}. 2. If the requirement is compliant, report "This requirement is compliant." 3. If the requirement is not compliant report "This requirement is not compliant." 4. If the requirement is not compliant, give the reason for non compliance and return the specific rule or guideline the requirement violates. 5. If the requirement is not compliant, report a refined version of the requirement delimited in quotes that is compliant with the provided {context}."""},
175
  return_source_documents=True,
176
  #return_generated_question=False,
177
  verbose=False,
@@ -226,11 +237,11 @@ def initialize_database(list_file_obj, chunk_size, chunk_overlap, progress=gr.Pr
226
  return vector_db, collection_name, "Complete!"
227
 
228
 
229
- def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
230
  # print("llm_option",llm_option)
231
  llm_name = list_llm[llm_option]
232
  print("llm_name: ",llm_name)
233
- qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db, progress)
234
  return qa_chain, "Complete!"
235
 
236
 
 
10
  from langchain.chains import ConversationChain
11
  from langchain.memory import ConversationBufferMemory
12
  from langchain_community.llms import HuggingFaceEndpoint
13
+ from langchain import PromptTemplate
14
 
15
  from pathlib import Path
16
  import chromadb
 
167
  # retriever=vector_db.as_retriever(search_type="similarity", search_kwargs={'k': 3})
168
  retriever=vector_db.as_retriever()
169
  progress(0.8, desc="Defining retrieval chain...")
170
+
171
+ prompt_template_string = f'''Your task is as follows: 1. Determine if the input {query} is compliant with the provided {context}. 2. If the requirement is compliant, report "This requirement is compliant." 3. If the requirement is not compliant report "This requirement is not compliant." 4. If the requirement is not compliant, give the reason for non compliance and return the specific rule or guideline the requirement violates. 5. If the requirement is not compliant, report a refined version of the requirement delimited in quotes that is compliant with the provided {context}.'''
172
+
173
+ # Create a PromptTemplate object
174
+ prompt_template = PromptTemplate(
175
+ template=prompt_template_string,
176
+ input_variables=["user_question"]
177
+ )
178
+
179
+
180
  qa_chain = ConversationalRetrievalChain.from_llm(
181
  llm,
182
  retriever=retriever,
183
  chain_type="stuff",
184
  memory=memory,
185
+ combine_docs_chain_kwargs={"prompt": prompt_template},
186
  return_source_documents=True,
187
  #return_generated_question=False,
188
  verbose=False,
 
237
  return vector_db, collection_name, "Complete!"
238
 
239
 
240
+ def initialize_LLM(llm_option, prompt_template, llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
241
  # print("llm_option",llm_option)
242
  llm_name = list_llm[llm_option]
243
  print("llm_name: ",llm_name)
244
+ qa_chain = initialize_llmchain(llm_name, prompt = prompt_template, llm_temperature, max_tokens, top_k, vector_db, progress)
245
  return qa_chain, "Complete!"
246
 
247