Spaces:
Sleeping
Sleeping
import torch, sys, os, argparse, textwrap, numbers, numpy, json, PIL | |
from torchvision import transforms | |
from torch.utils.data import TensorDataset | |
from netdissect.progress import verbose_progress, print_progress | |
from netdissect import InstrumentedModel, BrodenDataset, dissect | |
from netdissect import MultiSegmentDataset, GeneratorSegRunner | |
from netdissect import ImageOnlySegRunner | |
from netdissect.parallelfolder import ParallelImageFolders | |
from netdissect.zdataset import z_dataset_for_model | |
from netdissect.autoeval import autoimport_eval | |
from netdissect.modelconfig import create_instrumented_model | |
from netdissect.pidfile import exit_if_job_done, mark_job_done | |
help_epilog = '''\ | |
Example: to dissect three layers of the pretrained alexnet in torchvision: | |
python -m netdissect \\ | |
--model "torchvision.models.alexnet(pretrained=True)" \\ | |
--layers features.6:conv3 features.8:conv4 features.10:conv5 \\ | |
--imgsize 227 \\ | |
--outdir dissect/alexnet-imagenet | |
To dissect a progressive GAN model: | |
python -m netdissect \\ | |
--model "proggan.from_pth_file('model/churchoutdoor.pth')" \\ | |
--gan | |
''' | |
def main(): | |
# Training settings | |
def strpair(arg): | |
p = tuple(arg.split(':')) | |
if len(p) == 1: | |
p = p + p | |
return p | |
def intpair(arg): | |
p = arg.split(',') | |
if len(p) == 1: | |
p = p + p | |
return tuple(int(v) for v in p) | |
parser = argparse.ArgumentParser(description='Net dissect utility', | |
prog='python -m netdissect', | |
epilog=textwrap.dedent(help_epilog), | |
formatter_class=argparse.RawDescriptionHelpFormatter) | |
parser.add_argument('--model', type=str, default=None, | |
help='constructor for the model to test') | |
parser.add_argument('--pthfile', type=str, default=None, | |
help='filename of .pth file for the model') | |
parser.add_argument('--unstrict', action='store_true', default=False, | |
help='ignore unexpected pth parameters') | |
parser.add_argument('--submodule', type=str, default=None, | |
help='submodule to load from pthfile') | |
parser.add_argument('--outdir', type=str, default='dissect', | |
help='directory for dissection output') | |
parser.add_argument('--layers', type=strpair, nargs='+', | |
help='space-separated list of layer names to dissect' + | |
', in the form layername[:reportedname]') | |
parser.add_argument('--segments', type=str, default='dataset/broden', | |
help='directory containing segmentation dataset') | |
parser.add_argument('--segmenter', type=str, default=None, | |
help='constructor for asegmenter class') | |
parser.add_argument('--download', action='store_true', default=False, | |
help='downloads Broden dataset if needed') | |
parser.add_argument('--imagedir', type=str, default=None, | |
help='directory containing image-only dataset') | |
parser.add_argument('--imgsize', type=intpair, default=(227, 227), | |
help='input image size to use') | |
parser.add_argument('--netname', type=str, default=None, | |
help='name for network in generated reports') | |
parser.add_argument('--meta', type=str, nargs='+', | |
help='json files of metadata to add to report') | |
parser.add_argument('--merge', type=str, | |
help='json file of unit data to merge in report') | |
parser.add_argument('--examples', type=int, default=20, | |
help='number of image examples per unit') | |
parser.add_argument('--size', type=int, default=10000, | |
help='dataset subset size to use') | |
parser.add_argument('--batch_size', type=int, default=100, | |
help='batch size for forward pass') | |
parser.add_argument('--num_workers', type=int, default=24, | |
help='number of DataLoader workers') | |
parser.add_argument('--quantile_threshold', type=strfloat, default=None, | |
choices=[FloatRange(0.0, 1.0), 'iqr'], | |
help='quantile to use for masks') | |
parser.add_argument('--no-labels', action='store_true', default=False, | |
help='disables labeling of units') | |
parser.add_argument('--maxiou', action='store_true', default=False, | |
help='enables maxiou calculation') | |
parser.add_argument('--covariance', action='store_true', default=False, | |
help='enables covariance calculation') | |
parser.add_argument('--rank_all_labels', action='store_true', default=False, | |
help='include low-information labels in rankings') | |
parser.add_argument('--no-images', action='store_true', default=False, | |
help='disables generation of unit images') | |
parser.add_argument('--no-report', action='store_true', default=False, | |
help='disables generation report summary') | |
parser.add_argument('--no-cuda', action='store_true', default=False, | |
help='disables CUDA usage') | |
parser.add_argument('--gen', action='store_true', default=False, | |
help='test a generator model (e.g., a GAN)') | |
parser.add_argument('--gan', action='store_true', default=False, | |
help='synonym for --gen') | |
parser.add_argument('--perturbation', default=None, | |
help='filename of perturbation attack to apply') | |
parser.add_argument('--add_scale_offset', action='store_true', default=None, | |
help='offsets masks according to stride and padding') | |
parser.add_argument('--quiet', action='store_true', default=False, | |
help='silences console output') | |
if len(sys.argv) == 1: | |
parser.print_usage(sys.stderr) | |
sys.exit(1) | |
args = parser.parse_args() | |
args.images = not args.no_images | |
args.report = not args.no_report | |
args.labels = not args.no_labels | |
if args.gan: | |
args.gen = args.gan | |
# Set up console output | |
verbose_progress(not args.quiet) | |
# Exit right away if job is already done or being done. | |
if args.outdir is not None: | |
exit_if_job_done(args.outdir) | |
# Speed up pytorch | |
torch.backends.cudnn.benchmark = True | |
# Special case: download flag without model to test. | |
if args.model is None and args.download: | |
from netdissect.broden import ensure_broden_downloaded | |
for resolution in [224, 227, 384]: | |
ensure_broden_downloaded(args.segments, resolution, 1) | |
from netdissect.segmenter import ensure_upp_segmenter_downloaded | |
ensure_upp_segmenter_downloaded('dataset/segmodel') | |
sys.exit(0) | |
# Help if broden is not present | |
if not args.gen and not args.imagedir and not os.path.isdir(args.segments): | |
print_progress('Segmentation dataset not found at %s.' % args.segments) | |
print_progress('Specify dataset directory using --segments [DIR]') | |
print_progress('To download Broden, run: netdissect --download') | |
sys.exit(1) | |
# Default segmenter class | |
if args.gen and args.segmenter is None: | |
args.segmenter = ("netdissect.segmenter.UnifiedParsingSegmenter(" + | |
"segsizes=[256], segdiv='quad')") | |
# Default threshold | |
if args.quantile_threshold is None: | |
if args.gen: | |
args.quantile_threshold = 'iqr' | |
else: | |
args.quantile_threshold = 0.005 | |
# Set up CUDA | |
args.cuda = not args.no_cuda and torch.cuda.is_available() | |
if args.cuda: | |
torch.backends.cudnn.benchmark = True | |
# Construct the network with specified layers instrumented | |
if args.model is None: | |
print_progress('No model specified') | |
sys.exit(1) | |
model = create_instrumented_model(args) | |
# Update any metadata from files, if any | |
meta = getattr(model, 'meta', {}) | |
if args.meta: | |
for mfilename in args.meta: | |
with open(mfilename) as f: | |
meta.update(json.load(f)) | |
# Load any merge data from files | |
mergedata = None | |
if args.merge: | |
with open(args.merge) as f: | |
mergedata = json.load(f) | |
# Set up the output directory, verify write access | |
if args.outdir is None: | |
args.outdir = os.path.join('dissect', type(model).__name__) | |
exit_if_job_done(args.outdir) | |
print_progress('Writing output into %s.' % args.outdir) | |
os.makedirs(args.outdir, exist_ok=True) | |
train_dataset = None | |
if not args.gen: | |
# Load dataset for classifier case. | |
# Load perturbation | |
perturbation = numpy.load(args.perturbation | |
) if args.perturbation else None | |
segrunner = None | |
# Load broden dataset | |
if args.imagedir is not None: | |
dataset = try_to_load_images(args.imagedir, args.imgsize, | |
perturbation, args.size) | |
segrunner = ImageOnlySegRunner(dataset) | |
else: | |
dataset = try_to_load_broden(args.segments, args.imgsize, 1, | |
perturbation, args.download, args.size) | |
if dataset is None: | |
dataset = try_to_load_multiseg(args.segments, args.imgsize, | |
perturbation, args.size) | |
if dataset is None: | |
print_progress('No segmentation dataset found in %s', | |
args.segments) | |
print_progress('use --download to download Broden.') | |
sys.exit(1) | |
else: | |
# For segmenter case the dataset is just a random z | |
dataset = z_dataset_for_model(model, args.size) | |
train_dataset = z_dataset_for_model(model, args.size, seed=2) | |
segrunner = GeneratorSegRunner(autoimport_eval(args.segmenter)) | |
# Run dissect | |
dissect(args.outdir, model, dataset, | |
train_dataset=train_dataset, | |
segrunner=segrunner, | |
examples_per_unit=args.examples, | |
netname=args.netname, | |
quantile_threshold=args.quantile_threshold, | |
meta=meta, | |
merge=mergedata, | |
make_images=args.images, | |
make_labels=args.labels, | |
make_maxiou=args.maxiou, | |
make_covariance=args.covariance, | |
make_report=args.report, | |
make_row_images=args.images, | |
make_single_images=True, | |
rank_all_labels=args.rank_all_labels, | |
batch_size=args.batch_size, | |
num_workers=args.num_workers, | |
settings=vars(args)) | |
# Mark the directory so that it's not done again. | |
mark_job_done(args.outdir) | |
class AddPerturbation(object): | |
def __init__(self, perturbation): | |
self.perturbation = perturbation | |
def __call__(self, pic): | |
if self.perturbation is None: | |
return pic | |
# Convert to a numpy float32 array | |
npyimg = numpy.array(pic, numpy.uint8, copy=False | |
).astype(numpy.float32) | |
# Center the perturbation | |
oy, ox = ((self.perturbation.shape[d] - npyimg.shape[d]) // 2 | |
for d in [0, 1]) | |
npyimg += self.perturbation[ | |
oy:oy+npyimg.shape[0], ox:ox+npyimg.shape[1]] | |
# Pytorch conventions: as a float it should be [0..1] | |
npyimg.clip(0, 255, npyimg) | |
return npyimg / 255.0 | |
def test_dissection(): | |
verbose_progress(True) | |
from torchvision.models import alexnet | |
from torchvision import transforms | |
model = InstrumentedModel(alexnet(pretrained=True)) | |
model.eval() | |
# Load an alexnet | |
model.retain_layers([ | |
('features.0', 'conv1'), | |
('features.3', 'conv2'), | |
('features.6', 'conv3'), | |
('features.8', 'conv4'), | |
('features.10', 'conv5') ]) | |
# load broden dataset | |
bds = BrodenDataset('dataset/broden', | |
transform=transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Normalize(IMAGE_MEAN, IMAGE_STDEV)]), | |
size=100) | |
# run dissect | |
dissect('dissect/test', model, bds, | |
examples_per_unit=10) | |
def try_to_load_images(directory, imgsize, perturbation, size): | |
# Load plain image dataset | |
# TODO: allow other normalizations. | |
return ParallelImageFolders( | |
[directory], | |
transform=transforms.Compose([ | |
transforms.Resize(imgsize), | |
AddPerturbation(perturbation), | |
transforms.ToTensor(), | |
transforms.Normalize(IMAGE_MEAN, IMAGE_STDEV)]), | |
size=size) | |
def try_to_load_broden(directory, imgsize, broden_version, perturbation, | |
download, size): | |
# Load broden dataset | |
ds_resolution = (224 if max(imgsize) <= 224 else | |
227 if max(imgsize) <= 227 else 384) | |
if not os.path.isfile(os.path.join(directory, | |
'broden%d_%d' % (broden_version, ds_resolution), 'index.csv')): | |
return None | |
return BrodenDataset(directory, | |
resolution=ds_resolution, | |
download=download, | |
broden_version=broden_version, | |
transform=transforms.Compose([ | |
transforms.Resize(imgsize), | |
AddPerturbation(perturbation), | |
transforms.ToTensor(), | |
transforms.Normalize(IMAGE_MEAN, IMAGE_STDEV)]), | |
size=size) | |
def try_to_load_multiseg(directory, imgsize, perturbation, size): | |
if not os.path.isfile(os.path.join(directory, 'labelnames.json')): | |
return None | |
minsize = min(imgsize) if hasattr(imgsize, '__iter__') else imgsize | |
return MultiSegmentDataset(directory, | |
transform=(transforms.Compose([ | |
transforms.Resize(minsize), | |
transforms.CenterCrop(imgsize), | |
AddPerturbation(perturbation), | |
transforms.ToTensor(), | |
transforms.Normalize(IMAGE_MEAN, IMAGE_STDEV)]), | |
transforms.Compose([ | |
transforms.Resize(minsize, interpolation=PIL.Image.NEAREST), | |
transforms.CenterCrop(imgsize)])), | |
size=size) | |
def add_scale_offset_info(model, layer_names): | |
''' | |
Creates a 'scale_offset' property on the model which guesses | |
how to offset the featuremap, in cases where the convolutional | |
padding does not exacly correspond to keeping featuremap pixels | |
centered on the downsampled regions of the input. This mainly | |
shows up in AlexNet: ResNet and VGG pad convolutions to keep | |
them centered and do not need this. | |
''' | |
model.scale_offset = {} | |
seen = set() | |
sequence = [] | |
aka_map = {} | |
for name in layer_names: | |
aka = name | |
if not isinstance(aka, str): | |
name, aka = name | |
aka_map[name] = aka | |
for name, layer in model.named_modules(): | |
sequence.append(layer) | |
if name in aka_map: | |
seen.add(name) | |
aka = aka_map[name] | |
model.scale_offset[aka] = sequence_scale_offset(sequence) | |
for name in aka_map: | |
assert name in seen, ('Layer %s not found' % name) | |
def dilation_scale_offset(dilations): | |
'''Composes a list of (k, s, p) into a single total scale and offset.''' | |
if len(dilations) == 0: | |
return (1, 0) | |
scale, offset = dilation_scale_offset(dilations[1:]) | |
kernel, stride, padding = dilations[0] | |
scale *= stride | |
offset *= stride | |
offset += (kernel - 1) / 2.0 - padding | |
return scale, offset | |
def dilations(modulelist): | |
'''Converts a list of modules to (kernel_size, stride, padding)''' | |
result = [] | |
for module in modulelist: | |
settings = tuple(getattr(module, n, d) | |
for n, d in (('kernel_size', 1), ('stride', 1), ('padding', 0))) | |
settings = (((s, s) if not isinstance(s, tuple) else s) | |
for s in settings) | |
if settings != ((1, 1), (1, 1), (0, 0)): | |
result.append(zip(*settings)) | |
return zip(*result) | |
def sequence_scale_offset(modulelist): | |
'''Returns (yscale, yoffset), (xscale, xoffset) given a list of modules''' | |
return tuple(dilation_scale_offset(d) for d in dilations(modulelist)) | |
def strfloat(s): | |
try: | |
return float(s) | |
except: | |
return s | |
class FloatRange(object): | |
def __init__(self, start, end): | |
self.start = start | |
self.end = end | |
def __eq__(self, other): | |
return isinstance(other, float) and self.start <= other <= self.end | |
def __repr__(self): | |
return '[%g-%g]' % (self.start, self.end) | |
# Many models use this normalization. | |
IMAGE_MEAN = [0.485, 0.456, 0.406] | |
IMAGE_STDEV = [0.229, 0.224, 0.225] | |
if __name__ == '__main__': | |
main() | |