File size: 2,049 Bytes
032e687
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import math
import types
import torch.nn as nn
import torch.nn.functional as F
from mmengine.logging import print_log

from .unet import UNet, Upsample
from ..utils import resize


def upsample_forward_func(self, x):
    dtype = x.dtype
    x = x.float()
    if not self.size:
        size = [int(t * self.scale_factor) for t in x.shape[-2:]]
    else:
        size = self.size
    return resize(x, size, None, self.mode, self.align_corners).to(dtype)


class UNetHead(UNet):
    def __init__(self, upsample_input=None,
                 normalize_input=False,
                 *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.conv_seg = nn.Conv2d(self.base_channels, 1, kernel_size=1)

        for module in self.modules():
            if isinstance(module, Upsample):
                print_log("Replace upsample forward function")
                module.forward = types.MethodType(
                    upsample_forward_func, module)

        self.init_weights()
        self.upsample_input = upsample_input
        self.normalize_input = normalize_input

    @property
    def dtype(self):
        return self.conv_seg.weight.dtype

    def forward(self, x):
        h, w = x.shape[-2:]
        if self.normalize_input:
            assert x.min() >= 0.0 and x.max() <= 1.0
            x_sum = x.sum((-2, -1), keepdims=True).clamp(min=1e-12)
            x = x / x_sum

        if self.upsample_input is not None:
            scale_factor = max(1.0, self.upsample_input / max(h, w))
            x = F.interpolate(
                x.float(), scale_factor=scale_factor, mode='bilinear').to(x)
            # upsample the low-res input to get better results
            h, w = x.shape[-2:]

        dividend = 2**(self.num_stages - 1)
        padded_h = math.ceil(h / dividend) * dividend
        padded_w = math.ceil(w / dividend) * dividend

        padded_x = x.new_zeros(*x.shape[:2], padded_h, padded_w)
        padded_x[..., :h, :w] = x
        x = super().forward(padded_x)[-1][..., :h, :w]
        return self.conv_seg(x)