Artem Zhirkevich commited on
Commit
ea48b73
·
1 Parent(s): f590bb2

refactor agent, tools, and libs

Browse files
Files changed (5) hide show
  1. agent.py +11 -222
  2. app.py +2 -2
  3. dry_run.py +3 -3
  4. requirements.txt +1 -16
  5. tools.py +213 -0
agent.py CHANGED
@@ -1,247 +1,36 @@
1
  import os
2
  import time
3
- import tempfile
4
- import requests
5
- import pytesseract
6
- import wikipedia
7
- import mwclient
8
- import pandas as pd
9
- import easyocr
10
- from typing import List, Optional, Dict, Any
11
- from urllib.parse import urlparse
12
  from dotenv import load_dotenv
13
- from PIL import Image
14
- from tavily import TavilyClient
15
- from arxiv import Search, Client, SortCriterion, SortOrder
16
 
17
  from langgraph.graph.state import CompiledStateGraph
18
  from langgraph.graph import START, StateGraph, MessagesState
19
  from langgraph.prebuilt import tools_condition
20
  from langgraph.prebuilt import ToolNode
21
 
22
- from langchain_groq import ChatGroq
23
  from langchain_core.messages import HumanMessage, SystemMessage
24
  from langchain_google_genai import ChatGoogleGenerativeAI
25
- from langchain.memory import ConversationBufferMemory
26
  from langchain.tools import Tool, tool
27
  from langchain.callbacks.tracers import ConsoleCallbackHandler
28
- from langchain_community.utilities import DuckDuckGoSearchAPIWrapper
29
- from langchain_community.utilities import WikipediaAPIWrapper
30
- from langchain_experimental.utilities import PythonREPL
31
- from langchain_community.document_loaders import WebBaseLoader
32
-
33
-
34
- load_dotenv()
35
-
36
- vision_llm = ChatGroq(model="meta-llama/llama-4-scout-17b-16e-instruct", groq_api_key=os.getenv('GROQ_API_KEY'))
37
-
38
-
39
- @tool
40
- def web_search(query: str, domain: Optional[str] = None) -> str:
41
- """
42
- Perform a web search and return the raw results as a string.
43
-
44
- Args:
45
- query (str): The search query.
46
- domain (Optional[str]): If provided, restricts the search to this domain.
47
-
48
- Returns:
49
- str: Raw search results concatenated into a string.
50
- """
51
- try:
52
- time.sleep(2)
53
- search = DuckDuckGoSearchAPIWrapper()
54
- if domain:
55
- query = f"{query} site:{domain}"
56
- results = search.results(query, max_results=3)
57
-
58
- if not results:
59
- return "No results found."
60
-
61
- # Format into simple title + snippet
62
- formatted = ""
63
- for r in results:
64
- formatted += f"Title: {r['title']}\nURL: {r['link']}\nSnippet: {r['snippet']}\n\n"
65
- return formatted.strip()
66
-
67
- except Exception as e:
68
- return f"Search error: {e}"
69
-
70
-
71
- @tool
72
- def visit_webpage(url: str):
73
- """
74
- Fetches and loads the content of a webpage given its URL.
75
-
76
- Parameters:
77
- url (str): The URL of the webpage to be visited.
78
-
79
- Returns:
80
- str: A string containing the loaded content of the webpage.
81
- """
82
-
83
- # Initialize a WebBaseLoader with the provided URL
84
- loader = WebBaseLoader(url)
85
-
86
- # Set requests_kwargs to disable SSL certificate verification
87
- # This can help bypass SSL certificate errors but should be used cautiously
88
- loader.requests_kwargs = {'verify': False}
89
-
90
- # Load the webpage content using the loader
91
- docs = loader.load()
92
-
93
- # Return the loaded content formatted as a string
94
- return f"Page content: {docs}"
95
-
96
-
97
- @tool
98
- def wikipedia_search(query: str, max_docs: int = 1) -> str:
99
- """
100
- Search Wikipedia using mwclient and return exactly `max_docs` results.
101
-
102
- Args:
103
- query (str): The search query.
104
- max_docs (int): Number of results to return. Default is 1.
105
- """
106
- try:
107
- time.sleep(2)
108
- site = mwclient.Site("en.wikipedia.org")
109
- results = site.search(query, limit=max_docs)
110
-
111
- output = ""
112
- count = 0
113
-
114
- for page_info in results:
115
- title = page_info["title"]
116
- try:
117
- page = site.pages[title]
118
- content = page.text()
119
- first_paragraph = content.split('\n\n')[0]
120
-
121
- url = f"https://en.wikipedia.org/wiki/{title.replace(' ', '_')}"
122
 
123
- output += (
124
- f"--- Result {count + 1} ---\n"
125
- f"Title: {title}\n"
126
- f"Summary: {first_paragraph}...\n"
127
- f"URL: {url}\n\n"
128
- )
129
- count += 1
130
- if count >= max_docs:
131
- break
132
 
133
- except Exception:
134
- continue
135
 
136
- return output.strip() or "No valid matching pages found."
137
-
138
- except Exception as e:
139
- return f"Wikipedia search error: {str(e)}"
140
-
141
-
142
- @tool
143
- def extract_text_from_image(image_path: str) -> str:
144
- """
145
- Extracts text from an image file.
146
-
147
- Args:
148
- image_path (str): The file path to the image
149
- (e.g., '/path/to/document.png').
150
-
151
- Returns:
152
- str: Extracted text paragraphs separated by newlines,
153
- prefixed with "Extracted text:\n". Returns an error message
154
- string starting with 'Error:' on failure.
155
- """
156
-
157
- try:
158
- time.sleep(2)
159
-
160
- with open(image_path, "rb") as image_file:
161
- image_bytes = image_file.read()
162
-
163
- image_base64 = base64.b64encode(image_bytes).decode("utf-8")
164
-
165
- message = [
166
- HumanMessage(
167
- content=[
168
- {
169
- "type": "text",
170
- "text": (
171
- "Extract text or provide explanation of this image"
172
- ),
173
- },
174
- {
175
- "type": "image_url",
176
- "image_url": {
177
- "url": f"data:image/png;base64,{image_base64}"
178
- },
179
- },
180
- ]
181
- )
182
- ]
183
-
184
- response = vision_llm.invoke(message)
185
-
186
- all_text = response.content + "\n\n"
187
-
188
- return all_text.strip()
189
- except Exception as e:
190
- # A butler should handle errors gracefully
191
- error_msg = f"Error extracting text: {str(e)}"
192
- print(error_msg)
193
- return ""
194
-
195
-
196
-
197
- @tool
198
- def analyze_file(file_path: str) -> str:
199
- """
200
- Load and analyze a CSV or Excel file using pandas.
201
-
202
- Provides basic metadata and summary statistics for numeric columns.
203
-
204
- Args:
205
- file_path (str): Path to the CSV or Excel file.
206
-
207
- Returns:
208
- str: Summary statistics and metadata about the file data.
209
- """
210
- try:
211
- # Determine file type
212
- _, ext = os.path.splitext(file_path.lower())
213
-
214
- if ext == '.csv':
215
- df = pd.read_csv(file_path)
216
- elif ext in ['.xls', '.xlsx']:
217
- df = pd.read_excel(file_path)
218
- else:
219
- return f"Error: Unsupported file extension '{ext}'. Supported: .csv, .xls, .xlsx"
220
-
221
- result = "Summary statistics for numeric columns:\n"
222
- result += str(df.describe())
223
- result += "\n\n"
224
-
225
- result += f"Columns: {', '.join(df.columns)}\n\n"
226
- result += "Content:\n"
227
- result += df.astype(str).head(1000).to_string(index=False)
228
-
229
- return result
230
-
231
- except ImportError:
232
- return "Error: Required libraries are not installed. Install with 'pip install pandas openpyxl'."
233
- except FileNotFoundError:
234
- return f"Error: File not found at path '{file_path}'."
235
- except Exception as e:
236
- return f"Error analyzing file: {str(e)}"
237
 
238
 
239
- class Agent:
240
 
241
  _api_key: str
242
  _model_name: str
243
  _tools: List[Tool]
244
- _memory: ConversationBufferMemory
245
  _llm: ChatGoogleGenerativeAI
246
  _graph: CompiledStateGraph
247
 
 
1
  import os
2
  import time
3
+
4
+ from typing import List
 
 
 
 
 
 
 
5
  from dotenv import load_dotenv
 
 
 
6
 
7
  from langgraph.graph.state import CompiledStateGraph
8
  from langgraph.graph import START, StateGraph, MessagesState
9
  from langgraph.prebuilt import tools_condition
10
  from langgraph.prebuilt import ToolNode
11
 
 
12
  from langchain_core.messages import HumanMessage, SystemMessage
13
  from langchain_google_genai import ChatGoogleGenerativeAI
 
14
  from langchain.tools import Tool, tool
15
  from langchain.callbacks.tracers import ConsoleCallbackHandler
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
+ from tools import (
18
+ web_search,
19
+ visit_webpage,
20
+ wikipedia_search,
21
+ extract_text_from_image,
22
+ analyze_file,
23
+ )
 
 
24
 
 
 
25
 
26
+ load_dotenv()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
 
29
+ class GeminiAgent:
30
 
31
  _api_key: str
32
  _model_name: str
33
  _tools: List[Tool]
 
34
  _llm: ChatGoogleGenerativeAI
35
  _graph: CompiledStateGraph
36
 
app.py CHANGED
@@ -5,7 +5,7 @@ import gradio as gr
5
  import requests
6
  import inspect
7
  import pandas as pd
8
- from agent import Agent
9
  from evaluation_api import EvaluationApi
10
 
11
 
@@ -41,7 +41,7 @@ def run_and_submit_all( profile: gr.OAuthProfile | None):
41
 
42
  # 1. Instantiate Agent ( modify this part to create your agent)
43
  try:
44
- agent = Agent()
45
  except Exception as e:
46
  return f"Error initializing agent: {e}", None
47
 
 
5
  import requests
6
  import inspect
7
  import pandas as pd
8
+ from agent import GeminiAgent
9
  from evaluation_api import EvaluationApi
10
 
11
 
 
41
 
42
  # 1. Instantiate Agent ( modify this part to create your agent)
43
  try:
44
+ agent = GeminiAgent()
45
  except Exception as e:
46
  return f"Error initializing agent: {e}", None
47
 
dry_run.py CHANGED
@@ -3,9 +3,9 @@ import tempfile
3
  import json
4
  import os
5
 
6
- from agent import Agent
7
 
8
- random.seed(1)
9
 
10
  def get_question(file_path: str) -> str:
11
  with open(file_path, "r") as file:
@@ -45,7 +45,7 @@ print(json.dumps(question, indent=2))
45
 
46
  # print(file_path)
47
 
48
- agent = Agent()
49
 
50
  # messages = agent.run(f"Question: `{question["Question"]}` File path: {file_path}")
51
  messages = agent.run(f"Question: `{question["Question"]}`")
 
3
  import json
4
  import os
5
 
6
+ from agent import GeminiAgent
7
 
8
+ random.seed(3)
9
 
10
  def get_question(file_path: str) -> str:
11
  with open(file_path, "r") as file:
 
45
 
46
  # print(file_path)
47
 
48
+ agent = GeminiAgent()
49
 
50
  # messages = agent.run(f"Question: `{question["Question"]}` File path: {file_path}")
51
  messages = agent.run(f"Question: `{question["Question"]}`")
requirements.txt CHANGED
@@ -3,8 +3,6 @@ requests
3
  pandas
4
  openpyxl
5
  openai
6
- google-genai
7
- google-generativeai
8
  langchain
9
  langchain-community
10
  langchain-core
@@ -12,19 +10,6 @@ langchain-google-genai
12
  langgraph
13
  huggingface_hub
14
  python-dotenv
15
- wikipedia-api
16
- wikipedia
17
- arxiv
18
- datasets
19
- yt-dlp
20
- google-cloud-speech
21
- google-api-python-client
22
  duckduckgo-search
23
- pytesseract
24
- tavily-python
25
  langchain_groq
26
- langchain-tavily
27
- mwclient
28
- langchain_experimental
29
- easyocr
30
- smolagents
 
3
  pandas
4
  openpyxl
5
  openai
 
 
6
  langchain
7
  langchain-community
8
  langchain-core
 
10
  langgraph
11
  huggingface_hub
12
  python-dotenv
 
 
 
 
 
 
 
13
  duckduckgo-search
 
 
14
  langchain_groq
15
+ mwclient
 
 
 
 
tools.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import requests
4
+ import mwclient
5
+ from typing import Optional
6
+ from dotenv import load_dotenv
7
+
8
+ from langchain_groq import ChatGroq
9
+ from langchain_core.messages import HumanMessage
10
+ from langchain.tools import tool
11
+ from langchain_community.utilities import DuckDuckGoSearchAPIWrapper
12
+ from langchain_community.document_loaders import WebBaseLoader
13
+
14
+
15
+ load_dotenv()
16
+
17
+ vision_llm = ChatGroq(model="meta-llama/llama-4-scout-17b-16e-instruct", groq_api_key=os.getenv('GROQ_API_KEY'))
18
+
19
+
20
+ @tool
21
+ def web_search(query: str, domain: Optional[str] = None) -> str:
22
+ """
23
+ Perform a web search and return the raw results as a string.
24
+
25
+ Args:
26
+ query (str): The search query.
27
+ domain (Optional[str]): If provided, restricts the search to this domain.
28
+
29
+ Returns:
30
+ str: Raw search results concatenated into a string.
31
+ """
32
+
33
+ try:
34
+ time.sleep(2)
35
+
36
+ search = DuckDuckGoSearchAPIWrapper()
37
+ if domain:
38
+ query = f"{query} site:{domain}"
39
+ results = search.results(query, max_results=3)
40
+
41
+ if not results:
42
+ return "No results found."
43
+
44
+ formatted = ""
45
+ for r in results:
46
+ formatted += f"Title: {r['title']}\nURL: {r['link']}\nSnippet: {r['snippet']}\n\n"
47
+ return formatted.strip()
48
+
49
+ except Exception as e:
50
+ return f"Search error: {e}"
51
+
52
+
53
+ @tool
54
+ def visit_webpage(url: str):
55
+ """
56
+ Fetches and loads the content of a webpage given its URL.
57
+
58
+ Parameters:
59
+ url (str): The URL of the webpage to be visited.
60
+
61
+ Returns:
62
+ str: A string containing the loaded content of the webpage.
63
+ """
64
+
65
+ loader = WebBaseLoader(url)
66
+ loader.requests_kwargs = {'verify': False}
67
+
68
+ docs = loader.load()
69
+
70
+ return f"Page content: {docs}"
71
+
72
+
73
+ @tool
74
+ def wikipedia_search(query: str, max_docs: int = 1) -> str:
75
+ """
76
+ Search Wikipedia using mwclient and return exactly `max_docs` results.
77
+
78
+ Args:
79
+ query (str): The search query.
80
+ max_docs (int): Number of results to return. Default is 1.
81
+ """
82
+
83
+ try:
84
+ time.sleep(2)
85
+
86
+ site = mwclient.Site("en.wikipedia.org")
87
+ results = site.search(query, limit=max_docs)
88
+
89
+ output = ""
90
+ count = 0
91
+
92
+ for page_info in results:
93
+ title = page_info["title"]
94
+ try:
95
+ page = site.pages[title]
96
+ content = page.text()
97
+ first_paragraph = content.split('\n\n')[0]
98
+
99
+ url = f"https://en.wikipedia.org/wiki/{title.replace(' ', '_')}"
100
+
101
+ output += (
102
+ f"--- Result {count + 1} ---\n"
103
+ f"Title: {title}\n"
104
+ f"Summary: {first_paragraph}...\n"
105
+ f"URL: {url}\n\n"
106
+ )
107
+ count += 1
108
+ if count >= max_docs:
109
+ break
110
+
111
+ except Exception:
112
+ continue
113
+
114
+ return output.strip() or "No valid matching pages found."
115
+
116
+ except Exception as e:
117
+ return f"Wikipedia search error: {str(e)}"
118
+
119
+
120
+ @tool
121
+ def extract_text_from_image(image_path: str) -> str:
122
+ """
123
+ Extracts text from an image file.
124
+
125
+ Args:
126
+ image_path (str): The file path to the image
127
+ (e.g., '/path/to/document.png').
128
+
129
+ Returns:
130
+ str: Extracted text paragraphs separated by newlines,
131
+ prefixed with "Extracted text:\n". Returns an error message
132
+ string starting with 'Error:' on failure.
133
+ """
134
+
135
+ try:
136
+ time.sleep(2)
137
+
138
+ with open(image_path, "rb") as image_file:
139
+ image_bytes = image_file.read()
140
+
141
+ image_base64 = base64.b64encode(image_bytes).decode("utf-8")
142
+
143
+ message = [
144
+ HumanMessage(
145
+ content=[
146
+ {
147
+ "type": "text",
148
+ "text": (
149
+ "Extract text or provide explanation of this image"
150
+ ),
151
+ },
152
+ {
153
+ "type": "image_url",
154
+ "image_url": {
155
+ "url": f"data:image/png;base64,{image_base64}"
156
+ },
157
+ },
158
+ ]
159
+ )
160
+ ]
161
+
162
+ response = vision_llm.invoke(message)
163
+
164
+ all_text = response.content + "\n\n"
165
+
166
+ return all_text.strip()
167
+ except Exception as e:
168
+ error_msg = f"Error extracting text: {str(e)}"
169
+ print(error_msg)
170
+ return ""
171
+
172
+
173
+
174
+ @tool
175
+ def analyze_file(file_path: str) -> str:
176
+ """
177
+ Load and analyze a CSV or Excel file using pandas.
178
+
179
+ Provides basic metadata and summary statistics for numeric columns.
180
+
181
+ Args:
182
+ file_path (str): Path to the CSV or Excel file.
183
+
184
+ Returns:
185
+ str: Summary statistics and metadata about the file data.
186
+ """
187
+
188
+ try:
189
+ _, ext = os.path.splitext(file_path.lower())
190
+
191
+ if ext == '.csv':
192
+ df = pd.read_csv(file_path)
193
+ elif ext in ['.xls', '.xlsx']:
194
+ df = pd.read_excel(file_path)
195
+ else:
196
+ return f"Error: Unsupported file extension '{ext}'. Supported: .csv, .xls, .xlsx"
197
+
198
+ result = "Summary statistics for numeric columns:\n"
199
+ result += str(df.describe())
200
+ result += "\n\n"
201
+
202
+ result += f"Columns: {', '.join(df.columns)}\n\n"
203
+ result += "Content:\n"
204
+ result += df.astype(str).head(1000).to_string(index=False)
205
+
206
+ return result
207
+
208
+ except ImportError:
209
+ return "Error: Required libraries are not installed. Install with 'pip install pandas openpyxl'."
210
+ except FileNotFoundError:
211
+ return f"Error: File not found at path '{file_path}'."
212
+ except Exception as e:
213
+ return f"Error analyzing file: {str(e)}"