| """Domain A v2 academic bundle: dialogue normalization + staged QA / gold state.""" |
|
|
| from __future__ import annotations |
|
|
| import json |
| from dataclasses import dataclass |
| from pathlib import Path |
| from typing import Any, Iterator, Mapping |
|
|
| from eval_framework.datasets.schemas import NormalizedTurn, normalize_turn |
| from eval_framework.pipeline.gold_state import ( |
| SessionGoldState, |
| build_session_gold_states, |
| ) |
|
|
|
|
| @dataclass(frozen=True) |
| class Stage4Record: |
| uuid: str |
| sample_id: str |
| memory_sessions: tuple[tuple[str, tuple[Mapping[str, Any], ...]], ...] |
|
|
|
|
| @dataclass(frozen=True) |
| class QARecord: |
| uuid: str |
| sample_id: str |
| raw_checkpoints: tuple[Mapping[str, Any], ...] |
|
|
|
|
| @dataclass(frozen=True) |
| class NormalizedCheckpointQuestion: |
| question: str |
| gold_answer: str |
| gold_evidence_memory_ids: tuple[str, ...] |
| gold_evidence_contents: tuple[str, ...] |
| question_type: str |
| question_type_abbrev: str |
| difficulty: str |
|
|
|
|
| @dataclass(frozen=True) |
| class NormalizedCheckpoint: |
| checkpoint_id: str |
| covered_sessions: tuple[str, ...] |
| questions: tuple[NormalizedCheckpointQuestion, ...] |
|
|
|
|
| @dataclass(frozen=True) |
| class DomainAV2Session: |
| session_id: str |
| turns: tuple[NormalizedTurn, ...] |
|
|
|
|
| @dataclass(frozen=True) |
| class DomainAV2AcademicSample: |
| uuid: str |
| sample_id: str |
| sessions: tuple[DomainAV2Session, ...] |
| stage4: Stage4Record |
| qa_record: QARecord |
| normalized_checkpoints: tuple[NormalizedCheckpoint, ...] |
| session_gold_states: tuple[SessionGoldState, ...] |
|
|
|
|
| @dataclass(frozen=True) |
| class DomainAV2AcademicBundle: |
| samples: tuple[DomainAV2AcademicSample, ...] |
|
|
|
|
| def _read_jsonl(path: Path) -> Iterator[dict[str, Any]]: |
| with path.open(encoding="utf-8") as fh: |
| for line in fh: |
| line = line.strip() |
| if not line: |
| continue |
| yield json.loads(line) |
|
|
|
|
| def _stage4_from_obj(obj: Mapping[str, Any]) -> Stage4Record: |
| blocks: list[tuple[str, tuple[Mapping[str, Any], ...]]] = [] |
| for ms in obj.get("memory_sessions") or []: |
| sid = str(ms.get("session_id", "")) |
| pts = ms.get("memory_points") or [] |
| if not isinstance(pts, list): |
| pts = [] |
| blocks.append((sid, tuple(pts))) |
| return Stage4Record( |
| uuid=str(obj["uuid"]), |
| sample_id=str(obj["sample_id"]), |
| memory_sessions=tuple(blocks), |
| ) |
|
|
|
|
| def _qa_from_obj(obj: Mapping[str, Any]) -> QARecord: |
| cps = obj.get("checkpoints") or [] |
| if not isinstance(cps, list): |
| cps = [] |
| return QARecord( |
| uuid=str(obj["uuid"]), |
| sample_id=str(obj["sample_id"]), |
| raw_checkpoints=tuple(cps), |
| ) |
|
|
|
|
| def _normalize_checkpoint_question( |
| raw: Mapping[str, Any], |
| memory_content_map: Mapping[str, str], |
| ) -> NormalizedCheckpointQuestion: |
| evidence = raw.get("evidence") or [] |
| mem_ids: list[str] = [] |
| mem_contents: list[str] = [] |
| if isinstance(evidence, list): |
| for item in evidence: |
| if isinstance(item, dict) and "memory_id" in item: |
| mid = str(item["memory_id"]) |
| mem_ids.append(mid) |
| content = memory_content_map.get(mid, "") |
| if content: |
| mem_contents.append(content) |
| return NormalizedCheckpointQuestion( |
| question=str(raw.get("question", "")), |
| gold_answer=str(raw.get("answer", "")), |
| gold_evidence_memory_ids=tuple(mem_ids), |
| gold_evidence_contents=tuple(mem_contents), |
| question_type=str(raw.get("question_type", "")), |
| question_type_abbrev=str(raw.get("question_type_abbrev", "")), |
| difficulty=str(raw.get("difficulty", "")), |
| ) |
|
|
|
|
| def _normalize_checkpoints( |
| raw_checkpoints: tuple[Mapping[str, Any], ...], |
| memory_content_map: Mapping[str, str], |
| ) -> tuple[NormalizedCheckpoint, ...]: |
| out: list[NormalizedCheckpoint] = [] |
| for cp in raw_checkpoints: |
| qs = cp.get("questions") or [] |
| if not isinstance(qs, list): |
| qs = [] |
| covered = cp.get("covered_sessions") or [] |
| if not isinstance(covered, list): |
| covered = [] |
| out.append( |
| NormalizedCheckpoint( |
| checkpoint_id=str(cp.get("checkpoint_id", "")), |
| covered_sessions=tuple(str(x) for x in covered), |
| questions=tuple( |
| _normalize_checkpoint_question(q, memory_content_map) |
| for q in qs |
| if isinstance(q, Mapping) |
| ), |
| ) |
| ) |
| return tuple(out) |
|
|
|
|
| def _dialogue_turns(sample_id: str, session_id: str, dialogue: list[Any]) -> tuple[NormalizedTurn, ...]: |
| turns: list[NormalizedTurn] = [] |
| for turn_index, entry in enumerate(dialogue): |
| if not isinstance(entry, dict): |
| continue |
| text = str(entry.get("content", "")) |
| attachments_raw = entry.get("attachments") or [] |
| captions: list[str] = [] |
| if isinstance(attachments_raw, list): |
| for att in attachments_raw: |
| if isinstance(att, dict): |
| cap = att.get("caption", "") |
| captions.append(cap if isinstance(cap, str) else str(cap)) |
| if captions: |
| text = text + "\n\n" + "\n".join(captions) |
| ts = entry.get("timestamp") |
| timestamp = ts if isinstance(ts, str) else (str(ts) if ts is not None else None) |
| raw_turn = { |
| "sample_id": sample_id, |
| "session_id": session_id, |
| "turn_index": turn_index, |
| "role": str(entry.get("role", "user")), |
| "text": text, |
| "attachments": [], |
| "timestamp": timestamp, |
| } |
| turns.append(normalize_turn(raw_turn)) |
| return tuple(turns) |
|
|
|
|
| def load_domain_a_v2_academic(data_dir: Path) -> DomainAV2AcademicBundle: |
| data_dir = data_dir.resolve() |
| main_path = data_dir / "domain_a_v2.json" |
| stage4_path = data_dir / "stage4_memory_points.jsonl" |
| qa_path = data_dir / "stage4b_qa_checkpoints.jsonl" |
|
|
| raw_samples = json.loads(main_path.read_text(encoding="utf-8")) |
| if not isinstance(raw_samples, list): |
| raise ValueError("domain_a_v2.json must be a list") |
|
|
| stage4_by_id: dict[str, Stage4Record] = {} |
| for obj in _read_jsonl(stage4_path): |
| rec = _stage4_from_obj(obj) |
| stage4_by_id[rec.sample_id] = rec |
|
|
| qa_by_id: dict[str, QARecord] = {} |
| for obj in _read_jsonl(qa_path): |
| rec = _qa_from_obj(obj) |
| qa_by_id[rec.sample_id] = rec |
|
|
| built: list[DomainAV2AcademicSample] = [] |
| for item in raw_samples: |
| if not isinstance(item, dict): |
| continue |
| sample_id = str(item["sample_id"]) |
| uuid = str(item["uuid"]) |
| stage4 = stage4_by_id.get(sample_id) |
| qa = qa_by_id.get(sample_id) |
| if stage4 is None or qa is None: |
| raise KeyError(f"missing stage4 or QA row for sample_id={sample_id}") |
|
|
| stage4_map = {sid: pts for sid, pts in stage4.memory_sessions} |
|
|
| sessions_raw = item.get("sessions") or [] |
| if not isinstance(sessions_raw, list): |
| sessions_raw = [] |
|
|
| session_blocks: list[DomainAV2Session] = [] |
| ordered_ids: list[str] = [] |
| s00_points: tuple[Mapping[str, Any], ...] = () |
|
|
| for sess in sessions_raw: |
| if not isinstance(sess, dict): |
| continue |
| sid = str(sess.get("_v2_session_id", "")) |
| if not sid: |
| continue |
| ordered_ids.append(sid) |
| dialogue = sess.get("dialogue") or [] |
| if not isinstance(dialogue, list): |
| dialogue = [] |
| session_blocks.append( |
| DomainAV2Session( |
| session_id=sid, |
| turns=_dialogue_turns(sample_id, sid, dialogue), |
| ) |
| ) |
| if sid == "S00": |
| mps = sess.get("memory_points") or [] |
| if isinstance(mps, list): |
| s00_points = tuple(mps) |
|
|
| gold_states = build_session_gold_states( |
| ordered_ids, |
| s00_memory_points=s00_points, |
| stage4_by_session_id=stage4_map, |
| ) |
|
|
| |
| memory_content_map: dict[str, str] = {} |
| for mp_raw in s00_points: |
| if isinstance(mp_raw, Mapping): |
| mid = mp_raw.get("memory_id") |
| mc = mp_raw.get("memory_content") |
| if mid is not None and mc is not None: |
| memory_content_map[str(mid)] = str(mc) |
| for _sid, pts in stage4.memory_sessions: |
| for mp_raw in pts: |
| if isinstance(mp_raw, Mapping): |
| mid = mp_raw.get("memory_id") |
| mc = mp_raw.get("memory_content") |
| if mid is not None and mc is not None: |
| memory_content_map[str(mid)] = str(mc) |
|
|
| built.append( |
| DomainAV2AcademicSample( |
| uuid=uuid, |
| sample_id=sample_id, |
| sessions=tuple(session_blocks), |
| stage4=stage4, |
| qa_record=qa, |
| normalized_checkpoints=_normalize_checkpoints( |
| qa.raw_checkpoints, memory_content_map |
| ), |
| session_gold_states=gold_states, |
| ) |
| ) |
|
|
| return DomainAV2AcademicBundle(samples=tuple(built)) |
|
|