|
import numbers |
|
import torch |
|
from netdissect.autoeval import autoimport_eval |
|
from netdissect import pbar |
|
from netdissect.nethook import InstrumentedModel |
|
from netdissect.easydict import EasyDict |
|
|
|
def create_instrumented_model(args, **kwargs): |
|
''' |
|
Creates an instrumented model out of a namespace of arguments that |
|
correspond to ArgumentParser command-line args: |
|
model: a string to evaluate as a constructor for the model. |
|
pthfile: (optional) filename of .pth file for the model. |
|
layers: a list of layers to instrument, defaulted if not provided. |
|
edit: True to instrument the layers for editing. |
|
gen: True for a generator model. One-pixel input assumed. |
|
imgsize: For non-generator models, (y, x) dimensions for RGB input. |
|
cuda: True to use CUDA. |
|
|
|
The constructed model will be decorated with the following attributes: |
|
input_shape: (usually 4d) tensor shape for single-image input. |
|
output_shape: 4d tensor shape for output. |
|
feature_shape: map of layer names to 4d tensor shape for featuremaps. |
|
retained: map of layernames to tensors, filled after every evaluation. |
|
ablation: if editing, map of layernames to [0..1] alpha values to fill. |
|
replacement: if editing, map of layernames to values to fill. |
|
|
|
When editing, the feature value x will be replaced by: |
|
`x = (replacement * ablation) + (x * (1 - ablation))` |
|
''' |
|
|
|
args = EasyDict(vars(args), **kwargs) |
|
|
|
|
|
if args.model is None: |
|
pbar.print('No model specified') |
|
return None |
|
if isinstance(args.model, torch.nn.Module): |
|
model = args.model |
|
else: |
|
model = autoimport_eval(args.model) |
|
|
|
if isinstance(model, torch.nn.DataParallel): |
|
model = next(model.children()) |
|
|
|
|
|
meta = {} |
|
if getattr(args, 'pthfile', None) is not None: |
|
data = torch.load(args.pthfile) |
|
modelkey = getattr(args, 'modelkey', 'state_dict') |
|
if modelkey in data: |
|
meta = {} |
|
for key in data: |
|
if isinstance(data[key], numbers.Number): |
|
meta[key] = data[key] |
|
data = data[modelkey] |
|
submodule = getattr(args, 'submodule', None) |
|
if submodule is not None and len(submodule): |
|
remove_prefix = submodule + '.' |
|
data = { k[len(remove_prefix):]: v for k, v in data.items() |
|
if k.startswith(remove_prefix)} |
|
if not len(data): |
|
pbar.print('No submodule %s found in %s' % |
|
(submodule, args.pthfile)) |
|
return None |
|
model.load_state_dict(data, strict=not getattr(args, 'unstrict', False)) |
|
|
|
|
|
if getattr(args, 'layer', None) is not None: |
|
args.layers = [args.layer] |
|
|
|
if getattr(args, 'layers', None) is not None: |
|
if len(args.layers) == 1 and args.layers[0] == ('?', '?'): |
|
for name, layer in model.named_modules(): |
|
pbar.print(name) |
|
import sys |
|
sys.exit(0) |
|
if getattr(args, 'layers', None) is None: |
|
|
|
container = model |
|
prefix = '' |
|
while len(list(container.named_children())) == 1: |
|
name, container = next(container.named_children()) |
|
prefix += name + '.' |
|
|
|
args.layers = [prefix + name |
|
for name, module in container.named_children() |
|
if type(module).__module__ not in [ |
|
|
|
'torch.nn.modules.activation', |
|
|
|
'torch.nn.modules.pooling'] |
|
][:-1] |
|
pbar.print('Defaulting to layers: %s' % ' '.join(args.layers)) |
|
|
|
|
|
model = InstrumentedModel(model) |
|
model.meta = meta |
|
|
|
|
|
model.retain_layers(args.layers) |
|
model.eval() |
|
if args.cuda: |
|
model.cuda() |
|
|
|
|
|
annotate_model_shapes(model, |
|
gen=getattr(args, 'gen', False), |
|
imgsize=getattr(args, 'imgsize', None)) |
|
return model |
|
|
|
def annotate_model_shapes(model, gen=False, imgsize=None): |
|
assert (imgsize is not None) or gen |
|
|
|
|
|
if gen: |
|
|
|
|
|
first_layer = [c for c in model.modules() |
|
if isinstance(c, (torch.nn.Conv2d, torch.nn.ConvTranspose2d, |
|
torch.nn.Linear))][0] |
|
|
|
if isinstance(first_layer, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)): |
|
input_shape = (1, first_layer.in_channels, 1, 1) |
|
else: |
|
input_shape = (1, first_layer.in_features) |
|
else: |
|
|
|
input_shape = (1, 3) + tuple(imgsize) |
|
|
|
|
|
device = next(model.parameters()).device |
|
dry_run = torch.zeros(input_shape).to(device) |
|
with torch.no_grad(): |
|
output = model(dry_run) |
|
|
|
|
|
model.input_shape = input_shape |
|
model.feature_shape = { layer: feature.shape |
|
for layer, feature in model.retained_features().items() } |
|
model.output_shape = output.shape |
|
return model |
|
|