File size: 4,944 Bytes
32408ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import os
import numpy as np
import cv2


from how.utils.html import HTML

def visualize_attention_map(dataset_name, imgpaths, attentions, scales, outdir):
  assert len(imgpaths) == len(attentions)
  os.makedirs(outdir, exist_ok=True)
  for i, imgpath in enumerate(imgpaths): # for each image
    img_basename = os.path.splitext(os.path.basename(imgpath))[0]
    atts = attentions[i]
    # load image
    img = cv2.imread(imgpath)
    # generate the visu for each scale independently
    for j,s in enumerate(scales):
      a = atts[j]
      img_s = cv2.resize(img, None, fx=s, fy=s)
      heatmap_s = cv2.applyColorMap( (255*cv2.resize(a, (img_s.shape[1],img_s.shape[0]))).astype(np.uint8), cv2.COLORMAP_JET)
      overlay = cv2.addWeighted(heatmap_s, 0.5, img_s, 0.5, 0)
      cv2.imwrite(outdir+'{:s}_scale{:g}.jpg'.format(img_basename, s), overlay)
    # generate the visu for the aggregation over scales
    agg_atts = sum([cv2.resize(a, (img.shape[1],img.shape[0])) for a in atts]) / len(atts)
    heatmap_s = cv2.applyColorMap( (255*agg_atts).astype(np.uint8), cv2.COLORMAP_JET)
    overlay = cv2.addWeighted(heatmap_s, 0.5, img, 0.5, 0)
    cv2.imwrite(outdir+'{:s}_aggregated.jpg'.format(img_basename), overlay)
  # generate a html webpage for visualization
  doc = HTML()
  doc.header().title(dataset_name)
  b = doc.body()
  b.h(1, dataset_name+' (attention map)')
  t = b.table(cellpadding=2, border=1)
  for i, imgpath in enumerate(imgpaths):
    img_basename = os.path.splitext(os.path.basename(imgpath))[0]
    if i%3==0: t.row(['info','image','agg','scale 1']+['scale '+str(s) for s in scales if s!=1], header=True)
    r = t.row()
    r.cell(str(i)+': '+img_basename)
    r.cell('<a href="{img:s}"><img src="{img:s}"/></a>'.format(img=imgpath))
    r.cell('<a href="{img:s}"><img src="{img:s}"/></a>'.format(img='{:s}_aggregated.jpg'.format(img_basename)))
    r.cell('<a href="{img:s}"><img src="{img:s}"/></a>'.format(img='{:s}_scale1.jpg'.format(img_basename)))
    for s in scales:
      if s==1: continue
      r.cell('<a href="{img:s}"><img src="{img:s}"/></a>'.format(img='{:s}_scale{:g}.jpg'.format(img_basename,s)))
  doc.save(outdir+'index.html')

  
def visualize_region_maps(dataset_name, imgpaths, attentions, regions, scales, outdir, topk=10):  
  assert len(imgpaths) == len(attentions)
  assert len(attentions) == len(regions)
  assert 1 in scales # we display the regions only for scale 1 (at least so far)
  os.makedirs(outdir, exist_ok=True)
  # generate visualization of each region
  for i, imgpath in enumerate(imgpaths): # for each image
    img_basename = os.path.splitext(os.path.basename(imgpath))[0]
    regs = regions[i]
    # load image
    img = cv2.imread(imgpath)
    # for each scale
    for j,s in enumerate(scales):
      if s!=1: continue    # just consider scale 1
      r = regs[j][-1]
      img_s = cv2.resize(img, None, fx=s, fy=s)
      for ir in range(r.shape[0]):
        heatmap_s = cv2.applyColorMap( (255*cv2.resize(np.minimum(1,100*r[ir,:,:]), (img_s.shape[1],img_s.shape[0]))).astype(np.uint8), cv2.COLORMAP_JET) # factor 10 for easier visualization
        overlay = cv2.addWeighted(heatmap_s, 0.5, img_s, 0.5, 0)
        cv2.imwrite(outdir+'{:s}_region{:d}_scale{:g}.jpg'.format(img_basename, ir, s), overlay)
    # generate a html webpage for visualization
    doc = HTML()
    doc.header().title(dataset_name)
    b = doc.body()
    b.h(1, dataset_name+' (region maps)')
    t = b.table(cellpadding=2, border=1)
    for i, imgpath in enumerate(imgpaths):
      atts = attentions[i]
      regs = regions[i]
      for j,s in enumerate(scales):
        a = atts[j]
        rr = regs[j][-1] # -1 because it is a list of the history of regions
        if s==1: break
      argsort = np.argsort(-a)
      img_basename = os.path.splitext(os.path.basename(imgpath))[0]
      if i%3==0: t.row(['info','image']+['scale 1 - region {:d}'.format(ir) for ir in range(topk)], header=True)
      r = t.row()
      r.cell(str(i)+': '+img_basename)
      r.cell('<a href="{img:s}"><img src="{img:s}"/></a>'.format(img=imgpath))
      for ir in range(topk):
        index = argsort[ir]
        r.cell('<a href="{img:s}"><img src="{img:s}"/></a><br>index: {index:d}, att: {att:g}, rmax: {rmax:g}'.format(img='{:s}_region{:d}_scale{:g}.jpg'.format(img_basename,index,s), index=index, att=a[index], rmax=rr[index,:,:].max()))
    doc.save(outdir+'index.html')
  
if __name__=='__main__':
  dataset = 'roxford5k'
  from how.utils import data_helpers
  images, qimages, bbxs, gnd = data_helpers.load_dataset(dataset, data_root="/tmp-network/user/pweinzae/CNNImageRetrieval/data/")
  import pickle
  with open('/tmp-network/user/pweinzae/roxford5k_features_attentions.pkl', 'rb') as fid:
    features, attentions = pickle.load(fid)
  visualize_attention_maps(qimages, attentions, scales=[2.0, 1.414, 1.0, 0.707, 0.5, 0.353, 0.25], outdir='/tmp-network/user/pweinzae/tmp/visu_attention_maps/'+dataset)