Update knowledge_retrieval.py
Browse files- knowledge_retrieval.py +12 -7
knowledge_retrieval.py
CHANGED
@@ -6,27 +6,32 @@ from langchain.chains.combine_documents import create_stuff_documents_chain
|
|
6 |
from langchain_core.prompts import ChatPromptTemplate
|
7 |
from langchain.retrievers import EnsembleRetriever
|
8 |
from ai_config import n_of_questions, openai_api_key
|
9 |
-
from prompt_instructions import
|
10 |
|
11 |
n_of_questions = n_of_questions()
|
12 |
|
13 |
-
def setup_knowledge_retrieval(llm, language='english'):
|
14 |
embedding_model = OpenAIEmbeddings(openai_api_key=openai_api_key)
|
15 |
|
16 |
documents_faiss_index = FAISS.load_local("knowledge/faiss_index_all_documents", embedding_model,
|
17 |
allow_dangerous_deserialization=True)
|
18 |
|
19 |
-
|
20 |
documents_retriever = documents_faiss_index.as_retriever()
|
21 |
|
22 |
combined_retriever = EnsembleRetriever(
|
23 |
retrievers=[documents_retriever]
|
24 |
)
|
25 |
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
|
31 |
report_prompt = ChatPromptTemplate.from_messages([
|
32 |
("system", get_report_prompt(language)),
|
|
|
6 |
from langchain_core.prompts import ChatPromptTemplate
|
7 |
from langchain.retrievers import EnsembleRetriever
|
8 |
from ai_config import n_of_questions, openai_api_key
|
9 |
+
from prompt_instructions import get_interview_prompt_sarah, get_interview_prompt_aaron, get_report_prompt
|
10 |
|
11 |
n_of_questions = n_of_questions()
|
12 |
|
13 |
+
def setup_knowledge_retrieval(llm, language='english', voice='Sarah'):
|
14 |
embedding_model = OpenAIEmbeddings(openai_api_key=openai_api_key)
|
15 |
|
16 |
documents_faiss_index = FAISS.load_local("knowledge/faiss_index_all_documents", embedding_model,
|
17 |
allow_dangerous_deserialization=True)
|
18 |
|
|
|
19 |
documents_retriever = documents_faiss_index.as_retriever()
|
20 |
|
21 |
combined_retriever = EnsembleRetriever(
|
22 |
retrievers=[documents_retriever]
|
23 |
)
|
24 |
|
25 |
+
if voice == 'Sarah':
|
26 |
+
interview_prompt = ChatPromptTemplate.from_messages([
|
27 |
+
("system", get_interview_prompt_sarah(language, n_of_questions)),
|
28 |
+
("human", "{input}")
|
29 |
+
])
|
30 |
+
else:
|
31 |
+
interview_prompt = ChatPromptTemplate.from_messages([
|
32 |
+
("system", get_interview_prompt_aaron(language, n_of_questions)),
|
33 |
+
("human", "{input}")
|
34 |
+
])
|
35 |
|
36 |
report_prompt = ChatPromptTemplate.from_messages([
|
37 |
("system", get_report_prompt(language)),
|