Paul Engstler
Initial commit
92f0e98
import torch, argparse, sys, os, numpy
from .sampler import FixedRandomSubsetSampler, FixedSubsetSampler
from torch.utils.data import DataLoader
from torchvision import transforms
from . import pbar
from . import zdataset
from . import segmenter
from . import frechet_distance
from . import parallelfolder
NUM_OBJECTS=336
def main():
parser = argparse.ArgumentParser(description='Net dissect utility',
prog='python -m %s.fsd' % __package__)
parser.add_argument('true_dir')
parser.add_argument('gen_dir')
parser.add_argument('--size', type=int, default=10000)
parser.add_argument('--cachedir', default=None)
parser.add_argument('--histout', default=None)
parser.add_argument('--maxscale', type=float, default=50)
parser.add_argument('--labelcount', type=int, default=30)
parser.add_argument('--dpi', type=float, default=100)
if len(sys.argv) == 1:
parser.print_usage(sys.stderr)
sys.exit(1)
args = parser.parse_args()
true_dir, gen_dir = args.true_dir, args.gen_dir
seed1, seed2 = [1, 1 if true_dir != gen_dir else 2]
true_tally, gen_tally = [
cached_tally_directory(d, size=args.size, cachedir=args.cachedir,
seed=seed)
for d, seed in [(true_dir, seed1), (gen_dir, seed2)]]
fsd, meandiff, covdiff = frechet_distance.sample_frechet_distance(
true_tally * 100, gen_tally * 100, return_components=True)
print('fsd: %f; meandiff: %f; covdiff: %f' % (fsd, meandiff, covdiff))
if args.histout is not None:
diff_figure(true_tally * 100, gen_tally * 100,
labelcount=args.labelcount,
maxscale=args.maxscale,
dpi=args.dpi
).savefig(args.histout)
def cached_tally_directory(directory, size=10000, cachedir=None, seed=1,
download_from=None):
basename = ('%s_segtally_%d.npy' % (directory, size)).replace('/', '_')
if seed != 1:
basename = '%d_%s' % (seed, basename)
if cachedir is not None:
filename = os.path.join(cachedir, basename.replace('/', '_'))
else:
filename = basename
if not os.path.isfile(filename) and download_from:
from urllib.request import urlretrieve
from urllib.parse import urljoin
with pbar.reporthook() as hook:
urlretrieve(urljoin(download_from, basename), filename,
reporthook=hook)
if os.path.isfile(filename):
return numpy.load(filename)
os.makedirs(cachedir, exist_ok=True)
result = tally_directory(directory, size, seed=seed)
numpy.save(filename, result)
return result
def tally_directory(directory, size=10000, seed=1):
dataset = parallelfolder.ParallelImageFolders(
[directory],
transform=transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(256),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
]))
loader = DataLoader(dataset,
sampler=FixedRandomSubsetSampler(dataset, end=size,
seed=seed),
batch_size=10, pin_memory=True)
upp = segmenter.UnifiedParsingSegmenter()
labelnames, catnames = upp.get_label_and_category_names()
result = numpy.zeros((size, NUM_OBJECTS), dtype=numpy.float)
batch_result = torch.zeros(loader.batch_size, NUM_OBJECTS,
dtype=torch.float).cuda()
with torch.no_grad():
batch_index = 0
for [batch] in pbar(loader):
seg_result = upp.segment_batch(batch.cuda())
for i in range(len(batch)):
batch_result[i] = (
seg_result[i,0].view(-1).bincount(
minlength=NUM_OBJECTS).float()
/ (seg_result.shape[2] * seg_result.shape[3])
)
result[batch_index:batch_index+len(batch)] = (
batch_result.cpu().numpy())
batch_index += len(batch)
return result
def tally_dataset_objects(dataset, size=10000):
loader = DataLoader(dataset,
sampler=FixedRandomSubsetSampler(dataset, end=size),
batch_size=10, pin_memory=True)
upp = segmenter.UnifiedParsingSegmenter()
labelnames, catnames = upp.get_label_and_category_names()
result = numpy.zeros((size, NUM_OBJECTS), dtype=numpy.float)
batch_result = torch.zeros(loader.batch_size, NUM_OBJECTS,
dtype=torch.float).cuda()
with torch.no_grad():
batch_index = 0
for [batch] in pbar(loader):
seg_result = upp.segment_batch(batch.cuda())
for i in range(len(batch)):
batch_result[i] = (
seg_result[i,0].view(-1).bincount(
minlength=NUM_OBJECTS).float()
/ (seg_result.shape[2] * seg_result.shape[3])
)
result[batch_index:batch_index+len(batch)] = (
batch_result.cpu().numpy())
batch_index += len(batch)
return result
def tally_generated_objects(model, size=10000):
zds = zdataset.z_dataset_for_model(model, size)
loader = DataLoader(zds, batch_size=10, pin_memory=True)
upp = segmenter.UnifiedParsingSegmenter()
labelnames, catnames = upp.get_label_and_category_names()
result = numpy.zeros((size, NUM_OBJECTS), dtype=numpy.float)
batch_result = torch.zeros(loader.batch_size, NUM_OBJECTS,
dtype=torch.float).cuda()
with torch.no_grad():
batch_index = 0
for [zbatch] in pbar(loader):
img = model(zbatch.cuda())
seg_result = upp.segment_batch(img)
for i in range(len(zbatch)):
batch_result[i] = (
seg_result[i,0].view(-1).bincount(
minlength=NUM_OBJECTS).float()
/ (seg_result.shape[2] * seg_result.shape[3])
)
result[batch_index:batch_index+len(zbatch)] = (
batch_result.cpu().numpy())
batch_index += len(zbatch)
return result
def diff_figure(ttally, gtally,
labelcount=30, labelleft=True, dpi=100,
maxscale=50.0, legend=False):
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
from matplotlib.figure import Figure
tresult, gresult = [t.mean(0) for t in [ttally, gtally]]
upp = segmenter.UnifiedParsingSegmenter()
labelnames, catnames = upp.get_label_and_category_names()
x = []
labels = []
gen_amount = []
change_frac = []
true_amount = []
for label in numpy.argsort(-tresult):
if label == 0 or labelnames[label][1] == 'material':
continue
if tresult[label] == 0:
break
x.append(len(x))
labels.append(labelnames[label][0].split()[0])
true_amount.append(tresult[label].item())
gen_amount.append(gresult[label].item())
change_frac.append((float(gresult[label] - tresult[label])
/ tresult[label]))
if len(x) >= labelcount:
break
fig = Figure(dpi=dpi, figsize=(1.4 + 5.0 * labelcount / 30, 4.0))
FigureCanvas(fig)
a1, a0 = fig.subplots(2, 1, gridspec_kw = {'height_ratios':[1, 2]})
a0.bar(x, change_frac, label='relative delta')
a0.set_xticks(x)
a0.set_xticklabels(labels, rotation='vertical')
if labelleft:
a0.set_ylabel('relative delta\n(gen - train) / train')
a0.set_xlim(-1.0, len(x))
a0.set_ylim([-1, 1.1])
a0.grid(axis='y', antialiased=False, alpha=0.25)
if legend:
a0.legend(loc=2)
prev_high = None
for ix, cf in enumerate(change_frac):
if cf > 1.15:
if prev_high == (ix - 1):
offset = 0.1
else:
offset = 0.0
prev_high = ix
a0.text(ix, 1.15 + offset,
'%.1f' % cf, horizontalalignment='center', size=6)
a1.bar(x, true_amount, label='training')
a1.plot(x, gen_amount, linewidth=3, color='red', label='generated')
a1.set_yscale('log')
a1.set_xlim(-1.0, len(x))
a1.set_ylim(maxscale / 5000, maxscale)
from matplotlib.ticker import LogLocator
# a1.yaxis.set_major_locator(LogLocator(subs=(1,)))
# a1.yaxis.set_minor_locator(LogLocator(subs=(1,), numdecs=10))
# a1.yaxis.set_minor_locator(LogLocator(subs=(1,2,3,4,5,6,7,8,9)))
# a1.yaxis.set_minor_locator(yminor_locator)
if labelleft:
a1.set_ylabel('mean area\nlog scale')
if legend:
a1.legend()
a1.set_yticks([1e-2, 1e-1, 1.0, 1e+1])
a1.set_yticks([a * b for a in [1e-2, 1e-1, 1.0, 1e+1] for b in range(1,10)
if maxscale / 5000 <= a * b <= maxscale],
True) # minor ticks.
a1.set_xticks([])
fig.tight_layout()
return fig
if __name__ == '__main__':
main()