|
|
|
|
|
|
|
import math |
|
import numpy as np |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch.nn.parallel.data_parallel import DataParallel |
|
from torch.nn.parallel.scatter_gather import scatter |
|
import threading |
|
import torch |
|
from torch.cuda._utils import _get_device_index |
|
from torch.cuda.amp import autocast |
|
from torch._utils import ExceptionWrapper |
|
|
|
up_kwargs = {'mode': 'bilinear', 'align_corners': True} |
|
|
|
__all__ = ['LSeg_MultiEvalModule'] |
|
|
|
|
|
class LSeg_MultiEvalModule(DataParallel): |
|
"""Multi-size Segmentation Eavluator""" |
|
def __init__(self, module, device_ids=None, flip=True, |
|
scales=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75]): |
|
super(LSeg_MultiEvalModule, self).__init__(module, device_ids) |
|
self.base_size = module.base_size |
|
self.crop_size = module.crop_size |
|
self.scales = scales |
|
self.flip = flip |
|
print('MultiEvalModule: base_size {}, crop_size {}'. \ |
|
format(self.base_size, self.crop_size)) |
|
|
|
def parallel_forward(self, inputs, label_set='', **kwargs): |
|
"""Multi-GPU Mult-size Evaluation |
|
|
|
Args: |
|
inputs: list of Tensors |
|
""" |
|
if len(label_set) < 10: |
|
print('** MultiEvalModule parallel_forward phase: {} **'.format(label_set)) |
|
self.nclass = len(label_set) |
|
inputs = [(input.unsqueeze(0).cuda(device),) |
|
for input, device in zip(inputs, self.device_ids)] |
|
replicas = self.replicate(self, self.device_ids[:len(inputs)]) |
|
kwargs = scatter(kwargs, target_gpus, dim) if kwargs else [] |
|
if len(inputs) < len(kwargs): |
|
inputs.extend([() for _ in range(len(kwargs) - len(inputs))]) |
|
elif len(kwargs) < len(inputs): |
|
kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))]) |
|
outputs = parallel_apply(replicas, inputs, label_set, kwargs) |
|
return outputs |
|
|
|
def forward(self, image, label_set=''): |
|
"""Mult-size Evaluation""" |
|
|
|
if len(label_set) < 10: |
|
print('** MultiEvalModule forward phase: {} **'.format(label_set)) |
|
batch, _, h, w = image.size() |
|
assert(batch == 1) |
|
self.nclass = len(label_set) |
|
stride_rate = 2.0/3.0 |
|
crop_size = self.crop_size |
|
stride = int(crop_size * stride_rate) |
|
with torch.cuda.device_of(image): |
|
scores = image.new().resize_(batch,self.nclass,h,w).zero_().cuda() |
|
|
|
for scale in self.scales: |
|
long_size = int(math.ceil(self.base_size * scale)) |
|
if h > w: |
|
height = long_size |
|
width = int(1.0 * w * long_size / h + 0.5) |
|
short_size = width |
|
else: |
|
width = long_size |
|
height = int(1.0 * h * long_size / w + 0.5) |
|
short_size = height |
|
""" |
|
short_size = int(math.ceil(self.base_size * scale)) |
|
if h > w: |
|
width = short_size |
|
height = int(1.0 * h * short_size / w) |
|
long_size = height |
|
else: |
|
height = short_size |
|
width = int(1.0 * w * short_size / h) |
|
long_size = width |
|
""" |
|
|
|
cur_img = resize_image(image, height, width, **self.module._up_kwargs) |
|
if long_size <= crop_size: |
|
pad_img = pad_image(cur_img, self.module.mean, |
|
self.module.std, crop_size) |
|
outputs = module_inference(self.module, pad_img, label_set, self.flip) |
|
outputs = crop_image(outputs, 0, height, 0, width) |
|
else: |
|
if short_size < crop_size: |
|
|
|
pad_img = pad_image(cur_img, self.module.mean, |
|
self.module.std, crop_size) |
|
else: |
|
pad_img = cur_img |
|
_,_,ph,pw = pad_img.shape |
|
assert(ph >= height and pw >= width) |
|
|
|
h_grids = int(math.ceil(1.0 * (ph-crop_size)/stride)) + 1 |
|
w_grids = int(math.ceil(1.0 * (pw-crop_size)/stride)) + 1 |
|
with torch.cuda.device_of(image): |
|
outputs = image.new().resize_(batch,self.nclass,ph,pw).zero_().cuda() |
|
count_norm = image.new().resize_(batch,1,ph,pw).zero_().cuda() |
|
|
|
for idh in range(h_grids): |
|
for idw in range(w_grids): |
|
h0 = idh * stride |
|
w0 = idw * stride |
|
h1 = min(h0 + crop_size, ph) |
|
w1 = min(w0 + crop_size, pw) |
|
crop_img = crop_image(pad_img, h0, h1, w0, w1) |
|
|
|
pad_crop_img = pad_image(crop_img, self.module.mean, |
|
self.module.std, crop_size) |
|
output = module_inference(self.module, pad_crop_img, label_set, self.flip) |
|
outputs[:,:,h0:h1,w0:w1] += crop_image(output, |
|
0, h1-h0, 0, w1-w0) |
|
count_norm[:,:,h0:h1,w0:w1] += 1 |
|
assert((count_norm==0).sum()==0) |
|
outputs = outputs / count_norm |
|
outputs = outputs[:,:,:height,:width] |
|
score = resize_image(outputs, h, w, **self.module._up_kwargs) |
|
scores += score |
|
return scores |
|
|
|
def module_inference(module, image, label_set, flip=True): |
|
output = module.evaluate_random(image, label_set) |
|
if flip: |
|
fimg = flip_image(image) |
|
foutput = module.evaluate_random(fimg, label_set) |
|
output += flip_image(foutput) |
|
return output |
|
|
|
def resize_image(img, h, w, **up_kwargs): |
|
return F.interpolate(img, (h, w), **up_kwargs) |
|
|
|
def pad_image(img, mean, std, crop_size): |
|
b,c,h,w = img.shape |
|
assert(c==3) |
|
padh = crop_size - h if h < crop_size else 0 |
|
padw = crop_size - w if w < crop_size else 0 |
|
pad_values = -np.array(mean) / np.array(std) |
|
img_pad = img.new().resize_(b,c,h+padh,w+padw) |
|
for i in range(c): |
|
|
|
img_pad[:,i,:,:] = F.pad(img[:,i,:,:], (0, padw, 0, padh), value=pad_values[i]) |
|
assert(img_pad.size(2)>=crop_size and img_pad.size(3)>=crop_size) |
|
return img_pad |
|
|
|
def crop_image(img, h0, h1, w0, w1): |
|
return img[:,:,h0:h1,w0:w1] |
|
|
|
def flip_image(img): |
|
assert(img.dim()==4) |
|
with torch.cuda.device_of(img): |
|
idx = torch.arange(img.size(3)-1, -1, -1).type_as(img).long() |
|
return img.index_select(3, idx) |
|
|
|
|
|
def get_a_var(obj): |
|
if isinstance(obj, torch.Tensor): |
|
return obj |
|
|
|
if isinstance(obj, list) or isinstance(obj, tuple): |
|
for result in map(get_a_var, obj): |
|
if isinstance(result, torch.Tensor): |
|
return result |
|
if isinstance(obj, dict): |
|
for result in map(get_a_var, obj.items()): |
|
if isinstance(result, torch.Tensor): |
|
return result |
|
return None |
|
|
|
|
|
def parallel_apply(modules, inputs, label_set, kwargs_tup=None, devices=None): |
|
r"""Applies each `module` in :attr:`modules` in parallel on arguments |
|
contained in :attr:`inputs` (positional) and :attr:`kwargs_tup` (keyword) |
|
on each of :attr:`devices`. |
|
|
|
Args: |
|
modules (Module): modules to be parallelized |
|
inputs (tensor): inputs to the modules |
|
devices (list of int or torch.device): CUDA devices |
|
|
|
:attr:`modules`, :attr:`inputs`, :attr:`kwargs_tup` (if given), and |
|
:attr:`devices` (if given) should all have same length. Moreover, each |
|
element of :attr:`inputs` can either be a single object as the only argument |
|
to a module, or a collection of positional arguments. |
|
""" |
|
assert len(modules) == len(inputs) |
|
if kwargs_tup is not None: |
|
assert len(modules) == len(kwargs_tup) |
|
else: |
|
kwargs_tup = ({},) * len(modules) |
|
if devices is not None: |
|
assert len(modules) == len(devices) |
|
else: |
|
devices = [None] * len(modules) |
|
devices = [_get_device_index(x, True) for x in devices] |
|
lock = threading.Lock() |
|
results = {} |
|
grad_enabled, autocast_enabled = torch.is_grad_enabled(), torch.is_autocast_enabled() |
|
|
|
def _worker(i, module, input, label_set, kwargs, device=None): |
|
torch.set_grad_enabled(grad_enabled) |
|
if device is None: |
|
device = get_a_var(input).get_device() |
|
try: |
|
with torch.cuda.device(device), autocast(enabled=autocast_enabled): |
|
|
|
if not isinstance(input, (list, tuple)): |
|
input = (input,) |
|
output = module(*input, label_set, **kwargs) |
|
with lock: |
|
results[i] = output |
|
except Exception: |
|
with lock: |
|
results[i] = ExceptionWrapper( |
|
where="in replica {} on device {}".format(i, device)) |
|
|
|
if len(modules) > 1: |
|
threads = [threading.Thread(target=_worker, |
|
args=(i, module, input, label_set, kwargs, device)) |
|
for i, (module, input, kwargs, device) in |
|
enumerate(zip(modules, inputs, kwargs_tup, devices))] |
|
|
|
for thread in threads: |
|
thread.start() |
|
for thread in threads: |
|
thread.join() |
|
else: |
|
_worker(0, modules[0], inputs[0], label_set, kwargs_tup[0], devices[0]) |
|
|
|
outputs = [] |
|
for i in range(len(inputs)): |
|
output = results[i] |
|
if isinstance(output, ExceptionWrapper): |
|
output.reraise() |
|
outputs.append(output) |
|
return outputs |
|
|
|
|
|
|