Spaces:
Running
Running
| import json | |
| import re | |
| import logging | |
| from typing import Type, Any, Union, Optional, List, Dict | |
| from pydantic import BaseModel | |
| import os | |
| from .errors import StructuredOutputError | |
| logger = logging.getLogger(__name__) | |
| # Problematic phrases that might cause models to add prose instead of raw JSON | |
| MISALIGNMENT_PHRASES = [ | |
| "explain", | |
| "describe", | |
| "why", | |
| "step by step", | |
| "formatting", | |
| "reasoning", | |
| "thought process", | |
| ] | |
| def schema_guard(prompt: str, instruction: Optional[str] = None) -> None: | |
| """ | |
| Scans the prompt and instruction for phrases that might conflict with strict JSON generation. | |
| """ | |
| combined = (prompt + " " + (instruction or "")).lower() | |
| found = [phrase for phrase in MISALIGNMENT_PHRASES if phrase in combined] | |
| if found: | |
| # Check for strict mode via environment variable | |
| strict_mode = os.getenv("LLM_SCHEMA_GUARD_STRICT", "false").lower() == "true" | |
| warning_msg = f"Schema misalignment guard hit: problematic phrases found: {found}" | |
| if strict_mode: | |
| logger.error(f"STRICT MODE: {warning_msg}") | |
| raise ValueError(warning_msg) | |
| else: | |
| logger.warning(warning_msg) | |
| def get_json_instruction(schema: Type[BaseModel], current_instruction: Optional[str] = None) -> str: | |
| """ | |
| Returns a concise but strict JSON instruction, preserving existing instructions. | |
| """ | |
| json_requirements = ( | |
| "Return ONLY valid JSON. No prose, no preamble. " | |
| "Must conform exactly to this schema. No extra keys." | |
| ) | |
| schema_json = json.dumps(schema.model_json_schema()) | |
| base = f"{current_instruction}\n\n" if current_instruction else "" | |
| return f"{base}{json_requirements}\nSchema: {schema_json}" | |
| def extract_json(text: str) -> str: | |
| """ | |
| Robustly extract the largest JSON-like block from text. | |
| """ | |
| # Try to find the first '{' and last '}' | |
| # We use non-greedy find for the first '{' but greedy for the last '}' | |
| first = text.find("{") | |
| last = text.rfind("}") | |
| if first != -1 and last != -1 and last > first: | |
| return text[first : last + 1] | |
| return text.strip() | |
| def validate_structured_output( | |
| text: str, schema: Type[BaseModel], provider: str, model: str, prompt_id: str | |
| ) -> Union[Dict[str, Any], BaseModel]: | |
| """ | |
| Parses and validates the LLM output against a schema. | |
| Raises StructuredOutputError on failure. | |
| """ | |
| clean_text = extract_json(text) | |
| try: | |
| data = json.loads(clean_text) | |
| except json.JSONDecodeError as e: | |
| logger.error(f"JSON Parse Failure. Raw text between braces: {clean_text}") | |
| raise StructuredOutputError( | |
| provider=provider, | |
| model=model, | |
| prompt_id=prompt_id, | |
| raw_output=text, | |
| reason="JSON Parse Failure", | |
| details=str(e), | |
| ) | |
| try: | |
| return schema(**data) | |
| except Exception as e: | |
| logger.error(f"Schema Validation Failure. Data: {data}") | |
| raise StructuredOutputError( | |
| provider=provider, | |
| model=model, | |
| prompt_id=prompt_id, | |
| raw_output=text, | |
| reason="Schema Validation Failure", | |
| details=str(e), | |
| ) | |