from pathlib import Path import torch import torch.nn as nn from climategan.deeplab.deeplab_v2 import DeepLabV2Decoder from climategan.deeplab.deeplab_v3 import DeepLabV3Decoder from climategan.deeplab.mobilenet_v3 import MobileNetV2 from climategan.deeplab.resnet101_v3 import ResNet101 from climategan.deeplab.resnetmulti_v2 import ResNetMulti def create_encoder(opts, no_init=False, verbose=0): if opts.gen.encoder.architecture == "deeplabv2": if verbose > 0: print(" - Add Deeplabv2 Encoder") return DeeplabV2Encoder(opts, no_init, verbose) elif opts.gen.encoder.architecture == "deeplabv3": if verbose > 0: backone = opts.gen.deeplabv3.backbone print(" - Add Deeplabv3 ({}) Encoder".format(backone)) return build_v3_backbone(opts, no_init) else: raise NotImplementedError( "Unknown encoder: {}".format(opts.gen.encoder.architecture) ) def create_segmentation_decoder(opts, no_init=False, verbose=0): if opts.gen.s.architecture == "deeplabv2": if verbose > 0: print(" - Add DeepLabV2Decoder") return DeepLabV2Decoder(opts) elif opts.gen.s.architecture == "deeplabv3": if verbose > 0: print(" - Add DeepLabV3Decoder") return DeepLabV3Decoder(opts, no_init) else: raise NotImplementedError( "Unknown Segmentation architecture: {}".format(opts.gen.s.architecture) ) def build_v3_backbone(opts, no_init, verbose=0): backbone = opts.gen.deeplabv3.backbone output_stride = opts.gen.deeplabv3.output_stride if backbone == "resnet": resnet = ResNet101( output_stride=output_stride, BatchNorm=nn.BatchNorm2d, verbose=verbose, no_init=no_init, ) if not no_init: if opts.gen.deeplabv3.backbone == "resnet": assert Path(opts.gen.deeplabv3.pretrained_model.resnet).exists() std = torch.load(opts.gen.deeplabv3.pretrained_model.resnet) resnet.load_state_dict( { k.replace("backbone.", ""): v for k, v in std.items() if k.startswith("backbone.") } ) print( " - Loaded pre-trained DeepLabv3+ Resnet101 Backbone as Encoder" ) return resnet elif opts.gen.deeplabv3.backbone == "mobilenet": assert Path(opts.gen.deeplabv3.pretrained_model.mobilenet).exists() mobilenet = MobileNetV2( no_init=no_init, pretrained_path=opts.gen.deeplabv3.pretrained_model.mobilenet, ) print(" - Loaded pre-trained DeepLabv3+ MobileNetV2 Backbone as Encoder") return mobilenet else: raise NotImplementedError("Unknown backbone in " + str(opts.gen.deeplabv3)) class DeeplabV2Encoder(nn.Module): def __init__(self, opts, no_init=False, verbose=0): """Deeplab architecture encoder""" super().__init__() self.model = ResNetMulti(opts.gen.deeplabv2.nblocks, opts.gen.encoder.n_res) if opts.gen.deeplabv2.use_pretrained and not no_init: saved_state_dict = torch.load(opts.gen.deeplabv2.pretrained_model) new_params = self.model.state_dict().copy() for i in saved_state_dict: i_parts = i.split(".") if not i_parts[1] in ["layer5", "resblock"]: new_params[".".join(i_parts[1:])] = saved_state_dict[i] self.model.load_state_dict(new_params) if verbose > 0: print(" - Loaded pretrained weights") def forward(self, x): return self.model(x)