Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -308,16 +308,18 @@ class DocumentRAG:
|
|
| 308 |
def run_multiagent_storygraph(self, topic: str, context: str):
|
| 309 |
self.llm = ChatOpenAI(model_name="gpt-4", temperature=0.7, api_key=self.api_key)
|
| 310 |
|
| 311 |
-
#
|
| 312 |
story_graph = StateGraph(StoryState)
|
| 313 |
story_graph.add_node("Retrieve", self.retrieve_docs)
|
|
|
|
| 314 |
story_graph.add_node("Generate", self.generate_story)
|
| 315 |
story_graph.set_entry_point("Retrieve")
|
| 316 |
-
story_graph.add_edge("Retrieve", "
|
|
|
|
| 317 |
story_graph.set_finish_point("Generate")
|
| 318 |
story_subgraph = story_graph.compile()
|
| 319 |
|
| 320 |
-
# Main graph
|
| 321 |
graph = StateGraph(MultiAgentState)
|
| 322 |
graph.add_node("beginner_topic", self.beginner_topic)
|
| 323 |
graph.add_node("middle_topic", self.middle_topic)
|
|
@@ -334,17 +336,24 @@ class DocumentRAG:
|
|
| 334 |
["story_generator"])
|
| 335 |
graph.add_edge("story_generator", END)
|
| 336 |
|
| 337 |
-
|
| 338 |
compiled = graph.compile(checkpointer=MemorySaver())
|
| 339 |
thread = {"configurable": {"thread_id": "storygraph-session"}}
|
| 340 |
|
|
|
|
| 341 |
result = compiled.invoke({"topic": [topic], "context": [context]}, thread)
|
| 342 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 343 |
return result
|
| 344 |
|
| 345 |
|
| 346 |
|
| 347 |
|
|
|
|
| 348 |
# Initialize RAG system in session state
|
| 349 |
if "rag_system" not in st.session_state:
|
| 350 |
st.session_state.rag_system = DocumentRAG()
|
|
|
|
| 308 |
def run_multiagent_storygraph(self, topic: str, context: str):
|
| 309 |
self.llm = ChatOpenAI(model_name="gpt-4", temperature=0.7, api_key=self.api_key)
|
| 310 |
|
| 311 |
+
# Define the story subgraph with reranking
|
| 312 |
story_graph = StateGraph(StoryState)
|
| 313 |
story_graph.add_node("Retrieve", self.retrieve_docs)
|
| 314 |
+
story_graph.add_node("Rerank", self.rerank_docs) # Add rerank step
|
| 315 |
story_graph.add_node("Generate", self.generate_story)
|
| 316 |
story_graph.set_entry_point("Retrieve")
|
| 317 |
+
story_graph.add_edge("Retrieve", "Rerank")
|
| 318 |
+
story_graph.add_edge("Rerank", "Generate")
|
| 319 |
story_graph.set_finish_point("Generate")
|
| 320 |
story_subgraph = story_graph.compile()
|
| 321 |
|
| 322 |
+
# Main graph setup
|
| 323 |
graph = StateGraph(MultiAgentState)
|
| 324 |
graph.add_node("beginner_topic", self.beginner_topic)
|
| 325 |
graph.add_node("middle_topic", self.middle_topic)
|
|
|
|
| 336 |
["story_generator"])
|
| 337 |
graph.add_edge("story_generator", END)
|
| 338 |
|
|
|
|
| 339 |
compiled = graph.compile(checkpointer=MemorySaver())
|
| 340 |
thread = {"configurable": {"thread_id": "storygraph-session"}}
|
| 341 |
|
| 342 |
+
# Initial run to extract subtopics
|
| 343 |
result = compiled.invoke({"topic": [topic], "context": [context]}, thread)
|
| 344 |
|
| 345 |
+
# Fallback if no subtopics were extracted
|
| 346 |
+
if not result.get("sub_topic_list"):
|
| 347 |
+
fallback_subs = ["Neural Networks", "Reinforcement Learning", "Supervised vs Unsupervised"]
|
| 348 |
+
compiled.update_state(thread, {"sub_topic_list": fallback_subs})
|
| 349 |
+
result = compiled.invoke(None, thread, stream_mode="values")
|
| 350 |
+
|
| 351 |
return result
|
| 352 |
|
| 353 |
|
| 354 |
|
| 355 |
|
| 356 |
+
|
| 357 |
# Initialize RAG system in session state
|
| 358 |
if "rag_system" not in st.session_state:
|
| 359 |
st.session_state.rag_system = DocumentRAG()
|