Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python | |
| # eval_battleground_rlaif.py | |
| # | |
| # Evaluation script for Battlegrounds RLAIF models: No FT, SFT, and SFT+GRPO. | |
| # Measures action prediction accuracy against expert/labeled actions. | |
| import argparse | |
| import json | |
| import os | |
| import sys | |
| from typing import Optional, Dict, Any, List | |
| from tqdm import tqdm | |
| import torch | |
| from datasets import load_dataset, Dataset | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from peft import PeftModel | |
| _SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| if _SCRIPT_DIR not in sys.path: | |
| sys.path.append(_SCRIPT_DIR) | |
| from battleground_nl_utils import ( | |
| dataset_state_to_game_state, | |
| game_state_to_natural_language, | |
| ) | |
| # ================== Constants ================== | |
| LOCAL_INSTRUCT_PATH = "models/qwen3-4b-instruct-2507/Qwen/Qwen3-4B-Instruct-2507" | |
| DEFAULT_DATA_FILE = "RL/datasets/battleground_rlaif_multicandidate.jsonl" | |
| def _resolve_default_model_id() -> str: | |
| env_override = os.environ.get("QWEN_INSTRUCT_MODEL") | |
| if env_override: | |
| return env_override | |
| if os.path.isdir(LOCAL_INSTRUCT_PATH): | |
| return LOCAL_INSTRUCT_PATH | |
| return "Qwen/Qwen3-4B-Instruct" | |
| DEFAULT_MODEL_ID = _resolve_default_model_id() | |
| # ================== Data loading ================== | |
| INSTRUCTION_PREFIX = """You are a Hearthstone Battlegrounds AI. | |
| Given the current game state as a JSON object, choose exactly one best action and respond with a single JSON object in this exact format: | |
| {"action":{"type":"<ACTION_TYPE>","tavern_index":<int-or-null>,"hand_index":<int-or-null>,"board_index":<int-or-null>,"card_name":<string-or-null>}} | |
| Rules: | |
| 1. Respond with JSON only. Do not add explanations or any extra text. | |
| 2. The top-level object must have exactly one key: "action". | |
| 3. Use 0-based integers for indices or null when not used. | |
| 4. "type" must be one of: "BUY_FROM_TAVERN","PLAY_FROM_HAND","SELL_FROM_BOARD","HERO_POWER","ROLL","UPGRADE_TAVERN","FREEZE","END_TURN". | |
| 5. "card_name" must exactly match a card name from the game state when required, otherwise null. | |
| Now here is the game state JSON: | |
| """ | |
| INSTRUCTION_PREFIX_NL = """You are a Hearthstone Battlegrounds AI. | |
| Given the following natural language description of the current game state, choose exactly one best action and respond with a single JSON object in this exact format: | |
| {"action":{"type":"<ACTION_TYPE>","tavern_index":<int-or-null>,"hand_index":<int-or-null>,"board_index":<int-or-null>,"card_name":<string-or-null>}} | |
| Rules: | |
| 1. Respond with JSON only. Do not add explanations or any extra text. | |
| 2. The top-level object must have exactly one key: "action". | |
| 3. Use 0-based integers for indices or null when not used. | |
| 4. "type" must be one of: "BUY_FROM_TAVERN","PLAY_FROM_HAND","SELL_FROM_BOARD","HERO_POWER","ROLL","UPGRADE_TAVERN","FREEZE","END_TURN". | |
| 5. "card_name" must exactly match a card name from the game state when required, otherwise null. | |
| Now here is the description of the game state: | |
| """ | |
| def _build_prompt(example: Dict[str, Any], input_mode: str = "json") -> str: | |
| """Build prompt from game state (same format as training).""" | |
| if input_mode == "nl": | |
| game_state = dataset_state_to_game_state(example) | |
| nl_state = game_state_to_natural_language(game_state) | |
| prefix = INSTRUCTION_PREFIX_NL | |
| state_text = nl_state | |
| else: | |
| obj = { | |
| "task": "battlegrounds_policy_v1", | |
| "phase": example["phase"], | |
| "turn": example["turn"], | |
| "state": example["state"], | |
| } | |
| state_text = json.dumps(obj, separators=(",", ":"), ensure_ascii=False) | |
| prefix = INSTRUCTION_PREFIX | |
| return prefix + "\n" + state_text | |
| def load_eval_dataset( | |
| data_file: str, | |
| test_size: float = 0.1, | |
| seed: int = 42, | |
| limit: Optional[int] = None, | |
| input_mode: str = "json", | |
| ): | |
| """ | |
| Load evaluation dataset from JSONL file. | |
| Uses the same train/test split as training to get the held-out test set. | |
| """ | |
| raw = load_dataset("json", data_files={"train": data_file})["train"] | |
| # Same split as training | |
| split = raw.train_test_split(test_size=test_size, seed=seed) | |
| test_ds = split["test"] | |
| def format_example(example): | |
| prompt = _build_prompt(example, input_mode=input_mode) | |
| candidates = example["candidates"] | |
| # Find expert action | |
| expert = None | |
| for c in candidates: | |
| if c.get("role") == "expert": | |
| expert = c | |
| break | |
| if expert is None: | |
| expert = max(candidates, key=lambda x: float(x.get("reward", 0.0))) | |
| return { | |
| "prompt": prompt, | |
| "expert_action": expert["action"], | |
| "candidates": candidates, | |
| "game_id": example.get("game_id", ""), | |
| "step_id": example.get("step_id", 0), | |
| "turn": example["turn"], | |
| "phase": example["phase"], | |
| } | |
| test_ds = test_ds.map(format_example, remove_columns=raw.column_names) | |
| if limit is not None: | |
| test_ds = test_ds.select(range(min(limit, len(test_ds)))) | |
| return test_ds | |
| # ================== Action parsing & comparison ================== | |
| def parse_action_from_completion(text: str) -> Optional[Dict[str, Any]]: | |
| """ | |
| Parse model completion to extract action dict. | |
| Expected format from training: {"action": {...}} | |
| """ | |
| text = text.strip() | |
| # Try to find JSON in the text | |
| # Sometimes model outputs extra text before/after JSON | |
| start_idx = text.find("{") | |
| if start_idx == -1: | |
| return None | |
| # Find matching closing brace | |
| brace_count = 0 | |
| end_idx = -1 | |
| for i, c in enumerate(text[start_idx:], start=start_idx): | |
| if c == "{": | |
| brace_count += 1 | |
| elif c == "}": | |
| brace_count -= 1 | |
| if brace_count == 0: | |
| end_idx = i + 1 | |
| break | |
| if end_idx == -1: | |
| # No matching brace, try to find any closing brace | |
| end_idx = text.rfind("}") + 1 | |
| if end_idx == 0: | |
| return None | |
| json_str = text[start_idx:end_idx] | |
| try: | |
| obj = json.loads(json_str) | |
| except Exception: | |
| # Try to fix common issues | |
| try: | |
| # Sometimes model outputs incomplete JSON, try adding closing braces | |
| obj = json.loads(json_str + "}") | |
| except: | |
| try: | |
| obj = json.loads(json_str + "}}") | |
| except: | |
| return None | |
| if isinstance(obj, dict): | |
| # Format from training: {"action": {...}} | |
| if "action" in obj and isinstance(obj["action"], dict): | |
| return obj["action"] | |
| # If it's directly an action dict (has "type" field) | |
| if "type" in obj: | |
| return obj | |
| return None | |
| def actions_match(pred: Dict[str, Any], gold: Dict[str, Any], strict: bool = True) -> bool: | |
| """ | |
| Compare predicted action with gold action. | |
| Args: | |
| pred: Predicted action dict | |
| gold: Gold/expert action dict | |
| strict: If True, all fields must match exactly. If False, only compare key fields. | |
| """ | |
| if strict: | |
| return pred == gold | |
| # Relaxed matching: compare only essential fields | |
| key_fields = ["type", "tavern_index", "hand_index", "board_index", "card_name"] | |
| for field in key_fields: | |
| pred_val = pred.get(field) | |
| gold_val = gold.get(field) | |
| # Treat None and missing as equivalent | |
| if pred_val is None and gold_val is None: | |
| continue | |
| if pred_val != gold_val: | |
| return False | |
| return True | |
| def get_action_reward(pred: Dict[str, Any], candidates: List[Dict[str, Any]]) -> float: | |
| """Get reward for predicted action by matching against candidates.""" | |
| for cand in candidates: | |
| cand_action = cand.get("action", {}) | |
| if actions_match(pred, cand_action, strict=False): | |
| return float(cand.get("reward", 0.0)) | |
| return 0.0 | |
| # ================== Model loading ================== | |
| def load_base_model(model_path: str, bf16: bool = True): | |
| """Load base model without any adapters.""" | |
| dtype = torch.bfloat16 if bf16 and torch.cuda.is_available() else torch.float16 | |
| model_kwargs = { | |
| "torch_dtype": dtype, | |
| "trust_remote_code": True, | |
| } | |
| if torch.cuda.is_available(): | |
| model_kwargs["device_map"] = "auto" | |
| model = AutoModelForCausalLM.from_pretrained(model_path, **model_kwargs) | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| model_path, use_fast=True, trust_remote_code=True | |
| ) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| tokenizer.padding_side = "left" | |
| return model, tokenizer | |
| def load_peft_model(base_model_path: str, adapter_path: str, bf16: bool = True): | |
| """Load base model with PEFT adapter.""" | |
| dtype = torch.bfloat16 if bf16 and torch.cuda.is_available() else torch.float16 | |
| model_kwargs = { | |
| "torch_dtype": dtype, | |
| "trust_remote_code": True, | |
| } | |
| if torch.cuda.is_available(): | |
| model_kwargs["device_map"] = "auto" | |
| base_model = AutoModelForCausalLM.from_pretrained(base_model_path, **model_kwargs) | |
| model = PeftModel.from_pretrained(base_model, adapter_path) | |
| model = model.merge_and_unload() # Merge for faster inference | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| base_model_path, use_fast=True, trust_remote_code=True | |
| ) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| tokenizer.padding_side = "left" | |
| return model, tokenizer | |
| # ================== Evaluation ================== | |
| def evaluate_model( | |
| model, | |
| tokenizer, | |
| test_ds, | |
| max_new_tokens: int = 128, | |
| batch_size: int = 8, | |
| verbose: bool = False, | |
| ): | |
| """ | |
| Evaluate model on Battlegrounds test set. | |
| Returns: | |
| - exact_match_acc: Accuracy of exact action match | |
| - relaxed_match_acc: Accuracy with relaxed matching (key fields only) | |
| - avg_reward: Average reward of predicted actions | |
| - results: List of per-sample results | |
| """ | |
| model.eval() | |
| device = next(model.parameters()).device | |
| exact_correct = 0 | |
| relaxed_correct = 0 | |
| total_reward = 0.0 | |
| total = 0 | |
| parse_failures = 0 | |
| results = [] | |
| for i in tqdm(range(0, len(test_ds), batch_size), desc="Evaluating"): | |
| batch = test_ds[i : i + batch_size] | |
| prompts = batch["prompt"] if isinstance(batch["prompt"], list) else [batch["prompt"]] | |
| expert_actions = batch["expert_action"] if isinstance(batch["expert_action"], list) else [batch["expert_action"]] | |
| candidates_list = batch["candidates"] if isinstance(batch["candidates"], list) else [batch["candidates"]] | |
| inputs = tokenizer( | |
| prompts, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| max_length=1024, | |
| ).to(device) | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=False, | |
| pad_token_id=tokenizer.pad_token_id, | |
| eos_token_id=tokenizer.eos_token_id, | |
| ) | |
| # Decode and evaluate each sample | |
| for j, (output, prompt, expert_action, candidates) in enumerate( | |
| zip(outputs, prompts, expert_actions, candidates_list) | |
| ): | |
| input_len = inputs["input_ids"][j].shape[0] | |
| generated = tokenizer.decode(output[input_len:], skip_special_tokens=True) | |
| pred_action = parse_action_from_completion(generated) | |
| is_exact_match = False | |
| is_relaxed_match = False | |
| reward = 0.0 | |
| if pred_action is None: | |
| parse_failures += 1 | |
| else: | |
| is_exact_match = actions_match(pred_action, expert_action, strict=True) | |
| is_relaxed_match = actions_match(pred_action, expert_action, strict=False) | |
| reward = get_action_reward(pred_action, candidates) | |
| if is_exact_match: | |
| exact_correct += 1 | |
| if is_relaxed_match: | |
| relaxed_correct += 1 | |
| total_reward += reward | |
| total += 1 | |
| result = { | |
| "game_id": batch["game_id"][j] if isinstance(batch["game_id"], list) else batch["game_id"], | |
| "step_id": batch["step_id"][j] if isinstance(batch["step_id"], list) else batch["step_id"], | |
| "turn": batch["turn"][j] if isinstance(batch["turn"], list) else batch["turn"], | |
| "phase": batch["phase"][j] if isinstance(batch["phase"], list) else batch["phase"], | |
| "expert_action": expert_action, | |
| "predicted_action": pred_action, | |
| "generated_text": generated.strip()[:200], # Truncate for readability | |
| "exact_match": is_exact_match, | |
| "relaxed_match": is_relaxed_match, | |
| "reward": reward, | |
| } | |
| results.append(result) | |
| if verbose and not is_relaxed_match: | |
| print(f"\n[WRONG] Game: {result['game_id']}, Step: {result['step_id']}") | |
| print(f" Expert: {expert_action}") | |
| print(f" Pred: {pred_action}") | |
| print(f" Gen: {generated[:150]}") | |
| exact_match_acc = exact_correct / total if total > 0 else 0.0 | |
| relaxed_match_acc = relaxed_correct / total if total > 0 else 0.0 | |
| avg_reward = total_reward / total if total > 0 else 0.0 | |
| return { | |
| "exact_match_acc": exact_match_acc, | |
| "relaxed_match_acc": relaxed_match_acc, | |
| "avg_reward": avg_reward, | |
| "parse_failure_rate": parse_failures / total if total > 0 else 0.0, | |
| "total_samples": total, | |
| "results": results, | |
| } | |
| # ================== Main ================== | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Evaluate Battlegrounds RLAIF models: No FT, SFT, SFT+GRPO") | |
| parser.add_argument( | |
| "--base-model", | |
| default=DEFAULT_MODEL_ID, | |
| help="Base model path (Qwen instruct checkpoint).", | |
| ) | |
| parser.add_argument( | |
| "--output-dir", | |
| default="./battleground_rlaif_qwen", | |
| help="Directory containing SFT and GRPO checkpoints.", | |
| ) | |
| parser.add_argument( | |
| "--data-file", | |
| default=DEFAULT_DATA_FILE, | |
| help="Path to JSONL file with multi-candidate Battlegrounds data.", | |
| ) | |
| parser.add_argument( | |
| "--sft-adapter", | |
| default=None, | |
| help="Path to SFT adapter (default: <output-dir>/sft_model).", | |
| ) | |
| parser.add_argument( | |
| "--grpo-adapter", | |
| default=None, | |
| help="Path to GRPO adapter (default: <output-dir>/grpo_model).", | |
| ) | |
| parser.add_argument( | |
| "--eval-samples", | |
| type=int, | |
| default=50, | |
| help="Number of test samples to evaluate (default: 50 for quick testing, use -1 for full set).", | |
| ) | |
| parser.add_argument("--batch-size", type=int, default=8, help="Batch size for inference (default: 8 for A800).") | |
| parser.add_argument("--max-new-tokens", type=int, default=128, help="Max tokens to generate.") | |
| parser.add_argument("--disable-bf16", action="store_true", help="Use fp16 instead of bf16.") | |
| parser.add_argument("--verbose", action="store_true", help="Print wrong predictions.") | |
| parser.add_argument( | |
| "--eval-no-ft", action="store_true", help="Evaluate base model (no fine-tuning)." | |
| ) | |
| parser.add_argument("--eval-sft", action="store_true", help="Evaluate SFT model.") | |
| parser.add_argument("--eval-grpo", action="store_true", help="Evaluate SFT+GRPO model.") | |
| parser.add_argument( | |
| "--save-results", | |
| default=None, | |
| help="Path to save detailed results as JSON.", | |
| ) | |
| parser.add_argument( | |
| "--input-mode", | |
| choices=["json", "nl"], | |
| default="json", | |
| help="Input format for game state: 'json' uses raw JSON, 'nl' uses natural language description.", | |
| ) | |
| args = parser.parse_args() | |
| bf16 = not args.disable_bf16 | |
| # Default: evaluate all if none specified | |
| eval_all = not (args.eval_no_ft or args.eval_sft or args.eval_grpo) | |
| if eval_all: | |
| args.eval_no_ft = True | |
| args.eval_sft = True | |
| args.eval_grpo = True | |
| # Resolve adapter paths | |
| sft_adapter = args.sft_adapter or os.path.join(args.output_dir, "sft_model") | |
| grpo_adapter = args.grpo_adapter or os.path.join(args.output_dir, "grpo_model") | |
| # Handle eval_samples=-1 as full set | |
| eval_samples = None if args.eval_samples == -1 else args.eval_samples | |
| # Load test data | |
| print("Loading Battlegrounds test set...") | |
| if not os.path.exists(args.data_file): | |
| print(f"ERROR: Data file not found: {args.data_file}") | |
| return | |
| test_ds = load_eval_dataset( | |
| args.data_file, | |
| limit=eval_samples, | |
| input_mode=args.input_mode, | |
| ) | |
| print(f"Test samples: {len(test_ds)}") | |
| all_results = {} | |
| # ===== Evaluate No FT (base model) ===== | |
| if args.eval_no_ft: | |
| print("\n" + "=" * 60) | |
| print("Evaluating: No Fine-Tuning (Base Model)") | |
| print("=" * 60) | |
| model, tokenizer = load_base_model(args.base_model, bf16=bf16) | |
| metrics = evaluate_model( | |
| model, tokenizer, test_ds, | |
| max_new_tokens=args.max_new_tokens, | |
| batch_size=args.batch_size, | |
| verbose=args.verbose, | |
| ) | |
| print(f"[No FT] Exact Match: {metrics['exact_match_acc']:.4f}") | |
| print(f"[No FT] Relaxed Match: {metrics['relaxed_match_acc']:.4f}") | |
| print(f"[No FT] Avg Reward: {metrics['avg_reward']:.4f}") | |
| print(f"[No FT] Parse Failures: {metrics['parse_failure_rate']:.2%}") | |
| all_results["no_ft"] = metrics | |
| del model | |
| torch.cuda.empty_cache() | |
| # ===== Evaluate SFT ===== | |
| if args.eval_sft: | |
| print("\n" + "=" * 60) | |
| print("Evaluating: SFT Fine-Tuned Model") | |
| print("=" * 60) | |
| if not os.path.exists(sft_adapter): | |
| print(f"[SKIP] SFT adapter not found at: {sft_adapter}") | |
| else: | |
| model, tokenizer = load_peft_model(args.base_model, sft_adapter, bf16=bf16) | |
| metrics = evaluate_model( | |
| model, tokenizer, test_ds, | |
| max_new_tokens=args.max_new_tokens, | |
| batch_size=args.batch_size, | |
| verbose=args.verbose, | |
| ) | |
| print(f"[SFT] Exact Match: {metrics['exact_match_acc']:.4f}") | |
| print(f"[SFT] Relaxed Match: {metrics['relaxed_match_acc']:.4f}") | |
| print(f"[SFT] Avg Reward: {metrics['avg_reward']:.4f}") | |
| print(f"[SFT] Parse Failures: {metrics['parse_failure_rate']:.2%}") | |
| all_results["sft"] = metrics | |
| del model | |
| torch.cuda.empty_cache() | |
| # ===== Evaluate SFT + GRPO ===== | |
| if args.eval_grpo: | |
| print("\n" + "=" * 60) | |
| print("Evaluating: SFT + GRPO Fine-Tuned Model") | |
| print("=" * 60) | |
| grpo_epoch_dir = os.path.join(args.output_dir, "grpo") | |
| adapters_to_eval: List[tuple[str, str]] = [] | |
| # If user did not override --grpo-adapter and epoch checkpoints exist, | |
| # evaluate all checkpoint-* directories under output_dir/grpo plus final grpo_model. | |
| default_grpo_adapter = os.path.join(args.output_dir, "grpo_model") | |
| using_default_adapter = (args.grpo_adapter is None) or ( | |
| grpo_adapter == default_grpo_adapter | |
| ) | |
| if using_default_adapter and os.path.isdir(grpo_epoch_dir): | |
| checkpoint_names = [ | |
| d | |
| for d in os.listdir(grpo_epoch_dir) | |
| if d.startswith("checkpoint") | |
| and os.path.isdir(os.path.join(grpo_epoch_dir, d)) | |
| ] | |
| checkpoint_names.sort() | |
| for name in checkpoint_names: | |
| path = os.path.join(grpo_epoch_dir, name) | |
| label = f"sft_grpo_{name}" | |
| adapters_to_eval.append((label, path)) | |
| if os.path.exists(grpo_adapter): | |
| adapters_to_eval.append(("sft_grpo_final", grpo_adapter)) | |
| else: | |
| if os.path.exists(grpo_adapter): | |
| adapters_to_eval.append(("sft_grpo", grpo_adapter)) | |
| if not adapters_to_eval: | |
| print(f"[SKIP] No GRPO adapters found. Expected at: {grpo_adapter} or under {grpo_epoch_dir}") | |
| else: | |
| for label, adapter_path in adapters_to_eval: | |
| print("\n" + "-" * 60) | |
| print(f"Evaluating GRPO adapter: {label}") | |
| print(f"Path: {adapter_path}") | |
| model, tokenizer = load_peft_model( | |
| args.base_model, adapter_path, bf16=bf16 | |
| ) | |
| metrics = evaluate_model( | |
| model, | |
| tokenizer, | |
| test_ds, | |
| max_new_tokens=args.max_new_tokens, | |
| batch_size=args.batch_size, | |
| verbose=args.verbose, | |
| ) | |
| print(f"[{label}] Exact Match: {metrics['exact_match_acc']:.4f}") | |
| print(f"[{label}] Relaxed Match: {metrics['relaxed_match_acc']:.4f}") | |
| print(f"[{label}] Avg Reward: {metrics['avg_reward']:.4f}") | |
| print(f"[{label}] Parse Failures: {metrics['parse_failure_rate']:.2%}") | |
| all_results[label] = metrics | |
| del model | |
| torch.cuda.empty_cache() | |
| # ===== Summary ===== | |
| print("\n" + "=" * 60) | |
| print("SUMMARY") | |
| print("=" * 60) | |
| print(f"{'Model':<12} {'Exact':<10} {'Relaxed':<10} {'Reward':<10} {'Parse Fail':<10}") | |
| print("-" * 52) | |
| for name, data in all_results.items(): | |
| if "results" in data: # Has actual results | |
| print(f"{name:<12} {data['exact_match_acc']:<10.4f} {data['relaxed_match_acc']:<10.4f} {data['avg_reward']:<10.4f} {data['parse_failure_rate']:<10.2%}") | |
| # Save results | |
| if args.save_results: | |
| save_data = { | |
| name: { | |
| "exact_match_acc": data["exact_match_acc"], | |
| "relaxed_match_acc": data["relaxed_match_acc"], | |
| "avg_reward": data["avg_reward"], | |
| "parse_failure_rate": data["parse_failure_rate"], | |
| "total_samples": data["total_samples"], | |
| "sample_predictions": data["results"][:10], # First 10 for inspection | |
| } | |
| for name, data in all_results.items() | |
| if "results" in data | |
| } | |
| with open(args.save_results, "w") as f: | |
| json.dump(save_data, f, indent=2, ensure_ascii=False) | |
| print(f"\nResults saved to: {args.save_results}") | |
| if __name__ == "__main__": | |
| main() | |