Spaces:
Build error
Build error
# -*- coding: utf-8 -*- | |
# @Author : Lintao Peng | |
# @File : Ushape_Trans.py | |
# coding=utf-8 | |
# Design based on the pix2pix | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch | |
import datetime | |
import os | |
import time | |
import timeit | |
import copy | |
import numpy as np | |
from torch.nn import ModuleList | |
from torch.nn import Conv2d | |
from torch.nn import LeakyReLU | |
from net.block import * | |
from net.block import _equalized_conv2d | |
from net.SGFMT import TransformerModel | |
from net.PositionalEncoding import FixedPositionalEncoding,LearnedPositionalEncoding | |
from net.CMSFFT import ChannelTransformer | |
##权重初始化 | |
def weights_init_normal(m): | |
classname = m.__class__.__name__ | |
if classname.find("Conv") != -1: | |
torch.nn.init.normal_(m.weight.data, 0.0, 0.02) | |
elif classname.find("BatchNorm2d") != -1: | |
torch.nn.init.normal_(m.weight.data, 1.0, 0.02) | |
torch.nn.init.constant_(m.bias.data, 0.0) | |
class Generator(nn.Module): | |
""" | |
MSG-Unet-GAN的生成器部分 | |
""" | |
def __init__(self, | |
img_dim=256, | |
patch_dim=16, | |
embedding_dim=512, | |
num_channels=3, | |
num_heads=8, | |
num_layers=4, | |
hidden_dim=256, | |
dropout_rate=0.0, | |
attn_dropout_rate=0.0, | |
in_ch=3, | |
out_ch=3, | |
conv_patch_representation=True, | |
positional_encoding_type="learned", | |
use_eql=True): | |
super(Generator, self).__init__() | |
assert embedding_dim % num_heads == 0 | |
assert img_dim % patch_dim == 0 | |
self.out_ch=out_ch #输出通道数 | |
self.in_ch=in_ch #输入通道数 | |
self.img_dim = img_dim #输入图片尺寸 | |
self.embedding_dim = embedding_dim #512 | |
self.num_heads = num_heads #多头注意力中头的数量 | |
self.patch_dim = patch_dim #每个patch的尺寸 | |
self.num_channels = num_channels #图片通道数? | |
self.dropout_rate = dropout_rate #drop-out比率 | |
self.attn_dropout_rate = attn_dropout_rate #注意力模块的dropout比率 | |
self.conv_patch_representation = conv_patch_representation #True | |
self.num_patches = int((img_dim // patch_dim) ** 2) #将三通道图片分成多少块 | |
self.seq_length = self.num_patches #每个sequence的长度为patches的大小 | |
self.flatten_dim = 128 * num_channels #128*3=384 | |
#线性编码 | |
self.linear_encoding = nn.Linear(self.flatten_dim, self.embedding_dim) | |
#位置编码 | |
if positional_encoding_type == "learned": | |
self.position_encoding = LearnedPositionalEncoding( | |
self.seq_length, self.embedding_dim, self.seq_length | |
) | |
elif positional_encoding_type == "fixed": | |
self.position_encoding = FixedPositionalEncoding( | |
self.embedding_dim, | |
) | |
self.pe_dropout = nn.Dropout(p=self.dropout_rate) | |
self.transformer = TransformerModel( | |
embedding_dim, #512 | |
num_layers, #4 | |
num_heads, #8 | |
hidden_dim, #4096 | |
self.dropout_rate, | |
self.attn_dropout_rate, | |
) | |
#layer Norm | |
self.pre_head_ln = nn.LayerNorm(embedding_dim) | |
if self.conv_patch_representation: | |
self.Conv_x = nn.Conv2d( | |
256, | |
self.embedding_dim, #512 | |
kernel_size=3, | |
stride=1, | |
padding=1 | |
) | |
self.bn = nn.BatchNorm2d(256) | |
self.relu = nn.ReLU(inplace=True) | |
#modulelist | |
self.rgb_to_feature=ModuleList([from_rgb(32),from_rgb(64),from_rgb(128)]) | |
self.feature_to_rgb=ModuleList([to_rgb(32),to_rgb(64),to_rgb(128),to_rgb(256)]) | |
self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2) | |
self.Maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2) | |
self.Maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2) | |
self.Maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2) | |
self.Maxpool4 = nn.MaxPool2d(kernel_size=2, stride=2) | |
self.Conv1=conv_block(self.in_ch, 16) | |
self.Conv1_1 = conv_block(16, 32) | |
self.Conv2 = conv_block(32, 32) | |
self.Conv2_1 = conv_block(32, 64) | |
self.Conv3 = conv_block(64,64) | |
self.Conv3_1 = conv_block(64,128) | |
self.Conv4 = conv_block(128,128) | |
self.Conv4_1 = conv_block(128,256) | |
self.Conv5 = conv_block(512,256) | |
#self.Conv_x = conv_block(256,512) | |
self.mtc = ChannelTransformer(channel_num=[32,64,128,256], | |
patchSize=[32, 16, 8, 4]) | |
self.Up5 = up_conv(256, 256) | |
self.coatt5 = CCA(F_g=256, F_x=256) | |
self.Up_conv5 = conv_block(512, 256) | |
self.Up_conv5_1 = conv_block(256, 256) | |
self.Up4 = up_conv(256, 128) | |
self.coatt4 = CCA(F_g=128, F_x=128) | |
self.Up_conv4 = conv_block(256, 128) | |
self.Up_conv4_1 = conv_block(128, 128) | |
self.Up3 = up_conv(128, 64) | |
self.coatt3 = CCA(F_g=64, F_x=64) | |
self.Up_conv3 = conv_block(128, 64) | |
self.Up_conv3_1 = conv_block(64, 64) | |
self.Up2 = up_conv(64, 32) | |
self.coatt2 = CCA(F_g=32, F_x=32) | |
self.Up_conv2 = conv_block(64, 32) | |
self.Up_conv2_1 = conv_block(32, 32) | |
self.Conv = nn.Conv2d(32, self.out_ch, kernel_size=1, stride=1, padding=0) | |
# self.active = torch.nn.Sigmoid() | |
# | |
def reshape_output(self,x): #将transformer的输出resize为原来的特征图尺寸 | |
x = x.view( | |
x.size(0), | |
int(self.img_dim / self.patch_dim), | |
int(self.img_dim / self.patch_dim), | |
self.embedding_dim, | |
)#B,16,16,512 | |
x = x.permute(0, 3, 1, 2).contiguous() | |
return x | |
def forward(self, x): | |
#print(x.shape) | |
output=[] | |
x_1=self.Maxpool(x) | |
x_2=self.Maxpool(x_1) | |
x_3=self.Maxpool(x_2) | |
e1 = self.Conv1(x) | |
#print(e1.shape) | |
e1 = self.Conv1_1(e1) | |
e2 = self.Maxpool1(e1) | |
#32*128*128 | |
x_1=self.rgb_to_feature[0](x_1) | |
#e2=torch.cat((x_1,e2), dim=1) | |
e2=x_1+e2 | |
e2 = self.Conv2(e2) | |
e2 = self.Conv2_1(e2) | |
e3 = self.Maxpool2(e2) | |
#64*64*64 | |
x_2=self.rgb_to_feature[1](x_2) | |
#e3=torch.cat((x_2,e3), dim=1) | |
e3=x_2+e3 | |
e3 = self.Conv3(e3) | |
e3 = self.Conv3_1(e3) | |
e4 = self.Maxpool3(e3) | |
#128*32*32 | |
x_3=self.rgb_to_feature[2](x_3) | |
#e4=torch.cat((x_3,e4), dim=1) | |
e4=x_3+e4 | |
e4 = self.Conv4(e4) | |
e4 = self.Conv4_1(e4) | |
e5 = self.Maxpool4(e4) | |
#256*16*16 | |
#channel-wise transformer-based attention | |
e1,e2,e3,e4,att_weights = self.mtc(e1,e2,e3,e4) | |
#spatial-wise transformer-based attention | |
residual=e5 | |
#中间的隐变量 | |
#conv_x应该接受256通道,输出512通道的中间隐变量 | |
e5= self.bn(e5) | |
e5=self.relu(e5) | |
e5= self.Conv_x(e5) #out->512*16*16 shape->B,512,16,16 | |
e5= e5.permute(0, 2, 3, 1).contiguous() # B,512,16,16->B,16,16,512 | |
e5= e5.view(e5.size(0), -1, self.embedding_dim) #B,16,16,512->B,16*16,512 线性映射层 | |
e5= self.position_encoding(e5) #位置编码 | |
e5= self.pe_dropout(e5) #预dropout层 | |
# apply transformer | |
e5= self.transformer(e5) | |
e5= self.pre_head_ln(e5) | |
e5= self.reshape_output(e5)#out->512*16*16 shape->B,512,16,16 | |
e5=self.Conv5(e5) #out->256,16,16 shape->B,256,16,16 | |
#residual是否要加bn和relu? | |
e5=e5+residual | |
d5 = self.Up5(e5) | |
e4_att = self.coatt5(g=d5, x=e4) | |
d5 = torch.cat((e4_att, d5), dim=1) | |
d5 = self.Up_conv5(d5) | |
d5 = self.Up_conv5_1(d5) | |
#256 | |
out3=self.feature_to_rgb[3](d5) | |
output.append(out3)#32*32orH/8,W/8 | |
d4 = self.Up4(d5) | |
e3_att = self.coatt4(g=d4, x=e3) | |
d4 = torch.cat((e3_att, d4), dim=1) | |
d4 = self.Up_conv4(d4) | |
d4 = self.Up_conv4_1(d4) | |
#128 | |
out2=self.feature_to_rgb[2](d4) | |
output.append(out2)#64*64orH/4,W/4 | |
d3 = self.Up3(d4) | |
e2_att = self.coatt3(g=d3, x=e2) | |
d3 = torch.cat((e2_att, d3), dim=1) | |
d3 = self.Up_conv3(d3) | |
d3 = self.Up_conv3_1(d3) | |
#64 | |
out1=self.feature_to_rgb[1](d3) | |
output.append(out1)#128#128orH/2,W/2 | |
d2 = self.Up2(d3) | |
e1_att = self.coatt2(g=d2, x=e1) | |
d2 = torch.cat((e1_att, d2), dim=1) | |
d2 = self.Up_conv2(d2) | |
d2 = self.Up_conv2_1(d2) | |
#32 | |
out0=self.feature_to_rgb[0](d2) | |
output.append(out0)#256*256 | |
#out = self.Conv(d2) | |
#d1 = self.active(out) | |
#output=np.array(output) | |
return output[3] | |
class Discriminator(nn.Module): | |
def __init__(self, in_channels=3,use_eql=True): | |
super(Discriminator, self).__init__() | |
self.use_eql=use_eql | |
self.in_channels=in_channels | |
#modulelist | |
self.rgb_to_feature1=ModuleList([from_rgb(32),from_rgb(64),from_rgb(128)]) | |
self.rgb_to_feature2=ModuleList([from_rgb(32),from_rgb(64),from_rgb(128)]) | |
self.layer=_equalized_conv2d(self.in_channels*2, 64, (1, 1), bias=True) | |
# pixel_wise feature normalizer: | |
self.pixNorm = PixelwiseNorm() | |
# leaky_relu: | |
self.lrelu = LeakyReLU(0.2) | |
self.layer0=DisGeneralConvBlock(64,64,use_eql=self.use_eql) | |
#128*128*32 | |
self.layer1=DisGeneralConvBlock(128,128,use_eql=self.use_eql) | |
#64*64*64 | |
self.layer2=DisGeneralConvBlock(256,256,use_eql=self.use_eql) | |
#32*32*128 | |
self.layer3=DisGeneralConvBlock(512,512,use_eql=self.use_eql) | |
#16*16*256 | |
self.layer4=DisFinalBlock(512,use_eql=self.use_eql) | |
#8*8*512 | |
def forward(self, img_A, inputs): | |
#inputs图片尺寸从小到大 | |
# Concatenate image and condition image by channels to produce input | |
#img_input = torch.cat((img_A, img_B), 1) | |
#img_A_128= F.interpolate(img_A, size=[128, 128]) | |
#img_A_64= F.interpolate(img_A, size=[64, 64]) | |
#img_A_32= F.interpolate(img_A, size=[32, 32]) | |
x=torch.cat((img_A[3], inputs[3]), 1) | |
y = self.pixNorm(self.lrelu(self.layer(x))) | |
y=self.layer0(y) | |
#128*128*64 | |
x1=self.rgb_to_feature1[0](img_A[2]) | |
x2=self.rgb_to_feature2[0](inputs[2]) | |
x=torch.cat((x1,x2),1) | |
y=torch.cat((x,y),1) | |
y=self.layer1(y) | |
#64*64*128 | |
x1=self.rgb_to_feature1[1](img_A[1]) | |
x2=self.rgb_to_feature2[1](inputs[1]) | |
x=torch.cat((x1,x2),1) | |
y=torch.cat((x,y),1) | |
y=self.layer2(y) | |
#32*32*256 | |
x1=self.rgb_to_feature1[2](img_A[0]) | |
x2=self.rgb_to_feature2[2](inputs[0]) | |
x=torch.cat((x1,x2),1) | |
y=torch.cat((x,y),1) | |
y=self.layer3(y) | |
#16*16*512 | |
y=self.layer4(y) | |
#8*8*512 | |
return y | |