Claude commited on
Commit
5d9b495
·
unverified ·
1 Parent(s): 9df1925

perf: Sprint 2a — caching, parallel pipeline, async I/O, tests

Browse files

Performance:
- Cache AI providers singleton (model_registry.py)
- LRU cache for prompt templates (prompt_loader.py)
- Cache profiles at startup (profiles.py)
- Parallelize corpus_runner with asyncio.Semaphore(3) + gather
- Wrap all blocking I/O in job_runner with asyncio.to_thread
- Defer base64 encoding in Mistral provider (skip for text-only models)

Code quality:
- Add pagination to list_manuscripts endpoint
- Add 31 new tests for response_parser and prompt_loader

https://claude.ai/code/session_012NCh8yLxMXkRmBYQgHCTik

backend/app/api/v1/corpora.py CHANGED
@@ -127,13 +127,19 @@ async def delete_corpus(corpus_id: str, db: AsyncSession = Depends(get_db)) -> N
127
 
128
  @router.get("/{corpus_id}/manuscripts", response_model=list[ManuscriptResponse])
129
  async def list_manuscripts(
130
- corpus_id: str, db: AsyncSession = Depends(get_db)
 
 
 
131
  ) -> list[ManuscriptModel]:
132
- """Retourne tous les manuscrits d'un corpus."""
133
  corpus = await db.get(CorpusModel, corpus_id)
134
  if corpus is None:
135
  raise HTTPException(status_code=404, detail="Corpus introuvable")
136
  result = await db.execute(
137
- select(ManuscriptModel).where(ManuscriptModel.corpus_id == corpus_id)
 
 
 
138
  )
139
  return list(result.scalars().all())
 
127
 
128
  @router.get("/{corpus_id}/manuscripts", response_model=list[ManuscriptResponse])
129
  async def list_manuscripts(
130
+ corpus_id: str,
131
+ db: AsyncSession = Depends(get_db),
132
+ skip: int = Query(0, ge=0, description="Nombre d'éléments à sauter"),
133
+ limit: int = Query(100, ge=1, le=1000, description="Nombre maximum d'éléments"),
134
  ) -> list[ManuscriptModel]:
135
+ """Retourne les manuscrits d'un corpus (paginé)."""
136
  corpus = await db.get(CorpusModel, corpus_id)
137
  if corpus is None:
138
  raise HTTPException(status_code=404, detail="Corpus introuvable")
139
  result = await db.execute(
140
+ select(ManuscriptModel)
141
+ .where(ManuscriptModel.corpus_id == corpus_id)
142
+ .offset(skip)
143
+ .limit(limit)
144
  )
145
  return list(result.scalars().all())
backend/app/api/v1/profiles.py CHANGED
@@ -26,6 +26,27 @@ logger = logging.getLogger(__name__)
26
 
27
  router = APIRouter(prefix="/profiles", tags=["profiles"])
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
  def _load_profile(path: Path) -> CorpusProfile | None:
31
  """Charge et valide un fichier de profil JSON. Retourne None si invalide."""
@@ -40,27 +61,8 @@ def _load_profile(path: Path) -> CorpusProfile | None:
40
  @router.get("", response_model=list[dict])
41
  async def list_profiles() -> list[dict]:
42
  """Retourne tous les profils valides du dossier profiles/."""
43
- logger.info(
44
- "Résolution profiles_dir",
45
- extra={
46
- "profiles_dir": str(settings.profiles_dir),
47
- "resolved": str(settings.profiles_dir.resolve()),
48
- "is_dir": settings.profiles_dir.is_dir(),
49
- },
50
- )
51
- if not settings.profiles_dir.is_dir():
52
- logger.warning("profiles_dir introuvable : %s", settings.profiles_dir)
53
- return []
54
-
55
- def _scan_profiles() -> list[dict]:
56
- result = []
57
- for path in sorted(settings.profiles_dir.glob("*.json")):
58
- profile = _load_profile(path)
59
- if profile is not None:
60
- result.append(profile.model_dump())
61
- return result
62
-
63
- return await asyncio.to_thread(_scan_profiles)
64
 
65
 
66
  _SAFE_ID_RE = re.compile(r"^[a-z0-9][a-z0-9_-]*$")
@@ -71,16 +73,9 @@ async def get_profile(profile_id: str) -> dict:
71
  """Retourne un profil par son id (nom du fichier sans extension)."""
72
  if not _SAFE_ID_RE.match(profile_id):
73
  raise HTTPException(status_code=400, detail="profile_id invalide")
74
- path = settings.profiles_dir / f"{profile_id}.json"
75
-
76
- def _read() -> CorpusProfile | None:
77
- if not path.exists():
78
- return None
79
- return _load_profile(path)
80
 
81
- profile = await asyncio.to_thread(_read)
82
- if profile is None and not path.exists():
83
- raise HTTPException(status_code=404, detail="Profil introuvable")
84
  if profile is None:
85
- raise HTTPException(status_code=422, detail="Profil invalide")
86
  return profile.model_dump()
 
26
 
27
  router = APIRouter(prefix="/profiles", tags=["profiles"])
28
 
29
+ _profiles_cache: dict[str, CorpusProfile] | None = None
30
+
31
+
32
+ def _load_all_profiles() -> dict[str, CorpusProfile]:
33
+ """Charge tous les profils depuis le disque (cache singleton)."""
34
+ global _profiles_cache
35
+ if _profiles_cache is not None:
36
+ return _profiles_cache
37
+
38
+ result: dict[str, CorpusProfile] = {}
39
+ if settings.profiles_dir.is_dir():
40
+ for path in sorted(settings.profiles_dir.glob("*.json")):
41
+ profile = _load_profile(path)
42
+ if profile is not None:
43
+ result[profile.profile_id] = profile
44
+ else:
45
+ logger.warning("profiles_dir introuvable : %s", settings.profiles_dir)
46
+
47
+ _profiles_cache = result
48
+ return _profiles_cache
49
+
50
 
51
  def _load_profile(path: Path) -> CorpusProfile | None:
52
  """Charge et valide un fichier de profil JSON. Retourne None si invalide."""
 
61
  @router.get("", response_model=list[dict])
62
  async def list_profiles() -> list[dict]:
63
  """Retourne tous les profils valides du dossier profiles/."""
64
+ profiles = _load_all_profiles()
65
+ return [p.model_dump() for p in profiles.values()]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
 
68
  _SAFE_ID_RE = re.compile(r"^[a-z0-9][a-z0-9_-]*$")
 
73
  """Retourne un profil par son id (nom du fichier sans extension)."""
74
  if not _SAFE_ID_RE.match(profile_id):
75
  raise HTTPException(status_code=400, detail="profile_id invalide")
 
 
 
 
 
 
76
 
77
+ profiles = _load_all_profiles()
78
+ profile = profiles.get(profile_id)
 
79
  if profile is None:
80
+ raise HTTPException(status_code=404, detail="Profil introuvable")
81
  return profile.model_dump()
backend/app/services/ai/model_registry.py CHANGED
@@ -23,24 +23,27 @@ _PROVIDER_DISPLAY_NAMES: dict[ProviderType, str] = {
23
  }
24
 
25
 
 
 
 
26
  def _build_providers() -> list[AIProvider]:
27
- """Construit la liste des providers — imports différés.
 
 
 
28
 
29
- Pas de cache global : la construction est triviale (4 objets légers)
30
- et l'absence de cache permet de détecter immédiatement les changements
31
- de variables d'environnement sans redémarrage.
32
- """
33
  from app.services.ai.provider_google_ai import GoogleAIProvider
34
  from app.services.ai.provider_mistral import MistralProvider
35
  from app.services.ai.provider_vertex_key import VertexAPIKeyProvider
36
  from app.services.ai.provider_vertex_sa import VertexServiceAccountProvider
37
 
38
- return [
39
  GoogleAIProvider(),
40
  VertexAPIKeyProvider(),
41
  VertexServiceAccountProvider(),
42
  MistralProvider(),
43
  ]
 
44
 
45
 
46
  def get_available_providers() -> list[dict]:
 
23
  }
24
 
25
 
26
+ _providers_cache: list[AIProvider] | None = None
27
+
28
+
29
  def _build_providers() -> list[AIProvider]:
30
+ """Construit la liste des providers — imports différés, cache singleton."""
31
+ global _providers_cache
32
+ if _providers_cache is not None:
33
+ return _providers_cache
34
 
 
 
 
 
35
  from app.services.ai.provider_google_ai import GoogleAIProvider
36
  from app.services.ai.provider_mistral import MistralProvider
37
  from app.services.ai.provider_vertex_key import VertexAPIKeyProvider
38
  from app.services.ai.provider_vertex_sa import VertexServiceAccountProvider
39
 
40
+ _providers_cache = [
41
  GoogleAIProvider(),
42
  VertexAPIKeyProvider(),
43
  VertexServiceAccountProvider(),
44
  MistralProvider(),
45
  ]
46
+ return _providers_cache
47
 
48
 
49
  def get_available_providers() -> list[dict]:
backend/app/services/ai/prompt_loader.py CHANGED
@@ -7,11 +7,21 @@ Le code charge le fichier, substitue les variables {{nom}}, envoie à l'API.
7
  # 1. stdlib
8
  import logging
9
  import re
 
10
  from pathlib import Path
11
 
12
  logger = logging.getLogger(__name__)
13
 
14
 
 
 
 
 
 
 
 
 
 
15
  def load_and_render_prompt(template_path: str | Path, context: dict[str, str]) -> str:
16
  """Charge un template de prompt depuis un fichier et substitue les variables.
17
 
@@ -30,10 +40,8 @@ def load_and_render_prompt(template_path: str | Path, context: dict[str, str]) -
30
  FileNotFoundError: si le fichier template n'existe pas.
31
  """
32
  path = Path(template_path)
33
- if not path.exists():
34
- raise FileNotFoundError(f"Template de prompt introuvable : {path}")
35
 
36
- template = path.read_text(encoding="utf-8")
37
 
38
  rendered = template
39
  for key, value in context.items():
 
7
  # 1. stdlib
8
  import logging
9
  import re
10
+ from functools import lru_cache
11
  from pathlib import Path
12
 
13
  logger = logging.getLogger(__name__)
14
 
15
 
16
+ @lru_cache(maxsize=32)
17
+ def _read_template(path_str: str) -> str:
18
+ """Lit un template depuis le disque avec cache LRU."""
19
+ path = Path(path_str)
20
+ if not path.exists():
21
+ raise FileNotFoundError(f"Template introuvable : {path_str}")
22
+ return path.read_text(encoding="utf-8")
23
+
24
+
25
  def load_and_render_prompt(template_path: str | Path, context: dict[str, str]) -> str:
26
  """Charge un template de prompt depuis un fichier et substitue les variables.
27
 
 
40
  FileNotFoundError: si le fichier template n'existe pas.
41
  """
42
  path = Path(template_path)
 
 
43
 
44
+ template = _read_template(str(path))
45
 
46
  rendered = template
47
  for key, value in context.items():
backend/app/services/ai/provider_mistral.py CHANGED
@@ -213,8 +213,10 @@ class MistralProvider(AIProvider):
213
  from mistralai import Mistral
214
 
215
  client = Mistral(api_key=os.environ[_ENV_KEY])
216
- image_b64 = base64.b64encode(image_bytes).decode("utf-8")
217
- data_url = f"data:image/jpeg;base64,{image_b64}"
 
 
218
 
219
  # ── Chemin 1 : OCR dédié ─────────────────────────────────────────────
220
  if _is_ocr_model(model_id):
@@ -222,7 +224,7 @@ class MistralProvider(AIProvider):
222
  try:
223
  response = client.ocr.process(
224
  model=model_id,
225
- document={"type": "image_url", "image_url": {"url": data_url}},
226
  )
227
  except Exception as exc:
228
  logger.error("Appel Mistral OCR échoué", extra={"model": model_id, "error": str(exc)})
@@ -236,7 +238,7 @@ class MistralProvider(AIProvider):
236
  # ── Chemin 2 : Vision multimodale ────────────────────────────────────
237
  if supports_vision:
238
  content: object = [
239
- {"type": "image_url", "image_url": {"url": data_url}},
240
  {"type": "text", "text": prompt},
241
  ]
242
  # ── Chemin 3 : Texte seul ─────────────────────────────────────────────
 
213
  from mistralai import Mistral
214
 
215
  client = Mistral(api_key=os.environ[_ENV_KEY])
216
+
217
+ # Encodage base64 différé — calculé uniquement si le modèle a besoin de l'image
218
+ def _image_data_url() -> str:
219
+ return f"data:image/jpeg;base64,{base64.b64encode(image_bytes).decode('utf-8')}"
220
 
221
  # ── Chemin 1 : OCR dédié ─────────────────────────────────────────────
222
  if _is_ocr_model(model_id):
 
224
  try:
225
  response = client.ocr.process(
226
  model=model_id,
227
+ document={"type": "image_url", "image_url": {"url": _image_data_url()}},
228
  )
229
  except Exception as exc:
230
  logger.error("Appel Mistral OCR échoué", extra={"model": model_id, "error": str(exc)})
 
238
  # ── Chemin 2 : Vision multimodale ────────────────────────────────────
239
  if supports_vision:
240
  content: object = [
241
+ {"type": "image_url", "image_url": {"url": _image_data_url()}},
242
  {"type": "text", "text": prompt},
243
  ]
244
  # ── Chemin 3 : Texte seul ─────────────────────────────────────────────
backend/app/services/corpus_runner.py CHANGED
@@ -9,6 +9,7 @@ Point d'entrée : execute_corpus_job(corpus_id)
9
  Chaque page reçoit sa propre session pour isoler les échecs.
10
  """
11
  # 1. stdlib
 
12
  import logging
13
 
14
  # 2. third-party
@@ -58,11 +59,18 @@ async def execute_corpus_job(corpus_id: str) -> dict:
58
  extra={"corpus_id": corpus_id, "jobs": len(job_ids)},
59
  )
60
 
61
- # Exécution séquentielle — chaque job gère sa propre session
62
  from app.services.job_runner import execute_page_job
63
 
64
- for job_id in job_ids:
65
- await execute_page_job(job_id)
 
 
 
 
 
 
 
66
 
67
  # Bilan final
68
  async with async_session_factory() as db:
 
9
  Chaque page reçoit sa propre session pour isoler les échecs.
10
  """
11
  # 1. stdlib
12
+ import asyncio
13
  import logging
14
 
15
  # 2. third-party
 
59
  extra={"corpus_id": corpus_id, "jobs": len(job_ids)},
60
  )
61
 
62
+ # Exécution concurrente avec semaphore — chaque job gère sa propre session
63
  from app.services.job_runner import execute_page_job
64
 
65
+ _MAX_CONCURRENT = 3 # limiter la pression sur les APIs IA
66
+
67
+ sem = asyncio.Semaphore(_MAX_CONCURRENT)
68
+
69
+ async def _run_one(jid: str) -> None:
70
+ async with sem:
71
+ await execute_page_job(jid)
72
+
73
+ await asyncio.gather(*[_run_one(jid) for jid in job_ids])
74
 
75
  # Bilan final
76
  async with async_session_factory() as db:
backend/app/services/job_runner.py CHANGED
@@ -17,6 +17,7 @@ Sur toute exception : job → FAILED + error_message, page → ERROR.
17
  Aucun échec silencieux (CLAUDE.md §7).
18
  """
19
  # 1. stdlib
 
20
  import json
21
  import logging
22
  from datetime import datetime, timezone
@@ -139,7 +140,8 @@ async def _run_job_impl(job_id: str, db: AsyncSession) -> None:
139
 
140
  if page.iiif_service_url:
141
  # ── Mode IIIF natif : fetch en mémoire, zéro stockage ────────────
142
- deriv_bytes, deriv_w, deriv_h = fetch_ai_derivative_bytes(
 
143
  iiif_service_url=page.iiif_service_url,
144
  fallback_url=None,
145
  )
@@ -153,7 +155,8 @@ async def _run_job_impl(job_id: str, db: AsyncSession) -> None:
153
  )
154
 
155
  # ── 6. Analyse primaire IA (R05 : double stockage) ───────────────
156
- page_master = run_primary_analysis(
 
157
  derivative_image_bytes=deriv_bytes,
158
  derivative_width=deriv_w,
159
  derivative_height=deriv_h,
@@ -171,10 +174,12 @@ async def _run_job_impl(job_id: str, db: AsyncSession) -> None:
171
 
172
  elif image_source.startswith(("http://", "https://")):
173
  # ── Mode fallback URL : télécharge + stocke sur disque (legacy) ──
174
- image_info = fetch_and_normalize(
 
175
  image_source, corpus.slug, page.folio_label, data_dir
176
  )
177
- page_master = run_primary_analysis(
 
178
  derivative_image_path=Path(image_info.derivative_path),
179
  corpus_profile=corpus_profile,
180
  model_config=model_config,
@@ -198,10 +203,12 @@ async def _run_job_impl(job_id: str, db: AsyncSession) -> None:
198
  f"{image_source!r} (résolu : {source_path})"
199
  )
200
  source_bytes = source_path.read_bytes()
201
- image_info = create_derivatives(
 
202
  source_bytes, image_source, corpus.slug, page.folio_label, data_dir
203
  )
204
- page_master = run_primary_analysis(
 
205
  derivative_image_path=Path(image_info.derivative_path),
206
  corpus_profile=corpus_profile,
207
  model_config=model_config,
 
17
  Aucun échec silencieux (CLAUDE.md §7).
18
  """
19
  # 1. stdlib
20
+ import asyncio
21
  import json
22
  import logging
23
  from datetime import datetime, timezone
 
140
 
141
  if page.iiif_service_url:
142
  # ── Mode IIIF natif : fetch en mémoire, zéro stockage ────────────
143
+ deriv_bytes, deriv_w, deriv_h = await asyncio.to_thread(
144
+ fetch_ai_derivative_bytes,
145
  iiif_service_url=page.iiif_service_url,
146
  fallback_url=None,
147
  )
 
155
  )
156
 
157
  # ── 6. Analyse primaire IA (R05 : double stockage) ───────────────
158
+ page_master = await asyncio.to_thread(
159
+ run_primary_analysis,
160
  derivative_image_bytes=deriv_bytes,
161
  derivative_width=deriv_w,
162
  derivative_height=deriv_h,
 
174
 
175
  elif image_source.startswith(("http://", "https://")):
176
  # ── Mode fallback URL : télécharge + stocke sur disque (legacy) ──
177
+ image_info = await asyncio.to_thread(
178
+ fetch_and_normalize,
179
  image_source, corpus.slug, page.folio_label, data_dir
180
  )
181
+ page_master = await asyncio.to_thread(
182
+ run_primary_analysis,
183
  derivative_image_path=Path(image_info.derivative_path),
184
  corpus_profile=corpus_profile,
185
  model_config=model_config,
 
203
  f"{image_source!r} (résolu : {source_path})"
204
  )
205
  source_bytes = source_path.read_bytes()
206
+ image_info = await asyncio.to_thread(
207
+ create_derivatives,
208
  source_bytes, image_source, corpus.slug, page.folio_label, data_dir
209
  )
210
+ page_master = await asyncio.to_thread(
211
+ run_primary_analysis,
212
  derivative_image_path=Path(image_info.derivative_path),
213
  corpus_profile=corpus_profile,
214
  model_config=model_config,
backend/tests/test_prompt_loader.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Tests pour prompt_loader.py — chargement, rendu, et détection de variables non résolues.
3
+ """
4
+ import pytest
5
+
6
+ from app.services.ai.prompt_loader import load_and_render_prompt
7
+
8
+
9
+ @pytest.fixture()
10
+ def template_file(tmp_path):
11
+ """Crée un fichier template temporaire."""
12
+ def _create(content: str, name: str = "test_prompt.txt"):
13
+ path = tmp_path / name
14
+ path.write_text(content, encoding="utf-8")
15
+ return path
16
+ return _create
17
+
18
+
19
+ class TestLoadAndRenderPrompt:
20
+ def test_simple_substitution(self, template_file):
21
+ path = template_file("Analysez ce {{profile_label}} en {{script_type}}.")
22
+ result = load_and_render_prompt(path, {
23
+ "profile_label": "manuscrit enluminé",
24
+ "script_type": "caroline",
25
+ })
26
+ assert result == "Analysez ce manuscrit enluminé en caroline."
27
+
28
+ def test_multiple_occurrences(self, template_file):
29
+ path = template_file("{{lang}} et encore {{lang}}")
30
+ result = load_and_render_prompt(path, {"lang": "latin"})
31
+ assert result == "latin et encore latin"
32
+
33
+ def test_unused_context_keys_ignored(self, template_file):
34
+ path = template_file("Hello {{name}}")
35
+ result = load_and_render_prompt(path, {"name": "World", "extra": "unused"})
36
+ assert result == "Hello World"
37
+
38
+ def test_unresolved_variable_raises(self, template_file):
39
+ path = template_file("{{resolved}} but {{unresolved}} remains")
40
+ with pytest.raises(ValueError, match="Variables non résolues"):
41
+ load_and_render_prompt(path, {"resolved": "ok"})
42
+
43
+ def test_file_not_found_raises(self, tmp_path):
44
+ with pytest.raises(FileNotFoundError, match="Template introuvable"):
45
+ load_and_render_prompt(tmp_path / "nonexistent.txt", {})
46
+
47
+ def test_empty_template(self, template_file):
48
+ path = template_file("")
49
+ result = load_and_render_prompt(path, {"key": "value"})
50
+ assert result == ""
51
+
52
+ def test_no_variables_in_template(self, template_file):
53
+ path = template_file("Just plain text, no variables.")
54
+ result = load_and_render_prompt(path, {})
55
+ assert result == "Just plain text, no variables."
56
+
57
+ def test_path_as_string(self, template_file):
58
+ path = template_file("{{x}}")
59
+ result = load_and_render_prompt(str(path), {"x": "replaced"})
60
+ assert result == "replaced"
backend/tests/test_response_parser.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Tests pour response_parser.py — extraction JSON, correction VLM, parsing tolérant.
3
+ """
4
+ import pytest
5
+
6
+ from app.services.ai.response_parser import (
7
+ ParseError,
8
+ _extract_json_object,
9
+ _fix_common_json_issues,
10
+ _try_parse_json,
11
+ parse_ai_response,
12
+ )
13
+
14
+
15
+ # ── _extract_json_object ─────────────────────────────────────────────────────
16
+
17
+ class TestExtractJsonObject:
18
+ def test_simple_object(self):
19
+ assert _extract_json_object('{"a": 1}') == '{"a": 1}'
20
+
21
+ def test_text_before_json(self):
22
+ result = _extract_json_object('Here is the JSON: {"a": 1}')
23
+ assert result == '{"a": 1}'
24
+
25
+ def test_text_after_json(self):
26
+ result = _extract_json_object('{"a": 1} and more text')
27
+ assert result == '{"a": 1}'
28
+
29
+ def test_nested_braces(self):
30
+ result = _extract_json_object('{"a": {"b": {"c": 1}}}')
31
+ assert result == '{"a": {"b": {"c": 1}}}'
32
+
33
+ def test_braces_inside_strings(self):
34
+ result = _extract_json_object('{"text": "value with { and } inside"}')
35
+ assert result == '{"text": "value with { and } inside"}'
36
+
37
+ def test_escaped_quotes(self):
38
+ result = _extract_json_object('{"text": "he said \\"hello\\""}')
39
+ assert result == '{"text": "he said \\"hello\\""}'
40
+
41
+ def test_no_json(self):
42
+ result = _extract_json_object("no json here")
43
+ assert result == "no json here"
44
+
45
+ def test_unclosed_json(self):
46
+ result = _extract_json_object('some text {"a": 1')
47
+ assert result.startswith('{"a": 1')
48
+
49
+
50
+ # ── _fix_common_json_issues ──────────────────────────────────────────────────
51
+
52
+ class TestFixCommonJsonIssues:
53
+ def test_trailing_comma_before_brace(self):
54
+ assert _fix_common_json_issues('{"a": 1,}') == '{"a": 1}'
55
+
56
+ def test_trailing_comma_before_bracket(self):
57
+ assert _fix_common_json_issues('[1, 2,]') == '[1, 2]'
58
+
59
+ def test_trailing_comma_with_whitespace(self):
60
+ assert _fix_common_json_issues('{"a": 1 , }') == '{"a": 1 }'
61
+
62
+ def test_no_issues(self):
63
+ text = '{"a": 1, "b": 2}'
64
+ assert _fix_common_json_issues(text) == text
65
+
66
+
67
+ # ── _try_parse_json ──────────────────────────────────────────────────────────
68
+
69
+ class TestTryParseJson:
70
+ def test_valid_json(self):
71
+ assert _try_parse_json('{"a": 1}') == {"a": 1}
72
+
73
+ def test_json_with_trailing_comma(self):
74
+ result = _try_parse_json('{"a": 1,}')
75
+ assert result == {"a": 1}
76
+
77
+ def test_invalid_json(self):
78
+ assert _try_parse_json("not json at all") is None
79
+
80
+
81
+ # ── parse_ai_response ────────────────────────────────────────────────────────
82
+
83
+ class TestParseAiResponse:
84
+ def test_clean_json(self):
85
+ raw = '{"layout": {"regions": [{"id": "r1", "type": "text_block", "bbox": [10, 20, 100, 200], "confidence": 0.9}]}, "ocr": {"diplomatic_text": "hello", "confidence": 0.8}}'
86
+ layout, ocr = parse_ai_response(raw)
87
+ assert len(layout["regions"]) == 1
88
+ assert layout["regions"][0]["id"] == "r1"
89
+ assert ocr.diplomatic_text == "hello"
90
+
91
+ def test_markdown_fenced_json(self):
92
+ raw = '```json\n{"layout": {"regions": []}, "ocr": {"diplomatic_text": "test"}}\n```'
93
+ layout, ocr = parse_ai_response(raw)
94
+ assert layout["regions"] == []
95
+ assert ocr.diplomatic_text == "test"
96
+
97
+ def test_text_around_json(self):
98
+ raw = 'Here is my analysis:\n{"layout": {"regions": []}, "ocr": {"diplomatic_text": "ok"}}\nHope this helps!'
99
+ layout, ocr = parse_ai_response(raw)
100
+ assert ocr.diplomatic_text == "ok"
101
+
102
+ def test_invalid_region_skipped(self):
103
+ raw = '{"layout": {"regions": [{"id": "r1", "type": "text_block", "bbox": [-1, 0, 100, 200], "confidence": 0.5}, {"id": "r2", "type": "miniature", "bbox": [10, 20, 100, 200], "confidence": 0.8}]}}'
104
+ layout, ocr = parse_ai_response(raw)
105
+ assert len(layout["regions"]) == 1
106
+ assert layout["regions"][0]["id"] == "r2"
107
+
108
+ def test_missing_ocr_returns_default(self):
109
+ raw = '{"layout": {"regions": []}}'
110
+ layout, ocr = parse_ai_response(raw)
111
+ assert ocr.diplomatic_text == ""
112
+ assert ocr.confidence == 0.0
113
+
114
+ def test_not_json_raises_parse_error(self):
115
+ with pytest.raises(ParseError):
116
+ parse_ai_response("This is not JSON at all, no braces anywhere")
117
+
118
+ def test_json_array_raises_parse_error(self):
119
+ with pytest.raises(ParseError):
120
+ parse_ai_response("[1, 2, 3]")
121
+
122
+ def test_trailing_comma_tolerance(self):
123
+ raw = '{"layout": {"regions": [],}, "ocr": {"diplomatic_text": "tolerant",}}'
124
+ layout, ocr = parse_ai_response(raw)
125
+ assert ocr.diplomatic_text == "tolerant"