Instructions to use 1999xia/ViT_Fast with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- timm
How to use 1999xia/ViT_Fast with timm:
import timm model = timm.create_model("hf_hub:1999xia/ViT_Fast", pretrained=True) - Notebooks
- Google Colab
- Kaggle
| """Visualize all experiment types across all datasets.""" | |
| import matplotlib | |
| matplotlib.use('Agg') | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| from PIL import Image, ImageFilter, ImageOps | |
| import os, glob, random, cv2 | |
| # Setup | |
| data_dir = '/data/ypxia/Workspace/miss_patch/data' | |
| docs_dir = 'docs' | |
| os.makedirs(docs_dir, exist_ok=True) | |
| # Find sample images from each dataset | |
| samples = [] | |
| # Oxford Pets (high-res animals) | |
| pets_dir = os.path.join(data_dir, 'oxford-iiit-pet', 'images') | |
| if os.path.exists(pets_dir): | |
| files = sorted(glob.glob(os.path.join(pets_dir, '*.jpg'))) | |
| random.seed(42) | |
| f = random.choice(files) | |
| img = Image.open(f) | |
| name = os.path.basename(f).rsplit('_', 1)[0] | |
| samples.append(('Oxford Pets\n(cat/dog breed)', img, 93.81)) | |
| # Food-101 (fine-grained dishes) | |
| food_dir = os.path.join(data_dir, 'food-101', 'images') | |
| if os.path.exists(food_dir): | |
| subdirs = sorted(os.listdir(food_dir)) | |
| random.seed(43) | |
| sd = random.choice(subdirs) | |
| files = sorted(glob.glob(os.path.join(food_dir, sd, '*.jpg'))) | |
| if files: | |
| img = Image.open(files[0]) | |
| name = sd.replace('_', ' ') | |
| samples.append((f'Food-101\n({name})', img, 91.37)) | |
| # DTD (texture) | |
| dtd_dir = os.path.join(data_dir, 'dtd', 'dtd', 'images') | |
| if os.path.exists(dtd_dir): | |
| cats = sorted(os.listdir(dtd_dir)) | |
| random.seed(44) | |
| c = random.choice(cats) | |
| files = sorted(glob.glob(os.path.join(dtd_dir, c, '*.jpg'))) | |
| if files: | |
| img = Image.open(files[0]) | |
| samples.append((f'DTD\n({c})', img, 80.85)) | |
| # CIFAR-100 (32x32 native) | |
| cifar_dir = os.path.join(data_dir, 'cifar-100-python') | |
| if os.path.exists(cifar_dir): | |
| import pickle | |
| with open(os.path.join(cifar_dir, 'train'), 'rb') as f: | |
| batch = pickle.load(f, encoding='bytes') | |
| cifar_labels = [ | |
| 'apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle', | |
| 'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel', | |
| 'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee', 'clock', | |
| 'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur', | |
| 'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster', | |
| 'house', 'kangaroo', 'keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion', | |
| 'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain', 'mouse', | |
| 'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', 'pear', | |
| 'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy', 'porcupine', | |
| 'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', 'rose', | |
| 'sea', 'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail', 'snake', | |
| 'spider', 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper', 'table', | |
| 'tank', 'telephone', 'television', 'tiger', 'tractor', 'train', 'trout', | |
| 'tulip', 'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf', 'woman', 'worm' | |
| ] | |
| random.seed(45) | |
| idx = random.randint(0, 50000) | |
| arr = batch[b'data'][idx].reshape(3, 32, 32).transpose(1, 2, 0) | |
| label = cifar_labels[batch[b'fine_labels'][idx]] | |
| img = Image.fromarray(arr) | |
| samples.append((f'CIFAR-100\n({label}, 32x32 native)', img, 91.69)) | |
| # Process each sample into a standard 224x224 base | |
| def preprocess_to_224(img): | |
| """Same as training preprocessing: short edge -> 255, center crop -> 224.""" | |
| short = min(img.size) | |
| scale = 255 / short | |
| new_size = (int(img.size[0] * scale), int(img.size[1] * scale)) | |
| img_resized = img.resize(new_size, Image.BILINEAR) | |
| left = (new_size[0] - 224) // 2 | |
| top = (new_size[1] - 224) // 2 | |
| return img_resized.crop((left, top, left + 224, top + 224)) | |
| def process_funcs(): | |
| """Define all processing methods.""" | |
| return [ | |
| ('Original\n224x224', lambda img: img), | |
| ('168x168\n(100 patches)', lambda img: img.resize((168, 168), Image.BILINEAR)), | |
| ('112x112\n(49 patches)', lambda img: img.resize((112, 112), Image.BILINEAR)), | |
| ('80x80\n(25 patches)', lambda img: img.resize((80, 80), Image.BILINEAR)), | |
| ('Grayscale', lambda img: ImageOps.grayscale(img).convert('RGB')), | |
| ('Blur k=15\n(low freq)', lambda img: img.filter(ImageFilter.GaussianBlur(15))), | |
| ('Canny Edges\n(structure)', lambda img: canny_on_pil(img)), | |
| ('2-bit Color\n(4 colors)', lambda img: posterize(img, 2)), | |
| ] | |
| def canny_on_pil(pil_img): | |
| arr = np.array(pil_img) | |
| gray = cv2.cvtColor(arr, cv2.COLOR_RGB2GRAY) | |
| edges = cv2.Canny(gray, 50, 150) | |
| return Image.fromarray(np.stack([edges] * 3, axis=-1)) | |
| def posterize(pil_img, bits): | |
| arr = np.array(pil_img) | |
| arr = (arr >> (8 - bits)) << (8 - bits) | |
| return Image.fromarray(arr) | |
| # Create figure | |
| methods = process_funcs() | |
| n_rows = len(samples) | |
| n_cols = len(methods) + 1 # +1 for dataset info column | |
| fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 2.8, n_rows * 2.5)) | |
| # Column headers | |
| for col in range(1, n_cols): | |
| col_idx = col - 1 | |
| axes[0, col].set_title(methods[col_idx][0], fontsize=9, fontweight='bold') | |
| for row, (dataset_name, img, baseline_acc) in enumerate(samples): | |
| img_224 = preprocess_to_224(img) | |
| # First column: dataset info + sample image | |
| ax = axes[row, 0] | |
| # Show a small version of the original | |
| thumb = img_224.copy() | |
| ax.imshow(np.array(thumb)) | |
| ax.set_ylabel(dataset_name, fontsize=9, fontweight='bold') | |
| ax.set_xticks([]) | |
| ax.set_yticks([]) | |
| # Add baseline info | |
| if row == 0: | |
| ax.set_title('Original\nimage', fontsize=9) | |
| # Remaining columns: each processing method | |
| for col, (_, func) in enumerate(methods): | |
| ax = axes[row, col + 1] | |
| result = func(img_224) | |
| ax.imshow(np.array(result), cmap='gray' if 'Canny' in methods[col][0] else None) | |
| ax.set_xticks([]) | |
| ax.set_yticks([]) | |
| plt.tight_layout() | |
| plt.savefig(os.path.join(docs_dir, 'all_experiments_guide.png'), dpi=180, bbox_inches='tight') | |
| print(f"Saved to {docs_dir}/all_experiments_guide.png") | |
| # Also create a second figure: resolution sweep only, all datasets in one row | |
| fig2, axes2 = plt.subplots(len(samples), 6, figsize=(16, 2.5 * len(samples))) | |
| res_methods = [ | |
| ('224x224\n196 patches', lambda x: x), | |
| ('168x168\n100 patches', lambda x: x.resize((168, 168), Image.BILINEAR)), | |
| ('112x112\n49 patches', lambda x: x.resize((112, 112), Image.BILINEAR)), | |
| ('80x80\n25 patches', lambda x: x.resize((80, 80), Image.BILINEAR)), | |
| ('64x64\n16 patches', lambda x: x.resize((64, 64), Image.BILINEAR)), | |
| ('48x48\n9 patches', lambda x: x.resize((48, 48), Image.BILINEAR)), | |
| ] | |
| for row, (dataset_name, img, _) in enumerate(samples): | |
| img_224 = preprocess_to_224(img) | |
| for col, (label, func) in enumerate(res_methods): | |
| ax = axes2[row, col] | |
| result = func(img_224) | |
| ax.imshow(np.array(result)) | |
| ax.set_xticks([]) | |
| ax.set_yticks([]) | |
| if row == 0: | |
| ax.set_title(label, fontsize=10, fontweight='bold') | |
| if col == 0: | |
| ax.set_ylabel(dataset_name, fontsize=8) | |
| plt.tight_layout() | |
| plt.savefig(os.path.join(docs_dir, 'resolution_sweep_guide.png'), dpi=180, bbox_inches='tight') | |
| print(f"Saved to {docs_dir}/resolution_sweep_guide.png") | |