| |
| """ |
| Robust JSON Extraction from LLM Output |
| ======================================= |
| |
| LLMs frequently wrap JSON in markdown, add conversational preamble/postamble, |
| use Python-style booleans, or output malformed JSON. This module handles all |
| of those cases with a multi-strategy approach. |
| |
| Extracted from production tool calling orchestrator. Battle-tested against |
| Hermes-3, Llama 3.3, Qwen2, and Mistral models. |
| |
| Usage: |
| from robust_json_extraction import extract_json, extract_tool_calls |
| |
| # Handle any LLM output format |
| data = extract_json('Here is the result: ```json\\n{"key": "value"}\\n``` Hope that helps!') |
| |
| # Extract tool calls from Hermes-format XML |
| calls = extract_tool_calls('<tool_call>{"name": "search", "arguments": {"q": "test"}}</tool_call>') |
| """ |
|
|
| import json |
| import re |
| import ast |
| import xml.etree.ElementTree as ET |
| from json import JSONDecoder |
| from typing import Any, Dict, List, Optional |
|
|
|
|
| def extract_json(text: str) -> Any: |
| """ |
| Extract JSON from LLM output, handling common issues: |
| 1. Markdown code blocks (```json ... ```) |
| 2. Preamble text ("Here is the result: {...") |
| 3. Postamble text ("...} Let me know if you need help!") |
| 4. Python-style booleans (True/False/None instead of true/false/null) |
| |
| Returns parsed JSON data (dict, list, etc.) |
| Raises json.JSONDecodeError if no valid JSON can be extracted. |
| """ |
| text = text.strip() |
| if not text: |
| raise json.JSONDecodeError("Empty input", text, 0) |
|
|
| |
| if "```" in text: |
| |
| match = re.search(r'```(?:json)?\s*\n(.*?)\n```', text, re.DOTALL) |
| if match: |
| text = match.group(1).strip() |
| else: |
| |
| start = text.find('```') |
| if start != -1: |
| first_newline = text.find('\n', start) |
| if first_newline != -1: |
| text = text[first_newline + 1:] |
| if text.endswith("```"): |
| text = text[:-3].strip() |
|
|
| |
| if not text.startswith(('{', '[')): |
| for char in ['{', '[']: |
| idx = text.find(char) |
| if idx != -1: |
| text = text[idx:] |
| break |
|
|
| |
| try: |
| return json.loads(text) |
| except json.JSONDecodeError as original_error: |
| |
| decoder = JSONDecoder() |
| try: |
| data, _ = decoder.raw_decode(text) |
| return data |
| except json.JSONDecodeError: |
| pass |
|
|
| |
| try: |
| fixed = text.replace('True', 'true').replace('False', 'false').replace('None', 'null') |
| return json.loads(fixed) |
| except json.JSONDecodeError: |
| pass |
|
|
| |
| try: |
| fixed = text.replace('True', 'true').replace('False', 'false').replace('None', 'null') |
| data, _ = decoder.raw_decode(fixed) |
| return data |
| except json.JSONDecodeError: |
| pass |
|
|
| raise original_error |
|
|
|
|
| def parse_single_call(json_text: str) -> Optional[Dict]: |
| """ |
| Parse a single tool call JSON using multiple strategies. |
| Returns dict with 'name' and 'arguments' keys, or None if parsing fails. |
| """ |
| json_text = json_text.strip() |
| if not json_text: |
| return None |
|
|
| |
| try: |
| return json.loads(json_text) |
| except json.JSONDecodeError: |
| pass |
|
|
| |
| try: |
| python_text = json_text.replace('true', 'True').replace('false', 'False').replace('null', 'None') |
| return ast.literal_eval(python_text) |
| except (SyntaxError, ValueError): |
| pass |
|
|
| |
| try: |
| fixed = json_text.replace("'", '"').replace('True', 'true').replace('False', 'false').replace('None', 'null') |
| return json.loads(fixed) |
| except (json.JSONDecodeError, ValueError): |
| pass |
|
|
| |
| name_match = re.search(r"['\"]?name['\"]?\s*:\s*['\"]([^'\"]+)['\"]", json_text) |
| if name_match: |
| name = name_match.group(1) |
| arguments = {} |
| args_match = re.search(r"['\"]?arguments['\"]?\s*:\s*(\{[^}]+\})", json_text) |
| if args_match: |
| try: |
| arguments = json.loads(args_match.group(1)) |
| except json.JSONDecodeError: |
| try: |
| arguments = ast.literal_eval(args_match.group(1)) |
| except (SyntaxError, ValueError): |
| pass |
| return {"name": name, "arguments": arguments} |
|
|
| return None |
|
|
|
|
| def extract_tool_calls(assistant_message: str) -> List[Dict]: |
| """ |
| Extract tool calls from an assistant message containing <tool_call> XML tags. |
| |
| Supports: |
| - Single tool call: <tool_call>{"name": "fn", "arguments": {...}}</tool_call> |
| - Nested format: <tool_call>{"tool_calls": [...]}</tool_call> |
| - Multiple JSON objects in one block (line-by-line) |
| - Malformed XML (regex fallback) |
| |
| Returns list of dicts, each with 'name' and 'arguments' keys. |
| """ |
| tool_calls = [] |
|
|
| |
| try: |
| xml_root = f"<root>{assistant_message}</root>" |
| root = ET.fromstring(xml_root) |
|
|
| for element in root.findall(".//tool_call"): |
| raw_text = (element.text or "").strip() |
| if not raw_text: |
| continue |
|
|
| |
| json_data = parse_single_call(raw_text) |
|
|
| if json_data: |
| |
| if isinstance(json_data, dict) and 'tool_calls' in json_data: |
| nested = json_data.get('tool_calls', []) |
| if isinstance(nested, list): |
| tool_calls.extend(nested) |
| elif isinstance(json_data, dict) and 'name' in json_data: |
| tool_calls.append(json_data) |
| else: |
| |
| for line in raw_text.split('\n'): |
| line = line.strip() |
| if line.startswith('{'): |
| parsed = parse_single_call(line) |
| if parsed: |
| tool_calls.append(parsed) |
|
|
| except ET.ParseError: |
| |
| pattern = re.compile(r'<tool_call>(.*?)</tool_call>', re.DOTALL) |
| for match in pattern.findall(assistant_message): |
| raw_text = match.strip() |
| json_data = parse_single_call(raw_text) |
| if json_data: |
| if isinstance(json_data, dict) and 'tool_calls' in json_data: |
| tool_calls.extend(json_data.get('tool_calls', [])) |
| elif isinstance(json_data, dict) and 'name' in json_data: |
| tool_calls.append(json_data) |
|
|
| return tool_calls |
|
|
|
|
| |
| |
| |
|
|
| if __name__ == "__main__": |
| print("=" * 60) |
| print("Testing robust JSON extraction") |
| print("=" * 60) |
|
|
| |
| assert extract_json('{"key": "value"}') == {"key": "value"} |
| print(" [PASS] Clean JSON") |
|
|
| |
| assert extract_json('```json\n{"key": "value"}\n```') == {"key": "value"} |
| print(" [PASS] Markdown-wrapped JSON") |
|
|
| |
| assert extract_json('Here is the result: {"key": "value"}') == {"key": "value"} |
| print(" [PASS] Preamble text") |
|
|
| |
| assert extract_json('{"key": "value"} Hope that helps!') == {"key": "value"} |
| print(" [PASS] Postamble text") |
|
|
| |
| result = extract_json('Sure! ```json\n{"key": "value"}\n``` Let me know!') |
| assert result == {"key": "value"} |
| print(" [PASS] Preamble + markdown + postamble") |
|
|
| |
| assert extract_json('{"active": True, "deleted": False, "value": None}') == { |
| "active": True, "deleted": False, "value": None |
| } |
| print(" [PASS] Python-style booleans") |
|
|
| print("\n" + "=" * 60) |
| print("Testing tool call extraction") |
| print("=" * 60) |
|
|
| |
| calls = extract_tool_calls('<tool_call>{"name": "search", "arguments": {"q": "test"}}</tool_call>') |
| assert len(calls) == 1 |
| assert calls[0]["name"] == "search" |
| print(" [PASS] Single tool call") |
|
|
| |
| calls = extract_tool_calls( |
| '<tool_call>{"tool_calls": [{"name": "a", "arguments": {}}, {"name": "b", "arguments": {}}]}</tool_call>' |
| ) |
| assert len(calls) == 2 |
| print(" [PASS] Nested tool_calls array") |
|
|
| |
| calls = extract_tool_calls( |
| 'I will search for that.\n<tool_call>\n{"name": "search", "arguments": {"q": "hello"}}\n</tool_call>\nDone.' |
| ) |
| assert len(calls) == 1 |
| assert calls[0]["name"] == "search" |
| print(" [PASS] Mixed content with tool call") |
|
|
| print("\nAll tests passed.") |
|
|