File size: 3,176 Bytes
8b1e853
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import os
from langgraph.graph import StateGraph, END
from langgraph.checkpoint.memory import MemorySaver
# from .prompt import SYSTEM_PROMPT
import asyncio
from .agents.supervisor import SupervisorAgent
from .agents.verification import VerificationAgent, process_tool, verification_route
from .agents.medical import MedicalQuestionAgent, medical_route
from .agents.rag import RAGTool
from .agents.state.state import GraphState
from data.preprocessing.vectorstore.get import retriever
from langchain_openai import ChatOpenAI
from .upload_pdf.ingest_documents import PDFProcessor


pdf_processor = PDFProcessor(file_path=os.path.abspath(os.path.join(os.path.dirname(__file__), '..',  'data', 'combined_forms', 'temp', 'ACTC-Patient-Packet.pdf')))
questions = pdf_processor.extract_questions()
questions = [q.content for q in questions]
print('QUESTIONS**********************', questions)
memory = MemorySaver()

graph = StateGraph(GraphState)

supervisor = SupervisorAgent()
graph.add_node("supervisor_agent", supervisor)
graph.add_node("verification_agent", VerificationAgent())
graph.add_node("verification_tool_node", process_tool)
graph.add_node("medical_agent", MedicalQuestionAgent(questions=questions))
graph.add_node("rag_tool_node", RAGTool(retriever=retriever,
               llm=ChatOpenAI(model=os.environ["MODEL"])))

graph.set_entry_point("supervisor_agent")

graph.add_edge("verification_tool_node", "verification_agent")
graph.add_edge("rag_tool_node", "medical_agent")
graph.add_conditional_edges(
    'supervisor_agent',
     supervisor.route
)
graph.add_conditional_edges(
    "verification_agent",
    verification_route,
    {"__end__": END, "verification_tool_node": "verification_tool_node"}
)
graph.add_conditional_edges(
    "medical_agent",
    medical_route,
    {"__end__": END, "rag_tool_node": "rag_tool_node"}
)


async def run_verfication(app, fields="", values=""):
    config = {"configurable": {"thread_id": 1}}

    _input = input('User: ')
    while _input != 'quit':
        async for event in app.astream_events({"messages": [('user', _input)], "fields": "full name, birthdate", "values": "John Doe, 1990-01-01"}, config=config, version="v2"):
            if event['event'] == "on_chat_model_stream":
                data = event["data"]
                if data["chunk"].content:
                    print(data["chunk"].content.replace(
                        "\n", ""), end="", flush=True)

        _input = input('\nUser: ')


async def run(app):
    from langchain_core.messages import AIMessageChunk, HumanMessage
    config = {"configurable": {"thread_id": 1}}
    _user_input = input("User: ")

    while _user_input != "quit":
        out=""
        astream = app.astream({"messages": [HumanMessage(content=_user_input)], "fields":"full name, birthdate", "values":"John Doe, 1990-01-01"}, config=config, stream_mode="messages")
        async for msg, metadata in astream:
            if isinstance(msg, AIMessageChunk):
                out+=msg.content
        print('Assistant: ', out)
        _user_input = input("User: ")
    

if __name__ == "__main__":
    app = graph.compile(checkpointer=memory)
    asyncio.run(run(app))