File size: 2,049 Bytes
032e687 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
import math
import types
import torch.nn as nn
import torch.nn.functional as F
from mmengine.logging import print_log
from .unet import UNet, Upsample
from ..utils import resize
def upsample_forward_func(self, x):
dtype = x.dtype
x = x.float()
if not self.size:
size = [int(t * self.scale_factor) for t in x.shape[-2:]]
else:
size = self.size
return resize(x, size, None, self.mode, self.align_corners).to(dtype)
class UNetHead(UNet):
def __init__(self, upsample_input=None,
normalize_input=False,
*args, **kwargs):
super().__init__(*args, **kwargs)
self.conv_seg = nn.Conv2d(self.base_channels, 1, kernel_size=1)
for module in self.modules():
if isinstance(module, Upsample):
print_log("Replace upsample forward function")
module.forward = types.MethodType(
upsample_forward_func, module)
self.init_weights()
self.upsample_input = upsample_input
self.normalize_input = normalize_input
@property
def dtype(self):
return self.conv_seg.weight.dtype
def forward(self, x):
h, w = x.shape[-2:]
if self.normalize_input:
assert x.min() >= 0.0 and x.max() <= 1.0
x_sum = x.sum((-2, -1), keepdims=True).clamp(min=1e-12)
x = x / x_sum
if self.upsample_input is not None:
scale_factor = max(1.0, self.upsample_input / max(h, w))
x = F.interpolate(
x.float(), scale_factor=scale_factor, mode='bilinear').to(x)
# upsample the low-res input to get better results
h, w = x.shape[-2:]
dividend = 2**(self.num_stages - 1)
padded_h = math.ceil(h / dividend) * dividend
padded_w = math.ceil(w / dividend) * dividend
padded_x = x.new_zeros(*x.shape[:2], padded_h, padded_w)
padded_x[..., :h, :w] = x
x = super().forward(padded_x)[-1][..., :h, :w]
return self.conv_seg(x)
|