File size: 4,186 Bytes
6bccf2b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import sys
import os

from langgraph.graph import START, END, StateGraph
from langchain_openai import OpenAIEmbeddings
from chains import simple_chain, llm_with_tools
from langchain_core.messages import BaseMessage, HumanMessage, ToolMessage, AIMessage
from typing import TypedDict, Optional, Dict, List, Union, Annotated
from langchain_core.messages import AnyMessage #human or AI message
from langgraph.graph.message import add_messages # reducer in langgraph 
from langgraph.prebuilt import ToolNode, tools_condition
from langchain.agents import initialize_agent, Tool
from langchain.agents.agent_types import AgentType
from langgraph.checkpoint.memory import MemorySaver
import json
import langchain
from tools import json_to_table, goal_feasibility, rag_tool, save_data
import re

from dotenv import load_dotenv
load_dotenv()

memory = MemorySaver()
config = {"thread_id":"sample"}
tools = [json_to_table, rag_tool]
#tool_executor = ToolExecutor([json_to_table, goal_feasibility])
json_to_table_node = ToolNode([json_to_table])

rag_tool_node = ToolNode([rag_tool])
class Graph(TypedDict):
    query: Annotated[list[AnyMessage], add_messages]
    #chat_history : List[BaseMessage]
    user_data : Dict
    allocations : Dict 
    #data : str 
    output : Dict
    retrieved_context: str

def chat(state):
    inputs = {
        "query": state["query"],
        "user_data": state["user_data"],
        "allocations": state["allocations"],
        #"data": state["data"],
        "chat_history": state["query"],  # If you treat `query` as history
        "retrieved_context": state.get("retrieved_context", "")
    }

    result = simple_chain.invoke(inputs)
    #print(result)

    return {
        "query": state["query"],
        "user_data": state["user_data"],
        "allocations": state["allocations"],
        #"data": state["data"],
        "retrieved_context": "",  # clear after use
        "output": result.content
    }

def json_to_table_node(state):
    tool_output = json_to_table(state["allocations"])  # Or whatever your input is
    return AIMessage(content=tool_output)

def tools_condition(state):
    last_message = state["query"][-1]  # Last user or AI message
    if isinstance(last_message, AIMessage):
        tool_calls = getattr(last_message, "tool_calls", None)
        
        # Check if tool calls exist and handle them
        if tool_calls:
            tool_name = tool_calls[0].get('name', '')  # Safely access the tool name
            
            if tool_name == "json_to_table":
                return "show_allocation_table"
            
            elif tool_name == "rag_tool":
                return "query_rag"
            else:
                return "tools"  # Fallback in case of unknown tool names
    return "END"  # End the flow if no tool calls are found


# ---- GRAPH SETUP ----
graph = StateGraph(Graph)

# Nodes
graph.add_node("chat", chat)
graph.add_node("show_allocation_table", json_to_table_node)
#graph.add_node("save_data_info", save_data_node)
graph.add_node("query_rag", rag_tool_node)
graph.add_node("tool_output_to_message", lambda state: AIMessage(content=state["tool_output"]))


#graph.add_node("tools", ToolNode(tools))  # fallback for other tools


# Main flow
graph.add_edge(START, "chat")
graph.add_conditional_edges("chat", tools_condition)

# Each tool goes back to chat
graph.add_edge("show_allocation_table", "chat")
#graph.add_edge("save_data_info", "chat")
graph.add_edge("query_rag", "chat")

# End after a loop
graph.add_edge("chat", END)


# Compile
app = graph.compile(checkpointer=memory)

'''
with open('/home/pavan/Desktop/FOLDERS/RUBIC/RAG_without_profiler/RAG_rubik/sample_data/sample_alloc.json', 'r') as f:
    data = json.load(f)
with open('/home/pavan/Desktop/FOLDERS/RUBIC/RAG_without_profiler/RAG_rubik/sample_data/sample_alloc.json', 'r') as f:
    allocs = json.load(f)
inputs = {
    "query":"display my investments.",
    "user_data":data,
    "allocations":allocs,
    "data":"",
    "chat_history": [],
    
}

langchain.debug = True
print(app.invoke(inputs, config={"configurable": {"thread_id": "sample"}}).get('output'))
#print(json_to_table.args_schema.model_json_schema())
'''