Spaces:
Build error
Build error
Refactor main_v2.py to update task formatting for dual answer requests, enhancing response structure. Implement error handling for JSON parsing in agent results, ensuring robust output. Add unit tests in test_questions.py to validate succinct answer accuracy against expected values. Remove unused extract_final_answer utility from utils.py, streamlining the codebase.
2da6a11
unverified
import logging | |
import re | |
from typing import Optional | |
import requests | |
from smolagents import Tool | |
from smolagents.default_tools import DuckDuckGoSearchTool | |
logger = logging.getLogger(__name__) | |
class SmartSearchTool(Tool): | |
name = "smart_search" | |
description = """A smart search tool that searches Wikipedia for information.""" | |
inputs = { | |
"query": { | |
"type": "string", | |
"description": "The search query to find information", | |
} | |
} | |
output_type = "string" | |
def __init__(self): | |
super().__init__() | |
self.web_search_tool = DuckDuckGoSearchTool(max_results=1) | |
self.api_url = "https://en.wikipedia.org/w/api.php" | |
self.headers = { | |
"User-Agent": "SmartSearchTool/1.0 (https://github.com/yourusername/yourproject; your@email.com)" | |
} | |
def get_wikipedia_page(self, title: str) -> Optional[str]: | |
"""Get the raw wiki markup of a Wikipedia page.""" | |
try: | |
params = { | |
"action": "query", | |
"prop": "revisions", | |
"rvprop": "content", | |
"rvslots": "main", | |
"format": "json", | |
"titles": title, | |
"redirects": 1, | |
} | |
response = requests.get(self.api_url, params=params, headers=self.headers) | |
response.raise_for_status() | |
data = response.json() | |
# Extract page content | |
pages = data.get("query", {}).get("pages", {}) | |
for page_id, page_data in pages.items(): | |
if "revisions" in page_data: | |
return page_data["revisions"][0]["slots"]["main"]["*"] | |
return None | |
except Exception as e: | |
logger.error(f"Error getting Wikipedia page: {e}") | |
return None | |
def clean_wiki_content(self, content: str) -> str: | |
"""Clean Wikipedia content by removing markup and formatting.""" | |
# Remove citations | |
content = re.sub(r"\[\d+\]", "", content) | |
# Remove edit links | |
content = re.sub(r"\[edit\]", "", content) | |
# Remove file links | |
content = re.sub(r"\[\[File:.*?\]\]", "", content) | |
# Convert links to just text | |
content = re.sub(r"\[\[(?:[^|\]]*\|)?([^\]]+)\]\]", r"\1", content) | |
# Remove HTML comments | |
content = re.sub(r"<!--.*?-->", "", content, flags=re.DOTALL) | |
# Remove templates | |
content = re.sub(r"\{\{.*?\}\}", "", content) | |
# Remove small tags | |
content = re.sub(r"<small>.*?</small>", "", content) | |
# Normalize whitespace | |
content = re.sub(r"\n\s*\n", "\n\n", content) | |
return content.strip() | |
def format_wiki_table(self, table_content: str) -> str: | |
"""Format a Wikipedia table into readable text.""" | |
# Split into rows | |
rows = table_content.strip().split("\n") | |
formatted_rows = [] | |
current_row = [] | |
for row in rows: | |
# Skip empty rows and table structure markers | |
if not row.strip() or row.startswith("|-") or row.startswith("|+"): | |
if current_row: | |
formatted_rows.append("\t".join(current_row)) | |
current_row = [] | |
continue | |
# Extract cells | |
cells = [] | |
# Split the row into cells using | or ! as separators | |
cell_parts = re.split(r"[|!]", row) | |
for cell in cell_parts[1:]: # Skip the first empty part | |
# Clean up the cell content | |
cell = cell.strip() | |
# Remove any remaining markup | |
cell = re.sub(r"<.*?>", "", cell) # Remove HTML tags | |
cell = re.sub(r"\[\[.*?\|(.*?)\]\]", r"\1", cell) # Convert links | |
cell = re.sub(r"\[\[(.*?)\]\]", r"\1", cell) # Convert simple links | |
cell = re.sub(r"\{\{.*?\}\}", "", cell) # Remove templates | |
cell = re.sub(r"<small>.*?</small>", "", cell) # Remove small tags | |
cell = re.sub(r'rowspan="\d+"', "", cell) # Remove rowspan | |
cell = re.sub(r'colspan="\d+"', "", cell) # Remove colspan | |
cell = re.sub(r'class=".*?"', "", cell) # Remove class attributes | |
cell = re.sub(r'style=".*?"', "", cell) # Remove style attributes | |
cell = re.sub(r'align=".*?"', "", cell) # Remove align attributes | |
cell = re.sub(r'width=".*?"', "", cell) # Remove width attributes | |
cell = re.sub(r'bgcolor=".*?"', "", cell) # Remove bgcolor attributes | |
cell = re.sub(r'valign=".*?"', "", cell) # Remove valign attributes | |
cell = re.sub(r'border=".*?"', "", cell) # Remove border attributes | |
cell = re.sub( | |
r'cellpadding=".*?"', "", cell | |
) # Remove cellpadding attributes | |
cell = re.sub( | |
r'cellspacing=".*?"', "", cell | |
) # Remove cellspacing attributes | |
cell = re.sub(r"<ref.*?</ref>", "", cell) # Remove references | |
cell = re.sub(r"<ref.*?/>", "", cell) # Remove empty references | |
cell = re.sub( | |
r"<br\s*/?>", " ", cell | |
) # Replace line breaks with spaces | |
cell = re.sub(r"\s+", " ", cell) # Normalize whitespace | |
cells.append(cell) | |
if cells: | |
current_row.extend(cells) | |
if current_row: | |
formatted_rows.append("\t".join(current_row)) | |
if formatted_rows: | |
return "\n".join(formatted_rows) | |
return "" | |
def extract_wikipedia_title(self, search_result: str) -> Optional[str]: | |
"""Extract Wikipedia page title from search result.""" | |
# Look for Wikipedia links in the format [Title - Wikipedia](url) | |
wiki_match = re.search( | |
r"\[([^\]]+)\s*-\s*Wikipedia\]\(https://en\.wikipedia\.org/wiki/[^)]+\)", | |
search_result, | |
) | |
if wiki_match: | |
return wiki_match.group(1).strip() | |
return None | |
def forward(self, query: str) -> str: | |
logger.info(f"Starting smart search for query: {query}") | |
# First do a web search to find the Wikipedia page | |
search_result = self.web_search_tool.forward(query) | |
logger.info(f"Web search results: {search_result[:100]}...") | |
# Extract Wikipedia page title from search results | |
wiki_title = self.extract_wikipedia_title(search_result) | |
if not wiki_title: | |
return f"Could not find Wikipedia page in search results for '{query}'." | |
# Get Wikipedia page content | |
page_content = self.get_wikipedia_page(wiki_title) | |
if not page_content: | |
return f"Could not find Wikipedia page for '{wiki_title}'." | |
# Format tables and content | |
formatted_content = [] | |
current_section = [] | |
in_table = False | |
table_content = [] | |
for line in page_content.split("\n"): | |
if line.startswith("{|"): | |
in_table = True | |
table_content = [line] | |
elif line.startswith("|}"): | |
in_table = False | |
table_content.append(line) | |
formatted_table = self.format_wiki_table("\n".join(table_content)) | |
if formatted_table: | |
current_section.append(formatted_table) | |
elif in_table: | |
table_content.append(line) | |
else: | |
if line.strip(): | |
current_section.append(line) | |
elif current_section: | |
formatted_content.append("\n".join(current_section)) | |
current_section = [] | |
if current_section: | |
formatted_content.append("\n".join(current_section)) | |
# Clean and return the formatted content | |
cleaned_content = self.clean_wiki_content("\n\n".join(formatted_content)) | |
return f"Wikipedia content for '{wiki_title}':\n\n{cleaned_content}" | |
def main(query: str) -> str: | |
""" | |
Test function to run the SmartSearchTool directly. | |
Args: | |
query: The search query to test | |
Returns: | |
The search results | |
""" | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", | |
) | |
# Create and run the tool | |
tool = SmartSearchTool() | |
result = tool.forward(query) | |
# Print the result | |
print("\nSearch Results:") | |
print("-" * 80) | |
print(result) | |
print("-" * 80) | |
return result | |
if __name__ == "__main__": | |
import sys | |
if len(sys.argv) > 1: | |
query = " ".join(sys.argv[1:]) | |
main(query) | |
else: | |
print("Usage: python tool.py <search query>") | |
print("Example: python tool.py 'Mercedes Sosa discography'") | |