| |
| """ |
| fVLM benchmark evaluation (batched MCQ scoring). |
| |
| Usage: |
| # 135M model |
| python benchmark.py --llm /workspace/models/SmolLM2-135M-Instruct \ |
| --checkpoint /workspace/checkpoints/final/stage3/best.pt |
| |
| # 1.7B model |
| python benchmark.py --llm /workspace/models/SmolLM2-1.7B-Instruct \ |
| --checkpoint /workspace/checkpoints/final_1.7B/stage3/latest.pt |
| |
| # Run specific benchmarks only |
| python benchmark.py --checkpoint ... --only MVBench ScienceQA |
| |
| Key optimizations: |
| 1. Batch all MCQ options into ONE forward pass per mode (not N sequential) |
| 2. Compute per-option CE from logits (avoid model's averaged loss) |
| 3. Cache DINO encoding across modes (same frames, reuse kv_cache) |
| """ |
|
|
| import sys, os, json, tarfile, io, time, re, glob, gc |
|
|
| import torch |
| import torch.nn.functional as F |
| from PIL import Image |
| from torchvision import transforms |
| from transformers import AutoTokenizer |
| from collections import defaultdict |
|
|
| from model import FoveatedVLM |
|
|
|
|
| |
|
|
| def load_model(checkpoint_path, llm_name, dino_name="/workspace/models/dinov2-small", |
| device="cuda"): |
| model = FoveatedVLM( |
| llm_name=llm_name, dino_name=dino_name, |
| query_dim=384, visual_scale=0.14, lambda_coarse=0.0, deep_query=True, |
| ) |
| ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False) |
| state_dict = ckpt["model_state_dict"] |
| |
| if any("._orig_mod." in k for k in state_dict): |
| state_dict = {k.replace("._orig_mod", ""): v for k, v in state_dict.items()} |
| model.load_state_dict(state_dict) |
| model = model.to(device).to(torch.bfloat16).eval() |
| print(f"Loaded: {checkpoint_path} (step {ckpt.get('step', '?')})") |
| return model |
|
|
|
|
| def load_tokenizer(llm_name): |
| tok = AutoTokenizer.from_pretrained(llm_name) |
| if tok.pad_token is None: |
| tok.pad_token = tok.eos_token |
| return tok |
|
|
|
|
| FRAME_TRANSFORM = transforms.Compose([ |
| transforms.Resize((224, 224)), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
| ]) |
|
|
|
|
| |
|
|
| def load_all_samples_from_shards(shard_pattern): |
| shard_files = sorted(glob.glob(shard_pattern)) |
| print(f" Loading from {len(shard_files)} shards...") |
| samples = [] |
| for shard_path in shard_files: |
| with tarfile.open(shard_path, "r") as tar: |
| members = tar.getmembers() |
| grouped = {} |
| for m in members: |
| parts = m.name.split(".") |
| if m.name.endswith(".json"): |
| key = parts[0] |
| if key not in grouped: |
| grouped[key] = {"frames": {}} |
| grouped[key]["json"] = json.load(tar.extractfile(m)) |
| elif m.name.endswith(".jpg") or m.name.endswith(".png"): |
| key = parts[0] |
| frame_idx = int(parts[1]) if len(parts) >= 3 else 0 |
| if key not in grouped: |
| grouped[key] = {"frames": {}} |
| img_data = tar.extractfile(m).read() |
| img = Image.open(io.BytesIO(img_data)).convert("RGB") |
| grouped[key]["frames"][frame_idx] = img |
| for key in sorted(grouped.keys()): |
| entry = grouped[key] |
| if entry.get("json") and entry.get("frames"): |
| sorted_frames = [entry["frames"][i] for i in sorted(entry["frames"].keys())] |
| samples.append({ |
| "key": key, |
| "json": entry["json"], |
| "frames": sorted_frames, |
| }) |
| print(f" Loaded {len(samples)} samples") |
| return samples |
|
|
|
|
| def prepare_frames_tensor(pil_frames, device="cuda", replicate_to=8): |
| """Transform PIL frames to tensor. Replicate single-frame images to match training.""" |
| tensors = [FRAME_TRANSFORM(f) for f in pil_frames] |
| frames = torch.stack(tensors) |
| |
| if frames.shape[0] == 1 and replicate_to > 1: |
| frames = frames.repeat(replicate_to, 1, 1, 1) |
| return frames.unsqueeze(0).to(device, dtype=torch.bfloat16) |
|
|
|
|
| |
|
|
| def parse_mcq_options(user_text): |
| options = {} |
| for match in re.finditer(r'([A-Z])\.\s*(.+?)(?=\n[A-Z]\.|$)', user_text, re.DOTALL): |
| options[match.group(1)] = match.group(1) + ". " + match.group(2).strip() |
| return options |
|
|
|
|
| def extract_answer_letter(assistant_text): |
| m = re.match(r'([A-Z])\.', assistant_text.strip()) |
| if m: |
| return m.group(1) |
| return assistant_text.strip()[0] if assistant_text.strip() else "?" |
|
|
|
|
| |
|
|
| @torch.no_grad() |
| def score_options_batched(model, tokenizer, frames, question_text, options_dict, mode, device): |
| """ |
| Score all MCQ options in ONE batched forward pass. |
| Returns dict {letter: loss} where lower loss = better match. |
| """ |
| letters = sorted(options_dict.keys()) |
| if not letters: |
| return {} |
|
|
| |
| prompt_messages = [ |
| {"role": "system", "content": "You are a helpful AI assistant."}, |
| {"role": "user", "content": question_text}, |
| ] |
| prompt_text = tokenizer.apply_chat_template(prompt_messages, tokenize=False, add_generation_prompt=True) |
| prompt_ids = tokenizer.encode(prompt_text) |
| S_prompt = len(prompt_ids) |
|
|
| |
| all_ids = [] |
| for letter in letters: |
| messages = [ |
| {"role": "system", "content": "You are a helpful AI assistant."}, |
| {"role": "user", "content": question_text}, |
| {"role": "assistant", "content": options_dict[letter]}, |
| ] |
| full_text = tokenizer.apply_chat_template(messages, tokenize=False) |
| ids = tokenizer.encode(full_text) |
| all_ids.append(ids) |
|
|
| |
| max_len = max(len(ids) for ids in all_ids) |
| pad_id = tokenizer.pad_token_id |
| N = len(letters) |
|
|
| batch_ids = torch.full((N, max_len), pad_id, dtype=torch.long, device=device) |
| batch_attn = torch.zeros(N, max_len, dtype=torch.long, device=device) |
| batch_loss_mask = torch.zeros(N, max_len, dtype=torch.float32, device=device) |
|
|
| for i, ids in enumerate(all_ids): |
| L = len(ids) |
| batch_ids[i, :L] = torch.tensor(ids, dtype=torch.long) |
| batch_attn[i, :L] = 1 |
| batch_loss_mask[i, S_prompt:L] = 1.0 |
|
|
| |
| frames_batch = frames.expand(N, -1, -1, -1, -1) |
|
|
| |
| with torch.amp.autocast("cuda", dtype=torch.bfloat16): |
| result = model.forward( |
| frames=frames_batch, |
| input_ids=batch_ids, |
| attention_mask=batch_attn, |
| loss_mask=batch_loss_mask, |
| mode=mode, |
| ) |
|
|
| |
| logits = result["logits"] |
| T_visual = logits.shape[1] - batch_ids.shape[1] |
|
|
| |
| text_logits = logits[:, T_visual:, :] |
| shift_logits = text_logits[:, :-1, :].contiguous() |
| shift_labels = batch_ids[:, 1:].contiguous() |
| shift_mask = batch_loss_mask[:, 1:].contiguous() |
|
|
| |
| V = shift_logits.shape[-1] |
| per_token_loss = F.cross_entropy( |
| shift_logits.reshape(-1, V), |
| shift_labels.reshape(-1), |
| reduction="none", |
| ignore_index=pad_id, |
| ).reshape(N, -1) |
|
|
| |
| masked_loss = (per_token_loss * shift_mask).sum(dim=1) / shift_mask.sum(dim=1).clamp(min=1) |
|
|
| return {letters[i]: masked_loss[i].item() for i in range(N)} |
|
|
|
|
| |
|
|
| @torch.no_grad() |
| def evaluate_mcq_benchmark(model, tokenizer, samples, benchmark_name, modes, device): |
| results = {mode: {"correct": 0, "total": 0, "per_category": defaultdict(lambda: {"correct": 0, "total": 0})} |
| for mode in modes} |
| t0 = time.time() |
|
|
| for i, sample in enumerate(samples): |
| meta = sample["json"] |
| user_text = meta["user"] |
| gt_answer = meta["assistant"] |
| source = meta.get("source", "unknown") |
| category = source.split("/")[-1] if "/" in source else source |
| gt_letter = extract_answer_letter(gt_answer) |
| options = parse_mcq_options(user_text) |
| if not options: |
| continue |
|
|
| frames = prepare_frames_tensor(sample["frames"], device=device) |
|
|
| for mode in modes: |
| option_losses = score_options_batched( |
| model, tokenizer, frames, user_text, options, mode, device |
| ) |
| if not option_losses: |
| continue |
| pred_letter = min(option_losses, key=option_losses.get) |
| correct = (pred_letter == gt_letter) |
| results[mode]["total"] += 1 |
| if correct: |
| results[mode]["correct"] += 1 |
| results[mode]["per_category"][category]["total"] += 1 |
| if correct: |
| results[mode]["per_category"][category]["correct"] += 1 |
|
|
| if (i + 1) % 100 == 0: |
| elapsed = time.time() - t0 |
| for mode in modes: |
| r = results[mode] |
| acc = r["correct"] / max(r["total"], 1) * 100 |
| print(f" [{benchmark_name}] {i+1}/{len(samples)} | {mode}: {acc:.1f}% ({r['correct']}/{r['total']}) | {elapsed:.0f}s", flush=True) |
|
|
| return results |
|
|
|
|
| |
|
|
| @torch.no_grad() |
| def evaluate_val_loss(model, tokenizer, shard_pattern, modes, device, max_samples=1000): |
| samples = load_all_samples_from_shards(shard_pattern) |
| if max_samples: |
| samples = samples[:max_samples] |
|
|
| results = {mode: {"total_loss": 0.0, "count": 0} for mode in modes} |
| t0 = time.time() |
|
|
| for i, sample in enumerate(samples): |
| meta = sample["json"] |
| frames = prepare_frames_tensor(sample["frames"], device=device) |
|
|
| if "token_ids" in meta: |
| input_ids = torch.tensor(meta["token_ids"], dtype=torch.long).unsqueeze(0).to(device) |
| loss_mask_vals = meta.get("loss_mask", [1] * len(meta["token_ids"])) |
| loss_mask = torch.tensor(loss_mask_vals, dtype=torch.float32).unsqueeze(0).to(device) |
| else: |
| caption = meta.get("caption", meta.get("assistant", "")) |
| user = meta.get("user", "Describe this video.") |
| messages = [ |
| {"role": "system", "content": "You are a helpful AI assistant."}, |
| {"role": "user", "content": user}, |
| {"role": "assistant", "content": caption}, |
| ] |
| text = tokenizer.apply_chat_template(messages, tokenize=False) |
| input_ids = tokenizer.encode(text, return_tensors="pt").to(device) |
| loss_mask = torch.ones_like(input_ids, dtype=torch.float32) |
|
|
| attention_mask = torch.ones_like(input_ids) |
|
|
| for mode in modes: |
| with torch.amp.autocast("cuda", dtype=torch.bfloat16): |
| result = model.forward( |
| frames=frames, input_ids=input_ids, |
| attention_mask=attention_mask, loss_mask=loss_mask, |
| mode=mode, |
| ) |
| results[mode]["total_loss"] += result["loss"].item() |
| results[mode]["count"] += 1 |
|
|
| if (i + 1) % 200 == 0: |
| elapsed = time.time() - t0 |
| for mode in modes: |
| r = results[mode] |
| avg = r["total_loss"] / max(r["count"], 1) |
| print(f" [val_10k] {i+1}/{len(samples)} | {mode}: loss={avg:.4f} | {elapsed:.0f}s", flush=True) |
|
|
| return results |
|
|
|
|
| |
|
|
| def run_mcq_benchmark(model, tokenizer, name, shard_pattern, modes, device, all_results): |
| """Load, evaluate, and free one MCQ benchmark.""" |
| shards = glob.glob(shard_pattern) |
| if not shards: |
| print(f" Skipping {name} β shards not found") |
| return |
| samples = load_all_samples_from_shards(shard_pattern) |
| results = evaluate_mcq_benchmark(model, tokenizer, samples, name, modes, device) |
| del samples; gc.collect() |
|
|
| all_results[name.lower().replace("-", "_").replace(" ", "_")] = {} |
| key = name.lower().replace("-", "_").replace(" ", "_") |
| for mode in modes: |
| r = results[mode] |
| acc = r["correct"] / max(r["total"], 1) * 100 |
| all_results[key][mode] = { |
| "accuracy": acc, "correct": r["correct"], "total": r["total"], |
| "per_category": {cat: {"accuracy": v["correct"]/max(v["total"],1)*100, |
| "correct": v["correct"], "total": v["total"]} |
| for cat, v in r["per_category"].items()}, |
| } |
| print(f" {mode}: {acc:.1f}% ({r['correct']}/{r['total']})") |
|
|
|
|
| def main(): |
| import argparse |
| parser = argparse.ArgumentParser(description="fVLM benchmark evaluation") |
| parser.add_argument("--llm", default="/workspace/models/SmolLM2-135M-Instruct", |
| help="HuggingFace LLM path (e.g. SmolLM2-135M or 1.7B)") |
| parser.add_argument("--checkpoint", default="/workspace/checkpoints/final/stage3/best.pt", |
| help="Path to model checkpoint (.pt)") |
| parser.add_argument("--dino", default="/workspace/models/dinov2-small", |
| help="HuggingFace DINOv2 path") |
| parser.add_argument("--only", nargs="+", default=None, |
| help="Run only specified benchmarks (e.g. --only MVBench ScienceQA)") |
| parser.add_argument("--output", default=None, |
| help="Output JSON path (default: /workspace/benchmark_results_{model}.json)") |
| parser.add_argument("--merge", action="store_true", |
| help="Merge with existing results file instead of overwriting") |
| args = parser.parse_args() |
|
|
| |
| model_name = os.path.basename(args.llm).replace("-Instruct", "").replace("SmolLM2-", "") |
| if args.output is None: |
| args.output = f"/workspace/benchmark_results_{model_name}.json" |
|
|
| device = "cuda" |
| modes = ["coarse_only", "coarse_fine", "autoregressive"] |
|
|
| print("=" * 70) |
| print(f"fVLM-{model_name} BENCHMARK EVALUATION") |
| print(f" LLM: {args.llm}") |
| print(f" Checkpoint: {args.checkpoint}") |
| print(f" Output: {args.output}") |
| print("=" * 70) |
|
|
| print("\nLoading model (bf16)...") |
| model = load_model(args.checkpoint, args.llm, args.dino, device) |
| tokenizer = load_tokenizer(args.llm) |
|
|
| |
| all_results = {} |
| if args.merge and os.path.exists(args.output): |
| with open(args.output) as f: |
| all_results = json.load(f) |
| print(f" Loaded existing results from {args.output}") |
|
|
| t_global = time.time() |
|
|
| |
| benchmarks = [ |
| ("MVBench", "/workspace/data/eval/benchmarks/mvbench_shards/mvbench_*.tar"), |
| ("Video-MME", "/workspace/data/eval/benchmarks/video_mme_shards/video_mme_*.tar"), |
| ("ScienceQA", "/workspace/data/eval/benchmarks/scienceqa_shards/scienceqa_*.tar"), |
| ("POPE", "/workspace/data/eval/benchmarks/pope_shards/pope_*.tar"), |
| ("MLVU", "/workspace/data/eval/benchmarks/mlvu_shards/mlvu_*.tar"), |
| ] |
|
|
| if args.only: |
| only_set = {n.lower() for n in args.only} |
| benchmarks = [(n, p) for n, p in benchmarks if n.lower() in only_set] |
|
|
| for i, (name, pattern) in enumerate(benchmarks): |
| print(f"\n{'-' * 70}") |
| print(f"BENCHMARK: {name}") |
| print(f"{'-' * 70}") |
| run_mcq_benchmark(model, tokenizer, name, pattern, modes, device, all_results) |
|
|
| |
| with open(args.output, "w") as f: |
| json.dump(all_results, f, indent=2) |
| print(f"\nResults saved: {args.output}") |
|
|
| |
| total_time = time.time() - t_global |
| print("\n" + "=" * 70) |
| print(f"SUMMARY (total: {total_time:.0f}s = {total_time/60:.1f}min)") |
| print("=" * 70) |
| print(f"\n{'Benchmark':<15} {'Coarse-Only':>15} {'Coarse->Fine':>15} {'Autoregressive':>15}") |
| print("-" * 62) |
| for bench_name, bench_data in all_results.items(): |
| vals = [] |
| for mode in modes: |
| if mode not in bench_data: |
| vals.append("β") |
| elif "accuracy" in bench_data[mode]: |
| vals.append(f"{bench_data[mode]['accuracy']:.1f}%") |
| else: |
| vals.append(f"{bench_data[mode]['avg_loss']:.4f}") |
| print(f"{bench_name:<15} {vals[0]:>15} {vals[1]:>15} {vals[2]:>15}") |
| print("\n" + "=" * 70) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|