File size: 5,764 Bytes
92f0e98 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
import numbers
import torch
from netdissect.autoeval import autoimport_eval
from netdissect import pbar
from netdissect.nethook import InstrumentedModel
from netdissect.easydict import EasyDict
def create_instrumented_model(args, **kwargs):
'''
Creates an instrumented model out of a namespace of arguments that
correspond to ArgumentParser command-line args:
model: a string to evaluate as a constructor for the model.
pthfile: (optional) filename of .pth file for the model.
layers: a list of layers to instrument, defaulted if not provided.
edit: True to instrument the layers for editing.
gen: True for a generator model. One-pixel input assumed.
imgsize: For non-generator models, (y, x) dimensions for RGB input.
cuda: True to use CUDA.
The constructed model will be decorated with the following attributes:
input_shape: (usually 4d) tensor shape for single-image input.
output_shape: 4d tensor shape for output.
feature_shape: map of layer names to 4d tensor shape for featuremaps.
retained: map of layernames to tensors, filled after every evaluation.
ablation: if editing, map of layernames to [0..1] alpha values to fill.
replacement: if editing, map of layernames to values to fill.
When editing, the feature value x will be replaced by:
`x = (replacement * ablation) + (x * (1 - ablation))`
'''
args = EasyDict(vars(args), **kwargs)
# Construct the network
if args.model is None:
pbar.print('No model specified')
return None
if isinstance(args.model, torch.nn.Module):
model = args.model
else:
model = autoimport_eval(args.model)
# Unwrap any DataParallel-wrapped model
if isinstance(model, torch.nn.DataParallel):
model = next(model.children())
# Load its state dict
meta = {}
if getattr(args, 'pthfile', None) is not None:
data = torch.load(args.pthfile)
modelkey = getattr(args, 'modelkey', 'state_dict')
if modelkey in data:
meta = {}
for key in data:
if isinstance(data[key], numbers.Number):
meta[key] = data[key]
data = data[modelkey]
submodule = getattr(args, 'submodule', None)
if submodule is not None and len(submodule):
remove_prefix = submodule + '.'
data = { k[len(remove_prefix):]: v for k, v in data.items()
if k.startswith(remove_prefix)}
if not len(data):
pbar.print('No submodule %s found in %s' %
(submodule, args.pthfile))
return None
model.load_state_dict(data, strict=not getattr(args, 'unstrict', False))
# Decide which layers to instrument.
if getattr(args, 'layer', None) is not None:
args.layers = [args.layer]
# If the layer '?' is the only specified, just print out all layers.
if getattr(args, 'layers', None) is not None:
if len(args.layers) == 1 and args.layers[0] == ('?', '?'):
for name, layer in model.named_modules():
pbar.print(name)
import sys
sys.exit(0)
if getattr(args, 'layers', None) is None:
# Skip wrappers with only one named model
container = model
prefix = ''
while len(list(container.named_children())) == 1:
name, container = next(container.named_children())
prefix += name + '.'
# Default to all nontrivial top-level layers except last.
args.layers = [prefix + name
for name, module in container.named_children()
if type(module).__module__ not in [
# Skip ReLU and other activations.
'torch.nn.modules.activation',
# Skip pooling layers.
'torch.nn.modules.pooling']
][:-1]
pbar.print('Defaulting to layers: %s' % ' '.join(args.layers))
# Now wrap the model for instrumentation.
model = InstrumentedModel(model)
model.meta = meta
# Instrument the layers.
model.retain_layers(args.layers)
model.eval()
if args.cuda:
model.cuda()
# Annotate input, output, and feature shapes
annotate_model_shapes(model,
gen=getattr(args, 'gen', False),
imgsize=getattr(args, 'imgsize', None))
return model
def annotate_model_shapes(model, gen=False, imgsize=None):
assert (imgsize is not None) or gen
# Figure the input shape.
if gen:
# We can guess a generator's input shape by looking at the model.
# Examine first conv in model to determine input feature size.
first_layer = [c for c in model.modules()
if isinstance(c, (torch.nn.Conv2d, torch.nn.ConvTranspose2d,
torch.nn.Linear))][0]
# 4d input if convolutional, 2d input if first layer is linear.
if isinstance(first_layer, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)):
input_shape = (1, first_layer.in_channels, 1, 1)
else:
input_shape = (1, first_layer.in_features)
else:
# For a classifier, the input image shape is given as an argument.
input_shape = (1, 3) + tuple(imgsize)
# Run the model once to observe feature shapes.
device = next(model.parameters()).device
dry_run = torch.zeros(input_shape).to(device)
with torch.no_grad():
output = model(dry_run)
# Annotate shapes.
model.input_shape = input_shape
model.feature_shape = { layer: feature.shape
for layer, feature in model.retained_features().items() }
model.output_shape = output.shape
return model
|