| |
| """ |
| FunctionGemma evaluation script (v2). |
| |
| Uses a unified system prompt for evaluation. |
| |
| Usage: |
| python -m src.evaluate --model_path ./runs/<run>/final_model --benchmark_path ./data/benchmark_dataset.json |
| """ |
|
|
| import os |
| import re |
| import sys |
| import json |
| import argparse |
| import logging |
| from pathlib import Path |
| from typing import Dict, List, Optional, Tuple |
| from datetime import datetime |
| from concurrent.futures import ThreadPoolExecutor, as_completed |
| from threading import Lock |
|
|
| import torch |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
| from peft import PeftModel |
| from tqdm import tqdm |
|
|
| |
| PROJECT_ROOT = Path(__file__).resolve().parent.parent |
| if str(PROJECT_ROOT) not in sys.path: |
| sys.path.insert(0, str(PROJECT_ROOT)) |
|
|
| DEFAULT_BENCHMARK_PATH = PROJECT_ROOT / "data" / "benchmark_dataset.json" |
| DEFAULT_RESULTS_DIR = PROJECT_ROOT / "results" |
|
|
| from src.config import ( |
| get_system_prompt, get_system_prompt_short, TOOLS, |
| SOLANA_TOKENS, get_token_address |
| ) |
|
|
| |
| logging.basicConfig( |
| level=logging.INFO, |
| format='%(asctime)s - %(levelname)s - %(message)s' |
| ) |
| logger = logging.getLogger(__name__) |
|
|
|
|
| def load_model( |
| model_path: str, |
| lora_path: Optional[str] = None, |
| device: str = "auto", |
| load_in_8bit: bool = False, |
| load_in_4bit: bool = False, |
| ): |
| """Load model and tokenizer.""" |
| logger.info(f"Loading model: {model_path}") |
| |
| kwargs = { |
| "device_map": device, |
| "trust_remote_code": True, |
| } |
| |
| if load_in_8bit: |
| kwargs["load_in_8bit"] = True |
| elif load_in_4bit: |
| from transformers import BitsAndBytesConfig |
| kwargs["quantization_config"] = BitsAndBytesConfig( |
| load_in_4bit=True, |
| bnb_4bit_compute_dtype=torch.bfloat16, |
| bnb_4bit_use_double_quant=True, |
| bnb_4bit_quant_type="nf4", |
| ) |
| else: |
| kwargs["torch_dtype"] = torch.bfloat16 |
| |
| tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
| |
| model = AutoModelForCausalLM.from_pretrained(model_path, **kwargs) |
| |
| if lora_path: |
| logger.info(f"Loading LoRA adapter: {lora_path}") |
| model = PeftModel.from_pretrained(model, lora_path) |
| |
| model.eval() |
| return model, tokenizer |
|
|
|
|
| def parse_functiongemma_output(response: str) -> Tuple[Optional[str], Optional[Dict]]: |
| """ |
| Parse FunctionGemma formatted output. |
| |
| Format: <start_function_call>call:FUNC_NAME{key:<escape>value<escape>,...}<end_function_call> |
| """ |
| |
| pattern = r'<start_function_call>call:(\w+)\{([^}]*)\}<end_function_call>' |
| match = re.search(pattern, response) |
| |
| if not match: |
| |
| pattern = r'<start_function_call>call:(\w+)\{([^}]*)\}' |
| match = re.search(pattern, response) |
| |
| if not match: |
| |
| pattern = r'<start_function_call>call:(\w+)' |
| match = re.search(pattern, response) |
| if match: |
| return match.group(1), {} |
| |
| |
| for func in ["SEARCH_TOKEN", "EXECUTE_SWAP"]: |
| if func in response: |
| return func, {} |
| |
| return None, None |
| |
| func_name = match.group(1) |
| params_str = match.group(2) if len(match.groups()) > 1 else "" |
| |
| |
| args = parse_params_string(params_str) |
| |
| return func_name, args |
|
|
|
|
| def parse_params_string(params_str: str) -> Dict: |
| """Parse parameter string.""" |
| args = {} |
| if not params_str: |
| return args |
| |
| |
| param_pattern = r'(\w+):(?:<escape>([^<]*)<escape>|([^,}]+))' |
| |
| for match in re.finditer(param_pattern, params_str): |
| key = match.group(1) |
| value = match.group(2) if match.group(2) is not None else match.group(3) |
| |
| if value is None: |
| continue |
| |
| value = value.strip() |
| |
| |
| if value.endswith('%'): |
| try: |
| args[key] = float(value[:-1]) / 100 |
| continue |
| except ValueError: |
| pass |
| |
| |
| try: |
| if '.' in value: |
| args[key] = float(value) |
| else: |
| args[key] = int(value) |
| except ValueError: |
| args[key] = value |
| |
| return args |
|
|
|
|
| def is_rejection_response(response: str) -> bool: |
| """Check if the response is a rejection/clarification.""" |
| |
| if '<start_function_call>' not in response: |
| return True |
| |
| |
| rejection_keywords = [ |
| "please specify", "could you", "what token", "which token", |
| "请问", "请提供", "请告诉", "您能", "什么代币", "哪个代币", |
| "sorry", "can't", "cannot", "unable", "抱歉", "无法", |
| "more information", "more details", "更多信息", |
| ] |
| |
| response_lower = response.lower() |
| for keyword in rejection_keywords: |
| if keyword.lower() in response_lower: |
| return True |
| |
| return False |
|
|
|
|
| def format_messages_for_model( |
| messages: List[Dict], |
| tokenizer, |
| tools: List[Dict] = None, |
| ) -> str: |
| """Format messages into the model chat template.""" |
| if hasattr(tokenizer, 'apply_chat_template'): |
| try: |
| return tokenizer.apply_chat_template( |
| messages, |
| tools=tools, |
| tokenize=False, |
| add_generation_prompt=True, |
| ) |
| except Exception: |
| pass |
| |
| |
| formatted = "" |
| for msg in messages: |
| role = msg["role"] |
| content = msg["content"] |
| |
| if role == "system": |
| formatted += f"<start_of_turn>system\n{content}<end_of_turn>\n" |
| elif role == "user": |
| formatted += f"<start_of_turn>user\n{content}<end_of_turn>\n" |
| elif role == "assistant": |
| formatted += f"<start_of_turn>model\n{content}<end_of_turn>\n" |
| |
| formatted += "<start_of_turn>model\n" |
| return formatted |
|
|
|
|
| def generate_response( |
| model, |
| tokenizer, |
| prompt: str, |
| system_prompt: str, |
| max_new_tokens: int = 256, |
| ) -> str: |
| """Generate model response.""" |
| messages = [ |
| {"role": "system", "content": system_prompt}, |
| {"role": "user", "content": prompt}, |
| ] |
| |
| input_text = format_messages_for_model(messages, tokenizer, TOOLS) |
| inputs = tokenizer(input_text, return_tensors="pt") |
| inputs = {k: v.to(model.device) for k, v in inputs.items()} |
| |
| with torch.no_grad(): |
| outputs = model.generate( |
| **inputs, |
| max_new_tokens=max_new_tokens, |
| temperature=0.1, |
| do_sample=True, |
| pad_token_id=tokenizer.pad_token_id, |
| eos_token_id=tokenizer.eos_token_id, |
| ) |
| |
| response = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=False) |
| response = response.replace("<end_of_turn>", "").strip() |
| |
| return response |
|
|
|
|
| def compare_arguments(expected: Dict, actual: Dict) -> Tuple[float, List[str]]: |
| """Compare expected vs actual arguments.""" |
| if not expected: |
| return 1.0 if not actual else 0.0, [] |
| |
| if not actual: |
| return 0.0, ["No arguments extracted"] |
| |
| errors = [] |
| total_keys = set(expected.keys()) | set(actual.keys()) |
| |
| if not total_keys: |
| return 1.0, [] |
| |
| matched = 0 |
| |
| for key in expected.keys(): |
| exp_val = expected.get(key) |
| act_val = actual.get(key) |
| |
| if exp_val is None: |
| continue |
| |
| if act_val is None: |
| errors.append(f"Missing key: {key}") |
| continue |
| |
| |
| if str(exp_val) == str(act_val): |
| matched += 1 |
| elif isinstance(exp_val, str) and isinstance(act_val, str): |
| |
| if exp_val[:10] == act_val[:10]: |
| matched += 0.5 |
| errors.append(f"Partial match for {key}") |
| else: |
| errors.append(f"Value mismatch for {key}: expected {exp_val}, got {act_val}") |
| elif isinstance(exp_val, (int, float)) and isinstance(act_val, (int, float)): |
| if abs(float(exp_val) - float(act_val)) < 0.01: |
| matched += 1 |
| else: |
| errors.append(f"Value mismatch for {key}: expected {exp_val}, got {act_val}") |
| else: |
| errors.append(f"Type mismatch for {key}") |
| |
| |
| for key in actual.keys(): |
| if key not in expected: |
| errors.append(f"Extra key: {key}") |
| |
| score = matched / len([k for k in expected.keys() if expected.get(k) is not None]) if expected else 1.0 |
| return score, errors |
|
|
|
|
| def process_single_sample( |
| sample: Dict, |
| idx: int, |
| model, |
| tokenizer, |
| system_prompt: str, |
| ) -> Dict: |
| """Process one sample and return evaluation result.""" |
| sample_id = sample.get("id", idx + 1) |
| category = sample.get("category", "unknown") |
| user_input = sample["input"] |
| expected_func = sample["expected"]["function_name"] |
| expected_args = sample["expected"].get("arguments", {}) |
| |
| |
| if isinstance(user_input, dict) and "messages" in user_input: |
| prompt = "" |
| for msg in user_input["messages"]: |
| if msg.get("role") == "user": |
| prompt = msg.get("content", "") |
| break |
| else: |
| prompt = str(user_input) |
| |
| |
| response = generate_response(model, tokenizer, prompt, system_prompt) |
| |
| |
| actual_func, actual_args = parse_functiongemma_output(response) |
| is_rejection = is_rejection_response(response) |
| |
| |
| func_correct = False |
| args_correct = False |
| exact_match = False |
| arg_score = 0.0 |
| error_msg = None |
| rejection_correct = False |
| |
| if expected_func is None: |
| |
| func_correct = is_rejection or actual_func is None |
| args_correct = func_correct |
| exact_match = func_correct |
| arg_score = 1.0 if func_correct else 0.0 |
| rejection_correct = func_correct |
| |
| if not func_correct: |
| error_msg = f"Expected rejection, got {actual_func}" |
| else: |
| |
| func_correct = actual_func == expected_func |
| |
| if func_correct: |
| |
| arg_score, arg_errors = compare_arguments(expected_args, actual_args or {}) |
| args_correct = arg_score >= 0.99 |
| exact_match = args_correct |
| |
| if not args_correct: |
| error_msg = "; ".join(arg_errors) |
| else: |
| error_msg = f"Expected {expected_func}, got {actual_func}" |
| |
| |
| result = { |
| "sample_id": sample_id, |
| "category": category, |
| "expected_func": expected_func, |
| "actual_func": actual_func, |
| "func_correct": func_correct, |
| "args_correct": args_correct, |
| "exact_match": exact_match, |
| "rejection_correct": rejection_correct, |
| "arg_score": arg_score, |
| "error_msg": error_msg, |
| "user_input": user_input, |
| "expected_args": expected_args, |
| "actual_args": actual_args, |
| "response": response, |
| } |
| |
| return result |
|
|
|
|
| def evaluate_benchmark( |
| model, |
| tokenizer, |
| benchmark: List[Dict], |
| chain: str = "solana", |
| verbose: bool = False, |
| num_workers: int = 1, |
| ) -> Dict: |
| """Evaluate the benchmark (supports concurrency).""" |
| system_prompt = get_system_prompt_short(chain) |
| |
| results = { |
| "total": len(benchmark), |
| "function_correct": 0, |
| "arguments_correct": 0, |
| "exact_match": 0, |
| "rejection_correct": 0, |
| "total_arg_score": 0.0, |
| "by_category": {}, |
| "by_function": {}, |
| "errors": [], |
| } |
| |
| |
| results_lock = Lock() |
| |
| |
| if num_workers > 1: |
| logger.info(f"Evaluating with {num_workers} worker threads") |
| |
| with ThreadPoolExecutor(max_workers=num_workers) as executor: |
| |
| futures = { |
| executor.submit( |
| process_single_sample, |
| sample, i, model, tokenizer, system_prompt |
| ): i for i, sample in enumerate(benchmark) |
| } |
| |
| |
| with tqdm(total=len(benchmark), desc="Evaluation") as pbar: |
| for future in as_completed(futures): |
| result = future.result() |
| |
| |
| with results_lock: |
| _update_results(results, result, verbose) |
| |
| pbar.update(1) |
| else: |
| |
| logger.info("Evaluating with a single thread") |
| for i, sample in enumerate(tqdm(benchmark, desc="Evaluation")): |
| result = process_single_sample(sample, i, model, tokenizer, system_prompt) |
| _update_results(results, result, verbose) |
| |
| return results |
|
|
|
|
| def _update_results(results: Dict, result: Dict, verbose: bool): |
| """Update aggregated evaluation results.""" |
| sample_id = result["sample_id"] |
| category = result["category"] |
| expected_func = result["expected_func"] |
| actual_func = result["actual_func"] |
| func_correct = result["func_correct"] |
| args_correct = result["args_correct"] |
| exact_match = result["exact_match"] |
| rejection_correct = result["rejection_correct"] |
| arg_score = result["arg_score"] |
| error_msg = result["error_msg"] |
| |
| |
| if func_correct: |
| results["function_correct"] += 1 |
| if args_correct: |
| results["arguments_correct"] += 1 |
| if exact_match: |
| results["exact_match"] += 1 |
| if rejection_correct: |
| results["rejection_correct"] += 1 |
| results["total_arg_score"] += arg_score |
| |
| |
| if category not in results["by_category"]: |
| results["by_category"][category] = { |
| "total": 0, "func_correct": 0, "exact_match": 0, "arg_score": 0.0 |
| } |
| results["by_category"][category]["total"] += 1 |
| if func_correct: |
| results["by_category"][category]["func_correct"] += 1 |
| if exact_match: |
| results["by_category"][category]["exact_match"] += 1 |
| results["by_category"][category]["arg_score"] += arg_score |
| |
| |
| func_key = expected_func or "None" |
| if func_key not in results["by_function"]: |
| results["by_function"][func_key] = { |
| "total": 0, "func_correct": 0, "exact_match": 0, "arg_score": 0.0 |
| } |
| results["by_function"][func_key]["total"] += 1 |
| if func_correct: |
| results["by_function"][func_key]["func_correct"] += 1 |
| if exact_match: |
| results["by_function"][func_key]["exact_match"] += 1 |
| results["by_function"][func_key]["arg_score"] += arg_score |
| |
| |
| if error_msg and len(results["errors"]) < 10: |
| results["errors"].append({ |
| "id": sample_id, |
| "category": category, |
| "input": result["user_input"], |
| "expected_func": expected_func, |
| "actual_func": actual_func, |
| "expected_args": result["expected_args"], |
| "actual_args": result["actual_args"], |
| "error": error_msg, |
| "response": result["response"][:200], |
| }) |
| |
| if verbose: |
| status = "✓" if exact_match else "✗" |
| |
| user_input = result["user_input"] |
| if isinstance(user_input, dict): |
| user_msg = "" |
| if "messages" in user_input: |
| for msg in user_input["messages"]: |
| if msg.get("role") == "user": |
| user_msg = msg.get("content", "") |
| break |
| input_preview = user_msg[:50] if user_msg else str(user_input)[:50] |
| else: |
| input_preview = str(user_input)[:50] |
| logger.info(f"[{sample_id}] {status} {category}: {input_preview}...") |
|
|
|
|
| def print_report(results: Dict): |
| """Print evaluation report.""" |
| total = results["total"] |
| |
| print("\n" + "=" * 70) |
| print("FunctionGemma Evaluation Report") |
| print("=" * 70) |
| print(f"\nTotal samples: {total}") |
| |
| print("\n" + "-" * 70) |
| print("Overall metrics") |
| print("-" * 70) |
| |
| func_acc = results["function_correct"] / total * 100 if total > 0 else 0 |
| arg_acc = results["arguments_correct"] / total * 100 if total > 0 else 0 |
| exact_acc = results["exact_match"] / total * 100 if total > 0 else 0 |
| avg_arg_score = results["total_arg_score"] / total * 100 if total > 0 else 0 |
| |
| |
| rejection_samples = sum(1 for f in results["by_function"].values() if "None" in str(f)) |
| rejection_total = results["by_function"].get("None", {}).get("total", 0) |
| rejection_acc = results["rejection_correct"] / rejection_total * 100 if rejection_total > 0 else 0 |
| |
| print(f"Function selection accuracy: {func_acc:.2f}%") |
| print(f"Argument accuracy: {arg_acc:.2f}%") |
| print(f"Exact match accuracy: {exact_acc:.2f}%") |
| print(f"Average argument score: {avg_arg_score:.2f}%") |
| print(f"Rejection accuracy: {rejection_acc:.2f}%") |
| |
| print("\n" + "-" * 70) |
| print("By function") |
| print("-" * 70) |
| |
| for func, stats in sorted(results["by_function"].items()): |
| func_total = stats["total"] |
| func_correct = stats["func_correct"] / func_total * 100 if func_total > 0 else 0 |
| func_arg_score = stats["arg_score"] / func_total * 100 if func_total > 0 else 0 |
| func_exact = stats["exact_match"] / func_total * 100 if func_total > 0 else 0 |
| |
| print(f"{func:15} | samples: {func_total:3} | func acc: {func_correct:6.2f}% | " |
| f"arg score: {func_arg_score:6.2f}% | exact: {func_exact:6.2f}%") |
| |
| if results["errors"]: |
| print("\n" + "-" * 70) |
| print("Error samples") |
| print("-" * 70) |
| |
| for err in results["errors"][:5]: |
| print(f"\nID: {err['id']} | category: {err['category']}") |
| print(f"Input: {err['input']}") |
| print(f"Expected: {err['expected_func']} | Actual: {err['actual_func']}") |
| print(f"Error: {err['error']}") |
| |
| print("\n" + "=" * 70) |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="FunctionGemma evaluation (v2)") |
| parser.add_argument("--model_path", type=str, required=True, help="Model path") |
| parser.add_argument("--lora_path", type=str, default=None, help="LoRA adapter path") |
| parser.add_argument("--benchmark_path", type=str, default=str(DEFAULT_BENCHMARK_PATH), help="Benchmark dataset path") |
| parser.add_argument("--output_path", type=str, default=None, help="Output path (defaults to results/ with timestamp)") |
| parser.add_argument("--chain", type=str, default="solana", help="Chain name") |
| parser.add_argument("--load_in_8bit", action="store_true", help="Enable 8-bit quantization") |
| parser.add_argument("--load_in_4bit", action="store_true", help="Enable 4-bit quantization") |
| parser.add_argument("--verbose", action="store_true", help="Verbose logging") |
| parser.add_argument("--num_workers", type=int, default=4, help="Number of worker threads (default 4)") |
| args = parser.parse_args() |
| |
| |
| model, tokenizer = load_model( |
| args.model_path, |
| lora_path=args.lora_path, |
| load_in_8bit=args.load_in_8bit, |
| load_in_4bit=args.load_in_4bit, |
| ) |
| |
| |
| benchmark_path = Path(args.benchmark_path) |
| logger.info(f"Loading benchmark: {benchmark_path}") |
| with open(benchmark_path, 'r', encoding='utf-8') as f: |
| benchmark = json.load(f) |
| |
| logger.info(f"Benchmark samples: {len(benchmark)}") |
| |
| |
| logger.info("Starting evaluation...") |
| results = evaluate_benchmark( |
| model, tokenizer, benchmark, |
| chain=args.chain, |
| verbose=args.verbose, |
| num_workers=args.num_workers, |
| ) |
| |
| |
| print_report(results) |
| |
| |
| output_path = Path(args.output_path) if args.output_path else DEFAULT_RESULTS_DIR / f"evaluation_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json" |
| output_path.parent.mkdir(parents=True, exist_ok=True) |
| |
| with open(output_path, 'w', encoding='utf-8') as f: |
| json.dump(results, f, ensure_ascii=False, indent=2) |
| logger.info(f"Evaluation saved to: {output_path}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|