|
import torch, sys, os, argparse, textwrap, numbers, numpy, json, PIL |
|
from torchvision import transforms |
|
from torch.utils.data import TensorDataset |
|
from netdissect import pbar |
|
from netdissect.nethook import edit_layers |
|
from netdissect.zdataset import standard_z_sample |
|
from netdissect.autoeval import autoimport_eval |
|
from netdissect.easydict import EasyDict |
|
from netdissect.modelconfig import create_instrumented_model |
|
|
|
help_epilog = '''\ |
|
Example: |
|
|
|
python -m netdissect.evalablate \ |
|
--segmenter "netdissect.segmenter.UnifiedParsingSegmenter(segsizes=[256], segdiv='quad')" \ |
|
--model "proggan.from_pth_file('models/lsun_models/${SCENE}_lsun.pth')" \ |
|
--outdir dissect/dissectdir \ |
|
--classes mirror coffeetable tree \ |
|
--layers layer4 \ |
|
--size 1000 |
|
|
|
Output layout: |
|
dissectdir/layer5/ablation/mirror-iqr.json |
|
{ class: "mirror", |
|
classnum: 43, |
|
pixel_total: 41342300, |
|
class_pixels: 1234531, |
|
layer: "layer5", |
|
ranking: "mirror-iqr", |
|
ablation_units: [341, 23, 12, 142, 83, ...] |
|
ablation_pixels: [143242, 132344, 429931, ...] |
|
} |
|
|
|
''' |
|
|
|
def main(): |
|
|
|
def strpair(arg): |
|
p = tuple(arg.split(':')) |
|
if len(p) == 1: |
|
p = p + p |
|
return p |
|
|
|
parser = argparse.ArgumentParser(description='Ablation eval', |
|
epilog=textwrap.dedent(help_epilog), |
|
formatter_class=argparse.RawDescriptionHelpFormatter) |
|
parser.add_argument('--model', type=str, default=None, |
|
help='constructor for the model to test') |
|
parser.add_argument('--pthfile', type=str, default=None, |
|
help='filename of .pth file for the model') |
|
parser.add_argument('--outdir', type=str, default='dissect', required=True, |
|
help='directory for dissection output') |
|
parser.add_argument('--layers', type=strpair, nargs='+', |
|
help='space-separated list of layer names to edit' + |
|
', in the form layername[:reportedname]') |
|
parser.add_argument('--classes', type=str, nargs='+', |
|
help='space-separated list of class names to ablate') |
|
parser.add_argument('--metric', type=str, default='iou', |
|
help='ordering metric for selecting units') |
|
parser.add_argument('--unitcount', type=int, default=30, |
|
help='number of units to ablate') |
|
parser.add_argument('--segmenter', type=str, |
|
help='directory containing segmentation dataset') |
|
parser.add_argument('--netname', type=str, default=None, |
|
help='name for network in generated reports') |
|
parser.add_argument('--batch_size', type=int, default=5, |
|
help='batch size for forward pass') |
|
parser.add_argument('--size', type=int, default=200, |
|
help='number of images to test') |
|
parser.add_argument('--no-cuda', action='store_true', default=False, |
|
help='disables CUDA usage') |
|
parser.add_argument('--quiet', action='store_true', default=False, |
|
help='silences console output') |
|
if len(sys.argv) == 1: |
|
parser.print_usage(sys.stderr) |
|
sys.exit(1) |
|
args = parser.parse_args() |
|
|
|
|
|
pbar.verbose(not args.quiet) |
|
|
|
|
|
torch.backends.cudnn.benchmark = True |
|
|
|
|
|
args.cuda = not args.no_cuda and torch.cuda.is_available() |
|
if args.cuda: |
|
torch.backends.cudnn.benchmark = True |
|
|
|
|
|
with open(os.path.join(args.outdir, 'dissect.json')) as f: |
|
dissection = EasyDict(json.load(f)) |
|
if args.model is None: |
|
args.model = dissection.settings.model |
|
if args.pthfile is None: |
|
args.pthfile = dissection.settings.pthfile |
|
if args.segmenter is None: |
|
args.segmenter = dissection.settings.segmenter |
|
|
|
|
|
model = create_instrumented_model(args, gen=True, edit=True) |
|
if model is None: |
|
print('No model specified') |
|
sys.exit(1) |
|
|
|
|
|
device = next(model.parameters()).device |
|
input_shape = model.input_shape |
|
|
|
|
|
raw_sample = standard_z_sample(args.size, input_shape[1], seed=2).view( |
|
(args.size,) + input_shape[1:]) |
|
dataset = TensorDataset(raw_sample) |
|
|
|
|
|
segmenter = autoimport_eval(args.segmenter) |
|
|
|
|
|
labelnames, catnames = ( |
|
segmenter.get_label_and_category_names(dataset)) |
|
label_category = [catnames.index(c) if c in catnames else 0 |
|
for l, c in labelnames] |
|
labelnum_from_name = {n[0]: i for i, n in enumerate(labelnames)} |
|
|
|
segloader = torch.utils.data.DataLoader(dataset, |
|
batch_size=args.batch_size, num_workers=10, |
|
pin_memory=(device.type == 'cuda')) |
|
|
|
|
|
dissect_layer = {lrec.layer: lrec for lrec in dissection.layers} |
|
|
|
|
|
for l in model.ablation: |
|
model.ablation[l] = None |
|
|
|
|
|
for classname in pbar(args.classes): |
|
pbar.post(c=classname) |
|
for layername in pbar(model.ablation): |
|
pbar.post(l=layername) |
|
rankname = '%s-%s' % (classname, args.metric) |
|
classnum = labelnum_from_name[classname] |
|
try: |
|
ranking = next(r for r in dissect_layer[layername].rankings |
|
if r.name == rankname) |
|
except: |
|
print('%s not found' % rankname) |
|
sys.exit(1) |
|
ordering = numpy.argsort(ranking.score) |
|
|
|
ablationdir = os.path.join(args.outdir, layername, 'pixablation') |
|
if os.path.isfile(os.path.join(ablationdir, '%s.json'%rankname)): |
|
with open(os.path.join(ablationdir, '%s.json'%rankname)) as f: |
|
data = EasyDict(json.load(f)) |
|
|
|
if not all(a == o |
|
for a, o in zip(data.ablation_units, ordering)): |
|
continue |
|
if len(data.ablation_effects) >= args.unitcount: |
|
continue |
|
measurements = data.ablation_effects |
|
measurements = measure_ablation(segmenter, segloader, |
|
model, classnum, layername, ordering[:args.unitcount]) |
|
measurements = measurements.cpu().numpy().tolist() |
|
os.makedirs(ablationdir, exist_ok=True) |
|
with open(os.path.join(ablationdir, '%s.json'%rankname), 'w') as f: |
|
json.dump(dict( |
|
classname=classname, |
|
classnum=classnum, |
|
baseline=measurements[0], |
|
layer=layername, |
|
metric=args.metric, |
|
ablation_units=ordering.tolist(), |
|
ablation_effects=measurements[1:]), f) |
|
|
|
def measure_ablation(segmenter, loader, model, classnum, layer, ordering): |
|
total_bincount = 0 |
|
data_size = 0 |
|
device = next(model.parameters()).device |
|
for l in model.ablation: |
|
model.ablation[l] = None |
|
feature_units = model.feature_shape[layer][1] |
|
feature_shape = model.feature_shape[layer][2:] |
|
repeats = len(ordering) |
|
total_scores = torch.zeros(repeats + 1) |
|
for i, batch in enumerate(pbar(loader)): |
|
z_batch = batch[0] |
|
model.ablation[layer] = None |
|
tensor_images = model(z_batch.to(device)) |
|
seg = segmenter.segment_batch(tensor_images, downsample=2) |
|
mask = (seg == classnum).max(1)[0] |
|
downsampled_seg = torch.nn.functional.adaptive_avg_pool2d( |
|
mask.float()[:,None,:,:], feature_shape)[:,0,:,:] |
|
total_scores[0] += downsampled_seg.sum().cpu() |
|
|
|
|
|
interventions_needed = downsampled_seg.nonzero() |
|
location_count = len(interventions_needed) |
|
if location_count == 0: |
|
continue |
|
interventions_needed = interventions_needed.repeat(repeats, 1) |
|
inter_z = batch[0][interventions_needed[:,0]].to(device) |
|
inter_chan = torch.zeros(repeats, location_count, feature_units, |
|
device=device) |
|
for j, u in enumerate(ordering): |
|
inter_chan[j:, :, u] = 1 |
|
inter_chan = inter_chan.view(len(inter_z), feature_units) |
|
inter_loc = interventions_needed[:,1:] |
|
scores = torch.zeros(len(inter_z)) |
|
batch_size = len(batch[0]) |
|
for j in range(0, len(inter_z), batch_size): |
|
ibz = inter_z[j:j+batch_size] |
|
ibl = inter_loc[j:j+batch_size].t() |
|
imask = torch.zeros((len(ibz),) + feature_shape, device=ibz.device) |
|
imask[(torch.arange(len(ibz)),) + tuple(ibl)] = 1 |
|
ibc = inter_chan[j:j+batch_size] |
|
model.ablation[layer] = ( |
|
imask.float()[:,None,:,:] * ibc[:,:,None,None]) |
|
tensor_images = model(ibz) |
|
seg = segmenter.segment_batch(tensor_images, downsample=2) |
|
mask = (seg == classnum).max(1)[0] |
|
downsampled_iseg = torch.nn.functional.adaptive_avg_pool2d( |
|
mask.float()[:,None,:,:], feature_shape)[:,0,:,:] |
|
scores[j:j+batch_size] = downsampled_iseg[ |
|
(torch.arange(len(ibz)),) + tuple(ibl)] |
|
scores = scores.view(repeats, location_count).sum(1) |
|
total_scores[1:] += scores |
|
return total_scores |
|
|
|
if __name__ == '__main__': |
|
main() |
|
|