import torch, argparse, sys, os, numpy from .sampler import FixedRandomSubsetSampler, FixedSubsetSampler from torch.utils.data import DataLoader from torchvision import transforms from . import pbar from . import zdataset from . import segmenter from . import frechet_distance from . import parallelfolder NUM_OBJECTS=336 def main(): parser = argparse.ArgumentParser(description='Net dissect utility', prog='python -m %s.fsd' % __package__) parser.add_argument('true_dir') parser.add_argument('gen_dir') parser.add_argument('--size', type=int, default=10000) parser.add_argument('--cachedir', default=None) parser.add_argument('--histout', default=None) parser.add_argument('--maxscale', type=float, default=50) parser.add_argument('--labelcount', type=int, default=30) parser.add_argument('--dpi', type=float, default=100) if len(sys.argv) == 1: parser.print_usage(sys.stderr) sys.exit(1) args = parser.parse_args() true_dir, gen_dir = args.true_dir, args.gen_dir seed1, seed2 = [1, 1 if true_dir != gen_dir else 2] true_tally, gen_tally = [ cached_tally_directory(d, size=args.size, cachedir=args.cachedir, seed=seed) for d, seed in [(true_dir, seed1), (gen_dir, seed2)]] fsd, meandiff, covdiff = frechet_distance.sample_frechet_distance( true_tally * 100, gen_tally * 100, return_components=True) print('fsd: %f; meandiff: %f; covdiff: %f' % (fsd, meandiff, covdiff)) if args.histout is not None: diff_figure(true_tally * 100, gen_tally * 100, labelcount=args.labelcount, maxscale=args.maxscale, dpi=args.dpi ).savefig(args.histout) def cached_tally_directory(directory, size=10000, cachedir=None, seed=1, download_from=None): basename = ('%s_segtally_%d.npy' % (directory, size)).replace('/', '_') if seed != 1: basename = '%d_%s' % (seed, basename) if cachedir is not None: filename = os.path.join(cachedir, basename.replace('/', '_')) else: filename = basename if not os.path.isfile(filename) and download_from: from urllib.request import urlretrieve from urllib.parse import urljoin with pbar.reporthook() as hook: urlretrieve(urljoin(download_from, basename), filename, reporthook=hook) if os.path.isfile(filename): return numpy.load(filename) os.makedirs(cachedir, exist_ok=True) result = tally_directory(directory, size, seed=seed) numpy.save(filename, result) return result def tally_directory(directory, size=10000, seed=1): dataset = parallelfolder.ParallelImageFolders( [directory], transform=transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(256), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ])) loader = DataLoader(dataset, sampler=FixedRandomSubsetSampler(dataset, end=size, seed=seed), batch_size=10, pin_memory=True) upp = segmenter.UnifiedParsingSegmenter() labelnames, catnames = upp.get_label_and_category_names() result = numpy.zeros((size, NUM_OBJECTS), dtype=numpy.float) batch_result = torch.zeros(loader.batch_size, NUM_OBJECTS, dtype=torch.float).cuda() with torch.no_grad(): batch_index = 0 for [batch] in pbar(loader): seg_result = upp.segment_batch(batch.cuda()) for i in range(len(batch)): batch_result[i] = ( seg_result[i,0].view(-1).bincount( minlength=NUM_OBJECTS).float() / (seg_result.shape[2] * seg_result.shape[3]) ) result[batch_index:batch_index+len(batch)] = ( batch_result.cpu().numpy()) batch_index += len(batch) return result def tally_dataset_objects(dataset, size=10000): loader = DataLoader(dataset, sampler=FixedRandomSubsetSampler(dataset, end=size), batch_size=10, pin_memory=True) upp = segmenter.UnifiedParsingSegmenter() labelnames, catnames = upp.get_label_and_category_names() result = numpy.zeros((size, NUM_OBJECTS), dtype=numpy.float) batch_result = torch.zeros(loader.batch_size, NUM_OBJECTS, dtype=torch.float).cuda() with torch.no_grad(): batch_index = 0 for [batch] in pbar(loader): seg_result = upp.segment_batch(batch.cuda()) for i in range(len(batch)): batch_result[i] = ( seg_result[i,0].view(-1).bincount( minlength=NUM_OBJECTS).float() / (seg_result.shape[2] * seg_result.shape[3]) ) result[batch_index:batch_index+len(batch)] = ( batch_result.cpu().numpy()) batch_index += len(batch) return result def tally_generated_objects(model, size=10000): zds = zdataset.z_dataset_for_model(model, size) loader = DataLoader(zds, batch_size=10, pin_memory=True) upp = segmenter.UnifiedParsingSegmenter() labelnames, catnames = upp.get_label_and_category_names() result = numpy.zeros((size, NUM_OBJECTS), dtype=numpy.float) batch_result = torch.zeros(loader.batch_size, NUM_OBJECTS, dtype=torch.float).cuda() with torch.no_grad(): batch_index = 0 for [zbatch] in pbar(loader): img = model(zbatch.cuda()) seg_result = upp.segment_batch(img) for i in range(len(zbatch)): batch_result[i] = ( seg_result[i,0].view(-1).bincount( minlength=NUM_OBJECTS).float() / (seg_result.shape[2] * seg_result.shape[3]) ) result[batch_index:batch_index+len(zbatch)] = ( batch_result.cpu().numpy()) batch_index += len(zbatch) return result def diff_figure(ttally, gtally, labelcount=30, labelleft=True, dpi=100, maxscale=50.0, legend=False): from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas from matplotlib.figure import Figure tresult, gresult = [t.mean(0) for t in [ttally, gtally]] upp = segmenter.UnifiedParsingSegmenter() labelnames, catnames = upp.get_label_and_category_names() x = [] labels = [] gen_amount = [] change_frac = [] true_amount = [] for label in numpy.argsort(-tresult): if label == 0 or labelnames[label][1] == 'material': continue if tresult[label] == 0: break x.append(len(x)) labels.append(labelnames[label][0].split()[0]) true_amount.append(tresult[label].item()) gen_amount.append(gresult[label].item()) change_frac.append((float(gresult[label] - tresult[label]) / tresult[label])) if len(x) >= labelcount: break fig = Figure(dpi=dpi, figsize=(1.4 + 5.0 * labelcount / 30, 4.0)) FigureCanvas(fig) a1, a0 = fig.subplots(2, 1, gridspec_kw = {'height_ratios':[1, 2]}) a0.bar(x, change_frac, label='relative delta') a0.set_xticks(x) a0.set_xticklabels(labels, rotation='vertical') if labelleft: a0.set_ylabel('relative delta\n(gen - train) / train') a0.set_xlim(-1.0, len(x)) a0.set_ylim([-1, 1.1]) a0.grid(axis='y', antialiased=False, alpha=0.25) if legend: a0.legend(loc=2) prev_high = None for ix, cf in enumerate(change_frac): if cf > 1.15: if prev_high == (ix - 1): offset = 0.1 else: offset = 0.0 prev_high = ix a0.text(ix, 1.15 + offset, '%.1f' % cf, horizontalalignment='center', size=6) a1.bar(x, true_amount, label='training') a1.plot(x, gen_amount, linewidth=3, color='red', label='generated') a1.set_yscale('log') a1.set_xlim(-1.0, len(x)) a1.set_ylim(maxscale / 5000, maxscale) from matplotlib.ticker import LogLocator # a1.yaxis.set_major_locator(LogLocator(subs=(1,))) # a1.yaxis.set_minor_locator(LogLocator(subs=(1,), numdecs=10)) # a1.yaxis.set_minor_locator(LogLocator(subs=(1,2,3,4,5,6,7,8,9))) # a1.yaxis.set_minor_locator(yminor_locator) if labelleft: a1.set_ylabel('mean area\nlog scale') if legend: a1.legend() a1.set_yticks([1e-2, 1e-1, 1.0, 1e+1]) a1.set_yticks([a * b for a in [1e-2, 1e-1, 1.0, 1e+1] for b in range(1,10) if maxscale / 5000 <= a * b <= maxscale], True) # minor ticks. a1.set_xticks([]) fig.tight_layout() return fig if __name__ == '__main__': main()