HCho commited on
Commit
5bd8e40
·
verified ·
1 Parent(s): 05b2c49

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +13 -6
agent.py CHANGED
@@ -100,11 +100,18 @@ tools = [
100
  ]
101
 
102
 
103
- def build_graph(provider: str = "google"):
104
- llm = ChatGoogleGenerativeAI(model="gemini-2.5-flash-preview-05-20", temperature=0, api_key=google_api_key)
105
-
106
- llm_with_tools = llm.bind_tools(tools)
107
-
 
 
 
 
 
 
 
108
  def assistant(state: MessagesState):
109
  """ Use the tools to answer the query. you have add,subtract,multiply,divide,web_search,wikipedia_search,arxiv_search tools."""
110
  response = llm_with_tools.invoke([system_message]+state["messages"])
@@ -128,7 +135,7 @@ def build_graph(provider: str = "google"):
128
  if __name__ == "__main__":
129
  question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
130
  # Build the graph
131
- graph = build_graph(provider="google")
132
  # Run the graph
133
  messages = [HumanMessage(content=question)]
134
  messages = graph.invoke({"messages": messages})
 
100
  ]
101
 
102
 
103
+ def build_graph(provider: str):
104
+ if provider == "google":
105
+ # Google Gemini
106
+ llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0,api_key=google_api_key)
107
+ elif provider == "huggingface":
108
+ # TODO: Add huggingface endpoint
109
+ llm = ChatHuggingFace(
110
+ llm=HuggingFaceEndpoint(
111
+ url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
112
+ temperature=0,
113
+ ),
114
+ )
115
  def assistant(state: MessagesState):
116
  """ Use the tools to answer the query. you have add,subtract,multiply,divide,web_search,wikipedia_search,arxiv_search tools."""
117
  response = llm_with_tools.invoke([system_message]+state["messages"])
 
135
  if __name__ == "__main__":
136
  question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
137
  # Build the graph
138
+ graph = build_graph(provider="huggingface")
139
  # Run the graph
140
  messages = [HumanMessage(content=question)]
141
  messages = graph.invoke({"messages": messages})