reab5555 commited on
Commit
a831d89
·
verified ·
1 Parent(s): cb24908

Update knowledge_retrieval.py

Browse files
Files changed (1) hide show
  1. knowledge_retrieval.py +12 -7
knowledge_retrieval.py CHANGED
@@ -6,27 +6,32 @@ from langchain.chains.combine_documents import create_stuff_documents_chain
6
  from langchain_core.prompts import ChatPromptTemplate
7
  from langchain.retrievers import EnsembleRetriever
8
  from ai_config import n_of_questions, openai_api_key
9
- from prompt_instructions import get_interview_prompt, get_report_prompt
10
 
11
  n_of_questions = n_of_questions()
12
 
13
- def setup_knowledge_retrieval(llm, language='english'):
14
  embedding_model = OpenAIEmbeddings(openai_api_key=openai_api_key)
15
 
16
  documents_faiss_index = FAISS.load_local("knowledge/faiss_index_all_documents", embedding_model,
17
  allow_dangerous_deserialization=True)
18
 
19
-
20
  documents_retriever = documents_faiss_index.as_retriever()
21
 
22
  combined_retriever = EnsembleRetriever(
23
  retrievers=[documents_retriever]
24
  )
25
 
26
- interview_prompt = ChatPromptTemplate.from_messages([
27
- ("system", get_interview_prompt(language, n_of_questions)),
28
- ("human", "{input}")
29
- ])
 
 
 
 
 
 
30
 
31
  report_prompt = ChatPromptTemplate.from_messages([
32
  ("system", get_report_prompt(language)),
 
6
  from langchain_core.prompts import ChatPromptTemplate
7
  from langchain.retrievers import EnsembleRetriever
8
  from ai_config import n_of_questions, openai_api_key
9
+ from prompt_instructions import get_interview_prompt_sarah, get_interview_prompt_aaron, get_report_prompt
10
 
11
  n_of_questions = n_of_questions()
12
 
13
+ def setup_knowledge_retrieval(llm, language='english', voice='Sarah'):
14
  embedding_model = OpenAIEmbeddings(openai_api_key=openai_api_key)
15
 
16
  documents_faiss_index = FAISS.load_local("knowledge/faiss_index_all_documents", embedding_model,
17
  allow_dangerous_deserialization=True)
18
 
 
19
  documents_retriever = documents_faiss_index.as_retriever()
20
 
21
  combined_retriever = EnsembleRetriever(
22
  retrievers=[documents_retriever]
23
  )
24
 
25
+ if voice == 'Sarah':
26
+ interview_prompt = ChatPromptTemplate.from_messages([
27
+ ("system", get_interview_prompt_sarah(language, n_of_questions)),
28
+ ("human", "{input}")
29
+ ])
30
+ else:
31
+ interview_prompt = ChatPromptTemplate.from_messages([
32
+ ("system", get_interview_prompt_aaron(language, n_of_questions)),
33
+ ("human", "{input}")
34
+ ])
35
 
36
  report_prompt = ChatPromptTemplate.from_messages([
37
  ("system", get_report_prompt(language)),