vict0rsch's picture
initial commit from cc-ai/climateGAN
448ebbd
raw
history blame
3.81 kB
from pathlib import Path
import torch
import torch.nn as nn
from climategan.deeplab.deeplab_v2 import DeepLabV2Decoder
from climategan.deeplab.deeplab_v3 import DeepLabV3Decoder
from climategan.deeplab.mobilenet_v3 import MobileNetV2
from climategan.deeplab.resnet101_v3 import ResNet101
from climategan.deeplab.resnetmulti_v2 import ResNetMulti
def create_encoder(opts, no_init=False, verbose=0):
if opts.gen.encoder.architecture == "deeplabv2":
if verbose > 0:
print(" - Add Deeplabv2 Encoder")
return DeeplabV2Encoder(opts, no_init, verbose)
elif opts.gen.encoder.architecture == "deeplabv3":
if verbose > 0:
backone = opts.gen.deeplabv3.backbone
print(" - Add Deeplabv3 ({}) Encoder".format(backone))
return build_v3_backbone(opts, no_init)
else:
raise NotImplementedError(
"Unknown encoder: {}".format(opts.gen.encoder.architecture)
)
def create_segmentation_decoder(opts, no_init=False, verbose=0):
if opts.gen.s.architecture == "deeplabv2":
if verbose > 0:
print(" - Add DeepLabV2Decoder")
return DeepLabV2Decoder(opts)
elif opts.gen.s.architecture == "deeplabv3":
if verbose > 0:
print(" - Add DeepLabV3Decoder")
return DeepLabV3Decoder(opts, no_init)
else:
raise NotImplementedError(
"Unknown Segmentation architecture: {}".format(opts.gen.s.architecture)
)
def build_v3_backbone(opts, no_init, verbose=0):
backbone = opts.gen.deeplabv3.backbone
output_stride = opts.gen.deeplabv3.output_stride
if backbone == "resnet":
resnet = ResNet101(
output_stride=output_stride,
BatchNorm=nn.BatchNorm2d,
verbose=verbose,
no_init=no_init,
)
if not no_init:
if opts.gen.deeplabv3.backbone == "resnet":
assert Path(opts.gen.deeplabv3.pretrained_model.resnet).exists()
std = torch.load(opts.gen.deeplabv3.pretrained_model.resnet)
resnet.load_state_dict(
{
k.replace("backbone.", ""): v
for k, v in std.items()
if k.startswith("backbone.")
}
)
print(
" - Loaded pre-trained DeepLabv3+ Resnet101 Backbone as Encoder"
)
return resnet
elif opts.gen.deeplabv3.backbone == "mobilenet":
assert Path(opts.gen.deeplabv3.pretrained_model.mobilenet).exists()
mobilenet = MobileNetV2(
no_init=no_init,
pretrained_path=opts.gen.deeplabv3.pretrained_model.mobilenet,
)
print(" - Loaded pre-trained DeepLabv3+ MobileNetV2 Backbone as Encoder")
return mobilenet
else:
raise NotImplementedError("Unknown backbone in " + str(opts.gen.deeplabv3))
class DeeplabV2Encoder(nn.Module):
def __init__(self, opts, no_init=False, verbose=0):
"""Deeplab architecture encoder"""
super().__init__()
self.model = ResNetMulti(opts.gen.deeplabv2.nblocks, opts.gen.encoder.n_res)
if opts.gen.deeplabv2.use_pretrained and not no_init:
saved_state_dict = torch.load(opts.gen.deeplabv2.pretrained_model)
new_params = self.model.state_dict().copy()
for i in saved_state_dict:
i_parts = i.split(".")
if not i_parts[1] in ["layer5", "resblock"]:
new_params[".".join(i_parts[1:])] = saved_state_dict[i]
self.model.load_state_dict(new_params)
if verbose > 0:
print(" - Loaded pretrained weights")
def forward(self, x):
return self.model(x)