unet-try / up_unet.py
johnson115's picture
Upload 2 files
ab9f2cc verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import sympy as sp
import wandb
from PIL import Image
from datasets import load_dataset
from torchvision import transforms
from down_unet import down_model
'''上面的网络需要接受三个信息,上下采样模块需要重写,两次宽高减2后接受三个信息,renet块加入时间信息,'''
class conv_block(nn.Module): #一个下采样模块包含两个卷积层,深度channel从1-64-128-256这样[B,C,H,W]-->[B,C_DIM,H-2,W-2]
def __init__(self,in_channel,num_heads,channel_dim,use ="down"):
super(conv_block,self).__init__() #in_channel输入通道数,channle_dim输出通道数,一个块减少2
self.in_channel = in_channel
self.num_heads = num_heads
self.channel_dim = channel_dim
self.use = use
self.GN = nn.GroupNorm(num_groups=4, num_channels=in_channel) #这个channel指的是输入通道数
# num_groups 是组数(2,4,8)输入特征的通道分成多少组进行归一化,num_channels 是输入的通道数
self.conv = nn.Conv2d(in_channels=in_channel, out_channels=in_channel, kernel_size=3,
stride=1, padding=1, bias=False)
self.silu = nn.SiLU()
self.attention = nn.MultiheadAttention(embed_dim=self.in_channel, num_heads=self.num_heads)
if self.use == "down":
self.conv1 = nn.Conv2d(in_channels=self.in_channel, out_channels=self.channel_dim, kernel_size=3,
stride=1, padding=0, bias=False)
elif self.use =="up":
self.conv1 = nn.Conv2d(in_channels=self.in_channel, out_channels=self.channel_dim, kernel_size=3,
stride=1, padding=2, bias=False)
def resnet_block(self,X): #隐藏层使用和输入一样的大小
out = self.GN(X)
out = self.conv(out)
out = self.silu(out) #这里要加入时间信息
out = self.GN(out)
out = self.conv(out)
out = self.silu(out)
return out + X
def attention_block(self,X):
B,C,H,W = X.size()
out = self.GN(X)
out = self.conv(out)
out = out.view(B, self.in_channel, H * W).transpose(1, 2) # 将输入重构为 [B, L, C],其中 L = H * W
out, weights = self.attention(out, out, out)
out = out.transpose(1, 2).view(B, self.in_channel, H, W)
out = self.conv(out)
return out+X
def forward(self,X):
out = self.resnet_block(X)
out = self.attention_block(out)
out = self.conv1(out)
return out
class down_block(nn.Module): #宽高减4,加入两个信息,然后然后除以2
def __init__(self,in_channel,channel_dim): #in_channel4-->channel_dim64
super(down_block,self).__init__()
self.channel_dim = channel_dim
self.in_channel = in_channel
self.block1 = conv_block(in_channel=self.in_channel,num_heads=4,
channel_dim=self.channel_dim,use="down")
self.block2 = conv_block(in_channel=self.channel_dim, num_heads=4,
channel_dim=self.channel_dim, use="down")
self.return_conv = nn.Conv2d(in_channels=self.channel_dim*2,out_channels=self.channel_dim,kernel_size=1,
stride=1,padding=0,bias=False)
self.attention = nn.MultiheadAttention(embed_dim=self.channel_dim, num_heads=4)
self.down_pool = nn.Conv2d(in_channels=self.channel_dim, out_channels=self.channel_dim, kernel_size=2,
stride=2, padding=0, bias=False)
def caculate_attention(self,X_q,Y_kv):
B,C,H,W = X_q.size()
X_q = X_q.view(B, self.channel_dim, H * W).transpose(1, 2) # 将输入重构为 [B, L, C],其中 L = H * W
Y_kv = Y_kv.view(B, self.channel_dim, H * W).transpose(1, 2)
out, weights = self.attention(X_q, Y_kv, Y_kv)
out = out.transpose(1, 2).view(B, self.channel_dim, H, W)
return out
def forward(self,X,attention_out,pos_encoding): #输入[1,4,128,128],输出[1.64,124,124]-->[1,64,62,62]
out = self.block1(X)
for_skip_connection = self.block2(out)
out = torch.cat((for_skip_connection,pos_encoding),dim=1)
out = self.return_conv(out)
out = self.caculate_attention(X_q=attention_out,Y_kv=out)
out = self.down_pool(out)
return out,for_skip_connection
'''
X = torch.randn(1,4,128,128)
attention_out = torch.randn(1,64,124,124)
pos_encoding = torch.randn(1,64,124,124)
model = down_block(4,64,4)
out = model(X,attention_out,pos_encoding)
print(out.shape)
'''
class up_block(nn.Module):
def __init__(self,in_channel): #这里的in_channel指的是cat之后的通道数
super(up_block,self).__init__()
self.in_channel = in_channel
self.block1 = conv_block(in_channel=in_channel*2, num_heads=4,
channel_dim=in_channel,use="up")
self.block2 = conv_block(in_channel=in_channel, num_heads=4,
channel_dim=in_channel,use="up")
self.up_pool = nn.ConvTranspose2d(self.in_channel*2, self.in_channel,
kernel_size=2, stride=2)
self.return_conv = nn.Conv2d(in_channels=self.in_channel * 2, out_channels=self.in_channel, kernel_size=1,
stride=1, padding=0, bias=False)
self.attention = nn.MultiheadAttention(embed_dim=self.in_channel, num_heads=4)
def caculate_attention(self,X_q,Y_kv):
B,C,H,W = X_q.size()
X_q = X_q.view(B, self.in_channel, H * W).transpose(1, 2) # 将输入重构为 [B, L, C],其中 L = H * W
Y_kv = Y_kv.view(B, self.in_channel, H * W).transpose(1, 2)
out, weights = self.attention(X_q, Y_kv, Y_kv)
out = out.transpose(1, 2).view(B, self.in_channel, H, W)
return out
def forward(self,input,input_skip,attention_out,pos_encoding): #先对输入进行上采样,然后和跳跃的拼接,之后经过两个block
after_transposed = self.up_pool(input) #上采样得到的大小
after_cat = torch.cat((after_transposed, input_skip), dim=1) # 拼接张量
after_cat = self.return_conv(after_cat)
after_cat = torch.cat((after_cat, pos_encoding), dim=1)
after_cat = self.return_conv(after_cat)
out = self.caculate_attention(X_q=attention_out, Y_kv=after_cat)
out = self.block2(out) #通道数不用再降低了
out = self.block2(out)
return out
'''
X = torch.randn(1,128,62,62)
input_skip = torch.randn(1,64,124,124)
attention_out = torch.randn(1,64,124,124)
pos_encoding = torch.randn(1,64,124,124)
model = up_block(in_channel=64,num_head=4)
out = model(X,input_skip,attention_out,pos_encoding)
print(out.shape) # torch.Size([1, 64, 128, 128])
'''
class up_model(nn.Module):
def __init__(self):
super(up_model,self).__init__()
self.down_model = down_model()
self.start_conv = nn.Conv2d(in_channels=3, out_channels=4, kernel_size=1, stride=1)
self.down_block1 = down_block(4,64)
self.down_block2 = down_block(64,128)
self.down_block3 = down_block(128,256)
self.down_block4 = down_block(256,512)
self.bottle_conv = nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=1, stride=1)
self.up_block4 = up_block(512)
self.up_block3 = up_block(256)
self.up_block2 = up_block(128)
self.up_block1 = up_block(64)
self.final_conv = nn.Conv2d(in_channels=64, out_channels=3, kernel_size=1, stride=1)
def forward(self,input): #这个地方的输入一定要除的尽
X, attention_out1, attention_out2, attention_out3, attention_out4, attention_out5, attention_out6, attention_out7, attention_out8, pos_encoding1, pos_encoding2, pos_encoding3, pos_encoding4, pos_encoding5, pos_encoding6, pos_encoding7, pos_encoding8 =self.down_model(input)
input = self.start_conv(input)
out,for_skip1= self.down_block1(input,attention_out8,pos_encoding8)
out,for_skip2 = self.down_block1(out, attention_out7, pos_encoding7)
out,for_skip3 = self.down_block1(out, attention_out6, pos_encoding6)
out,for_skip4 = self.down_block1(out, attention_out5, pos_encoding5)
out = self.bottle_conv(out)
# print("bottle",out.shape)
out = self.up_block4(out, for_skip4, attention_out4,pos_encoding4)
out = self.up_block4(out, for_skip3, attention_out3, pos_encoding3)
out = self.up_block4(out, for_skip2, attention_out2, pos_encoding2)
out = self.up_block4(out, for_skip1, attention_out1, pos_encoding1)
out = self.final_conv(out)
return out