SuperFeatures / how /utils /visualize.py
YannisK's picture
temp state
32408ed
raw
history blame
4.94 kB
import os
import numpy as np
import cv2
from how.utils.html import HTML
def visualize_attention_map(dataset_name, imgpaths, attentions, scales, outdir):
assert len(imgpaths) == len(attentions)
os.makedirs(outdir, exist_ok=True)
for i, imgpath in enumerate(imgpaths): # for each image
img_basename = os.path.splitext(os.path.basename(imgpath))[0]
atts = attentions[i]
# load image
img = cv2.imread(imgpath)
# generate the visu for each scale independently
for j,s in enumerate(scales):
a = atts[j]
img_s = cv2.resize(img, None, fx=s, fy=s)
heatmap_s = cv2.applyColorMap( (255*cv2.resize(a, (img_s.shape[1],img_s.shape[0]))).astype(np.uint8), cv2.COLORMAP_JET)
overlay = cv2.addWeighted(heatmap_s, 0.5, img_s, 0.5, 0)
cv2.imwrite(outdir+'{:s}_scale{:g}.jpg'.format(img_basename, s), overlay)
# generate the visu for the aggregation over scales
agg_atts = sum([cv2.resize(a, (img.shape[1],img.shape[0])) for a in atts]) / len(atts)
heatmap_s = cv2.applyColorMap( (255*agg_atts).astype(np.uint8), cv2.COLORMAP_JET)
overlay = cv2.addWeighted(heatmap_s, 0.5, img, 0.5, 0)
cv2.imwrite(outdir+'{:s}_aggregated.jpg'.format(img_basename), overlay)
# generate a html webpage for visualization
doc = HTML()
doc.header().title(dataset_name)
b = doc.body()
b.h(1, dataset_name+' (attention map)')
t = b.table(cellpadding=2, border=1)
for i, imgpath in enumerate(imgpaths):
img_basename = os.path.splitext(os.path.basename(imgpath))[0]
if i%3==0: t.row(['info','image','agg','scale 1']+['scale '+str(s) for s in scales if s!=1], header=True)
r = t.row()
r.cell(str(i)+': '+img_basename)
r.cell('<a href="{img:s}"><img src="{img:s}"/></a>'.format(img=imgpath))
r.cell('<a href="{img:s}"><img src="{img:s}"/></a>'.format(img='{:s}_aggregated.jpg'.format(img_basename)))
r.cell('<a href="{img:s}"><img src="{img:s}"/></a>'.format(img='{:s}_scale1.jpg'.format(img_basename)))
for s in scales:
if s==1: continue
r.cell('<a href="{img:s}"><img src="{img:s}"/></a>'.format(img='{:s}_scale{:g}.jpg'.format(img_basename,s)))
doc.save(outdir+'index.html')
def visualize_region_maps(dataset_name, imgpaths, attentions, regions, scales, outdir, topk=10):
assert len(imgpaths) == len(attentions)
assert len(attentions) == len(regions)
assert 1 in scales # we display the regions only for scale 1 (at least so far)
os.makedirs(outdir, exist_ok=True)
# generate visualization of each region
for i, imgpath in enumerate(imgpaths): # for each image
img_basename = os.path.splitext(os.path.basename(imgpath))[0]
regs = regions[i]
# load image
img = cv2.imread(imgpath)
# for each scale
for j,s in enumerate(scales):
if s!=1: continue # just consider scale 1
r = regs[j][-1]
img_s = cv2.resize(img, None, fx=s, fy=s)
for ir in range(r.shape[0]):
heatmap_s = cv2.applyColorMap( (255*cv2.resize(np.minimum(1,100*r[ir,:,:]), (img_s.shape[1],img_s.shape[0]))).astype(np.uint8), cv2.COLORMAP_JET) # factor 10 for easier visualization
overlay = cv2.addWeighted(heatmap_s, 0.5, img_s, 0.5, 0)
cv2.imwrite(outdir+'{:s}_region{:d}_scale{:g}.jpg'.format(img_basename, ir, s), overlay)
# generate a html webpage for visualization
doc = HTML()
doc.header().title(dataset_name)
b = doc.body()
b.h(1, dataset_name+' (region maps)')
t = b.table(cellpadding=2, border=1)
for i, imgpath in enumerate(imgpaths):
atts = attentions[i]
regs = regions[i]
for j,s in enumerate(scales):
a = atts[j]
rr = regs[j][-1] # -1 because it is a list of the history of regions
if s==1: break
argsort = np.argsort(-a)
img_basename = os.path.splitext(os.path.basename(imgpath))[0]
if i%3==0: t.row(['info','image']+['scale 1 - region {:d}'.format(ir) for ir in range(topk)], header=True)
r = t.row()
r.cell(str(i)+': '+img_basename)
r.cell('<a href="{img:s}"><img src="{img:s}"/></a>'.format(img=imgpath))
for ir in range(topk):
index = argsort[ir]
r.cell('<a href="{img:s}"><img src="{img:s}"/></a><br>index: {index:d}, att: {att:g}, rmax: {rmax:g}'.format(img='{:s}_region{:d}_scale{:g}.jpg'.format(img_basename,index,s), index=index, att=a[index], rmax=rr[index,:,:].max()))
doc.save(outdir+'index.html')
if __name__=='__main__':
dataset = 'roxford5k'
from how.utils import data_helpers
images, qimages, bbxs, gnd = data_helpers.load_dataset(dataset, data_root="/tmp-network/user/pweinzae/CNNImageRetrieval/data/")
import pickle
with open('/tmp-network/user/pweinzae/roxford5k_features_attentions.pkl', 'rb') as fid:
features, attentions = pickle.load(fid)
visualize_attention_maps(qimages, attentions, scales=[2.0, 1.414, 1.0, 0.707, 0.5, 0.353, 0.25], outdir='/tmp-network/user/pweinzae/tmp/visu_attention_maps/'+dataset)