# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import os import functools from torch.autograd import Variable from util.image_pool import ImagePool from .base_model import BaseModel from . import networks import math class Mapping_Model_with_mask(nn.Module): def __init__(self, nc, mc=64, n_blocks=3, norm="instance", padding_type="reflect", opt=None): super(Mapping_Model_with_mask, self).__init__() norm_layer = networks.get_norm_layer(norm_type=norm) activation = nn.ReLU(True) model = [] tmp_nc = 64 n_up = 4 for i in range(n_up): ic = min(tmp_nc * (2 ** i), mc) oc = min(tmp_nc * (2 ** (i + 1)), mc) model += [nn.Conv2d(ic, oc, 3, 1, 1), norm_layer(oc), activation] self.before_NL = nn.Sequential(*model) if opt.NL_res: self.NL = networks.NonLocalBlock2D_with_mask_Res( mc, mc, opt.NL_fusion_method, opt.correlation_renormalize, opt.softmax_temperature, opt.use_self, opt.cosin_similarity, ) print("You are using NL + Res") model = [] for i in range(n_blocks): model += [ networks.ResnetBlock( mc, padding_type=padding_type, activation=activation, norm_layer=norm_layer, opt=opt, dilation=opt.mapping_net_dilation, ) ] for i in range(n_up - 1): ic = min(64 * (2 ** (4 - i)), mc) oc = min(64 * (2 ** (3 - i)), mc) model += [nn.Conv2d(ic, oc, 3, 1, 1), norm_layer(oc), activation] model += [nn.Conv2d(tmp_nc * 2, tmp_nc, 3, 1, 1)] if opt.feat_dim > 0 and opt.feat_dim < 64: model += [norm_layer(tmp_nc), activation, nn.Conv2d(tmp_nc, opt.feat_dim, 1, 1)] # model += [nn.Conv2d(64, 1, 1, 1, 0)] self.after_NL = nn.Sequential(*model) def forward(self, input, mask): x1 = self.before_NL(input) del input x2 = self.NL(x1, mask) del x1, mask x3 = self.after_NL(x2) del x2 return x3 class Mapping_Model_with_mask_2(nn.Module): ## Multi-Scale Patch Attention def __init__(self, nc, mc=64, n_blocks=3, norm="instance", padding_type="reflect", opt=None): super(Mapping_Model_with_mask_2, self).__init__() norm_layer = networks.get_norm_layer(norm_type=norm) activation = nn.ReLU(True) model = [] tmp_nc = 64 n_up = 4 for i in range(n_up): ic = min(tmp_nc * (2 ** i), mc) oc = min(tmp_nc * (2 ** (i + 1)), mc) model += [nn.Conv2d(ic, oc, 3, 1, 1), norm_layer(oc), activation] for i in range(2): model += [ networks.ResnetBlock( mc, padding_type=padding_type, activation=activation, norm_layer=norm_layer, opt=opt, dilation=opt.mapping_net_dilation, ) ] print("Mapping: You are using multi-scale patch attention, conv combine + mask input") self.before_NL = nn.Sequential(*model) if opt.mapping_exp==1: self.NL_scale_1=networks.Patch_Attention_4(mc,mc,8) model = [] for i in range(2): model += [ networks.ResnetBlock( mc, padding_type=padding_type, activation=activation, norm_layer=norm_layer, opt=opt, dilation=opt.mapping_net_dilation, ) ] self.res_block_1 = nn.Sequential(*model) if opt.mapping_exp==1: self.NL_scale_2=networks.Patch_Attention_4(mc,mc,4) model = [] for i in range(2): model += [ networks.ResnetBlock( mc, padding_type=padding_type, activation=activation, norm_layer=norm_layer, opt=opt, dilation=opt.mapping_net_dilation, ) ] self.res_block_2 = nn.Sequential(*model) if opt.mapping_exp==1: self.NL_scale_3=networks.Patch_Attention_4(mc,mc,2) # self.NL_scale_3=networks.Patch_Attention_2(mc,mc,2) model = [] for i in range(2): model += [ networks.ResnetBlock( mc, padding_type=padding_type, activation=activation, norm_layer=norm_layer, opt=opt, dilation=opt.mapping_net_dilation, ) ] for i in range(n_up - 1): ic = min(64 * (2 ** (4 - i)), mc) oc = min(64 * (2 ** (3 - i)), mc) model += [nn.Conv2d(ic, oc, 3, 1, 1), norm_layer(oc), activation] model += [nn.Conv2d(tmp_nc * 2, tmp_nc, 3, 1, 1)] if opt.feat_dim > 0 and opt.feat_dim < 64: model += [norm_layer(tmp_nc), activation, nn.Conv2d(tmp_nc, opt.feat_dim, 1, 1)] # model += [nn.Conv2d(64, 1, 1, 1, 0)] self.after_NL = nn.Sequential(*model) def forward(self, input, mask): x1 = self.before_NL(input) x2 = self.NL_scale_1(x1,mask) x3 = self.res_block_1(x2) x4 = self.NL_scale_2(x3,mask) x5 = self.res_block_2(x4) x6 = self.NL_scale_3(x5,mask) x7 = self.after_NL(x6) return x7 def inference_forward(self, input, mask): x1 = self.before_NL(input) del input x2 = self.NL_scale_1.inference_forward(x1,mask) del x1 x3 = self.res_block_1(x2) del x2 x4 = self.NL_scale_2.inference_forward(x3,mask) del x3 x5 = self.res_block_2(x4) del x4 x6 = self.NL_scale_3.inference_forward(x5,mask) del x5 x7 = self.after_NL(x6) del x6 return x7