subashpoudel commited on
Commit
fb491f0
·
1 Parent(s): d604e49

Updated the react agent with tool in node

Browse files
Files changed (1) hide show
  1. my_agent/utils/nodes.py +23 -28
my_agent/utils/nodes.py CHANGED
@@ -1,13 +1,17 @@
1
  import pandas as pd
2
  import ast
3
  from .state import State
4
- from .tools import StoryFormatter, BrainstromTopicFormatter
5
  from langchain_core.messages import SystemMessage
6
  from .models_loader import llm , ST
7
  from .data_loader import load_influencer_data
8
  from groq import Groq
9
  import os
10
  from .prompts import image_captioning_prompt , initial_story_prompt , refined_story_prompt , brainstroming_prompt , final_story_prompt
 
 
 
 
11
 
12
 
13
  def caption_image(state: State) -> State:
@@ -79,36 +83,27 @@ def retrieve(state: State) -> State:
79
  return state
80
 
81
  def generate_story(state:State)-> State:
82
- print('The state retrieval is:',state.retrievals)
83
- retrieval_list= state.retrievals[-1]
84
- agentic_stories = []
85
-
86
- for item in retrieval_list:
87
- print('item:', item[-1].values())
88
-
89
- agentic_stories.extend(item[-1].values()) # Add all stories to the list
90
-
91
- retrieval = " ".join(agentic_stories)
92
-
93
- if len(state.preferred_topics)==0:
94
- template = initial_story_prompt(retrieval , state)
95
- else:
96
- template = refined_story_prompt(retrieval , state)
97
 
98
  # and {state.image_captions[-1]}
99
 
100
- messages = [SystemMessage(content=template)]
101
- response = llm.bind_tools([StoryFormatter]).invoke(messages)
102
- print('The response is:',response)
103
- if hasattr(response, 'tool_calls') and response.tool_calls:
104
- response = response.tool_calls[0]['args']
105
- elif hasattr(response, 'content'):
106
- response = response.content
107
- else:
108
- response = "No response"
109
- state.stories.append(response)
110
- # return State(messages="Story generated", topic=state.topic,stories=state.stories)
111
- return state
112
 
113
 
114
 
 
1
  import pandas as pd
2
  import ast
3
  from .state import State
4
+ from .tools import StoryFormatter, BrainstromTopicFormatter , retrieve_tool
5
  from langchain_core.messages import SystemMessage
6
  from .models_loader import llm , ST
7
  from .data_loader import load_influencer_data
8
  from groq import Groq
9
  import os
10
  from .prompts import image_captioning_prompt , initial_story_prompt , refined_story_prompt , brainstroming_prompt , final_story_prompt
11
+ from langgraph.prebuilt import create_react_agent
12
+ from pydantic import BaseModel , Field
13
+ from langchain_core.tools import tool
14
+
15
 
16
 
17
  def caption_image(state: State) -> State:
 
83
  return state
84
 
85
  def generate_story(state:State)-> State:
86
+ tools=[retrieve_tool]
87
+
88
+ react_agent=create_react_agent(
89
+ model=llm.bind_tools(tools),
90
+ tools=tools
91
+
92
+ )
93
+ if len(state.preferred_topics)==0:
94
+ template = initial_story_prompt(state)
95
+ else:
96
+ template = refined_story_prompt(state)
 
 
 
 
97
 
98
  # and {state.image_captions[-1]}
99
 
100
+ messages = [SystemMessage(content=template)]
101
+
102
+ response = react_agent.invoke({'messages':messages})
103
+ response = response['messages'][-1]
104
+ state.stories.append(response)
105
+ # return State(messages="Story generated", topic=state.topic,stories=state.stories)
106
+ return state
 
 
 
 
 
107
 
108
 
109