Spaces:
Build error
Build error
import os | |
import numpy as np | |
import cv2 | |
from how.utils.html import HTML | |
def visualize_attention_map(dataset_name, imgpaths, attentions, scales, outdir): | |
assert len(imgpaths) == len(attentions) | |
os.makedirs(outdir, exist_ok=True) | |
for i, imgpath in enumerate(imgpaths): # for each image | |
img_basename = os.path.splitext(os.path.basename(imgpath))[0] | |
atts = attentions[i] | |
# load image | |
img = cv2.imread(imgpath) | |
# generate the visu for each scale independently | |
for j,s in enumerate(scales): | |
a = atts[j] | |
img_s = cv2.resize(img, None, fx=s, fy=s) | |
heatmap_s = cv2.applyColorMap( (255*cv2.resize(a, (img_s.shape[1],img_s.shape[0]))).astype(np.uint8), cv2.COLORMAP_JET) | |
overlay = cv2.addWeighted(heatmap_s, 0.5, img_s, 0.5, 0) | |
cv2.imwrite(outdir+'{:s}_scale{:g}.jpg'.format(img_basename, s), overlay) | |
# generate the visu for the aggregation over scales | |
agg_atts = sum([cv2.resize(a, (img.shape[1],img.shape[0])) for a in atts]) / len(atts) | |
heatmap_s = cv2.applyColorMap( (255*agg_atts).astype(np.uint8), cv2.COLORMAP_JET) | |
overlay = cv2.addWeighted(heatmap_s, 0.5, img, 0.5, 0) | |
cv2.imwrite(outdir+'{:s}_aggregated.jpg'.format(img_basename), overlay) | |
# generate a html webpage for visualization | |
doc = HTML() | |
doc.header().title(dataset_name) | |
b = doc.body() | |
b.h(1, dataset_name+' (attention map)') | |
t = b.table(cellpadding=2, border=1) | |
for i, imgpath in enumerate(imgpaths): | |
img_basename = os.path.splitext(os.path.basename(imgpath))[0] | |
if i%3==0: t.row(['info','image','agg','scale 1']+['scale '+str(s) for s in scales if s!=1], header=True) | |
r = t.row() | |
r.cell(str(i)+': '+img_basename) | |
r.cell('<a href="{img:s}"><img src="{img:s}"/></a>'.format(img=imgpath)) | |
r.cell('<a href="{img:s}"><img src="{img:s}"/></a>'.format(img='{:s}_aggregated.jpg'.format(img_basename))) | |
r.cell('<a href="{img:s}"><img src="{img:s}"/></a>'.format(img='{:s}_scale1.jpg'.format(img_basename))) | |
for s in scales: | |
if s==1: continue | |
r.cell('<a href="{img:s}"><img src="{img:s}"/></a>'.format(img='{:s}_scale{:g}.jpg'.format(img_basename,s))) | |
doc.save(outdir+'index.html') | |
def visualize_region_maps(dataset_name, imgpaths, attentions, regions, scales, outdir, topk=10): | |
assert len(imgpaths) == len(attentions) | |
assert len(attentions) == len(regions) | |
assert 1 in scales # we display the regions only for scale 1 (at least so far) | |
os.makedirs(outdir, exist_ok=True) | |
# generate visualization of each region | |
for i, imgpath in enumerate(imgpaths): # for each image | |
img_basename = os.path.splitext(os.path.basename(imgpath))[0] | |
regs = regions[i] | |
# load image | |
img = cv2.imread(imgpath) | |
# for each scale | |
for j,s in enumerate(scales): | |
if s!=1: continue # just consider scale 1 | |
r = regs[j][-1] | |
img_s = cv2.resize(img, None, fx=s, fy=s) | |
for ir in range(r.shape[0]): | |
heatmap_s = cv2.applyColorMap( (255*cv2.resize(np.minimum(1,100*r[ir,:,:]), (img_s.shape[1],img_s.shape[0]))).astype(np.uint8), cv2.COLORMAP_JET) # factor 10 for easier visualization | |
overlay = cv2.addWeighted(heatmap_s, 0.5, img_s, 0.5, 0) | |
cv2.imwrite(outdir+'{:s}_region{:d}_scale{:g}.jpg'.format(img_basename, ir, s), overlay) | |
# generate a html webpage for visualization | |
doc = HTML() | |
doc.header().title(dataset_name) | |
b = doc.body() | |
b.h(1, dataset_name+' (region maps)') | |
t = b.table(cellpadding=2, border=1) | |
for i, imgpath in enumerate(imgpaths): | |
atts = attentions[i] | |
regs = regions[i] | |
for j,s in enumerate(scales): | |
a = atts[j] | |
rr = regs[j][-1] # -1 because it is a list of the history of regions | |
if s==1: break | |
argsort = np.argsort(-a) | |
img_basename = os.path.splitext(os.path.basename(imgpath))[0] | |
if i%3==0: t.row(['info','image']+['scale 1 - region {:d}'.format(ir) for ir in range(topk)], header=True) | |
r = t.row() | |
r.cell(str(i)+': '+img_basename) | |
r.cell('<a href="{img:s}"><img src="{img:s}"/></a>'.format(img=imgpath)) | |
for ir in range(topk): | |
index = argsort[ir] | |
r.cell('<a href="{img:s}"><img src="{img:s}"/></a><br>index: {index:d}, att: {att:g}, rmax: {rmax:g}'.format(img='{:s}_region{:d}_scale{:g}.jpg'.format(img_basename,index,s), index=index, att=a[index], rmax=rr[index,:,:].max())) | |
doc.save(outdir+'index.html') | |
if __name__=='__main__': | |
dataset = 'roxford5k' | |
from how.utils import data_helpers | |
images, qimages, bbxs, gnd = data_helpers.load_dataset(dataset, data_root="/tmp-network/user/pweinzae/CNNImageRetrieval/data/") | |
import pickle | |
with open('/tmp-network/user/pweinzae/roxford5k_features_attentions.pkl', 'rb') as fid: | |
features, attentions = pickle.load(fid) | |
visualize_attention_maps(qimages, attentions, scales=[2.0, 1.414, 1.0, 0.707, 0.5, 0.353, 0.25], outdir='/tmp-network/user/pweinzae/tmp/visu_attention_maps/'+dataset) |