''' A simple tool to generate sample of output of a GAN, subject to filtering, sorting, or intervention. ''' import torch, numpy, os, argparse, sys, shutil, errno, numbers from PIL import Image from torch.utils.data import TensorDataset from netdissect.zdataset import standard_z_sample from netdissect.progress import default_progress, verbose_progress from netdissect.autoeval import autoimport_eval from netdissect.workerpool import WorkerBase, WorkerPool from netdissect.nethook import retain_layers from netdissect.runningstats import RunningTopK def main(): parser = argparse.ArgumentParser(description='GAN sample making utility') parser.add_argument('--model', type=str, default=None, help='constructor for the model to test') parser.add_argument('--pthfile', type=str, default=None, help='filename of .pth file for the model') parser.add_argument('--outdir', type=str, default='images', help='directory for image output') parser.add_argument('--size', type=int, default=100, help='number of images to output') parser.add_argument('--test_size', type=int, default=None, help='number of images to test') parser.add_argument('--layer', type=str, default=None, help='layer to inspect') parser.add_argument('--seed', type=int, default=1, help='seed') parser.add_argument('--quiet', action='store_true', default=False, help='silences console output') if len(sys.argv) == 1: parser.print_usage(sys.stderr) sys.exit(1) args = parser.parse_args() verbose_progress(not args.quiet) # Instantiate the model model = autoimport_eval(args.model) if args.pthfile is not None: data = torch.load(args.pthfile) if 'state_dict' in data: meta = {} for key in data: if isinstance(data[key], numbers.Number): meta[key] = data[key] data = data['state_dict'] model.load_state_dict(data) # Unwrap any DataParallel-wrapped model if isinstance(model, torch.nn.DataParallel): model = next(model.children()) # Examine first conv in model to determine input feature size. first_layer = [c for c in model.modules() if isinstance(c, (torch.nn.Conv2d, torch.nn.ConvTranspose2d, torch.nn.Linear))][0] # 4d input if convolutional, 2d input if first layer is linear. if isinstance(first_layer, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)): z_channels = first_layer.in_channels spatialdims = (1, 1) else: z_channels = first_layer.in_features spatialdims = () # Instrument the model retain_layers(model, [args.layer]) model.cuda() if args.test_size is None: args.test_size = args.size * 20 z_universe = standard_z_sample(args.test_size, z_channels, seed=args.seed) z_universe = z_universe.view(tuple(z_universe.shape) + spatialdims) indexes = get_all_highest_znums( model, z_universe, args.size, seed=args.seed) save_chosen_unit_images(args.outdir, model, z_universe, indexes, lightbox=True) def get_all_highest_znums(model, z_universe, size, batch_size=10, seed=1): # The model should have been instrumented already retained_items = list(model.retained.items()) assert len(retained_items) == 1 layer = retained_items[0][0] # By default, a 10% sample progress = default_progress() num_units = None with torch.no_grad(): # Pass 1: collect max activation stats z_loader = torch.utils.data.DataLoader(TensorDataset(z_universe), batch_size=batch_size, num_workers=2, pin_memory=True) rtk = RunningTopK(k=size) for [z] in progress(z_loader, desc='Finding max activations'): z = z.cuda() model(z) feature = model.retained[layer] num_units = feature.shape[1] max_feature = feature.view( feature.shape[0], num_units, -1).max(2)[0] rtk.add(max_feature) td, ti = rtk.result() highest = ti.sort(1)[0] return highest def save_chosen_unit_images(dirname, model, z_universe, indices, shared_dir="shared_images", unitdir_template="unit_{}", name_template="image_{}.jpg", lightbox=False, batch_size=50, seed=1): all_indices = torch.unique(indices.view(-1), sorted=True) z_sample = z_universe[all_indices] progress = default_progress() sdir = os.path.join(dirname, shared_dir) created_hashdirs = set() for index in range(len(z_universe)): hd = hashdir(index) if hd not in created_hashdirs: created_hashdirs.add(hd) os.makedirs(os.path.join(sdir, hd), exist_ok=True) with torch.no_grad(): # Pass 2: now generate images z_loader = torch.utils.data.DataLoader(TensorDataset(z_sample), batch_size=batch_size, num_workers=2, pin_memory=True) saver = WorkerPool(SaveImageWorker) for batch_num, [z] in enumerate(progress(z_loader, desc='Saving images')): z = z.cuda() start_index = batch_num * batch_size im = ((model(z) + 1) / 2 * 255).clamp(0, 255).byte().permute( 0, 2, 3, 1).cpu() for i in range(len(im)): index = all_indices[i + start_index].item() filename = os.path.join(sdir, hashdir(index), name_template.format(index)) saver.add(im[i].numpy(), filename) saver.join() linker = WorkerPool(MakeLinkWorker) for u in progress(range(len(indices)), desc='Making links'): udir = os.path.join(dirname, unitdir_template.format(u)) os.makedirs(udir, exist_ok=True) for r in range(indices.shape[1]): index = indices[u,r].item() fn = name_template.format(index) # sourcename = os.path.join('..', shared_dir, fn) sourcename = os.path.join(sdir, hashdir(index), fn) targname = os.path.join(udir, fn) linker.add(sourcename, targname) if lightbox: copy_lightbox_to(udir) linker.join() def copy_lightbox_to(dirname): srcdir = os.path.realpath( os.path.join(os.getcwd(), os.path.dirname(__file__))) shutil.copy(os.path.join(srcdir, 'lightbox.html'), os.path.join(dirname, '+lightbox.html')) def hashdir(index): # To keep the number of files the shared directory lower, split it # into 100 subdirectories named as follows. return '%02d' % (index % 100) class SaveImageWorker(WorkerBase): # Saving images can be sped up by sending jpeg encoding and # file-writing work to a pool. def work(self, data, filename): Image.fromarray(data).save(filename, optimize=True, quality=100) class MakeLinkWorker(WorkerBase): # Creating symbolic links is a bit slow and can be done faster # in parallel rather than waiting for each to be created. def work(self, sourcename, targname): try: os.link(sourcename, targname) except OSError as e: if e.errno == errno.EEXIST: os.remove(targname) os.link(sourcename, targname) else: raise class MakeSyminkWorker(WorkerBase): # Creating symbolic links is a bit slow and can be done faster # in parallel rather than waiting for each to be created. def work(self, sourcename, targname): try: os.symlink(sourcename, targname) except OSError as e: if e.errno == errno.EEXIST: os.remove(targname) os.symlink(sourcename, targname) else: raise if __name__ == '__main__': main()