cc-api / response_formatter.py
Severian's picture
Upload 81 files
995af0f verified
from typing import Dict, Optional, Tuple, List, Any
import re
import xml.etree.ElementTree as ET
from datetime import datetime
import json
class ToolType:
DUCKDUCKGO = "duckduckgo_search"
REDDIT_NEWS = "reddit_x_gnews_newswire_crunchbase"
PUBMED = "pubmed_search"
CENSUS = "get_census_data"
HEATMAP = "heatmap_code"
MERMAID = "mermaid_diagram"
WISQARS = "wisqars"
WONDER = "wonder"
NCHS = "nchs"
ONESTEP = "onestep"
DQS = "dqs_nhis_adult_summary_health_statistics"
class ResponseFormatter:
@staticmethod
def format_thought(
thought: str,
observation: Optional[str] = None,
citations: List[Dict] = None,
metadata: Dict = None
) -> Tuple[str, str]:
"""Format agent thought and observation for both terminal and XML output"""
# Terminal format
terminal_output = thought.strip()
if observation:
cleaned_obs = ResponseFormatter._clean_markdown(observation)
if cleaned_obs:
terminal_output += f"\n\nObservation:\n{cleaned_obs}"
# XML format
root = ET.Element("agent_response")
thought_elem = ET.SubElement(root, "thought")
thought_elem.text = thought.strip()
if observation:
obs_elem = ET.SubElement(root, "observation")
# Extract and format tool outputs
tool_outputs = ResponseFormatter._extract_tool_outputs(observation)
if tool_outputs:
tools_elem = ET.SubElement(obs_elem, "tools")
for tool_name, tool_data in tool_outputs.items():
tool_elem = ResponseFormatter._create_tool_element(tools_elem, tool_name, tool_data)
# Add citations if available
if citations:
citations_elem = ET.SubElement(root, "citations")
for citation in citations:
cite_elem = ET.SubElement(citations_elem, "citation")
for key, value in citation.items():
cite_detail = ET.SubElement(cite_elem, key)
cite_detail.text = str(value)
# Add metadata if available
if metadata:
metadata_elem = ET.SubElement(root, "metadata")
for key, value in metadata.items():
meta_detail = ET.SubElement(metadata_elem, key)
meta_detail.text = str(value)
xml_output = ET.tostring(root, encoding='unicode')
return terminal_output, xml_output
@staticmethod
def _create_tool_element(parent: ET.Element, tool_name: str, tool_data: Dict) -> ET.Element:
"""Create XML element for specific tool type with appropriate structure"""
tool_elem = ET.SubElement(parent, "tool")
tool_elem.set("name", tool_name)
# Handle different tool types
if tool_name == ToolType.CENSUS:
ResponseFormatter._format_census_data(tool_elem, tool_data)
elif tool_name == ToolType.MERMAID:
ResponseFormatter._format_mermaid_data(tool_elem, tool_data)
elif tool_name in [ToolType.WISQARS, ToolType.WONDER, ToolType.NCHS]:
ResponseFormatter._format_health_data(tool_elem, tool_data)
else:
# Generic tool output format
content_elem = ET.SubElement(tool_elem, "content")
content_elem.text = ResponseFormatter._clean_markdown(str(tool_data))
return tool_elem
@staticmethod
def _format_census_data(tool_elem: ET.Element, data: Dict) -> None:
"""Format census data with specific structure"""
try:
# Extract census tract data
tracts_elem = ET.SubElement(tool_elem, "census_tracts")
# Parse the llm_result to extract structured data
if "llm_result" in data:
result = json.loads(data["llm_result"])
for tract_data in result.get("tracts", []):
tract_elem = ET.SubElement(tracts_elem, "tract")
tract_elem.set("id", str(tract_data.get("tract", "")))
# Add tract details
for key, value in tract_data.items():
if key != "tract":
detail_elem = ET.SubElement(tract_elem, key.replace("_", ""))
detail_elem.text = str(value)
except:
# Fallback to simple format if parsing fails
content_elem = ET.SubElement(tool_elem, "content")
content_elem.text = ResponseFormatter._clean_markdown(str(data))
@staticmethod
def _format_mermaid_data(tool_elem: ET.Element, data: Dict) -> None:
"""Format mermaid diagram data"""
try:
diagram_elem = ET.SubElement(tool_elem, "diagram")
if "mermaid_diagram" in data:
# Clean the mermaid code
mermaid_code = re.sub(r'```mermaid\s*|\s*```', '', data["mermaid_diagram"])
diagram_elem.text = mermaid_code.strip()
except:
content_elem = ET.SubElement(tool_elem, "content")
content_elem.text = ResponseFormatter._clean_markdown(str(data))
@staticmethod
def _format_health_data(tool_elem: ET.Element, data: Dict) -> None:
"""Format health-related data from WISQARS, WONDER, etc."""
try:
if isinstance(data, dict):
for key, value in data.items():
category_elem = ET.SubElement(tool_elem, key.replace("_", ""))
if isinstance(value, dict):
for sub_key, sub_value in value.items():
sub_elem = ET.SubElement(category_elem, sub_key.replace("_", ""))
sub_elem.text = str(sub_value)
else:
category_elem.text = str(value)
except:
content_elem = ET.SubElement(tool_elem, "content")
content_elem.text = ResponseFormatter._clean_markdown(str(data))
@staticmethod
def _extract_tool_outputs(observation: str) -> Dict[str, Any]:
"""Extract and clean tool outputs from observation"""
tool_outputs = {}
try:
if isinstance(observation, str):
data = json.loads(observation)
for key, value in data.items():
if isinstance(value, str) and "llm_result" in value:
try:
tool_result = json.loads(value)
tool_outputs[key] = tool_result
except:
tool_outputs[key] = value
except:
pass
return tool_outputs
@staticmethod
def format_message(message: str) -> Tuple[str, str]:
"""Format agent message for both terminal and XML output"""
# Terminal format
terminal_output = message.strip()
# XML format
root = ET.Element("agent_response")
msg_elem = ET.SubElement(root, "message")
msg_elem.text = message.strip()
xml_output = ET.tostring(root, encoding='unicode')
return terminal_output, xml_output
@staticmethod
def format_error(error: str) -> Tuple[str, str]:
"""Format error message for both terminal and XML output"""
# Terminal format
terminal_output = f"Error: {error}"
# XML format
root = ET.Element("agent_response")
error_elem = ET.SubElement(root, "error")
error_elem.text = error
xml_output = ET.tostring(root, encoding='unicode')
return terminal_output, xml_output
@staticmethod
def _clean_markdown(text: str) -> str:
"""Clean markdown formatting from text"""
text = re.sub(r'```.*?```', '', text, flags=re.DOTALL)
text = re.sub(r'[*_`#]', '', text)
return re.sub(r'\n{3,}', '\n\n', text.strip())