File size: 7,957 Bytes
8f87579
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
'''
A simple tool to generate sample of output of a GAN,
subject to filtering, sorting, or intervention.
'''

import torch, numpy, os, argparse, sys, shutil, errno, numbers
from PIL import Image
from torch.utils.data import TensorDataset
from netdissect.zdataset import standard_z_sample
from netdissect.progress import default_progress, verbose_progress
from netdissect.autoeval import autoimport_eval
from netdissect.workerpool import WorkerBase, WorkerPool
from netdissect.nethook import retain_layers
from netdissect.runningstats import RunningTopK

def main():
    parser = argparse.ArgumentParser(description='GAN sample making utility')
    parser.add_argument('--model', type=str, default=None,
            help='constructor for the model to test')
    parser.add_argument('--pthfile', type=str, default=None,
            help='filename of .pth file for the model')
    parser.add_argument('--outdir', type=str, default='images',
            help='directory for image output')
    parser.add_argument('--size', type=int, default=100,
            help='number of images to output')
    parser.add_argument('--test_size', type=int, default=None,
            help='number of images to test')
    parser.add_argument('--layer', type=str, default=None,
            help='layer to inspect')
    parser.add_argument('--seed', type=int, default=1,
            help='seed')
    parser.add_argument('--quiet', action='store_true', default=False,
            help='silences console output')
    if len(sys.argv) == 1:
        parser.print_usage(sys.stderr)
        sys.exit(1)
    args = parser.parse_args()
    verbose_progress(not args.quiet)

    # Instantiate the model
    model = autoimport_eval(args.model)
    if args.pthfile is not None:
        data = torch.load(args.pthfile)
        if 'state_dict' in data:
            meta = {}
            for key in data:
                if isinstance(data[key], numbers.Number):
                    meta[key] = data[key]
            data = data['state_dict']
        model.load_state_dict(data)
    # Unwrap any DataParallel-wrapped model
    if isinstance(model, torch.nn.DataParallel):
        model = next(model.children())
    # Examine first conv in model to determine input feature size.
    first_layer = [c for c in model.modules()
            if isinstance(c, (torch.nn.Conv2d, torch.nn.ConvTranspose2d,
                torch.nn.Linear))][0]
    # 4d input if convolutional, 2d input if first layer is linear.
    if isinstance(first_layer, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)):
        z_channels = first_layer.in_channels
        spatialdims = (1, 1)
    else:
        z_channels = first_layer.in_features
        spatialdims = ()
    # Instrument the model
    retain_layers(model, [args.layer])
    model.cuda()

    if args.test_size is None:
        args.test_size = args.size * 20
    z_universe = standard_z_sample(args.test_size, z_channels,
            seed=args.seed)
    z_universe = z_universe.view(tuple(z_universe.shape) + spatialdims)
    indexes = get_all_highest_znums(
            model, z_universe, args.size, seed=args.seed)
    save_chosen_unit_images(args.outdir, model, z_universe, indexes,
            lightbox=True)


def get_all_highest_znums(model, z_universe, size,
        batch_size=10, seed=1):
    # The model should have been instrumented already
    retained_items = list(model.retained.items())
    assert len(retained_items) == 1
    layer = retained_items[0][0]
    # By default, a 10% sample
    progress = default_progress()
    num_units = None
    with torch.no_grad():
        # Pass 1: collect max activation stats
        z_loader = torch.utils.data.DataLoader(TensorDataset(z_universe),
                    batch_size=batch_size, num_workers=2,
                    pin_memory=True)
        rtk = RunningTopK(k=size)
        for [z] in progress(z_loader, desc='Finding max activations'):
            z = z.cuda()
            model(z)
            feature = model.retained[layer]
            num_units = feature.shape[1]
            max_feature = feature.view(
                    feature.shape[0], num_units, -1).max(2)[0]
            rtk.add(max_feature)
        td, ti = rtk.result()
        highest = ti.sort(1)[0]
    return highest

def save_chosen_unit_images(dirname, model, z_universe, indices,
        shared_dir="shared_images",
        unitdir_template="unit_{}",
        name_template="image_{}.jpg",
        lightbox=False, batch_size=50, seed=1):
    all_indices = torch.unique(indices.view(-1), sorted=True)
    z_sample = z_universe[all_indices]
    progress = default_progress()
    sdir = os.path.join(dirname, shared_dir)
    created_hashdirs = set()
    for index in range(len(z_universe)):
        hd = hashdir(index)
        if hd not in created_hashdirs:
            created_hashdirs.add(hd)
            os.makedirs(os.path.join(sdir, hd), exist_ok=True)
    with torch.no_grad():
        # Pass 2: now generate images
        z_loader = torch.utils.data.DataLoader(TensorDataset(z_sample),
                    batch_size=batch_size, num_workers=2,
                    pin_memory=True)
        saver = WorkerPool(SaveImageWorker)
        for batch_num, [z] in enumerate(progress(z_loader,
                desc='Saving images')):
            z = z.cuda()
            start_index = batch_num * batch_size
            im = ((model(z) + 1) / 2 * 255).clamp(0, 255).byte().permute(
                    0, 2, 3, 1).cpu()
            for i in range(len(im)):
                index = all_indices[i + start_index].item()
                filename = os.path.join(sdir, hashdir(index),
                        name_template.format(index))
                saver.add(im[i].numpy(), filename)
        saver.join()
    linker = WorkerPool(MakeLinkWorker)
    for u in progress(range(len(indices)), desc='Making links'):
        udir = os.path.join(dirname, unitdir_template.format(u))
        os.makedirs(udir, exist_ok=True)
        for r in range(indices.shape[1]):
            index = indices[u,r].item()
            fn = name_template.format(index)
            # sourcename = os.path.join('..', shared_dir, fn)
            sourcename = os.path.join(sdir, hashdir(index), fn)
            targname = os.path.join(udir, fn)
            linker.add(sourcename, targname)
        if lightbox:
            copy_lightbox_to(udir)
    linker.join()

def copy_lightbox_to(dirname):
   srcdir = os.path.realpath(
       os.path.join(os.getcwd(), os.path.dirname(__file__)))
   shutil.copy(os.path.join(srcdir, 'lightbox.html'),
           os.path.join(dirname, '+lightbox.html'))

def hashdir(index):
    # To keep the number of files the shared directory lower, split it
    # into 100 subdirectories named as follows.
    return '%02d' % (index % 100)

class SaveImageWorker(WorkerBase):
    # Saving images can be sped up by sending jpeg encoding and
    # file-writing work to a pool.
    def work(self, data, filename):
        Image.fromarray(data).save(filename, optimize=True, quality=100)

class MakeLinkWorker(WorkerBase):
    # Creating symbolic links is a bit slow and can be done faster
    # in parallel rather than waiting for each to be created.
    def work(self, sourcename, targname):
        try:
            os.link(sourcename, targname)
        except OSError as e:
            if e.errno == errno.EEXIST:
                os.remove(targname)
                os.link(sourcename, targname)
            else:
                raise

class MakeSyminkWorker(WorkerBase):
    # Creating symbolic links is a bit slow and can be done faster
    # in parallel rather than waiting for each to be created.
    def work(self, sourcename, targname):
        try:
            os.symlink(sourcename, targname)
        except OSError as e:
            if e.errno == errno.EEXIST:
                os.remove(targname)
                os.symlink(sourcename, targname)
            else:
                raise

if __name__ == '__main__':
    main()