timm
ViT_Fast / test_preprocess_reduce.py
1999xia's picture
Upload folder using huggingface_hub
54ee1eb verified
raw
history blame contribute delete
8.4 kB
"""
Test image preprocessing strategies that actually reduce token count.
Corrected experiments:
1. Downsample (no restore): resize to smaller, feed directly (inference only)
2. Checkerboard removal: remove every other row (halve image size)
3. Center crop: crop a smaller region from the image
Usage:
python test_preprocess_reduce.py --dataset cifar100 --gpu 5
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
import argparse
import sys
import os
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from datasets import get_cifar100_loader, get_oxford_pets_loader, get_food101_loader
import timm
DATASETS = {
'cifar100': (get_cifar100_loader, 100),
'oxford_pets': (get_oxford_pets_loader, 37),
'food101': (get_food101_loader, 101),
}
def resize_down(images, target_size):
"""Resize image to smaller resolution (no upsample back)."""
B, C, H, W = images.shape
return F.interpolate(images, size=(target_size, target_size), mode='bilinear', align_corners=False)
def remove_even_rows_cols(images):
"""Remove every other row AND column → image size halves.
224×224 → remove even rows → 112×224 → remove even cols → 112×112
Effectively 75% pixel reduction, image shrinks to 1/4 area.
"""
return images[:, :, ::2, ::2]
def remove_even_rows_only(images):
"""Remove every other row only → height halves, width unchanged.
224×224 → 112×224 (non-square, but fewer patches)
"""
return images[:, :, ::2, :]
def center_crop(images, crop_size):
"""Center crop the image to a smaller size."""
B, C, H, W = images.shape
h_start = (H - crop_size) // 2
w_start = (W - crop_size) // 2
return images[:, :, h_start:h_start+crop_size, w_start:w_start+crop_size]
def get_num_patches(img_size, patch_size=16, stride=16):
"""Calculate number of patches for a given image size."""
grid = (img_size - patch_size) // stride + 1
return grid * grid
def make_model_for_size(device, num_classes, img_size, checkpoint_path, patch_size=16):
"""Create a ViT-B/16 model adapted for a specific input size.
Loads checkpoint from 224×224 model, interpolates pos_embed for new size.
"""
model = timm.create_model(
'vit_base_patch16_224.augreg_in21k',
pretrained=True, num_classes=num_classes, img_size=img_size)
if os.path.exists(checkpoint_path):
ckpt = torch.load(checkpoint_path, map_location='cpu', weights_only=True)
sd = ckpt.get('model_state_dict', ckpt)
# Interpolate position embeddings if sizes differ
if 'pos_embed' in sd and sd['pos_embed'].shape != model.pos_embed.shape:
old_pos = sd.pop('pos_embed')
cls_pos = old_pos[:, 0:1, :]
patch_pos = old_pos[:, 1:, :]
old_grid = int((patch_pos.shape[1]) ** 0.5)
new_grid = (img_size - patch_size) // patch_size + 1
if old_grid != new_grid:
patch_2d = patch_pos.reshape(1, old_grid, old_grid, -1).permute(0, 3, 1, 2)
patch_new = F.interpolate(patch_2d, size=(new_grid, new_grid),
mode='bilinear', align_corners=False)
patch_new = patch_new.permute(0, 2, 3, 1).reshape(1, -1, 768)
sd['pos_embed'] = torch.cat([cls_pos, patch_new], dim=1)
missing, unexpected = model.load_state_dict(sd, strict=False)
if missing:
print(f' Missing keys: {len(missing)} (expected for diff img_size)', flush=True)
if unexpected:
print(f' Unexpected keys: {len(unexpected)}', flush=True)
model = model.to(device)
model.eval()
return model
@torch.no_grad()
def evaluate(model, loader, device, preprocess=None, track_time=False, img_size=224):
model.eval()
correct = 0
total = 0
total_time = 0
for images, targets in loader:
images, targets = images.to(device), targets.to(device)
if preprocess is not None:
images = preprocess(images)
if track_time:
torch.cuda.synchronize()
start = time.time()
logits = model(images)
if track_time:
torch.cuda.synchronize()
end = time.time()
total_time += (end - start)
_, predicted = logits.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
acc = 100. * correct / total
if track_time:
latency = total_time / total * 1000
throughput = total / total_time
return acc, latency, throughput
return acc
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, default='cifar100', choices=list(DATASETS.keys()))
parser.add_argument('--gpu', type=int, default=0)
parser.add_argument('--batch_size', type=int, default=32)
args = parser.parse_args()
device = f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu'
print(f'Device: {device}', flush=True)
if torch.cuda.is_available():
print(f'GPU: {torch.cuda.get_device_name(args.gpu)}', flush=True)
loader_fn, num_classes = DATASETS[args.dataset]
result = loader_fn(batch_size=args.batch_size, data_dir='./data', num_workers=4)
if len(result) == 4:
train_loader, val_loader, test_loader, n_cls = result
else:
train_loader, test_loader, n_cls = result
val_loader = test_loader
print(f'Dataset: {args.dataset}, Test: {len(test_loader.dataset)}, Classes: {n_cls}', flush=True)
# Load fine-tuned ViT-B/16
print('Loading ViT-B/16 IN-21K...', flush=True)
model = timm.create_model('vit_base_patch16_224.augreg_in21k', pretrained=True, num_classes=n_cls)
model = model.to(device)
model.eval()
ckpt_path = f'checkpoints/{args.dataset}_vit_b16_ft/best_model.pth'
if os.path.exists(ckpt_path):
ckpt = torch.load(ckpt_path, map_location=device, weights_only=True)
if 'model_state_dict' in ckpt:
model.load_state_dict(ckpt['model_state_dict'])
else:
model.load_state_dict(ckpt)
print(f' Loaded fine-tuned checkpoint: {ckpt_path}', flush=True)
# Warmup
for images, _ in test_loader:
images = images.to(device)
_ = model(images)
break
ckpt_base = f'checkpoints/{args.dataset}_vit_b16_ft/best_model.pth'
methods = [
# (name, preprocess_fn, img_size, patches)
('Baseline (224×224)', None, 224, 196),
('Downsample 168×168 (推理)', lambda x: resize_down(x, 168), 168, 100),
('Downsample 112×112 (推理)', lambda x: resize_down(x, 112), 112, 49),
('隔行去行列 (224→112)', remove_even_rows_cols, 112, 49),
('Center crop 168×168', lambda x: center_crop(x, 168), 168, 100),
('Center crop 112×112', lambda x: center_crop(x, 112), 112, 49),
]
print(f'\n=== {args.dataset}: 图片预处理(实际减少 token 数量) ===\n', flush=True)
print(f'{"方法":<30} {"Acc":>8} {"Patches":>8} {"Latency":>10} {"Throughput":>12}', flush=True)
print('-' * 75, flush=True)
for name, preprocess, img_size, expected_patches in methods:
if isinstance(img_size, tuple):
# Skip non-square cases (need separate model creation)
print(f'{name:<30} skipped (non-square input)', flush=True)
continue
# Create a fresh model for each input size
m = make_model_for_size(device, n_cls, 224 if isinstance(img_size, tuple) else img_size, ckpt_base)
acc, latency, throughput = evaluate(m, test_loader, device, preprocess, track_time=True, img_size=img_size)
print(f'{name:<30} {acc:>7.2f}% {expected_patches:>8} {latency:>8.2f}ms {throughput:>10.1f}/s', flush=True)
del m
torch.cuda.empty_cache()
print(f'\n对比:降采样训练(已训练过的模型,非推理时硬切)')
print(f' 168×168 训练: 91.56%, 100 patches (来自 test_downsample_train.py)')
print(f' 112×112 训练: 90.00%, 49 patches (来自 test_downsample_train.py)')
print(f'\n结论:推理时直接缩小图片会掉点较多,因为位置编码不匹配')
print(f'降采样训练(重新训练适配)的效果远好于推理时硬切', flush=True)
if __name__ == '__main__':
main()