File size: 2,871 Bytes
d015578
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import cv2
import numpy as np

from spiga.data.visualize.inspect_dataset import DatasetInspector, inspect_parser


class HeatmapInspector(DatasetInspector):

    def __init__(self, database, anns_type, data_aug=True, image_shape=(256,256)):

        super().__init__(database, anns_type, data_aug=data_aug, pose=False, image_shape=image_shape)

        self.data_config.aug_names.append('heatmaps2D')
        self.data_config.heatmap2D_norm = False
        self.data_config.aug_names.append('boundaries')
        self.data_config.shuffle = False
        self.reload_dataset()

    def show_dataset(self, ids_list=None):

        if ids_list is None:
            ids = self.get_idx(shuffle=self.data_config.shuffle)
        else:
            ids = ids_list

        for img_id in ids:
            data_dict = self.dataset[img_id]

            crop_imgs, _ = self.plot_features(data_dict)

            # Plot landmark crop
            cv2.imshow('crop', crop_imgs['lnd'])

            # Plot landmarks 2D (group)
            crop_allheats = self._plot_heatmaps2D(data_dict)

            # Plot boundaries shape
            cv2.imshow('boundary', np.max(data_dict['boundary'], axis=0))

            for lnd_idx in range(self.data_config.database.num_landmarks):
                # Heatmaps 2D
                crop_heats = self._plot_heatmaps2D(data_dict, lnd_idx)
                maps = cv2.hconcat([crop_allheats['heatmaps2D'], crop_heats['heatmaps2D']])
                cv2.imshow('heatmaps', maps)

                key = cv2.waitKey()
                if key == ord('q'):
                    break
                if key == ord('n'):
                    break

            if key == ord('q'):
                break

    def _plot_heatmaps2D(self, data_dict, heatmap_id=None):

        # Variables
        heatmaps = {}
        image = data_dict['image']

        if heatmap_id is None:
            heatmaps2D = data_dict['heatmap2D']
            heatmaps2D = np.max(heatmaps2D, axis=0)
        else:
            heatmaps2D = data_dict['heatmap2D'][heatmap_id]

        # Plot maps
        heatmaps['heatmaps2D'] = self._merge_imgmap(image, heatmaps2D)
        return heatmaps

    def _merge_imgmap(self, image, maps):
        crop_maps = cv2.applyColorMap(np.uint8(255 * maps), cv2.COLORMAP_JET)
        return cv2.addWeighted(image, 0.7, crop_maps, 0.3, 0)


if __name__ == '__main__':

    args = inspect_parser()
    data_aug = True
    database = args.database
    anns_type = args.anns
    select_img = args.img
    if args.clean:
        data_aug = False

    if len(args.shape) != 2:
        raise ValueError('--shape requires two values: width and height. Ej: --shape 256 256')
    else:
        img_shape = tuple(args.shape)

    visualizer = HeatmapInspector(database, anns_type, data_aug, image_shape=img_shape)
    visualizer.show_dataset(ids_list=select_img)