|
import sys |
|
import os |
|
|
|
from langgraph.graph import START, END, StateGraph |
|
from langchain_openai import OpenAIEmbeddings |
|
from chains import simple_chain, llm_with_tools |
|
from langchain_core.messages import BaseMessage, HumanMessage, ToolMessage, AIMessage |
|
from typing import TypedDict, Optional, Dict, List, Union, Annotated |
|
from langchain_core.messages import AnyMessage |
|
from langgraph.graph.message import add_messages |
|
from langgraph.prebuilt import ToolNode, tools_condition |
|
from langchain.agents import initialize_agent, Tool |
|
from langchain.agents.agent_types import AgentType |
|
from langgraph.checkpoint.memory import MemorySaver |
|
import json |
|
import langchain |
|
from tools import json_to_table, goal_feasibility, rag_tool, save_data |
|
import re |
|
|
|
from dotenv import load_dotenv |
|
load_dotenv() |
|
|
|
memory = MemorySaver() |
|
config = {"thread_id":"sample"} |
|
tools = [json_to_table, rag_tool] |
|
|
|
json_to_table_node = ToolNode([json_to_table]) |
|
|
|
rag_tool_node = ToolNode([rag_tool]) |
|
class Graph(TypedDict): |
|
query: Annotated[list[AnyMessage], add_messages] |
|
|
|
user_data : Dict |
|
allocations : Dict |
|
|
|
output : Dict |
|
retrieved_context: str |
|
|
|
def chat(state): |
|
inputs = { |
|
"query": state["query"], |
|
"user_data": state["user_data"], |
|
"allocations": state["allocations"], |
|
|
|
"chat_history": state["query"], |
|
"retrieved_context": state.get("retrieved_context", "") |
|
} |
|
|
|
result = simple_chain.invoke(inputs) |
|
|
|
|
|
return { |
|
"query": state["query"], |
|
"user_data": state["user_data"], |
|
"allocations": state["allocations"], |
|
|
|
"retrieved_context": "", |
|
"output": result.content |
|
} |
|
|
|
def json_to_table_node(state): |
|
tool_output = json_to_table(state["allocations"]) |
|
return AIMessage(content=tool_output) |
|
|
|
def tools_condition(state): |
|
last_message = state["query"][-1] |
|
if isinstance(last_message, AIMessage): |
|
tool_calls = getattr(last_message, "tool_calls", None) |
|
|
|
|
|
if tool_calls: |
|
tool_name = tool_calls[0].get('name', '') |
|
|
|
if tool_name == "json_to_table": |
|
return "show_allocation_table" |
|
|
|
elif tool_name == "rag_tool": |
|
return "query_rag" |
|
else: |
|
return "tools" |
|
return "END" |
|
|
|
|
|
|
|
graph = StateGraph(Graph) |
|
|
|
|
|
graph.add_node("chat", chat) |
|
graph.add_node("show_allocation_table", json_to_table_node) |
|
|
|
graph.add_node("query_rag", rag_tool_node) |
|
graph.add_node("tool_output_to_message", lambda state: AIMessage(content=state["tool_output"])) |
|
|
|
|
|
|
|
|
|
|
|
|
|
graph.add_edge(START, "chat") |
|
graph.add_conditional_edges("chat", tools_condition) |
|
|
|
|
|
graph.add_edge("show_allocation_table", "chat") |
|
|
|
graph.add_edge("query_rag", "chat") |
|
|
|
|
|
graph.add_edge("chat", END) |
|
|
|
|
|
|
|
app = graph.compile(checkpointer=memory) |
|
|
|
''' |
|
with open('/home/pavan/Desktop/FOLDERS/RUBIC/RAG_without_profiler/RAG_rubik/sample_data/sample_alloc.json', 'r') as f: |
|
data = json.load(f) |
|
with open('/home/pavan/Desktop/FOLDERS/RUBIC/RAG_without_profiler/RAG_rubik/sample_data/sample_alloc.json', 'r') as f: |
|
allocs = json.load(f) |
|
inputs = { |
|
"query":"display my investments.", |
|
"user_data":data, |
|
"allocations":allocs, |
|
"data":"", |
|
"chat_history": [], |
|
|
|
} |
|
|
|
langchain.debug = True |
|
print(app.invoke(inputs, config={"configurable": {"thread_id": "sample"}}).get('output')) |
|
#print(json_to_table.args_schema.model_json_schema()) |
|
''' |