|
import random |
|
from langchain_community.vectorstores import FAISS |
|
from langchain_openai import OpenAIEmbeddings |
|
from langchain.chains import create_retrieval_chain |
|
from langchain.chains.combine_documents import create_stuff_documents_chain |
|
from langchain_core.prompts import ChatPromptTemplate |
|
from langchain.retrievers import EnsembleRetriever |
|
from ai_config import n_of_questions, openai_api_key |
|
from prompt_instructions import get_interview_prompt_sarah, get_interview_prompt_aaron, get_report_prompt |
|
|
|
n_of_questions = n_of_questions() |
|
|
|
def setup_knowledge_retrieval(llm, language='english', voice='Sarah'): |
|
embedding_model = OpenAIEmbeddings(openai_api_key=openai_api_key) |
|
|
|
documents_faiss_index = FAISS.load_local("knowledge/faiss_index_all_documents", embedding_model, |
|
allow_dangerous_deserialization=True) |
|
|
|
documents_retriever = documents_faiss_index.as_retriever() |
|
|
|
combined_retriever = EnsembleRetriever( |
|
retrievers=[documents_retriever] |
|
) |
|
|
|
if voice == 'Sarah': |
|
interview_prompt = ChatPromptTemplate.from_messages([ |
|
("system", get_interview_prompt_sarah(language, n_of_questions)), |
|
("human", "{input}") |
|
]) |
|
else: |
|
interview_prompt = ChatPromptTemplate.from_messages([ |
|
("system", get_interview_prompt_aaron(language, n_of_questions)), |
|
("human", "{input}") |
|
]) |
|
|
|
report_prompt = ChatPromptTemplate.from_messages([ |
|
("system", get_report_prompt(language)), |
|
("human", "Please provide a concise clinical report based on the interview.") |
|
]) |
|
|
|
interview_chain = create_stuff_documents_chain(llm, interview_prompt) |
|
report_chain = create_stuff_documents_chain(llm, report_prompt) |
|
|
|
interview_retrieval_chain = create_retrieval_chain(combined_retriever, interview_chain) |
|
report_retrieval_chain = create_retrieval_chain(combined_retriever, report_chain) |
|
|
|
return interview_retrieval_chain, report_retrieval_chain, combined_retriever |
|
|
|
|
|
def get_next_response(interview_chain, message, history, question_count): |
|
combined_history = "\n".join(history) |
|
|
|
|
|
if question_count >= n_of_questions: |
|
return "Thank you for your responses. I will now prepare a report." |
|
|
|
|
|
result = interview_chain.invoke({ |
|
"input": f"Based on the patient's last response: '{message}', and considering the full interview history, ask a specific, detailed question that hasn't been asked before and is relevant to the patient's situation.", |
|
"history": combined_history, |
|
"question_number": question_count + 1 |
|
}) |
|
|
|
next_question = result.get("answer", "Could you provide more details on that?") |
|
|
|
|
|
history.append(f"Q{question_count + 1}: {next_question}") |
|
history.append(f"A{question_count + 1}: {message}") |
|
|
|
return next_question |
|
|
|
|
|
def generate_report(report_chain, history, language): |
|
combined_history = "\n".join(history) |
|
|
|
result = report_chain.invoke({ |
|
"input": "Please provide a clinical report based on the interview.", |
|
"history": combined_history, |
|
"language": language |
|
}) |
|
|
|
return result.get("answer", "Unable to generate report due to insufficient information.") |
|
|
|
|
|
def get_initial_question(interview_chain): |
|
result = interview_chain.invoke({ |
|
"input": "What should be the first question in a clinical psychology interview?", |
|
"history": "", |
|
"question_number": 1 |
|
}) |
|
return result.get("answer", "Could you tell me a little bit about yourself and what brings you here today?") |
|
|