Spaces:
Sleeping
Sleeping
| from typing import Any, Optional | |
| from smolagents.tools import Tool | |
| import duckduckgo_search | |
| class DuckDuckGoSearchTool(Tool): | |
| name = "web_search" | |
| description = "Performs a duckduckgo web search based on your query (think a Google search) then returns the top search results." | |
| inputs = {'query': {'type': 'string', 'description': 'The search query to perform.'}} | |
| output_type = "string" | |
| def __init__(self, max_results=10, **kwargs): | |
| super().__init__() | |
| self.max_results = max_results | |
| try: | |
| from duckduckgo_search import DDGS | |
| except ImportError as e: | |
| raise ImportError( | |
| "You must install package `duckduckgo_search` to run this tool: for instance run `pip install duckduckgo-search`." | |
| ) from e | |
| self.ddgs = DDGS(**kwargs) | |
| def forward(self, query: str) -> str: | |
| results = self.ddgs.text(query, max_results=self.max_results) | |
| if len(results) == 0: | |
| raise Exception("No results found! Try a less restrictive/shorter query.") | |
| postprocessed_results = [f"[{result['title']}]({result['href']})\n{result['body']}" for result in results] | |
| return "## Search Results\n\n" + "\n\n".join(postprocessed_results) | |
| class GoogleSearchTool(Tool): | |
| name = "web_search" | |
| description = """Performs a google web search for your query then returns a string of the top search results.""" | |
| inputs = { | |
| "query": {"type": "string", "description": "The search query to perform."}, | |
| "filter_year": { | |
| "type": "integer", | |
| "description": "Optionally restrict results to a certain year", | |
| "nullable": True, | |
| }, | |
| } | |
| output_type = "string" | |
| def __init__(self, provider: str = "serpapi"): | |
| super().__init__() | |
| import os | |
| self.provider = provider | |
| if provider == "serpapi": | |
| self.organic_key = "organic_results" | |
| api_key_env_name = "SERPAPI_API_KEY" | |
| else: | |
| self.organic_key = "organic" | |
| api_key_env_name = "SERPER_API_KEY" | |
| self.api_key = os.getenv(api_key_env_name) | |
| if self.api_key is None: | |
| raise ValueError(f"Missing API key. Make sure you have '{api_key_env_name}' in your env variables.") | |
| def forward(self, query: str, filter_year: Optional[int] = None) -> str: | |
| import requests | |
| if self.provider == "serpapi": | |
| params = { | |
| "q": query, | |
| "api_key": self.api_key, | |
| "engine": "google", | |
| "google_domain": "google.com", | |
| } | |
| base_url = "https://serpapi.com/search.json" | |
| else: | |
| params = { | |
| "q": query, | |
| "api_key": self.api_key, | |
| } | |
| base_url = "https://google.serper.dev/search" | |
| if filter_year is not None: | |
| params["tbs"] = f"cdr:1,cd_min:01/01/{filter_year},cd_max:12/31/{filter_year}" | |
| response = requests.get(base_url, params=params) | |
| if response.status_code == 200: | |
| results = response.json() | |
| else: | |
| raise ValueError(response.json()) | |
| if self.organic_key not in results.keys(): | |
| if filter_year is not None: | |
| raise Exception( | |
| f"No results found for query: '{query}' with filtering on year={filter_year}. Use a less restrictive query or do not filter on year." | |
| ) | |
| else: | |
| raise Exception(f"No results found for query: '{query}'. Use a less restrictive query.") | |
| if len(results[self.organic_key]) == 0: | |
| year_filter_message = f" with filter year={filter_year}" if filter_year is not None else "" | |
| return f"No results found for '{query}'{year_filter_message}. Try with a more general query, or remove the year filter." | |
| web_snippets = [] | |
| if self.organic_key in results: | |
| for idx, page in enumerate(results[self.organic_key]): | |
| date_published = "" | |
| if "date" in page: | |
| date_published = "\nDate published: " + page["date"] | |
| source = "" | |
| if "source" in page: | |
| source = "\nSource: " + page["source"] | |
| snippet = "" | |
| if "snippet" in page: | |
| snippet = "\n" + page["snippet"] | |
| redacted_version = f"{idx}. [{page['title']}]({page['link']}){date_published}{source}\n{snippet}" | |
| web_snippets.append(redacted_version) | |
| return "## Search Results\n" + "\n\n".join(web_snippets) | |
| class VisitWebpageTool(Tool): | |
| name = "visit_webpage" | |
| description = ( | |
| "Visits a webpage at the given url and reads its content as a markdown string. Use this to browse webpages." | |
| ) | |
| inputs = { | |
| "url": { | |
| "type": "string", | |
| "description": "The url of the webpage to visit.", | |
| } | |
| } | |
| output_type = "string" | |
| def __init__(self, max_output_length: int = 40000): | |
| super().__init__() | |
| self.max_output_length = max_output_length | |
| def forward(self, url: str) -> str: | |
| try: | |
| import re | |
| import requests | |
| from markdownify import markdownify | |
| from requests.exceptions import RequestException | |
| from smolagents.utils import truncate_content | |
| except ImportError as e: | |
| raise ImportError( | |
| "You must install packages `markdownify` and `requests` to run this tool: for instance run `pip install markdownify requests`." | |
| ) from e | |
| try: | |
| # Send a GET request to the URL with a 20-second timeout | |
| response = requests.get(url, timeout=20) | |
| response.raise_for_status() # Raise an exception for bad status codes | |
| # Convert the HTML content to Markdown | |
| markdown_content = markdownify(response.text).strip() | |
| # Remove multiple line breaks | |
| markdown_content = re.sub(r"\n{3,}", "\n\n", markdown_content) | |
| return truncate_content(markdown_content, self.max_output_length) | |
| except requests.exceptions.Timeout: | |
| return "The request timed out. Please try again later or check the URL." | |
| except RequestException as e: | |
| return f"Error fetching the webpage: {str(e)}" | |
| except Exception as e: | |
| return f"An unexpected error occurred: {str(e)}" | |