Spaces:
Runtime error
Runtime error
########################################################################### | |
# Referred to: https://github.com/zhanghang1989/PyTorch-Encoding | |
########################################################################### | |
import math | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.nn.parallel.data_parallel import DataParallel | |
from torch.nn.parallel.scatter_gather import scatter | |
import threading | |
import torch | |
from torch.cuda._utils import _get_device_index | |
from torch.cuda.amp import autocast | |
from torch._utils import ExceptionWrapper | |
up_kwargs = {'mode': 'bilinear', 'align_corners': True} | |
__all__ = ['MultiEvalModule'] | |
class MultiEvalModule(DataParallel): | |
"""Multi-size Segmentation Eavluator""" | |
def __init__(self, module, nclass, device_ids=None, flip=True, | |
scales=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75]): | |
super(MultiEvalModule, self).__init__(module, device_ids) | |
self.nclass = nclass | |
self.base_size = module.base_size | |
self.crop_size = module.crop_size | |
self.scales = scales | |
self.flip = flip | |
print('MultiEvalModule: base_size {}, crop_size {}'. \ | |
format(self.base_size, self.crop_size)) | |
def parallel_forward(self, inputs, **kwargs): | |
"""Multi-GPU Mult-size Evaluation | |
Args: | |
inputs: list of Tensors | |
""" | |
inputs = [(input.unsqueeze(0).cuda(device),) | |
for input, device in zip(inputs, self.device_ids)] | |
replicas = self.replicate(self, self.device_ids[:len(inputs)]) | |
kwargs = scatter(kwargs, target_gpus, dim) if kwargs else [] | |
if len(inputs) < len(kwargs): | |
inputs.extend([() for _ in range(len(kwargs) - len(inputs))]) | |
elif len(kwargs) < len(inputs): | |
kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))]) | |
outputs = self.parallel_apply(replicas, inputs, kwargs) | |
#for out in outputs: | |
# print('out.size()', out.size()) | |
return outputs | |
def forward(self, image): | |
"""Mult-size Evaluation""" | |
# only single image is supported for evaluation | |
batch, _, h, w = image.size() | |
assert(batch == 1) | |
stride_rate = 2.0/3.0 | |
crop_size = self.crop_size | |
stride = int(crop_size * stride_rate) | |
with torch.cuda.device_of(image): | |
scores = image.new().resize_(batch,self.nclass,h,w).zero_().cuda() | |
for scale in self.scales: | |
long_size = int(math.ceil(self.base_size * scale)) | |
if h > w: | |
height = long_size | |
width = int(1.0 * w * long_size / h + 0.5) | |
short_size = width | |
else: | |
width = long_size | |
height = int(1.0 * h * long_size / w + 0.5) | |
short_size = height | |
""" | |
short_size = int(math.ceil(self.base_size * scale)) | |
if h > w: | |
width = short_size | |
height = int(1.0 * h * short_size / w) | |
long_size = height | |
else: | |
height = short_size | |
width = int(1.0 * w * short_size / h) | |
long_size = width | |
""" | |
# resize image to current size | |
cur_img = resize_image(image, height, width, **self.module._up_kwargs) | |
if long_size <= crop_size: | |
pad_img = pad_image(cur_img, self.module.mean, | |
self.module.std, crop_size) | |
outputs = module_inference(self.module, pad_img, self.flip) | |
outputs = crop_image(outputs, 0, height, 0, width) | |
else: | |
if short_size < crop_size: | |
# pad if needed | |
pad_img = pad_image(cur_img, self.module.mean, | |
self.module.std, crop_size) | |
else: | |
pad_img = cur_img | |
_,_,ph,pw = pad_img.size() | |
assert(ph >= height and pw >= width) | |
# grid forward and normalize | |
h_grids = int(math.ceil(1.0 * (ph-crop_size)/stride)) + 1 | |
w_grids = int(math.ceil(1.0 * (pw-crop_size)/stride)) + 1 | |
with torch.cuda.device_of(image): | |
outputs = image.new().resize_(batch,self.nclass,ph,pw).zero_().cuda() | |
count_norm = image.new().resize_(batch,1,ph,pw).zero_().cuda() | |
# grid evaluation | |
for idh in range(h_grids): | |
for idw in range(w_grids): | |
h0 = idh * stride | |
w0 = idw * stride | |
h1 = min(h0 + crop_size, ph) | |
w1 = min(w0 + crop_size, pw) | |
crop_img = crop_image(pad_img, h0, h1, w0, w1) | |
# pad if needed | |
pad_crop_img = pad_image(crop_img, self.module.mean, | |
self.module.std, crop_size) | |
output = module_inference(self.module, pad_crop_img, self.flip) | |
outputs[:,:,h0:h1,w0:w1] += crop_image(output, | |
0, h1-h0, 0, w1-w0) | |
count_norm[:,:,h0:h1,w0:w1] += 1 | |
assert((count_norm==0).sum()==0) | |
outputs = outputs / count_norm | |
outputs = outputs[:,:,:height,:width] | |
score = resize_image(outputs, h, w, **self.module._up_kwargs) | |
scores += score | |
return scores | |
def module_inference(module, image, flip=True): | |
output = module.evaluate(image) | |
if flip: | |
fimg = flip_image(image) | |
foutput = module.evaluate(fimg) | |
output += flip_image(foutput) | |
return output | |
def resize_image(img, h, w, **up_kwargs): | |
return F.interpolate(img, (h, w), **up_kwargs) | |
def pad_image(img, mean, std, crop_size): | |
b,c,h,w = img.size() | |
assert(c==3) | |
padh = crop_size - h if h < crop_size else 0 | |
padw = crop_size - w if w < crop_size else 0 | |
pad_values = -np.array(mean) / np.array(std) | |
img_pad = img.new().resize_(b,c,h+padh,w+padw) | |
for i in range(c): | |
# note that pytorch pad params is in reversed orders | |
img_pad[:,i,:,:] = F.pad(img[:,i,:,:], (0, padw, 0, padh), value=pad_values[i]) | |
assert(img_pad.size(2)>=crop_size and img_pad.size(3)>=crop_size) | |
return img_pad | |
def crop_image(img, h0, h1, w0, w1): | |
return img[:,:,h0:h1,w0:w1] | |
def flip_image(img): | |
assert(img.dim()==4) | |
with torch.cuda.device_of(img): | |
idx = torch.arange(img.size(3)-1, -1, -1).type_as(img).long() | |
return img.index_select(3, idx) |