import logging import re from typing import Optional import requests from smolagents import Tool from smolagents.default_tools import DuckDuckGoSearchTool logger = logging.getLogger(__name__) class SmartSearchTool(Tool): name = "smart_search" description = """A smart search tool that searches Wikipedia for information.""" inputs = { "query": { "type": "string", "description": "The search query to find information", } } output_type = "string" def __init__(self): super().__init__() self.web_search_tool = DuckDuckGoSearchTool(max_results=1) self.api_url = "https://en.wikipedia.org/w/api.php" self.headers = { "User-Agent": "SmartSearchTool/1.0 (https://github.com/yourusername/yourproject; your@email.com)" } def get_wikipedia_page(self, title: str) -> Optional[str]: """Get the raw wiki markup of a Wikipedia page.""" try: params = { "action": "query", "prop": "revisions", "rvprop": "content", "rvslots": "main", "format": "json", "titles": title, "redirects": 1, } response = requests.get(self.api_url, params=params, headers=self.headers) response.raise_for_status() data = response.json() # Extract page content pages = data.get("query", {}).get("pages", {}) for page_id, page_data in pages.items(): if "revisions" in page_data: return page_data["revisions"][0]["slots"]["main"]["*"] return None except Exception as e: logger.error(f"Error getting Wikipedia page: {e}") return None def clean_wiki_content(self, content: str) -> str: """Clean Wikipedia content by removing markup and formatting.""" # Remove citations content = re.sub(r"\[\d+\]", "", content) # Remove edit links content = re.sub(r"\[edit\]", "", content) # Remove file links content = re.sub(r"\[\[File:.*?\]\]", "", content) # Convert links to just text content = re.sub(r"\[\[(?:[^|\]]*\|)?([^\]]+)\]\]", r"\1", content) # Remove HTML comments content = re.sub(r"", "", content, flags=re.DOTALL) # Remove templates content = re.sub(r"\{\{.*?\}\}", "", content) # Remove small tags content = re.sub(r".*?", "", content) # Normalize whitespace content = re.sub(r"\n\s*\n", "\n\n", content) return content.strip() def format_wiki_table(self, table_content: str) -> str: """Format a Wikipedia table into readable text.""" # Split into rows rows = table_content.strip().split("\n") formatted_rows = [] current_row = [] for row in rows: # Skip empty rows and table structure markers if not row.strip() or row.startswith("|-") or row.startswith("|+"): if current_row: formatted_rows.append("\t".join(current_row)) current_row = [] continue # Extract cells cells = [] # Split the row into cells using | or ! as separators cell_parts = re.split(r"[|!]", row) for cell in cell_parts[1:]: # Skip the first empty part # Clean up the cell content cell = cell.strip() # Remove any remaining markup cell = re.sub(r"<.*?>", "", cell) # Remove HTML tags cell = re.sub(r"\[\[.*?\|(.*?)\]\]", r"\1", cell) # Convert links cell = re.sub(r"\[\[(.*?)\]\]", r"\1", cell) # Convert simple links cell = re.sub(r"\{\{.*?\}\}", "", cell) # Remove templates cell = re.sub(r".*?", "", cell) # Remove small tags cell = re.sub(r'rowspan="\d+"', "", cell) # Remove rowspan cell = re.sub(r'colspan="\d+"', "", cell) # Remove colspan cell = re.sub(r'class=".*?"', "", cell) # Remove class attributes cell = re.sub(r'style=".*?"', "", cell) # Remove style attributes cell = re.sub(r'align=".*?"', "", cell) # Remove align attributes cell = re.sub(r'width=".*?"', "", cell) # Remove width attributes cell = re.sub(r'bgcolor=".*?"', "", cell) # Remove bgcolor attributes cell = re.sub(r'valign=".*?"', "", cell) # Remove valign attributes cell = re.sub(r'border=".*?"', "", cell) # Remove border attributes cell = re.sub( r'cellpadding=".*?"', "", cell ) # Remove cellpadding attributes cell = re.sub( r'cellspacing=".*?"', "", cell ) # Remove cellspacing attributes cell = re.sub(r"", "", cell) # Remove references cell = re.sub(r"", "", cell) # Remove empty references cell = re.sub( r"", " ", cell ) # Replace line breaks with spaces cell = re.sub(r"\s+", " ", cell) # Normalize whitespace cells.append(cell) if cells: current_row.extend(cells) if current_row: formatted_rows.append("\t".join(current_row)) if formatted_rows: return "\n".join(formatted_rows) return "" def extract_wikipedia_title(self, search_result: str) -> Optional[str]: """Extract Wikipedia page title from search result.""" # Look for Wikipedia links in the format [Title - Wikipedia](url) wiki_match = re.search( r"\[([^\]]+)\s*-\s*Wikipedia\]\(https://en\.wikipedia\.org/wiki/[^)]+\)", search_result, ) if wiki_match: return wiki_match.group(1).strip() return None def forward(self, query: str) -> str: logger.info(f"Starting smart search for query: {query}") # First do a web search to find the Wikipedia page search_result = self.web_search_tool.forward(query) logger.info(f"Web search results: {search_result[:100]}...") # Extract Wikipedia page title from search results wiki_title = self.extract_wikipedia_title(search_result) if not wiki_title: return f"Could not find Wikipedia page in search results for '{query}'." # Get Wikipedia page content page_content = self.get_wikipedia_page(wiki_title) if not page_content: return f"Could not find Wikipedia page for '{wiki_title}'." # Format tables and content formatted_content = [] current_section = [] in_table = False table_content = [] for line in page_content.split("\n"): if line.startswith("{|"): in_table = True table_content = [line] elif line.startswith("|}"): in_table = False table_content.append(line) formatted_table = self.format_wiki_table("\n".join(table_content)) if formatted_table: current_section.append(formatted_table) elif in_table: table_content.append(line) else: if line.strip(): current_section.append(line) elif current_section: formatted_content.append("\n".join(current_section)) current_section = [] if current_section: formatted_content.append("\n".join(current_section)) # Clean and return the formatted content cleaned_content = self.clean_wiki_content("\n\n".join(formatted_content)) return f"Wikipedia content for '{wiki_title}':\n\n{cleaned_content}" def main(query: str) -> str: """ Test function to run the SmartSearchTool directly. Args: query: The search query to test Returns: The search results """ # Configure logging logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", ) # Create and run the tool tool = SmartSearchTool() result = tool.forward(query) # Print the result print("\nSearch Results:") print("-" * 80) print(result) print("-" * 80) return result if __name__ == "__main__": import sys if len(sys.argv) > 1: query = " ".join(sys.argv[1:]) main(query) else: print("Usage: python tool.py ") print("Example: python tool.py 'Mercedes Sosa discography'")