from typing import Dict, Optional, Tuple, List, Any, Set, Union import re import xml.etree.ElementTree as ET from datetime import datetime import json import logging from enum import Enum # Setup logger logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) # Create console handler if needed if not logger.handlers: ch = logging.StreamHandler() ch.setLevel(logging.INFO) formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') ch.setFormatter(formatter) logger.addHandler(ch) class StreamingFormatter: def __init__(self): self.processed_events = set() self.current_tool_outputs = [] self.current_citations = [] self.current_metadata = {} self.current_message_id = None self.current_message_buffer = "" def reset(self): """Reset the formatter state""" self.processed_events.clear() self.current_tool_outputs.clear() self.current_citations.clear() self.current_metadata.clear() self.current_message_id = None self.current_message_buffer = "" def append_to_buffer(self, text: str): """Append text to the current message buffer""" self.current_message_buffer += text def get_and_clear_buffer(self) -> str: """Get the current buffer content and clear it""" content = self.current_message_buffer self.current_message_buffer = "" return content class ToolType(Enum): """Enum for supported tool types""" DUCKDUCKGO = "ddgo_search" REDDIT_NEWS = "reddit_x_gnews_newswire_crunchbase" PUBMED = "pubmed_search" CENSUS = "get_census_data" HEATMAP = "heatmap_code" MERMAID = "mermaid_output" WISQARS = "wisqars" WONDER = "wonder" NCHS = "nchs" ONESTEP = "onestep" DQS = "dqs_nhis_adult_summary_health_statistics" @classmethod def get_tool_type(cls, tool_name: str) -> Optional['ToolType']: """Get enum member from tool name string""" try: return cls[tool_name.upper()] except KeyError: return None class ResponseFormatter: _instance = None def __new__(cls): if cls._instance is None: cls._instance = super(ResponseFormatter, cls).__new__(cls) cls._instance.streaming_state = StreamingFormatter() cls._instance.logger = logger return cls._instance def format_thought( self, thought: str, observation: str, citations: List[Dict] = None, metadata: Dict = None, tool_outputs: List[Dict] = None, event_id: str = None, message_id: str = None ) -> Optional[Tuple[str, str]]: """Format agent thought for both terminal and XML output""" # Skip if already processed in streaming mode if event_id and event_id in self.streaming_state.processed_events: return None # Handle message state if message_id != self.streaming_state.current_message_id: self.streaming_state.reset() self.streaming_state.current_message_id = message_id # Skip empty thoughts if not thought and not observation and not tool_outputs: return None # Terminal format terminal_output = { "type": "agent_thought", "content": thought, "metadata": metadata or {} } if tool_outputs: # Deduplicate tool outputs seen_outputs = set() unique_outputs = [] for output in tool_outputs: output_key = f"{output.get('type')}:{output.get('content')}" if output_key not in seen_outputs: seen_outputs.add(output_key) unique_outputs.append(output) terminal_output["tool_outputs"] = unique_outputs # XML format root = ET.Element("agent_response") if thought: thought_elem = ET.SubElement(root, "thought") thought_elem.text = thought if observation: obs_elem = ET.SubElement(root, "observation") obs_elem.text = observation if tool_outputs: tools_elem = ET.SubElement(root, "tool_outputs") for tool_output in unique_outputs: tool_elem = ET.SubElement(tools_elem, "tool_output") tool_elem.attrib["type"] = tool_output.get("type", "") tool_elem.text = tool_output.get("content", "") if citations: cites_elem = ET.SubElement(root, "citations") for citation in citations: cite_elem = ET.SubElement(cites_elem, "citation") for key, value in citation.items(): cite_elem.attrib[key] = str(value) xml_output = ET.tostring(root, encoding='unicode') # Track processed event if event_id: self.streaming_state.processed_events.add(event_id) return json.dumps(terminal_output), xml_output def format_message( self, message: str, event_id: str = None, message_id: str = None ) -> Optional[Tuple[str, str]]: """Format agent message for both terminal and XML output""" # Skip if already processed if event_id and event_id in self.streaming_state.processed_events: return None # Handle message state if message_id != self.streaming_state.current_message_id: self.streaming_state.reset() self.streaming_state.current_message_id = message_id # Accumulate message content self.streaming_state.append_to_buffer(message) # Only output if we have meaningful content if not self.streaming_state.current_message_buffer.strip(): return None # Terminal format terminal_output = self.streaming_state.current_message_buffer.strip() # XML format root = ET.Element("agent_response") msg_elem = ET.SubElement(root, "message") msg_elem.text = terminal_output xml_output = ET.tostring(root, encoding='unicode') # Track processed event if event_id: self.streaming_state.processed_events.add(event_id) return terminal_output, xml_output def format_error( self, error: str, event_id: str = None, message_id: str = None ) -> Optional[Tuple[str, str]]: """Format error message for both terminal and XML output""" # Skip if already processed if event_id and event_id in self.streaming_state.processed_events: return None # Handle message state if message_id != self.streaming_state.current_message_id: self.streaming_state.reset() self.streaming_state.current_message_id = message_id # Skip empty errors if not error: return None # Terminal format terminal_output = f"Error: {error}" # XML format root = ET.Element("agent_response") error_elem = ET.SubElement(root, "error") error_elem.text = error xml_output = ET.tostring(root, encoding='unicode') # Track processed event if event_id: self.streaming_state.processed_events.add(event_id) return terminal_output, xml_output def format_tool_output( self, tool_type: str, content: Union[str, Dict], metadata: Optional[Dict] = None ) -> Dict: """Format tool output into standardized structure""" try: # Get enum tool type tool = ToolType.get_tool_type(tool_type) if not tool: self.logger.warning(f"Unknown tool type: {tool_type}") return { "type": tool_type, "content": content, "metadata": metadata or {} } # Format based on tool type if tool == ToolType.MERMAID: return { "type": "mermaid", "content": self._clean_mermaid_content(content), "metadata": metadata or {} } elif tool == ToolType.HEATMAP: return { "type": "heatmap", "content": self._format_heatmap_data(content), "metadata": metadata or {} } else: # Default formatting for other tools return { "type": tool.value, "content": content, "metadata": metadata or {} } except Exception as e: self.logger.error(f"Error formatting tool output: {str(e)}") return { "type": "error", "content": str(e), "metadata": metadata or {} } def _clean_mermaid_content(self, content: Union[str, Dict]) -> str: """Clean and standardize mermaid diagram content""" try: if isinstance(content, dict): content = content.get("mermaid_diagram", "") # Remove markdown formatting content = re.sub(r'```mermaid\s*|\s*```', '', content) # Clean up whitespace content = content.strip() return content except Exception as e: self.logger.error(f"Error cleaning mermaid content: {str(e)}") return str(content) def _format_heatmap_data(self, content: Union[str, Dict]) -> Dict: """Format heatmap data into standardized structure""" try: if isinstance(content, str): content = json.loads(content) return { "data": content.get("data", []), "options": content.get("options", {}), "metadata": content.get("metadata", {}) } except Exception as e: self.logger.error(f"Error formatting heatmap data: {str(e)}") return {"error": str(e)} @staticmethod def _clean_markdown(text: str) -> str: """Clean markdown formatting from text""" text = re.sub(r'```.*?```', '', text, flags=re.DOTALL) text = re.sub(r'[*_`#]', '', text) return re.sub(r'\n{3,}', '\n\n', text.strip())