'''Some helper functions for PyTorch.''' import os import sys import time import math import numpy as np import torch import torch.nn as nn def get_sub_image(mega_image,overlap=0.2,ratio=1,crop_size=512): #mage_image: original image #ratio: ratio * 512 counter the different heights of image taken #return: list of sub image and list fo the upper left corner of sub image coor_list = [] sub_image_list = [] w,h,c = mega_image.shape if w < crop_size or h < crop_size: mega_image = image_padding(mega_image) size = int(ratio*crop_size) num_rows = int(w/int(size*(1-overlap))) num_cols = int(h/int(size*(1-overlap))) new_size = int(size*(1-overlap)) for i in range(num_rows+1): if (i == num_rows): for j in range(num_cols+1): if (j==num_cols): sub_image = mega_image[-size:,-size:,:] coor_list.append([w-size,h-size]) sub_image_list.append (sub_image) else: sub_image = mega_image[-size:,new_size*j:new_size*j+size,:] coor_list.append([w-size,new_size*j]) sub_image_list.append (sub_image) else: for j in range(num_cols+1): if (j==num_cols): sub_image = mega_image[new_size*i:new_size*i+size,-size:,:] coor_list.append([new_size*i,h-size]) sub_image_list.append (sub_image) else: sub_image = mega_image[new_size*i:new_size*i+size,new_size*j:new_size*j+size,:] coor_list.append([new_size*i,new_size*j]) sub_image_list.append (sub_image) return sub_image_list,coor_list def image_padding(mega_image): w,h,c = mega_image.shape result = np.full((max(512,h),max(512,w), 3), (0,0,0), dtype=np.uint8) result[0:h,0:w] = mega_image return result def py_cpu_nms(dets, thresh): """Pure Python NMS baseline.""" dets = np.asarray(dets) x1 = dets[:, 0] y1 = dets[:, 1] x2 = dets[:, 2] y2 = dets[:, 3] scores = dets[:, 4] areas = (x2 - x1 + 1) * (y2 - y1 + 1) order = scores.argsort()[::-1] keep = [] while order.size > 0: i = order[0] keep.append(i) xx1 = np.maximum(x1[i], x1[order[1:]]) yy1 = np.maximum(y1[i], y1[order[1:]]) xx2 = np.minimum(x2[i], x2[order[1:]]) yy2 = np.minimum(y2[i], y2[order[1:]]) w = np.maximum(0.0, xx2 - xx1 + 1) h = np.maximum(0.0, yy2 - yy1 + 1) inter = w * h ovr = inter / (areas[i] + areas[order[1:]] - inter) inds = np.where(ovr <= thresh)[0] order = order[inds + 1] return keep def sort_key(row): return row[-1] def filter_small_fp(bbox_list): """Remove small predictions""" bbox_area_list = [] new_bbox_list = [] bbox_list.sort(key = sort_key,reverse=True) for bbox in bbox_list[0:max(int(0.05*len(bbox_list)),1)]: bbox_area_list.append((bbox[2]-bbox[0])*(bbox[3]-bbox[1])) print(len(bbox_area_list)) average_area = np.mean(bbox_area_list) for bbox in bbox_list: bbox_area = (bbox[2]-bbox[0])*(bbox[3]-bbox[1]) if abs(bbox_area-average_area)/average_area < 0.8: new_bbox_list.append(bbox) return new_bbox_list def get_mean_and_std(dataset, max_load=10000): '''Compute the mean and std value of dataset.''' # dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2) mean = torch.zeros(3) std = torch.zeros(3) print('==> Computing mean and std..') N = min(max_load, len(dataset)) for i in range(N): print(i) im,_,_ = dataset.load(1) for j in range(3): mean[j] += im[:,j,:,:].mean() std[j] += im[:,j,:,:].std() mean.div_(N) std.div_(N) return mean, std def mask_select(input, mask, dim=0): '''Select tensor rows/cols using a mask tensor. Args: input: (tensor) input tensor, sized [N,M]. mask: (tensor) mask tensor, sized [N,] or [M,]. dim: (tensor) mask dim. Returns: (tensor) selected rows/cols. Example: >>> a = torch.randn(4,2) >>> a -0.3462 -0.6930 0.4560 -0.7459 -0.1289 -0.9955 1.7454 1.9787 [torch.FloatTensor of size 4x2] >>> i = a[:,0] > 0 >>> i 0 1 0 1 [torch.ByteTensor of size 4] >>> masked_select(a, i, 0) 0.4560 -0.7459 1.7454 1.9787 [torch.FloatTensor of size 2x2] ''' index = mask.nonzero().squeeze(1) return input.index_select(dim, index) def meshgrid(x, y, row_major=True): '''Return meshgrid in range x & y. Args: x: (int) first dim range. y: (int) second dim range. row_major: (bool) row major or column major. Returns: (tensor) meshgrid, sized [x*y,2] Example: >> meshgrid(3,2) 0 0 1 0 2 0 0 1 1 1 2 1 [torch.FloatTensor of size 6x2] >> meshgrid(3,2,row_major=False) 0 0 0 1 0 2 1 0 1 1 1 2 [torch.FloatTensor of size 6x2] ''' a = torch.arange(0,x) b = torch.arange(0,y) xx = a.repeat(y).view(-1,1) yy = b.view(-1,1).repeat(1,x).view(-1,1) return torch.cat([xx,yy],1) if row_major else torch.cat([yy,xx],1) def change_box_order(boxes, order): '''Change box order between (xmin,ymin,xmax,ymax) and (xcenter,ycenter,width,height). Args: boxes: (tensor) bounding boxes, sized [N,4]. order: (str) either 'xyxy2xywh' or 'xywh2xyxy'. Returns: (tensor) converted bounding boxes, sized [N,4]. ''' assert order in ['xyxy2xywh','xywh2xyxy'] a = boxes[:,:2] b = boxes[:,2:] if order == 'xyxy2xywh': return torch.cat([(a+b)/2,b-a+1], 1) return torch.cat([a-b/2,a+b/2], 1) def box_iou(box1, box2, order='xyxy'): '''Compute the intersection over union of two set of boxes. The default box order is (xmin, ymin, xmax, ymax). Args: box1: (tensor) bounding boxes, sized [N,4]. box2: (tensor) bounding boxes, sized [M,4]. order: (str) box order, either 'xyxy' or 'xywh'. Return: (tensor) iou, sized [N,M]. Reference: https://github.com/chainer/chainercv/blob/master/chainercv/utils/bbox/bbox_iou.py ''' if order == 'xywh': box1 = change_box_order(box1, 'xywh2xyxy') box2 = change_box_order(box2, 'xywh2xyxy') N = box1.size(0) M = box2.size(0) lt = torch.max(box1[:,None,:2], box2[:,:2]) # [N,M,2] rb = torch.min(box1[:,None,2:], box2[:,2:]) # [N,M,2] wh = (rb-lt+1).clamp(min=0) # [N,M,2] inter = wh[:,:,0] * wh[:,:,1] # [N,M] area1 = (box1[:,2]-box1[:,0]+1) * (box1[:,3]-box1[:,1]+1) # [N,] area2 = (box2[:,2]-box2[:,0]+1) * (box2[:,3]-box2[:,1]+1) # [M,] iou = inter / (area1[:,None] + area2 - inter) return iou def box_nms(bboxes, scores, threshold=0.5, mode='union'): '''Non maximum suppression. Args: bboxes: (tensor) bounding boxes, sized [N,4]. scores: (tensor) bbox scores, sized [N,]. threshold: (float) overlap threshold. mode: (str) 'union' or 'min'. Returns: keep: (tensor) selected indices. Reference: https://github.com/rbgirshick/py-faster-rcnn/blob/master/lib/nms/py_cpu_nms.py ''' #print (bboxes.shape,scores.shape) if (len(bboxes.shape)==1): bboxes = bboxes.unsqueeze(0) x1 = bboxes[:,0] y1 = bboxes[:,1] x2 = bboxes[:,2] y2 = bboxes[:,3] areas = (x2-x1+1) * (y2-y1+1) _, order = scores.sort(0, descending=True) keep = [] while order.numel() > 0: if order.numel() == 1: i = order.item() else: i = order.data[0] keep.append(i) if order.numel() == 1: break xx1 = x1[order[1:]].clamp(min=x1[i]) yy1 = y1[order[1:]].clamp(min=y1[i]) xx2 = x2[order[1:]].clamp(max=x2[i]) yy2 = y2[order[1:]].clamp(max=y2[i]) w = (xx2-xx1+1).clamp(min=0) h = (yy2-yy1+1).clamp(min=0) inter = w*h if mode == 'union': ovr = inter / (areas[i] + areas[order[1:]] - inter) elif mode == 'min': ovr = inter / areas[order[1:]].clamp(max=areas[i]) else: raise TypeError('Unknown nms mode: %s.' % mode) ids = (ovr<=threshold).nonzero().squeeze() if ids.numel() == 0: break order = order[ids+1] return torch.LongTensor(keep) def softmax(x): '''Softmax along a specific dimension. Args: x: (tensor) input tensor, sized [N,D]. Returns: (tensor) softmaxed tensor, sized [N,D]. ''' xmax, _ = x.max(1) x_shift = x - xmax.view(-1,1) x_exp = x_shift.exp() return x_exp / x_exp.sum(1).view(-1,1) def one_hot_embedding(labels, num_classes): '''Embedding labels to one-hot form. Args: labels: (LongTensor) class labels, sized [N,]. num_classes: (int) number of classes. Returns: (tensor) encoded labels, sized [N,#classes]. ''' y = torch.eye(num_classes) # [D,D] return y[labels] # [N,D] def msr_init(net): '''Initialize layer parameters.''' for layer in net: if type(layer) == nn.Conv2d: n = layer.kernel_size[0]*layer.kernel_size[1]*layer.out_channels layer.weight.data.normal_(0, math.sqrt(2./n)) layer.bias.data.zero_() elif type(layer) == nn.BatchNorm2d: layer.weight.data.fill_(1) layer.bias.data.zero_() elif type(layer) == nn.Linear: layer.bias.data.zero_() #_, term_width = os.popen('stty size', 'r').read().split() term_width = 80 TOTAL_BAR_LENGTH = 86. last_time = time.time() begin_time = last_time def progress_bar(current, total, msg=None): global last_time, begin_time if current == 0: begin_time = time.time() # Reset for new bar. cur_len = int(TOTAL_BAR_LENGTH*current/total) rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1 sys.stdout.write(' [') for i in range(cur_len): sys.stdout.write('=') sys.stdout.write('>') for i in range(rest_len): sys.stdout.write('.') sys.stdout.write(']') cur_time = time.time() step_time = cur_time - last_time last_time = cur_time tot_time = cur_time - begin_time L = [] L.append(' Step: %s' % format_time(step_time)) L.append(' | Tot: %s' % format_time(tot_time)) if msg: L.append(' | ' + msg) msg = ''.join(L) sys.stdout.write(msg) for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3): sys.stdout.write(' ') # Go back to the center of the bar. for i in range(term_width-int(TOTAL_BAR_LENGTH/2)): sys.stdout.write('\b') sys.stdout.write(' %d/%d ' % (current+1, total)) if current < total-1: sys.stdout.write('\r') else: sys.stdout.write('\n') sys.stdout.flush() def format_time(seconds): days = int(seconds / 3600/24) seconds = seconds - days*3600*24 hours = int(seconds / 3600) seconds = seconds - hours*3600 minutes = int(seconds / 60) seconds = seconds - minutes*60 secondsf = int(seconds) seconds = seconds - secondsf millis = int(seconds*1000) f = '' i = 1 if days > 0: f += str(days) + 'D' i += 1 if hours > 0 and i <= 2: f += str(hours) + 'h' i += 1 if minutes > 0 and i <= 2: f += str(minutes) + 'm' i += 1 if secondsf > 0 and i <= 2: f += str(secondsf) + 's' i += 1 if millis > 0 and i <= 2: f += str(millis) + 'ms' i += 1 if f == '': f = '0ms' return f