lang-seg / additional_utils /encoding_models.py
akhaliq's picture
akhaliq HF staff
add files
0870534
raw
history blame
6.72 kB
###########################################################################
# Referred to: https://github.com/zhanghang1989/PyTorch-Encoding
###########################################################################
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parallel.data_parallel import DataParallel
from torch.nn.parallel.scatter_gather import scatter
import threading
import torch
from torch.cuda._utils import _get_device_index
from torch.cuda.amp import autocast
from torch._utils import ExceptionWrapper
up_kwargs = {'mode': 'bilinear', 'align_corners': True}
__all__ = ['MultiEvalModule']
class MultiEvalModule(DataParallel):
"""Multi-size Segmentation Eavluator"""
def __init__(self, module, nclass, device_ids=None, flip=True,
scales=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75]):
super(MultiEvalModule, self).__init__(module, device_ids)
self.nclass = nclass
self.base_size = module.base_size
self.crop_size = module.crop_size
self.scales = scales
self.flip = flip
print('MultiEvalModule: base_size {}, crop_size {}'. \
format(self.base_size, self.crop_size))
def parallel_forward(self, inputs, **kwargs):
"""Multi-GPU Mult-size Evaluation
Args:
inputs: list of Tensors
"""
inputs = [(input.unsqueeze(0).cuda(device),)
for input, device in zip(inputs, self.device_ids)]
replicas = self.replicate(self, self.device_ids[:len(inputs)])
kwargs = scatter(kwargs, target_gpus, dim) if kwargs else []
if len(inputs) < len(kwargs):
inputs.extend([() for _ in range(len(kwargs) - len(inputs))])
elif len(kwargs) < len(inputs):
kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))])
outputs = self.parallel_apply(replicas, inputs, kwargs)
#for out in outputs:
# print('out.size()', out.size())
return outputs
def forward(self, image):
"""Mult-size Evaluation"""
# only single image is supported for evaluation
batch, _, h, w = image.size()
assert(batch == 1)
stride_rate = 2.0/3.0
crop_size = self.crop_size
stride = int(crop_size * stride_rate)
with torch.cuda.device_of(image):
scores = image.new().resize_(batch,self.nclass,h,w).zero_().cuda()
for scale in self.scales:
long_size = int(math.ceil(self.base_size * scale))
if h > w:
height = long_size
width = int(1.0 * w * long_size / h + 0.5)
short_size = width
else:
width = long_size
height = int(1.0 * h * long_size / w + 0.5)
short_size = height
"""
short_size = int(math.ceil(self.base_size * scale))
if h > w:
width = short_size
height = int(1.0 * h * short_size / w)
long_size = height
else:
height = short_size
width = int(1.0 * w * short_size / h)
long_size = width
"""
# resize image to current size
cur_img = resize_image(image, height, width, **self.module._up_kwargs)
if long_size <= crop_size:
pad_img = pad_image(cur_img, self.module.mean,
self.module.std, crop_size)
outputs = module_inference(self.module, pad_img, self.flip)
outputs = crop_image(outputs, 0, height, 0, width)
else:
if short_size < crop_size:
# pad if needed
pad_img = pad_image(cur_img, self.module.mean,
self.module.std, crop_size)
else:
pad_img = cur_img
_,_,ph,pw = pad_img.size()
assert(ph >= height and pw >= width)
# grid forward and normalize
h_grids = int(math.ceil(1.0 * (ph-crop_size)/stride)) + 1
w_grids = int(math.ceil(1.0 * (pw-crop_size)/stride)) + 1
with torch.cuda.device_of(image):
outputs = image.new().resize_(batch,self.nclass,ph,pw).zero_().cuda()
count_norm = image.new().resize_(batch,1,ph,pw).zero_().cuda()
# grid evaluation
for idh in range(h_grids):
for idw in range(w_grids):
h0 = idh * stride
w0 = idw * stride
h1 = min(h0 + crop_size, ph)
w1 = min(w0 + crop_size, pw)
crop_img = crop_image(pad_img, h0, h1, w0, w1)
# pad if needed
pad_crop_img = pad_image(crop_img, self.module.mean,
self.module.std, crop_size)
output = module_inference(self.module, pad_crop_img, self.flip)
outputs[:,:,h0:h1,w0:w1] += crop_image(output,
0, h1-h0, 0, w1-w0)
count_norm[:,:,h0:h1,w0:w1] += 1
assert((count_norm==0).sum()==0)
outputs = outputs / count_norm
outputs = outputs[:,:,:height,:width]
score = resize_image(outputs, h, w, **self.module._up_kwargs)
scores += score
return scores
def module_inference(module, image, flip=True):
output = module.evaluate(image)
if flip:
fimg = flip_image(image)
foutput = module.evaluate(fimg)
output += flip_image(foutput)
return output
def resize_image(img, h, w, **up_kwargs):
return F.interpolate(img, (h, w), **up_kwargs)
def pad_image(img, mean, std, crop_size):
b,c,h,w = img.size()
assert(c==3)
padh = crop_size - h if h < crop_size else 0
padw = crop_size - w if w < crop_size else 0
pad_values = -np.array(mean) / np.array(std)
img_pad = img.new().resize_(b,c,h+padh,w+padw)
for i in range(c):
# note that pytorch pad params is in reversed orders
img_pad[:,i,:,:] = F.pad(img[:,i,:,:], (0, padw, 0, padh), value=pad_values[i])
assert(img_pad.size(2)>=crop_size and img_pad.size(3)>=crop_size)
return img_pad
def crop_image(img, h0, h1, w0, w1):
return img[:,:,h0:h1,w0:w1]
def flip_image(img):
assert(img.dim()==4)
with torch.cuda.device_of(img):
idx = torch.arange(img.size(3)-1, -1, -1).type_as(img).long()
return img.index_select(3, idx)