timm
File size: 7,236 Bytes
54ee1eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
"""Visualize all experiment types across all datasets."""
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image, ImageFilter, ImageOps
import os, glob, random, cv2

# Setup
data_dir = '/data/ypxia/Workspace/miss_patch/data'
docs_dir = 'docs'
os.makedirs(docs_dir, exist_ok=True)

# Find sample images from each dataset
samples = []

# Oxford Pets (high-res animals)
pets_dir = os.path.join(data_dir, 'oxford-iiit-pet', 'images')
if os.path.exists(pets_dir):
    files = sorted(glob.glob(os.path.join(pets_dir, '*.jpg')))
    random.seed(42)
    f = random.choice(files)
    img = Image.open(f)
    name = os.path.basename(f).rsplit('_', 1)[0]
    samples.append(('Oxford Pets\n(cat/dog breed)', img, 93.81))

# Food-101 (fine-grained dishes)
food_dir = os.path.join(data_dir, 'food-101', 'images')
if os.path.exists(food_dir):
    subdirs = sorted(os.listdir(food_dir))
    random.seed(43)
    sd = random.choice(subdirs)
    files = sorted(glob.glob(os.path.join(food_dir, sd, '*.jpg')))
    if files:
        img = Image.open(files[0])
        name = sd.replace('_', ' ')
        samples.append((f'Food-101\n({name})', img, 91.37))

# DTD (texture)
dtd_dir = os.path.join(data_dir, 'dtd', 'dtd', 'images')
if os.path.exists(dtd_dir):
    cats = sorted(os.listdir(dtd_dir))
    random.seed(44)
    c = random.choice(cats)
    files = sorted(glob.glob(os.path.join(dtd_dir, c, '*.jpg')))
    if files:
        img = Image.open(files[0])
        samples.append((f'DTD\n({c})', img, 80.85))

# CIFAR-100 (32x32 native)
cifar_dir = os.path.join(data_dir, 'cifar-100-python')
if os.path.exists(cifar_dir):
    import pickle
    with open(os.path.join(cifar_dir, 'train'), 'rb') as f:
        batch = pickle.load(f, encoding='bytes')
    cifar_labels = [
        'apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle',
        'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel',
        'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee', 'clock',
        'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur',
        'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster',
        'house', 'kangaroo', 'keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion',
        'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain', 'mouse',
        'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', 'pear',
        'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy', 'porcupine',
        'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', 'rose',
        'sea', 'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail', 'snake',
        'spider', 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper', 'table',
        'tank', 'telephone', 'television', 'tiger', 'tractor', 'train', 'trout',
        'tulip', 'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf', 'woman', 'worm'
    ]
    random.seed(45)
    idx = random.randint(0, 50000)
    arr = batch[b'data'][idx].reshape(3, 32, 32).transpose(1, 2, 0)
    label = cifar_labels[batch[b'fine_labels'][idx]]
    img = Image.fromarray(arr)
    samples.append((f'CIFAR-100\n({label}, 32x32 native)', img, 91.69))

# Process each sample into a standard 224x224 base
def preprocess_to_224(img):
    """Same as training preprocessing: short edge -> 255, center crop -> 224."""
    short = min(img.size)
    scale = 255 / short
    new_size = (int(img.size[0] * scale), int(img.size[1] * scale))
    img_resized = img.resize(new_size, Image.BILINEAR)
    left = (new_size[0] - 224) // 2
    top = (new_size[1] - 224) // 2
    return img_resized.crop((left, top, left + 224, top + 224))

def process_funcs():
    """Define all processing methods."""
    return [
        ('Original\n224x224', lambda img: img),
        ('168x168\n(100 patches)', lambda img: img.resize((168, 168), Image.BILINEAR)),
        ('112x112\n(49 patches)', lambda img: img.resize((112, 112), Image.BILINEAR)),
        ('80x80\n(25 patches)', lambda img: img.resize((80, 80), Image.BILINEAR)),
        ('Grayscale', lambda img: ImageOps.grayscale(img).convert('RGB')),
        ('Blur k=15\n(low freq)', lambda img: img.filter(ImageFilter.GaussianBlur(15))),
        ('Canny Edges\n(structure)', lambda img: canny_on_pil(img)),
        ('2-bit Color\n(4 colors)', lambda img: posterize(img, 2)),
    ]

def canny_on_pil(pil_img):
    arr = np.array(pil_img)
    gray = cv2.cvtColor(arr, cv2.COLOR_RGB2GRAY)
    edges = cv2.Canny(gray, 50, 150)
    return Image.fromarray(np.stack([edges] * 3, axis=-1))

def posterize(pil_img, bits):
    arr = np.array(pil_img)
    arr = (arr >> (8 - bits)) << (8 - bits)
    return Image.fromarray(arr)

# Create figure
methods = process_funcs()
n_rows = len(samples)
n_cols = len(methods) + 1  # +1 for dataset info column

fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 2.8, n_rows * 2.5))

# Column headers
for col in range(1, n_cols):
    col_idx = col - 1
    axes[0, col].set_title(methods[col_idx][0], fontsize=9, fontweight='bold')

for row, (dataset_name, img, baseline_acc) in enumerate(samples):
    img_224 = preprocess_to_224(img)

    # First column: dataset info + sample image
    ax = axes[row, 0]
    # Show a small version of the original
    thumb = img_224.copy()
    ax.imshow(np.array(thumb))
    ax.set_ylabel(dataset_name, fontsize=9, fontweight='bold')
    ax.set_xticks([])
    ax.set_yticks([])
    # Add baseline info
    if row == 0:
        ax.set_title('Original\nimage', fontsize=9)

    # Remaining columns: each processing method
    for col, (_, func) in enumerate(methods):
        ax = axes[row, col + 1]
        result = func(img_224)
        ax.imshow(np.array(result), cmap='gray' if 'Canny' in methods[col][0] else None)
        ax.set_xticks([])
        ax.set_yticks([])

plt.tight_layout()
plt.savefig(os.path.join(docs_dir, 'all_experiments_guide.png'), dpi=180, bbox_inches='tight')
print(f"Saved to {docs_dir}/all_experiments_guide.png")

# Also create a second figure: resolution sweep only, all datasets in one row
fig2, axes2 = plt.subplots(len(samples), 6, figsize=(16, 2.5 * len(samples)))
res_methods = [
    ('224x224\n196 patches', lambda x: x),
    ('168x168\n100 patches', lambda x: x.resize((168, 168), Image.BILINEAR)),
    ('112x112\n49 patches', lambda x: x.resize((112, 112), Image.BILINEAR)),
    ('80x80\n25 patches', lambda x: x.resize((80, 80), Image.BILINEAR)),
    ('64x64\n16 patches', lambda x: x.resize((64, 64), Image.BILINEAR)),
    ('48x48\n9 patches', lambda x: x.resize((48, 48), Image.BILINEAR)),
]

for row, (dataset_name, img, _) in enumerate(samples):
    img_224 = preprocess_to_224(img)
    for col, (label, func) in enumerate(res_methods):
        ax = axes2[row, col]
        result = func(img_224)
        ax.imshow(np.array(result))
        ax.set_xticks([])
        ax.set_yticks([])
        if row == 0:
            ax.set_title(label, fontsize=10, fontweight='bold')
        if col == 0:
            ax.set_ylabel(dataset_name, fontsize=8)

plt.tight_layout()
plt.savefig(os.path.join(docs_dir, 'resolution_sweep_guide.png'), dpi=180, bbox_inches='tight')
print(f"Saved to {docs_dir}/resolution_sweep_guide.png")