import socket import timeit import numpy as np from PIL import Image from datetime import datetime import os import sys from collections import OrderedDict sys.path.append('./') # PyTorch includes import torch from torch.autograd import Variable from torchvision import transforms import cv2 # Custom includes from networks import deeplab_xception_transfer, graph from dataloaders import custom_transforms as tr # import argparse import torch.nn.functional as F import warnings warnings.filterwarnings("ignore") label_colours = [(0,0,0) , (128,0,0), (255,0,0), (0,85,0), (170,0,51), (255,85,0), (0,0,85), (0,119,221), (85,85,0), (0,85,85), (85,51,0), (52,86,128), (0,128,0) , (0,0,255), (51,170,221), (0,255,255), (85,255,170), (170,255,85), (255,255,0), (255,170,0)] def flip(x, dim): indices = [slice(None)] * x.dim() indices[dim] = torch.arange(x.size(dim) - 1, -1, -1, dtype=torch.long, device=x.device) return x[tuple(indices)] def flip_cihp(tail_list): ''' :param tail_list: tail_list size is 1 x n_class x h x w :return: ''' # tail_list = tail_list[0] tail_list_rev = [None] * 20 for xx in range(14): tail_list_rev[xx] = tail_list[xx].unsqueeze(0) tail_list_rev[14] = tail_list[15].unsqueeze(0) tail_list_rev[15] = tail_list[14].unsqueeze(0) tail_list_rev[16] = tail_list[17].unsqueeze(0) tail_list_rev[17] = tail_list[16].unsqueeze(0) tail_list_rev[18] = tail_list[19].unsqueeze(0) tail_list_rev[19] = tail_list[18].unsqueeze(0) return torch.cat(tail_list_rev,dim=0) def decode_labels(mask, num_images=1, num_classes=20): """Decode batch of segmentation masks. Args: mask: result of inference after taking argmax. num_images: number of images to decode from the batch. num_classes: number of classes to predict (including background). Returns: A batch with num_images RGB images of the same size as the input. """ n, h, w = mask.shape assert (n >= num_images), 'Batch size %d should be greater or equal than number of images to save %d.' % ( n, num_images) outputs = np.zeros((num_images, h, w, 3), dtype=np.uint8) for i in range(num_images): img = Image.new('RGB', (len(mask[i, 0]), len(mask[i]))) pixels = img.load() for j_, j in enumerate(mask[i, :, :]): for k_, k in enumerate(j): if k < num_classes: pixels[k_, j_] = label_colours[k] outputs[i] = np.array(img) return outputs def read_img(img_path): _img = Image.open(img_path).convert('RGB') # return is RGB pic return _img def img_transform(img, transform=None): sample = {'image': img, 'label': 0} sample = transform(sample) return sample def inference(net, img_path='', output_path='./', output_name='f', use_gpu=True): ''' :param net: :param img_path: :param output_path: :return: ''' # adj adj2_ = torch.from_numpy(graph.cihp2pascal_nlp_adj).float() adj2_test = adj2_.unsqueeze(0).unsqueeze(0).expand(1, 1, 7, 20).cuda().transpose(2, 3) adj1_ = Variable(torch.from_numpy(graph.preprocess_adj(graph.pascal_graph)).float()) adj3_test = adj1_.unsqueeze(0).unsqueeze(0).expand(1, 1, 7, 7).cuda() cihp_adj = graph.preprocess_adj(graph.cihp_graph) adj3_ = Variable(torch.from_numpy(cihp_adj).float()) adj1_test = adj3_.unsqueeze(0).unsqueeze(0).expand(1, 1, 20, 20).cuda() # multi-scale scale_list = [1, 0.5, 0.75, 1.25, 1.5, 1.75] img = read_img(img_path) testloader_list = [] testloader_flip_list = [] for pv in scale_list: composed_transforms_ts = transforms.Compose([ tr.Scale_only_img(pv), tr.Normalize_xception_tf_only_img(), tr.ToTensor_only_img()]) composed_transforms_ts_flip = transforms.Compose([ tr.Scale_only_img(pv), tr.HorizontalFlip_only_img(), tr.Normalize_xception_tf_only_img(), tr.ToTensor_only_img()]) testloader_list.append(img_transform(img, composed_transforms_ts)) # print(img_transform(img, composed_transforms_ts)) testloader_flip_list.append(img_transform(img, composed_transforms_ts_flip)) # print(testloader_list) start_time = timeit.default_timer() # One testing epoch net.eval() # 1 0.5 0.75 1.25 1.5 1.75 ; flip: for iii, sample_batched in enumerate(zip(testloader_list, testloader_flip_list)): inputs, labels = sample_batched[0]['image'], sample_batched[0]['label'] inputs_f, _ = sample_batched[1]['image'], sample_batched[1]['label'] inputs = inputs.unsqueeze(0) inputs_f = inputs_f.unsqueeze(0) inputs = torch.cat((inputs, inputs_f), dim=0) if iii == 0: _, _, h, w = inputs.size() # assert inputs.size() == inputs_f.size() # Forward pass of the mini-batch inputs = Variable(inputs, requires_grad=False) with torch.no_grad(): if use_gpu >= 0: inputs = inputs.cuda() # outputs = net.forward(inputs) outputs = net.forward(inputs, adj1_test.cuda(), adj3_test.cuda(), adj2_test.cuda()) outputs = (outputs[0] + flip(flip_cihp(outputs[1]), dim=-1)) / 2 outputs = outputs.unsqueeze(0) if iii > 0: outputs = F.upsample(outputs, size=(h, w), mode='bilinear', align_corners=True) outputs_final = outputs_final + outputs else: outputs_final = outputs.clone() ################ plot pic predictions = torch.max(outputs_final, 1)[1] results = predictions.cpu().numpy() vis_res = decode_labels(results) parsing_im = Image.fromarray(vis_res[0]) parsing_im.save(output_path+'/{}.png'.format(output_name)) cv2.imwrite(output_path+'/{}_gray.png'.format(output_name), results[0, :, :]) end_time = timeit.default_timer() print('time used for the multi-scale image inference' + ' is :' + str(end_time - start_time)) if __name__ == '__main__': '''argparse begin''' parser = argparse.ArgumentParser() # parser.add_argument('--loadmodel',default=None,type=str) parser.add_argument('--loadmodel', default='', type=str) parser.add_argument('--img_path', default='', type=str) parser.add_argument('--output_path', default='', type=str) parser.add_argument('--output_name', default='', type=str) parser.add_argument('--use_gpu', default=1, type=int) opts = parser.parse_args() net = deeplab_xception_transfer.deeplab_xception_transfer_projection_savemem(n_classes=20, hidden_layers=128, source_classes=7, ) if not opts.loadmodel == '': x = torch.load(opts.loadmodel) net.load_source_model(x) print('load model:', opts.loadmodel) else: print('no model load !!!!!!!!') raise RuntimeError('No model!!!!') if opts.use_gpu >0 : net.cuda() use_gpu = True else: use_gpu = False raise RuntimeError('must use the gpu!!!!') inference(net=net, img_path=opts.img_path,output_path=opts.output_path , output_name=opts.output_name, use_gpu=use_gpu)