shinka-backup / eval_agent /feedback.py
JustinTX's picture
Add files using upload-large-folder tool
6f90f5c verified
"""Feedback loop for eval agent — computes effectiveness of auxiliary metrics."""
from __future__ import annotations
import json
import logging
import math
from pathlib import Path
from typing import Dict, List, Optional, Tuple
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Statistics helpers (no scipy dependency)
# ---------------------------------------------------------------------------
def _rank(values: List[float]) -> List[float]:
"""Assign ranks to values (average rank for ties)."""
n = len(values)
indexed = sorted(range(n), key=lambda i: values[i])
ranks = [0.0] * n
i = 0
while i < n:
j = i
while j < n - 1 and values[indexed[j + 1]] == values[indexed[j]]:
j += 1
avg_rank = (i + j) / 2.0 + 1 # 1-based
for k in range(i, j + 1):
ranks[indexed[k]] = avg_rank
i = j + 1
return ranks
def _spearman_correlation(x: List[float], y: List[float]) -> float:
"""Spearman rank correlation without scipy."""
n = len(x)
if n < 3:
return 0.0
rx = _rank(x)
ry = _rank(y)
d_sq = sum((a - b) ** 2 for a, b in zip(rx, ry))
denom = n * (n * n - 1)
if denom == 0:
return 0.0
return 1.0 - (6.0 * d_sq) / denom
def _variance(values: List[float]) -> float:
"""Population variance."""
n = len(values)
if n < 2:
return 0.0
mean = sum(values) / n
return sum((v - mean) ** 2 for v in values) / n
def _linear_trend(values: List[float]) -> float:
"""Slope of linear regression (values indexed 0..n-1)."""
n = len(values)
if n < 2:
return 0.0
x_mean = (n - 1) / 2.0
y_mean = sum(values) / n
num = sum((i - x_mean) * (v - y_mean) for i, v in enumerate(values))
den = sum((i - x_mean) ** 2 for i in range(n))
if den == 0:
return 0.0
return num / den
# ---------------------------------------------------------------------------
# Data loading
# ---------------------------------------------------------------------------
def _load_generation_metrics(
results_dir: Path, gen_start: int, gen_end: int
) -> List[Tuple[int, Dict]]:
"""Load metrics.json for a range of generations. Returns [(gen, data), ...]."""
records = []
for gen in range(gen_start, gen_end + 1):
metrics_path = results_dir / f"gen_{gen}" / "results" / "metrics.json"
if not metrics_path.exists():
continue
try:
with open(metrics_path) as f:
data = json.load(f)
records.append((gen, data))
except Exception:
continue
return records
def _extract_aux_metric_names(records: List[Tuple[int, Dict]]) -> List[str]:
"""Collect all aux_ metric names from public metrics across records.
Filters out framework metrics (aux_aux_metric_*) that are always present."""
FRAMEWORK_PREFIXES = ("aux_aux_metric_eval_success", "aux_aux_metric_error_code",
"aux_aux_metric_error_message_length", "aux_aux_metric_error_detail_length",
"aux_aux_metric_non_numeric_dropped_count")
names: set[str] = set()
for _, data in records:
public = data.get("public", {})
for key in public:
if key.startswith("aux_") and key not in FRAMEWORK_PREFIXES:
names.add(key)
return sorted(names)
# ---------------------------------------------------------------------------
# Feedback computation
# ---------------------------------------------------------------------------
def _classify_correlation(rho: float) -> Tuple[str, str]:
"""Classify correlation strength. Returns (label, emoji)."""
abs_rho = abs(rho)
if abs_rho >= 0.6:
if rho > 0:
return "strong positive", "USEFUL"
return "strong negative", "MISLEADING"
if abs_rho >= 0.3:
if rho > 0:
return "moderate positive", "USEFUL"
return "moderate negative", "REVIEW"
return "weak/none", "LOW SIGNAL"
def compute_metric_feedback(
results_dir: Path,
current_gen: int,
last_agent_gen: int,
pre_window: int = 10,
) -> str:
"""
Compute feedback on auxiliary metrics effectiveness since last agent intervention.
Args:
results_dir: Experiment root directory.
current_gen: Current generation number.
last_agent_gen: Generation when agent last ran.
pre_window: Number of generations before agent intervention to compare.
Returns:
Formatted markdown string for inclusion in agent task message.
Empty string if insufficient data.
"""
if last_agent_gen < 0 or current_gen <= last_agent_gen:
return ""
results_dir = Path(results_dir)
# Load post-intervention data (last_agent_gen+1 .. current_gen)
post_records = _load_generation_metrics(results_dir, last_agent_gen + 1, current_gen)
if len(post_records) < 2:
return "" # Not enough data for meaningful feedback
# Load pre-intervention data for trajectory comparison
pre_start = max(0, last_agent_gen - pre_window)
pre_records = _load_generation_metrics(results_dir, pre_start, last_agent_gen)
# --- Section 1: Score trajectory comparison ---
trajectory_lines = _build_trajectory_section(pre_records, post_records, last_agent_gen)
# --- Section 2: Per-metric effectiveness ---
aux_names = _extract_aux_metric_names(post_records)
if not aux_names:
# No auxiliary metrics to evaluate
if trajectory_lines:
return (
f"\n\n📋 Feedback on Previous Actions (gen {last_agent_gen} → gen {current_gen}):\n\n"
+ "\n".join(trajectory_lines)
+ "\n\nNo auxiliary metrics found in recent generations.\n"
)
return ""
metric_lines = _build_metric_effectiveness_section(post_records, aux_names)
# --- Section 3: Actionable recommendation ---
recommendation = _build_recommendation(post_records, aux_names)
# Assemble
parts = [
f"\n\n📋 Feedback on Your Previous Actions (gen {last_agent_gen} → gen {current_gen}):",
"",
]
if trajectory_lines:
parts.extend(trajectory_lines)
parts.append("")
parts.extend(metric_lines)
if recommendation:
parts.append("")
parts.append(recommendation)
return "\n".join(parts)
def _build_trajectory_section(
pre_records: List[Tuple[int, Dict]],
post_records: List[Tuple[int, Dict]],
last_agent_gen: int,
) -> List[str]:
"""Build score trajectory comparison lines."""
lines: List[str] = ["Score Trajectory:"]
pre_scores = [d.get("combined_score", 0.0) for _, d in pre_records if isinstance(d.get("combined_score"), (int, float))]
post_scores = [d.get("combined_score", 0.0) for _, d in post_records if isinstance(d.get("combined_score"), (int, float))]
if pre_scores:
pre_avg = sum(pre_scores) / len(pre_scores)
pre_trend = _linear_trend(pre_scores)
lines.append(f"- Before intervention ({len(pre_scores)} gens): avg {pre_avg:.4f}, trend {pre_trend:+.4f}/gen")
if post_scores:
post_avg = sum(post_scores) / len(post_scores)
post_trend = _linear_trend(post_scores)
lines.append(f"- After intervention ({len(post_scores)} gens): avg {post_avg:.4f}, trend {post_trend:+.4f}/gen")
if pre_scores and post_scores:
delta = post_avg - pre_avg
lines.append(f"- Net impact: avg score {'improved' if delta > 0 else 'declined'} by {delta:+.4f}")
# Score decline warning
if pre_avg > 0 and delta < 0 and abs(delta) / pre_avg > 0.05:
lines.append("")
lines.append("⚠️ WARNING: Score has DECLINED significantly since agent intervention.")
lines.append("Consider simplifying or clearing auxiliary metrics to eliminate potential interference.")
lines.append("Focus diagnostic_report.md on algorithmic advice rather than adding new metrics.")
return lines if len(lines) > 1 else []
def _build_metric_effectiveness_section(
post_records: List[Tuple[int, Dict]],
aux_names: List[str],
) -> List[str]:
"""Build per-metric effectiveness table."""
scores = []
for _, data in post_records:
cs = data.get("combined_score")
if isinstance(cs, (int, float)) and math.isfinite(cs):
scores.append(cs)
else:
scores.append(None)
lines = [
"Metric Effectiveness Report:",
"| Metric | Corr w/ Score | Variance | Coverage | Signal |",
"|--------|--------------|----------|----------|--------|",
]
for name in aux_names:
metric_vals: List[Optional[float]] = []
for _, data in post_records:
v = data.get("public", {}).get(name)
if isinstance(v, (int, float)) and math.isfinite(v):
metric_vals.append(float(v))
else:
metric_vals.append(None)
# Paired non-None values for correlation
paired_m = []
paired_s = []
for mv, sv in zip(metric_vals, scores):
if mv is not None and sv is not None:
paired_m.append(mv)
paired_s.append(sv)
n_total = len(metric_vals)
n_present = sum(1 for v in metric_vals if v is not None)
coverage = n_present / n_total if n_total > 0 else 0.0
if len(paired_m) >= 3:
rho = _spearman_correlation(paired_m, paired_s)
var = _variance(paired_m)
label, signal = _classify_correlation(rho)
# Override signal if variance is near zero
if var < 1e-9:
signal = "LOW SIGNAL"
lines.append(
f"| {name} | {rho:+.2f} ({label}) | {var:.4f} | {coverage:.0%} | {signal} |"
)
else:
lines.append(
f"| {name} | N/A (< 3 pts) | N/A | {coverage:.0%} | INSUFFICIENT |"
)
return lines
def _build_recommendation(
post_records: List[Tuple[int, Dict]],
aux_names: List[str],
) -> str:
"""Build actionable recommendation based on metric analysis."""
scores = [
data.get("combined_score", 0.0)
for _, data in post_records
if isinstance(data.get("combined_score"), (int, float))
]
to_remove: List[str] = []
to_review: List[str] = []
for name in aux_names:
paired_m = []
paired_s = []
for _, data in post_records:
v = data.get("public", {}).get(name)
cs = data.get("combined_score")
if (
isinstance(v, (int, float)) and math.isfinite(v)
and isinstance(cs, (int, float)) and math.isfinite(cs)
):
paired_m.append(float(v))
paired_s.append(float(cs))
if len(paired_m) < 3:
continue
rho = _spearman_correlation(paired_m, paired_s)
var = _variance(paired_m)
if var < 1e-9:
# Distinguish: constant because code broken vs constant because metric is bad
all_scores_zero = all(s <= 0.01 for s in paired_s)
if all_scores_zero:
# Don't recommend removal — metric is constant because code never worked
pass # skip this metric, it may become useful when code starts working
else:
to_remove.append(f"{name} (constant value, no signal)")
elif rho < -0.3:
to_review.append(f"{name} (negative correlation {rho:+.2f}, may mislead optimizer)")
elif abs(rho) < 0.1:
to_remove.append(f"{name} (near-zero correlation {rho:+.2f})")
parts: List[str] = []
if to_remove:
parts.append("Consider REMOVING: " + "; ".join(to_remove) + ".")
if to_review:
parts.append("Consider REVIEWING: " + "; ".join(to_review) + ".")
if not parts:
parts.append("All current metrics show reasonable signal. Focus on diagnostic_report.md rather than adding new metrics.")
# Churn warning: if most metrics are flagged for removal, the agent is likely churning
if len(to_remove) >= 3:
parts.append(
"⚠️ CHURN WARNING: Most metrics have low signal. "
"Stop modifying metrics and focus on writing a useful diagnostic_report.md instead."
)
return "Recommendation: " + " ".join(parts)