Paul Engstler
Initial commit
92f0e98
raw
history blame
7.24 kB
import os
import numpy
from scipy.interpolate import RectBivariateSpline
def activation_visualization(image, data, level, alpha=0.5, source_shape=None,
target_shape=None, crop=False, zoom=None, border=2,
negate=False, return_mask=False, **kwargs):
"""
Makes a visualiztion image of activation data overlaid on the image.
Params:
image The original image.
data The single channel feature map.
alpha The darkening to apply in inactive regions of the image.
level The threshold of activation levels to highlight.
"""
if len(image.shape) == 2:
# Puff up grayscale image to RGB.
image = image[:,:,None] * numpy.array([[[1, 1, 1]]])
surface = activation_surface(data, target_shape=image.shape[:2],
source_shape=source_shape, **kwargs)
if negate:
surface = -surface
level = -level
if crop:
# crop to source_shape
if source_shape is not None:
ch, cw = ((t - s) // 2 for s, t in zip(
source_shape, image.shape[:2]))
image = image[ch:ch+source_shape[0], cw:cw+source_shape[1]]
surface = surface[ch:ch+source_shape[0], cw:cw+source_shape[1]]
if crop is True:
crop = surface.shape
elif not hasattr(crop, '__len__'):
crop = (crop, crop)
if zoom is not None:
source_rect = best_sub_rect(surface >= level, crop, zoom,
pad=border)
else:
source_rect = (0, surface.shape[0], 0, surface.shape[1])
image = zoom_image(image, source_rect, crop)
surface = zoom_image(surface, source_rect, crop)
# Rescale visualization to target_shape if specified.
if target_shape is not None and (image.shape[0] != target_shape[0]
or image.shape[1] != target_shape[1]):
image = zoom_image(image, target_shape=target_shape)
surface = zoom_image(surface, target_shape=target_shape)
mask = (surface >= level)
# Add a yellow border at the edge of the mask for contrast
result = (mask[:, :, None] * (1 - alpha) + alpha) * image
if border:
edge = mask_border(mask)[:,:,None]
result = numpy.maximum(edge * numpy.array([[[200, 200, 0]]]), result)
if not return_mask:
return result
mask_image = (1 - mask[:, :, None]) * numpy.array(
[[[0, 0, 0, 255 * (1 - alpha)]]], dtype=numpy.uint8)
if border:
mask_image = numpy.maximum(edge * numpy.array([[[200, 200, 0, 255]]]),
mask_image)
return result, mask_image
def activation_surface(data, target_shape=None, source_shape=None,
scale_offset=None, deg=1, pad=True):
"""
Generates an upsampled activation sample.
Params:
target_shape Shape of the output array.
source_shape The centered shape of the output to match with data
when upscaling. Defaults to the whole target_shape.
scale_offset The amount by which to scale, then offset data
dimensions to end up with target dimensions. A pair of pairs.
deg Degree of interpolation to apply (1 = linear, etc).
pad True to zero-pad the edge instead of doing a funny edge interp.
"""
# Default is that nothing is resized.
if target_shape is None:
target_shape = data.shape
# Make a default scale_offset to fill the image if there isn't one
if scale_offset is None:
scale = tuple(float(ts) / ds
for ts, ds in zip(target_shape, data.shape))
offset = tuple(0.5 * s - 0.5 for s in scale)
else:
scale, offset = (v for v in zip(*scale_offset))
# Now we adjust offsets to take into account cropping and so on
if source_shape is not None:
offset = tuple(o + (ts - ss) / 2.0
for o, ss, ts in zip(offset, source_shape, target_shape))
# Pad the edge with zeros for sensible edge behavior
if pad:
zeropad = numpy.zeros(
(data.shape[0] + 2, data.shape[1] + 2), dtype=data.dtype)
zeropad[1:-1, 1:-1] = data
data = zeropad
offset = tuple((o - s) for o, s in zip(offset, scale))
# Upsample linearly
ty, tx = (numpy.arange(ts) for ts in target_shape)
sy, sx = (numpy.arange(ss) * s + o
for ss, s, o in zip(data.shape, scale, offset))
levels = RectBivariateSpline(
sy, sx, data, kx=deg, ky=deg)(ty, tx, grid=True)
# Return the mask.
return levels
def mask_border(mask, border=2):
"""Given a mask computes a border mask"""
from scipy import ndimage
struct = ndimage.generate_binary_structure(2, 2)
erosion = numpy.ones((mask.shape[0] + 10, mask.shape[1] + 10), dtype='int')
erosion[5:5+mask.shape[0], 5:5+mask.shape[1]] = ~mask
for _ in range(border):
erosion = ndimage.binary_erosion(erosion, struct)
return ~mask ^ erosion[5:5+mask.shape[0], 5:5+mask.shape[1]]
def bounding_rect(mask, pad=0):
"""Returns (r, b, l, r) boundaries so that all nonzero pixels in mask
have locations (i, j) with t <= i < b, and l <= j < r."""
nz = mask.nonzero()
if len(nz[0]) == 0:
# print('no pixels')
return (0, mask.shape[0], 0, mask.shape[1])
(t, b), (l, r) = [(max(0, p.min() - pad), min(s, p.max() + 1 + pad))
for p, s in zip(nz, mask.shape)]
return (t, b, l, r)
def best_sub_rect(mask, shape, max_zoom=None, pad=2):
"""Finds the smallest subrectangle containing all the nonzeros of mask,
matching the aspect ratio of shape, and where the zoom-up ratio is no
more than max_zoom"""
t, b, l, r = bounding_rect(mask, pad=pad)
height = max(b - t, int(round(float(shape[0]) * (r - l) / shape[1])))
if max_zoom is not None:
height = int(max(round(float(shape[0]) / max_zoom), height))
width = int(round(float(shape[1]) * height / shape[0]))
nt = min(mask.shape[0] - height, max(0, (b + t - height) // 2))
nb = nt + height
nl = min(mask.shape[1] - width, max(0, (r + l - width) // 2))
nr = nl + width
return (nt, nb, nl, nr)
def zoom_image(img, source_rect=None, target_shape=None):
"""Zooms pixels from the source_rect of img to target_shape."""
import warnings
from scipy.ndimage import zoom
if target_shape is None:
target_shape = img.shape
if source_rect is None:
st, sb, sl, sr = (0, img.shape[0], 0, img.shape[1])
else:
st, sb, sl, sr = source_rect
source = img[st:sb, sl:sr]
if source.shape == target_shape:
return source
zoom_tuple = tuple(float(t) / s
for t, s in zip(target_shape, source.shape[:2])
) + (1,) * (img.ndim - 2)
with warnings.catch_warnings():
warnings.simplefilter('ignore', UserWarning) # "output shape of zoom"
target = zoom(source, zoom_tuple)
assert target.shape[:2] == target_shape, (target.shape, target_shape)
return target
def choose_level(feature_map, percentile=0.8):
'''
Chooses the top 80% level (or whatever the level chosen).
'''
data_range = numpy.sort(feature_map.flatten())
return numpy.interp(
percentile, numpy.linspace(0, 1, len(data_range)), data_range)