| import json |
| import hashlib |
| from flask import Blueprint, request, jsonify |
| from datasets import load_dataset |
|
|
| bp = Blueprint("rlm_eval_datasets", __name__, url_prefix="/api/rlm-eval/datasets") |
|
|
| _cache: dict[str, dict] = {} |
|
|
|
|
| def _make_id(repo: str, config: str, split: str) -> str: |
| key = f"{repo}:{config}:{split}" |
| return hashlib.md5(key.encode()).hexdigest()[:12] |
|
|
|
|
| def _build_hierarchy(rows: list[dict]) -> dict: |
| """Reconstruct hierarchy from flat rows: examples -> iterations.""" |
| examples: dict[int, dict] = {} |
|
|
| for row in rows: |
| ei = row.get("example_idx", 0) |
| ri = row.get("rlm_iter", 0) |
|
|
| if ei not in examples: |
| examples[ei] = { |
| "example_idx": ei, |
| "question_text": row.get("question_text", ""), |
| "eval_correct": row.get("eval_correct"), |
| "iterations": {}, |
| "total_input_tokens": 0, |
| "total_output_tokens": 0, |
| "total_execution_time": 0.0, |
| "final_answer": None, |
| "final_answer_preview": "", |
| } |
|
|
| ex = examples[ei] |
|
|
| |
| code_blocks = [] |
| cbj = row.get("code_blocks_json", "") |
| if cbj and cbj != "[]": |
| try: |
| raw_blocks = json.loads(cbj) if isinstance(cbj, str) else cbj |
| for cb in raw_blocks: |
| block = {"code": cb.get("code", "")} |
| result = cb.get("result", {}) |
| if isinstance(result, dict) and result.get("stdout"): |
| block["stdout"] = result["stdout"] |
| elif cb.get("stdout"): |
| block["stdout"] = cb["stdout"] |
| code_blocks.append(block) |
| except (json.JSONDecodeError, TypeError): |
| code_blocks = [] |
|
|
| iteration = { |
| "rlm_iter": ri, |
| "prompt": row.get("prompt", ""), |
| "response": row.get("response", ""), |
| "model": row.get("model", ""), |
| "input_tokens": row.get("input_tokens", 0), |
| "output_tokens": row.get("output_tokens", 0), |
| "execution_time": row.get("execution_time", 0.0), |
| "has_code_blocks": row.get("has_code_blocks", False), |
| "code_blocks": code_blocks, |
| "final_answer": row.get("final_answer"), |
| "timestamp": row.get("timestamp", ""), |
| } |
|
|
| ex["iterations"][ri] = iteration |
| ex["total_input_tokens"] += iteration["input_tokens"] or 0 |
| ex["total_output_tokens"] += iteration["output_tokens"] or 0 |
| ex["total_execution_time"] += iteration["execution_time"] or 0.0 |
|
|
| if iteration["final_answer"]: |
| ex["final_answer"] = iteration["final_answer"] |
| ex["final_answer_preview"] = (iteration["final_answer"] or "")[:200] |
|
|
| |
| result = [] |
| for ei_key in sorted(examples.keys()): |
| ex = examples[ei_key] |
| iters_list = [] |
| for ri_key in sorted(ex["iterations"].keys()): |
| iters_list.append(ex["iterations"][ri_key]) |
| ex["iterations"] = iters_list |
| result.append(ex) |
|
|
| return {"examples": result} |
|
|
|
|
| @bp.route("/load", methods=["POST"]) |
| def load_dataset_endpoint(): |
| data = request.get_json() |
| repo = data.get("repo", "").strip() |
| if not repo: |
| return jsonify({"error": "repo is required"}), 400 |
|
|
| config = data.get("config", "rlm_call_traces") |
| split = data.get("split", "train") |
|
|
| try: |
| ds = load_dataset(repo, config, split=split) |
| except Exception as e: |
| return jsonify({"error": f"Failed to load dataset: {e}"}), 400 |
|
|
| ds_id = _make_id(repo, config, split) |
| rows = [ds[i] for i in range(len(ds))] |
| hierarchy = _build_hierarchy(rows) |
|
|
| |
| first_row = rows[0] if rows else {} |
| metadata = { |
| "run_id": first_row.get("run_id", ""), |
| "method": first_row.get("method", ""), |
| "model": first_row.get("model", ""), |
| } |
|
|
| _cache[ds_id] = { |
| "repo": repo, |
| "config": config, |
| "split": split, |
| "hierarchy": hierarchy, |
| "metadata": metadata, |
| "n_rows": len(rows), |
| } |
|
|
| short_name = repo.rsplit("/", 1)[-1] if "/" in repo else repo |
|
|
| return jsonify({ |
| "id": ds_id, |
| "repo": repo, |
| "name": short_name, |
| "config": config, |
| "split": split, |
| "metadata": metadata, |
| "n_examples": len(hierarchy["examples"]), |
| "n_rows": len(rows), |
| }) |
|
|
|
|
| @bp.route("/", methods=["GET"]) |
| def list_datasets(): |
| result = [] |
| for ds_id, info in _cache.items(): |
| result.append({ |
| "id": ds_id, |
| "repo": info["repo"], |
| "name": info["repo"].rsplit("/", 1)[-1] if "/" in info["repo"] else info["repo"], |
| "config": info["config"], |
| "split": info["split"], |
| "metadata": info["metadata"], |
| "n_rows": info["n_rows"], |
| "n_examples": len(info["hierarchy"]["examples"]), |
| }) |
| return jsonify(result) |
|
|
|
|
| @bp.route("/<ds_id>/overview", methods=["GET"]) |
| def get_overview(ds_id): |
| """Level 1: Summary of all examples.""" |
| if ds_id not in _cache: |
| return jsonify({"error": "Dataset not loaded"}), 404 |
|
|
| info = _cache[ds_id] |
| hierarchy = info["hierarchy"] |
|
|
| summaries = [] |
| for ex in hierarchy["examples"]: |
| summaries.append({ |
| "example_idx": ex["example_idx"], |
| "question_text": (ex["question_text"] or "")[:300], |
| "eval_correct": ex["eval_correct"], |
| "n_iterations": len(ex["iterations"]), |
| "total_input_tokens": ex["total_input_tokens"], |
| "total_output_tokens": ex["total_output_tokens"], |
| "total_execution_time": ex["total_execution_time"], |
| "final_answer_preview": ex["final_answer_preview"], |
| }) |
|
|
| return jsonify({ |
| "metadata": info["metadata"], |
| "examples": summaries, |
| }) |
|
|
|
|
| @bp.route("/<ds_id>/example/<int:example_idx>", methods=["GET"]) |
| def get_example_detail(ds_id, example_idx): |
| """Level 2: Iteration timeline for one example.""" |
| if ds_id not in _cache: |
| return jsonify({"error": "Dataset not loaded"}), 404 |
|
|
| info = _cache[ds_id] |
| hierarchy = info["hierarchy"] |
|
|
| ex_data = None |
| for ex in hierarchy["examples"]: |
| if ex["example_idx"] == example_idx: |
| ex_data = ex |
| break |
|
|
| if ex_data is None: |
| return jsonify({"error": f"Example {example_idx} not found"}), 404 |
|
|
| iters = [] |
| for it in ex_data["iterations"]: |
| iters.append({ |
| "rlm_iter": it["rlm_iter"], |
| "model": it["model"], |
| "input_tokens": it["input_tokens"], |
| "output_tokens": it["output_tokens"], |
| "execution_time": it["execution_time"], |
| "has_code_blocks": it["has_code_blocks"], |
| "n_code_blocks": len(it["code_blocks"]), |
| "response_preview": (it["response"] or "")[:300], |
| "has_final_answer": it["final_answer"] is not None, |
| "timestamp": it["timestamp"], |
| }) |
|
|
| return jsonify({ |
| "example_idx": example_idx, |
| "question_text": ex_data["question_text"], |
| "eval_correct": ex_data["eval_correct"], |
| "total_input_tokens": ex_data["total_input_tokens"], |
| "total_output_tokens": ex_data["total_output_tokens"], |
| "total_execution_time": ex_data["total_execution_time"], |
| "final_answer": ex_data["final_answer"], |
| "iterations": iters, |
| }) |
|
|
|
|
| @bp.route("/<ds_id>/example/<int:example_idx>/iter/<int:rlm_iter>", methods=["GET"]) |
| def get_iter_detail(ds_id, example_idx, rlm_iter): |
| """Full detail for a specific RLM iteration within an example.""" |
| if ds_id not in _cache: |
| return jsonify({"error": "Dataset not loaded"}), 404 |
|
|
| info = _cache[ds_id] |
| hierarchy = info["hierarchy"] |
|
|
| for ex in hierarchy["examples"]: |
| if ex["example_idx"] != example_idx: |
| continue |
| for it in ex["iterations"]: |
| if it["rlm_iter"] == rlm_iter: |
| return jsonify(it) |
|
|
| return jsonify({"error": "Iteration not found"}), 404 |
|
|
|
|
| @bp.route("/<ds_id>", methods=["DELETE"]) |
| def unload_dataset(ds_id): |
| if ds_id in _cache: |
| del _cache[ds_id] |
| return jsonify({"status": "ok"}) |
|
|