import torch import torch.nn as nn import torch.nn.functional as F import math import sympy as sp import wandb from PIL import Image from datasets import load_dataset from torchvision import transforms from down_unet import down_model '''上面的网络需要接受三个信息,上下采样模块需要重写,两次宽高减2后接受三个信息,renet块加入时间信息,''' class conv_block(nn.Module): #一个下采样模块包含两个卷积层,深度channel从1-64-128-256这样[B,C,H,W]-->[B,C_DIM,H-2,W-2] def __init__(self,in_channel,num_heads,channel_dim,use ="down"): super(conv_block,self).__init__() #in_channel输入通道数,channle_dim输出通道数,一个块减少2 self.in_channel = in_channel self.num_heads = num_heads self.channel_dim = channel_dim self.use = use self.GN = nn.GroupNorm(num_groups=4, num_channels=in_channel) #这个channel指的是输入通道数 # num_groups 是组数(2,4,8)输入特征的通道分成多少组进行归一化,num_channels 是输入的通道数 self.conv = nn.Conv2d(in_channels=in_channel, out_channels=in_channel, kernel_size=3, stride=1, padding=1, bias=False) self.silu = nn.SiLU() self.attention = nn.MultiheadAttention(embed_dim=self.in_channel, num_heads=self.num_heads) if self.use == "down": self.conv1 = nn.Conv2d(in_channels=self.in_channel, out_channels=self.channel_dim, kernel_size=3, stride=1, padding=0, bias=False) elif self.use =="up": self.conv1 = nn.Conv2d(in_channels=self.in_channel, out_channels=self.channel_dim, kernel_size=3, stride=1, padding=2, bias=False) def resnet_block(self,X): #隐藏层使用和输入一样的大小 out = self.GN(X) out = self.conv(out) out = self.silu(out) #这里要加入时间信息 out = self.GN(out) out = self.conv(out) out = self.silu(out) return out + X def attention_block(self,X): B,C,H,W = X.size() out = self.GN(X) out = self.conv(out) out = out.view(B, self.in_channel, H * W).transpose(1, 2) # 将输入重构为 [B, L, C],其中 L = H * W out, weights = self.attention(out, out, out) out = out.transpose(1, 2).view(B, self.in_channel, H, W) out = self.conv(out) return out+X def forward(self,X): out = self.resnet_block(X) out = self.attention_block(out) out = self.conv1(out) return out class down_block(nn.Module): #宽高减4,加入两个信息,然后然后除以2 def __init__(self,in_channel,channel_dim): #in_channel4-->channel_dim64 super(down_block,self).__init__() self.channel_dim = channel_dim self.in_channel = in_channel self.block1 = conv_block(in_channel=self.in_channel,num_heads=4, channel_dim=self.channel_dim,use="down") self.block2 = conv_block(in_channel=self.channel_dim, num_heads=4, channel_dim=self.channel_dim, use="down") self.return_conv = nn.Conv2d(in_channels=self.channel_dim*2,out_channels=self.channel_dim,kernel_size=1, stride=1,padding=0,bias=False) self.attention = nn.MultiheadAttention(embed_dim=self.channel_dim, num_heads=4) self.down_pool = nn.Conv2d(in_channels=self.channel_dim, out_channels=self.channel_dim, kernel_size=2, stride=2, padding=0, bias=False) def caculate_attention(self,X_q,Y_kv): B,C,H,W = X_q.size() X_q = X_q.view(B, self.channel_dim, H * W).transpose(1, 2) # 将输入重构为 [B, L, C],其中 L = H * W Y_kv = Y_kv.view(B, self.channel_dim, H * W).transpose(1, 2) out, weights = self.attention(X_q, Y_kv, Y_kv) out = out.transpose(1, 2).view(B, self.channel_dim, H, W) return out def forward(self,X,attention_out,pos_encoding): #输入[1,4,128,128],输出[1.64,124,124]-->[1,64,62,62] out = self.block1(X) for_skip_connection = self.block2(out) out = torch.cat((for_skip_connection,pos_encoding),dim=1) out = self.return_conv(out) out = self.caculate_attention(X_q=attention_out,Y_kv=out) out = self.down_pool(out) return out,for_skip_connection ''' X = torch.randn(1,4,128,128) attention_out = torch.randn(1,64,124,124) pos_encoding = torch.randn(1,64,124,124) model = down_block(4,64,4) out = model(X,attention_out,pos_encoding) print(out.shape) ''' class up_block(nn.Module): def __init__(self,in_channel): #这里的in_channel指的是cat之后的通道数 super(up_block,self).__init__() self.in_channel = in_channel self.block1 = conv_block(in_channel=in_channel*2, num_heads=4, channel_dim=in_channel,use="up") self.block2 = conv_block(in_channel=in_channel, num_heads=4, channel_dim=in_channel,use="up") self.up_pool = nn.ConvTranspose2d(self.in_channel*2, self.in_channel, kernel_size=2, stride=2) self.return_conv = nn.Conv2d(in_channels=self.in_channel * 2, out_channels=self.in_channel, kernel_size=1, stride=1, padding=0, bias=False) self.attention = nn.MultiheadAttention(embed_dim=self.in_channel, num_heads=4) def caculate_attention(self,X_q,Y_kv): B,C,H,W = X_q.size() X_q = X_q.view(B, self.in_channel, H * W).transpose(1, 2) # 将输入重构为 [B, L, C],其中 L = H * W Y_kv = Y_kv.view(B, self.in_channel, H * W).transpose(1, 2) out, weights = self.attention(X_q, Y_kv, Y_kv) out = out.transpose(1, 2).view(B, self.in_channel, H, W) return out def forward(self,input,input_skip,attention_out,pos_encoding): #先对输入进行上采样,然后和跳跃的拼接,之后经过两个block after_transposed = self.up_pool(input) #上采样得到的大小 after_cat = torch.cat((after_transposed, input_skip), dim=1) # 拼接张量 after_cat = self.return_conv(after_cat) after_cat = torch.cat((after_cat, pos_encoding), dim=1) after_cat = self.return_conv(after_cat) out = self.caculate_attention(X_q=attention_out, Y_kv=after_cat) out = self.block2(out) #通道数不用再降低了 out = self.block2(out) return out ''' X = torch.randn(1,128,62,62) input_skip = torch.randn(1,64,124,124) attention_out = torch.randn(1,64,124,124) pos_encoding = torch.randn(1,64,124,124) model = up_block(in_channel=64,num_head=4) out = model(X,input_skip,attention_out,pos_encoding) print(out.shape) # torch.Size([1, 64, 128, 128]) ''' class up_model(nn.Module): def __init__(self): super(up_model,self).__init__() self.down_model = down_model() self.start_conv = nn.Conv2d(in_channels=3, out_channels=4, kernel_size=1, stride=1) self.down_block1 = down_block(4,64) self.down_block2 = down_block(64,128) self.down_block3 = down_block(128,256) self.down_block4 = down_block(256,512) self.bottle_conv = nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=1, stride=1) self.up_block4 = up_block(512) self.up_block3 = up_block(256) self.up_block2 = up_block(128) self.up_block1 = up_block(64) self.final_conv = nn.Conv2d(in_channels=64, out_channels=3, kernel_size=1, stride=1) def forward(self,input): #这个地方的输入一定要除的尽 X, attention_out1, attention_out2, attention_out3, attention_out4, attention_out5, attention_out6, attention_out7, attention_out8, pos_encoding1, pos_encoding2, pos_encoding3, pos_encoding4, pos_encoding5, pos_encoding6, pos_encoding7, pos_encoding8 =self.down_model(input) input = self.start_conv(input) out,for_skip1= self.down_block1(input,attention_out8,pos_encoding8) out,for_skip2 = self.down_block1(out, attention_out7, pos_encoding7) out,for_skip3 = self.down_block1(out, attention_out6, pos_encoding6) out,for_skip4 = self.down_block1(out, attention_out5, pos_encoding5) out = self.bottle_conv(out) # print("bottle",out.shape) out = self.up_block4(out, for_skip4, attention_out4,pos_encoding4) out = self.up_block4(out, for_skip3, attention_out3, pos_encoding3) out = self.up_block4(out, for_skip2, attention_out2, pos_encoding2) out = self.up_block4(out, for_skip1, attention_out1, pos_encoding1) out = self.final_conv(out) return out