File size: 16,721 Bytes
ba5dcdc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
import torch, sys, os, argparse, textwrap, numbers, numpy, json, PIL
from torchvision import transforms
from torch.utils.data import TensorDataset
from netdissect.progress import verbose_progress, print_progress
from netdissect import InstrumentedModel, BrodenDataset, dissect
from netdissect import MultiSegmentDataset, GeneratorSegRunner
from netdissect import ImageOnlySegRunner
from netdissect.parallelfolder import ParallelImageFolders
from netdissect.zdataset import z_dataset_for_model
from netdissect.autoeval import autoimport_eval
from netdissect.modelconfig import create_instrumented_model
from netdissect.pidfile import exit_if_job_done, mark_job_done

help_epilog = '''\
Example: to dissect three layers of the pretrained alexnet in torchvision:

python -m netdissect \\
        --model "torchvision.models.alexnet(pretrained=True)" \\
        --layers features.6:conv3 features.8:conv4 features.10:conv5 \\
        --imgsize 227 \\
        --outdir dissect/alexnet-imagenet

To dissect a progressive GAN model:

python -m netdissect \\
        --model "proggan.from_pth_file('model/churchoutdoor.pth')" \\
        --gan
'''

def main():
    # Training settings
    def strpair(arg):
        p = tuple(arg.split(':'))
        if len(p) == 1:
            p = p + p
        return p
    def intpair(arg):
        p = arg.split(',')
        if len(p) == 1:
            p = p + p
        return tuple(int(v) for v in p)

    parser = argparse.ArgumentParser(description='Net dissect utility',
            prog='python -m netdissect',
            epilog=textwrap.dedent(help_epilog),
            formatter_class=argparse.RawDescriptionHelpFormatter)
    parser.add_argument('--model', type=str, default=None,
                        help='constructor for the model to test')
    parser.add_argument('--pthfile', type=str, default=None,
                        help='filename of .pth file for the model')
    parser.add_argument('--unstrict', action='store_true', default=False,
                        help='ignore unexpected pth parameters')
    parser.add_argument('--submodule', type=str, default=None,
                        help='submodule to load from pthfile')
    parser.add_argument('--outdir', type=str, default='dissect',
                        help='directory for dissection output')
    parser.add_argument('--layers', type=strpair, nargs='+',
                        help='space-separated list of layer names to dissect' +
                        ', in the form layername[:reportedname]')
    parser.add_argument('--segments', type=str, default='dataset/broden',
                        help='directory containing segmentation dataset')
    parser.add_argument('--segmenter', type=str, default=None,
                        help='constructor for asegmenter class')
    parser.add_argument('--download', action='store_true', default=False,
                        help='downloads Broden dataset if needed')
    parser.add_argument('--imagedir', type=str, default=None,
                        help='directory containing image-only dataset')
    parser.add_argument('--imgsize', type=intpair, default=(227, 227),
                        help='input image size to use')
    parser.add_argument('--netname', type=str, default=None,
                        help='name for network in generated reports')
    parser.add_argument('--meta', type=str, nargs='+',
                        help='json files of metadata to add to report')
    parser.add_argument('--merge', type=str,
                        help='json file of unit data to merge in report')
    parser.add_argument('--examples', type=int, default=20,
                        help='number of image examples per unit')
    parser.add_argument('--size', type=int, default=10000,
                        help='dataset subset size to use')
    parser.add_argument('--batch_size', type=int, default=100,
                        help='batch size for forward pass')
    parser.add_argument('--num_workers', type=int, default=24,
                        help='number of DataLoader workers')
    parser.add_argument('--quantile_threshold', type=strfloat, default=None,
                        choices=[FloatRange(0.0, 1.0), 'iqr'],
                        help='quantile to use for masks')
    parser.add_argument('--no-labels', action='store_true', default=False,
                        help='disables labeling of units')
    parser.add_argument('--maxiou', action='store_true', default=False,
                        help='enables maxiou calculation')
    parser.add_argument('--covariance', action='store_true', default=False,
                        help='enables covariance calculation')
    parser.add_argument('--rank_all_labels', action='store_true', default=False,
                        help='include low-information labels in rankings')
    parser.add_argument('--no-images', action='store_true', default=False,
                        help='disables generation of unit images')
    parser.add_argument('--no-report', action='store_true', default=False,
                        help='disables generation report summary')
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='disables CUDA usage')
    parser.add_argument('--gen', action='store_true', default=False,
                        help='test a generator model (e.g., a GAN)')
    parser.add_argument('--gan', action='store_true', default=False,
                        help='synonym for --gen')
    parser.add_argument('--perturbation', default=None,
                        help='filename of perturbation attack to apply')
    parser.add_argument('--add_scale_offset', action='store_true', default=None,
                        help='offsets masks according to stride and padding')
    parser.add_argument('--quiet', action='store_true', default=False,
                        help='silences console output')
    if len(sys.argv) == 1:
        parser.print_usage(sys.stderr)
        sys.exit(1)
    args = parser.parse_args()
    args.images = not args.no_images
    args.report = not args.no_report
    args.labels = not args.no_labels
    if args.gan:
        args.gen = args.gan

    # Set up console output
    verbose_progress(not args.quiet)

    # Exit right away if job is already done or being done.
    if args.outdir is not None:
        exit_if_job_done(args.outdir)

    # Speed up pytorch
    torch.backends.cudnn.benchmark = True

    # Special case: download flag without model to test.
    if args.model is None and args.download:
        from netdissect.broden import ensure_broden_downloaded
        for resolution in [224, 227, 384]:
            ensure_broden_downloaded(args.segments, resolution, 1)
        from netdissect.segmenter import ensure_upp_segmenter_downloaded
        ensure_upp_segmenter_downloaded('dataset/segmodel')
        sys.exit(0)

    # Help if broden is not present
    if not args.gen and not args.imagedir and not os.path.isdir(args.segments):
        print_progress('Segmentation dataset not found at %s.' % args.segments)
        print_progress('Specify dataset directory using --segments [DIR]')
        print_progress('To download Broden, run: netdissect --download')
        sys.exit(1)

    # Default segmenter class
    if args.gen and args.segmenter is None:
        args.segmenter = ("netdissect.segmenter.UnifiedParsingSegmenter(" +
                "segsizes=[256], segdiv='quad')")

    # Default threshold
    if args.quantile_threshold is None:
        if args.gen:
            args.quantile_threshold = 'iqr'
        else:
            args.quantile_threshold = 0.005

    # Set up CUDA
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    if args.cuda:
        torch.backends.cudnn.benchmark = True

    # Construct the network with specified layers instrumented
    if args.model is None:
        print_progress('No model specified')
        sys.exit(1)
    model = create_instrumented_model(args)

    # Update any metadata from files, if any
    meta = getattr(model, 'meta', {})
    if args.meta:
        for mfilename in args.meta:
            with open(mfilename) as f:
                meta.update(json.load(f))

    # Load any merge data from files
    mergedata = None
    if args.merge:
        with open(args.merge) as f:
            mergedata = json.load(f)

    # Set up the output directory, verify write access
    if args.outdir is None:
        args.outdir = os.path.join('dissect', type(model).__name__)
        exit_if_job_done(args.outdir)
        print_progress('Writing output into %s.' % args.outdir)
    os.makedirs(args.outdir, exist_ok=True)
    train_dataset = None

    if not args.gen:
        # Load dataset for classifier case.
        # Load perturbation
        perturbation = numpy.load(args.perturbation
                ) if args.perturbation else None
        segrunner = None

        # Load broden dataset
        if args.imagedir is not None:
            dataset = try_to_load_images(args.imagedir, args.imgsize,
                    perturbation, args.size)
            segrunner = ImageOnlySegRunner(dataset)
        else:
            dataset = try_to_load_broden(args.segments, args.imgsize, 1,
                perturbation, args.download, args.size)
        if dataset is None:
            dataset = try_to_load_multiseg(args.segments, args.imgsize,
                    perturbation, args.size)
        if dataset is None:
            print_progress('No segmentation dataset found in %s',
                    args.segments)
            print_progress('use --download to download Broden.')
            sys.exit(1)
    else:
        # For segmenter case the dataset is just a random z
        dataset = z_dataset_for_model(model, args.size)
        train_dataset = z_dataset_for_model(model, args.size, seed=2)
        segrunner = GeneratorSegRunner(autoimport_eval(args.segmenter))

    # Run dissect
    dissect(args.outdir, model, dataset,
            train_dataset=train_dataset,
            segrunner=segrunner,
            examples_per_unit=args.examples,
            netname=args.netname,
            quantile_threshold=args.quantile_threshold,
            meta=meta,
            merge=mergedata,
            make_images=args.images,
            make_labels=args.labels,
            make_maxiou=args.maxiou,
            make_covariance=args.covariance,
            make_report=args.report,
            make_row_images=args.images,
            make_single_images=True,
            rank_all_labels=args.rank_all_labels,
            batch_size=args.batch_size,
            num_workers=args.num_workers,
            settings=vars(args))

    # Mark the directory so that it's not done again.
    mark_job_done(args.outdir)

class AddPerturbation(object):
    def __init__(self, perturbation):
        self.perturbation = perturbation

    def __call__(self, pic):
        if self.perturbation is None:
            return pic
        # Convert to a numpy float32 array
        npyimg = numpy.array(pic, numpy.uint8, copy=False
                ).astype(numpy.float32)
        # Center the perturbation
        oy, ox = ((self.perturbation.shape[d] - npyimg.shape[d]) // 2
                for d in [0, 1])
        npyimg += self.perturbation[
                oy:oy+npyimg.shape[0], ox:ox+npyimg.shape[1]]
        # Pytorch conventions: as a float it should be [0..1]
        npyimg.clip(0, 255, npyimg)
        return npyimg / 255.0

def test_dissection():
    verbose_progress(True)
    from torchvision.models import alexnet
    from torchvision import transforms
    model = InstrumentedModel(alexnet(pretrained=True))
    model.eval()
    # Load an alexnet
    model.retain_layers([
        ('features.0', 'conv1'),
        ('features.3', 'conv2'),
        ('features.6', 'conv3'),
        ('features.8', 'conv4'),
        ('features.10', 'conv5') ])
    # load broden dataset
    bds = BrodenDataset('dataset/broden',
            transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(IMAGE_MEAN, IMAGE_STDEV)]),
            size=100)
    # run dissect
    dissect('dissect/test', model, bds,
            examples_per_unit=10)

def try_to_load_images(directory, imgsize, perturbation, size):
    # Load plain image dataset
    # TODO: allow other normalizations.
    return ParallelImageFolders(
            [directory],
            transform=transforms.Compose([
                transforms.Resize(imgsize),
                AddPerturbation(perturbation),
                transforms.ToTensor(),
                transforms.Normalize(IMAGE_MEAN, IMAGE_STDEV)]),
            size=size)

def try_to_load_broden(directory, imgsize, broden_version, perturbation,
        download, size):
    # Load broden dataset
    ds_resolution = (224 if max(imgsize) <= 224 else
                     227 if max(imgsize) <= 227 else 384)
    if not os.path.isfile(os.path.join(directory,
           'broden%d_%d' % (broden_version, ds_resolution), 'index.csv')):
        return None
    return BrodenDataset(directory,
            resolution=ds_resolution,
            download=download,
            broden_version=broden_version,
            transform=transforms.Compose([
                transforms.Resize(imgsize),
                AddPerturbation(perturbation),
                transforms.ToTensor(),
                transforms.Normalize(IMAGE_MEAN, IMAGE_STDEV)]),
            size=size)

def try_to_load_multiseg(directory, imgsize, perturbation, size):
    if not os.path.isfile(os.path.join(directory, 'labelnames.json')):
        return None
    minsize = min(imgsize) if hasattr(imgsize, '__iter__') else imgsize
    return MultiSegmentDataset(directory,
            transform=(transforms.Compose([
                transforms.Resize(minsize),
                transforms.CenterCrop(imgsize),
                AddPerturbation(perturbation),
                transforms.ToTensor(),
                transforms.Normalize(IMAGE_MEAN, IMAGE_STDEV)]),
            transforms.Compose([
                transforms.Resize(minsize, interpolation=PIL.Image.NEAREST),
                transforms.CenterCrop(imgsize)])),
            size=size)

def add_scale_offset_info(model, layer_names):
    '''
    Creates a 'scale_offset' property on the model which guesses
    how to offset the featuremap, in cases where the convolutional
    padding does not exacly correspond to keeping featuremap pixels
    centered on the downsampled regions of the input.  This mainly
    shows up in AlexNet: ResNet and VGG pad convolutions to keep
    them centered and do not need this.
    '''
    model.scale_offset = {}
    seen = set()
    sequence = []
    aka_map = {}
    for name in layer_names:
        aka = name
        if not isinstance(aka, str):
            name, aka = name
        aka_map[name] = aka
    for name, layer in model.named_modules():
        sequence.append(layer)
        if name in aka_map:
            seen.add(name)
            aka = aka_map[name]
            model.scale_offset[aka] = sequence_scale_offset(sequence)
    for name in aka_map:
        assert name in seen, ('Layer %s not found' % name)

def dilation_scale_offset(dilations):
    '''Composes a list of (k, s, p) into a single total scale and offset.'''
    if len(dilations) == 0:
        return (1, 0)
    scale, offset = dilation_scale_offset(dilations[1:])
    kernel, stride, padding = dilations[0]
    scale *= stride
    offset *= stride
    offset += (kernel - 1) / 2.0 - padding
    return scale, offset

def dilations(modulelist):
    '''Converts a list of modules to (kernel_size, stride, padding)'''
    result = []
    for module in modulelist:
        settings = tuple(getattr(module, n, d)
            for n, d in (('kernel_size', 1), ('stride', 1), ('padding', 0)))
        settings = (((s, s) if not isinstance(s, tuple) else s)
            for s in settings)
        if settings != ((1, 1), (1, 1), (0, 0)):
            result.append(zip(*settings))
    return zip(*result)

def sequence_scale_offset(modulelist):
    '''Returns (yscale, yoffset), (xscale, xoffset) given a list of modules'''
    return tuple(dilation_scale_offset(d) for d in dilations(modulelist))


def strfloat(s):
    try:
        return float(s)
    except:
        return s

class FloatRange(object):
    def __init__(self, start, end):
        self.start = start
        self.end = end
    def __eq__(self, other):
        return isinstance(other, float) and self.start <= other <= self.end
    def __repr__(self):
        return '[%g-%g]' % (self.start, self.end)

# Many models use this normalization.
IMAGE_MEAN = [0.485, 0.456, 0.406]
IMAGE_STDEV = [0.229, 0.224, 0.225]

if __name__ == '__main__':
    main()