CharacterGAN / netdissect /dissection.py
mfrashad's picture
Init code
8f87579
'''
To run dissection:
1. Load up the convolutional model you wish to dissect, and wrap it in
an InstrumentedModel; then call imodel.retain_layers([layernames,..])
to instrument the layers of interest.
2. Load the segmentation dataset using the BrodenDataset class;
use the transform_image argument to normalize images to be
suitable for the model, or the size argument to truncate the dataset.
3. Choose a directory in which to write the output, and call
dissect(outdir, model, dataset).
Example:
from dissect import InstrumentedModel, dissect
from broden import BrodenDataset
model = InstrumentedModel(load_my_model())
model.eval()
model.cuda()
model.retain_layers(['conv1', 'conv2', 'conv3', 'conv4', 'conv5'])
bds = BrodenDataset('dataset/broden1_227',
transform_image=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(IMAGE_MEAN, IMAGE_STDEV)]),
size=1000)
dissect('result/dissect', model, bds,
examples_per_unit=10)
'''
import torch, numpy, os, re, json, shutil, types, tempfile, torchvision
# import warnings
# warnings.simplefilter('error', UserWarning)
from PIL import Image
from xml.etree import ElementTree as et
from collections import OrderedDict, defaultdict
from .progress import verbose_progress, default_progress, print_progress
from .progress import desc_progress
from .runningstats import RunningQuantile, RunningTopK
from .runningstats import RunningCrossCovariance, RunningConditionalQuantile
from .sampler import FixedSubsetSampler
from .actviz import activation_visualization
from .segviz import segment_visualization, high_contrast
from .workerpool import WorkerBase, WorkerPool
from .segmenter import UnifiedParsingSegmenter
def dissect(outdir, model, dataset,
segrunner=None,
train_dataset=None,
model_segmenter=None,
quantile_threshold=0.005,
iou_threshold=0.05,
iqr_threshold=0.01,
examples_per_unit=100,
batch_size=100,
num_workers=24,
seg_batch_size=5,
make_images=True,
make_labels=True,
make_maxiou=False,
make_covariance=False,
make_report=True,
make_row_images=True,
make_single_images=False,
rank_all_labels=False,
netname=None,
meta=None,
merge=None,
settings=None,
):
'''
Runs net dissection in-memory, using pytorch, and saves visualizations
and metadata into outdir.
'''
assert not model.training, 'Run model.eval() before dissection'
if netname is None:
netname = type(model).__name__
if segrunner is None:
segrunner = ClassifierSegRunner(dataset)
if train_dataset is None:
train_dataset = dataset
make_iqr = (quantile_threshold == 'iqr')
with torch.no_grad():
device = next(model.parameters()).device
levels = None
labelnames, catnames = None, None
maxioudata, iqrdata = None, None
labeldata = None
iqrdata, cov = None, None
labelnames, catnames = segrunner.get_label_and_category_names()
label_category = [catnames.index(c) if c in catnames else 0
for l, c in labelnames]
# First, always collect qunatiles and topk information.
segloader = torch.utils.data.DataLoader(dataset,
batch_size=batch_size, num_workers=num_workers,
pin_memory=(device.type == 'cuda'))
quantiles, topk = collect_quantiles_and_topk(outdir, model,
segloader, segrunner, k=examples_per_unit)
# Thresholds can be automatically chosen by maximizing iqr
if make_iqr:
# Get thresholds based on an IQR optimization
segloader = torch.utils.data.DataLoader(train_dataset,
batch_size=1, num_workers=num_workers,
pin_memory=(device.type == 'cuda'))
iqrdata = collect_iqr(outdir, model, segloader, segrunner)
max_iqr, full_iqr_levels = iqrdata[:2]
max_iqr_agreement = iqrdata[4]
# qualified_iqr[max_iqr_quantile[layer] > 0.5] = 0
levels = {layer: full_iqr_levels[layer][
max_iqr[layer].max(0)[1],
torch.arange(max_iqr[layer].shape[1])].to(device)
for layer in full_iqr_levels}
else:
levels = {k: qc.quantiles([1.0 - quantile_threshold])[:,0]
for k, qc in quantiles.items()}
quantiledata = (topk, quantiles, levels, quantile_threshold)
if make_images:
segloader = torch.utils.data.DataLoader(dataset,
batch_size=batch_size, num_workers=num_workers,
pin_memory=(device.type == 'cuda'))
generate_images(outdir, model, dataset, topk, levels, segrunner,
row_length=examples_per_unit, batch_size=seg_batch_size,
row_images=make_row_images,
single_images=make_single_images,
num_workers=num_workers)
if make_maxiou:
assert train_dataset, "Need training dataset for maxiou."
segloader = torch.utils.data.DataLoader(train_dataset,
batch_size=1, num_workers=num_workers,
pin_memory=(device.type == 'cuda'))
maxioudata = collect_maxiou(outdir, model, segloader,
segrunner)
if make_labels:
segloader = torch.utils.data.DataLoader(dataset,
batch_size=1, num_workers=num_workers,
pin_memory=(device.type == 'cuda'))
iou_scores, iqr_scores, tcs, lcs, ccs, ics = (
collect_bincounts(outdir, model, segloader,
levels, segrunner))
labeldata = (iou_scores, iqr_scores, lcs, ccs, ics, iou_threshold,
iqr_threshold)
if make_covariance:
segloader = torch.utils.data.DataLoader(dataset,
batch_size=seg_batch_size,
num_workers=num_workers,
pin_memory=(device.type == 'cuda'))
cov = collect_covariance(outdir, model, segloader, segrunner)
if make_report:
generate_report(outdir,
quantiledata=quantiledata,
labelnames=labelnames,
catnames=catnames,
labeldata=labeldata,
maxioudata=maxioudata,
iqrdata=iqrdata,
covariancedata=cov,
rank_all_labels=rank_all_labels,
netname=netname,
meta=meta,
mergedata=merge,
settings=settings)
return quantiledata, labeldata
def generate_report(outdir, quantiledata, labelnames=None, catnames=None,
labeldata=None, maxioudata=None, iqrdata=None, covariancedata=None,
rank_all_labels=False, netname='Model', meta=None, settings=None,
mergedata=None):
'''
Creates dissection.json reports and summary bargraph.svg files in the
specified output directory, and copies a dissection.html interface
to go along with it.
'''
all_layers = []
# Current source code directory, for html to copy.
srcdir = os.path.realpath(
os.path.join(os.getcwd(), os.path.dirname(__file__)))
# Unpack arguments
topk, quantiles, levels, quantile_threshold = quantiledata
top_record = dict(
netname=netname,
meta=meta,
default_ranking='unit',
quantile_threshold=quantile_threshold)
if settings is not None:
top_record['settings'] = settings
if labeldata is not None:
iou_scores, iqr_scores, lcs, ccs, ics, iou_threshold, iqr_threshold = (
labeldata)
catorder = {'object': -7, 'scene': -6, 'part': -5,
'piece': -4,
'material': -3, 'texture': -2, 'color': -1}
for i, cat in enumerate(c for c in catnames if c not in catorder):
catorder[cat] = i
catnumber = {n: i for i, n in enumerate(catnames)}
catnumber['-'] = 0
top_record['default_ranking'] = 'label'
top_record['iou_threshold'] = iou_threshold
top_record['iqr_threshold'] = iqr_threshold
labelnumber = dict((name[0], num)
for num, name in enumerate(labelnames))
# Make a segmentation color dictionary
segcolors = {}
for i, name in enumerate(labelnames):
key = ','.join(str(s) for s in high_contrast[i % len(high_contrast)])
if key in segcolors:
segcolors[key] += '/' + name[0]
else:
segcolors[key] = name[0]
top_record['segcolors'] = segcolors
for layer in topk.keys():
units, rankings = [], []
record = dict(layer=layer, units=units, rankings=rankings)
# For every unit, we always have basic visualization information.
topa, topi = topk[layer].result()
lev = levels[layer]
for u in range(len(topa)):
units.append(dict(
unit=u,
interp=True,
level=lev[u].item(),
top=[dict(imgnum=i.item(), maxact=a.item())
for i, a in zip(topi[u], topa[u])],
))
rankings.append(dict(name="unit", score=list([
u for u in range(len(topa))])))
# TODO: consider including stats and ranking based on quantiles,
# variance, connectedness here.
# if we have labeldata, then every unit also gets a bunch of other info
if labeldata is not None:
lscore, qscore, cc, ic = [dat[layer]
for dat in [iou_scores, iqr_scores, ccs, ics]]
if iqrdata is not None:
# If we have IQR thresholds, assign labels based on that
max_iqr, max_iqr_level = iqrdata[:2]
best_label = max_iqr[layer].max(0)[1]
best_score = lscore[best_label, torch.arange(lscore.shape[1])]
best_qscore = qscore[best_label, torch.arange(lscore.shape[1])]
else:
# Otherwise, assign labels based on max iou
best_score, best_label = lscore.max(0)
best_qscore = qscore[best_label, torch.arange(qscore.shape[1])]
record['iou_threshold'] = iou_threshold,
for u, urec in enumerate(units):
score, qscore, label = (
best_score[u], best_qscore[u], best_label[u])
urec.update(dict(
iou=score.item(),
iou_iqr=qscore.item(),
lc=lcs[label].item(),
cc=cc[catnumber[labelnames[label][1]], u].item(),
ic=ic[label, u].item(),
interp=(qscore.item() > iqr_threshold and
score.item() > iou_threshold),
iou_labelnum=label.item(),
iou_label=labelnames[label.item()][0],
iou_cat=labelnames[label.item()][1],
))
if maxioudata is not None:
max_iou, max_iou_level, max_iou_quantile = maxioudata
qualified_iou = max_iou[layer].clone()
# qualified_iou[max_iou_quantile[layer] > 0.75] = 0
best_score, best_label = qualified_iou.max(0)
for u, urec in enumerate(units):
urec.update(dict(
maxiou=best_score[u].item(),
maxiou_label=labelnames[best_label[u].item()][0],
maxiou_cat=labelnames[best_label[u].item()][1],
maxiou_level=max_iou_level[layer][best_label[u], u].item(),
maxiou_quantile=max_iou_quantile[layer][
best_label[u], u].item()))
if iqrdata is not None:
[max_iqr, max_iqr_level, max_iqr_quantile,
max_iqr_iou, max_iqr_agreement] = iqrdata
qualified_iqr = max_iqr[layer].clone()
qualified_iqr[max_iqr_quantile[layer] > 0.5] = 0
best_score, best_label = qualified_iqr.max(0)
for u, urec in enumerate(units):
urec.update(dict(
iqr=best_score[u].item(),
iqr_label=labelnames[best_label[u].item()][0],
iqr_cat=labelnames[best_label[u].item()][1],
iqr_level=max_iqr_level[layer][best_label[u], u].item(),
iqr_quantile=max_iqr_quantile[layer][
best_label[u], u].item(),
iqr_iou=max_iqr_iou[layer][best_label[u], u].item()
))
if covariancedata is not None:
score = covariancedata[layer].correlation()
best_score, best_label = score.max(1)
for u, urec in enumerate(units):
urec.update(dict(
cor=best_score[u].item(),
cor_label=labelnames[best_label[u].item()][0],
cor_cat=labelnames[best_label[u].item()][1]
))
if mergedata is not None:
# Final step: if the user passed any data to merge into the
# units, merge them now. This can be used, for example, to
# indiate that a unit is not interpretable based on some
# outside analysis of unit statistics.
for lrec in mergedata.get('layers', []):
if lrec['layer'] == layer:
break
else:
lrec = None
for u, urec in enumerate(lrec.get('units', []) if lrec else []):
units[u].update(urec)
# After populating per-unit info, populate per-layer ranking info
if labeldata is not None:
# Collect all labeled units
labelunits = defaultdict(list)
all_labelunits = defaultdict(list)
for u, urec in enumerate(units):
if urec['interp']:
labelunits[urec['iou_labelnum']].append(u)
all_labelunits[urec['iou_labelnum']].append(u)
# Sort all units in order with most popular label first.
label_ordering = sorted(units,
# Sort by:
key=lambda r: (-1 if r['interp'] else 0, # interpretable
-len(labelunits[r['iou_labelnum']]), # label freq, score
-max([units[u]['iou']
for u in labelunits[r['iou_labelnum']]], default=0),
r['iou_labelnum'], # label
-r['iou'])) # unit score
# Add label and iou ranking.
rankings.append(dict(name="label", score=(numpy.argsort(list(
ur['unit'] for ur in label_ordering))).tolist()))
rankings.append(dict(name="max iou", metric="iou", score=list(
-ur['iou'] for ur in units)))
# Add ranking for top labels
# for labelnum in [n for n in sorted(
# all_labelunits.keys(), key=lambda x:
# -len(all_labelunits[x])) if len(all_labelunits[n])]:
# label = labelnames[labelnum][0]
# rankings.append(dict(name="%s-iou" % label,
# concept=label, metric='iou',
# score=(-lscore[labelnum, :]).tolist()))
# Collate labels by category then frequency.
record['labels'] = [dict(
label=labelnames[label][0],
labelnum=label,
units=labelunits[label],
cat=labelnames[label][1])
for label in (sorted(labelunits.keys(),
# Sort by:
key=lambda l: (catorder.get( # category
labelnames[l][1], 0),
-len(labelunits[l]), # label freq
-max([units[u]['iou'] for u in labelunits[l]],
default=0) # score
))) if len(labelunits[label])]
# Total number of interpretable units.
record['interpretable'] = sum(len(group['units'])
for group in record['labels'])
# Make a bargraph of labels
os.makedirs(os.path.join(outdir, safe_dir_name(layer)),
exist_ok=True)
catgroups = OrderedDict()
for _, cat in sorted([(v, k) for k, v in catorder.items()]):
catgroups[cat] = []
for rec in record['labels']:
if rec['cat'] not in catgroups:
catgroups[rec['cat']] = []
catgroups[rec['cat']].append(rec['label'])
make_svg_bargraph(
[rec['label'] for rec in record['labels']],
[len(rec['units']) for rec in record['labels']],
[(cat, len(group)) for cat, group in catgroups.items()],
filename=os.path.join(outdir, safe_dir_name(layer),
'bargraph.svg'))
# Only show the bargraph if it is non-empty.
if len(record['labels']):
record['bargraph'] = 'bargraph.svg'
if maxioudata is not None:
rankings.append(dict(name="max maxiou", metric="maxiou", score=list(
-ur['maxiou'] for ur in units)))
if iqrdata is not None:
rankings.append(dict(name="max iqr", metric="iqr", score=list(
-ur['iqr'] for ur in units)))
if covariancedata is not None:
rankings.append(dict(name="max cor", metric="cor", score=list(
-ur['cor'] for ur in units)))
all_layers.append(record)
# Now add the same rankings to every layer...
all_labels = None
if rank_all_labels:
all_labels = [name for name, cat in labelnames]
if labeldata is not None:
# Count layers+quadrants with a given label, and sort by freq
counted_labels = defaultdict(int)
for label in [
re.sub(r'-(?:t|b|l|r|tl|tr|bl|br)$', '', unitrec['iou_label'])
for record in all_layers for unitrec in record['units']]:
counted_labels[label] += 1
if all_labels is None:
all_labels = [label for count, label in sorted((-v, k)
for k, v in counted_labels.items())]
for record in all_layers:
layer = record['layer']
for label in all_labels:
labelnum = labelnumber[label]
record['rankings'].append(dict(name="%s-iou" % label,
concept=label, metric='iou',
score=(-iou_scores[layer][labelnum, :]).tolist()))
if maxioudata is not None:
if all_labels is None:
counted_labels = defaultdict(int)
for label in [
re.sub(r'-(?:t|b|l|r|tl|tr|bl|br)$', '',
unitrec['maxiou_label'])
for record in all_layers for unitrec in record['units']]:
counted_labels[label] += 1
all_labels = [label for count, label in sorted((-v, k)
for k, v in counted_labels.items())]
qualified_iou = max_iou[layer].clone()
qualified_iou[max_iou_quantile[layer] > 0.5] = 0
for record in all_layers:
layer = record['layer']
for label in all_labels:
labelnum = labelnumber[label]
record['rankings'].append(dict(name="%s-maxiou" % label,
concept=label, metric='maxiou',
score=(-qualified_iou[labelnum, :]).tolist()))
if iqrdata is not None:
if all_labels is None:
counted_labels = defaultdict(int)
for label in [
re.sub(r'-(?:t|b|l|r|tl|tr|bl|br)$', '',
unitrec['iqr_label'])
for record in all_layers for unitrec in record['units']]:
counted_labels[label] += 1
all_labels = [label for count, label in sorted((-v, k)
for k, v in counted_labels.items())]
# qualified_iqr[max_iqr_quantile[layer] > 0.5] = 0
for record in all_layers:
layer = record['layer']
qualified_iqr = max_iqr[layer].clone()
for label in all_labels:
labelnum = labelnumber[label]
record['rankings'].append(dict(name="%s-iqr" % label,
concept=label, metric='iqr',
score=(-qualified_iqr[labelnum, :]).tolist()))
if covariancedata is not None:
if all_labels is None:
counted_labels = defaultdict(int)
for label in [
re.sub(r'-(?:t|b|l|r|tl|tr|bl|br)$', '',
unitrec['cor_label'])
for record in all_layers for unitrec in record['units']]:
counted_labels[label] += 1
all_labels = [label for count, label in sorted((-v, k)
for k, v in counted_labels.items())]
for record in all_layers:
layer = record['layer']
score = covariancedata[layer].correlation()
for label in all_labels:
labelnum = labelnumber[label]
record['rankings'].append(dict(name="%s-cor" % label,
concept=label, metric='cor',
score=(-score[:, labelnum]).tolist()))
for record in all_layers:
layer = record['layer']
# Dump per-layer json inside per-layer directory
record['dirname'] = '.'
with open(os.path.join(outdir, safe_dir_name(layer), 'dissect.json'),
'w') as jsonfile:
top_record['layers'] = [record]
json.dump(top_record, jsonfile, indent=1)
# Copy the per-layer html
shutil.copy(os.path.join(srcdir, 'dissect.html'),
os.path.join(outdir, safe_dir_name(layer), 'dissect.html'))
record['dirname'] = safe_dir_name(layer)
# Dump all-layer json in parent directory
with open(os.path.join(outdir, 'dissect.json'), 'w') as jsonfile:
top_record['layers'] = all_layers
json.dump(top_record, jsonfile, indent=1)
# Copy the all-layer html
shutil.copy(os.path.join(srcdir, 'dissect.html'),
os.path.join(outdir, 'dissect.html'))
shutil.copy(os.path.join(srcdir, 'edit.html'),
os.path.join(outdir, 'edit.html'))
def generate_images(outdir, model, dataset, topk, levels,
segrunner, row_length=None, gap_pixels=5,
row_images=True, single_images=False, prefix='',
batch_size=100, num_workers=24):
'''
Creates an image strip file for every unit of every retained layer
of the model, in the format [outdir]/[layername]/[unitnum]-top.jpg.
Assumes that the indexes of topk refer to the indexes of dataset.
Limits each strip to the top row_length images.
'''
progress = default_progress()
needed_images = {}
if row_images is False:
row_length = 1
# Pass 1: needed_images lists all images that are topk for some unit.
for layer in topk:
topresult = topk[layer].result()[1].cpu()
for unit, row in enumerate(topresult):
for rank, imgnum in enumerate(row[:row_length]):
imgnum = imgnum.item()
if imgnum not in needed_images:
needed_images[imgnum] = []
needed_images[imgnum].append((layer, unit, rank))
levels = {k: v.cpu().numpy() for k, v in levels.items()}
row_length = len(row[:row_length])
needed_sample = FixedSubsetSampler(sorted(needed_images.keys()))
device = next(model.parameters()).device
segloader = torch.utils.data.DataLoader(dataset,
batch_size=batch_size, num_workers=num_workers,
pin_memory=(device.type == 'cuda'),
sampler=needed_sample)
vizgrid, maskgrid, origrid, seggrid = [{} for _ in range(4)]
# Pass 2: populate vizgrid with visualizations of top units.
pool = None
for i, batch in enumerate(
progress(segloader, desc='Making images')):
# Reverse transformation to get the image in byte form.
seg, _, byte_im, _ = segrunner.run_and_segment_batch(batch, model,
want_rgb=True)
torch_features = model.retained_features()
scale_offset = getattr(model, 'scale_offset', None)
if pool is None:
# Distribute the work across processes: create shared mmaps.
for layer, tf in torch_features.items():
[vizgrid[layer], maskgrid[layer], origrid[layer],
seggrid[layer]] = [
create_temp_mmap_grid((tf.shape[1],
byte_im.shape[1], row_length,
byte_im.shape[2] + gap_pixels, depth),
dtype='uint8',
fill=255)
for depth in [3, 4, 3, 3]]
# Pass those mmaps to worker processes.
pool = WorkerPool(worker=VisualizeImageWorker,
memmap_grid_info=[
{layer: (g.filename, g.shape, g.dtype)
for layer, g in grid.items()}
for grid in [vizgrid, maskgrid, origrid, seggrid]])
byte_im = byte_im.cpu().numpy()
numpy_seg = seg.cpu().numpy()
features = {}
for index in range(len(byte_im)):
imgnum = needed_sample.samples[index + i*segloader.batch_size]
for layer, unit, rank in needed_images[imgnum]:
if layer not in features:
features[layer] = torch_features[layer].cpu().numpy()
pool.add(layer, unit, rank,
byte_im[index],
features[layer][index, unit],
levels[layer][unit],
scale_offset[layer] if scale_offset else None,
numpy_seg[index])
pool.join()
# Pass 3: save image strips as [outdir]/[layer]/[unitnum]-[top/orig].jpg
pool = WorkerPool(worker=SaveImageWorker)
for layer, vg in progress(vizgrid.items(), desc='Saving images'):
os.makedirs(os.path.join(outdir, safe_dir_name(layer),
prefix + 'image'), exist_ok=True)
if single_images:
os.makedirs(os.path.join(outdir, safe_dir_name(layer),
prefix + 's-image'), exist_ok=True)
og, sg, mg = origrid[layer], seggrid[layer], maskgrid[layer]
for unit in progress(range(len(vg)), desc='Units'):
for suffix, grid in [('top.jpg', vg), ('orig.jpg', og),
('seg.png', sg), ('mask.png', mg)]:
strip = grid[unit].reshape(
(grid.shape[1], grid.shape[2] * grid.shape[3],
grid.shape[4]))
if row_images:
filename = os.path.join(outdir, safe_dir_name(layer),
prefix + 'image', '%d-%s' % (unit, suffix))
pool.add(strip[:,:-gap_pixels,:].copy(), filename)
# Image.fromarray(strip[:,:-gap_pixels,:]).save(filename,
# optimize=True, quality=80)
if single_images:
single_filename = os.path.join(outdir, safe_dir_name(layer),
prefix + 's-image', '%d-%s' % (unit, suffix))
pool.add(strip[:,:strip.shape[1] // row_length
- gap_pixels,:].copy(), single_filename)
# Image.fromarray(strip[:,:strip.shape[1] // row_length
# - gap_pixels,:]).save(single_filename,
# optimize=True, quality=80)
pool.join()
# Delete the shared memory map files
clear_global_shared_files([g.filename
for grid in [vizgrid, maskgrid, origrid, seggrid]
for g in grid.values()])
global_shared_files = {}
def create_temp_mmap_grid(shape, dtype, fill):
dtype = numpy.dtype(dtype)
filename = os.path.join(tempfile.mkdtemp(), 'temp-%s-%s.mmap' %
('x'.join('%d' % s for s in shape), dtype.name))
fid = open(filename, mode='w+b')
original = numpy.memmap(fid, dtype=dtype, mode='w+', shape=shape)
original.fid = fid
original[...] = fill
global_shared_files[filename] = original
return original
def shared_temp_mmap_grid(filename, shape, dtype):
if filename not in global_shared_files:
global_shared_files[filename] = numpy.memmap(
filename, dtype=dtype, mode='r+', shape=shape)
return global_shared_files[filename]
def clear_global_shared_files(filenames):
for fn in filenames:
if fn in global_shared_files:
del global_shared_files[fn]
try:
os.unlink(fn)
except OSError:
pass
class VisualizeImageWorker(WorkerBase):
def setup(self, memmap_grid_info):
self.vizgrid, self.maskgrid, self.origrid, self.seggrid = [
{layer: shared_temp_mmap_grid(*info)
for layer, info in grid.items()}
for grid in memmap_grid_info]
def work(self, layer, unit, rank,
byte_im, acts, level, scale_offset, seg):
self.origrid[layer][unit,:,rank,:byte_im.shape[0],:] = byte_im
[self.vizgrid[layer][unit,:,rank,:byte_im.shape[0],:],
self.maskgrid[layer][unit,:,rank,:byte_im.shape[0],:]] = (
activation_visualization(
byte_im,
acts,
level,
scale_offset=scale_offset,
return_mask=True))
self.seggrid[layer][unit,:,rank,:byte_im.shape[0],:] = (
segment_visualization(seg, byte_im.shape[0:2]))
class SaveImageWorker(WorkerBase):
def work(self, data, filename):
Image.fromarray(data).save(filename, optimize=True, quality=80)
def score_tally_stats(label_category, tc, truth, cc, ic):
pred = cc[label_category]
total = tc[label_category][:, None]
truth = truth[:, None]
epsilon = 1e-20 # avoid division-by-zero
union = pred + truth - ic
iou = ic.double() / (union.double() + epsilon)
arr = torch.empty(size=(2, 2) + ic.shape, dtype=ic.dtype, device=ic.device)
arr[0, 0] = ic
arr[0, 1] = pred - ic
arr[1, 0] = truth - ic
arr[1, 1] = total - union
arr = arr.double() / total.double()
mi = mutual_information(arr)
je = joint_entropy(arr)
iqr = mi / je
iqr[torch.isnan(iqr)] = 0 # Zero out any 0/0
return iou, iqr
def collect_quantiles_and_topk(outdir, model, segloader,
segrunner, k=100, resolution=1024):
'''
Collects (estimated) quantile information and (exact) sorted top-K lists
for every channel in the retained layers of the model. Returns
a map of quantiles (one RunningQuantile for each layer) along with
a map of topk (one RunningTopK for each layer).
'''
device = next(model.parameters()).device
features = model.retained_features()
cached_quantiles = {
layer: load_quantile_if_present(os.path.join(outdir,
safe_dir_name(layer)), 'quantiles.npz',
device=torch.device('cpu'))
for layer in features }
cached_topks = {
layer: load_topk_if_present(os.path.join(outdir,
safe_dir_name(layer)), 'topk.npz',
device=torch.device('cpu'))
for layer in features }
if (all(value is not None for value in cached_quantiles.values()) and
all(value is not None for value in cached_topks.values())):
return cached_quantiles, cached_topks
layer_batch_size = 8
all_layers = list(features.keys())
layer_batches = [all_layers[i:i+layer_batch_size]
for i in range(0, len(all_layers), layer_batch_size)]
quantiles, topks = {}, {}
progress = default_progress()
for layer_batch in layer_batches:
for i, batch in enumerate(progress(segloader, desc='Quantiles')):
# We don't actually care about the model output.
model(batch[0].to(device))
features = model.retained_features()
# We care about the retained values
for key in layer_batch:
value = features[key]
if topks.get(key, None) is None:
topks[key] = RunningTopK(k)
if quantiles.get(key, None) is None:
quantiles[key] = RunningQuantile(resolution=resolution)
topvalue = value
if len(value.shape) > 2:
topvalue, _ = value.view(*(value.shape[:2] + (-1,))).max(2)
# Put the channel index last.
value = value.permute(
(0,) + tuple(range(2, len(value.shape))) + (1,)
).contiguous().view(-1, value.shape[1])
quantiles[key].add(value)
topks[key].add(topvalue)
# Save GPU memory
for key in layer_batch:
quantiles[key].to_(torch.device('cpu'))
topks[key].to_(torch.device('cpu'))
for layer in quantiles:
save_state_dict(quantiles[layer],
os.path.join(outdir, safe_dir_name(layer), 'quantiles.npz'))
save_state_dict(topks[layer],
os.path.join(outdir, safe_dir_name(layer), 'topk.npz'))
return quantiles, topks
def collect_bincounts(outdir, model, segloader, levels, segrunner):
'''
Returns label_counts, category_activation_counts, and intersection_counts,
across the data set, counting the pixels of intersection between upsampled,
thresholded model featuremaps, with segmentation classes in the segloader.
label_counts (independent of model): pixels across the data set that
are labeled with the given label.
category_activation_counts (one per layer): for each feature channel,
pixels across the dataset where the channel exceeds the level
threshold. There is one count per category: activations only
contribute to the categories for which any category labels are
present on the images.
intersection_counts (one per layer): for each feature channel and
label, pixels across the dataset where the channel exceeds
the level, and the labeled segmentation class is also present.
This is a performance-sensitive function. Best performance is
achieved with a counting scheme which assumes a segloader with
batch_size 1.
'''
# Load cached data if present
(iou_scores, iqr_scores,
total_counts, label_counts, category_activation_counts,
intersection_counts) = {}, {}, None, None, {}, {}
found_all = True
for layer in model.retained_features():
filename = os.path.join(outdir, safe_dir_name(layer), 'bincounts.npz')
if os.path.isfile(filename):
data = numpy.load(filename)
iou_scores[layer] = torch.from_numpy(data['iou_scores'])
iqr_scores[layer] = torch.from_numpy(data['iqr_scores'])
total_counts = torch.from_numpy(data['total_counts'])
label_counts = torch.from_numpy(data['label_counts'])
category_activation_counts[layer] = torch.from_numpy(
data['category_activation_counts'])
intersection_counts[layer] = torch.from_numpy(
data['intersection_counts'])
else:
found_all = False
if found_all:
return (iou_scores, iqr_scores,
total_counts, label_counts, category_activation_counts,
intersection_counts)
device = next(model.parameters()).device
labelcat, categories = segrunner.get_label_and_category_names()
label_category = [categories.index(c) if c in categories else 0
for l, c in labelcat]
num_labels, num_categories = (len(n) for n in [labelcat, categories])
# One-hot vector of category for each label
labelcat = torch.zeros(num_labels, num_categories,
dtype=torch.long, device=device)
labelcat.scatter_(1, torch.from_numpy(numpy.array(label_category,
dtype='int64')).to(device)[:,None], 1)
# Running bincounts
# activation_counts = {}
assert segloader.batch_size == 1 # category_activation_counts needs this.
category_activation_counts = {}
intersection_counts = {}
label_counts = torch.zeros(num_labels, dtype=torch.long, device=device)
total_counts = torch.zeros(num_categories, dtype=torch.long, device=device)
progress = default_progress()
scale_offset_map = getattr(model, 'scale_offset', None)
upsample_grids = {}
# total_batch_categories = torch.zeros(
# labelcat.shape[1], dtype=torch.long, device=device)
for i, batch in enumerate(progress(segloader, desc='Bincounts')):
seg, batch_label_counts, _, imshape = segrunner.run_and_segment_batch(
batch, model, want_bincount=True, want_rgb=True)
bc = batch_label_counts.cpu()
batch_label_counts = batch_label_counts.to(device)
seg = seg.to(device)
features = model.retained_features()
# Accumulate bincounts and identify nonzeros
label_counts += batch_label_counts[0]
batch_labels = bc[0].nonzero()[:,0]
batch_categories = labelcat[batch_labels].max(0)[0]
total_counts += batch_categories * (
seg.shape[0] * seg.shape[2] * seg.shape[3])
for key, value in features.items():
if key not in upsample_grids:
upsample_grids[key] = upsample_grid(value.shape[2:],
seg.shape[2:], imshape,
scale_offset=scale_offset_map.get(key, None)
if scale_offset_map is not None else None,
dtype=value.dtype, device=value.device)
upsampled = torch.nn.functional.grid_sample(value,
upsample_grids[key], padding_mode='border')
amask = (upsampled > levels[key][None,:,None,None].to(
upsampled.device))
ac = amask.int().view(amask.shape[1], -1).sum(1)
# if key not in activation_counts:
# activation_counts[key] = ac
# else:
# activation_counts[key] += ac
# The fastest approach: sum over each label separately!
for label in batch_labels.tolist():
if label == 0:
continue # ignore the background label
imask = amask * ((seg == label).max(dim=1, keepdim=True)[0])
ic = imask.int().view(imask.shape[1], -1).sum(1)
if key not in intersection_counts:
intersection_counts[key] = torch.zeros(num_labels,
amask.shape[1], dtype=torch.long, device=device)
intersection_counts[key][label] += ic
# Count activations within images that have category labels.
# Note: This only makes sense with batch-size one
# total_batch_categories += batch_categories
cc = batch_categories[:,None] * ac[None,:]
if key not in category_activation_counts:
category_activation_counts[key] = cc
else:
category_activation_counts[key] += cc
iou_scores = {}
iqr_scores = {}
for k in intersection_counts:
iou_scores[k], iqr_scores[k] = score_tally_stats(
label_category, total_counts, label_counts,
category_activation_counts[k], intersection_counts[k])
for k in intersection_counts:
numpy.savez(os.path.join(outdir, safe_dir_name(k), 'bincounts.npz'),
iou_scores=iou_scores[k].cpu().numpy(),
iqr_scores=iqr_scores[k].cpu().numpy(),
total_counts=total_counts.cpu().numpy(),
label_counts=label_counts.cpu().numpy(),
category_activation_counts=category_activation_counts[k]
.cpu().numpy(),
intersection_counts=intersection_counts[k].cpu().numpy(),
levels=levels[k].cpu().numpy())
return (iou_scores, iqr_scores,
total_counts, label_counts, category_activation_counts,
intersection_counts)
def collect_cond_quantiles(outdir, model, segloader, segrunner):
'''
Returns maxiou and maxiou_level across the data set, one per layer.
This is a performance-sensitive function. Best performance is
achieved with a counting scheme which assumes a segloader with
batch_size 1.
'''
device = next(model.parameters()).device
cached_cond_quantiles = {
layer: load_conditional_quantile_if_present(os.path.join(outdir,
safe_dir_name(layer)), 'cond_quantiles.npz') # on cpu
for layer in model.retained_features() }
label_fracs = load_npy_if_present(outdir, 'label_fracs.npy', 'cpu')
if label_fracs is not None and all(
value is not None for value in cached_cond_quantiles.values()):
return cached_cond_quantiles, label_fracs
labelcat, categories = segrunner.get_label_and_category_names()
label_category = [categories.index(c) if c in categories else 0
for l, c in labelcat]
num_labels, num_categories = (len(n) for n in [labelcat, categories])
# One-hot vector of category for each label
labelcat = torch.zeros(num_labels, num_categories,
dtype=torch.long, device=device)
labelcat.scatter_(1, torch.from_numpy(numpy.array(label_category,
dtype='int64')).to(device)[:,None], 1)
# Running maxiou
assert segloader.batch_size == 1 # category_activation_counts needs this.
conditional_quantiles = {}
label_counts = torch.zeros(num_labels, dtype=torch.long, device=device)
pixel_count = 0
progress = default_progress()
scale_offset_map = getattr(model, 'scale_offset', None)
upsample_grids = {}
common_conditions = set()
if label_fracs is None or label_fracs is 0:
for i, batch in enumerate(progress(segloader, desc='label fracs')):
seg, batch_label_counts, im, _ = segrunner.run_and_segment_batch(
batch, model, want_bincount=True, want_rgb=True)
batch_label_counts = batch_label_counts.to(device)
features = model.retained_features()
# Accumulate bincounts and identify nonzeros
label_counts += batch_label_counts[0]
pixel_count += seg.shape[2] * seg.shape[3]
label_fracs = (label_counts.cpu().float() / pixel_count)[:, None, None]
numpy.save(os.path.join(outdir, 'label_fracs.npy'), label_fracs)
skip_threshold = 1e-4
skip_labels = set(i.item()
for i in (label_fracs.view(-1) < skip_threshold).nonzero().view(-1))
for layer in progress(model.retained_features().keys(), desc='CQ layers'):
if cached_cond_quantiles.get(layer, None) is not None:
conditional_quantiles[layer] = cached_cond_quantiles[layer]
continue
for i, batch in enumerate(progress(segloader, desc='Condquant')):
seg, batch_label_counts, _, imshape = (
segrunner.run_and_segment_batch(
batch, model, want_bincount=True, want_rgb=True))
bc = batch_label_counts.cpu()
batch_label_counts = batch_label_counts.to(device)
features = model.retained_features()
# Accumulate bincounts and identify nonzeros
label_counts += batch_label_counts[0]
pixel_count += seg.shape[2] * seg.shape[3]
batch_labels = bc[0].nonzero()[:,0]
batch_categories = labelcat[batch_labels].max(0)[0]
cpu_seg = None
value = features[layer]
if layer not in upsample_grids:
upsample_grids[layer] = upsample_grid(value.shape[2:],
seg.shape[2:], imshape,
scale_offset=scale_offset_map.get(layer, None)
if scale_offset_map is not None else None,
dtype=value.dtype, device=value.device)
if layer not in conditional_quantiles:
conditional_quantiles[layer] = RunningConditionalQuantile(
resolution=2048)
upsampled = torch.nn.functional.grid_sample(value,
upsample_grids[layer], padding_mode='border').view(
value.shape[1], -1)
conditional_quantiles[layer].add(('all',), upsampled.t())
cpu_upsampled = None
for label in batch_labels.tolist():
if label in skip_labels:
continue
label_key = ('label', label)
if label_key in common_conditions:
imask = (seg == label).max(dim=1)[0].view(-1)
intersected = upsampled[:, imask]
conditional_quantiles[layer].add(('label', label),
intersected.t())
else:
if cpu_seg is None:
cpu_seg = seg.cpu()
if cpu_upsampled is None:
cpu_upsampled = upsampled.cpu()
imask = (cpu_seg == label).max(dim=1)[0].view(-1)
intersected = cpu_upsampled[:, imask]
conditional_quantiles[layer].add(('label', label),
intersected.t())
if num_categories > 1:
for cat in batch_categories.nonzero()[:,0]:
conditional_quantiles[layer].add(('cat', cat.item()),
upsampled.t())
# Move the most common conditions to the GPU.
if i and not i & (i - 1): # if i is a power of 2:
cq = conditional_quantiles[layer]
common_conditions = set(cq.most_common_conditions(64))
cq.to_('cpu', [k for k in cq.running_quantiles.keys()
if k not in common_conditions])
# When a layer is done, get it off the GPU
conditional_quantiles[layer].to_('cpu')
label_fracs = (label_counts.cpu().float() / pixel_count)[:, None, None]
for cq in conditional_quantiles.values():
cq.to_('cpu')
for layer in conditional_quantiles:
save_state_dict(conditional_quantiles[layer],
os.path.join(outdir, safe_dir_name(layer), 'cond_quantiles.npz'))
numpy.save(os.path.join(outdir, 'label_fracs.npy'), label_fracs)
return conditional_quantiles, label_fracs
def collect_maxiou(outdir, model, segloader, segrunner):
'''
Returns maxiou and maxiou_level across the data set, one per layer.
This is a performance-sensitive function. Best performance is
achieved with a counting scheme which assumes a segloader with
batch_size 1.
'''
device = next(model.parameters()).device
conditional_quantiles, label_fracs = collect_cond_quantiles(
outdir, model, segloader, segrunner)
labelcat, categories = segrunner.get_label_and_category_names()
label_category = [categories.index(c) if c in categories else 0
for l, c in labelcat]
num_labels, num_categories = (len(n) for n in [labelcat, categories])
label_list = [('label', i) for i in range(num_labels)]
category_list = [('all',)] if num_categories <= 1 else (
[('cat', i) for i in range(num_categories)])
max_iou, max_iou_level, max_iou_quantile = {}, {}, {}
fracs = torch.logspace(-3, 0, 100)
progress = default_progress()
for layer, cq in progress(conditional_quantiles.items(), desc='Maxiou'):
levels = cq.conditional(('all',)).quantiles(1 - fracs)
denoms = 1 - cq.collected_normalize(category_list, levels)
isects = (1 - cq.collected_normalize(label_list, levels)) * label_fracs
unions = label_fracs + denoms[label_category, :, :] - isects
iou = isects / unions
# TODO: erase any for which threshold is bad
max_iou[layer], level_bucket = iou.max(2)
max_iou_level[layer] = levels[
torch.arange(levels.shape[0])[None,:], level_bucket]
max_iou_quantile[layer] = fracs[level_bucket]
for layer in model.retained_features():
numpy.savez(os.path.join(outdir, safe_dir_name(layer), 'max_iou.npz'),
max_iou=max_iou[layer].cpu().numpy(),
max_iou_level=max_iou_level[layer].cpu().numpy(),
max_iou_quantile=max_iou_quantile[layer].cpu().numpy())
return (max_iou, max_iou_level, max_iou_quantile)
def collect_iqr(outdir, model, segloader, segrunner):
'''
Returns iqr and iqr_level.
This is a performance-sensitive function. Best performance is
achieved with a counting scheme which assumes a segloader with
batch_size 1.
'''
max_iqr, max_iqr_level, max_iqr_quantile, max_iqr_iou = {}, {}, {}, {}
max_iqr_agreement = {}
found_all = True
for layer in model.retained_features():
filename = os.path.join(outdir, safe_dir_name(layer), 'iqr.npz')
if os.path.isfile(filename):
data = numpy.load(filename)
max_iqr[layer] = torch.from_numpy(data['max_iqr'])
max_iqr_level[layer] = torch.from_numpy(data['max_iqr_level'])
max_iqr_quantile[layer] = torch.from_numpy(data['max_iqr_quantile'])
max_iqr_iou[layer] = torch.from_numpy(data['max_iqr_iou'])
max_iqr_agreement[layer] = torch.from_numpy(
data['max_iqr_agreement'])
else:
found_all = False
if found_all:
return (max_iqr, max_iqr_level, max_iqr_quantile, max_iqr_iou,
max_iqr_agreement)
device = next(model.parameters()).device
conditional_quantiles, label_fracs = collect_cond_quantiles(
outdir, model, segloader, segrunner)
labelcat, categories = segrunner.get_label_and_category_names()
label_category = [categories.index(c) if c in categories else 0
for l, c in labelcat]
num_labels, num_categories = (len(n) for n in [labelcat, categories])
label_list = [('label', i) for i in range(num_labels)]
category_list = [('all',)] if num_categories <= 1 else (
[('cat', i) for i in range(num_categories)])
full_mi, full_je, full_iqr = {}, {}, {}
fracs = torch.logspace(-3, 0, 100)
progress = default_progress()
for layer, cq in progress(conditional_quantiles.items(), desc='IQR'):
levels = cq.conditional(('all',)).quantiles(1 - fracs)
truth = label_fracs.to(device)
preds = (1 - cq.collected_normalize(category_list, levels)
)[label_category, :, :].to(device)
cond_isects = 1 - cq.collected_normalize(label_list, levels).to(device)
isects = cond_isects * truth
unions = truth + preds - isects
arr = torch.empty(size=(2, 2) + isects.shape, dtype=isects.dtype,
device=device)
arr[0, 0] = isects
arr[0, 1] = preds - isects
arr[1, 0] = truth - isects
arr[1, 1] = 1 - unions
arr.clamp_(0, 1)
mi = mutual_information(arr)
mi[:,:,-1] = 0 # at the 1.0 quantile should be no MI.
# Don't trust mi when less than label_frac is less than 1e-3,
# because our samples are too small.
mi[label_fracs.view(-1) < 1e-3, :, :] = 0
je = joint_entropy(arr)
iqr = mi / je
iqr[torch.isnan(iqr)] = 0 # Zero out any 0/0
full_mi[layer] = mi.cpu()
full_je[layer] = je.cpu()
full_iqr[layer] = iqr.cpu()
del mi, je
agreement = isects + arr[1, 1]
# When optimizing, maximize only over those pairs where the
# unit is positively correlated with the label, and where the
# threshold level is positive
positive_iqr = iqr
positive_iqr[agreement <= 0.8] = 0
positive_iqr[(levels <= 0.0)[None, :, :].expand(positive_iqr.shape)] = 0
# TODO: erase any for which threshold is bad
maxiqr, level_bucket = positive_iqr.max(2)
max_iqr[layer] = maxiqr.cpu()
max_iqr_level[layer] = levels.to(device)[
torch.arange(levels.shape[0])[None,:], level_bucket].cpu()
max_iqr_quantile[layer] = fracs.to(device)[level_bucket].cpu()
max_iqr_agreement[layer] = agreement[
torch.arange(agreement.shape[0])[:, None],
torch.arange(agreement.shape[1])[None, :],
level_bucket].cpu()
# Compute the iou that goes with each maximized iqr
matching_iou = (isects[
torch.arange(isects.shape[0])[:, None],
torch.arange(isects.shape[1])[None, :],
level_bucket] /
unions[
torch.arange(unions.shape[0])[:, None],
torch.arange(unions.shape[1])[None, :],
level_bucket])
matching_iou[torch.isnan(matching_iou)] = 0
max_iqr_iou[layer] = matching_iou.cpu()
for layer in model.retained_features():
numpy.savez(os.path.join(outdir, safe_dir_name(layer), 'iqr.npz'),
max_iqr=max_iqr[layer].cpu().numpy(),
max_iqr_level=max_iqr_level[layer].cpu().numpy(),
max_iqr_quantile=max_iqr_quantile[layer].cpu().numpy(),
max_iqr_iou=max_iqr_iou[layer].cpu().numpy(),
max_iqr_agreement=max_iqr_agreement[layer].cpu().numpy(),
full_mi=full_mi[layer].cpu().numpy(),
full_je=full_je[layer].cpu().numpy(),
full_iqr=full_iqr[layer].cpu().numpy())
return (max_iqr, max_iqr_level, max_iqr_quantile, max_iqr_iou,
max_iqr_agreement)
def mutual_information(arr):
total = 0
for j in range(arr.shape[0]):
for k in range(arr.shape[1]):
joint = arr[j,k]
ind = arr[j,:].sum(dim=0) * arr[:,k].sum(dim=0)
term = joint * (joint / ind).log()
term[torch.isnan(term)] = 0
total += term
return total.clamp_(0)
def joint_entropy(arr):
total = 0
for j in range(arr.shape[0]):
for k in range(arr.shape[1]):
joint = arr[j,k]
term = joint * joint.log()
term[torch.isnan(term)] = 0
total += term
return (-total).clamp_(0)
def information_quality_ratio(arr):
iqr = mutual_information(arr) / joint_entropy(arr)
iqr[torch.isnan(iqr)] = 0
return iqr
def collect_covariance(outdir, model, segloader, segrunner):
'''
Returns label_mean, label_variance, unit_mean, unit_variance,
and cross_covariance across the data set.
label_mean, label_variance (independent of model):
treating the label as a one-hot, each label's mean and variance.
unit_mean, unit_variance (one per layer): for each feature channel,
the mean and variance of the activations in that channel.
cross_covariance (one per layer): the cross covariance between the
labels and the units in the layer.
'''
device = next(model.parameters()).device
cached_covariance = {
layer: load_covariance_if_present(os.path.join(outdir,
safe_dir_name(layer)), 'covariance.npz', device=device)
for layer in model.retained_features() }
if all(value is not None for value in cached_covariance.values()):
return cached_covariance
labelcat, categories = segrunner.get_label_and_category_names()
label_category = [categories.index(c) if c in categories else 0
for l, c in labelcat]
num_labels, num_categories = (len(n) for n in [labelcat, categories])
# Running covariance
cov = {}
progress = default_progress()
scale_offset_map = getattr(model, 'scale_offset', None)
upsample_grids = {}
for i, batch in enumerate(progress(segloader, desc='Covariance')):
seg, _, _, imshape = segrunner.run_and_segment_batch(batch, model,
want_rgb=True)
features = model.retained_features()
ohfeats = multilabel_onehot(seg, num_labels, ignore_index=0)
# Accumulate bincounts and identify nonzeros
for key, value in features.items():
if key not in upsample_grids:
upsample_grids[key] = upsample_grid(value.shape[2:],
seg.shape[2:], imshape,
scale_offset=scale_offset_map.get(key, None)
if scale_offset_map is not None else None,
dtype=value.dtype, device=value.device)
upsampled = torch.nn.functional.grid_sample(value,
upsample_grids[key].expand(
(value.shape[0],) + upsample_grids[key].shape[1:]),
padding_mode='border')
if key not in cov:
cov[key] = RunningCrossCovariance()
cov[key].add(upsampled, ohfeats)
for layer in cov:
save_state_dict(cov[layer],
os.path.join(outdir, safe_dir_name(layer), 'covariance.npz'))
return cov
def multilabel_onehot(labels, num_labels, dtype=None, ignore_index=None):
'''
Converts a multilabel tensor into a onehot tensor.
The input labels is a tensor of shape (samples, multilabels, y, x).
The output is a tensor of shape (samples, num_labels, y, x).
If ignore_index is specified, labels with that index are ignored.
Each x in labels should be 0 <= x < num_labels, or x == ignore_index.
'''
assert ignore_index is None or ignore_index <= 0
if dtype is None:
dtype = torch.float
device = labels.device
chans = num_labels + (-ignore_index if ignore_index else 0)
outshape = (labels.shape[0], chans) + labels.shape[2:]
result = torch.zeros(outshape, device=device, dtype=dtype)
if ignore_index and ignore_index < 0:
labels = labels + (-ignore_index)
result.scatter_(1, labels, 1)
if ignore_index and ignore_index < 0:
result = result[:, -ignore_index:]
elif ignore_index is not None:
result[:, ignore_index] = 0
return result
def load_npy_if_present(outdir, filename, device):
filepath = os.path.join(outdir, filename)
if os.path.isfile(filepath):
data = numpy.load(filepath)
return torch.from_numpy(data).to(device)
return 0
def load_npz_if_present(outdir, filename, varnames, device):
filepath = os.path.join(outdir, filename)
if os.path.isfile(filepath):
data = numpy.load(filepath)
numpy_result = [data[n] for n in varnames]
return tuple(torch.from_numpy(data).to(device) for data in numpy_result)
return None
def load_quantile_if_present(outdir, filename, device):
filepath = os.path.join(outdir, filename)
if os.path.isfile(filepath):
data = numpy.load(filepath)
result = RunningQuantile(state=data)
result.to_(device)
return result
return None
def load_conditional_quantile_if_present(outdir, filename):
filepath = os.path.join(outdir, filename)
if os.path.isfile(filepath):
data = numpy.load(filepath)
result = RunningConditionalQuantile(state=data)
return result
return None
def load_topk_if_present(outdir, filename, device):
filepath = os.path.join(outdir, filename)
if os.path.isfile(filepath):
data = numpy.load(filepath)
result = RunningTopK(state=data)
result.to_(device)
return result
return None
def load_covariance_if_present(outdir, filename, device):
filepath = os.path.join(outdir, filename)
if os.path.isfile(filepath):
data = numpy.load(filepath)
result = RunningCrossCovariance(state=data)
result.to_(device)
return result
return None
def save_state_dict(obj, filepath):
dirname = os.path.dirname(filepath)
os.makedirs(dirname, exist_ok=True)
dic = obj.state_dict()
numpy.savez(filepath, **dic)
def upsample_grid(data_shape, target_shape, input_shape=None,
scale_offset=None, dtype=torch.float, device=None):
'''Prepares a grid to use with grid_sample to upsample a batch of
features in data_shape to the target_shape. Can use scale_offset
and input_shape to center the grid in a nondefault way: scale_offset
maps feature pixels to input_shape pixels, and it is assumed that
the target_shape is a uniform downsampling of input_shape.'''
# Default is that nothing is resized.
if target_shape is None:
target_shape = data_shape
# Make a default scale_offset to fill the image if there isn't one
if scale_offset is None:
scale = tuple(float(ts) / ds
for ts, ds in zip(target_shape, data_shape))
offset = tuple(0.5 * s - 0.5 for s in scale)
else:
scale, offset = (v for v in zip(*scale_offset))
# Handle downsampling for different input vs target shape.
if input_shape is not None:
scale = tuple(s * (ts - 1) / (ns - 1)
for s, ns, ts in zip(scale, input_shape, target_shape))
offset = tuple(o * (ts - 1) / (ns - 1)
for o, ns, ts in zip(offset, input_shape, target_shape))
# Pytorch needs target coordinates in terms of source coordinates [-1..1]
ty, tx = (((torch.arange(ts, dtype=dtype, device=device) - o)
* (2 / (s * (ss - 1))) - 1)
for ts, ss, s, o, in zip(target_shape, data_shape, scale, offset))
# Whoa, note that grid_sample reverses the order y, x -> x, y.
grid = torch.stack(
(tx[None,:].expand(target_shape), ty[:,None].expand(target_shape)),2
)[None,:,:,:].expand((1, target_shape[0], target_shape[1], 2))
return grid
def safe_dir_name(filename):
keepcharacters = (' ','.','_','-')
return ''.join(c
for c in filename if c.isalnum() or c in keepcharacters).rstrip()
bargraph_palette = [
('#4B4CBF', '#B6B6F2'),
('#55B05B', '#B6F2BA'),
('#50BDAC', '#A5E5DB'),
('#81C679', '#C0FF9B'),
('#F0883B', '#F2CFB6'),
('#D4CF24', '#F2F1B6'),
('#D92E2B', '#F2B6B6'),
('#AB6BC6', '#CFAAFF'),
]
def make_svg_bargraph(labels, heights, categories,
barheight=100, barwidth=12, show_labels=True, filename=None):
# if len(labels) == 0:
# return # Nothing to do
unitheight = float(barheight) / max(max(heights, default=1), 1)
textheight = barheight if show_labels else 0
labelsize = float(barwidth)
gap = float(barwidth) / 4
textsize = barwidth + gap
rollup = max(heights, default=1)
textmargin = float(labelsize) * 2 / 3
leftmargin = 32
rightmargin = 8
svgwidth = len(heights) * (barwidth + gap) + 2 * leftmargin + rightmargin
svgheight = barheight + textheight
# create an SVG XML element
svg = et.Element('svg', width=str(svgwidth), height=str(svgheight),
version='1.1', xmlns='http://www.w3.org/2000/svg')
# Draw the bar graph
basey = svgheight - textheight
x = leftmargin
# Add units scale on left
if len(heights):
for h in [1, (max(heights) + 1) // 2, max(heights)]:
et.SubElement(svg, 'text', x='0', y='0',
style=('font-family:sans-serif;font-size:%dpx;' +
'text-anchor:end;alignment-baseline:hanging;' +
'transform:translate(%dpx, %dpx);') %
(textsize, x - gap, basey - h * unitheight)).text = str(h)
et.SubElement(svg, 'text', x='0', y='0',
style=('font-family:sans-serif;font-size:%dpx;' +
'text-anchor:middle;' +
'transform:translate(%dpx, %dpx) rotate(-90deg)') %
(textsize, x - gap - textsize, basey - h * unitheight / 2)
).text = 'units'
# Draw big category background rectangles
for catindex, (cat, catcount) in enumerate(categories):
if not catcount:
continue
et.SubElement(svg, 'rect', x=str(x), y=str(basey - rollup * unitheight),
width=(str((barwidth + gap) * catcount - gap)),
height = str(rollup*unitheight),
fill=bargraph_palette[catindex % len(bargraph_palette)][1])
x += (barwidth + gap) * catcount
# Draw small bars as well as 45degree text labels
x = leftmargin
catindex = -1
catcount = 0
for label, height in zip(labels, heights):
while not catcount and catindex <= len(categories):
catindex += 1
catcount = categories[catindex][1]
color = bargraph_palette[catindex % len(bargraph_palette)][0]
et.SubElement(svg, 'rect', x=str(x), y=str(basey-(height * unitheight)),
width=str(barwidth), height=str(height * unitheight),
fill=color)
x += barwidth
if show_labels:
et.SubElement(svg, 'text', x='0', y='0',
style=('font-family:sans-serif;font-size:%dpx;text-anchor:end;'+
'transform:translate(%dpx, %dpx) rotate(-45deg);') %
(labelsize, x, basey + textmargin)).text = readable(label)
x += gap
catcount -= 1
# Text labels for each category
x = leftmargin
for cat, catcount in categories:
if not catcount:
continue
et.SubElement(svg, 'text', x='0', y='0',
style=('font-family:sans-serif;font-size:%dpx;text-anchor:end;'+
'transform:translate(%dpx, %dpx) rotate(-90deg);') %
(textsize, x + (barwidth + gap) * catcount - gap,
basey - rollup * unitheight + gap)).text = '%d %s' % (
catcount, readable(cat + ('s' if catcount != 1 else '')))
x += (barwidth + gap) * catcount
# Output - this is the bare svg.
result = et.tostring(svg)
if filename:
f = open(filename, 'wb')
# When writing to a file a special header is needed.
f.write(''.join([
'<?xml version=\"1.0\" standalone=\"no\"?>\n',
'<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n',
'\"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n']
).encode('utf-8'))
f.write(result)
f.close()
return result
readable_replacements = [(re.compile(r[0]), r[1]) for r in [
(r'-[sc]$', ''),
(r'_', ' '),
]]
def readable(label):
for pattern, subst in readable_replacements:
label= re.sub(pattern, subst, label)
return label
def reverse_normalize_from_transform(transform):
'''
Crawl around the transforms attached to a dataset looking for a
Normalize transform, and return it a corresponding ReverseNormalize,
or None if no normalization is found.
'''
if isinstance(transform, torchvision.transforms.Normalize):
return ReverseNormalize(transform.mean, transform.std)
t = getattr(transform, 'transform', None)
if t is not None:
return reverse_normalize_from_transform(t)
transforms = getattr(transform, 'transforms', None)
if transforms is not None:
for t in reversed(transforms):
result = reverse_normalize_from_transform(t)
if result is not None:
return result
return None
class ReverseNormalize:
'''
Applies the reverse of torchvision.transforms.Normalize.
'''
def __init__(self, mean, stdev):
mean = numpy.array(mean)
stdev = numpy.array(stdev)
self.mean = torch.from_numpy(mean)[None,:,None,None].float()
self.stdev = torch.from_numpy(stdev)[None,:,None,None].float()
def __call__(self, data):
device = data.device
return data.mul(self.stdev.to(device)).add_(self.mean.to(device))
class ImageOnlySegRunner:
def __init__(self, dataset, recover_image=None):
if recover_image is None:
recover_image = reverse_normalize_from_transform(dataset)
self.recover_image = recover_image
self.dataset = dataset
def get_label_and_category_names(self):
return [('-', '-')], ['-']
def run_and_segment_batch(self, batch, model,
want_bincount=False, want_rgb=False):
[im] = batch
device = next(model.parameters()).device
if want_rgb:
rgb = self.recover_image(im.clone()
).permute(0, 2, 3, 1).mul_(255).clamp(0, 255).byte()
else:
rgb = None
# Stubs for seg and bc
seg = torch.zeros(im.shape[0], 1, 1, 1, dtype=torch.long)
bc = torch.ones(im.shape[0], 1, dtype=torch.long)
# Run the model.
model(im.to(device))
return seg, bc, rgb, im.shape[2:]
class ClassifierSegRunner:
def __init__(self, dataset, recover_image=None):
# The dataset contains explicit segmentations
if recover_image is None:
recover_image = reverse_normalize_from_transform(dataset)
self.recover_image = recover_image
self.dataset = dataset
def get_label_and_category_names(self):
catnames = self.dataset.categories
label_and_cat_names = [(readable(label),
catnames[self.dataset.label_category[i]])
for i, label in enumerate(self.dataset.labels)]
return label_and_cat_names, catnames
def run_and_segment_batch(self, batch, model,
want_bincount=False, want_rgb=False):
'''
Runs the dissected model on one batch of the dataset, and
returns a multilabel semantic segmentation for the data.
Given a batch of size (n, c, y, x) the segmentation should
be a (long integer) tensor of size (n, d, y//r, x//r) where
d is the maximum number of simultaneous labels given to a pixel,
and where r is some (optional) resolution reduction factor.
In the segmentation returned, the label `0` is reserved for
the background "no-label".
In addition to the segmentation, bc, rgb, and shape are returned
where bc is a per-image bincount counting returned label pixels,
rgb is a viewable (n, y, x, rgb) byte image tensor for the data
for visualizations (reversing normalizations, for example), and
shape is the (y, x) size of the data. If want_bincount or
want_rgb are False, those return values may be None.
'''
im, seg, bc = batch
device = next(model.parameters()).device
if want_rgb:
rgb = self.recover_image(im.clone()
).permute(0, 2, 3, 1).mul_(255).clamp(0, 255).byte()
else:
rgb = None
# Run the model.
model(im.to(device))
return seg, bc, rgb, im.shape[2:]
class GeneratorSegRunner:
def __init__(self, segmenter):
# The segmentations are given by an algorithm
if segmenter is None:
segmenter = UnifiedParsingSegmenter(segsizes=[256], segdiv='quad')
self.segmenter = segmenter
self.num_classes = len(segmenter.get_label_and_category_names()[0])
def get_label_and_category_names(self):
return self.segmenter.get_label_and_category_names()
def run_and_segment_batch(self, batch, model,
want_bincount=False, want_rgb=False):
'''
Runs the dissected model on one batch of the dataset, and
returns a multilabel semantic segmentation for the data.
Given a batch of size (n, c, y, x) the segmentation should
be a (long integer) tensor of size (n, d, y//r, x//r) where
d is the maximum number of simultaneous labels given to a pixel,
and where r is some (optional) resolution reduction factor.
In the segmentation returned, the label `0` is reserved for
the background "no-label".
In addition to the segmentation, bc, rgb, and shape are returned
where bc is a per-image bincount counting returned label pixels,
rgb is a viewable (n, y, x, rgb) byte image tensor for the data
for visualizations (reversing normalizations, for example), and
shape is the (y, x) size of the data. If want_bincount or
want_rgb are False, those return values may be None.
'''
device = next(model.parameters()).device
z_batch = batch[0]
tensor_images = model(z_batch.to(device))
seg = self.segmenter.segment_batch(tensor_images, downsample=2)
if want_bincount:
index = torch.arange(z_batch.shape[0],
dtype=torch.long, device=device)
bc = (seg + index[:, None, None, None] * self.num_classes).view(-1
).bincount(minlength=z_batch.shape[0] * self.num_classes)
bc = bc.view(z_batch.shape[0], self.num_classes)
else:
bc = None
if want_rgb:
images = ((tensor_images + 1) / 2 * 255)
rgb = images.permute(0, 2, 3, 1).clamp(0, 255).byte()
else:
rgb = None
return seg, bc, rgb, tensor_images.shape[2:]