dennis111 commited on
Commit
20bd124
·
1 Parent(s): 47545b1
Files changed (3) hide show
  1. agent.py +83 -11
  2. app_playground.ipynb +0 -0
  3. requirements.txt +2 -1
agent.py CHANGED
@@ -4,6 +4,9 @@ from langchain.chat_models import init_chat_model
4
  from langchain_core.messages import HumanMessage, SystemMessage, AIMessage, AnyMessage
5
  from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
6
  from langgraph.graph import add_messages, START, END, StateGraph
 
 
 
7
 
8
  from typing_extensions import TypedDict, Annotated
9
 
@@ -30,33 +33,102 @@ def get_graph(llm):
30
  ]
31
  )
32
 
33
- def call_model(state: State):
34
- print("\n-------------------- Agent has been called -----------------------------------\n")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  messages = state["messages"]
 
36
  messages.append(HumanMessage(content="Write a plan how to solve this qustion?"))
 
 
 
 
 
 
 
 
37
 
38
 
39
- prompt_plan = prompt_template.invoke(messages)
40
- plan = llm.invoke(prompt_plan).content
41
- messages.append(AIMessage(content=plan))
42
- messages.append(HumanMessage(content="Now give me the answer to the question."))
 
 
 
43
  prompt_answer = prompt_template.invoke(messages)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  response = llm.invoke(prompt_answer)
 
45
 
 
 
 
 
 
46
 
47
- print("\nThe Prompt is: ", prompt_answer, "\n")
 
 
48
 
49
- print("Agent has made a decision: ",response.content)
50
- return {"messages": [response], "aggregate": ["Agent"]}
 
 
51
 
52
  # Build graph
53
  builder = StateGraph(State)
 
 
54
  builder.add_node("Agent", call_model)
 
 
55
 
56
 
57
  # Logic
58
- builder.add_edge(START, "Agent")
59
- builder.add_edge("Agent", END)
 
 
 
60
 
61
  return builder.compile()
62
 
 
4
  from langchain_core.messages import HumanMessage, SystemMessage, AIMessage, AnyMessage
5
  from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
6
  from langgraph.graph import add_messages, START, END, StateGraph
7
+ from langchain_core.tools import tool
8
+ from langgraph.prebuilt import ToolNode
9
+
10
 
11
  from typing_extensions import TypedDict, Annotated
12
 
 
33
  ]
34
  )
35
 
36
+ from langchain_community.retrievers import WikipediaRetriever
37
+
38
+ # Wikipedia retriever
39
+ wiki_retriever = WikipediaRetriever(load_max_docs =20)
40
+
41
+ @tool
42
+ def retrieve(query: str):
43
+ """
44
+ This function retrieves Wikipedia entries based on the query.
45
+ """
46
+ print("\n-------------------- Tool has been called --------------------\n")
47
+ print("The query is: ", query)
48
+ docs = wiki_retriever.invoke(query)
49
+ serialized = "\n\n".join(
50
+ (f"\nContent:\n{doc.page_content}")
51
+ for doc in docs
52
+ )
53
+
54
+ return serialized
55
+
56
+ tools = [retrieve]
57
+ tool_node = ToolNode(tools)
58
+ llm_with_tools = llm.bind_tools(tools)
59
+
60
+ def make_plan(state: State):
61
+ print("\n-------------------- Starting to create a plan --------------------\n")
62
+ # get all messages from the state
63
  messages = state["messages"]
64
+ # append planning message
65
  messages.append(HumanMessage(content="Write a plan how to solve this qustion?"))
66
+ # create prompt
67
+ prompt = prompt_template.invoke(messages)
68
+ # invoke LLM
69
+ response = llm.invoke(prompt)
70
+ print("The plan is: ", response.content)
71
+ return {"messages": [response], "aggregate": ["Plan"]}
72
+
73
+
74
 
75
 
76
+ def call_model(state: State):
77
+ print("\n-------------------- Agent has been called -----------------------------------\n")
78
+ # get all messages from the state
79
+ messages = state["messages"]
80
+ # append instruction message
81
+ messages.append(HumanMessage(content="Please provide me the answer to the question in detail."))
82
+ # create prompt
83
  prompt_answer = prompt_template.invoke(messages)
84
+ # invoke LLM
85
+ response = llm_with_tools.invoke(prompt_answer)
86
+ print("\nThe Prompt is: ", prompt_answer, "\n")
87
+ print("Agent has made a decision:\n",response, response.content, response.tool_calls)
88
+ print("Type von der Antwort: ",type(response))
89
+ return {"messages": [response], "aggregate": ["Agent"]}
90
+
91
+ def get_answer(state: State):
92
+ # get all messages from the state
93
+ messages = state["messages"]
94
+ # add prompt message
95
+ messages.append(HumanMessage(content="Please provide me just the plain answer to the question"))
96
+ # create prompt
97
+ prompt_answer = prompt_template.invoke(messages)
98
+ # invoke LLM
99
  response = llm.invoke(prompt_answer)
100
+ return {"messages": [response], "aggregate": ["Answer"]}
101
 
102
+ def should_continue(state: State):
103
+ print("\n-------------------- Decision of forwarding has been made --------------------\n")
104
+ messages = state["messages"]
105
+ print(type(messages[-1]))
106
+ print("The last message is: ", messages[-1])
107
 
108
+ if len(state["aggregate"]) < 8:
109
+ last_message = messages[-1]
110
+ if last_message.tool_calls:
111
 
112
+ return "tools"
113
+ return "Answer"
114
+ else:
115
+ return "Answer"
116
 
117
  # Build graph
118
  builder = StateGraph(State)
119
+ builder.add_node("tools", tool_node)
120
+ builder.add_node("Plan", make_plan)
121
  builder.add_node("Agent", call_model)
122
+ builder.add_node("Answer", get_answer)
123
+
124
 
125
 
126
  # Logic
127
+ builder.add_edge(START, "Plan")
128
+ builder.add_edge("Plan", "Agent")
129
+ builder.add_conditional_edges("Agent", should_continue, ["tools", "Answer"])
130
+ builder.add_edge("tools", "Agent")
131
+ builder.add_edge("Answer", END)
132
 
133
  return builder.compile()
134
 
app_playground.ipynb CHANGED
The diff for this file is too large to render. See raw diff
 
requirements.txt CHANGED
@@ -7,4 +7,5 @@ python-dotenv~=1.1.0
7
  typing_extensions~=4.13.2
8
  langgraph~=0.3.34
9
  langchain-core~=0.3.56
10
- langchain-groq~=0.3.2
 
 
7
  typing_extensions~=4.13.2
8
  langgraph~=0.3.34
9
  langchain-core~=0.3.56
10
+ langchain-groq~=0.3.2
11
+ langchain-community ~=0.3.22