Spaces:
Running
Running
| """Build a simplified probe tag set from informativeness data. | |
| This script creates a small, bundle-balanced probe list intended for a single | |
| structured LLM probe query. If reliability results are provided, it also | |
| computes reliability-aware final selection flags. | |
| Outputs (overwrite): | |
| - data/simplified_probe_tags.csv | |
| - data/analysis/simplified_probe_tags_summary.json | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import csv | |
| import json | |
| from pathlib import Path | |
| from typing import Dict, List, Set | |
| REPO = Path(__file__).resolve().parents[1] | |
| PROBE_INFO_CSV = REPO / "data" / "analysis" / "probe_informativeness.csv" | |
| OUT_CSV = REPO / "data" / "simplified_probe_tags.csv" | |
| OUT_SUMMARY = REPO / "data" / "analysis" / "simplified_probe_tags_summary.json" | |
| BUNDLE_SPECS = { | |
| "clothing_state": { | |
| "cap": 6, | |
| "force": ["clothing", "clothed", "topwear", "bottomwear", "topless", "nude"], | |
| "deny": {"5_fingers"}, | |
| }, | |
| "scene_pose": { | |
| "cap": 4, | |
| "force": ["simple_background", "standing", "sitting", "outside"], | |
| "deny": set(), | |
| }, | |
| "gaze_expression": { | |
| "cap": 5, | |
| "force": ["smile", "looking_at_viewer", "open_mouth", "blush", "eyes_closed"], | |
| "deny": set(), | |
| }, | |
| "text_symbols": { | |
| "cap": 3, | |
| "force": ["text", "dialogue", "<3"], | |
| "deny": set(), | |
| }, | |
| "body_type_presence": { | |
| "cap": 4, | |
| "force": ["anthro", "feral", "biped", "humanoid"], | |
| "deny": set(), | |
| }, | |
| "count_cardinality": { | |
| "cap": 5, | |
| "force": ["zero_pictured", "solo", "duo", "trio", "group"], | |
| "deny": {"husky", "marsupial", "black_bars"}, | |
| }, | |
| "body_shape_breasts": { | |
| "cap": 4, | |
| "force": ["breasts", "big_breasts", "wide_hips", "thick_thighs"], | |
| "deny": set(), | |
| }, | |
| "species_taxonomy": { | |
| "cap": 6, | |
| "force": ["canid", "canis", "felid", "leporid", "bird", "bear", "unicorn", "equid"], | |
| "deny": {"mammal"}, | |
| }, | |
| } | |
| def _load_probe_rows(path: Path) -> List[Dict[str, str]]: | |
| with path.open("r", encoding="utf-8", newline="") as f: | |
| return list(csv.DictReader(f)) | |
| def _load_reliability(path: Path) -> Dict[str, Dict[str, str]]: | |
| if not path or not path.is_file(): | |
| return {} | |
| with path.open("r", encoding="utf-8", newline="") as f: | |
| rows = list(csv.DictReader(f)) | |
| return {r["tag"]: r for r in rows} | |
| def _as_float(v: str, default: float = 0.0) -> float: | |
| try: | |
| return float(v) | |
| except Exception: | |
| return default | |
| def _as_int(v: str, default: int = 0) -> int: | |
| try: | |
| return int(v) | |
| except Exception: | |
| return default | |
| def main() -> None: | |
| ap = argparse.ArgumentParser(description="Build simplified probe set.") | |
| ap.add_argument("--probe-info", type=Path, default=PROBE_INFO_CSV) | |
| ap.add_argument("--reliability-csv", type=Path, default=None, help="Optional probe reliability CSV.") | |
| ap.add_argument("--min-prevalence", type=float, default=0.01) | |
| ap.add_argument("--max-prevalence", type=float, default=0.70) | |
| ap.add_argument("--min-support-pos", type=int, default=5) | |
| ap.add_argument("--min-f1-strong", type=float, default=0.45) | |
| ap.add_argument("--min-precision-strong", type=float, default=0.50) | |
| args = ap.parse_args() | |
| if not args.probe_info.is_file(): | |
| raise FileNotFoundError(f"Missing probe informativeness CSV: {args.probe_info}") | |
| rows = _load_probe_rows(args.probe_info) | |
| rel = _load_reliability(args.reliability_csv) if args.reliability_csv else {} | |
| by_bundle: Dict[str, List[Dict[str, str]]] = {} | |
| by_tag: Dict[str, Dict[str, str]] = {} | |
| for r in rows: | |
| tag = r["tag"] | |
| by_tag[tag] = r | |
| b = r.get("suggested_probe_bundle", "other") | |
| by_bundle.setdefault(b, []).append(r) | |
| for b in by_bundle: | |
| by_bundle[b].sort(key=lambda x: _as_float(x.get("actionable_score", "0")), reverse=True) | |
| selected_initial: List[Dict[str, str]] = [] | |
| selected_tags: Set[str] = set() | |
| for bundle, spec in BUNDLE_SPECS.items(): | |
| cap = int(spec["cap"]) | |
| deny = set(spec["deny"]) | |
| forced = spec["force"] | |
| candidates = by_bundle.get(bundle, []) | |
| def ok(r: Dict[str, str]) -> bool: | |
| tag = r["tag"] | |
| p = _as_float(r.get("prevalence", "0")) | |
| return tag not in deny and args.min_prevalence <= p <= args.max_prevalence | |
| for t in forced: | |
| r = by_tag.get(t) | |
| if not r or not ok(r): | |
| continue | |
| if t in selected_tags: | |
| continue | |
| selected_initial.append(r) | |
| selected_tags.add(t) | |
| if sum(1 for x in selected_initial if x.get("suggested_probe_bundle") == bundle) >= cap: | |
| break | |
| if sum(1 for x in selected_initial if x.get("suggested_probe_bundle") == bundle) >= cap: | |
| continue | |
| for r in candidates: | |
| if not ok(r): | |
| continue | |
| t = r["tag"] | |
| if t in selected_tags: | |
| continue | |
| selected_initial.append(r) | |
| selected_tags.add(t) | |
| if sum(1 for x in selected_initial if x.get("suggested_probe_bundle") == bundle) >= cap: | |
| break | |
| # Reliability-aware scoring (if reliability CSV exists). | |
| out_rows: List[Dict[str, str]] = [] | |
| for r in selected_initial: | |
| tag = r["tag"] | |
| rr = rel.get(tag, {}) | |
| support_pos = _as_int(rr.get("support_pos", "0")) | |
| precision_strong = _as_float(rr.get("precision_strong", "0")) | |
| recall_strong = _as_float(rr.get("recall_strong", "0")) | |
| f1_strong = _as_float(rr.get("f1_strong", "0")) | |
| actionable = _as_float(r.get("actionable_score", "0")) | |
| has_rel = bool(rr) | |
| if has_rel: | |
| reliability_weight = f1_strong | |
| final_score = actionable * (0.25 + 0.75 * reliability_weight) | |
| selected_final = int( | |
| support_pos >= args.min_support_pos | |
| and f1_strong >= args.min_f1_strong | |
| and precision_strong >= args.min_precision_strong | |
| ) | |
| rel_note = ( | |
| f"support={support_pos}, f1={f1_strong:.3f}, " | |
| f"prec={precision_strong:.3f}, rec={recall_strong:.3f}" | |
| ) | |
| else: | |
| reliability_weight = 0.0 | |
| final_score = actionable | |
| selected_final = 0 | |
| rel_note = "no_reliability_data" | |
| out_rows.append( | |
| { | |
| "tag": tag, | |
| "bundle": r.get("suggested_probe_bundle", "other"), | |
| "needs_glossary": r.get("needs_glossary", "0"), | |
| "prevalence": r.get("prevalence", ""), | |
| "actionable_score": f"{actionable:.6f}", | |
| "selected_initial": "1", | |
| "support_pos": str(support_pos), | |
| "precision_strong": f"{precision_strong:.6f}", | |
| "recall_strong": f"{recall_strong:.6f}", | |
| "f1_strong": f"{f1_strong:.6f}", | |
| "reliability_weight": f"{reliability_weight:.6f}", | |
| "final_score": f"{final_score:.6f}", | |
| "selected_final": str(selected_final), | |
| "reliability_note": rel_note, | |
| } | |
| ) | |
| out_rows.sort(key=lambda x: (_as_float(x["final_score"]), _as_float(x["actionable_score"])), reverse=True) | |
| OUT_CSV.parent.mkdir(parents=True, exist_ok=True) | |
| with OUT_CSV.open("w", encoding="utf-8", newline="") as f: | |
| writer = csv.DictWriter( | |
| f, | |
| fieldnames=[ | |
| "tag", | |
| "bundle", | |
| "needs_glossary", | |
| "prevalence", | |
| "actionable_score", | |
| "selected_initial", | |
| "support_pos", | |
| "precision_strong", | |
| "recall_strong", | |
| "f1_strong", | |
| "reliability_weight", | |
| "final_score", | |
| "selected_final", | |
| "reliability_note", | |
| ], | |
| ) | |
| writer.writeheader() | |
| writer.writerows(out_rows) | |
| selected_final_tags = [r["tag"] for r in out_rows if r["selected_final"] == "1"] | |
| bundle_specs_json = {} | |
| for k, v in BUNDLE_SPECS.items(): | |
| bundle_specs_json[k] = { | |
| "cap": v["cap"], | |
| "force": list(v["force"]), | |
| "deny": sorted(list(v["deny"])), | |
| } | |
| summary = { | |
| "probe_info_csv": str(args.probe_info), | |
| "reliability_csv": str(args.reliability_csv) if args.reliability_csv else None, | |
| "n_selected_initial": len(out_rows), | |
| "n_selected_final": len(selected_final_tags), | |
| "selected_final_tags": selected_final_tags, | |
| "bundle_specs": bundle_specs_json, | |
| "thresholds": { | |
| "min_prevalence": args.min_prevalence, | |
| "max_prevalence": args.max_prevalence, | |
| "min_support_pos": args.min_support_pos, | |
| "min_f1_strong": args.min_f1_strong, | |
| "min_precision_strong": args.min_precision_strong, | |
| }, | |
| "outputs": { | |
| "csv": str(OUT_CSV), | |
| "summary_json": str(OUT_SUMMARY), | |
| }, | |
| } | |
| with OUT_SUMMARY.open("w", encoding="utf-8") as f: | |
| json.dump(summary, f, indent=2, ensure_ascii=False) | |
| print(f"Selected initial probes: {len(out_rows)}") | |
| print(f"Selected final probes: {len(selected_final_tags)}") | |
| print(f"Outputs: {OUT_CSV}, {OUT_SUMMARY}") | |
| if __name__ == "__main__": | |
| main() | |