Spaces:
Running
Running
''' | |
To run dissection: | |
1. Load up the convolutional model you wish to dissect, and wrap it in | |
an InstrumentedModel; then call imodel.retain_layers([layernames,..]) | |
to instrument the layers of interest. | |
2. Load the segmentation dataset using the BrodenDataset class; | |
use the transform_image argument to normalize images to be | |
suitable for the model, or the size argument to truncate the dataset. | |
3. Choose a directory in which to write the output, and call | |
dissect(outdir, model, dataset). | |
Example: | |
from dissect import InstrumentedModel, dissect | |
from broden import BrodenDataset | |
model = InstrumentedModel(load_my_model()) | |
model.eval() | |
model.cuda() | |
model.retain_layers(['conv1', 'conv2', 'conv3', 'conv4', 'conv5']) | |
bds = BrodenDataset('dataset/broden1_227', | |
transform_image=transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Normalize(IMAGE_MEAN, IMAGE_STDEV)]), | |
size=1000) | |
dissect('result/dissect', model, bds, | |
examples_per_unit=10) | |
''' | |
import torch, numpy, os, re, json, shutil, types, tempfile, torchvision | |
# import warnings | |
# warnings.simplefilter('error', UserWarning) | |
from PIL import Image | |
from xml.etree import ElementTree as et | |
from collections import OrderedDict, defaultdict | |
from .progress import verbose_progress, default_progress, print_progress | |
from .progress import desc_progress | |
from .runningstats import RunningQuantile, RunningTopK | |
from .runningstats import RunningCrossCovariance, RunningConditionalQuantile | |
from .sampler import FixedSubsetSampler | |
from .actviz import activation_visualization | |
from .segviz import segment_visualization, high_contrast | |
from .workerpool import WorkerBase, WorkerPool | |
from .segmenter import UnifiedParsingSegmenter | |
def dissect(outdir, model, dataset, | |
segrunner=None, | |
train_dataset=None, | |
model_segmenter=None, | |
quantile_threshold=0.005, | |
iou_threshold=0.05, | |
iqr_threshold=0.01, | |
examples_per_unit=100, | |
batch_size=100, | |
num_workers=24, | |
seg_batch_size=5, | |
make_images=True, | |
make_labels=True, | |
make_maxiou=False, | |
make_covariance=False, | |
make_report=True, | |
make_row_images=True, | |
make_single_images=False, | |
rank_all_labels=False, | |
netname=None, | |
meta=None, | |
merge=None, | |
settings=None, | |
): | |
''' | |
Runs net dissection in-memory, using pytorch, and saves visualizations | |
and metadata into outdir. | |
''' | |
assert not model.training, 'Run model.eval() before dissection' | |
if netname is None: | |
netname = type(model).__name__ | |
if segrunner is None: | |
segrunner = ClassifierSegRunner(dataset) | |
if train_dataset is None: | |
train_dataset = dataset | |
make_iqr = (quantile_threshold == 'iqr') | |
with torch.no_grad(): | |
device = next(model.parameters()).device | |
levels = None | |
labelnames, catnames = None, None | |
maxioudata, iqrdata = None, None | |
labeldata = None | |
iqrdata, cov = None, None | |
labelnames, catnames = segrunner.get_label_and_category_names() | |
label_category = [catnames.index(c) if c in catnames else 0 | |
for l, c in labelnames] | |
# First, always collect qunatiles and topk information. | |
segloader = torch.utils.data.DataLoader(dataset, | |
batch_size=batch_size, num_workers=num_workers, | |
pin_memory=(device.type == 'cuda')) | |
quantiles, topk = collect_quantiles_and_topk(outdir, model, | |
segloader, segrunner, k=examples_per_unit) | |
# Thresholds can be automatically chosen by maximizing iqr | |
if make_iqr: | |
# Get thresholds based on an IQR optimization | |
segloader = torch.utils.data.DataLoader(train_dataset, | |
batch_size=1, num_workers=num_workers, | |
pin_memory=(device.type == 'cuda')) | |
iqrdata = collect_iqr(outdir, model, segloader, segrunner) | |
max_iqr, full_iqr_levels = iqrdata[:2] | |
max_iqr_agreement = iqrdata[4] | |
# qualified_iqr[max_iqr_quantile[layer] > 0.5] = 0 | |
levels = {layer: full_iqr_levels[layer][ | |
max_iqr[layer].max(0)[1], | |
torch.arange(max_iqr[layer].shape[1])].to(device) | |
for layer in full_iqr_levels} | |
else: | |
levels = {k: qc.quantiles([1.0 - quantile_threshold])[:,0] | |
for k, qc in quantiles.items()} | |
quantiledata = (topk, quantiles, levels, quantile_threshold) | |
if make_images: | |
segloader = torch.utils.data.DataLoader(dataset, | |
batch_size=batch_size, num_workers=num_workers, | |
pin_memory=(device.type == 'cuda')) | |
generate_images(outdir, model, dataset, topk, levels, segrunner, | |
row_length=examples_per_unit, batch_size=seg_batch_size, | |
row_images=make_row_images, | |
single_images=make_single_images, | |
num_workers=num_workers) | |
if make_maxiou: | |
assert train_dataset, "Need training dataset for maxiou." | |
segloader = torch.utils.data.DataLoader(train_dataset, | |
batch_size=1, num_workers=num_workers, | |
pin_memory=(device.type == 'cuda')) | |
maxioudata = collect_maxiou(outdir, model, segloader, | |
segrunner) | |
if make_labels: | |
segloader = torch.utils.data.DataLoader(dataset, | |
batch_size=1, num_workers=num_workers, | |
pin_memory=(device.type == 'cuda')) | |
iou_scores, iqr_scores, tcs, lcs, ccs, ics = ( | |
collect_bincounts(outdir, model, segloader, | |
levels, segrunner)) | |
labeldata = (iou_scores, iqr_scores, lcs, ccs, ics, iou_threshold, | |
iqr_threshold) | |
if make_covariance: | |
segloader = torch.utils.data.DataLoader(dataset, | |
batch_size=seg_batch_size, | |
num_workers=num_workers, | |
pin_memory=(device.type == 'cuda')) | |
cov = collect_covariance(outdir, model, segloader, segrunner) | |
if make_report: | |
generate_report(outdir, | |
quantiledata=quantiledata, | |
labelnames=labelnames, | |
catnames=catnames, | |
labeldata=labeldata, | |
maxioudata=maxioudata, | |
iqrdata=iqrdata, | |
covariancedata=cov, | |
rank_all_labels=rank_all_labels, | |
netname=netname, | |
meta=meta, | |
mergedata=merge, | |
settings=settings) | |
return quantiledata, labeldata | |
def generate_report(outdir, quantiledata, labelnames=None, catnames=None, | |
labeldata=None, maxioudata=None, iqrdata=None, covariancedata=None, | |
rank_all_labels=False, netname='Model', meta=None, settings=None, | |
mergedata=None): | |
''' | |
Creates dissection.json reports and summary bargraph.svg files in the | |
specified output directory, and copies a dissection.html interface | |
to go along with it. | |
''' | |
all_layers = [] | |
# Current source code directory, for html to copy. | |
srcdir = os.path.realpath( | |
os.path.join(os.getcwd(), os.path.dirname(__file__))) | |
# Unpack arguments | |
topk, quantiles, levels, quantile_threshold = quantiledata | |
top_record = dict( | |
netname=netname, | |
meta=meta, | |
default_ranking='unit', | |
quantile_threshold=quantile_threshold) | |
if settings is not None: | |
top_record['settings'] = settings | |
if labeldata is not None: | |
iou_scores, iqr_scores, lcs, ccs, ics, iou_threshold, iqr_threshold = ( | |
labeldata) | |
catorder = {'object': -7, 'scene': -6, 'part': -5, | |
'piece': -4, | |
'material': -3, 'texture': -2, 'color': -1} | |
for i, cat in enumerate(c for c in catnames if c not in catorder): | |
catorder[cat] = i | |
catnumber = {n: i for i, n in enumerate(catnames)} | |
catnumber['-'] = 0 | |
top_record['default_ranking'] = 'label' | |
top_record['iou_threshold'] = iou_threshold | |
top_record['iqr_threshold'] = iqr_threshold | |
labelnumber = dict((name[0], num) | |
for num, name in enumerate(labelnames)) | |
# Make a segmentation color dictionary | |
segcolors = {} | |
for i, name in enumerate(labelnames): | |
key = ','.join(str(s) for s in high_contrast[i % len(high_contrast)]) | |
if key in segcolors: | |
segcolors[key] += '/' + name[0] | |
else: | |
segcolors[key] = name[0] | |
top_record['segcolors'] = segcolors | |
for layer in topk.keys(): | |
units, rankings = [], [] | |
record = dict(layer=layer, units=units, rankings=rankings) | |
# For every unit, we always have basic visualization information. | |
topa, topi = topk[layer].result() | |
lev = levels[layer] | |
for u in range(len(topa)): | |
units.append(dict( | |
unit=u, | |
interp=True, | |
level=lev[u].item(), | |
top=[dict(imgnum=i.item(), maxact=a.item()) | |
for i, a in zip(topi[u], topa[u])], | |
)) | |
rankings.append(dict(name="unit", score=list([ | |
u for u in range(len(topa))]))) | |
# TODO: consider including stats and ranking based on quantiles, | |
# variance, connectedness here. | |
# if we have labeldata, then every unit also gets a bunch of other info | |
if labeldata is not None: | |
lscore, qscore, cc, ic = [dat[layer] | |
for dat in [iou_scores, iqr_scores, ccs, ics]] | |
if iqrdata is not None: | |
# If we have IQR thresholds, assign labels based on that | |
max_iqr, max_iqr_level = iqrdata[:2] | |
best_label = max_iqr[layer].max(0)[1] | |
best_score = lscore[best_label, torch.arange(lscore.shape[1])] | |
best_qscore = qscore[best_label, torch.arange(lscore.shape[1])] | |
else: | |
# Otherwise, assign labels based on max iou | |
best_score, best_label = lscore.max(0) | |
best_qscore = qscore[best_label, torch.arange(qscore.shape[1])] | |
record['iou_threshold'] = iou_threshold, | |
for u, urec in enumerate(units): | |
score, qscore, label = ( | |
best_score[u], best_qscore[u], best_label[u]) | |
urec.update(dict( | |
iou=score.item(), | |
iou_iqr=qscore.item(), | |
lc=lcs[label].item(), | |
cc=cc[catnumber[labelnames[label][1]], u].item(), | |
ic=ic[label, u].item(), | |
interp=(qscore.item() > iqr_threshold and | |
score.item() > iou_threshold), | |
iou_labelnum=label.item(), | |
iou_label=labelnames[label.item()][0], | |
iou_cat=labelnames[label.item()][1], | |
)) | |
if maxioudata is not None: | |
max_iou, max_iou_level, max_iou_quantile = maxioudata | |
qualified_iou = max_iou[layer].clone() | |
# qualified_iou[max_iou_quantile[layer] > 0.75] = 0 | |
best_score, best_label = qualified_iou.max(0) | |
for u, urec in enumerate(units): | |
urec.update(dict( | |
maxiou=best_score[u].item(), | |
maxiou_label=labelnames[best_label[u].item()][0], | |
maxiou_cat=labelnames[best_label[u].item()][1], | |
maxiou_level=max_iou_level[layer][best_label[u], u].item(), | |
maxiou_quantile=max_iou_quantile[layer][ | |
best_label[u], u].item())) | |
if iqrdata is not None: | |
[max_iqr, max_iqr_level, max_iqr_quantile, | |
max_iqr_iou, max_iqr_agreement] = iqrdata | |
qualified_iqr = max_iqr[layer].clone() | |
qualified_iqr[max_iqr_quantile[layer] > 0.5] = 0 | |
best_score, best_label = qualified_iqr.max(0) | |
for u, urec in enumerate(units): | |
urec.update(dict( | |
iqr=best_score[u].item(), | |
iqr_label=labelnames[best_label[u].item()][0], | |
iqr_cat=labelnames[best_label[u].item()][1], | |
iqr_level=max_iqr_level[layer][best_label[u], u].item(), | |
iqr_quantile=max_iqr_quantile[layer][ | |
best_label[u], u].item(), | |
iqr_iou=max_iqr_iou[layer][best_label[u], u].item() | |
)) | |
if covariancedata is not None: | |
score = covariancedata[layer].correlation() | |
best_score, best_label = score.max(1) | |
for u, urec in enumerate(units): | |
urec.update(dict( | |
cor=best_score[u].item(), | |
cor_label=labelnames[best_label[u].item()][0], | |
cor_cat=labelnames[best_label[u].item()][1] | |
)) | |
if mergedata is not None: | |
# Final step: if the user passed any data to merge into the | |
# units, merge them now. This can be used, for example, to | |
# indiate that a unit is not interpretable based on some | |
# outside analysis of unit statistics. | |
for lrec in mergedata.get('layers', []): | |
if lrec['layer'] == layer: | |
break | |
else: | |
lrec = None | |
for u, urec in enumerate(lrec.get('units', []) if lrec else []): | |
units[u].update(urec) | |
# After populating per-unit info, populate per-layer ranking info | |
if labeldata is not None: | |
# Collect all labeled units | |
labelunits = defaultdict(list) | |
all_labelunits = defaultdict(list) | |
for u, urec in enumerate(units): | |
if urec['interp']: | |
labelunits[urec['iou_labelnum']].append(u) | |
all_labelunits[urec['iou_labelnum']].append(u) | |
# Sort all units in order with most popular label first. | |
label_ordering = sorted(units, | |
# Sort by: | |
key=lambda r: (-1 if r['interp'] else 0, # interpretable | |
-len(labelunits[r['iou_labelnum']]), # label freq, score | |
-max([units[u]['iou'] | |
for u in labelunits[r['iou_labelnum']]], default=0), | |
r['iou_labelnum'], # label | |
-r['iou'])) # unit score | |
# Add label and iou ranking. | |
rankings.append(dict(name="label", score=(numpy.argsort(list( | |
ur['unit'] for ur in label_ordering))).tolist())) | |
rankings.append(dict(name="max iou", metric="iou", score=list( | |
-ur['iou'] for ur in units))) | |
# Add ranking for top labels | |
# for labelnum in [n for n in sorted( | |
# all_labelunits.keys(), key=lambda x: | |
# -len(all_labelunits[x])) if len(all_labelunits[n])]: | |
# label = labelnames[labelnum][0] | |
# rankings.append(dict(name="%s-iou" % label, | |
# concept=label, metric='iou', | |
# score=(-lscore[labelnum, :]).tolist())) | |
# Collate labels by category then frequency. | |
record['labels'] = [dict( | |
label=labelnames[label][0], | |
labelnum=label, | |
units=labelunits[label], | |
cat=labelnames[label][1]) | |
for label in (sorted(labelunits.keys(), | |
# Sort by: | |
key=lambda l: (catorder.get( # category | |
labelnames[l][1], 0), | |
-len(labelunits[l]), # label freq | |
-max([units[u]['iou'] for u in labelunits[l]], | |
default=0) # score | |
))) if len(labelunits[label])] | |
# Total number of interpretable units. | |
record['interpretable'] = sum(len(group['units']) | |
for group in record['labels']) | |
# Make a bargraph of labels | |
os.makedirs(os.path.join(outdir, safe_dir_name(layer)), | |
exist_ok=True) | |
catgroups = OrderedDict() | |
for _, cat in sorted([(v, k) for k, v in catorder.items()]): | |
catgroups[cat] = [] | |
for rec in record['labels']: | |
if rec['cat'] not in catgroups: | |
catgroups[rec['cat']] = [] | |
catgroups[rec['cat']].append(rec['label']) | |
make_svg_bargraph( | |
[rec['label'] for rec in record['labels']], | |
[len(rec['units']) for rec in record['labels']], | |
[(cat, len(group)) for cat, group in catgroups.items()], | |
filename=os.path.join(outdir, safe_dir_name(layer), | |
'bargraph.svg')) | |
# Only show the bargraph if it is non-empty. | |
if len(record['labels']): | |
record['bargraph'] = 'bargraph.svg' | |
if maxioudata is not None: | |
rankings.append(dict(name="max maxiou", metric="maxiou", score=list( | |
-ur['maxiou'] for ur in units))) | |
if iqrdata is not None: | |
rankings.append(dict(name="max iqr", metric="iqr", score=list( | |
-ur['iqr'] for ur in units))) | |
if covariancedata is not None: | |
rankings.append(dict(name="max cor", metric="cor", score=list( | |
-ur['cor'] for ur in units))) | |
all_layers.append(record) | |
# Now add the same rankings to every layer... | |
all_labels = None | |
if rank_all_labels: | |
all_labels = [name for name, cat in labelnames] | |
if labeldata is not None: | |
# Count layers+quadrants with a given label, and sort by freq | |
counted_labels = defaultdict(int) | |
for label in [ | |
re.sub(r'-(?:t|b|l|r|tl|tr|bl|br)$', '', unitrec['iou_label']) | |
for record in all_layers for unitrec in record['units']]: | |
counted_labels[label] += 1 | |
if all_labels is None: | |
all_labels = [label for count, label in sorted((-v, k) | |
for k, v in counted_labels.items())] | |
for record in all_layers: | |
layer = record['layer'] | |
for label in all_labels: | |
labelnum = labelnumber[label] | |
record['rankings'].append(dict(name="%s-iou" % label, | |
concept=label, metric='iou', | |
score=(-iou_scores[layer][labelnum, :]).tolist())) | |
if maxioudata is not None: | |
if all_labels is None: | |
counted_labels = defaultdict(int) | |
for label in [ | |
re.sub(r'-(?:t|b|l|r|tl|tr|bl|br)$', '', | |
unitrec['maxiou_label']) | |
for record in all_layers for unitrec in record['units']]: | |
counted_labels[label] += 1 | |
all_labels = [label for count, label in sorted((-v, k) | |
for k, v in counted_labels.items())] | |
qualified_iou = max_iou[layer].clone() | |
qualified_iou[max_iou_quantile[layer] > 0.5] = 0 | |
for record in all_layers: | |
layer = record['layer'] | |
for label in all_labels: | |
labelnum = labelnumber[label] | |
record['rankings'].append(dict(name="%s-maxiou" % label, | |
concept=label, metric='maxiou', | |
score=(-qualified_iou[labelnum, :]).tolist())) | |
if iqrdata is not None: | |
if all_labels is None: | |
counted_labels = defaultdict(int) | |
for label in [ | |
re.sub(r'-(?:t|b|l|r|tl|tr|bl|br)$', '', | |
unitrec['iqr_label']) | |
for record in all_layers for unitrec in record['units']]: | |
counted_labels[label] += 1 | |
all_labels = [label for count, label in sorted((-v, k) | |
for k, v in counted_labels.items())] | |
# qualified_iqr[max_iqr_quantile[layer] > 0.5] = 0 | |
for record in all_layers: | |
layer = record['layer'] | |
qualified_iqr = max_iqr[layer].clone() | |
for label in all_labels: | |
labelnum = labelnumber[label] | |
record['rankings'].append(dict(name="%s-iqr" % label, | |
concept=label, metric='iqr', | |
score=(-qualified_iqr[labelnum, :]).tolist())) | |
if covariancedata is not None: | |
if all_labels is None: | |
counted_labels = defaultdict(int) | |
for label in [ | |
re.sub(r'-(?:t|b|l|r|tl|tr|bl|br)$', '', | |
unitrec['cor_label']) | |
for record in all_layers for unitrec in record['units']]: | |
counted_labels[label] += 1 | |
all_labels = [label for count, label in sorted((-v, k) | |
for k, v in counted_labels.items())] | |
for record in all_layers: | |
layer = record['layer'] | |
score = covariancedata[layer].correlation() | |
for label in all_labels: | |
labelnum = labelnumber[label] | |
record['rankings'].append(dict(name="%s-cor" % label, | |
concept=label, metric='cor', | |
score=(-score[:, labelnum]).tolist())) | |
for record in all_layers: | |
layer = record['layer'] | |
# Dump per-layer json inside per-layer directory | |
record['dirname'] = '.' | |
with open(os.path.join(outdir, safe_dir_name(layer), 'dissect.json'), | |
'w') as jsonfile: | |
top_record['layers'] = [record] | |
json.dump(top_record, jsonfile, indent=1) | |
# Copy the per-layer html | |
shutil.copy(os.path.join(srcdir, 'dissect.html'), | |
os.path.join(outdir, safe_dir_name(layer), 'dissect.html')) | |
record['dirname'] = safe_dir_name(layer) | |
# Dump all-layer json in parent directory | |
with open(os.path.join(outdir, 'dissect.json'), 'w') as jsonfile: | |
top_record['layers'] = all_layers | |
json.dump(top_record, jsonfile, indent=1) | |
# Copy the all-layer html | |
shutil.copy(os.path.join(srcdir, 'dissect.html'), | |
os.path.join(outdir, 'dissect.html')) | |
shutil.copy(os.path.join(srcdir, 'edit.html'), | |
os.path.join(outdir, 'edit.html')) | |
def generate_images(outdir, model, dataset, topk, levels, | |
segrunner, row_length=None, gap_pixels=5, | |
row_images=True, single_images=False, prefix='', | |
batch_size=100, num_workers=24): | |
''' | |
Creates an image strip file for every unit of every retained layer | |
of the model, in the format [outdir]/[layername]/[unitnum]-top.jpg. | |
Assumes that the indexes of topk refer to the indexes of dataset. | |
Limits each strip to the top row_length images. | |
''' | |
progress = default_progress() | |
needed_images = {} | |
if row_images is False: | |
row_length = 1 | |
# Pass 1: needed_images lists all images that are topk for some unit. | |
for layer in topk: | |
topresult = topk[layer].result()[1].cpu() | |
for unit, row in enumerate(topresult): | |
for rank, imgnum in enumerate(row[:row_length]): | |
imgnum = imgnum.item() | |
if imgnum not in needed_images: | |
needed_images[imgnum] = [] | |
needed_images[imgnum].append((layer, unit, rank)) | |
levels = {k: v.cpu().numpy() for k, v in levels.items()} | |
row_length = len(row[:row_length]) | |
needed_sample = FixedSubsetSampler(sorted(needed_images.keys())) | |
device = next(model.parameters()).device | |
segloader = torch.utils.data.DataLoader(dataset, | |
batch_size=batch_size, num_workers=num_workers, | |
pin_memory=(device.type == 'cuda'), | |
sampler=needed_sample) | |
vizgrid, maskgrid, origrid, seggrid = [{} for _ in range(4)] | |
# Pass 2: populate vizgrid with visualizations of top units. | |
pool = None | |
for i, batch in enumerate( | |
progress(segloader, desc='Making images')): | |
# Reverse transformation to get the image in byte form. | |
seg, _, byte_im, _ = segrunner.run_and_segment_batch(batch, model, | |
want_rgb=True) | |
torch_features = model.retained_features() | |
scale_offset = getattr(model, 'scale_offset', None) | |
if pool is None: | |
# Distribute the work across processes: create shared mmaps. | |
for layer, tf in torch_features.items(): | |
[vizgrid[layer], maskgrid[layer], origrid[layer], | |
seggrid[layer]] = [ | |
create_temp_mmap_grid((tf.shape[1], | |
byte_im.shape[1], row_length, | |
byte_im.shape[2] + gap_pixels, depth), | |
dtype='uint8', | |
fill=255) | |
for depth in [3, 4, 3, 3]] | |
# Pass those mmaps to worker processes. | |
pool = WorkerPool(worker=VisualizeImageWorker, | |
memmap_grid_info=[ | |
{layer: (g.filename, g.shape, g.dtype) | |
for layer, g in grid.items()} | |
for grid in [vizgrid, maskgrid, origrid, seggrid]]) | |
byte_im = byte_im.cpu().numpy() | |
numpy_seg = seg.cpu().numpy() | |
features = {} | |
for index in range(len(byte_im)): | |
imgnum = needed_sample.samples[index + i*segloader.batch_size] | |
for layer, unit, rank in needed_images[imgnum]: | |
if layer not in features: | |
features[layer] = torch_features[layer].cpu().numpy() | |
pool.add(layer, unit, rank, | |
byte_im[index], | |
features[layer][index, unit], | |
levels[layer][unit], | |
scale_offset[layer] if scale_offset else None, | |
numpy_seg[index]) | |
pool.join() | |
# Pass 3: save image strips as [outdir]/[layer]/[unitnum]-[top/orig].jpg | |
pool = WorkerPool(worker=SaveImageWorker) | |
for layer, vg in progress(vizgrid.items(), desc='Saving images'): | |
os.makedirs(os.path.join(outdir, safe_dir_name(layer), | |
prefix + 'image'), exist_ok=True) | |
if single_images: | |
os.makedirs(os.path.join(outdir, safe_dir_name(layer), | |
prefix + 's-image'), exist_ok=True) | |
og, sg, mg = origrid[layer], seggrid[layer], maskgrid[layer] | |
for unit in progress(range(len(vg)), desc='Units'): | |
for suffix, grid in [('top.jpg', vg), ('orig.jpg', og), | |
('seg.png', sg), ('mask.png', mg)]: | |
strip = grid[unit].reshape( | |
(grid.shape[1], grid.shape[2] * grid.shape[3], | |
grid.shape[4])) | |
if row_images: | |
filename = os.path.join(outdir, safe_dir_name(layer), | |
prefix + 'image', '%d-%s' % (unit, suffix)) | |
pool.add(strip[:,:-gap_pixels,:].copy(), filename) | |
# Image.fromarray(strip[:,:-gap_pixels,:]).save(filename, | |
# optimize=True, quality=80) | |
if single_images: | |
single_filename = os.path.join(outdir, safe_dir_name(layer), | |
prefix + 's-image', '%d-%s' % (unit, suffix)) | |
pool.add(strip[:,:strip.shape[1] // row_length | |
- gap_pixels,:].copy(), single_filename) | |
# Image.fromarray(strip[:,:strip.shape[1] // row_length | |
# - gap_pixels,:]).save(single_filename, | |
# optimize=True, quality=80) | |
pool.join() | |
# Delete the shared memory map files | |
clear_global_shared_files([g.filename | |
for grid in [vizgrid, maskgrid, origrid, seggrid] | |
for g in grid.values()]) | |
global_shared_files = {} | |
def create_temp_mmap_grid(shape, dtype, fill): | |
dtype = numpy.dtype(dtype) | |
filename = os.path.join(tempfile.mkdtemp(), 'temp-%s-%s.mmap' % | |
('x'.join('%d' % s for s in shape), dtype.name)) | |
fid = open(filename, mode='w+b') | |
original = numpy.memmap(fid, dtype=dtype, mode='w+', shape=shape) | |
original.fid = fid | |
original[...] = fill | |
global_shared_files[filename] = original | |
return original | |
def shared_temp_mmap_grid(filename, shape, dtype): | |
if filename not in global_shared_files: | |
global_shared_files[filename] = numpy.memmap( | |
filename, dtype=dtype, mode='r+', shape=shape) | |
return global_shared_files[filename] | |
def clear_global_shared_files(filenames): | |
for fn in filenames: | |
if fn in global_shared_files: | |
del global_shared_files[fn] | |
try: | |
os.unlink(fn) | |
except OSError: | |
pass | |
class VisualizeImageWorker(WorkerBase): | |
def setup(self, memmap_grid_info): | |
self.vizgrid, self.maskgrid, self.origrid, self.seggrid = [ | |
{layer: shared_temp_mmap_grid(*info) | |
for layer, info in grid.items()} | |
for grid in memmap_grid_info] | |
def work(self, layer, unit, rank, | |
byte_im, acts, level, scale_offset, seg): | |
self.origrid[layer][unit,:,rank,:byte_im.shape[0],:] = byte_im | |
[self.vizgrid[layer][unit,:,rank,:byte_im.shape[0],:], | |
self.maskgrid[layer][unit,:,rank,:byte_im.shape[0],:]] = ( | |
activation_visualization( | |
byte_im, | |
acts, | |
level, | |
scale_offset=scale_offset, | |
return_mask=True)) | |
self.seggrid[layer][unit,:,rank,:byte_im.shape[0],:] = ( | |
segment_visualization(seg, byte_im.shape[0:2])) | |
class SaveImageWorker(WorkerBase): | |
def work(self, data, filename): | |
Image.fromarray(data).save(filename, optimize=True, quality=80) | |
def score_tally_stats(label_category, tc, truth, cc, ic): | |
pred = cc[label_category] | |
total = tc[label_category][:, None] | |
truth = truth[:, None] | |
epsilon = 1e-20 # avoid division-by-zero | |
union = pred + truth - ic | |
iou = ic.double() / (union.double() + epsilon) | |
arr = torch.empty(size=(2, 2) + ic.shape, dtype=ic.dtype, device=ic.device) | |
arr[0, 0] = ic | |
arr[0, 1] = pred - ic | |
arr[1, 0] = truth - ic | |
arr[1, 1] = total - union | |
arr = arr.double() / total.double() | |
mi = mutual_information(arr) | |
je = joint_entropy(arr) | |
iqr = mi / je | |
iqr[torch.isnan(iqr)] = 0 # Zero out any 0/0 | |
return iou, iqr | |
def collect_quantiles_and_topk(outdir, model, segloader, | |
segrunner, k=100, resolution=1024): | |
''' | |
Collects (estimated) quantile information and (exact) sorted top-K lists | |
for every channel in the retained layers of the model. Returns | |
a map of quantiles (one RunningQuantile for each layer) along with | |
a map of topk (one RunningTopK for each layer). | |
''' | |
device = next(model.parameters()).device | |
features = model.retained_features() | |
cached_quantiles = { | |
layer: load_quantile_if_present(os.path.join(outdir, | |
safe_dir_name(layer)), 'quantiles.npz', | |
device=torch.device('cpu')) | |
for layer in features } | |
cached_topks = { | |
layer: load_topk_if_present(os.path.join(outdir, | |
safe_dir_name(layer)), 'topk.npz', | |
device=torch.device('cpu')) | |
for layer in features } | |
if (all(value is not None for value in cached_quantiles.values()) and | |
all(value is not None for value in cached_topks.values())): | |
return cached_quantiles, cached_topks | |
layer_batch_size = 8 | |
all_layers = list(features.keys()) | |
layer_batches = [all_layers[i:i+layer_batch_size] | |
for i in range(0, len(all_layers), layer_batch_size)] | |
quantiles, topks = {}, {} | |
progress = default_progress() | |
for layer_batch in layer_batches: | |
for i, batch in enumerate(progress(segloader, desc='Quantiles')): | |
# We don't actually care about the model output. | |
model(batch[0].to(device)) | |
features = model.retained_features() | |
# We care about the retained values | |
for key in layer_batch: | |
value = features[key] | |
if topks.get(key, None) is None: | |
topks[key] = RunningTopK(k) | |
if quantiles.get(key, None) is None: | |
quantiles[key] = RunningQuantile(resolution=resolution) | |
topvalue = value | |
if len(value.shape) > 2: | |
topvalue, _ = value.view(*(value.shape[:2] + (-1,))).max(2) | |
# Put the channel index last. | |
value = value.permute( | |
(0,) + tuple(range(2, len(value.shape))) + (1,) | |
).contiguous().view(-1, value.shape[1]) | |
quantiles[key].add(value) | |
topks[key].add(topvalue) | |
# Save GPU memory | |
for key in layer_batch: | |
quantiles[key].to_(torch.device('cpu')) | |
topks[key].to_(torch.device('cpu')) | |
for layer in quantiles: | |
save_state_dict(quantiles[layer], | |
os.path.join(outdir, safe_dir_name(layer), 'quantiles.npz')) | |
save_state_dict(topks[layer], | |
os.path.join(outdir, safe_dir_name(layer), 'topk.npz')) | |
return quantiles, topks | |
def collect_bincounts(outdir, model, segloader, levels, segrunner): | |
''' | |
Returns label_counts, category_activation_counts, and intersection_counts, | |
across the data set, counting the pixels of intersection between upsampled, | |
thresholded model featuremaps, with segmentation classes in the segloader. | |
label_counts (independent of model): pixels across the data set that | |
are labeled with the given label. | |
category_activation_counts (one per layer): for each feature channel, | |
pixels across the dataset where the channel exceeds the level | |
threshold. There is one count per category: activations only | |
contribute to the categories for which any category labels are | |
present on the images. | |
intersection_counts (one per layer): for each feature channel and | |
label, pixels across the dataset where the channel exceeds | |
the level, and the labeled segmentation class is also present. | |
This is a performance-sensitive function. Best performance is | |
achieved with a counting scheme which assumes a segloader with | |
batch_size 1. | |
''' | |
# Load cached data if present | |
(iou_scores, iqr_scores, | |
total_counts, label_counts, category_activation_counts, | |
intersection_counts) = {}, {}, None, None, {}, {} | |
found_all = True | |
for layer in model.retained_features(): | |
filename = os.path.join(outdir, safe_dir_name(layer), 'bincounts.npz') | |
if os.path.isfile(filename): | |
data = numpy.load(filename) | |
iou_scores[layer] = torch.from_numpy(data['iou_scores']) | |
iqr_scores[layer] = torch.from_numpy(data['iqr_scores']) | |
total_counts = torch.from_numpy(data['total_counts']) | |
label_counts = torch.from_numpy(data['label_counts']) | |
category_activation_counts[layer] = torch.from_numpy( | |
data['category_activation_counts']) | |
intersection_counts[layer] = torch.from_numpy( | |
data['intersection_counts']) | |
else: | |
found_all = False | |
if found_all: | |
return (iou_scores, iqr_scores, | |
total_counts, label_counts, category_activation_counts, | |
intersection_counts) | |
device = next(model.parameters()).device | |
labelcat, categories = segrunner.get_label_and_category_names() | |
label_category = [categories.index(c) if c in categories else 0 | |
for l, c in labelcat] | |
num_labels, num_categories = (len(n) for n in [labelcat, categories]) | |
# One-hot vector of category for each label | |
labelcat = torch.zeros(num_labels, num_categories, | |
dtype=torch.long, device=device) | |
labelcat.scatter_(1, torch.from_numpy(numpy.array(label_category, | |
dtype='int64')).to(device)[:,None], 1) | |
# Running bincounts | |
# activation_counts = {} | |
assert segloader.batch_size == 1 # category_activation_counts needs this. | |
category_activation_counts = {} | |
intersection_counts = {} | |
label_counts = torch.zeros(num_labels, dtype=torch.long, device=device) | |
total_counts = torch.zeros(num_categories, dtype=torch.long, device=device) | |
progress = default_progress() | |
scale_offset_map = getattr(model, 'scale_offset', None) | |
upsample_grids = {} | |
# total_batch_categories = torch.zeros( | |
# labelcat.shape[1], dtype=torch.long, device=device) | |
for i, batch in enumerate(progress(segloader, desc='Bincounts')): | |
seg, batch_label_counts, _, imshape = segrunner.run_and_segment_batch( | |
batch, model, want_bincount=True, want_rgb=True) | |
bc = batch_label_counts.cpu() | |
batch_label_counts = batch_label_counts.to(device) | |
seg = seg.to(device) | |
features = model.retained_features() | |
# Accumulate bincounts and identify nonzeros | |
label_counts += batch_label_counts[0] | |
batch_labels = bc[0].nonzero()[:,0] | |
batch_categories = labelcat[batch_labels].max(0)[0] | |
total_counts += batch_categories * ( | |
seg.shape[0] * seg.shape[2] * seg.shape[3]) | |
for key, value in features.items(): | |
if key not in upsample_grids: | |
upsample_grids[key] = upsample_grid(value.shape[2:], | |
seg.shape[2:], imshape, | |
scale_offset=scale_offset_map.get(key, None) | |
if scale_offset_map is not None else None, | |
dtype=value.dtype, device=value.device) | |
upsampled = torch.nn.functional.grid_sample(value, | |
upsample_grids[key], padding_mode='border') | |
amask = (upsampled > levels[key][None,:,None,None].to( | |
upsampled.device)) | |
ac = amask.int().view(amask.shape[1], -1).sum(1) | |
# if key not in activation_counts: | |
# activation_counts[key] = ac | |
# else: | |
# activation_counts[key] += ac | |
# The fastest approach: sum over each label separately! | |
for label in batch_labels.tolist(): | |
if label == 0: | |
continue # ignore the background label | |
imask = amask * ((seg == label).max(dim=1, keepdim=True)[0]) | |
ic = imask.int().view(imask.shape[1], -1).sum(1) | |
if key not in intersection_counts: | |
intersection_counts[key] = torch.zeros(num_labels, | |
amask.shape[1], dtype=torch.long, device=device) | |
intersection_counts[key][label] += ic | |
# Count activations within images that have category labels. | |
# Note: This only makes sense with batch-size one | |
# total_batch_categories += batch_categories | |
cc = batch_categories[:,None] * ac[None,:] | |
if key not in category_activation_counts: | |
category_activation_counts[key] = cc | |
else: | |
category_activation_counts[key] += cc | |
iou_scores = {} | |
iqr_scores = {} | |
for k in intersection_counts: | |
iou_scores[k], iqr_scores[k] = score_tally_stats( | |
label_category, total_counts, label_counts, | |
category_activation_counts[k], intersection_counts[k]) | |
for k in intersection_counts: | |
numpy.savez(os.path.join(outdir, safe_dir_name(k), 'bincounts.npz'), | |
iou_scores=iou_scores[k].cpu().numpy(), | |
iqr_scores=iqr_scores[k].cpu().numpy(), | |
total_counts=total_counts.cpu().numpy(), | |
label_counts=label_counts.cpu().numpy(), | |
category_activation_counts=category_activation_counts[k] | |
.cpu().numpy(), | |
intersection_counts=intersection_counts[k].cpu().numpy(), | |
levels=levels[k].cpu().numpy()) | |
return (iou_scores, iqr_scores, | |
total_counts, label_counts, category_activation_counts, | |
intersection_counts) | |
def collect_cond_quantiles(outdir, model, segloader, segrunner): | |
''' | |
Returns maxiou and maxiou_level across the data set, one per layer. | |
This is a performance-sensitive function. Best performance is | |
achieved with a counting scheme which assumes a segloader with | |
batch_size 1. | |
''' | |
device = next(model.parameters()).device | |
cached_cond_quantiles = { | |
layer: load_conditional_quantile_if_present(os.path.join(outdir, | |
safe_dir_name(layer)), 'cond_quantiles.npz') # on cpu | |
for layer in model.retained_features() } | |
label_fracs = load_npy_if_present(outdir, 'label_fracs.npy', 'cpu') | |
if label_fracs is not None and all( | |
value is not None for value in cached_cond_quantiles.values()): | |
return cached_cond_quantiles, label_fracs | |
labelcat, categories = segrunner.get_label_and_category_names() | |
label_category = [categories.index(c) if c in categories else 0 | |
for l, c in labelcat] | |
num_labels, num_categories = (len(n) for n in [labelcat, categories]) | |
# One-hot vector of category for each label | |
labelcat = torch.zeros(num_labels, num_categories, | |
dtype=torch.long, device=device) | |
labelcat.scatter_(1, torch.from_numpy(numpy.array(label_category, | |
dtype='int64')).to(device)[:,None], 1) | |
# Running maxiou | |
assert segloader.batch_size == 1 # category_activation_counts needs this. | |
conditional_quantiles = {} | |
label_counts = torch.zeros(num_labels, dtype=torch.long, device=device) | |
pixel_count = 0 | |
progress = default_progress() | |
scale_offset_map = getattr(model, 'scale_offset', None) | |
upsample_grids = {} | |
common_conditions = set() | |
if label_fracs is None or label_fracs is 0: | |
for i, batch in enumerate(progress(segloader, desc='label fracs')): | |
seg, batch_label_counts, im, _ = segrunner.run_and_segment_batch( | |
batch, model, want_bincount=True, want_rgb=True) | |
batch_label_counts = batch_label_counts.to(device) | |
features = model.retained_features() | |
# Accumulate bincounts and identify nonzeros | |
label_counts += batch_label_counts[0] | |
pixel_count += seg.shape[2] * seg.shape[3] | |
label_fracs = (label_counts.cpu().float() / pixel_count)[:, None, None] | |
numpy.save(os.path.join(outdir, 'label_fracs.npy'), label_fracs) | |
skip_threshold = 1e-4 | |
skip_labels = set(i.item() | |
for i in (label_fracs.view(-1) < skip_threshold).nonzero().view(-1)) | |
for layer in progress(model.retained_features().keys(), desc='CQ layers'): | |
if cached_cond_quantiles.get(layer, None) is not None: | |
conditional_quantiles[layer] = cached_cond_quantiles[layer] | |
continue | |
for i, batch in enumerate(progress(segloader, desc='Condquant')): | |
seg, batch_label_counts, _, imshape = ( | |
segrunner.run_and_segment_batch( | |
batch, model, want_bincount=True, want_rgb=True)) | |
bc = batch_label_counts.cpu() | |
batch_label_counts = batch_label_counts.to(device) | |
features = model.retained_features() | |
# Accumulate bincounts and identify nonzeros | |
label_counts += batch_label_counts[0] | |
pixel_count += seg.shape[2] * seg.shape[3] | |
batch_labels = bc[0].nonzero()[:,0] | |
batch_categories = labelcat[batch_labels].max(0)[0] | |
cpu_seg = None | |
value = features[layer] | |
if layer not in upsample_grids: | |
upsample_grids[layer] = upsample_grid(value.shape[2:], | |
seg.shape[2:], imshape, | |
scale_offset=scale_offset_map.get(layer, None) | |
if scale_offset_map is not None else None, | |
dtype=value.dtype, device=value.device) | |
if layer not in conditional_quantiles: | |
conditional_quantiles[layer] = RunningConditionalQuantile( | |
resolution=2048) | |
upsampled = torch.nn.functional.grid_sample(value, | |
upsample_grids[layer], padding_mode='border').view( | |
value.shape[1], -1) | |
conditional_quantiles[layer].add(('all',), upsampled.t()) | |
cpu_upsampled = None | |
for label in batch_labels.tolist(): | |
if label in skip_labels: | |
continue | |
label_key = ('label', label) | |
if label_key in common_conditions: | |
imask = (seg == label).max(dim=1)[0].view(-1) | |
intersected = upsampled[:, imask] | |
conditional_quantiles[layer].add(('label', label), | |
intersected.t()) | |
else: | |
if cpu_seg is None: | |
cpu_seg = seg.cpu() | |
if cpu_upsampled is None: | |
cpu_upsampled = upsampled.cpu() | |
imask = (cpu_seg == label).max(dim=1)[0].view(-1) | |
intersected = cpu_upsampled[:, imask] | |
conditional_quantiles[layer].add(('label', label), | |
intersected.t()) | |
if num_categories > 1: | |
for cat in batch_categories.nonzero()[:,0]: | |
conditional_quantiles[layer].add(('cat', cat.item()), | |
upsampled.t()) | |
# Move the most common conditions to the GPU. | |
if i and not i & (i - 1): # if i is a power of 2: | |
cq = conditional_quantiles[layer] | |
common_conditions = set(cq.most_common_conditions(64)) | |
cq.to_('cpu', [k for k in cq.running_quantiles.keys() | |
if k not in common_conditions]) | |
# When a layer is done, get it off the GPU | |
conditional_quantiles[layer].to_('cpu') | |
label_fracs = (label_counts.cpu().float() / pixel_count)[:, None, None] | |
for cq in conditional_quantiles.values(): | |
cq.to_('cpu') | |
for layer in conditional_quantiles: | |
save_state_dict(conditional_quantiles[layer], | |
os.path.join(outdir, safe_dir_name(layer), 'cond_quantiles.npz')) | |
numpy.save(os.path.join(outdir, 'label_fracs.npy'), label_fracs) | |
return conditional_quantiles, label_fracs | |
def collect_maxiou(outdir, model, segloader, segrunner): | |
''' | |
Returns maxiou and maxiou_level across the data set, one per layer. | |
This is a performance-sensitive function. Best performance is | |
achieved with a counting scheme which assumes a segloader with | |
batch_size 1. | |
''' | |
device = next(model.parameters()).device | |
conditional_quantiles, label_fracs = collect_cond_quantiles( | |
outdir, model, segloader, segrunner) | |
labelcat, categories = segrunner.get_label_and_category_names() | |
label_category = [categories.index(c) if c in categories else 0 | |
for l, c in labelcat] | |
num_labels, num_categories = (len(n) for n in [labelcat, categories]) | |
label_list = [('label', i) for i in range(num_labels)] | |
category_list = [('all',)] if num_categories <= 1 else ( | |
[('cat', i) for i in range(num_categories)]) | |
max_iou, max_iou_level, max_iou_quantile = {}, {}, {} | |
fracs = torch.logspace(-3, 0, 100) | |
progress = default_progress() | |
for layer, cq in progress(conditional_quantiles.items(), desc='Maxiou'): | |
levels = cq.conditional(('all',)).quantiles(1 - fracs) | |
denoms = 1 - cq.collected_normalize(category_list, levels) | |
isects = (1 - cq.collected_normalize(label_list, levels)) * label_fracs | |
unions = label_fracs + denoms[label_category, :, :] - isects | |
iou = isects / unions | |
# TODO: erase any for which threshold is bad | |
max_iou[layer], level_bucket = iou.max(2) | |
max_iou_level[layer] = levels[ | |
torch.arange(levels.shape[0])[None,:], level_bucket] | |
max_iou_quantile[layer] = fracs[level_bucket] | |
for layer in model.retained_features(): | |
numpy.savez(os.path.join(outdir, safe_dir_name(layer), 'max_iou.npz'), | |
max_iou=max_iou[layer].cpu().numpy(), | |
max_iou_level=max_iou_level[layer].cpu().numpy(), | |
max_iou_quantile=max_iou_quantile[layer].cpu().numpy()) | |
return (max_iou, max_iou_level, max_iou_quantile) | |
def collect_iqr(outdir, model, segloader, segrunner): | |
''' | |
Returns iqr and iqr_level. | |
This is a performance-sensitive function. Best performance is | |
achieved with a counting scheme which assumes a segloader with | |
batch_size 1. | |
''' | |
max_iqr, max_iqr_level, max_iqr_quantile, max_iqr_iou = {}, {}, {}, {} | |
max_iqr_agreement = {} | |
found_all = True | |
for layer in model.retained_features(): | |
filename = os.path.join(outdir, safe_dir_name(layer), 'iqr.npz') | |
if os.path.isfile(filename): | |
data = numpy.load(filename) | |
max_iqr[layer] = torch.from_numpy(data['max_iqr']) | |
max_iqr_level[layer] = torch.from_numpy(data['max_iqr_level']) | |
max_iqr_quantile[layer] = torch.from_numpy(data['max_iqr_quantile']) | |
max_iqr_iou[layer] = torch.from_numpy(data['max_iqr_iou']) | |
max_iqr_agreement[layer] = torch.from_numpy( | |
data['max_iqr_agreement']) | |
else: | |
found_all = False | |
if found_all: | |
return (max_iqr, max_iqr_level, max_iqr_quantile, max_iqr_iou, | |
max_iqr_agreement) | |
device = next(model.parameters()).device | |
conditional_quantiles, label_fracs = collect_cond_quantiles( | |
outdir, model, segloader, segrunner) | |
labelcat, categories = segrunner.get_label_and_category_names() | |
label_category = [categories.index(c) if c in categories else 0 | |
for l, c in labelcat] | |
num_labels, num_categories = (len(n) for n in [labelcat, categories]) | |
label_list = [('label', i) for i in range(num_labels)] | |
category_list = [('all',)] if num_categories <= 1 else ( | |
[('cat', i) for i in range(num_categories)]) | |
full_mi, full_je, full_iqr = {}, {}, {} | |
fracs = torch.logspace(-3, 0, 100) | |
progress = default_progress() | |
for layer, cq in progress(conditional_quantiles.items(), desc='IQR'): | |
levels = cq.conditional(('all',)).quantiles(1 - fracs) | |
truth = label_fracs.to(device) | |
preds = (1 - cq.collected_normalize(category_list, levels) | |
)[label_category, :, :].to(device) | |
cond_isects = 1 - cq.collected_normalize(label_list, levels).to(device) | |
isects = cond_isects * truth | |
unions = truth + preds - isects | |
arr = torch.empty(size=(2, 2) + isects.shape, dtype=isects.dtype, | |
device=device) | |
arr[0, 0] = isects | |
arr[0, 1] = preds - isects | |
arr[1, 0] = truth - isects | |
arr[1, 1] = 1 - unions | |
arr.clamp_(0, 1) | |
mi = mutual_information(arr) | |
mi[:,:,-1] = 0 # at the 1.0 quantile should be no MI. | |
# Don't trust mi when less than label_frac is less than 1e-3, | |
# because our samples are too small. | |
mi[label_fracs.view(-1) < 1e-3, :, :] = 0 | |
je = joint_entropy(arr) | |
iqr = mi / je | |
iqr[torch.isnan(iqr)] = 0 # Zero out any 0/0 | |
full_mi[layer] = mi.cpu() | |
full_je[layer] = je.cpu() | |
full_iqr[layer] = iqr.cpu() | |
del mi, je | |
agreement = isects + arr[1, 1] | |
# When optimizing, maximize only over those pairs where the | |
# unit is positively correlated with the label, and where the | |
# threshold level is positive | |
positive_iqr = iqr | |
positive_iqr[agreement <= 0.8] = 0 | |
positive_iqr[(levels <= 0.0)[None, :, :].expand(positive_iqr.shape)] = 0 | |
# TODO: erase any for which threshold is bad | |
maxiqr, level_bucket = positive_iqr.max(2) | |
max_iqr[layer] = maxiqr.cpu() | |
max_iqr_level[layer] = levels.to(device)[ | |
torch.arange(levels.shape[0])[None,:], level_bucket].cpu() | |
max_iqr_quantile[layer] = fracs.to(device)[level_bucket].cpu() | |
max_iqr_agreement[layer] = agreement[ | |
torch.arange(agreement.shape[0])[:, None], | |
torch.arange(agreement.shape[1])[None, :], | |
level_bucket].cpu() | |
# Compute the iou that goes with each maximized iqr | |
matching_iou = (isects[ | |
torch.arange(isects.shape[0])[:, None], | |
torch.arange(isects.shape[1])[None, :], | |
level_bucket] / | |
unions[ | |
torch.arange(unions.shape[0])[:, None], | |
torch.arange(unions.shape[1])[None, :], | |
level_bucket]) | |
matching_iou[torch.isnan(matching_iou)] = 0 | |
max_iqr_iou[layer] = matching_iou.cpu() | |
for layer in model.retained_features(): | |
numpy.savez(os.path.join(outdir, safe_dir_name(layer), 'iqr.npz'), | |
max_iqr=max_iqr[layer].cpu().numpy(), | |
max_iqr_level=max_iqr_level[layer].cpu().numpy(), | |
max_iqr_quantile=max_iqr_quantile[layer].cpu().numpy(), | |
max_iqr_iou=max_iqr_iou[layer].cpu().numpy(), | |
max_iqr_agreement=max_iqr_agreement[layer].cpu().numpy(), | |
full_mi=full_mi[layer].cpu().numpy(), | |
full_je=full_je[layer].cpu().numpy(), | |
full_iqr=full_iqr[layer].cpu().numpy()) | |
return (max_iqr, max_iqr_level, max_iqr_quantile, max_iqr_iou, | |
max_iqr_agreement) | |
def mutual_information(arr): | |
total = 0 | |
for j in range(arr.shape[0]): | |
for k in range(arr.shape[1]): | |
joint = arr[j,k] | |
ind = arr[j,:].sum(dim=0) * arr[:,k].sum(dim=0) | |
term = joint * (joint / ind).log() | |
term[torch.isnan(term)] = 0 | |
total += term | |
return total.clamp_(0) | |
def joint_entropy(arr): | |
total = 0 | |
for j in range(arr.shape[0]): | |
for k in range(arr.shape[1]): | |
joint = arr[j,k] | |
term = joint * joint.log() | |
term[torch.isnan(term)] = 0 | |
total += term | |
return (-total).clamp_(0) | |
def information_quality_ratio(arr): | |
iqr = mutual_information(arr) / joint_entropy(arr) | |
iqr[torch.isnan(iqr)] = 0 | |
return iqr | |
def collect_covariance(outdir, model, segloader, segrunner): | |
''' | |
Returns label_mean, label_variance, unit_mean, unit_variance, | |
and cross_covariance across the data set. | |
label_mean, label_variance (independent of model): | |
treating the label as a one-hot, each label's mean and variance. | |
unit_mean, unit_variance (one per layer): for each feature channel, | |
the mean and variance of the activations in that channel. | |
cross_covariance (one per layer): the cross covariance between the | |
labels and the units in the layer. | |
''' | |
device = next(model.parameters()).device | |
cached_covariance = { | |
layer: load_covariance_if_present(os.path.join(outdir, | |
safe_dir_name(layer)), 'covariance.npz', device=device) | |
for layer in model.retained_features() } | |
if all(value is not None for value in cached_covariance.values()): | |
return cached_covariance | |
labelcat, categories = segrunner.get_label_and_category_names() | |
label_category = [categories.index(c) if c in categories else 0 | |
for l, c in labelcat] | |
num_labels, num_categories = (len(n) for n in [labelcat, categories]) | |
# Running covariance | |
cov = {} | |
progress = default_progress() | |
scale_offset_map = getattr(model, 'scale_offset', None) | |
upsample_grids = {} | |
for i, batch in enumerate(progress(segloader, desc='Covariance')): | |
seg, _, _, imshape = segrunner.run_and_segment_batch(batch, model, | |
want_rgb=True) | |
features = model.retained_features() | |
ohfeats = multilabel_onehot(seg, num_labels, ignore_index=0) | |
# Accumulate bincounts and identify nonzeros | |
for key, value in features.items(): | |
if key not in upsample_grids: | |
upsample_grids[key] = upsample_grid(value.shape[2:], | |
seg.shape[2:], imshape, | |
scale_offset=scale_offset_map.get(key, None) | |
if scale_offset_map is not None else None, | |
dtype=value.dtype, device=value.device) | |
upsampled = torch.nn.functional.grid_sample(value, | |
upsample_grids[key].expand( | |
(value.shape[0],) + upsample_grids[key].shape[1:]), | |
padding_mode='border') | |
if key not in cov: | |
cov[key] = RunningCrossCovariance() | |
cov[key].add(upsampled, ohfeats) | |
for layer in cov: | |
save_state_dict(cov[layer], | |
os.path.join(outdir, safe_dir_name(layer), 'covariance.npz')) | |
return cov | |
def multilabel_onehot(labels, num_labels, dtype=None, ignore_index=None): | |
''' | |
Converts a multilabel tensor into a onehot tensor. | |
The input labels is a tensor of shape (samples, multilabels, y, x). | |
The output is a tensor of shape (samples, num_labels, y, x). | |
If ignore_index is specified, labels with that index are ignored. | |
Each x in labels should be 0 <= x < num_labels, or x == ignore_index. | |
''' | |
assert ignore_index is None or ignore_index <= 0 | |
if dtype is None: | |
dtype = torch.float | |
device = labels.device | |
chans = num_labels + (-ignore_index if ignore_index else 0) | |
outshape = (labels.shape[0], chans) + labels.shape[2:] | |
result = torch.zeros(outshape, device=device, dtype=dtype) | |
if ignore_index and ignore_index < 0: | |
labels = labels + (-ignore_index) | |
result.scatter_(1, labels, 1) | |
if ignore_index and ignore_index < 0: | |
result = result[:, -ignore_index:] | |
elif ignore_index is not None: | |
result[:, ignore_index] = 0 | |
return result | |
def load_npy_if_present(outdir, filename, device): | |
filepath = os.path.join(outdir, filename) | |
if os.path.isfile(filepath): | |
data = numpy.load(filepath) | |
return torch.from_numpy(data).to(device) | |
return 0 | |
def load_npz_if_present(outdir, filename, varnames, device): | |
filepath = os.path.join(outdir, filename) | |
if os.path.isfile(filepath): | |
data = numpy.load(filepath) | |
numpy_result = [data[n] for n in varnames] | |
return tuple(torch.from_numpy(data).to(device) for data in numpy_result) | |
return None | |
def load_quantile_if_present(outdir, filename, device): | |
filepath = os.path.join(outdir, filename) | |
if os.path.isfile(filepath): | |
data = numpy.load(filepath) | |
result = RunningQuantile(state=data) | |
result.to_(device) | |
return result | |
return None | |
def load_conditional_quantile_if_present(outdir, filename): | |
filepath = os.path.join(outdir, filename) | |
if os.path.isfile(filepath): | |
data = numpy.load(filepath) | |
result = RunningConditionalQuantile(state=data) | |
return result | |
return None | |
def load_topk_if_present(outdir, filename, device): | |
filepath = os.path.join(outdir, filename) | |
if os.path.isfile(filepath): | |
data = numpy.load(filepath) | |
result = RunningTopK(state=data) | |
result.to_(device) | |
return result | |
return None | |
def load_covariance_if_present(outdir, filename, device): | |
filepath = os.path.join(outdir, filename) | |
if os.path.isfile(filepath): | |
data = numpy.load(filepath) | |
result = RunningCrossCovariance(state=data) | |
result.to_(device) | |
return result | |
return None | |
def save_state_dict(obj, filepath): | |
dirname = os.path.dirname(filepath) | |
os.makedirs(dirname, exist_ok=True) | |
dic = obj.state_dict() | |
numpy.savez(filepath, **dic) | |
def upsample_grid(data_shape, target_shape, input_shape=None, | |
scale_offset=None, dtype=torch.float, device=None): | |
'''Prepares a grid to use with grid_sample to upsample a batch of | |
features in data_shape to the target_shape. Can use scale_offset | |
and input_shape to center the grid in a nondefault way: scale_offset | |
maps feature pixels to input_shape pixels, and it is assumed that | |
the target_shape is a uniform downsampling of input_shape.''' | |
# Default is that nothing is resized. | |
if target_shape is None: | |
target_shape = data_shape | |
# Make a default scale_offset to fill the image if there isn't one | |
if scale_offset is None: | |
scale = tuple(float(ts) / ds | |
for ts, ds in zip(target_shape, data_shape)) | |
offset = tuple(0.5 * s - 0.5 for s in scale) | |
else: | |
scale, offset = (v for v in zip(*scale_offset)) | |
# Handle downsampling for different input vs target shape. | |
if input_shape is not None: | |
scale = tuple(s * (ts - 1) / (ns - 1) | |
for s, ns, ts in zip(scale, input_shape, target_shape)) | |
offset = tuple(o * (ts - 1) / (ns - 1) | |
for o, ns, ts in zip(offset, input_shape, target_shape)) | |
# Pytorch needs target coordinates in terms of source coordinates [-1..1] | |
ty, tx = (((torch.arange(ts, dtype=dtype, device=device) - o) | |
* (2 / (s * (ss - 1))) - 1) | |
for ts, ss, s, o, in zip(target_shape, data_shape, scale, offset)) | |
# Whoa, note that grid_sample reverses the order y, x -> x, y. | |
grid = torch.stack( | |
(tx[None,:].expand(target_shape), ty[:,None].expand(target_shape)),2 | |
)[None,:,:,:].expand((1, target_shape[0], target_shape[1], 2)) | |
return grid | |
def safe_dir_name(filename): | |
keepcharacters = (' ','.','_','-') | |
return ''.join(c | |
for c in filename if c.isalnum() or c in keepcharacters).rstrip() | |
bargraph_palette = [ | |
('#4B4CBF', '#B6B6F2'), | |
('#55B05B', '#B6F2BA'), | |
('#50BDAC', '#A5E5DB'), | |
('#81C679', '#C0FF9B'), | |
('#F0883B', '#F2CFB6'), | |
('#D4CF24', '#F2F1B6'), | |
('#D92E2B', '#F2B6B6'), | |
('#AB6BC6', '#CFAAFF'), | |
] | |
def make_svg_bargraph(labels, heights, categories, | |
barheight=100, barwidth=12, show_labels=True, filename=None): | |
# if len(labels) == 0: | |
# return # Nothing to do | |
unitheight = float(barheight) / max(max(heights, default=1), 1) | |
textheight = barheight if show_labels else 0 | |
labelsize = float(barwidth) | |
gap = float(barwidth) / 4 | |
textsize = barwidth + gap | |
rollup = max(heights, default=1) | |
textmargin = float(labelsize) * 2 / 3 | |
leftmargin = 32 | |
rightmargin = 8 | |
svgwidth = len(heights) * (barwidth + gap) + 2 * leftmargin + rightmargin | |
svgheight = barheight + textheight | |
# create an SVG XML element | |
svg = et.Element('svg', width=str(svgwidth), height=str(svgheight), | |
version='1.1', xmlns='http://www.w3.org/2000/svg') | |
# Draw the bar graph | |
basey = svgheight - textheight | |
x = leftmargin | |
# Add units scale on left | |
if len(heights): | |
for h in [1, (max(heights) + 1) // 2, max(heights)]: | |
et.SubElement(svg, 'text', x='0', y='0', | |
style=('font-family:sans-serif;font-size:%dpx;' + | |
'text-anchor:end;alignment-baseline:hanging;' + | |
'transform:translate(%dpx, %dpx);') % | |
(textsize, x - gap, basey - h * unitheight)).text = str(h) | |
et.SubElement(svg, 'text', x='0', y='0', | |
style=('font-family:sans-serif;font-size:%dpx;' + | |
'text-anchor:middle;' + | |
'transform:translate(%dpx, %dpx) rotate(-90deg)') % | |
(textsize, x - gap - textsize, basey - h * unitheight / 2) | |
).text = 'units' | |
# Draw big category background rectangles | |
for catindex, (cat, catcount) in enumerate(categories): | |
if not catcount: | |
continue | |
et.SubElement(svg, 'rect', x=str(x), y=str(basey - rollup * unitheight), | |
width=(str((barwidth + gap) * catcount - gap)), | |
height = str(rollup*unitheight), | |
fill=bargraph_palette[catindex % len(bargraph_palette)][1]) | |
x += (barwidth + gap) * catcount | |
# Draw small bars as well as 45degree text labels | |
x = leftmargin | |
catindex = -1 | |
catcount = 0 | |
for label, height in zip(labels, heights): | |
while not catcount and catindex <= len(categories): | |
catindex += 1 | |
catcount = categories[catindex][1] | |
color = bargraph_palette[catindex % len(bargraph_palette)][0] | |
et.SubElement(svg, 'rect', x=str(x), y=str(basey-(height * unitheight)), | |
width=str(barwidth), height=str(height * unitheight), | |
fill=color) | |
x += barwidth | |
if show_labels: | |
et.SubElement(svg, 'text', x='0', y='0', | |
style=('font-family:sans-serif;font-size:%dpx;text-anchor:end;'+ | |
'transform:translate(%dpx, %dpx) rotate(-45deg);') % | |
(labelsize, x, basey + textmargin)).text = readable(label) | |
x += gap | |
catcount -= 1 | |
# Text labels for each category | |
x = leftmargin | |
for cat, catcount in categories: | |
if not catcount: | |
continue | |
et.SubElement(svg, 'text', x='0', y='0', | |
style=('font-family:sans-serif;font-size:%dpx;text-anchor:end;'+ | |
'transform:translate(%dpx, %dpx) rotate(-90deg);') % | |
(textsize, x + (barwidth + gap) * catcount - gap, | |
basey - rollup * unitheight + gap)).text = '%d %s' % ( | |
catcount, readable(cat + ('s' if catcount != 1 else ''))) | |
x += (barwidth + gap) * catcount | |
# Output - this is the bare svg. | |
result = et.tostring(svg) | |
if filename: | |
f = open(filename, 'wb') | |
# When writing to a file a special header is needed. | |
f.write(''.join([ | |
'<?xml version=\"1.0\" standalone=\"no\"?>\n', | |
'<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n', | |
'\"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n'] | |
).encode('utf-8')) | |
f.write(result) | |
f.close() | |
return result | |
readable_replacements = [(re.compile(r[0]), r[1]) for r in [ | |
(r'-[sc]$', ''), | |
(r'_', ' '), | |
]] | |
def readable(label): | |
for pattern, subst in readable_replacements: | |
label= re.sub(pattern, subst, label) | |
return label | |
def reverse_normalize_from_transform(transform): | |
''' | |
Crawl around the transforms attached to a dataset looking for a | |
Normalize transform, and return it a corresponding ReverseNormalize, | |
or None if no normalization is found. | |
''' | |
if isinstance(transform, torchvision.transforms.Normalize): | |
return ReverseNormalize(transform.mean, transform.std) | |
t = getattr(transform, 'transform', None) | |
if t is not None: | |
return reverse_normalize_from_transform(t) | |
transforms = getattr(transform, 'transforms', None) | |
if transforms is not None: | |
for t in reversed(transforms): | |
result = reverse_normalize_from_transform(t) | |
if result is not None: | |
return result | |
return None | |
class ReverseNormalize: | |
''' | |
Applies the reverse of torchvision.transforms.Normalize. | |
''' | |
def __init__(self, mean, stdev): | |
mean = numpy.array(mean) | |
stdev = numpy.array(stdev) | |
self.mean = torch.from_numpy(mean)[None,:,None,None].float() | |
self.stdev = torch.from_numpy(stdev)[None,:,None,None].float() | |
def __call__(self, data): | |
device = data.device | |
return data.mul(self.stdev.to(device)).add_(self.mean.to(device)) | |
class ImageOnlySegRunner: | |
def __init__(self, dataset, recover_image=None): | |
if recover_image is None: | |
recover_image = reverse_normalize_from_transform(dataset) | |
self.recover_image = recover_image | |
self.dataset = dataset | |
def get_label_and_category_names(self): | |
return [('-', '-')], ['-'] | |
def run_and_segment_batch(self, batch, model, | |
want_bincount=False, want_rgb=False): | |
[im] = batch | |
device = next(model.parameters()).device | |
if want_rgb: | |
rgb = self.recover_image(im.clone() | |
).permute(0, 2, 3, 1).mul_(255).clamp(0, 255).byte() | |
else: | |
rgb = None | |
# Stubs for seg and bc | |
seg = torch.zeros(im.shape[0], 1, 1, 1, dtype=torch.long) | |
bc = torch.ones(im.shape[0], 1, dtype=torch.long) | |
# Run the model. | |
model(im.to(device)) | |
return seg, bc, rgb, im.shape[2:] | |
class ClassifierSegRunner: | |
def __init__(self, dataset, recover_image=None): | |
# The dataset contains explicit segmentations | |
if recover_image is None: | |
recover_image = reverse_normalize_from_transform(dataset) | |
self.recover_image = recover_image | |
self.dataset = dataset | |
def get_label_and_category_names(self): | |
catnames = self.dataset.categories | |
label_and_cat_names = [(readable(label), | |
catnames[self.dataset.label_category[i]]) | |
for i, label in enumerate(self.dataset.labels)] | |
return label_and_cat_names, catnames | |
def run_and_segment_batch(self, batch, model, | |
want_bincount=False, want_rgb=False): | |
''' | |
Runs the dissected model on one batch of the dataset, and | |
returns a multilabel semantic segmentation for the data. | |
Given a batch of size (n, c, y, x) the segmentation should | |
be a (long integer) tensor of size (n, d, y//r, x//r) where | |
d is the maximum number of simultaneous labels given to a pixel, | |
and where r is some (optional) resolution reduction factor. | |
In the segmentation returned, the label `0` is reserved for | |
the background "no-label". | |
In addition to the segmentation, bc, rgb, and shape are returned | |
where bc is a per-image bincount counting returned label pixels, | |
rgb is a viewable (n, y, x, rgb) byte image tensor for the data | |
for visualizations (reversing normalizations, for example), and | |
shape is the (y, x) size of the data. If want_bincount or | |
want_rgb are False, those return values may be None. | |
''' | |
im, seg, bc = batch | |
device = next(model.parameters()).device | |
if want_rgb: | |
rgb = self.recover_image(im.clone() | |
).permute(0, 2, 3, 1).mul_(255).clamp(0, 255).byte() | |
else: | |
rgb = None | |
# Run the model. | |
model(im.to(device)) | |
return seg, bc, rgb, im.shape[2:] | |
class GeneratorSegRunner: | |
def __init__(self, segmenter): | |
# The segmentations are given by an algorithm | |
if segmenter is None: | |
segmenter = UnifiedParsingSegmenter(segsizes=[256], segdiv='quad') | |
self.segmenter = segmenter | |
self.num_classes = len(segmenter.get_label_and_category_names()[0]) | |
def get_label_and_category_names(self): | |
return self.segmenter.get_label_and_category_names() | |
def run_and_segment_batch(self, batch, model, | |
want_bincount=False, want_rgb=False): | |
''' | |
Runs the dissected model on one batch of the dataset, and | |
returns a multilabel semantic segmentation for the data. | |
Given a batch of size (n, c, y, x) the segmentation should | |
be a (long integer) tensor of size (n, d, y//r, x//r) where | |
d is the maximum number of simultaneous labels given to a pixel, | |
and where r is some (optional) resolution reduction factor. | |
In the segmentation returned, the label `0` is reserved for | |
the background "no-label". | |
In addition to the segmentation, bc, rgb, and shape are returned | |
where bc is a per-image bincount counting returned label pixels, | |
rgb is a viewable (n, y, x, rgb) byte image tensor for the data | |
for visualizations (reversing normalizations, for example), and | |
shape is the (y, x) size of the data. If want_bincount or | |
want_rgb are False, those return values may be None. | |
''' | |
device = next(model.parameters()).device | |
z_batch = batch[0] | |
tensor_images = model(z_batch.to(device)) | |
seg = self.segmenter.segment_batch(tensor_images, downsample=2) | |
if want_bincount: | |
index = torch.arange(z_batch.shape[0], | |
dtype=torch.long, device=device) | |
bc = (seg + index[:, None, None, None] * self.num_classes).view(-1 | |
).bincount(minlength=z_batch.shape[0] * self.num_classes) | |
bc = bc.view(z_batch.shape[0], self.num_classes) | |
else: | |
bc = None | |
if want_rgb: | |
images = ((tensor_images + 1) / 2 * 255) | |
rgb = images.permute(0, 2, 3, 1).clamp(0, 255).byte() | |
else: | |
rgb = None | |
return seg, bc, rgb, tensor_images.shape[2:] | |