gpt2_base_prefix_682k / scripts /evaluate_models_json.py
augustocsc's picture
GPT-2 Base trained on prefix dataset (682K)
c082aa2 verified
#!/usr/bin/env python3
"""
Quick evaluation script for JSON-formatted models.
Reads base model from adapter_config.json automatically.
"""
import argparse
import json
import logging
import os
import sys
from pathlib import Path
from tqdm import tqdm
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
from datasets import load_dataset
sys.path.insert(0, str(Path(__file__).parent.parent))
from classes.expression import Expression
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def load_model_auto(model_path: str):
"""Load model with automatic base model detection from adapter_config.json"""
adapter_config_path = os.path.join(model_path, "adapter_config.json")
if not os.path.exists(adapter_config_path):
raise FileNotFoundError(f"No adapter_config.json found in {model_path}")
with open(adapter_config_path) as f:
adapter_config = json.load(f)
base_model_name = adapter_config.get("base_model_name_or_path", "gpt2")
logger.info(f"Loading base model: {base_model_name}")
# Load base model
model = AutoModelForCausalLM.from_pretrained(
base_model_name,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto" if torch.cuda.is_available() else None
)
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
tokenizer.pad_token = tokenizer.eos_token
# Load LoRA adapter
logger.info(f"Loading LoRA adapter from {model_path}")
model = PeftModel.from_pretrained(model, model_path)
model = model.merge_and_unload()
model.eval()
return model, tokenizer, base_model_name
def create_json_prompt(vars_list, ops_list, cons="C"):
"""Create JSON format prompt"""
prompt = {
"vars": vars_list,
"ops": ops_list,
"cons": cons,
"expr": ""
}
prompt_str = json.dumps(prompt, ensure_ascii=False)
prompt_str = prompt_str.rsplit('"expr":', 1)[0] + '"expr": "'
return prompt_str
def extract_expression_json(output: str):
"""Extract expression from JSON output"""
import re
# Try to extract from "expr": "..." pattern
match = re.search(r'"expr":\s*"([^"]*)"', output)
if match:
return match.group(1)
# Try without closing quote
match = re.search(r'"expr":\s*"([^"]+)', output)
if match:
expr = match.group(1)
# Clean up common artifacts
expr = expr.split('"')[0].split('}')[0].strip()
return expr
return None
def evaluate_model(model, tokenizer, num_samples=500, dataset_name="augustocsc/sintetico_natural", data_dir="700K"):
"""Evaluate model on dataset"""
device = model.device
logger.info(f"Using device: {device}")
# Load dataset
logger.info(f"Loading dataset {dataset_name}/{data_dir}")
dataset = load_dataset(dataset_name, data_dir, split="train")
# Sample
import random
random.seed(42)
indices = random.sample(range(len(dataset)), min(num_samples, len(dataset)))
results = []
valid_count = 0
parseable_count = 0
unique_expressions = set()
logger.info(f"Evaluating on {len(indices)} samples...")
for idx in tqdm(indices, desc="Evaluating"):
sample = dataset[idx]
prompt_text = sample.get("i_prompt_n", "")
# Parse prompt to extract vars and ops
vars_line = [l for l in prompt_text.split('\n') if l.startswith('vars:')]
ops_line = [l for l in prompt_text.split('\n') if l.startswith('oper:')]
if not vars_line or not ops_line:
continue
vars_list = [v.strip() for v in vars_line[0].replace('vars:', '').split(',')]
ops_list = [o.strip() for o in ops_line[0].replace('oper:', '').split(',')]
# Create JSON prompt
prompt = create_json_prompt(vars_list, ops_list)
# Generate
inputs = tokenizer(prompt, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=100,
temperature=0.7,
top_p=0.9,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract expression
expr_str = extract_expression_json(generated)
# Validate
is_valid = False
is_parseable = False
error_msg = None
if expr_str:
try:
expr = Expression.parse_infix(expr_str)
is_parseable = True
is_valid = expr.validate()
if is_valid:
unique_expressions.add(expr_str)
except Exception as e:
error_msg = str(e)[:100]
else:
error_msg = "Failed to extract expression"
if is_valid:
valid_count += 1
if is_parseable:
parseable_count += 1
results.append({
"sample_idx": idx,
"prompt": prompt,
"generated": generated[:500],
"expression": expr_str,
"valid": is_valid,
"parseable": is_parseable,
"error": error_msg
})
total = len(results)
metrics = {
"model_path": str(model),
"num_samples": total,
"valid_rate": valid_count / total if total > 0 else 0,
"parseable_rate": parseable_count / total if total > 0 else 0,
"unique_expressions": len(unique_expressions),
"diversity_rate": len(unique_expressions) / total if total > 0 else 0,
}
return metrics, results
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str, required=True)
parser.add_argument("--num_samples", type=int, default=500)
parser.add_argument("--output_dir", type=str, default="./results_corrected")
args = parser.parse_args()
# Load model
model, tokenizer, base_model_name = load_model_auto(args.model_path)
# Evaluate
metrics, results = evaluate_model(model, tokenizer, args.num_samples)
# Print results
print("\n" + "="*60)
print(f"EVALUATION RESULTS - {os.path.basename(args.model_path)}")
print("="*60)
print(f"Base model: {base_model_name}")
print(f"Valid rate: {metrics['valid_rate']*100:.1f}%")
print(f"Parseable rate: {metrics['parseable_rate']*100:.1f}%")
print(f"Unique expressions: {metrics['unique_expressions']}")
print(f"Diversity rate: {metrics['diversity_rate']*100:.1f}%")
print("="*60)
# Save results
os.makedirs(args.output_dir, exist_ok=True)
model_name = os.path.basename(args.model_path)
metrics_path = os.path.join(args.output_dir, f"{model_name}_metrics.json")
with open(metrics_path, 'w') as f:
json.dump(metrics, f, indent=2)
results_path = os.path.join(args.output_dir, f"{model_name}_results.json")
with open(results_path, 'w') as f:
json.dump(results, f, indent=2)
logger.info(f"Results saved to {args.output_dir}")
if __name__ == "__main__":
main()