Prompt_Squirrel_RAG / scripts /build_simplified_probe_set.py
Food Desert
Consolidate probe configs and eval artifacts on main
6e50f4d
"""Build a simplified probe tag set from informativeness data.
This script creates a small, bundle-balanced probe list intended for a single
structured LLM probe query. If reliability results are provided, it also
computes reliability-aware final selection flags.
Outputs (overwrite):
- data/simplified_probe_tags.csv
- data/analysis/simplified_probe_tags_summary.json
"""
from __future__ import annotations
import argparse
import csv
import json
from pathlib import Path
from typing import Dict, List, Set
REPO = Path(__file__).resolve().parents[1]
PROBE_INFO_CSV = REPO / "data" / "analysis" / "probe_informativeness.csv"
OUT_CSV = REPO / "data" / "simplified_probe_tags.csv"
OUT_SUMMARY = REPO / "data" / "analysis" / "simplified_probe_tags_summary.json"
BUNDLE_SPECS = {
"clothing_state": {
"cap": 6,
"force": ["clothing", "clothed", "topwear", "bottomwear", "topless", "nude"],
"deny": {"5_fingers"},
},
"scene_pose": {
"cap": 4,
"force": ["simple_background", "standing", "sitting", "outside"],
"deny": set(),
},
"gaze_expression": {
"cap": 5,
"force": ["smile", "looking_at_viewer", "open_mouth", "blush", "eyes_closed"],
"deny": set(),
},
"text_symbols": {
"cap": 3,
"force": ["text", "dialogue", "<3"],
"deny": set(),
},
"body_type_presence": {
"cap": 4,
"force": ["anthro", "feral", "biped", "humanoid"],
"deny": set(),
},
"count_cardinality": {
"cap": 5,
"force": ["zero_pictured", "solo", "duo", "trio", "group"],
"deny": {"husky", "marsupial", "black_bars"},
},
"body_shape_breasts": {
"cap": 4,
"force": ["breasts", "big_breasts", "wide_hips", "thick_thighs"],
"deny": set(),
},
"species_taxonomy": {
"cap": 6,
"force": ["canid", "canis", "felid", "leporid", "bird", "bear", "unicorn", "equid"],
"deny": {"mammal"},
},
}
def _load_probe_rows(path: Path) -> List[Dict[str, str]]:
with path.open("r", encoding="utf-8", newline="") as f:
return list(csv.DictReader(f))
def _load_reliability(path: Path) -> Dict[str, Dict[str, str]]:
if not path or not path.is_file():
return {}
with path.open("r", encoding="utf-8", newline="") as f:
rows = list(csv.DictReader(f))
return {r["tag"]: r for r in rows}
def _as_float(v: str, default: float = 0.0) -> float:
try:
return float(v)
except Exception:
return default
def _as_int(v: str, default: int = 0) -> int:
try:
return int(v)
except Exception:
return default
def main() -> None:
ap = argparse.ArgumentParser(description="Build simplified probe set.")
ap.add_argument("--probe-info", type=Path, default=PROBE_INFO_CSV)
ap.add_argument("--reliability-csv", type=Path, default=None, help="Optional probe reliability CSV.")
ap.add_argument("--min-prevalence", type=float, default=0.01)
ap.add_argument("--max-prevalence", type=float, default=0.70)
ap.add_argument("--min-support-pos", type=int, default=5)
ap.add_argument("--min-f1-strong", type=float, default=0.45)
ap.add_argument("--min-precision-strong", type=float, default=0.50)
args = ap.parse_args()
if not args.probe_info.is_file():
raise FileNotFoundError(f"Missing probe informativeness CSV: {args.probe_info}")
rows = _load_probe_rows(args.probe_info)
rel = _load_reliability(args.reliability_csv) if args.reliability_csv else {}
by_bundle: Dict[str, List[Dict[str, str]]] = {}
by_tag: Dict[str, Dict[str, str]] = {}
for r in rows:
tag = r["tag"]
by_tag[tag] = r
b = r.get("suggested_probe_bundle", "other")
by_bundle.setdefault(b, []).append(r)
for b in by_bundle:
by_bundle[b].sort(key=lambda x: _as_float(x.get("actionable_score", "0")), reverse=True)
selected_initial: List[Dict[str, str]] = []
selected_tags: Set[str] = set()
for bundle, spec in BUNDLE_SPECS.items():
cap = int(spec["cap"])
deny = set(spec["deny"])
forced = spec["force"]
candidates = by_bundle.get(bundle, [])
def ok(r: Dict[str, str]) -> bool:
tag = r["tag"]
p = _as_float(r.get("prevalence", "0"))
return tag not in deny and args.min_prevalence <= p <= args.max_prevalence
for t in forced:
r = by_tag.get(t)
if not r or not ok(r):
continue
if t in selected_tags:
continue
selected_initial.append(r)
selected_tags.add(t)
if sum(1 for x in selected_initial if x.get("suggested_probe_bundle") == bundle) >= cap:
break
if sum(1 for x in selected_initial if x.get("suggested_probe_bundle") == bundle) >= cap:
continue
for r in candidates:
if not ok(r):
continue
t = r["tag"]
if t in selected_tags:
continue
selected_initial.append(r)
selected_tags.add(t)
if sum(1 for x in selected_initial if x.get("suggested_probe_bundle") == bundle) >= cap:
break
# Reliability-aware scoring (if reliability CSV exists).
out_rows: List[Dict[str, str]] = []
for r in selected_initial:
tag = r["tag"]
rr = rel.get(tag, {})
support_pos = _as_int(rr.get("support_pos", "0"))
precision_strong = _as_float(rr.get("precision_strong", "0"))
recall_strong = _as_float(rr.get("recall_strong", "0"))
f1_strong = _as_float(rr.get("f1_strong", "0"))
actionable = _as_float(r.get("actionable_score", "0"))
has_rel = bool(rr)
if has_rel:
reliability_weight = f1_strong
final_score = actionable * (0.25 + 0.75 * reliability_weight)
selected_final = int(
support_pos >= args.min_support_pos
and f1_strong >= args.min_f1_strong
and precision_strong >= args.min_precision_strong
)
rel_note = (
f"support={support_pos}, f1={f1_strong:.3f}, "
f"prec={precision_strong:.3f}, rec={recall_strong:.3f}"
)
else:
reliability_weight = 0.0
final_score = actionable
selected_final = 0
rel_note = "no_reliability_data"
out_rows.append(
{
"tag": tag,
"bundle": r.get("suggested_probe_bundle", "other"),
"needs_glossary": r.get("needs_glossary", "0"),
"prevalence": r.get("prevalence", ""),
"actionable_score": f"{actionable:.6f}",
"selected_initial": "1",
"support_pos": str(support_pos),
"precision_strong": f"{precision_strong:.6f}",
"recall_strong": f"{recall_strong:.6f}",
"f1_strong": f"{f1_strong:.6f}",
"reliability_weight": f"{reliability_weight:.6f}",
"final_score": f"{final_score:.6f}",
"selected_final": str(selected_final),
"reliability_note": rel_note,
}
)
out_rows.sort(key=lambda x: (_as_float(x["final_score"]), _as_float(x["actionable_score"])), reverse=True)
OUT_CSV.parent.mkdir(parents=True, exist_ok=True)
with OUT_CSV.open("w", encoding="utf-8", newline="") as f:
writer = csv.DictWriter(
f,
fieldnames=[
"tag",
"bundle",
"needs_glossary",
"prevalence",
"actionable_score",
"selected_initial",
"support_pos",
"precision_strong",
"recall_strong",
"f1_strong",
"reliability_weight",
"final_score",
"selected_final",
"reliability_note",
],
)
writer.writeheader()
writer.writerows(out_rows)
selected_final_tags = [r["tag"] for r in out_rows if r["selected_final"] == "1"]
bundle_specs_json = {}
for k, v in BUNDLE_SPECS.items():
bundle_specs_json[k] = {
"cap": v["cap"],
"force": list(v["force"]),
"deny": sorted(list(v["deny"])),
}
summary = {
"probe_info_csv": str(args.probe_info),
"reliability_csv": str(args.reliability_csv) if args.reliability_csv else None,
"n_selected_initial": len(out_rows),
"n_selected_final": len(selected_final_tags),
"selected_final_tags": selected_final_tags,
"bundle_specs": bundle_specs_json,
"thresholds": {
"min_prevalence": args.min_prevalence,
"max_prevalence": args.max_prevalence,
"min_support_pos": args.min_support_pos,
"min_f1_strong": args.min_f1_strong,
"min_precision_strong": args.min_precision_strong,
},
"outputs": {
"csv": str(OUT_CSV),
"summary_json": str(OUT_SUMMARY),
},
}
with OUT_SUMMARY.open("w", encoding="utf-8") as f:
json.dump(summary, f, indent=2, ensure_ascii=False)
print(f"Selected initial probes: {len(out_rows)}")
print(f"Selected final probes: {len(selected_final_tags)}")
print(f"Outputs: {OUT_CSV}, {OUT_SUMMARY}")
if __name__ == "__main__":
main()