''' A simple tool to generate sample of output of a GAN, and apply semantic segmentation on the output. ''' import torch, numpy, os, argparse, sys, shutil from PIL import Image from torch.utils.data import TensorDataset from netdissect.zdataset import standard_z_sample, z_dataset_for_model from netdissect.progress import default_progress, verbose_progress from netdissect.autoeval import autoimport_eval from netdissect.workerpool import WorkerBase, WorkerPool from netdissect.nethook import edit_layers, retain_layers from netdissect.segviz import segment_visualization from netdissect.segmenter import UnifiedParsingSegmenter from scipy.io import savemat def main(): parser = argparse.ArgumentParser(description='GAN output segmentation util') parser.add_argument('--model', type=str, default= 'netdissect.proggan.from_pth_file("' + 'models/karras/churchoutdoor_lsun.pth")', help='constructor for the model to test') parser.add_argument('--outdir', type=str, default='images', help='directory for image output') parser.add_argument('--size', type=int, default=100, help='number of images to output') parser.add_argument('--seed', type=int, default=1, help='seed') parser.add_argument('--quiet', action='store_true', default=False, help='silences console output') #if len(sys.argv) == 1: # parser.print_usage(sys.stderr) # sys.exit(1) args = parser.parse_args() verbose_progress(not args.quiet) # Instantiate the model model = autoimport_eval(args.model) # Make the standard z z_dataset = z_dataset_for_model(model, size=args.size) # Make the segmenter segmenter = UnifiedParsingSegmenter() # Write out text labels labels, cats = segmenter.get_label_and_category_names() with open(os.path.join(args.outdir, 'labels.txt'), 'w') as f: for i, (label, cat) in enumerate(labels): f.write('%s %s\n' % (label, cat)) # Move models to cuda model.cuda() batch_size = 10 progress = default_progress() dirname = args.outdir with torch.no_grad(): # Pass 2: now generate images z_loader = torch.utils.data.DataLoader(z_dataset, batch_size=batch_size, num_workers=2, pin_memory=True) for batch_num, [z] in enumerate(progress(z_loader, desc='Saving images')): z = z.cuda() start_index = batch_num * batch_size tensor_im = model(z) byte_im = ((tensor_im + 1) / 2 * 255).clamp(0, 255).byte().permute( 0, 2, 3, 1).cpu() seg = segmenter.segment_batch(tensor_im) for i in range(len(tensor_im)): index = i + start_index filename = os.path.join(dirname, '%d_img.jpg' % index) Image.fromarray(byte_im[i].numpy()).save( filename, optimize=True, quality=100) filename = os.path.join(dirname, '%d_seg.mat' % index) savemat(filename, dict(seg=seg[i].cpu().numpy())) filename = os.path.join(dirname, '%d_seg.png' % index) Image.fromarray(segment_visualization(seg[i].cpu().numpy(), tensor_im.shape[2:])).save(filename) srcdir = os.path.realpath( os.path.join(os.getcwd(), os.path.dirname(__file__))) shutil.copy(os.path.join(srcdir, 'lightbox.html'), os.path.join(dirname, '+lightbox.html')) if __name__ == '__main__': main()