Spaces:
Sleeping
Sleeping
File size: 3,176 Bytes
8b1e853 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
import os
from langgraph.graph import StateGraph, END
from langgraph.checkpoint.memory import MemorySaver
# from .prompt import SYSTEM_PROMPT
import asyncio
from .agents.supervisor import SupervisorAgent
from .agents.verification import VerificationAgent, process_tool, verification_route
from .agents.medical import MedicalQuestionAgent, medical_route
from .agents.rag import RAGTool
from .agents.state.state import GraphState
from data.preprocessing.vectorstore.get import retriever
from langchain_openai import ChatOpenAI
from .upload_pdf.ingest_documents import PDFProcessor
pdf_processor = PDFProcessor(file_path=os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'data', 'combined_forms', 'temp', 'ACTC-Patient-Packet.pdf')))
questions = pdf_processor.extract_questions()
questions = [q.content for q in questions]
print('QUESTIONS**********************', questions)
memory = MemorySaver()
graph = StateGraph(GraphState)
supervisor = SupervisorAgent()
graph.add_node("supervisor_agent", supervisor)
graph.add_node("verification_agent", VerificationAgent())
graph.add_node("verification_tool_node", process_tool)
graph.add_node("medical_agent", MedicalQuestionAgent(questions=questions))
graph.add_node("rag_tool_node", RAGTool(retriever=retriever,
llm=ChatOpenAI(model=os.environ["MODEL"])))
graph.set_entry_point("supervisor_agent")
graph.add_edge("verification_tool_node", "verification_agent")
graph.add_edge("rag_tool_node", "medical_agent")
graph.add_conditional_edges(
'supervisor_agent',
supervisor.route
)
graph.add_conditional_edges(
"verification_agent",
verification_route,
{"__end__": END, "verification_tool_node": "verification_tool_node"}
)
graph.add_conditional_edges(
"medical_agent",
medical_route,
{"__end__": END, "rag_tool_node": "rag_tool_node"}
)
async def run_verfication(app, fields="", values=""):
config = {"configurable": {"thread_id": 1}}
_input = input('User: ')
while _input != 'quit':
async for event in app.astream_events({"messages": [('user', _input)], "fields": "full name, birthdate", "values": "John Doe, 1990-01-01"}, config=config, version="v2"):
if event['event'] == "on_chat_model_stream":
data = event["data"]
if data["chunk"].content:
print(data["chunk"].content.replace(
"\n", ""), end="", flush=True)
_input = input('\nUser: ')
async def run(app):
from langchain_core.messages import AIMessageChunk, HumanMessage
config = {"configurable": {"thread_id": 1}}
_user_input = input("User: ")
while _user_input != "quit":
out=""
astream = app.astream({"messages": [HumanMessage(content=_user_input)], "fields":"full name, birthdate", "values":"John Doe, 1990-01-01"}, config=config, stream_mode="messages")
async for msg, metadata in astream:
if isinstance(msg, AIMessageChunk):
out+=msg.content
print('Assistant: ', out)
_user_input = input("User: ")
if __name__ == "__main__":
app = graph.compile(checkpointer=memory)
asyncio.run(run(app))
|