import cv2 import numpy as np from spiga.data.visualize.inspect_dataset import DatasetInspector, inspect_parser class HeatmapInspector(DatasetInspector): def __init__(self, database, anns_type, data_aug=True, image_shape=(256,256)): super().__init__(database, anns_type, data_aug=data_aug, pose=False, image_shape=image_shape) self.data_config.aug_names.append('heatmaps2D') self.data_config.heatmap2D_norm = False self.data_config.aug_names.append('boundaries') self.data_config.shuffle = False self.reload_dataset() def show_dataset(self, ids_list=None): if ids_list is None: ids = self.get_idx(shuffle=self.data_config.shuffle) else: ids = ids_list for img_id in ids: data_dict = self.dataset[img_id] crop_imgs, _ = self.plot_features(data_dict) # Plot landmark crop cv2.imshow('crop', crop_imgs['lnd']) # Plot landmarks 2D (group) crop_allheats = self._plot_heatmaps2D(data_dict) # Plot boundaries shape cv2.imshow('boundary', np.max(data_dict['boundary'], axis=0)) for lnd_idx in range(self.data_config.database.num_landmarks): # Heatmaps 2D crop_heats = self._plot_heatmaps2D(data_dict, lnd_idx) maps = cv2.hconcat([crop_allheats['heatmaps2D'], crop_heats['heatmaps2D']]) cv2.imshow('heatmaps', maps) key = cv2.waitKey() if key == ord('q'): break if key == ord('n'): break if key == ord('q'): break def _plot_heatmaps2D(self, data_dict, heatmap_id=None): # Variables heatmaps = {} image = data_dict['image'] if heatmap_id is None: heatmaps2D = data_dict['heatmap2D'] heatmaps2D = np.max(heatmaps2D, axis=0) else: heatmaps2D = data_dict['heatmap2D'][heatmap_id] # Plot maps heatmaps['heatmaps2D'] = self._merge_imgmap(image, heatmaps2D) return heatmaps def _merge_imgmap(self, image, maps): crop_maps = cv2.applyColorMap(np.uint8(255 * maps), cv2.COLORMAP_JET) return cv2.addWeighted(image, 0.7, crop_maps, 0.3, 0) if __name__ == '__main__': args = inspect_parser() data_aug = True database = args.database anns_type = args.anns select_img = args.img if args.clean: data_aug = False if len(args.shape) != 2: raise ValueError('--shape requires two values: width and height. Ej: --shape 256 256') else: img_shape = tuple(args.shape) visualizer = HeatmapInspector(database, anns_type, data_aug, image_shape=img_shape) visualizer.show_dataset(ids_list=select_img)