Spaces:
Runtime error
Runtime error
File size: 7,085 Bytes
97069e1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
'''
A simple tool to generate sample of output of a GAN,
subject to filtering, sorting, or intervention.
'''
import torch, numpy, os, argparse, numbers, sys, shutil
from PIL import Image
from torch.utils.data import TensorDataset
from netdissect.zdataset import standard_z_sample
from netdissect.progress import default_progress, verbose_progress
from netdissect.autoeval import autoimport_eval
from netdissect.workerpool import WorkerBase, WorkerPool
from netdissect.nethook import edit_layers, retain_layers
def main():
parser = argparse.ArgumentParser(description='GAN sample making utility')
parser.add_argument('--model', type=str, default=None,
help='constructor for the model to test')
parser.add_argument('--pthfile', type=str, default=None,
help='filename of .pth file for the model')
parser.add_argument('--outdir', type=str, default='images',
help='directory for image output')
parser.add_argument('--size', type=int, default=100,
help='number of images to output')
parser.add_argument('--test_size', type=int, default=None,
help='number of images to test')
parser.add_argument('--layer', type=str, default=None,
help='layer to inspect')
parser.add_argument('--seed', type=int, default=1,
help='seed')
parser.add_argument('--maximize_units', type=int, nargs='+', default=None,
help='units to maximize')
parser.add_argument('--ablate_units', type=int, nargs='+', default=None,
help='units to ablate')
parser.add_argument('--quiet', action='store_true', default=False,
help='silences console output')
if len(sys.argv) == 1:
parser.print_usage(sys.stderr)
sys.exit(1)
args = parser.parse_args()
verbose_progress(not args.quiet)
# Instantiate the model
model = autoimport_eval(args.model)
if args.pthfile is not None:
data = torch.load(args.pthfile)
if 'state_dict' in data:
meta = {}
for key in data:
if isinstance(data[key], numbers.Number):
meta[key] = data[key]
data = data['state_dict']
model.load_state_dict(data)
# Unwrap any DataParallel-wrapped model
if isinstance(model, torch.nn.DataParallel):
model = next(model.children())
# Examine first conv in model to determine input feature size.
first_layer = [c for c in model.modules()
if isinstance(c, (torch.nn.Conv2d, torch.nn.ConvTranspose2d,
torch.nn.Linear))][0]
# 4d input if convolutional, 2d input if first layer is linear.
if isinstance(first_layer, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)):
z_channels = first_layer.in_channels
spatialdims = (1, 1)
else:
z_channels = first_layer.in_features
spatialdims = ()
# Instrument the model if needed
if args.maximize_units is not None:
retain_layers(model, [args.layer])
model.cuda()
# Get the sample of z vectors
if args.maximize_units is None:
indexes = torch.arange(args.size)
z_sample = standard_z_sample(args.size, z_channels, seed=args.seed)
z_sample = z_sample.view(tuple(z_sample.shape) + spatialdims)
else:
# By default, if maximizing units, get a 'top 5%' sample.
if args.test_size is None:
args.test_size = args.size * 20
z_universe = standard_z_sample(args.test_size, z_channels,
seed=args.seed)
z_universe = z_universe.view(tuple(z_universe.shape) + spatialdims)
indexes = get_highest_znums(model, z_universe, args.maximize_units,
args.size, seed=args.seed)
z_sample = z_universe[indexes]
if args.ablate_units:
edit_layers(model, [args.layer])
dims = max(2, max(args.ablate_units) + 1) # >=2 to avoid broadcast
model.ablation[args.layer] = torch.zeros(dims)
model.ablation[args.layer][args.ablate_units] = 1
save_znum_images(args.outdir, model, z_sample, indexes,
args.layer, args.ablate_units)
copy_lightbox_to(args.outdir)
def get_highest_znums(model, z_universe, max_units, size,
batch_size=100, seed=1):
# The model should have been instrumented already
retained_items = list(model.retained.items())
assert len(retained_items) == 1
layer = retained_items[0][0]
# By default, a 10% sample
progress = default_progress()
num_units = None
with torch.no_grad():
# Pass 1: collect max activation stats
z_loader = torch.utils.data.DataLoader(TensorDataset(z_universe),
batch_size=batch_size, num_workers=2,
pin_memory=True)
scores = []
for [z] in progress(z_loader, desc='Finding max activations'):
z = z.cuda()
model(z)
feature = model.retained[layer]
num_units = feature.shape[1]
max_feature = feature[:, max_units, ...].view(
feature.shape[0], len(max_units), -1).max(2)[0]
total_feature = max_feature.sum(1)
scores.append(total_feature.cpu())
scores = torch.cat(scores, 0)
highest = (-scores).sort(0)[1][:size].sort(0)[0]
return highest
def save_znum_images(dirname, model, z_sample, indexes, layer, ablated_units,
name_template="image_{}.png", lightbox=False, batch_size=100, seed=1):
progress = default_progress()
os.makedirs(dirname, exist_ok=True)
with torch.no_grad():
# Pass 2: now generate images
z_loader = torch.utils.data.DataLoader(TensorDataset(z_sample),
batch_size=batch_size, num_workers=2,
pin_memory=True)
saver = WorkerPool(SaveImageWorker)
if ablated_units is not None:
dims = max(2, max(ablated_units) + 1) # >=2 to avoid broadcast
mask = torch.zeros(dims)
mask[ablated_units] = 1
model.ablation[layer] = mask[None,:,None,None].cuda()
for batch_num, [z] in enumerate(progress(z_loader,
desc='Saving images')):
z = z.cuda()
start_index = batch_num * batch_size
im = ((model(z) + 1) / 2 * 255).clamp(0, 255).byte().permute(
0, 2, 3, 1).cpu()
for i in range(len(im)):
index = i + start_index
if indexes is not None:
index = indexes[index].item()
filename = os.path.join(dirname, name_template.format(index))
saver.add(im[i].numpy(), filename)
saver.join()
def copy_lightbox_to(dirname):
srcdir = os.path.realpath(
os.path.join(os.getcwd(), os.path.dirname(__file__)))
shutil.copy(os.path.join(srcdir, 'lightbox.html'),
os.path.join(dirname, '+lightbox.html'))
class SaveImageWorker(WorkerBase):
def work(self, data, filename):
Image.fromarray(data).save(filename, optimize=True, quality=100)
if __name__ == '__main__':
main()
|