Instructions to use 1999xia/ViT_Fast with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- timm
How to use 1999xia/ViT_Fast with timm:
import timm model = timm.create_model("hf_hub:1999xia/ViT_Fast", pretrained=True) - Notebooks
- Google Colab
- Kaggle
| """ | |
| Test image preprocessing strategies that actually reduce token count. | |
| Corrected experiments: | |
| 1. Downsample (no restore): resize to smaller, feed directly (inference only) | |
| 2. Checkerboard removal: remove every other row (halve image size) | |
| 3. Center crop: crop a smaller region from the image | |
| Usage: | |
| python test_preprocess_reduce.py --dataset cifar100 --gpu 5 | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import time | |
| import argparse | |
| import sys | |
| import os | |
| sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) | |
| from datasets import get_cifar100_loader, get_oxford_pets_loader, get_food101_loader | |
| import timm | |
| DATASETS = { | |
| 'cifar100': (get_cifar100_loader, 100), | |
| 'oxford_pets': (get_oxford_pets_loader, 37), | |
| 'food101': (get_food101_loader, 101), | |
| } | |
| def resize_down(images, target_size): | |
| """Resize image to smaller resolution (no upsample back).""" | |
| B, C, H, W = images.shape | |
| return F.interpolate(images, size=(target_size, target_size), mode='bilinear', align_corners=False) | |
| def remove_even_rows_cols(images): | |
| """Remove every other row AND column → image size halves. | |
| 224×224 → remove even rows → 112×224 → remove even cols → 112×112 | |
| Effectively 75% pixel reduction, image shrinks to 1/4 area. | |
| """ | |
| return images[:, :, ::2, ::2] | |
| def remove_even_rows_only(images): | |
| """Remove every other row only → height halves, width unchanged. | |
| 224×224 → 112×224 (non-square, but fewer patches) | |
| """ | |
| return images[:, :, ::2, :] | |
| def center_crop(images, crop_size): | |
| """Center crop the image to a smaller size.""" | |
| B, C, H, W = images.shape | |
| h_start = (H - crop_size) // 2 | |
| w_start = (W - crop_size) // 2 | |
| return images[:, :, h_start:h_start+crop_size, w_start:w_start+crop_size] | |
| def get_num_patches(img_size, patch_size=16, stride=16): | |
| """Calculate number of patches for a given image size.""" | |
| grid = (img_size - patch_size) // stride + 1 | |
| return grid * grid | |
| def make_model_for_size(device, num_classes, img_size, checkpoint_path, patch_size=16): | |
| """Create a ViT-B/16 model adapted for a specific input size. | |
| Loads checkpoint from 224×224 model, interpolates pos_embed for new size. | |
| """ | |
| model = timm.create_model( | |
| 'vit_base_patch16_224.augreg_in21k', | |
| pretrained=True, num_classes=num_classes, img_size=img_size) | |
| if os.path.exists(checkpoint_path): | |
| ckpt = torch.load(checkpoint_path, map_location='cpu', weights_only=True) | |
| sd = ckpt.get('model_state_dict', ckpt) | |
| # Interpolate position embeddings if sizes differ | |
| if 'pos_embed' in sd and sd['pos_embed'].shape != model.pos_embed.shape: | |
| old_pos = sd.pop('pos_embed') | |
| cls_pos = old_pos[:, 0:1, :] | |
| patch_pos = old_pos[:, 1:, :] | |
| old_grid = int((patch_pos.shape[1]) ** 0.5) | |
| new_grid = (img_size - patch_size) // patch_size + 1 | |
| if old_grid != new_grid: | |
| patch_2d = patch_pos.reshape(1, old_grid, old_grid, -1).permute(0, 3, 1, 2) | |
| patch_new = F.interpolate(patch_2d, size=(new_grid, new_grid), | |
| mode='bilinear', align_corners=False) | |
| patch_new = patch_new.permute(0, 2, 3, 1).reshape(1, -1, 768) | |
| sd['pos_embed'] = torch.cat([cls_pos, patch_new], dim=1) | |
| missing, unexpected = model.load_state_dict(sd, strict=False) | |
| if missing: | |
| print(f' Missing keys: {len(missing)} (expected for diff img_size)', flush=True) | |
| if unexpected: | |
| print(f' Unexpected keys: {len(unexpected)}', flush=True) | |
| model = model.to(device) | |
| model.eval() | |
| return model | |
| def evaluate(model, loader, device, preprocess=None, track_time=False, img_size=224): | |
| model.eval() | |
| correct = 0 | |
| total = 0 | |
| total_time = 0 | |
| for images, targets in loader: | |
| images, targets = images.to(device), targets.to(device) | |
| if preprocess is not None: | |
| images = preprocess(images) | |
| if track_time: | |
| torch.cuda.synchronize() | |
| start = time.time() | |
| logits = model(images) | |
| if track_time: | |
| torch.cuda.synchronize() | |
| end = time.time() | |
| total_time += (end - start) | |
| _, predicted = logits.max(1) | |
| total += targets.size(0) | |
| correct += predicted.eq(targets).sum().item() | |
| acc = 100. * correct / total | |
| if track_time: | |
| latency = total_time / total * 1000 | |
| throughput = total / total_time | |
| return acc, latency, throughput | |
| return acc | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--dataset', type=str, default='cifar100', choices=list(DATASETS.keys())) | |
| parser.add_argument('--gpu', type=int, default=0) | |
| parser.add_argument('--batch_size', type=int, default=32) | |
| args = parser.parse_args() | |
| device = f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu' | |
| print(f'Device: {device}', flush=True) | |
| if torch.cuda.is_available(): | |
| print(f'GPU: {torch.cuda.get_device_name(args.gpu)}', flush=True) | |
| loader_fn, num_classes = DATASETS[args.dataset] | |
| result = loader_fn(batch_size=args.batch_size, data_dir='./data', num_workers=4) | |
| if len(result) == 4: | |
| train_loader, val_loader, test_loader, n_cls = result | |
| else: | |
| train_loader, test_loader, n_cls = result | |
| val_loader = test_loader | |
| print(f'Dataset: {args.dataset}, Test: {len(test_loader.dataset)}, Classes: {n_cls}', flush=True) | |
| # Load fine-tuned ViT-B/16 | |
| print('Loading ViT-B/16 IN-21K...', flush=True) | |
| model = timm.create_model('vit_base_patch16_224.augreg_in21k', pretrained=True, num_classes=n_cls) | |
| model = model.to(device) | |
| model.eval() | |
| ckpt_path = f'checkpoints/{args.dataset}_vit_b16_ft/best_model.pth' | |
| if os.path.exists(ckpt_path): | |
| ckpt = torch.load(ckpt_path, map_location=device, weights_only=True) | |
| if 'model_state_dict' in ckpt: | |
| model.load_state_dict(ckpt['model_state_dict']) | |
| else: | |
| model.load_state_dict(ckpt) | |
| print(f' Loaded fine-tuned checkpoint: {ckpt_path}', flush=True) | |
| # Warmup | |
| for images, _ in test_loader: | |
| images = images.to(device) | |
| _ = model(images) | |
| break | |
| ckpt_base = f'checkpoints/{args.dataset}_vit_b16_ft/best_model.pth' | |
| methods = [ | |
| # (name, preprocess_fn, img_size, patches) | |
| ('Baseline (224×224)', None, 224, 196), | |
| ('Downsample 168×168 (推理)', lambda x: resize_down(x, 168), 168, 100), | |
| ('Downsample 112×112 (推理)', lambda x: resize_down(x, 112), 112, 49), | |
| ('隔行去行列 (224→112)', remove_even_rows_cols, 112, 49), | |
| ('Center crop 168×168', lambda x: center_crop(x, 168), 168, 100), | |
| ('Center crop 112×112', lambda x: center_crop(x, 112), 112, 49), | |
| ] | |
| print(f'\n=== {args.dataset}: 图片预处理(实际减少 token 数量) ===\n', flush=True) | |
| print(f'{"方法":<30} {"Acc":>8} {"Patches":>8} {"Latency":>10} {"Throughput":>12}', flush=True) | |
| print('-' * 75, flush=True) | |
| for name, preprocess, img_size, expected_patches in methods: | |
| if isinstance(img_size, tuple): | |
| # Skip non-square cases (need separate model creation) | |
| print(f'{name:<30} skipped (non-square input)', flush=True) | |
| continue | |
| # Create a fresh model for each input size | |
| m = make_model_for_size(device, n_cls, 224 if isinstance(img_size, tuple) else img_size, ckpt_base) | |
| acc, latency, throughput = evaluate(m, test_loader, device, preprocess, track_time=True, img_size=img_size) | |
| print(f'{name:<30} {acc:>7.2f}% {expected_patches:>8} {latency:>8.2f}ms {throughput:>10.1f}/s', flush=True) | |
| del m | |
| torch.cuda.empty_cache() | |
| print(f'\n对比:降采样训练(已训练过的模型,非推理时硬切)') | |
| print(f' 168×168 训练: 91.56%, 100 patches (来自 test_downsample_train.py)') | |
| print(f' 112×112 训练: 90.00%, 49 patches (来自 test_downsample_train.py)') | |
| print(f'\n结论:推理时直接缩小图片会掉点较多,因为位置编码不匹配') | |
| print(f'降采样训练(重新训练适配)的效果远好于推理时硬切', flush=True) | |
| if __name__ == '__main__': | |
| main() | |