File size: 3,735 Bytes
7ad2e16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import random
from langchain_community.vectorstores import FAISS
from langchain_openai import OpenAIEmbeddings
from langchain.chains import create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_core.prompts import ChatPromptTemplate
from langchain.retrievers import EnsembleRetriever
from ai_config import n_of_questions, openai_api_key
from prompt_instructions import get_interview_prompt_sarah, get_interview_prompt_aaron, get_report_prompt

n_of_questions = n_of_questions()

def setup_knowledge_retrieval(llm, language='english', voice='Sarah'):
    embedding_model = OpenAIEmbeddings(openai_api_key=openai_api_key)

    documents_faiss_index = FAISS.load_local("knowledge/faiss_index_all_documents", embedding_model,
                                               allow_dangerous_deserialization=True)

    documents_retriever = documents_faiss_index.as_retriever()

    combined_retriever = EnsembleRetriever(
        retrievers=[documents_retriever]
    )

    if voice == 'Sarah':
        interview_prompt = ChatPromptTemplate.from_messages([
            ("system", get_interview_prompt_sarah(language, n_of_questions)),
            ("human", "{input}")
        ])
    else:
        interview_prompt = ChatPromptTemplate.from_messages([
            ("system", get_interview_prompt_aaron(language, n_of_questions)),
            ("human", "{input}")
        ])

    report_prompt = ChatPromptTemplate.from_messages([
        ("system", get_report_prompt(language)),
        ("human", "Please provide a concise clinical report based on the interview.")
    ])

    interview_chain = create_stuff_documents_chain(llm, interview_prompt)
    report_chain = create_stuff_documents_chain(llm, report_prompt)

    interview_retrieval_chain = create_retrieval_chain(combined_retriever, interview_chain)
    report_retrieval_chain = create_retrieval_chain(combined_retriever, report_chain)

    return interview_retrieval_chain, report_retrieval_chain, combined_retriever


def get_next_response(interview_chain, message, history, question_count):
    combined_history = "\n".join(history)

    # Check if the interview should end
    if question_count >= n_of_questions:
        return "Thank you for your responses. I will now prepare a report."

    # Generate the next question
    result = interview_chain.invoke({
        "input": f"Based on the patient's last response: '{message}', and considering the full interview history, ask a specific, detailed question that hasn't been asked before and is relevant to the patient's situation.",
        "history": combined_history,
        "question_number": question_count + 1  # Increment question number here
    })

    next_question = result.get("answer", "Could you provide more details on that?")

    # Update history with the new question and response
    history.append(f"Q{question_count + 1}: {next_question}")
    history.append(f"A{question_count + 1}: {message}")

    return next_question


def generate_report(report_chain, history, language):
    combined_history = "\n".join(history)

    result = report_chain.invoke({
        "input": "Please provide a clinical report based on the interview.",
        "history": combined_history,
        "language": language
    })

    return result.get("answer", "Unable to generate report due to insufficient information.")


def get_initial_question(interview_chain):
    result = interview_chain.invoke({
        "input": "What should be the first question in a clinical psychology interview?",
        "history": "",
        "question_number": 1
    })
    return result.get("answer", "Could you tell me a little bit about yourself and what brings you here today?")