| """
|
| vLLM Batch Inference for Variance-based One-Shot Question Selection
|
| ====================================================================
|
| 用 vLLM 做批量推理,比 HuggingFace generate 快 5-10x。
|
| 对 SFT 模型在 1533 道开放题上各跑 10 次推理,
|
| 用精确字符串匹配判断对错,选出方差最大的题目。
|
|
|
| 用法 (在 Docker 容器中):
|
| CUDA_VISIBLE_DEVICES=0 python3 -u variance_select_vllm.py
|
| """
|
|
|
| import json
|
| import re
|
| import os
|
| import numpy as np
|
| from pathlib import Path
|
|
|
|
|
| MODEL_PATH = "/workspace/rl4phyx/RL4Phyx/SFT/checkpoints/sft_qwen25vl_3b/merged"
|
| TEST_FILE = "/workspace/rl4phyx/RL4Phyx/SFT/sft_test.jsonl"
|
| IMAGE_DIR = "/workspace/rl4phyx/RL4Phyx/SFT/images"
|
| OUTPUT_DIR = "/workspace/rl4phyx/RL4Phyx/SFT"
|
|
|
| NUM_RUNS = 10
|
| MAX_NEW_TOKENS = 3072
|
| TEMPERATURE = 0.7
|
| TOP_P = 0.9
|
| RLVR_COPIES = 128
|
|
|
|
|
| def extract_boxed(text):
|
| """从模型输出中提取 \\boxed{} 内的答案"""
|
| matches = re.findall(r'\\boxed\{([^{}]*(?:\{[^{}]*\}[^{}]*)*)\}', text)
|
| if matches:
|
| return matches[-1].strip()
|
| return None
|
|
|
|
|
| def normalize_answer(ans):
|
| """轻度归一化:去空格、去末尾句号"""
|
| if ans is None:
|
| return None
|
| ans = ans.strip()
|
| ans = ans.rstrip('.')
|
| ans = re.sub(r'\\text\{([^}]*)\}', r'\1', ans)
|
| ans = re.sub(r'\s+', ' ', ans)
|
| return ans
|
|
|
|
|
| def exact_match(pred_answer, gt_answer):
|
| """精确字符串匹配"""
|
| if pred_answer is None:
|
| return False
|
| pred_norm = normalize_answer(pred_answer)
|
| gt_norm = normalize_answer(gt_answer)
|
| if pred_norm is None or gt_norm is None:
|
| return False
|
| return pred_norm == gt_norm
|
|
|
|
|
| def load_test_data(test_file):
|
| """加载测试数据"""
|
| data = []
|
| with open(test_file, 'r', encoding='utf-8') as f:
|
| for line in f:
|
| data.append(json.loads(line.strip()))
|
| print(f"Loaded {len(data)} test samples")
|
| return data
|
|
|
|
|
| def build_vllm_inputs(test_data, processor):
|
| """构建 vLLM 批量推理的输入"""
|
| from PIL import Image
|
|
|
| prompts = []
|
| multi_modal_data_list = []
|
|
|
| for i, item in enumerate(test_data):
|
| prompt_text = item['prompt']
|
| image_path = os.path.join(IMAGE_DIR, f"{item['index']}.png")
|
|
|
|
|
| messages = [{"role": "user", "content": []}]
|
|
|
| has_image = os.path.exists(image_path)
|
| if has_image:
|
| messages[0]["content"].append({
|
| "type": "image",
|
| "image": f"file://{image_path}"
|
| })
|
|
|
| messages[0]["content"].append({
|
| "type": "text",
|
| "text": prompt_text
|
| })
|
|
|
|
|
| text = processor.apply_chat_template(
|
| messages, tokenize=False, add_generation_prompt=True
|
| )
|
|
|
|
|
| mm_data = {}
|
| if has_image:
|
| img = Image.open(image_path).convert("RGB")
|
| mm_data["image"] = img
|
|
|
| prompts.append(text)
|
| multi_modal_data_list.append(mm_data if mm_data else None)
|
|
|
| if (i + 1) % 200 == 0:
|
| print(f" Prepared {i+1}/{len(test_data)} inputs...")
|
|
|
| return prompts, multi_modal_data_list
|
|
|
|
|
| def run_vllm_batch(llm, prompts, multi_modal_data_list, sampling_params, run_id):
|
| """用 vLLM 跑一次批量推理"""
|
| from vllm import TextPrompt
|
|
|
| print(f"\n{'='*60}")
|
| print(f" Run {run_id + 1}/{NUM_RUNS}")
|
| print(f"{'='*60}")
|
|
|
|
|
| requests = []
|
| for i, (prompt, mm_data) in enumerate(zip(prompts, multi_modal_data_list)):
|
| req = {
|
| "prompt": prompt,
|
| }
|
| if mm_data:
|
| req["multi_modal_data"] = mm_data
|
| requests.append(req)
|
|
|
|
|
| print(f" Submitting {len(requests)} requests to vLLM...")
|
| outputs = llm.generate(requests, sampling_params)
|
| print(f" Got {len(outputs)} outputs")
|
|
|
| return outputs
|
|
|
|
|
| def main():
|
| print("=" * 60)
|
| print(" vLLM Batch Variance Selection")
|
| print("=" * 60)
|
|
|
|
|
| print("\n[1/5] Loading vLLM model...")
|
| from vllm import LLM, SamplingParams
|
| from transformers import AutoProcessor
|
|
|
| llm = LLM(
|
| model=MODEL_PATH,
|
| dtype="bfloat16",
|
| trust_remote_code=True,
|
| max_model_len=4096,
|
| gpu_memory_utilization=0.85,
|
| limit_mm_per_prompt={"image": 1},
|
| )
|
|
|
| sampling_params = SamplingParams(
|
| temperature=TEMPERATURE,
|
| top_p=TOP_P,
|
| max_tokens=MAX_NEW_TOKENS,
|
| )
|
|
|
| processor = AutoProcessor.from_pretrained(MODEL_PATH)
|
| print(" Model loaded!")
|
|
|
|
|
| print("\n[2/5] Loading test data...")
|
| test_data = load_test_data(TEST_FILE)
|
|
|
|
|
| print("\n[3/5] Building vLLM inputs...")
|
| prompts, multi_modal_data_list = build_vllm_inputs(test_data, processor)
|
| print(f" Built {len(prompts)} prompts")
|
|
|
|
|
| print(f"\n[4/5] Running {NUM_RUNS} inference passes...")
|
| all_runs = []
|
|
|
| for run_id in range(NUM_RUNS):
|
| outputs = run_vllm_batch(llm, prompts, multi_modal_data_list, sampling_params, run_id)
|
|
|
|
|
| run_results = []
|
| correct_count = 0
|
| for i, output in enumerate(outputs):
|
| prediction = output.outputs[0].text
|
| pred_answer = extract_boxed(prediction)
|
| gt_answer = test_data[i]['ground_truth']
|
| is_correct = exact_match(pred_answer, gt_answer)
|
| if is_correct:
|
| correct_count += 1
|
|
|
| run_results.append({
|
| 'index': test_data[i]['index'],
|
| 'pred_answer': pred_answer,
|
| 'gt_answer': gt_answer,
|
| 'correct': is_correct,
|
| })
|
|
|
| all_runs.append(run_results)
|
| print(f" Run {run_id+1} accuracy: {correct_count}/{len(test_data)} "
|
| f"({100*correct_count/len(test_data):.1f}%)")
|
|
|
|
|
| interim_path = os.path.join(OUTPUT_DIR, f"variance_run_{run_id}.json")
|
| with open(interim_path, 'w', encoding='utf-8') as f:
|
| json.dump(run_results, f, ensure_ascii=False, indent=2)
|
| print(f" Saved to {interim_path}")
|
|
|
|
|
| print(f"\n[5/5] Computing variance...")
|
| n_questions = len(test_data)
|
| stats = []
|
|
|
| for qi in range(n_questions):
|
| correct_flags = [all_runs[r][qi]['correct'] for r in range(NUM_RUNS)]
|
| n_correct = sum(correct_flags)
|
| p = n_correct / NUM_RUNS
|
| variance = p * (1 - p)
|
|
|
| stats.append({
|
| 'index': test_data[qi]['index'],
|
| 'category': test_data[qi].get('category', ''),
|
| 'ground_truth': test_data[qi]['ground_truth'],
|
| 'n_correct': n_correct,
|
| 'accuracy': p,
|
| 'variance': variance,
|
| 'correct_flags': correct_flags,
|
| 'pred_answers': [all_runs[r][qi]['pred_answer'] for r in range(NUM_RUNS)],
|
| })
|
|
|
| stats.sort(key=lambda x: x['variance'], reverse=True)
|
|
|
|
|
| stats_path = os.path.join(OUTPUT_DIR, "variance_results.json")
|
| with open(stats_path, 'w', encoding='utf-8') as f:
|
| json.dump(stats, f, ensure_ascii=False, indent=2)
|
| print(f" Saved variance stats to {stats_path}")
|
|
|
|
|
| print(f"\n{'='*60}")
|
| print(f" TOP 20 HIGHEST VARIANCE QUESTIONS")
|
| print(f"{'='*60}")
|
| for i, s in enumerate(stats[:20]):
|
| print(f" #{i+1}: idx={s['index']} | gt={s['ground_truth'][:30]:30s} | "
|
| f"correct={s['n_correct']}/{NUM_RUNS} | var={s['variance']:.4f} | "
|
| f"cat={s['category']}")
|
|
|
|
|
| best = stats[0]
|
| print(f"\n{'='*60}")
|
| print(f" SELECTED QUESTION FOR ONE-SHOT RLVR")
|
| print(f"{'='*60}")
|
| print(f" Index: {best['index']}")
|
| print(f" Category: {best['category']}")
|
| print(f" Ground Truth: {best['ground_truth']}")
|
| print(f" Accuracy: {best['n_correct']}/{NUM_RUNS} ({best['accuracy']*100:.0f}%)")
|
| print(f" Variance: {best['variance']:.4f}")
|
|
|
|
|
| best_idx = int(best['index'])
|
| best_item = next(item for item in test_data if int(item['index']) == best_idx)
|
|
|
| best_path = os.path.join(OUTPUT_DIR, "best_question_for_rlvr.json")
|
| with open(best_path, 'w', encoding='utf-8') as f:
|
| json.dump({'selected_question': best_item, 'stats': best}, f, ensure_ascii=False, indent=2)
|
| print(f" Saved to {best_path}")
|
|
|
|
|
| import pandas as pd
|
| prompt_messages = [{"role": "user", "content": best_item['prompt']}]
|
| records = [{
|
| 'prompt': prompt_messages,
|
| 'answer': best_item['ground_truth'],
|
| 'image_path': f"{best_item['index']}.png",
|
| 'data_source': 'deepscaler',
|
| 'category': best_item.get('category', 'Physics'),
|
| 'index': best_item['index'],
|
| } for _ in range(RLVR_COPIES)]
|
|
|
| df = pd.DataFrame(records)
|
| parquet_path = os.path.join(OUTPUT_DIR, "rlvr_train.parquet")
|
| df.to_parquet(parquet_path, index=False)
|
| print(f" Saved training parquet ({len(df)} rows) to {parquet_path}")
|
|
|
| print(f"\n{'='*60}")
|
| print(f" DONE!")
|
| print(f"{'='*60}")
|
|
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|