CharacterGAN / netdissect /fullablate.py
mfrashad's picture
Init code
8f87579
import torch, sys, os, argparse, textwrap, numbers, numpy, json, PIL
from torchvision import transforms
from torch.utils.data import TensorDataset
from netdissect.progress import default_progress, post_progress, desc_progress
from netdissect.progress import verbose_progress, print_progress
from netdissect.nethook import edit_layers
from netdissect.zdataset import standard_z_sample
from netdissect.autoeval import autoimport_eval
from netdissect.easydict import EasyDict
from netdissect.modelconfig import create_instrumented_model
help_epilog = '''\
Example:
python -m netdissect.evalablate \
--segmenter "netdissect.GanImageSegmenter(segvocab='lowres', segsizes=[160,288], segdiv='quad')" \
--model "proggan.from_pth_file('models/lsun_models/${SCENE}_lsun.pth')" \
--outdir dissect/dissectdir \
--classname tree \
--layer layer4 \
--size 1000
Output layout:
dissectdir/layer5/ablation/mirror-iqr.json
{ class: "mirror",
classnum: 43,
pixel_total: 41342300,
class_pixels: 1234531,
layer: "layer5",
ranking: "mirror-iqr",
ablation_units: [341, 23, 12, 142, 83, ...]
ablation_pixels: [143242, 132344, 429931, ...]
}
'''
def main():
# Training settings
def strpair(arg):
p = tuple(arg.split(':'))
if len(p) == 1:
p = p + p
return p
parser = argparse.ArgumentParser(description='Ablation eval',
epilog=textwrap.dedent(help_epilog),
formatter_class=argparse.RawDescriptionHelpFormatter)
parser.add_argument('--model', type=str, default=None,
help='constructor for the model to test')
parser.add_argument('--pthfile', type=str, default=None,
help='filename of .pth file for the model')
parser.add_argument('--outdir', type=str, default='dissect', required=True,
help='directory for dissection output')
parser.add_argument('--layer', type=strpair,
help='space-separated list of layer names to edit' +
', in the form layername[:reportedname]')
parser.add_argument('--classname', type=str,
help='class name to ablate')
parser.add_argument('--metric', type=str, default='iou',
help='ordering metric for selecting units')
parser.add_argument('--unitcount', type=int, default=30,
help='number of units to ablate')
parser.add_argument('--segmenter', type=str,
help='directory containing segmentation dataset')
parser.add_argument('--netname', type=str, default=None,
help='name for network in generated reports')
parser.add_argument('--batch_size', type=int, default=25,
help='batch size for forward pass')
parser.add_argument('--mixed_units', action='store_true', default=False,
help='true to keep alpha for non-zeroed units')
parser.add_argument('--size', type=int, default=200,
help='number of images to test')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA usage')
parser.add_argument('--quiet', action='store_true', default=False,
help='silences console output')
if len(sys.argv) == 1:
parser.print_usage(sys.stderr)
sys.exit(1)
args = parser.parse_args()
# Set up console output
verbose_progress(not args.quiet)
# Speed up pytorch
torch.backends.cudnn.benchmark = True
# Set up CUDA
args.cuda = not args.no_cuda and torch.cuda.is_available()
if args.cuda:
torch.backends.cudnn.benchmark = True
# Take defaults for model constructor etc from dissect.json settings.
with open(os.path.join(args.outdir, 'dissect.json')) as f:
dissection = EasyDict(json.load(f))
if args.model is None:
args.model = dissection.settings.model
if args.pthfile is None:
args.pthfile = dissection.settings.pthfile
if args.segmenter is None:
args.segmenter = dissection.settings.segmenter
if args.layer is None:
args.layer = dissection.settings.layers[0]
args.layers = [args.layer]
# Also load specific analysis
layername = args.layer[1]
if args.metric == 'iou':
summary = dissection
else:
with open(os.path.join(args.outdir, layername, args.metric,
args.classname, 'summary.json')) as f:
summary = EasyDict(json.load(f))
# Instantiate generator
model = create_instrumented_model(args, gen=True, edit=True)
if model is None:
print('No model specified')
sys.exit(1)
# Instantiate model
device = next(model.parameters()).device
input_shape = model.input_shape
# 4d input if convolutional, 2d input if first layer is linear.
raw_sample = standard_z_sample(args.size, input_shape[1], seed=3).view(
(args.size,) + input_shape[1:])
dataset = TensorDataset(raw_sample)
# Create the segmenter
segmenter = autoimport_eval(args.segmenter)
# Now do the actual work.
labelnames, catnames = (
segmenter.get_label_and_category_names(dataset))
label_category = [catnames.index(c) if c in catnames else 0
for l, c in labelnames]
labelnum_from_name = {n[0]: i for i, n in enumerate(labelnames)}
segloader = torch.utils.data.DataLoader(dataset,
batch_size=args.batch_size, num_workers=10,
pin_memory=(device.type == 'cuda'))
# Index the dissection layers by layer name.
# First, collect a baseline
for l in model.ablation:
model.ablation[l] = None
# For each sort-order, do an ablation
progress = default_progress()
classname = args.classname
classnum = labelnum_from_name[classname]
# Get iou ranking from dissect.json
iou_rankname = '%s-%s' % (classname, 'iou')
dissect_layer = {lrec.layer: lrec for lrec in dissection.layers}
iou_ranking = next(r for r in dissect_layer[layername].rankings
if r.name == iou_rankname)
# Get trained ranking from summary.json
rankname = '%s-%s' % (classname, args.metric)
summary_layer = {lrec.layer: lrec for lrec in summary.layers}
ranking = next(r for r in summary_layer[layername].rankings
if r.name == rankname)
# Get ordering, first by ranking, then break ties by iou.
ordering = [t[2] for t in sorted([(s1, s2, i)
for i, (s1, s2) in enumerate(zip(ranking.score, iou_ranking.score))])]
values = (-numpy.array(ranking.score))[ordering]
if not args.mixed_units:
values[...] = 1
ablationdir = os.path.join(args.outdir, layername, 'fullablation')
measurements = measure_full_ablation(segmenter, segloader,
model, classnum, layername,
ordering[:args.unitcount], values[:args.unitcount])
measurements = measurements.cpu().numpy().tolist()
os.makedirs(ablationdir, exist_ok=True)
with open(os.path.join(ablationdir, '%s.json'%rankname), 'w') as f:
json.dump(dict(
classname=classname,
classnum=classnum,
baseline=measurements[0],
layer=layername,
metric=args.metric,
ablation_units=ordering,
ablation_values=values.tolist(),
ablation_effects=measurements[1:]), f)
def measure_full_ablation(segmenter, loader, model, classnum, layer,
ordering, values):
'''
Quick and easy counting of segmented pixels reduced by ablating units.
'''
progress = default_progress()
device = next(model.parameters()).device
feature_units = model.feature_shape[layer][1]
feature_shape = model.feature_shape[layer][2:]
repeats = len(ordering)
total_scores = torch.zeros(repeats + 1)
print(ordering)
print(values.tolist())
with torch.no_grad():
for l in model.ablation:
model.ablation[l] = None
for i, [ibz] in enumerate(progress(loader)):
ibz = ibz.cuda()
for num_units in progress(range(len(ordering) + 1)):
ablation = torch.zeros(feature_units, device=device)
ablation[ordering[:num_units]] = torch.tensor(
values[:num_units]).to(ablation.device, ablation.dtype)
model.ablation[layer] = ablation
tensor_images = model(ibz)
seg = segmenter.segment_batch(tensor_images, downsample=2)
mask = (seg == classnum).max(1)[0]
total_scores[num_units] += mask.sum().float().cpu()
return total_scores
def count_segments(segmenter, loader, model):
total_bincount = 0
data_size = 0
progress = default_progress()
for i, batch in enumerate(progress(loader)):
tensor_images = model(z_batch.to(device))
seg = segmenter.segment_batch(tensor_images, downsample=2)
bc = (seg + index[:, None, None, None] * self.num_classes).view(-1
).bincount(minlength=z_batch.shape[0] * self.num_classes)
data_size += seg.shape[0] * seg.shape[2] * seg.shape[3]
total_bincount += batch_label_counts.float().sum(0)
normalized_bincount = total_bincount / data_size
return normalized_bincount
if __name__ == '__main__':
main()