angry-meow commited on
Commit
06dc1b7
·
1 Parent(s): 954011f

updated graph stucture

Browse files
Files changed (1) hide show
  1. graph.py +37 -16
graph.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import Dict, TypedDict, Annotated, Sequence
2
  from langgraph.graph import Graph, StateGraph, END
3
  from langgraph.prebuilt import ToolExecutor
4
  from langchain.schema import StrOutputParser
@@ -12,11 +12,14 @@ from operator import itemgetter
12
  # Define the state structure
13
  class State(TypedDict):
14
  messages: Sequence[str]
 
15
  research_data: Dict[str, str]
16
- draft_post: str
 
17
  final_post: str
18
 
19
 
 
20
  # Research Agent Pieces
21
  qdrant_research_chain = (
22
  {"context": itemgetter("topic") | models.compression_retriever, "topic": itemgetter("topic")}
@@ -34,10 +37,10 @@ tavily_chain = (
34
 
35
  def query_qdrant(state: State) -> State:
36
  # Extract the last message as the input
37
- input_text = state["messages"][-1]
38
 
39
  # Run the chain
40
- result = qdrant_research_chain.invoke({"topic": input_text})
41
 
42
  # Update the state with the research results
43
  state["research_data"]["qdrant_results"] = result
@@ -46,7 +49,7 @@ def query_qdrant(state: State) -> State:
46
 
47
  def web_search(state: State) -> State:
48
  # Extract the last message as the topic
49
- topic = state["messages"][-1]
50
 
51
  # Get the Qdrant results from the state
52
  qdrant_results = state["research_data"].get("qdrant_results", "No previous results available.")
@@ -99,11 +102,15 @@ research_graph.add_node("research_supervisor", research_supervisor)
99
 
100
  research_graph.add_edge("query_qdrant", "research_supervisor")
101
  research_graph.add_edge("web_search", "research_supervisor")
102
- research_graph.add_edge("research_supervisor", "query_qdrant")
103
- research_graph.add_edge("research_supervisor", "web_search")
104
- research_graph.add_edge("research_supervisor", END)
 
 
 
105
 
106
  research_graph.set_entry_point("research_supervisor")
 
107
 
108
  # Create the writing team graph
109
  writing_graph = StateGraph(State)
@@ -114,15 +121,25 @@ writing_graph.add_node("voice_editing", voice_editing)
114
  writing_graph.add_node("post_review", post_review)
115
  writing_graph.add_node("writing_supervisor", writing_supervisor)
116
 
117
- writing_graph.add_edge("writing_supervisor", "post_creation")
118
- writing_graph.add_edge("post_creation", "copy_editing")
119
- writing_graph.add_edge("copy_editing", "voice_editing")
120
- writing_graph.add_edge("voice_editing", "post_review")
121
  writing_graph.add_edge("post_review", "writing_supervisor")
122
- writing_graph.add_edge("writing_supervisor", END)
 
 
 
 
 
 
 
 
 
123
 
124
  writing_graph.set_entry_point("writing_supervisor")
125
 
 
 
126
  # Create the overall graph
127
  overall_graph = StateGraph(State)
128
 
@@ -136,11 +153,15 @@ overall_graph.add_node("overall_supervisor", overall_supervisor)
136
  overall_graph.set_entry_point("overall_supervisor")
137
 
138
  # Connect the nodes
139
- overall_graph.add_edge("overall_supervisor", "research_team")
140
  overall_graph.add_edge("research_team", "overall_supervisor")
141
- overall_graph.add_edge("overall_supervisor", "writing_team")
142
  overall_graph.add_edge("writing_team", "overall_supervisor")
143
- overall_graph.add_edge("overall_supervisor", END)
 
 
 
 
 
 
144
 
145
  # Compile the graph
146
  app = overall_graph.compile()
 
1
+ from typing import Dict, List, TypedDict, Annotated, Sequence
2
  from langgraph.graph import Graph, StateGraph, END
3
  from langgraph.prebuilt import ToolExecutor
4
  from langchain.schema import StrOutputParser
 
12
  # Define the state structure
13
  class State(TypedDict):
14
  messages: Sequence[str]
15
+ topic: str
16
  research_data: Dict[str, str]
17
+ team_members: List[str]
18
+ draft_posts: Sequence[str]
19
  final_post: str
20
 
21
 
22
+ research_members = ["Qdrant_researcher", "Web_researcher"]
23
  # Research Agent Pieces
24
  qdrant_research_chain = (
25
  {"context": itemgetter("topic") | models.compression_retriever, "topic": itemgetter("topic")}
 
37
 
38
  def query_qdrant(state: State) -> State:
39
  # Extract the last message as the input
40
+ topic = state["topic"]
41
 
42
  # Run the chain
43
+ result = qdrant_research_chain.invoke({"topic": topic})
44
 
45
  # Update the state with the research results
46
  state["research_data"]["qdrant_results"] = result
 
49
 
50
  def web_search(state: State) -> State:
51
  # Extract the last message as the topic
52
+ topic = state["topic"]
53
 
54
  # Get the Qdrant results from the state
55
  qdrant_results = state["research_data"].get("qdrant_results", "No previous results available.")
 
102
 
103
  research_graph.add_edge("query_qdrant", "research_supervisor")
104
  research_graph.add_edge("web_search", "research_supervisor")
105
+ research_graph.add_conditional_edges(
106
+ "research_supervisor",
107
+ lambda x: x["next"],
108
+ {"query_qdrant": "query_qdrant", "web_search": "web_search", "FINISH": END},
109
+ )
110
+ #research_graph.add_edge("research_supervisor", END)
111
 
112
  research_graph.set_entry_point("research_supervisor")
113
+ research_graph_comp = research_graph.compile()
114
 
115
  # Create the writing team graph
116
  writing_graph = StateGraph(State)
 
121
  writing_graph.add_node("post_review", post_review)
122
  writing_graph.add_node("writing_supervisor", writing_supervisor)
123
 
124
+ writing_graph.add_edge("post_creation", "writing_supervisor")
125
+ writing_graph.add_edge("copy_editing", "writing_supervisor")
126
+ writing_graph.add_edge("voice_editing", "writing_supervisor")
 
127
  writing_graph.add_edge("post_review", "writing_supervisor")
128
+ writing_graph.add_conditional_edges(
129
+ "writing_supervisor",
130
+ lambda x: x["next"],
131
+ {"post_creation": "post_creation",
132
+ "copy_editing": "copy_editing",
133
+ "voice_editing": "voice_editing",
134
+ "post_review": "post_review",
135
+ "FINISH": END},
136
+ )
137
+ #writing_graph.add_edge("writing_supervisor", END)
138
 
139
  writing_graph.set_entry_point("writing_supervisor")
140
 
141
+ writing_graph_comp = research_graph.compile()
142
+
143
  # Create the overall graph
144
  overall_graph = StateGraph(State)
145
 
 
153
  overall_graph.set_entry_point("overall_supervisor")
154
 
155
  # Connect the nodes
 
156
  overall_graph.add_edge("research_team", "overall_supervisor")
 
157
  overall_graph.add_edge("writing_team", "overall_supervisor")
158
+ overall_graph.add_conditional_edges(
159
+ "overall_supervisor",
160
+ lambda x: x["next"],
161
+ {"research_team": "research_team",
162
+ "writing_team": "writing_team",
163
+ "FINISH": END},
164
+ )
165
 
166
  # Compile the graph
167
  app = overall_graph.compile()