timm
ViT_Fast / scripts /visualize_all_experiments.py
1999xia's picture
Upload folder using huggingface_hub
54ee1eb verified
Raw
History Blame Contribute Delete
7.24 kB
"""Visualize all experiment types across all datasets."""
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image, ImageFilter, ImageOps
import os, glob, random, cv2
# Setup
data_dir = '/data/ypxia/Workspace/miss_patch/data'
docs_dir = 'docs'
os.makedirs(docs_dir, exist_ok=True)
# Find sample images from each dataset
samples = []
# Oxford Pets (high-res animals)
pets_dir = os.path.join(data_dir, 'oxford-iiit-pet', 'images')
if os.path.exists(pets_dir):
files = sorted(glob.glob(os.path.join(pets_dir, '*.jpg')))
random.seed(42)
f = random.choice(files)
img = Image.open(f)
name = os.path.basename(f).rsplit('_', 1)[0]
samples.append(('Oxford Pets\n(cat/dog breed)', img, 93.81))
# Food-101 (fine-grained dishes)
food_dir = os.path.join(data_dir, 'food-101', 'images')
if os.path.exists(food_dir):
subdirs = sorted(os.listdir(food_dir))
random.seed(43)
sd = random.choice(subdirs)
files = sorted(glob.glob(os.path.join(food_dir, sd, '*.jpg')))
if files:
img = Image.open(files[0])
name = sd.replace('_', ' ')
samples.append((f'Food-101\n({name})', img, 91.37))
# DTD (texture)
dtd_dir = os.path.join(data_dir, 'dtd', 'dtd', 'images')
if os.path.exists(dtd_dir):
cats = sorted(os.listdir(dtd_dir))
random.seed(44)
c = random.choice(cats)
files = sorted(glob.glob(os.path.join(dtd_dir, c, '*.jpg')))
if files:
img = Image.open(files[0])
samples.append((f'DTD\n({c})', img, 80.85))
# CIFAR-100 (32x32 native)
cifar_dir = os.path.join(data_dir, 'cifar-100-python')
if os.path.exists(cifar_dir):
import pickle
with open(os.path.join(cifar_dir, 'train'), 'rb') as f:
batch = pickle.load(f, encoding='bytes')
cifar_labels = [
'apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle',
'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel',
'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee', 'clock',
'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur',
'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster',
'house', 'kangaroo', 'keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion',
'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain', 'mouse',
'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', 'pear',
'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy', 'porcupine',
'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', 'rose',
'sea', 'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail', 'snake',
'spider', 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper', 'table',
'tank', 'telephone', 'television', 'tiger', 'tractor', 'train', 'trout',
'tulip', 'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf', 'woman', 'worm'
]
random.seed(45)
idx = random.randint(0, 50000)
arr = batch[b'data'][idx].reshape(3, 32, 32).transpose(1, 2, 0)
label = cifar_labels[batch[b'fine_labels'][idx]]
img = Image.fromarray(arr)
samples.append((f'CIFAR-100\n({label}, 32x32 native)', img, 91.69))
# Process each sample into a standard 224x224 base
def preprocess_to_224(img):
"""Same as training preprocessing: short edge -> 255, center crop -> 224."""
short = min(img.size)
scale = 255 / short
new_size = (int(img.size[0] * scale), int(img.size[1] * scale))
img_resized = img.resize(new_size, Image.BILINEAR)
left = (new_size[0] - 224) // 2
top = (new_size[1] - 224) // 2
return img_resized.crop((left, top, left + 224, top + 224))
def process_funcs():
"""Define all processing methods."""
return [
('Original\n224x224', lambda img: img),
('168x168\n(100 patches)', lambda img: img.resize((168, 168), Image.BILINEAR)),
('112x112\n(49 patches)', lambda img: img.resize((112, 112), Image.BILINEAR)),
('80x80\n(25 patches)', lambda img: img.resize((80, 80), Image.BILINEAR)),
('Grayscale', lambda img: ImageOps.grayscale(img).convert('RGB')),
('Blur k=15\n(low freq)', lambda img: img.filter(ImageFilter.GaussianBlur(15))),
('Canny Edges\n(structure)', lambda img: canny_on_pil(img)),
('2-bit Color\n(4 colors)', lambda img: posterize(img, 2)),
]
def canny_on_pil(pil_img):
arr = np.array(pil_img)
gray = cv2.cvtColor(arr, cv2.COLOR_RGB2GRAY)
edges = cv2.Canny(gray, 50, 150)
return Image.fromarray(np.stack([edges] * 3, axis=-1))
def posterize(pil_img, bits):
arr = np.array(pil_img)
arr = (arr >> (8 - bits)) << (8 - bits)
return Image.fromarray(arr)
# Create figure
methods = process_funcs()
n_rows = len(samples)
n_cols = len(methods) + 1 # +1 for dataset info column
fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 2.8, n_rows * 2.5))
# Column headers
for col in range(1, n_cols):
col_idx = col - 1
axes[0, col].set_title(methods[col_idx][0], fontsize=9, fontweight='bold')
for row, (dataset_name, img, baseline_acc) in enumerate(samples):
img_224 = preprocess_to_224(img)
# First column: dataset info + sample image
ax = axes[row, 0]
# Show a small version of the original
thumb = img_224.copy()
ax.imshow(np.array(thumb))
ax.set_ylabel(dataset_name, fontsize=9, fontweight='bold')
ax.set_xticks([])
ax.set_yticks([])
# Add baseline info
if row == 0:
ax.set_title('Original\nimage', fontsize=9)
# Remaining columns: each processing method
for col, (_, func) in enumerate(methods):
ax = axes[row, col + 1]
result = func(img_224)
ax.imshow(np.array(result), cmap='gray' if 'Canny' in methods[col][0] else None)
ax.set_xticks([])
ax.set_yticks([])
plt.tight_layout()
plt.savefig(os.path.join(docs_dir, 'all_experiments_guide.png'), dpi=180, bbox_inches='tight')
print(f"Saved to {docs_dir}/all_experiments_guide.png")
# Also create a second figure: resolution sweep only, all datasets in one row
fig2, axes2 = plt.subplots(len(samples), 6, figsize=(16, 2.5 * len(samples)))
res_methods = [
('224x224\n196 patches', lambda x: x),
('168x168\n100 patches', lambda x: x.resize((168, 168), Image.BILINEAR)),
('112x112\n49 patches', lambda x: x.resize((112, 112), Image.BILINEAR)),
('80x80\n25 patches', lambda x: x.resize((80, 80), Image.BILINEAR)),
('64x64\n16 patches', lambda x: x.resize((64, 64), Image.BILINEAR)),
('48x48\n9 patches', lambda x: x.resize((48, 48), Image.BILINEAR)),
]
for row, (dataset_name, img, _) in enumerate(samples):
img_224 = preprocess_to_224(img)
for col, (label, func) in enumerate(res_methods):
ax = axes2[row, col]
result = func(img_224)
ax.imshow(np.array(result))
ax.set_xticks([])
ax.set_yticks([])
if row == 0:
ax.set_title(label, fontsize=10, fontweight='bold')
if col == 0:
ax.set_ylabel(dataset_name, fontsize=8)
plt.tight_layout()
plt.savefig(os.path.join(docs_dir, 'resolution_sweep_guide.png'), dpi=180, bbox_inches='tight')
print(f"Saved to {docs_dir}/resolution_sweep_guide.png")