project_02_DS / experiments /parameter_sweep.py
griddev's picture
first push
c374021
"""
experiments/parameter_sweep.py
================================
Sweep beam_size, length_penalty, and max_new_tokens across BLIP, ViT-GPT2,
and GIT to measure the effect of decoding parameters on caption quality (CIDEr).
Usage:
python -m experiments.parameter_sweep --model blip --eval_batches 15
The sweep matrix:
beam_size : [3, 5, 10]
length_penalty: [0.8, 1.0, 1.2]
max_new_tokens: [20, 50]
Each cell reports CIDEr on the validation set (25 batches by default).
A summary table is printed at the end.
Insight guide:
- beam_size ↑ β†’ more diverse candidates considered, usually better quality
but slower decoding; diminishing returns above ~5
- length_penalty < 1.0 β†’ penalizes shorter sequences β†’ longer captions
- length_penalty > 1.0 β†’ rewards shorter sequences β†’ more compact captions
- max_new_tokens ↑ β†’ allows longer captions; may hurt CIDEr if model rambles
"""
import argparse
import itertools
import torch
from tqdm.auto import tqdm
from pycocoevalcap.cider.cider import Cider
# ─────────────────────────────────────────────────────────────────────────────
# Default Search Space
# ─────────────────────────────────────────────────────────────────────────────
BEAM_SIZES = [3, 5, 10]
LENGTH_PENALTIES = [0.8, 1.0, 1.2]
MAX_TOKENS = [20, 50]
# ─────────────────────────────────────────────────────────────────────────────
# Per-Model Caption Generator (handles BLIP / ViT-GPT2 / GIT)
# ─────────────────────────────────────────────────────────────────────────────
def _generate_blip(model, processor, batch, device,
num_beams, max_new_tokens, length_penalty):
pixel_values = batch["pixel_values"].to(device)
with torch.no_grad():
out = model.generate(
pixel_values=pixel_values,
num_beams=num_beams,
max_new_tokens=max_new_tokens,
length_penalty=length_penalty,
)
return processor.batch_decode(out, skip_special_tokens=True)
def _generate_vit_gpt2(model, tokenizer, batch, device,
num_beams, max_new_tokens, length_penalty):
pixel_values = batch["pixel_values"].to(device)
with torch.no_grad():
out = model.generate(
pixel_values=pixel_values,
num_beams=num_beams,
max_new_tokens=max_new_tokens,
length_penalty=length_penalty,
)
return [tokenizer.decode(ids, skip_special_tokens=True) for ids in out]
def _generate_git(model, processor, batch, device,
num_beams, max_new_tokens, length_penalty):
inputs = {k: v.to(device) for k, v in batch.items()
if k in ("pixel_values", "input_ids", "attention_mask")}
with torch.no_grad():
out = model.generate(
**inputs,
num_beams=num_beams,
max_new_tokens=max_new_tokens,
length_penalty=length_penalty,
)
return processor.batch_decode(out, skip_special_tokens=True)
# ─────────────────────────────────────────────────────────────────────────────
# CIDEr Evaluator for One Configuration
# ─────────────────────────────────────────────────────────────────────────────
def eval_one_config(model_name, model_objs, dataloader, device,
num_beams, max_new_tokens, length_penalty,
eval_batches=25):
"""
Evaluate CIDEr for one (model, num_beams, max_new_tokens, length_penalty) combo.
model_objs: dict with keys depending on model_name
- blip: {'model': ..., 'processor': ...}
- vit_gpt2: {'model': ..., 'tokenizer': ...}
- git: {'model': ..., 'processor': ...}
Returns:
cider_score: float
"""
gts, res = {}, {}
for i, batch in enumerate(tqdm(
dataloader,
desc=f" {model_name} b={num_beams} L={length_penalty} T={max_new_tokens}",
leave=False)):
if i >= eval_batches:
break
if model_name == "blip":
preds = _generate_blip(
model_objs["model"], model_objs["processor"],
batch, device, num_beams, max_new_tokens, length_penalty)
labels = batch["labels"].clone()
gt_texts = model_objs["processor"].batch_decode(
labels, skip_special_tokens=True)
elif model_name == "vit_gpt2":
preds = _generate_vit_gpt2(
model_objs["model"], model_objs["tokenizer"],
batch, device, num_beams, max_new_tokens, length_penalty)
labels = batch["labels"].clone()
labels[labels == -100] = model_objs["pad_token_id"]
gt_texts = model_objs["tokenizer"].batch_decode(
labels, skip_special_tokens=True)
elif model_name == "git":
preds = _generate_git(
model_objs["model"], model_objs["processor"],
batch, device, num_beams, max_new_tokens, length_penalty)
labels = batch["labels"].clone()
labels[labels == -100] = model_objs["processor"].tokenizer.pad_token_id
gt_texts = model_objs["processor"].batch_decode(
labels, skip_special_tokens=True)
else:
raise ValueError(f"Unknown model: {model_name}")
for j, (pred, gt) in enumerate(zip(preds, gt_texts)):
key = str(i * len(preds) + j)
res[key] = [pred]
gts[key] = [gt]
if not gts:
return 0.0
scorer = Cider()
score, _ = scorer.compute_score(gts, res)
return score
# ─────────────────────────────────────────────────────────────────────────────
# Full Sweep Runner
# ─────────────────────────────────────────────────────────────────────────────
def run_parameter_sweep(model_name, model_objs, dataloader, device,
beam_sizes=None, length_penalties=None, max_tokens=None,
eval_batches=25):
"""
Run the full decoding parameter sweep for one model.
Args:
model_name : 'blip' | 'vit_gpt2' | 'git'
model_objs : dict of model + processor/tokenizer references
dataloader : validation DataLoader
device : torch.device
beam_sizes : list of int beam sizes (default: [3, 5, 10])
length_penalties : list of float penalties (default: [0.8, 1.0, 1.2])
max_tokens : list of int max new tokens (default: [20, 50])
eval_batches : number of batches per configuration
Returns:
results: list of dicts with keys:
model, beam_size, length_penalty, max_tokens, cider
"""
beam_sizes = beam_sizes or BEAM_SIZES
length_penalties = length_penalties or LENGTH_PENALTIES
max_tokens = max_tokens or MAX_TOKENS
combos = list(itertools.product(beam_sizes, length_penalties, max_tokens))
print(f"\nπŸ”¬ Parameter Sweep β€” {model_name.upper()} ({len(combos)} configurations)")
print("=" * 70)
results = []
for num_beams, lp, mt in combos:
score = eval_one_config(
model_name, model_objs, dataloader, device,
num_beams=num_beams, max_new_tokens=mt,
length_penalty=lp, eval_batches=eval_batches,
)
results.append({
"model": model_name, "beam_size": num_beams,
"length_penalty": lp, "max_tokens": mt, "cider": score,
})
# ── Print summary table ───────────────────────────────────────────────────
print(f"\n{'='*70}")
print(f" Parameter Sweep Results β€” {model_name.upper()}")
print(f"{'='*70}")
print(f" {'Beams':>5} {'LenPenalty':>10} {'MaxTok':>7} {'CIDEr':>8}")
print(f" {'-'*5} {'-'*10} {'-'*7} {'-'*8}")
best = max(results, key=lambda r: r["cider"])
for r in sorted(results, key=lambda x: (-x["cider"], x["beam_size"])):
marker = " ← best" if r == best else ""
print(f" {r['beam_size']:>5} {r['length_penalty']:>10.1f} "
f"{r['max_tokens']:>7} {r['cider']:>8.4f}{marker}")
print(f"{'='*70}")
return results
# ─────────────────────────────────────────────────────────────────────────────
# CLI Entrypoint
# ─────────────────────────────────────────────────────────────────────────────
def main():
parser = argparse.ArgumentParser(description="Decoding parameter sweep")
parser.add_argument("--model", choices=["blip", "vit_gpt2", "git"],
default="blip")
parser.add_argument("--eval_batches", type=int, default=15)
args = parser.parse_args()
import sys, os
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from config import CFG
from data_prep import get_dataloaders, get_dataloaders_for_model
device = torch.device(
"mps" if torch.backends.mps.is_available() else
"cuda" if torch.cuda.is_available() else "cpu"
)
cfg = CFG.load_for_model(args.model)
if args.model == "blip":
from models.blip_tuner import get_blip_model
model, processor = get_blip_model(cfg, device)
model.eval()
_, val_loader = get_dataloaders(cfg, processor)
model_objs = {"model": model, "processor": processor}
elif args.model == "vit_gpt2":
from models.vit_gpt2_tuner import get_vit_gpt2_model
model, processor, tokenizer = get_vit_gpt2_model(cfg, device)
model.eval()
_, val_loader = get_dataloaders_for_model(cfg, "vit_gpt2", processor, tokenizer)
model_objs = {"model": model, "tokenizer": tokenizer,
"pad_token_id": tokenizer.pad_token_id}
elif args.model == "git":
from models.git_tuner import get_git_model
model, processor = get_git_model(cfg, device)
model.eval()
_, val_loader = get_dataloaders_for_model(cfg, "git", processor)
model_objs = {"model": model, "processor": processor}
run_parameter_sweep(
args.model, model_objs, val_loader, device,
eval_batches=args.eval_batches,
)
if __name__ == "__main__":
main()