File size: 1,851 Bytes
bcc0f94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import torch.nn as nn
from .encoder import ResnetEncoder as Encoder
from .decoder import CustomDecoder as Decoder


class UNet(nn.Module):
    def __init__(self, decoder_config, nclasses, input_shape=(224, 224)):
        super(UNet, self).__init__()
        self.encoder = Encoder()
        self.decoder = Decoder(config=decoder_config)

        self.output = nn.Sequential(
            nn.Conv2d(
                in_channels=decoder_config["block1"]["out_channels"],
                out_channels=nclasses,
                kernel_size=1,
            ),
            nn.UpsamplingBilinear2d(size=input_shape),
        )

    def forward(self, x):
        x, encoder_step_output = self.encoder(x)
        x = self.decoder(x, encoder_step_output)
        x = self.output(x)
        return x


if __name__ == "__main__":
    import torch
    import yaml
    from easydict import EasyDict
    from torchinfo import summary

    # load config
    config_path = "./config/resnet_config.yml"

    with open(config_path, "r") as file:
        yaml_data = yaml.safe_load(file)

    config = EasyDict(yaml_data)

    # input shape
    input_shape = (224, 224)

    # device
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")

    # model definition
    model = UNet(
        decoder_config=config["decoder_config"], nclasses=1, input_shape=input_shape
    ).to(device)

    summary(
        model,
        input_data=torch.rand((1, 3, input_shape[0], input_shape[1])),
        device=device,
    )

    # load weights (if any)
    model_path = None

    if model_path is not None:
        checkpoint = torch.load(model_path, map_location=device)
        model.decoder.load_state_dict(checkpoint["decoder_state_dict"], strict=False)
        model.output.load_state_dict(checkpoint["output_state_dict"], strict=False)