| | import torch |
| | import torch.nn.functional as F |
| | from numpy import ceil |
| |
|
| |
|
| | def cat_k(input): |
| | """concat second dimesion to batch""" |
| | return input.flatten(0, 1) |
| |
|
| |
|
| | def split_k(input, size: int, dim: int = 0): |
| | """reshape input to original batch size""" |
| | if dim < 0: |
| | dim = input.dim() + dim |
| | split_size = list(input.size()) |
| | split_size[dim] = size |
| | split_size.insert(dim+1, -1) |
| | return input.view(split_size) |
| |
|
| |
|
| | class Alignment(torch.nn.Module): |
| | """Image Alignment for model downsample requirement""" |
| | |
| | def __init__(self, divisor=64., mode='pad', padding_mode='replicate'): |
| | super().__init__() |
| | self.divisor = float(divisor) |
| | self.mode = mode |
| | self.padding_mode = padding_mode |
| | self._tmp_shape = None |
| |
|
| | def extra_repr(self): |
| | s = 'divisor={divisor}, mode={mode}' |
| | if self.mode == 'pad': |
| | s += ', padding_mode={padding_mode}' |
| | return s.format(**self.__dict__) |
| |
|
| | @staticmethod |
| | def _resize(input, size): |
| | return F.interpolate(input, size, mode='bilinear', align_corners=False) |
| |
|
| | def _align(self, input): |
| | H, W = input.size()[-2:] |
| | H_ = int(ceil(H / self.divisor) * self.divisor) |
| | W_ = int(ceil(W / self.divisor) * self.divisor) |
| | pad_H, pad_W = H_-H, W_-W |
| | if pad_H == pad_W == 0: |
| | self._tmp_shape = None |
| | return input |
| |
|
| | self._tmp_shape = input.size() |
| | if self.mode == 'pad': |
| | return F.pad(input, (0, pad_W, 0, pad_H), mode=self.padding_mode) |
| | elif self.mode == 'resize': |
| | return self._resize(input, size=(H_, W_)) |
| |
|
| | def _resume(self, input, shape=None): |
| | if shape is not None: |
| | self._tmp_shape = shape |
| | if self._tmp_shape is None: |
| | return input |
| |
|
| | if self.mode == 'pad': |
| | output = input[..., :self._tmp_shape[-2], :self._tmp_shape[-1]] |
| | elif self.mode == 'resize': |
| | output = self._resize(input, size=self._tmp_shape[-2:]) |
| |
|
| | return output |
| |
|
| | def align(self, input): |
| | """align""" |
| | if input.dim() == 4: |
| | return self._align(input) |
| | elif input.dim() == 5: |
| | return split_k(self._align(cat_k(input)), input.size(0)) |
| |
|
| | def resume(self, input, shape=None): |
| | """resume""" |
| | if input.dim() == 4: |
| | return self._resume(input, shape) |
| | elif input.dim() == 5: |
| | return split_k(self._resume(cat_k(input), shape), input.size(0)) |
| |
|
| | def forward(self, func, *args, **kwargs): |
| | pass |