import torch, sys, os, argparse, textwrap, numbers, numpy, json, PIL from torchvision import transforms from torch.utils.data import TensorDataset from netdissect.progress import default_progress, post_progress, desc_progress from netdissect.progress import verbose_progress, print_progress from netdissect.nethook import edit_layers from netdissect.zdataset import standard_z_sample from netdissect.autoeval import autoimport_eval from netdissect.easydict import EasyDict from netdissect.modelconfig import create_instrumented_model help_epilog = '''\ Example: python -m netdissect.evalablate \ --segmenter "netdissect.segmenter.UnifiedParsingSegmenter(segsizes=[256], segdiv='quad')" \ --model "proggan.from_pth_file('models/lsun_models/${SCENE}_lsun.pth')" \ --outdir dissect/dissectdir \ --classes mirror coffeetable tree \ --layers layer4 \ --size 1000 Output layout: dissectdir/layer5/ablation/mirror-iqr.json { class: "mirror", classnum: 43, pixel_total: 41342300, class_pixels: 1234531, layer: "layer5", ranking: "mirror-iqr", ablation_units: [341, 23, 12, 142, 83, ...] ablation_pixels: [143242, 132344, 429931, ...] } ''' def main(): # Training settings def strpair(arg): p = tuple(arg.split(':')) if len(p) == 1: p = p + p return p parser = argparse.ArgumentParser(description='Ablation eval', epilog=textwrap.dedent(help_epilog), formatter_class=argparse.RawDescriptionHelpFormatter) parser.add_argument('--model', type=str, default=None, help='constructor for the model to test') parser.add_argument('--pthfile', type=str, default=None, help='filename of .pth file for the model') parser.add_argument('--outdir', type=str, default='dissect', required=True, help='directory for dissection output') parser.add_argument('--layers', type=strpair, nargs='+', help='space-separated list of layer names to edit' + ', in the form layername[:reportedname]') parser.add_argument('--classes', type=str, nargs='+', help='space-separated list of class names to ablate') parser.add_argument('--metric', type=str, default='iou', help='ordering metric for selecting units') parser.add_argument('--unitcount', type=int, default=30, help='number of units to ablate') parser.add_argument('--segmenter', type=str, help='directory containing segmentation dataset') parser.add_argument('--netname', type=str, default=None, help='name for network in generated reports') parser.add_argument('--batch_size', type=int, default=5, help='batch size for forward pass') parser.add_argument('--size', type=int, default=200, help='number of images to test') parser.add_argument('--no-cuda', action='store_true', default=False, help='disables CUDA usage') parser.add_argument('--quiet', action='store_true', default=False, help='silences console output') if len(sys.argv) == 1: parser.print_usage(sys.stderr) sys.exit(1) args = parser.parse_args() # Set up console output verbose_progress(not args.quiet) # Speed up pytorch torch.backends.cudnn.benchmark = True # Set up CUDA args.cuda = not args.no_cuda and torch.cuda.is_available() if args.cuda: torch.backends.cudnn.benchmark = True # Take defaults for model constructor etc from dissect.json settings. with open(os.path.join(args.outdir, 'dissect.json')) as f: dissection = EasyDict(json.load(f)) if args.model is None: args.model = dissection.settings.model if args.pthfile is None: args.pthfile = dissection.settings.pthfile if args.segmenter is None: args.segmenter = dissection.settings.segmenter # Instantiate generator model = create_instrumented_model(args, gen=True, edit=True) if model is None: print('No model specified') sys.exit(1) # Instantiate model device = next(model.parameters()).device input_shape = model.input_shape # 4d input if convolutional, 2d input if first layer is linear. raw_sample = standard_z_sample(args.size, input_shape[1], seed=2).view( (args.size,) + input_shape[1:]) dataset = TensorDataset(raw_sample) # Create the segmenter segmenter = autoimport_eval(args.segmenter) # Now do the actual work. labelnames, catnames = ( segmenter.get_label_and_category_names(dataset)) label_category = [catnames.index(c) if c in catnames else 0 for l, c in labelnames] labelnum_from_name = {n[0]: i for i, n in enumerate(labelnames)} segloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, num_workers=10, pin_memory=(device.type == 'cuda')) # Index the dissection layers by layer name. dissect_layer = {lrec.layer: lrec for lrec in dissection.layers} # First, collect a baseline for l in model.ablation: model.ablation[l] = None # For each sort-order, do an ablation progress = default_progress() for classname in progress(args.classes): post_progress(c=classname) for layername in progress(model.ablation): post_progress(l=layername) rankname = '%s-%s' % (classname, args.metric) classnum = labelnum_from_name[classname] try: ranking = next(r for r in dissect_layer[layername].rankings if r.name == rankname) except: print('%s not found' % rankname) sys.exit(1) ordering = numpy.argsort(ranking.score) # Check if already done ablationdir = os.path.join(args.outdir, layername, 'pixablation') if os.path.isfile(os.path.join(ablationdir, '%s.json'%rankname)): with open(os.path.join(ablationdir, '%s.json'%rankname)) as f: data = EasyDict(json.load(f)) # If the unit ordering is not the same, something is wrong if not all(a == o for a, o in zip(data.ablation_units, ordering)): continue if len(data.ablation_effects) >= args.unitcount: continue # file already done. measurements = data.ablation_effects measurements = measure_ablation(segmenter, segloader, model, classnum, layername, ordering[:args.unitcount]) measurements = measurements.cpu().numpy().tolist() os.makedirs(ablationdir, exist_ok=True) with open(os.path.join(ablationdir, '%s.json'%rankname), 'w') as f: json.dump(dict( classname=classname, classnum=classnum, baseline=measurements[0], layer=layername, metric=args.metric, ablation_units=ordering.tolist(), ablation_effects=measurements[1:]), f) def measure_ablation(segmenter, loader, model, classnum, layer, ordering): total_bincount = 0 data_size = 0 device = next(model.parameters()).device progress = default_progress() for l in model.ablation: model.ablation[l] = None feature_units = model.feature_shape[layer][1] feature_shape = model.feature_shape[layer][2:] repeats = len(ordering) total_scores = torch.zeros(repeats + 1) for i, batch in enumerate(progress(loader)): z_batch = batch[0] model.ablation[layer] = None tensor_images = model(z_batch.to(device)) seg = segmenter.segment_batch(tensor_images, downsample=2) mask = (seg == classnum).max(1)[0] downsampled_seg = torch.nn.functional.adaptive_avg_pool2d( mask.float()[:,None,:,:], feature_shape)[:,0,:,:] total_scores[0] += downsampled_seg.sum().cpu() # Now we need to do an intervention for every location # that had a nonzero downsampled_seg, if any. interventions_needed = downsampled_seg.nonzero() location_count = len(interventions_needed) if location_count == 0: continue interventions_needed = interventions_needed.repeat(repeats, 1) inter_z = batch[0][interventions_needed[:,0]].to(device) inter_chan = torch.zeros(repeats, location_count, feature_units, device=device) for j, u in enumerate(ordering): inter_chan[j:, :, u] = 1 inter_chan = inter_chan.view(len(inter_z), feature_units) inter_loc = interventions_needed[:,1:] scores = torch.zeros(len(inter_z)) batch_size = len(batch[0]) for j in range(0, len(inter_z), batch_size): ibz = inter_z[j:j+batch_size] ibl = inter_loc[j:j+batch_size].t() imask = torch.zeros((len(ibz),) + feature_shape, device=ibz.device) imask[(torch.arange(len(ibz)),) + tuple(ibl)] = 1 ibc = inter_chan[j:j+batch_size] model.ablation[layer] = ( imask.float()[:,None,:,:] * ibc[:,:,None,None]) tensor_images = model(ibz) seg = segmenter.segment_batch(tensor_images, downsample=2) mask = (seg == classnum).max(1)[0] downsampled_iseg = torch.nn.functional.adaptive_avg_pool2d( mask.float()[:,None,:,:], feature_shape)[:,0,:,:] scores[j:j+batch_size] = downsampled_iseg[ (torch.arange(len(ibz)),) + tuple(ibl)] scores = scores.view(repeats, location_count).sum(1) total_scores[1:] += scores return total_scores def count_segments(segmenter, loader, model): total_bincount = 0 data_size = 0 progress = default_progress() for i, batch in enumerate(progress(loader)): tensor_images = model(z_batch.to(device)) seg = segmenter.segment_batch(tensor_images, downsample=2) bc = (seg + index[:, None, None, None] * self.num_classes).view(-1 ).bincount(minlength=z_batch.shape[0] * self.num_classes) data_size += seg.shape[0] * seg.shape[2] * seg.shape[3] total_bincount += batch_label_counts.float().sum(0) normalized_bincount = total_bincount / data_size return normalized_bincount if __name__ == '__main__': main()