Spaces:
Running
Running
from typing import Dict, Optional, Tuple, List, Any | |
import re | |
import xml.etree.ElementTree as ET | |
from datetime import datetime | |
import json | |
class ToolType: | |
DUCKDUCKGO = "duckduckgo_search" | |
REDDIT_NEWS = "reddit_x_gnews_newswire_crunchbase" | |
PUBMED = "pubmed_search" | |
CENSUS = "get_census_data" | |
HEATMAP = "heatmap_code" | |
MERMAID = "mermaid_diagram" | |
WISQARS = "wisqars" | |
WONDER = "wonder" | |
NCHS = "nchs" | |
ONESTEP = "onestep" | |
DQS = "dqs_nhis_adult_summary_health_statistics" | |
class ResponseFormatter: | |
def format_thought( | |
thought: str, | |
observation: Optional[str] = None, | |
citations: List[Dict] = None, | |
metadata: Dict = None | |
) -> Tuple[str, str]: | |
"""Format agent thought and observation for both terminal and XML output""" | |
# Terminal format | |
terminal_output = thought.strip() | |
if observation: | |
cleaned_obs = ResponseFormatter._clean_markdown(observation) | |
if cleaned_obs: | |
terminal_output += f"\n\nObservation:\n{cleaned_obs}" | |
# XML format | |
root = ET.Element("agent_response") | |
thought_elem = ET.SubElement(root, "thought") | |
thought_elem.text = thought.strip() | |
if observation: | |
obs_elem = ET.SubElement(root, "observation") | |
# Extract and format tool outputs | |
tool_outputs = ResponseFormatter._extract_tool_outputs(observation) | |
if tool_outputs: | |
tools_elem = ET.SubElement(obs_elem, "tools") | |
for tool_name, tool_data in tool_outputs.items(): | |
tool_elem = ResponseFormatter._create_tool_element(tools_elem, tool_name, tool_data) | |
# Add citations if available | |
if citations: | |
citations_elem = ET.SubElement(root, "citations") | |
for citation in citations: | |
cite_elem = ET.SubElement(citations_elem, "citation") | |
for key, value in citation.items(): | |
cite_detail = ET.SubElement(cite_elem, key) | |
cite_detail.text = str(value) | |
# Add metadata if available | |
if metadata: | |
metadata_elem = ET.SubElement(root, "metadata") | |
for key, value in metadata.items(): | |
meta_detail = ET.SubElement(metadata_elem, key) | |
meta_detail.text = str(value) | |
xml_output = ET.tostring(root, encoding='unicode') | |
return terminal_output, xml_output | |
def _create_tool_element(parent: ET.Element, tool_name: str, tool_data: Dict) -> ET.Element: | |
"""Create XML element for specific tool type with appropriate structure""" | |
tool_elem = ET.SubElement(parent, "tool") | |
tool_elem.set("name", tool_name) | |
# Handle different tool types | |
if tool_name == ToolType.CENSUS: | |
ResponseFormatter._format_census_data(tool_elem, tool_data) | |
elif tool_name == ToolType.MERMAID: | |
ResponseFormatter._format_mermaid_data(tool_elem, tool_data) | |
elif tool_name in [ToolType.WISQARS, ToolType.WONDER, ToolType.NCHS]: | |
ResponseFormatter._format_health_data(tool_elem, tool_data) | |
else: | |
# Generic tool output format | |
content_elem = ET.SubElement(tool_elem, "content") | |
content_elem.text = ResponseFormatter._clean_markdown(str(tool_data)) | |
return tool_elem | |
def _format_census_data(tool_elem: ET.Element, data: Dict) -> None: | |
"""Format census data with specific structure""" | |
try: | |
# Extract census tract data | |
tracts_elem = ET.SubElement(tool_elem, "census_tracts") | |
# Parse the llm_result to extract structured data | |
if "llm_result" in data: | |
result = json.loads(data["llm_result"]) | |
for tract_data in result.get("tracts", []): | |
tract_elem = ET.SubElement(tracts_elem, "tract") | |
tract_elem.set("id", str(tract_data.get("tract", ""))) | |
# Add tract details | |
for key, value in tract_data.items(): | |
if key != "tract": | |
detail_elem = ET.SubElement(tract_elem, key.replace("_", "")) | |
detail_elem.text = str(value) | |
except: | |
# Fallback to simple format if parsing fails | |
content_elem = ET.SubElement(tool_elem, "content") | |
content_elem.text = ResponseFormatter._clean_markdown(str(data)) | |
def _format_mermaid_data(tool_elem: ET.Element, data: Dict) -> None: | |
"""Format mermaid diagram data""" | |
try: | |
diagram_elem = ET.SubElement(tool_elem, "diagram") | |
if "mermaid_diagram" in data: | |
# Clean the mermaid code | |
mermaid_code = re.sub(r'```mermaid\s*|\s*```', '', data["mermaid_diagram"]) | |
diagram_elem.text = mermaid_code.strip() | |
except: | |
content_elem = ET.SubElement(tool_elem, "content") | |
content_elem.text = ResponseFormatter._clean_markdown(str(data)) | |
def _format_health_data(tool_elem: ET.Element, data: Dict) -> None: | |
"""Format health-related data from WISQARS, WONDER, etc.""" | |
try: | |
if isinstance(data, dict): | |
for key, value in data.items(): | |
category_elem = ET.SubElement(tool_elem, key.replace("_", "")) | |
if isinstance(value, dict): | |
for sub_key, sub_value in value.items(): | |
sub_elem = ET.SubElement(category_elem, sub_key.replace("_", "")) | |
sub_elem.text = str(sub_value) | |
else: | |
category_elem.text = str(value) | |
except: | |
content_elem = ET.SubElement(tool_elem, "content") | |
content_elem.text = ResponseFormatter._clean_markdown(str(data)) | |
def _extract_tool_outputs(observation: str) -> Dict[str, Any]: | |
"""Extract and clean tool outputs from observation""" | |
tool_outputs = {} | |
try: | |
if isinstance(observation, str): | |
data = json.loads(observation) | |
for key, value in data.items(): | |
if isinstance(value, str) and "llm_result" in value: | |
try: | |
tool_result = json.loads(value) | |
tool_outputs[key] = tool_result | |
except: | |
tool_outputs[key] = value | |
except: | |
pass | |
return tool_outputs | |
def format_message(message: str) -> Tuple[str, str]: | |
"""Format agent message for both terminal and XML output""" | |
# Terminal format | |
terminal_output = message.strip() | |
# XML format | |
root = ET.Element("agent_response") | |
msg_elem = ET.SubElement(root, "message") | |
msg_elem.text = message.strip() | |
xml_output = ET.tostring(root, encoding='unicode') | |
return terminal_output, xml_output | |
def format_error(error: str) -> Tuple[str, str]: | |
"""Format error message for both terminal and XML output""" | |
# Terminal format | |
terminal_output = f"Error: {error}" | |
# XML format | |
root = ET.Element("agent_response") | |
error_elem = ET.SubElement(root, "error") | |
error_elem.text = error | |
xml_output = ET.tostring(root, encoding='unicode') | |
return terminal_output, xml_output | |
def _clean_markdown(text: str) -> str: | |
"""Clean markdown formatting from text""" | |
text = re.sub(r'```.*?```', '', text, flags=re.DOTALL) | |
text = re.sub(r'[*_`#]', '', text) | |
return re.sub(r'\n{3,}', '\n\n', text.strip()) |