import os import torch import random import numpy as np import torch.nn as nn import torch.nn.functional as F import matplotlib.pyplot as plt def recursive_glob(rootdir='.', suffix=''): """Performs recursive glob with given suffix and rootdir :param rootdir is the root directory :param suffix is the suffix to be searched """ return [os.path.join(looproot, filename) for looproot, _, filenames in os.walk(rootdir) for filename in filenames if filename.endswith(suffix)] def get_cityscapes_labels(): return np.array([ # [ 0, 0, 0], [128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156], [190, 153, 153], [153, 153, 153], [250, 170, 30], [220, 220, 0], [107, 142, 35], [152, 251, 152], [0, 130, 180], [220, 20, 60], [255, 0, 0], [0, 0, 142], [0, 0, 70], [0, 60, 100], [0, 80, 100], [0, 0, 230], [119, 11, 32]]) def get_pascal_labels(): """Load the mapping that associates pascal classes with label colors Returns: np.ndarray with dimensions (21, 3) """ return np.asarray([[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128], [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0], [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128], [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0], [0, 64, 128]]) def get_mhp_labels(): """Load the mapping that associates pascal classes with label colors Returns: np.ndarray with dimensions (21, 3) """ return np.asarray([[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128], [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0], [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128], [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0], [0, 64, 128], # 21 [96, 0, 0], [0, 96, 0], [96, 96, 0], [0, 0, 96], [96, 0, 96], [0, 96, 96], [96, 96, 96], [32, 0, 0], [160, 0, 0], [32, 96, 0], [160, 96, 0], [32, 0, 96], [160, 0, 96], [32, 96, 96], [160, 96, 96], [0, 32, 0], [96, 32, 0], [0, 160, 0], [96, 160, 0], [0, 32, 96], # 41 [48, 0, 0], [0, 48, 0], [48, 48, 0], [0, 0, 96], [48, 0, 48], [0, 48, 48], [48, 48, 48], [16, 0, 0], [80, 0, 0], [16, 48, 0], [80, 48, 0], [16, 0, 48], [80, 0, 48], [16, 48, 48], [80, 48, 48], [0, 16, 0], [48, 16, 0], [0, 80, 0], # 59 ]) def encode_segmap(mask): """Encode segmentation label images as pascal classes Args: mask (np.ndarray): raw segmentation label image of dimension (M, N, 3), in which the Pascal classes are encoded as colours. Returns: (np.ndarray): class map with dimensions (M,N), where the value at a given location is the integer denoting the class index. """ mask = mask.astype(int) label_mask = np.zeros((mask.shape[0], mask.shape[1]), dtype=np.int16) for ii, label in enumerate(get_pascal_labels()): label_mask[np.where(np.all(mask == label, axis=-1))[:2]] = ii label_mask = label_mask.astype(int) return label_mask def decode_seg_map_sequence(label_masks, dataset='pascal'): rgb_masks = [] for label_mask in label_masks: rgb_mask = decode_segmap(label_mask, dataset) rgb_masks.append(rgb_mask) rgb_masks = torch.from_numpy(np.array(rgb_masks).transpose([0, 3, 1, 2])) return rgb_masks def decode_segmap(label_mask, dataset, plot=False): """Decode segmentation class labels into a color image Args: label_mask (np.ndarray): an (M,N) array of integer values denoting the class label at each spatial location. plot (bool, optional): whether to show the resulting color image in a figure. Returns: (np.ndarray, optional): the resulting decoded color image. """ if dataset == 'pascal': n_classes = 21 label_colours = get_pascal_labels() elif dataset == 'cityscapes': n_classes = 19 label_colours = get_cityscapes_labels() elif dataset == 'mhp': n_classes = 59 label_colours = get_mhp_labels() else: raise NotImplementedError r = label_mask.copy() g = label_mask.copy() b = label_mask.copy() for ll in range(0, n_classes): r[label_mask == ll] = label_colours[ll, 0] g[label_mask == ll] = label_colours[ll, 1] b[label_mask == ll] = label_colours[ll, 2] rgb = np.zeros((label_mask.shape[0], label_mask.shape[1], 3)) rgb[:, :, 0] = r / 255.0 rgb[:, :, 1] = g / 255.0 rgb[:, :, 2] = b / 255.0 if plot: plt.imshow(rgb) plt.show() else: return rgb def generate_param_report(logfile, param): log_file = open(logfile, 'w') for key, val in param.items(): log_file.write(key + ':' + str(val) + '\n') log_file.close() def cross_entropy2d(logit, target, ignore_index=255, weight=None, size_average=True, batch_average=True): n, c, h, w = logit.size() # logit = logit.permute(0, 2, 3, 1) target = target.squeeze(1) if weight is None: criterion = nn.CrossEntropyLoss(weight=weight, ignore_index=ignore_index,size_average=size_average) else: criterion = nn.CrossEntropyLoss(weight=torch.from_numpy(np.array(weight)).float().cuda(), ignore_index=ignore_index, size_average=size_average) loss = criterion(logit, target.long()) return loss def cross_entropy2d_dataparallel(logit, target, ignore_index=255, weight=None, size_average=True, batch_average=True): n, c, h, w = logit.size() # logit = logit.permute(0, 2, 3, 1) target = target.squeeze(1) if weight is None: criterion = nn.DataParallel(nn.CrossEntropyLoss(weight=weight, ignore_index=ignore_index,size_average=size_average)) else: criterion = nn.DataParallel(nn.CrossEntropyLoss(weight=torch.from_numpy(np.array(weight)).float().cuda(), ignore_index=ignore_index, size_average=size_average)) loss = criterion(logit, target.long()) return loss.sum() def lr_poly(base_lr, iter_, max_iter=100, power=0.9): return base_lr * ((1 - float(iter_) / max_iter) ** power) def get_iou(pred, gt, n_classes=21): total_iou = 0.0 for i in range(len(pred)): pred_tmp = pred[i] gt_tmp = gt[i] intersect = [0] * n_classes union = [0] * n_classes for j in range(n_classes): match = (pred_tmp == j) + (gt_tmp == j) it = torch.sum(match == 2).item() un = torch.sum(match > 0).item() intersect[j] += it union[j] += un iou = [] for k in range(n_classes): if union[k] == 0: continue iou.append(intersect[k] / union[k]) img_iou = (sum(iou) / len(iou)) total_iou += img_iou return total_iou def scale_tensor(input,size=512,mode='bilinear'): print(input.size()) # b,h,w = input.size() _, _, h, w = input.size() if mode == 'nearest': if h == 512 and w == 512: return input return F.upsample_nearest(input,size=(size,size)) if h>512 and w > 512: return F.upsample(input, size=(size,size), mode=mode, align_corners=True) return F.upsample(input, size=(size,size), mode=mode, align_corners=True) def scale_tensor_list(input,): output = [] for i in range(len(input)-1): output_item = [] for j in range(len(input[i])): _, _, h, w = input[-1][j].size() output_item.append(F.upsample(input[i][j], size=(h,w), mode='bilinear', align_corners=True)) output.append(output_item) output.append(input[-1]) return output def scale_tensor_list_0(input,base_input): output = [] assert len(input) == len(base_input) for j in range(len(input)): _, _, h, w = base_input[j].size() after_size = F.upsample(input[j], size=(h,w), mode='bilinear', align_corners=True) base_input[j] = base_input[j] + after_size # output.append(output_item) # output.append(input[-1]) return base_input if __name__ == '__main__': print(lr_poly(0.007,iter_=99,max_iter=150))