Spaces:
Sleeping
Sleeping
| """ | |
| Format Detection Utilities for GEPA Optimizer. | |
| This module provides utilities to automatically detect output format patterns | |
| from expected outputs and generate format constraints for reflection prompts. | |
| Key Features: | |
| 1. Auto-detect JSON, key-value, tabular, or free-text formats | |
| 2. Generate format specifications from examples | |
| 3. Create format constraint strings for prompt injection | |
| """ | |
| import re | |
| import json | |
| from typing import List, Dict, Any, Optional, Tuple | |
| def detect_output_format(expected_outputs: List[str]) -> Dict[str, Any]: | |
| """ | |
| Analyze expected outputs to detect the common format pattern. | |
| Args: | |
| expected_outputs: List of expected output strings from the dataset | |
| Returns: | |
| Dictionary containing: | |
| - format_type: 'json', 'key_value', 'tabular', 'structured_text', 'free_text' | |
| - format_spec: Human-readable format specification | |
| - format_example: Example showing the format | |
| - format_constraint: Constraint text to add to prompts | |
| - detected_keys: List of keys/fields detected (for structured formats) | |
| - avg_length: Average length of outputs (to enforce conciseness) | |
| """ | |
| if not expected_outputs: | |
| return { | |
| 'format_type': 'unknown', | |
| 'format_spec': 'Unknown format', | |
| 'format_example': '', | |
| 'format_constraint': '', | |
| 'detected_keys': [], | |
| 'avg_length': 0 | |
| } | |
| # Filter out empty outputs | |
| valid_outputs = [o for o in expected_outputs if o and o.strip()] | |
| if not valid_outputs: | |
| return _create_format_result('unknown', 'Unknown format', '', [], 0) | |
| # Calculate average length for conciseness constraint | |
| avg_length = sum(len(o) for o in valid_outputs) // len(valid_outputs) | |
| max_length = max(len(o) for o in valid_outputs) | |
| # Try to detect format type (in order of specificity) | |
| # 1. Check for JSON format | |
| json_result = _detect_json_format(valid_outputs, avg_length, max_length) | |
| if json_result: | |
| return json_result | |
| # 2. Check for key-value format (e.g., "Department: X | Sentiment: Y") | |
| kv_result = _detect_key_value_format(valid_outputs, avg_length, max_length) | |
| if kv_result: | |
| return kv_result | |
| # 3. Check for bullet/list format | |
| list_result = _detect_list_format(valid_outputs, avg_length, max_length) | |
| if list_result: | |
| return list_result | |
| # 4. Check for tabular/structured text | |
| structured_result = _detect_structured_text(valid_outputs, avg_length, max_length) | |
| if structured_result: | |
| return structured_result | |
| # 5. Default to free text with length constraint | |
| return _create_format_result( | |
| 'free_text', | |
| f'Free-form text response (typically {avg_length} characters)', | |
| valid_outputs[0][:100] if valid_outputs else '', | |
| [], | |
| avg_length, | |
| max_length | |
| ) | |
| def _detect_json_format(outputs: List[str], avg_length: int, max_length: int) -> Optional[Dict[str, Any]]: | |
| """Detect if outputs are JSON format.""" | |
| json_count = 0 | |
| all_keys = [] | |
| for output in outputs: | |
| stripped = output.strip() | |
| if stripped.startswith('{') and stripped.endswith('}'): | |
| try: | |
| parsed = json.loads(stripped) | |
| if isinstance(parsed, dict): | |
| json_count += 1 | |
| all_keys.extend(parsed.keys()) | |
| except json.JSONDecodeError: | |
| pass | |
| # If majority are JSON | |
| if json_count >= len(outputs) * 0.7: | |
| # Find common keys | |
| key_counts = {} | |
| for key in all_keys: | |
| key_counts[key] = key_counts.get(key, 0) + 1 | |
| common_keys = [k for k, v in key_counts.items() if v >= json_count * 0.5] | |
| # Build format spec | |
| format_spec = f"JSON object with keys: {', '.join(common_keys)}" | |
| format_example = outputs[0][:200] if outputs else '{}' | |
| return _create_format_result( | |
| 'json', | |
| format_spec, | |
| format_example, | |
| common_keys, | |
| avg_length, | |
| max_length | |
| ) | |
| return None | |
| def _detect_key_value_format(outputs: List[str], avg_length: int, max_length: int) -> Optional[Dict[str, Any]]: | |
| """Detect key-value formats like 'Department: X | Sentiment: Y'.""" | |
| # Common separators for key-value pairs | |
| separators = ['|', '\n', ';', ','] | |
| key_patterns = [ | |
| r'([A-Za-z_][A-Za-z0-9_\s]*)\s*[:=]\s*([^|;\n,]+)', # Key: Value or Key = Value | |
| ] | |
| all_keys = [] | |
| kv_count = 0 | |
| detected_separator = None | |
| for output in outputs: | |
| # Try to find key-value pairs | |
| for pattern in key_patterns: | |
| matches = re.findall(pattern, output) | |
| if len(matches) >= 2: # At least 2 key-value pairs | |
| kv_count += 1 | |
| for key, _ in matches: | |
| all_keys.append(key.strip()) | |
| # Detect separator | |
| for sep in separators: | |
| if sep in output: | |
| detected_separator = sep | |
| break | |
| break | |
| # If majority are key-value | |
| if kv_count >= len(outputs) * 0.6: | |
| # Find common keys | |
| key_counts = {} | |
| for key in all_keys: | |
| normalized = key.strip().lower() | |
| key_counts[normalized] = key_counts.get(normalized, 0) + 1 | |
| common_keys = [k for k, v in sorted(key_counts.items(), key=lambda x: -x[1]) | |
| if v >= kv_count * 0.4][:5] # Top 5 keys | |
| # Determine the exact format pattern | |
| sep_display = detected_separator if detected_separator else ' | ' | |
| format_spec = f"Key-value pairs: {sep_display.join([f'{k}: [value]' for k in common_keys])}" | |
| format_example = outputs[0] if outputs else '' | |
| return _create_format_result( | |
| 'key_value', | |
| format_spec, | |
| format_example, | |
| common_keys, | |
| avg_length, | |
| max_length | |
| ) | |
| return None | |
| def _detect_list_format(outputs: List[str], avg_length: int, max_length: int) -> Optional[Dict[str, Any]]: | |
| """Detect bullet/numbered list formats.""" | |
| list_patterns = [ | |
| r'^[-*•]\s+', # Bullet points | |
| r'^\d+[.)]\s+', # Numbered list | |
| ] | |
| list_count = 0 | |
| for output in outputs: | |
| lines = output.strip().split('\n') | |
| list_lines = 0 | |
| for line in lines: | |
| for pattern in list_patterns: | |
| if re.match(pattern, line.strip()): | |
| list_lines += 1 | |
| break | |
| if list_lines >= len(lines) * 0.5: # Majority are list items | |
| list_count += 1 | |
| if list_count >= len(outputs) * 0.6: | |
| return _create_format_result( | |
| 'list', | |
| 'Bullet or numbered list format', | |
| outputs[0][:200] if outputs else '', | |
| [], | |
| avg_length, | |
| max_length | |
| ) | |
| return None | |
| def _detect_structured_text(outputs: List[str], avg_length: int, max_length: int) -> Optional[Dict[str, Any]]: | |
| """Detect structured text with consistent patterns.""" | |
| # Check for consistent line patterns | |
| line_counts = [len(o.strip().split('\n')) for o in outputs] | |
| avg_lines = sum(line_counts) // len(line_counts) if line_counts else 1 | |
| if avg_lines >= 2: | |
| return _create_format_result( | |
| 'structured_text', | |
| f'Structured text with ~{avg_lines} lines', | |
| outputs[0][:200] if outputs else '', | |
| [], | |
| avg_length, | |
| max_length | |
| ) | |
| return None | |
| def _create_format_result( | |
| format_type: str, | |
| format_spec: str, | |
| format_example: str, | |
| detected_keys: List[str], | |
| avg_length: int, | |
| max_length: int = 0 | |
| ) -> Dict[str, Any]: | |
| """Create a standardized format detection result.""" | |
| # Generate format constraint based on type | |
| if format_type == 'json': | |
| constraint = f"""OUTPUT FORMAT REQUIREMENT: | |
| - Return ONLY a valid JSON object | |
| - Required keys: {', '.join(detected_keys) if detected_keys else 'as shown in examples'} | |
| - NO explanations, NO prose, NO markdown code blocks | |
| - Maximum length: ~{max_length} characters | |
| - Example format: {format_example[:150]}""" | |
| elif format_type == 'key_value': | |
| constraint = f"""OUTPUT FORMAT REQUIREMENT: | |
| - Return ONLY in key-value format: {format_spec} | |
| - NO explanations, NO reasoning, NO additional text | |
| - Be CONCISE - output should be ~{avg_length} characters max | |
| - Example: {format_example}""" | |
| elif format_type == 'list': | |
| constraint = f"""OUTPUT FORMAT REQUIREMENT: | |
| - Return as a bullet or numbered list | |
| - NO explanations before or after the list | |
| - Keep it concise (~{avg_length} characters)""" | |
| elif format_type == 'structured_text': | |
| constraint = f"""OUTPUT FORMAT REQUIREMENT: | |
| - Follow the structured format shown in examples | |
| - NO additional explanations or commentary | |
| - Keep output concise (~{avg_length} characters)""" | |
| else: | |
| constraint = f"""OUTPUT FORMAT REQUIREMENT: | |
| - Keep response CONCISE and DIRECT | |
| - NO lengthy explanations or reasoning | |
| - Target length: ~{avg_length} characters (max {max_length}) | |
| - Match the format/style of the expected examples""" | |
| return { | |
| 'format_type': format_type, | |
| 'format_spec': format_spec, | |
| 'format_example': format_example[:200] if format_example else '', | |
| 'format_constraint': constraint, | |
| 'detected_keys': detected_keys, | |
| 'avg_length': avg_length, | |
| 'max_length': max_length | |
| } | |
| def build_format_aware_reflection_prompt( | |
| base_prompt: str, | |
| format_info: Dict[str, Any], | |
| include_example: bool = True | |
| ) -> str: | |
| """ | |
| Enhance a reflection prompt with format awareness. | |
| Args: | |
| base_prompt: The original reflection prompt | |
| format_info: Format detection result from detect_output_format() | |
| include_example: Whether to include format example | |
| Returns: | |
| Enhanced prompt with format constraints | |
| """ | |
| if not format_info or format_info.get('format_type') == 'unknown': | |
| return base_prompt | |
| format_section = f""" | |
| 🎯 CRITICAL FORMAT REQUIREMENT: | |
| The optimized prompt MUST produce outputs that match this EXACT format: | |
| {format_info['format_constraint']} | |
| ⚠️ COMMON FAILURE MODES TO AVOID: | |
| 1. Generating explanations when only the answer is needed | |
| 2. Adding "Here's the analysis..." or similar preambles | |
| 3. Producing verbose output when concise is required | |
| 4. Wrong structure (e.g., prose instead of key-value pairs) | |
| """ | |
| if include_example and format_info.get('format_example'): | |
| format_section += f""" | |
| 📋 EXAMPLE OF CORRECT OUTPUT FORMAT: | |
| {format_info['format_example']} | |
| """ | |
| # Insert format section near the end of the prompt but before any final instructions | |
| return base_prompt + format_section | |
| def generate_format_feedback( | |
| predicted_output: str, | |
| expected_output: str, | |
| format_info: Dict[str, Any] | |
| ) -> str: | |
| """ | |
| Generate specific feedback about format compliance. | |
| Args: | |
| predicted_output: What the model actually produced | |
| expected_output: The ground truth output | |
| format_info: Format detection result | |
| Returns: | |
| Specific format-related feedback | |
| """ | |
| predicted_len = len(predicted_output) if predicted_output else 0 | |
| expected_len = len(expected_output) if expected_output else 0 | |
| issues = [] | |
| # Check length discrepancy | |
| if format_info.get('avg_length', 0) > 0: | |
| if predicted_len > format_info['avg_length'] * 3: | |
| issues.append(f"OUTPUT TOO VERBOSE: Generated {predicted_len} chars, expected ~{format_info['avg_length']} chars") | |
| elif predicted_len > format_info.get('max_length', predicted_len) * 2: | |
| issues.append(f"OUTPUT TOO LONG: {predicted_len} chars vs max expected {format_info.get('max_length', 'unknown')}") | |
| # Check format type compliance | |
| format_type = format_info.get('format_type', 'unknown') | |
| if format_type == 'json': | |
| try: | |
| json.loads(predicted_output.strip() if predicted_output else '{}') | |
| except json.JSONDecodeError: | |
| issues.append("FORMAT ERROR: Expected JSON but got non-JSON output") | |
| elif format_type == 'key_value': | |
| # Check if output has key-value structure | |
| if predicted_output and ':' not in predicted_output: | |
| issues.append("FORMAT ERROR: Expected key-value pairs (Key: Value) but output lacks this structure") | |
| # Check for common verbose patterns | |
| verbose_indicators = [ | |
| 'let me', 'i will', 'here is', "here's", 'analysis:', 'step-by-step', | |
| 'first,', 'to begin', 'in order to', 'the following', 'please note' | |
| ] | |
| if predicted_output: | |
| lower_output = predicted_output.lower() | |
| found_verbose = [v for v in verbose_indicators if v in lower_output] | |
| if found_verbose: | |
| issues.append(f"VERBOSITY WARNING: Output contains explanatory phrases: {', '.join(found_verbose[:3])}") | |
| if not issues: | |
| return "" | |
| return "\n🚨 FORMAT ISSUES DETECTED:\n" + "\n".join(f" • {issue}" for issue in issues) | |