First_agent_template / tools /web_search.py
Shane
refactoring with help from Claude to remove duplicated code
5acc791
from typing import Any, Optional, Callable, List
from smolagents.tools import Tool
import duckduckgo_search
import googlesearch
from functools import wraps
def setup_search_dependency(package_name: str, import_func: Callable):
"""Utility function to handle search dependency setup"""
try:
return import_func()
except ImportError as e:
raise ImportError(
f"You must install package `{package_name}` to run this tool: "
f"for instance run `pip install {package_name}`."
) from e
def handle_search_errors(func):
"""Decorator to handle common search error cases"""
@wraps(func)
def wrapper(*args, **kwargs):
try:
results = func(*args, **kwargs)
if not results:
raise Exception("No results found! Try a less restrictive/shorter query.")
return results
except Exception as e:
return f"Error performing search: {str(e)}"
return wrapper
class DuckDuckGoSearchTool(Tool):
name = "web_search"
description = "Performs a duckduckgo web search based on your query (think a Google search) then returns the top search results."
inputs = {'query': {'type': 'string', 'description': 'The search query to perform.'}}
output_type = "string"
def __init__(self, max_results=10, **kwargs):
super().__init__()
self.max_results = max_results
self.ddgs = setup_search_dependency(
'duckduckgo-search',
lambda: duckduckgo_search.DDGS(**kwargs)
)
@handle_search_errors
def forward(self, query: str) -> str:
results = self.ddgs.text(query, max_results=self.max_results)
formatted_results = [
f"[{result['title']}]({result['href']})\n{result['body']}"
for result in results
]
return "## Search Results\n\n" + "\n\n".join(formatted_results)
class GoogleSearchTool(Tool):
name = "google_search"
description = "Performs a Google web search based on your query and returns the top search results."
inputs = {'query': {'type': 'string', 'description': 'The search query to perform.'}}
output_type = "string"
def __init__(self, max_results=10, **kwargs):
super().__init__()
self.max_results = max_results
self.search = setup_search_dependency(
'googlesearch-python',
lambda: googlesearch.search
)
@handle_search_errors
def forward(self, query: str) -> str:
search_results = list(self.search(query, num_results=self.max_results))
formatted_results = [
f"[Result {i+1}]({url})\n{url}"
for i, url in enumerate(search_results)
]
return "## Search Results\n\n" + "\n\n".join(formatted_results)