File size: 7,957 Bytes
6064c9d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 |
A simple tool to generate sample of output of a GAN,
subject to filtering, sorting, or intervention.
import torch, numpy, os, argparse, sys, shutil, errno, numbers
from PIL import Image
from import TensorDataset
from netdissect.zdataset import standard_z_sample
from netdissect.progress import default_progress, verbose_progress
from netdissect.autoeval import autoimport_eval
from netdissect.workerpool import WorkerBase, WorkerPool
from netdissect.nethook import retain_layers
from netdissect.runningstats import RunningTopK
def main():
parser = argparse.ArgumentParser(description='GAN sample making utility')
parser.add_argument('--model', type=str, default=None,
help='constructor for the model to test')
parser.add_argument('--pthfile', type=str, default=None,
help='filename of .pth file for the model')
parser.add_argument('--outdir', type=str, default='images',
help='directory for image output')
parser.add_argument('--size', type=int, default=100,
help='number of images to output')
parser.add_argument('--test_size', type=int, default=None,
help='number of images to test')
parser.add_argument('--layer', type=str, default=None,
help='layer to inspect')
parser.add_argument('--seed', type=int, default=1,
parser.add_argument('--quiet', action='store_true', default=False,
help='silences console output')
if len(sys.argv) == 1:
args = parser.parse_args()
verbose_progress(not args.quiet)
# Instantiate the model
model = autoimport_eval(args.model)
if args.pthfile is not None:
data = torch.load(args.pthfile)
if 'state_dict' in data:
meta = {}
for key in data:
if isinstance(data[key], numbers.Number):
meta[key] = data[key]
data = data['state_dict']
# Unwrap any DataParallel-wrapped model
if isinstance(model, torch.nn.DataParallel):
model = next(model.children())
# Examine first conv in model to determine input feature size.
first_layer = [c for c in model.modules()
if isinstance(c, (torch.nn.Conv2d, torch.nn.ConvTranspose2d,
# 4d input if convolutional, 2d input if first layer is linear.
if isinstance(first_layer, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)):
z_channels = first_layer.in_channels
spatialdims = (1, 1)
z_channels = first_layer.in_features
spatialdims = ()
# Instrument the model
retain_layers(model, [args.layer])
if args.test_size is None:
args.test_size = args.size * 20
z_universe = standard_z_sample(args.test_size, z_channels,
z_universe = z_universe.view(tuple(z_universe.shape) + spatialdims)
indexes = get_all_highest_znums(
model, z_universe, args.size, seed=args.seed)
save_chosen_unit_images(args.outdir, model, z_universe, indexes,
def get_all_highest_znums(model, z_universe, size,
batch_size=10, seed=1):
# The model should have been instrumented already
retained_items = list(model.retained.items())
assert len(retained_items) == 1
layer = retained_items[0][0]
# By default, a 10% sample
progress = default_progress()
num_units = None
with torch.no_grad():
# Pass 1: collect max activation stats
z_loader =,
batch_size=batch_size, num_workers=2,
rtk = RunningTopK(k=size)
for [z] in progress(z_loader, desc='Finding max activations'):
z = z.cuda()
feature = model.retained[layer]
num_units = feature.shape[1]
max_feature = feature.view(
feature.shape[0], num_units, -1).max(2)[0]
td, ti = rtk.result()
highest = ti.sort(1)[0]
return highest
def save_chosen_unit_images(dirname, model, z_universe, indices,
lightbox=False, batch_size=50, seed=1):
all_indices = torch.unique(indices.view(-1), sorted=True)
z_sample = z_universe[all_indices]
progress = default_progress()
sdir = os.path.join(dirname, shared_dir)
created_hashdirs = set()
for index in range(len(z_universe)):
hd = hashdir(index)
if hd not in created_hashdirs:
os.makedirs(os.path.join(sdir, hd), exist_ok=True)
with torch.no_grad():
# Pass 2: now generate images
z_loader =,
batch_size=batch_size, num_workers=2,
saver = WorkerPool(SaveImageWorker)
for batch_num, [z] in enumerate(progress(z_loader,
desc='Saving images')):
z = z.cuda()
start_index = batch_num * batch_size
im = ((model(z) + 1) / 2 * 255).clamp(0, 255).byte().permute(
0, 2, 3, 1).cpu()
for i in range(len(im)):
index = all_indices[i + start_index].item()
filename = os.path.join(sdir, hashdir(index),
saver.add(im[i].numpy(), filename)
linker = WorkerPool(MakeLinkWorker)
for u in progress(range(len(indices)), desc='Making links'):
udir = os.path.join(dirname, unitdir_template.format(u))
os.makedirs(udir, exist_ok=True)
for r in range(indices.shape[1]):
index = indices[u,r].item()
fn = name_template.format(index)
# sourcename = os.path.join('..', shared_dir, fn)
sourcename = os.path.join(sdir, hashdir(index), fn)
targname = os.path.join(udir, fn)
linker.add(sourcename, targname)
if lightbox:
def copy_lightbox_to(dirname):
srcdir = os.path.realpath(
os.path.join(os.getcwd(), os.path.dirname(__file__)))
shutil.copy(os.path.join(srcdir, 'lightbox.html'),
os.path.join(dirname, '+lightbox.html'))
def hashdir(index):
# To keep the number of files the shared directory lower, split it
# into 100 subdirectories named as follows.
return '%02d' % (index % 100)
class SaveImageWorker(WorkerBase):
# Saving images can be sped up by sending jpeg encoding and
# file-writing work to a pool.
def work(self, data, filename):
Image.fromarray(data).save(filename, optimize=True, quality=100)
class MakeLinkWorker(WorkerBase):
# Creating symbolic links is a bit slow and can be done faster
# in parallel rather than waiting for each to be created.
def work(self, sourcename, targname):
try:, targname)
except OSError as e:
if e.errno == errno.EEXIST:
os.remove(targname), targname)
class MakeSyminkWorker(WorkerBase):
# Creating symbolic links is a bit slow and can be done faster
# in parallel rather than waiting for each to be created.
def work(self, sourcename, targname):
os.symlink(sourcename, targname)
except OSError as e:
if e.errno == errno.EEXIST:
os.symlink(sourcename, targname)
if __name__ == '__main__':