Instructions to use 1999xia/ViT_Fast with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- timm
How to use 1999xia/ViT_Fast with timm:
import timm model = timm.create_model("hf_hub:1999xia/ViT_Fast", pretrained=True) - Notebooks
- Google Colab
- Kaggle
File size: 7,236 Bytes
54ee1eb | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 | """Visualize all experiment types across all datasets."""
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image, ImageFilter, ImageOps
import os, glob, random, cv2
# Setup
data_dir = '/data/ypxia/Workspace/miss_patch/data'
docs_dir = 'docs'
os.makedirs(docs_dir, exist_ok=True)
# Find sample images from each dataset
samples = []
# Oxford Pets (high-res animals)
pets_dir = os.path.join(data_dir, 'oxford-iiit-pet', 'images')
if os.path.exists(pets_dir):
files = sorted(glob.glob(os.path.join(pets_dir, '*.jpg')))
random.seed(42)
f = random.choice(files)
img = Image.open(f)
name = os.path.basename(f).rsplit('_', 1)[0]
samples.append(('Oxford Pets\n(cat/dog breed)', img, 93.81))
# Food-101 (fine-grained dishes)
food_dir = os.path.join(data_dir, 'food-101', 'images')
if os.path.exists(food_dir):
subdirs = sorted(os.listdir(food_dir))
random.seed(43)
sd = random.choice(subdirs)
files = sorted(glob.glob(os.path.join(food_dir, sd, '*.jpg')))
if files:
img = Image.open(files[0])
name = sd.replace('_', ' ')
samples.append((f'Food-101\n({name})', img, 91.37))
# DTD (texture)
dtd_dir = os.path.join(data_dir, 'dtd', 'dtd', 'images')
if os.path.exists(dtd_dir):
cats = sorted(os.listdir(dtd_dir))
random.seed(44)
c = random.choice(cats)
files = sorted(glob.glob(os.path.join(dtd_dir, c, '*.jpg')))
if files:
img = Image.open(files[0])
samples.append((f'DTD\n({c})', img, 80.85))
# CIFAR-100 (32x32 native)
cifar_dir = os.path.join(data_dir, 'cifar-100-python')
if os.path.exists(cifar_dir):
import pickle
with open(os.path.join(cifar_dir, 'train'), 'rb') as f:
batch = pickle.load(f, encoding='bytes')
cifar_labels = [
'apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle',
'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel',
'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee', 'clock',
'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur',
'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster',
'house', 'kangaroo', 'keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion',
'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain', 'mouse',
'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', 'pear',
'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy', 'porcupine',
'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', 'rose',
'sea', 'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail', 'snake',
'spider', 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper', 'table',
'tank', 'telephone', 'television', 'tiger', 'tractor', 'train', 'trout',
'tulip', 'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf', 'woman', 'worm'
]
random.seed(45)
idx = random.randint(0, 50000)
arr = batch[b'data'][idx].reshape(3, 32, 32).transpose(1, 2, 0)
label = cifar_labels[batch[b'fine_labels'][idx]]
img = Image.fromarray(arr)
samples.append((f'CIFAR-100\n({label}, 32x32 native)', img, 91.69))
# Process each sample into a standard 224x224 base
def preprocess_to_224(img):
"""Same as training preprocessing: short edge -> 255, center crop -> 224."""
short = min(img.size)
scale = 255 / short
new_size = (int(img.size[0] * scale), int(img.size[1] * scale))
img_resized = img.resize(new_size, Image.BILINEAR)
left = (new_size[0] - 224) // 2
top = (new_size[1] - 224) // 2
return img_resized.crop((left, top, left + 224, top + 224))
def process_funcs():
"""Define all processing methods."""
return [
('Original\n224x224', lambda img: img),
('168x168\n(100 patches)', lambda img: img.resize((168, 168), Image.BILINEAR)),
('112x112\n(49 patches)', lambda img: img.resize((112, 112), Image.BILINEAR)),
('80x80\n(25 patches)', lambda img: img.resize((80, 80), Image.BILINEAR)),
('Grayscale', lambda img: ImageOps.grayscale(img).convert('RGB')),
('Blur k=15\n(low freq)', lambda img: img.filter(ImageFilter.GaussianBlur(15))),
('Canny Edges\n(structure)', lambda img: canny_on_pil(img)),
('2-bit Color\n(4 colors)', lambda img: posterize(img, 2)),
]
def canny_on_pil(pil_img):
arr = np.array(pil_img)
gray = cv2.cvtColor(arr, cv2.COLOR_RGB2GRAY)
edges = cv2.Canny(gray, 50, 150)
return Image.fromarray(np.stack([edges] * 3, axis=-1))
def posterize(pil_img, bits):
arr = np.array(pil_img)
arr = (arr >> (8 - bits)) << (8 - bits)
return Image.fromarray(arr)
# Create figure
methods = process_funcs()
n_rows = len(samples)
n_cols = len(methods) + 1 # +1 for dataset info column
fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 2.8, n_rows * 2.5))
# Column headers
for col in range(1, n_cols):
col_idx = col - 1
axes[0, col].set_title(methods[col_idx][0], fontsize=9, fontweight='bold')
for row, (dataset_name, img, baseline_acc) in enumerate(samples):
img_224 = preprocess_to_224(img)
# First column: dataset info + sample image
ax = axes[row, 0]
# Show a small version of the original
thumb = img_224.copy()
ax.imshow(np.array(thumb))
ax.set_ylabel(dataset_name, fontsize=9, fontweight='bold')
ax.set_xticks([])
ax.set_yticks([])
# Add baseline info
if row == 0:
ax.set_title('Original\nimage', fontsize=9)
# Remaining columns: each processing method
for col, (_, func) in enumerate(methods):
ax = axes[row, col + 1]
result = func(img_224)
ax.imshow(np.array(result), cmap='gray' if 'Canny' in methods[col][0] else None)
ax.set_xticks([])
ax.set_yticks([])
plt.tight_layout()
plt.savefig(os.path.join(docs_dir, 'all_experiments_guide.png'), dpi=180, bbox_inches='tight')
print(f"Saved to {docs_dir}/all_experiments_guide.png")
# Also create a second figure: resolution sweep only, all datasets in one row
fig2, axes2 = plt.subplots(len(samples), 6, figsize=(16, 2.5 * len(samples)))
res_methods = [
('224x224\n196 patches', lambda x: x),
('168x168\n100 patches', lambda x: x.resize((168, 168), Image.BILINEAR)),
('112x112\n49 patches', lambda x: x.resize((112, 112), Image.BILINEAR)),
('80x80\n25 patches', lambda x: x.resize((80, 80), Image.BILINEAR)),
('64x64\n16 patches', lambda x: x.resize((64, 64), Image.BILINEAR)),
('48x48\n9 patches', lambda x: x.resize((48, 48), Image.BILINEAR)),
]
for row, (dataset_name, img, _) in enumerate(samples):
img_224 = preprocess_to_224(img)
for col, (label, func) in enumerate(res_methods):
ax = axes2[row, col]
result = func(img_224)
ax.imshow(np.array(result))
ax.set_xticks([])
ax.set_yticks([])
if row == 0:
ax.set_title(label, fontsize=10, fontweight='bold')
if col == 0:
ax.set_ylabel(dataset_name, fontsize=8)
plt.tight_layout()
plt.savefig(os.path.join(docs_dir, 'resolution_sweep_guide.png'), dpi=180, bbox_inches='tight')
print(f"Saved to {docs_dir}/resolution_sweep_guide.png")
|