|
from pathlib import Path |
|
|
|
import torch |
|
import torch.nn as nn |
|
from climategan.deeplab.deeplab_v2 import DeepLabV2Decoder |
|
from climategan.deeplab.deeplab_v3 import DeepLabV3Decoder |
|
from climategan.deeplab.mobilenet_v3 import MobileNetV2 |
|
from climategan.deeplab.resnet101_v3 import ResNet101 |
|
from climategan.deeplab.resnetmulti_v2 import ResNetMulti |
|
|
|
|
|
def create_encoder(opts, no_init=False, verbose=0): |
|
if opts.gen.encoder.architecture == "deeplabv2": |
|
if verbose > 0: |
|
print(" - Add Deeplabv2 Encoder") |
|
return DeeplabV2Encoder(opts, no_init, verbose) |
|
elif opts.gen.encoder.architecture == "deeplabv3": |
|
if verbose > 0: |
|
backone = opts.gen.deeplabv3.backbone |
|
print(" - Add Deeplabv3 ({}) Encoder".format(backone)) |
|
return build_v3_backbone(opts, no_init) |
|
else: |
|
raise NotImplementedError( |
|
"Unknown encoder: {}".format(opts.gen.encoder.architecture) |
|
) |
|
|
|
|
|
def create_segmentation_decoder(opts, no_init=False, verbose=0): |
|
if opts.gen.s.architecture == "deeplabv2": |
|
if verbose > 0: |
|
print(" - Add DeepLabV2Decoder") |
|
return DeepLabV2Decoder(opts) |
|
elif opts.gen.s.architecture == "deeplabv3": |
|
if verbose > 0: |
|
print(" - Add DeepLabV3Decoder") |
|
return DeepLabV3Decoder(opts, no_init) |
|
else: |
|
raise NotImplementedError( |
|
"Unknown Segmentation architecture: {}".format(opts.gen.s.architecture) |
|
) |
|
|
|
|
|
def build_v3_backbone(opts, no_init, verbose=0): |
|
backbone = opts.gen.deeplabv3.backbone |
|
output_stride = opts.gen.deeplabv3.output_stride |
|
if backbone == "resnet": |
|
resnet = ResNet101( |
|
output_stride=output_stride, |
|
BatchNorm=nn.BatchNorm2d, |
|
verbose=verbose, |
|
no_init=no_init, |
|
) |
|
if not no_init: |
|
if opts.gen.deeplabv3.backbone == "resnet": |
|
assert Path(opts.gen.deeplabv3.pretrained_model.resnet).exists() |
|
|
|
std = torch.load(opts.gen.deeplabv3.pretrained_model.resnet) |
|
resnet.load_state_dict( |
|
{ |
|
k.replace("backbone.", ""): v |
|
for k, v in std.items() |
|
if k.startswith("backbone.") |
|
} |
|
) |
|
print( |
|
" - Loaded pre-trained DeepLabv3+ Resnet101 Backbone as Encoder" |
|
) |
|
return resnet |
|
|
|
elif opts.gen.deeplabv3.backbone == "mobilenet": |
|
assert Path(opts.gen.deeplabv3.pretrained_model.mobilenet).exists() |
|
mobilenet = MobileNetV2( |
|
no_init=no_init, |
|
pretrained_path=opts.gen.deeplabv3.pretrained_model.mobilenet, |
|
) |
|
print(" - Loaded pre-trained DeepLabv3+ MobileNetV2 Backbone as Encoder") |
|
return mobilenet |
|
|
|
else: |
|
raise NotImplementedError("Unknown backbone in " + str(opts.gen.deeplabv3)) |
|
|
|
|
|
class DeeplabV2Encoder(nn.Module): |
|
def __init__(self, opts, no_init=False, verbose=0): |
|
"""Deeplab architecture encoder""" |
|
super().__init__() |
|
|
|
self.model = ResNetMulti(opts.gen.deeplabv2.nblocks, opts.gen.encoder.n_res) |
|
if opts.gen.deeplabv2.use_pretrained and not no_init: |
|
saved_state_dict = torch.load(opts.gen.deeplabv2.pretrained_model) |
|
new_params = self.model.state_dict().copy() |
|
for i in saved_state_dict: |
|
i_parts = i.split(".") |
|
if not i_parts[1] in ["layer5", "resblock"]: |
|
new_params[".".join(i_parts[1:])] = saved_state_dict[i] |
|
self.model.load_state_dict(new_params) |
|
if verbose > 0: |
|
print(" - Loaded pretrained weights") |
|
|
|
def forward(self, x): |
|
return self.model(x) |
|
|