wt002 commited on
Commit
fe25c9a
·
verified ·
1 Parent(s): 7b1f7dd

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +50 -7
agent.py CHANGED
@@ -55,18 +55,54 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, BitsAndB
55
  from langchain_huggingface import HuggingFaceEndpoint
56
  from langchain.agents import initialize_agent
57
  from langchain.agents import AgentType
 
 
 
58
 
59
 
60
  load_dotenv()
61
 
62
 
 
 
63
  @tool
64
- def calculator(inputs: dict):
65
- """Perform mathematical operations based on the operation provided."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  a = inputs.get("a")
67
  b = inputs.get("b")
68
- operation = inputs.get("operation")
69
-
 
 
 
70
  if operation == "add":
71
  return a + b
72
  elif operation == "subtract":
@@ -80,7 +116,8 @@ def calculator(inputs: dict):
80
  elif operation == "modulus":
81
  return a % b
82
  else:
83
- return "Unknown operation"
 
84
 
85
 
86
 
@@ -638,8 +675,14 @@ def select_tool_and_run(question: str, tools: list) -> any:
638
  print(f"No matching tool found for {intent}")
639
  return None
640
 
641
- # Step 4: Call the tool function and return the result
642
- return tool_func.run(question) # or tool_func(question), depending on how your tools are structured
 
 
 
 
 
 
643
 
644
  # Run the function to select and execute the tool
645
  result = select_tool_and_run(question, tools)
 
55
  from langchain_huggingface import HuggingFaceEndpoint
56
  from langchain.agents import initialize_agent
57
  from langchain.agents import AgentType
58
+ from typing import Union, List
59
+ from functools import reduce
60
+ import operator
61
 
62
 
63
  load_dotenv()
64
 
65
 
66
+
67
+
68
  @tool
69
+ def calculator(inputs: Union[str, dict]):
70
+ """
71
+ Perform mathematical operations based on the operation provided.
72
+ Supports both binary (a, b) operations and list operations.
73
+ """
74
+
75
+ # If input is a JSON string, parse it
76
+ if isinstance(inputs, str):
77
+ try:
78
+ import json
79
+ inputs = json.loads(inputs)
80
+ except Exception as e:
81
+ return f"Invalid input format: {e}"
82
+
83
+ # Handle list-based operations like SUM
84
+ if "list" in inputs:
85
+ nums = inputs.get("list", [])
86
+ op = inputs.get("operation", "").lower()
87
+
88
+ if not isinstance(nums, list) or not all(isinstance(n, (int, float)) for n in nums):
89
+ return "Invalid list input. Must be a list of numbers."
90
+
91
+ if op == "sum":
92
+ return sum(nums)
93
+ elif op == "multiply":
94
+ return reduce(operator.mul, nums, 1)
95
+ else:
96
+ return f"Unsupported list operation: {op}"
97
+
98
+ # Handle basic two-number operations
99
  a = inputs.get("a")
100
  b = inputs.get("b")
101
+ operation = inputs.get("operation", "").lower()
102
+
103
+ if a is None or b is None or not isinstance(a, (int, float)) or not isinstance(b, (int, float)):
104
+ return "Both 'a' and 'b' must be numbers."
105
+
106
  if operation == "add":
107
  return a + b
108
  elif operation == "subtract":
 
116
  elif operation == "modulus":
117
  return a % b
118
  else:
119
+ return f"Unknown operation: {operation}"
120
+
121
 
122
 
123
 
 
675
  print(f"No matching tool found for {intent}")
676
  return None
677
 
678
+ # If question was transformed into JSON for tool input
679
+ try:
680
+ parsed_input = json.loads(question)
681
+ except json.JSONDecodeError:
682
+ parsed_input = question # fallback if question is not JSON
683
+
684
+ return tool_func.run(parsed_input)
685
+
686
 
687
  # Run the function to select and execute the tool
688
  result = select_tool_and_run(question, tools)