| |
|
|
| from __future__ import annotations |
|
|
| import copy |
| import io |
| import json |
| import os |
| import sys |
| from datetime import datetime, timezone |
| from pathlib import Path |
|
|
| import streamlit as st |
| from huggingface_hub import HfApi, hf_hub_download |
| from huggingface_hub.utils import EntryNotFoundError |
|
|
| SCRIPT_DIR = Path(__file__).resolve().parent |
| if str(SCRIPT_DIR) not in sys.path: |
| sys.path.insert(0, str(SCRIPT_DIR)) |
|
|
| import validate_compliance_prm as validator |
|
|
| ROOT = Path(__file__).resolve().parents[1] |
| SOURCE_BUNDLES_PATH = ROOT / "data" / "bundles" / "pilot_bundles_v1.jsonl" |
| GUIDELINE_PATH = ROOT / "data" / "docs" / "pilot_annotation_guideline_v1.md" |
| ANNOTATIONS_DIR = ROOT / "data" / "annotations" |
|
|
| TARGET_BUNDLE_IDS = [ |
| "A17_CN_BUNDLE", |
| "A17_US_BUNDLE", |
| "A17_ISLAMIC_BUNDLE", |
| "M29_CN_BUNDLE", |
| "M29_US_BUNDLE", |
| "M29_ISLAMIC_BUNDLE", |
| ] |
| TRACE_LABELS = ["compliant", "deadline_missed", "hard_violation"] |
| STATUSES = ["in_progress", "final"] |
|
|
|
|
| def load_jsonl(path: Path) -> list[dict]: |
| with path.open("r", encoding="utf-8") as handle: |
| return [json.loads(line) for line in handle if line.strip()] |
|
|
|
|
| def load_source_bundles() -> dict[str, dict]: |
| bundles = { |
| bundle["bundle_id"]: bundle |
| for bundle in load_jsonl(SOURCE_BUNDLES_PATH) |
| if bundle["bundle_id"] in TARGET_BUNDLE_IDS |
| } |
| return {bundle_id: bundles[bundle_id] for bundle_id in TARGET_BUNDLE_IDS} |
|
|
|
|
| def annotation_path(annotator_id: str, bundle_id: str) -> Path: |
| return ANNOTATIONS_DIR / annotator_id / f"{bundle_id}.json" |
|
|
|
|
| def dataset_repo_id() -> str: |
| return os.getenv("HF_DATASET_REPO", "").strip() |
|
|
|
|
| def dataset_repo_subdir() -> str: |
| return os.getenv("HF_DATASET_SUBDIR", "annotations").strip().strip("/") or "annotations" |
|
|
|
|
| def hf_token() -> str: |
| for key in ("HF_TOKEN", "HUGGINGFACEHUB_API_TOKEN"): |
| value = os.getenv(key, "").strip() |
| if value: |
| return value |
| return "" |
|
|
|
|
| def storage_backend() -> str: |
| if dataset_repo_id() and hf_token(): |
| return "hf_dataset" |
| return "local" |
|
|
|
|
| def dataset_repo_path(annotator_id: str, bundle_id: str) -> str: |
| return f"{dataset_repo_subdir()}/{annotator_id}/{bundle_id}.json" |
|
|
|
|
| def build_initial_annotation(bundle: dict, annotator_id: str) -> dict: |
| annotation = copy.deepcopy(bundle) |
| annotation["annotator_id"] = annotator_id |
| annotation["status"] = "in_progress" |
| annotation["updated_at"] = None |
| annotation["change_notes"] = "" |
| return annotation |
|
|
|
|
| def save_local_annotation(payload: dict) -> Path: |
| path = annotation_path(payload["annotator_id"], payload["bundle_id"]) |
| path.parent.mkdir(parents=True, exist_ok=True) |
| with path.open("w", encoding="utf-8") as handle: |
| json.dump(payload, handle, indent=2, ensure_ascii=False) |
| return path |
|
|
|
|
| def load_remote_annotation(bundle: dict, annotator_id: str) -> dict | None: |
| try: |
| downloaded_path = hf_hub_download( |
| repo_id=dataset_repo_id(), |
| filename=dataset_repo_path(annotator_id, bundle["bundle_id"]), |
| repo_type="dataset", |
| token=hf_token(), |
| ) |
| except EntryNotFoundError: |
| return None |
| except Exception: |
| return None |
|
|
| with Path(downloaded_path).open("r", encoding="utf-8") as handle: |
| return json.load(handle) |
|
|
|
|
| def save_remote_annotation(payload: dict) -> str: |
| repo_id = dataset_repo_id() |
| api = HfApi(token=hf_token()) |
| api.create_repo(repo_id=repo_id, repo_type="dataset", exist_ok=True, private=True) |
|
|
| repo_path = dataset_repo_path(payload["annotator_id"], payload["bundle_id"]) |
| payload_bytes = json.dumps(payload, indent=2, ensure_ascii=False).encode("utf-8") |
| api.upload_file( |
| path_or_fileobj=io.BytesIO(payload_bytes), |
| path_in_repo=repo_path, |
| repo_id=repo_id, |
| repo_type="dataset", |
| commit_message=f"Update annotation: {payload['bundle_id']} ({payload['annotator_id']})", |
| ) |
| return f"hf://datasets/{repo_id}/{repo_path}" |
|
|
|
|
| def load_annotation(bundle: dict, annotator_id: str) -> dict: |
| path = annotation_path(annotator_id, bundle["bundle_id"]) |
| if path.exists(): |
| with path.open("r", encoding="utf-8") as handle: |
| return json.load(handle) |
| if storage_backend() == "hf_dataset": |
| remote = load_remote_annotation(bundle, annotator_id) |
| if remote is not None: |
| return remote |
| return build_initial_annotation(bundle, annotator_id) |
|
|
|
|
| def save_annotation(annotation: dict) -> str: |
| payload = copy.deepcopy(annotation) |
| payload["updated_at"] = datetime.now(timezone.utc).isoformat() |
| local_path = save_local_annotation(payload) |
| if storage_backend() == "hf_dataset": |
| remote_path = save_remote_annotation(payload) |
| return f"{remote_path} (local mirror: {local_path})" |
| return str(local_path) |
|
|
|
|
| def require_password() -> None: |
| expected_password = os.getenv("ANNOTATION_APP_PASSWORD", "").strip() |
| if not expected_password: |
| return |
|
|
| if st.session_state.get("authenticated"): |
| return |
|
|
| st.title("CPRM Annotation App") |
| st.caption("This instance is password-protected.") |
| typed_password = st.text_input("Shared Password", type="password") |
| if st.button("Unlock"): |
| if typed_password == expected_password: |
| st.session_state["authenticated"] = True |
| st.rerun() |
| st.error("Incorrect password.") |
| st.stop() |
|
|
|
|
| def read_guideline() -> str: |
| if GUIDELINE_PATH.exists(): |
| return GUIDELINE_PATH.read_text(encoding="utf-8") |
| return "Guideline file not found. Generate `data/docs/pilot_annotation_guideline_v1.md` first." |
|
|
|
|
| def reset_guideline_gate() -> None: |
| st.session_state["guideline_acknowledged"] = False |
| st.session_state["guideline_confirmed_for"] = None |
|
|
|
|
| def render_guideline_gate() -> None: |
| st.title("CPRM Pilot Annotation App") |
| st.caption("Step 1 of 2: read the guideline, confirm it, then enter the annotation workspace.") |
|
|
| annotator_id = st.text_input( |
| "Annotator ID", |
| value=st.session_state.get("annotator_id", "solo_annotator"), |
| help="Use a stable ID so saved files go to a consistent annotation folder.", |
| ).strip() |
| st.session_state["annotator_id"] = annotator_id |
|
|
| if not annotator_id: |
| st.info("Enter an annotator ID before continuing.") |
| st.stop() |
|
|
| st.subheader("Guideline") |
| st.markdown(read_guideline()) |
|
|
| acknowledged = st.checkbox( |
| "I have read the guideline and I understand that round 1 only edits existing fields and does not change step count.", |
| value=False, |
| key="guideline_ack_checkbox", |
| ) |
|
|
| if st.button("Enter Annotation Workspace", type="primary", disabled=not acknowledged): |
| st.session_state["guideline_acknowledged"] = True |
| st.session_state["guideline_confirmed_for"] = annotator_id |
| st.rerun() |
|
|
| st.stop() |
|
|
|
|
| def ensure_working_annotation(source_bundle: dict, annotator_id: str, bundle_id: str) -> dict: |
| state_key = "working_bundle_key" |
| target_key = f"{annotator_id}:{bundle_id}" |
| if st.session_state.get(state_key) != target_key: |
| st.session_state[state_key] = target_key |
| st.session_state["working_bundle"] = load_annotation(source_bundle, annotator_id) |
| return copy.deepcopy(st.session_state["working_bundle"]) |
|
|
|
|
| def step_key(bundle_id: str, trace_id: str, step_id: int, field: str, suffix: str = "") -> str: |
| extra = f":{suffix}" if suffix else "" |
| return f"{bundle_id}:{trace_id}:{step_id}:{field}{extra}" |
|
|
|
|
| def trace_key(bundle_id: str, trace_id: str, field: str) -> str: |
| return f"{bundle_id}:{trace_id}:{field}" |
|
|
|
|
| def get_rule_options(bundle: dict) -> list[str]: |
| seen: set[str] = set() |
| ordered_rule_ids: list[str] = [] |
|
|
| def add_rule(rule_id: str | None) -> None: |
| if not rule_id or rule_id in seen: |
| return |
| seen.add(rule_id) |
| ordered_rule_ids.append(rule_id) |
|
|
| for rule_id in bundle["rulebook"]: |
| add_rule(rule_id) |
|
|
| for candidate in bundle["candidates"]: |
| for step in candidate["steps"]: |
| for rule_id in step["active_rule_ids"]: |
| add_rule(rule_id) |
| add_rule(step["violated_rule_id"]) |
| for rule_id in step["soft_coverage_delta"]: |
| add_rule(rule_id) |
|
|
| return ordered_rule_ids |
|
|
|
|
| def render_metadata(bundle: dict, annotator_id: str) -> tuple[str, str]: |
| with st.sidebar: |
| if st.button("Back To Guideline"): |
| reset_guideline_gate() |
| st.rerun() |
|
|
| st.header("Bundle") |
| st.write(f"`{bundle['bundle_id']}`") |
| st.write(f"Annotator: `{annotator_id}`") |
| st.write(f"Jurisdiction: `{bundle['jurisdiction']}`") |
| st.write(f"Mode: `{bundle['mode']}`") |
| st.write(f"Storage backend: `{storage_backend()}`") |
| if storage_backend() == "hf_dataset": |
| st.write(f"Dataset repo: `{dataset_repo_id()}`") |
| st.write("Rulebook:") |
| for rule_id in bundle["rulebook"]: |
| st.code(rule_id) |
|
|
| status = st.selectbox( |
| "Bundle Status", |
| options=STATUSES, |
| index=STATUSES.index(bundle.get("status", "in_progress")), |
| key=f"{bundle['bundle_id']}:status", |
| ) |
| change_notes = st.text_area( |
| "Change Notes", |
| value=bundle.get("change_notes", ""), |
| height=160, |
| key=f"{bundle['bundle_id']}:change_notes", |
| help="Short note on what changed from the machine-generated version.", |
| ) |
|
|
| with st.expander("Guideline", expanded=False): |
| st.markdown(read_guideline()) |
|
|
| return status, change_notes |
|
|
|
|
| def render_step_editor(bundle: dict, trace: dict, step: dict, rule_options: list[str]) -> dict: |
| bundle_id = bundle["bundle_id"] |
| trace_id = trace["trace_id"] |
| step_id = step["step_id"] |
|
|
| st.markdown(f"**Step {step_id}:** `{step['text']}`") |
| action_type = st.selectbox( |
| f"Action Type ({step_id})", |
| options=sorted(validator.ALLOWED_ACTION_TYPES), |
| index=sorted(validator.ALLOWED_ACTION_TYPES).index(step["action_type"]), |
| key=step_key(bundle_id, trace_id, step_id, "action_type"), |
| ) |
| active_rule_ids = st.multiselect( |
| f"Active Rule IDs ({step_id})", |
| options=rule_options, |
| default=step["active_rule_ids"], |
| key=step_key(bundle_id, trace_id, step_id, "active_rule_ids"), |
| ) |
| hard_violation = st.checkbox( |
| f"Hard Violation ({step_id})", |
| value=bool(step["hard_violation"]), |
| key=step_key(bundle_id, trace_id, step_id, "hard_violation"), |
| ) |
| violated_rule_id = st.selectbox( |
| f"Violated Rule ID ({step_id})", |
| options=[None] + rule_options, |
| index=([None] + rule_options).index(step["violated_rule_id"]), |
| key=step_key(bundle_id, trace_id, step_id, "violated_rule_id"), |
| format_func=lambda value: "None" if value is None else value, |
| ) |
|
|
| st.caption("Soft Coverage Delta") |
| soft_coverage_delta: dict[str, float] = {} |
| columns = st.columns(len(rule_options) or 1) |
| for index, rule_id in enumerate(rule_options): |
| default_value = float(step["soft_coverage_delta"].get(rule_id, 0.0)) |
| with columns[index]: |
| value = st.number_input( |
| rule_id, |
| min_value=0.0, |
| max_value=1.0, |
| value=default_value, |
| step=0.05, |
| key=step_key(bundle_id, trace_id, step_id, "soft_delta", rule_id), |
| ) |
| if value > 0: |
| soft_coverage_delta[rule_id] = round(float(value), 2) |
|
|
| return { |
| "step_id": step_id, |
| "action_type": action_type, |
| "text": step["text"], |
| "active_rule_ids": active_rule_ids, |
| "hard_violation": int(hard_violation), |
| "violated_rule_id": violated_rule_id, |
| "soft_coverage_delta": soft_coverage_delta, |
| } |
|
|
|
|
| def render_trace_editor(bundle: dict, trace: dict) -> dict: |
| bundle_id = bundle["bundle_id"] |
| trace_id = trace["trace_id"] |
| rule_options = get_rule_options(bundle) |
|
|
| label = st.selectbox( |
| "Trace Label", |
| options=TRACE_LABELS, |
| index=TRACE_LABELS.index(trace["label"]), |
| key=trace_key(bundle_id, trace_id, "label"), |
| ) |
| overall_compliant = st.checkbox( |
| "Overall Compliant", |
| value=bool(trace["overall_compliant"]), |
| key=trace_key(bundle_id, trace_id, "overall_compliant"), |
| ) |
| step_ids = [step["step_id"] for step in trace["steps"]] |
| first_violation_step = st.selectbox( |
| "First Violation Step", |
| options=[None] + step_ids, |
| index=([None] + step_ids).index(trace["first_violation_step"]), |
| key=trace_key(bundle_id, trace_id, "first_violation_step"), |
| format_func=lambda value: "None" if value is None else f"Step {value}", |
| ) |
|
|
| edited_steps = [] |
| for step in trace["steps"]: |
| with st.container(border=True): |
| edited_steps.append(render_step_editor(bundle, trace, step, rule_options)) |
|
|
| edited_trace = copy.deepcopy(trace) |
| edited_trace["label"] = label |
| edited_trace["overall_compliant"] = overall_compliant |
| edited_trace["first_violation_step"] = first_violation_step |
| edited_trace["steps"] = edited_steps |
| return edited_trace |
|
|
|
|
| def render_bundle_editor(bundle: dict) -> dict: |
| tabs = st.tabs([candidate["trace_id"] for candidate in bundle["candidates"]]) |
| edited_candidates = [] |
| for tab, candidate in zip(tabs, bundle["candidates"]): |
| with tab: |
| edited_candidates.append(render_trace_editor(bundle, candidate)) |
|
|
| edited_bundle = copy.deepcopy(bundle) |
| edited_bundle["candidates"] = edited_candidates |
| return edited_bundle |
|
|
|
|
| def render_validation_panel(bundle: dict, valid_rule_ids: set[str]) -> None: |
| result = validator.validate_single_bundle(bundle, valid_rule_ids) |
| with st.expander("Validation", expanded=True): |
| st.write( |
| { |
| "ok": result["ok"], |
| "errors": len(result["errors"]), |
| "warnings": len(result["warnings"]), |
| } |
| ) |
| if result["errors"]: |
| st.error("\n".join(result["errors"])) |
| if result["warnings"]: |
| st.warning("\n".join(result["warnings"])) |
| if not result["errors"] and not result["warnings"]: |
| st.success("No validation issues detected.") |
|
|
|
|
| def main() -> None: |
| st.set_page_config(page_title="CPRM Annotation App", layout="wide") |
| require_password() |
|
|
| current_annotator = st.session_state.get("annotator_id", "").strip() |
| if ( |
| not st.session_state.get("guideline_acknowledged") |
| or st.session_state.get("guideline_confirmed_for") != current_annotator |
| ): |
| render_guideline_gate() |
|
|
| source_bundles = load_source_bundles() |
| valid_rule_ids = validator.load_rule_ids(validator.RULE_CARDS_PATH) |
|
|
| st.title("CPRM Pilot Annotation App") |
| st.caption( |
| "Step 2 of 2: annotate one of the 6 calibration bundles. Existing steps are editable, but step count is fixed." |
| ) |
|
|
| bundle_id = st.selectbox("Bundle", options=TARGET_BUNDLE_IDS) |
| source_bundle = source_bundles[bundle_id] |
| working_bundle = ensure_working_annotation(source_bundle, current_annotator, bundle_id) |
| status, change_notes = render_metadata(working_bundle, current_annotator) |
|
|
| left, right = st.columns([3, 2]) |
| with left: |
| edited_bundle = render_bundle_editor(working_bundle) |
| with right: |
| st.subheader("Scenario") |
| st.json( |
| { |
| "bundle_id": source_bundle["bundle_id"], |
| "scenario_id": source_bundle["scenario_id"], |
| "intent_id": source_bundle["intent_id"], |
| "jurisdiction": source_bundle["jurisdiction"], |
| "mode": source_bundle["mode"], |
| "rulebook": source_bundle["rulebook"], |
| }, |
| expanded=False, |
| ) |
|
|
| edited_bundle["annotator_id"] = current_annotator |
| edited_bundle["status"] = status |
| edited_bundle["change_notes"] = change_notes |
| edited_bundle["updated_at"] = working_bundle.get("updated_at") |
|
|
| render_validation_panel(edited_bundle, valid_rule_ids) |
|
|
| col1, col2 = st.columns(2) |
| with col1: |
| if st.button("Save Annotation", type="primary"): |
| saved_path = save_annotation(edited_bundle) |
| st.session_state["working_bundle"] = copy.deepcopy(edited_bundle) |
| st.success(f"Saved to {saved_path}") |
| with col2: |
| st.download_button( |
| "Download JSON", |
| data=json.dumps(edited_bundle, indent=2, ensure_ascii=False), |
| file_name=f"{edited_bundle['bundle_id']}.json", |
| mime="application/json", |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|