shinka-backup / scripts /trajectory_viewer.py
JustinTX's picture
Add files using upload-large-folder tool
6f90f5c verified
#!/usr/bin/env python3
import argparse
import json
import re
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
import streamlit as st
ROLE_STYLE = {
"system": {"label": "SYSTEM", "color": "#4B5563", "bg": "#F3F4F6"},
"user": {"label": "USER", "color": "#1D4ED8", "bg": "#DBEAFE"},
"assistant": {"label": "ASSISTANT", "color": "#065F46", "bg": "#D1FAE5"},
"tool": {"label": "TOOL", "color": "#7C2D12", "bg": "#FFEDD5"},
}
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Streamlit viewer for eval-agent trajectory files.")
parser.add_argument("--dir", type=str, default="", help="Directory containing agent run files.")
return parser.parse_args()
def file_sort_key(path: Path) -> Tuple[int, int, str]:
m = re.match(r"gen_(\d+)_(.*)$", path.name)
if not m:
return (10**9, 10**9, path.name)
gen = int(m.group(1))
suffix = m.group(2)
order = {
"task_message.txt": 0,
"result.json": 1,
"trajectory_messages.json": 2,
}.get(suffix, 99)
return (gen, order, path.name)
def try_load_json(path: Path) -> Optional[Any]:
try:
with open(path, "r", encoding="utf-8") as f:
return json.load(f)
except Exception:
return None
def extract_text_from_message(message: Dict[str, Any]) -> str:
text_parts: List[str] = []
content = message.get("content")
if isinstance(content, list):
for item in content:
if isinstance(item, dict) and item.get("type") == "text":
text = item.get("text")
if isinstance(text, str) and text:
text_parts.append(text)
return "\n".join(text_parts).strip()
def trajectory_summary(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
rows: List[Dict[str, Any]] = []
for idx, msg in enumerate(messages):
role = msg.get("role", "unknown")
text = extract_text_from_message(msg)
preview = text[:120] + ("..." if len(text) > 120 else "")
tool_calls = msg.get("tool_calls")
tool_call_count = len(tool_calls) if isinstance(tool_calls, list) else 0
rows.append(
{
"idx": idx,
"role": role,
"tool_calls": tool_call_count,
"chars": len(text),
"preview": preview,
}
)
return rows
def render_trajectory(messages: List[Dict[str, Any]]):
st.subheader("Trajectory Overview")
rows = trajectory_summary(messages)
st.dataframe(rows, width="stretch")
st.subheader("Full Message Timeline")
show_raw = st.checkbox("Show raw dict under each message", value=False)
for idx, msg in enumerate(messages):
role = str(msg.get("role", "unknown"))
style = ROLE_STYLE.get(role, {"label": role.upper(), "color": "#111827", "bg": "#F9FAFB"})
text = extract_text_from_message(msg)
tool_calls = msg.get("tool_calls")
tool_call_count = len(tool_calls) if isinstance(tool_calls, list) else 0
title = f"{style['label']} #{idx}"
if tool_call_count > 0:
title += f" | tool_calls={tool_call_count}"
st.markdown(
(
f"<div style='margin:8px 0 4px 0;'>"
f"<span style='background:{style['bg']}; color:{style['color']};"
" padding:4px 10px; border-radius:999px; font-weight:700;'>"
f"{title}</span></div>"
),
unsafe_allow_html=True,
)
show_msg = st.toggle(f"Show message #{idx}", value=True, key=f"show_msg_{idx}")
if show_msg:
if text:
st.markdown(
(
f"<div style='border-left:4px solid {style['color']}; padding:8px 12px;"
f" background:{style['bg']}; border-radius:6px; white-space:pre-wrap;'>"
f"{text}</div>"
),
unsafe_allow_html=True,
)
else:
st.caption("<no text content>")
if show_raw and show_msg:
st.json(msg)
def main():
args = parse_args()
st.set_page_config(page_title="Eval Agent Trajectory Viewer", layout="wide")
st.title("Eval Agent Trajectory Viewer")
default_dir = args.dir or ""
run_dir_input = st.sidebar.text_input("Run directory", value=default_dir)
run_dir = Path(run_dir_input).expanduser() if run_dir_input else None
if not run_dir_input:
st.info("Pass `--dir` or set the directory in the sidebar.")
return
if not run_dir or not run_dir.exists() or not run_dir.is_dir():
st.error(f"Directory not found: {run_dir_input}")
return
files = sorted([p for p in run_dir.iterdir() if p.is_file()], key=file_sort_key)
if not files:
st.warning("No files found in this directory.")
return
file_names = [p.name for p in files]
selected_name = st.sidebar.selectbox("Select file", options=file_names, index=0)
selected_path = run_dir / selected_name
st.caption(f"Selected: `{selected_path}`")
st.caption(f"Size: {selected_path.stat().st_size:,} bytes")
if selected_name.endswith("_trajectory_messages.json"):
data = try_load_json(selected_path)
if not isinstance(data, list):
st.error("Trajectory file is not a JSON list.")
return
msg_list = [x for x in data if isinstance(x, dict)]
st.success(f"Loaded {len(msg_list)} message dicts.")
render_trajectory(msg_list)
elif selected_name.endswith(".json"):
data = try_load_json(selected_path)
if data is None:
st.error("Failed to parse JSON.")
raw = selected_path.read_text(encoding="utf-8", errors="replace")
st.code(raw, language="json")
else:
st.json(data)
else:
raw = selected_path.read_text(encoding="utf-8", errors="replace")
st.code(raw, language="text")
if __name__ == "__main__":
main()