File size: 6,758 Bytes
f0eb1da
 
 
 
 
 
 
fb62e9e
 
 
 
 
 
 
 
 
 
 
6138f05
f0eb1da
fb62e9e
6138f05
d10ae74
5359065
f0eb1da
fb62e9e
2ae392c
fb62e9e
 
 
 
 
6138f05
2ae392c
d10ae74
2f68dd7
b17b706
d10ae74
 
 
 
 
 
 
 
b17b706
 
 
 
 
3929291
d10ae74
f0eb1da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d10ae74
 
 
f0eb1da
d10ae74
 
 
f0eb1da
fb62e9e
 
 
 
 
 
 
 
 
d10ae74
 
6138f05
 
f0eb1da
d10ae74
 
6138f05
 
f0eb1da
2f68dd7
d10ae74
 
 
 
b17b706
d10ae74
2f68dd7
d10ae74
 
b17b706
d10ae74
fb62e9e
d10ae74
f0eb1da
 
 
d10ae74
f0eb1da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d10ae74
 
3929291
d10ae74
 
 
 
 
f0eb1da
d10ae74
 
 
 
f0eb1da
d10ae74
 
f0eb1da
d10ae74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f0eb1da
d10ae74
 
f0eb1da
d10ae74
 
f0eb1da
d10ae74
 
 
 
 
 
 
 
 
f0eb1da
 
d10ae74
f0eb1da
 
 
d10ae74
f0eb1da
d10ae74
2ae392c
 
d10ae74
2ae392c
d10ae74
2ae392c
d10ae74
 
f0eb1da
d10ae74
fb62e9e
d10ae74
 
f0eb1da
d10ae74
f0eb1da
d10ae74
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
from smolagents import (
    CodeAgent,
    LiteLLMModel,
    DuckDuckGoSearchTool,
    PythonInterpreterTool,
    VisitWebpageTool,
)
from src.tools import (
    transcribe_audio_file,
    transcribe_from_youtube,
    read_excel_file,
    wiki_search,
    multiply,
    add,
    subtract,
    divide,
    modulus,
)
import os
from typing import List
from PIL import Image
from dotenv import load_dotenv


load_dotenv()

SYSTEM_PROMPT = """
You are a helpful assistant tasked with answering questions using a set of tools. 
Now, I will ask you a question. Report your thoughts, and finish your answer with the following template: 
FINAL ANSWER: [YOUR FINAL ANSWER]. 
YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
Your answer should only start with "FINAL ANSWER: ", then follows with the answer. 
"""


class CustomAgent:
    def __init__(
        self,
        model_id: str = "gemini/gemini-2.0-flash",
        additional_imports: List[str] = None,
        logging=False,
        max_steps=10,
        verbose: bool = False,
        executor_type: str = "local",
        timeout: int = 120,
    ):
        """
        Initialize the CustomAgent with a model and tools.
        If no model is provided, a default one is used.
        """
        self.logging = logging
        self.verbose = verbose
        self.imports = [
            "pandas",
            "numpy",
            "io",
            "datetime",
            "json",
            "re",
            "math",
            "os",
            "requests",
            "csv",
            "urllib",
            "youtube-transcript-api",
            "SpeechRecognition",
            "pydub",
        ]
        if additional_imports:
            self.imports.extend(additional_imports)

        # Initialize tools
        self.tools = [
            DuckDuckGoSearchTool(),
            PythonInterpreterTool(),
            VisitWebpageTool(),
            wiki_search,
            transcribe_audio_file,
            transcribe_from_youtube,
            read_excel_file,
            multiply,
            add,
            subtract,
            divide,
            modulus,
        ]

        # Initialize the model
        model = LiteLLMModel(
            model_id=model_id,
            api_key=os.getenv("GEMINI_API_KEY"),
            timeout=timeout,
        )

        # Initialize the CodeAgent
        self.agent = CodeAgent(
            model=model,
            tools=self.tools,
            additional_authorized_imports=self.imports,
            executor_type=executor_type,
            max_steps=max_steps,
            verbosity_level=2 if verbose else 0,
        )
        if self.verbose:
            print("CustomAgent initialized.")

    def forward(self, question: str, file_path) -> str:
        print(f"QUESTION: {question[:100]}...")
        try:
            full_prompt = f"""Question: {question}
                
            {SYSTEM_PROMPT}"""
            if file_path:
                file_path_ext = os.path.splitext(file_path)[1]
                if file_path_ext.lower() in [".jpg", ".jpeg", ".png"]:
                    image = Image.open(file_path).convert("RGB")
                    answer = self.agent.run(full_prompt, images=[image])
                elif file_path_ext.lower() in [".txt", ".py"]:
                    with open(file_path, "r") as f:
                        content = f.read()
                    full_prompt = f"""Question: {question}
                    File content: ```{content}```

                    {SYSTEM_PROMPT}"""
                    answer = self.agent.run(full_prompt)
                else:
                    full_prompt = f"""Question: {question}
                    File path: {file_path}

                    {SYSTEM_PROMPT}"""
                    answer = self.agent.run(full_prompt)
            else:
                answer = self.agent.run(full_prompt)
            answer = self._clean_answer(answer)
            return answer

        except Exception as e:
            error_msg = f"Error answering question: {e}"
            if self.verbose:
                print(error_msg)
            return error_msg

    def _clean_answer(self, answer: any) -> str:
        """
        Clean up the answer to remove common prefixes and formatting
        that models often add but that can cause exact match failures.

        Args:
            answer: The raw answer from the model

        Returns:
            The cleaned answer as a string
        """
        # Convert non-string types to strings
        if not isinstance(answer, str):
            if isinstance(answer, float):
                if answer.is_integer():
                    formatted_answer = str(int(answer))
                else:
                    if abs(answer) >= 1000:
                        formatted_answer = f"${answer:,.2f}"
                    else:
                        formatted_answer = str(answer)
                return formatted_answer
            elif isinstance(answer, int):
                return str(answer)
            else:
                return str(answer)

        # Normalize whitespace
        answer = answer.strip()

        # Remove common prefixes and formatting that models add
        prefixes_to_remove = [
            "The answer is ",
            "Answer: ",
            "Final answer: ",
            "The result is ",
            "To answer this question: ",
            "Based on the information provided, ",
            "According to the information: ",
        ]
        for prefix in prefixes_to_remove:
            if answer.startswith(prefix):
                answer = answer[len(prefix) :].strip()

        # Remove quotes if they wrap the entire answer
        if (answer.startswith('"') and answer.endswith('"')) or (
            answer.startswith("'") and answer.endswith("'")
        ):
            answer = answer[1:-1].strip()

        return answer


def get_config():
    """
    Get the agent configuration based on environment variables
    """
    # Default configuration
    config = {
        "model_id": "gemini/gemini-2.5-flash-preview-04-17",
        "logging": False,
        "max_steps": 10,
        "verbose": False,
        "executor_type": "local",
        "timeout": 120,
    }

    return config