|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import math
|
|
import sympy as sp
|
|
import wandb
|
|
from PIL import Image
|
|
from datasets import load_dataset
|
|
from torchvision import transforms
|
|
|
|
from down_unet import down_model
|
|
'''上面的网络需要接受三个信息,上下采样模块需要重写,两次宽高减2后接受三个信息,renet块加入时间信息,'''
|
|
|
|
class conv_block(nn.Module):
|
|
def __init__(self,in_channel,num_heads,channel_dim,use ="down"):
|
|
super(conv_block,self).__init__()
|
|
|
|
|
|
self.in_channel = in_channel
|
|
self.num_heads = num_heads
|
|
self.channel_dim = channel_dim
|
|
self.use = use
|
|
|
|
self.GN = nn.GroupNorm(num_groups=4, num_channels=in_channel)
|
|
|
|
self.conv = nn.Conv2d(in_channels=in_channel, out_channels=in_channel, kernel_size=3,
|
|
stride=1, padding=1, bias=False)
|
|
self.silu = nn.SiLU()
|
|
self.attention = nn.MultiheadAttention(embed_dim=self.in_channel, num_heads=self.num_heads)
|
|
|
|
if self.use == "down":
|
|
self.conv1 = nn.Conv2d(in_channels=self.in_channel, out_channels=self.channel_dim, kernel_size=3,
|
|
stride=1, padding=0, bias=False)
|
|
elif self.use =="up":
|
|
self.conv1 = nn.Conv2d(in_channels=self.in_channel, out_channels=self.channel_dim, kernel_size=3,
|
|
stride=1, padding=2, bias=False)
|
|
|
|
def resnet_block(self,X):
|
|
|
|
out = self.GN(X)
|
|
out = self.conv(out)
|
|
out = self.silu(out)
|
|
|
|
out = self.GN(out)
|
|
out = self.conv(out)
|
|
out = self.silu(out)
|
|
|
|
return out + X
|
|
|
|
def attention_block(self,X):
|
|
|
|
B,C,H,W = X.size()
|
|
|
|
out = self.GN(X)
|
|
out = self.conv(out)
|
|
|
|
out = out.view(B, self.in_channel, H * W).transpose(1, 2)
|
|
out, weights = self.attention(out, out, out)
|
|
out = out.transpose(1, 2).view(B, self.in_channel, H, W)
|
|
|
|
out = self.conv(out)
|
|
|
|
return out+X
|
|
|
|
def forward(self,X):
|
|
|
|
out = self.resnet_block(X)
|
|
out = self.attention_block(out)
|
|
out = self.conv1(out)
|
|
|
|
return out
|
|
|
|
|
|
|
|
class down_block(nn.Module):
|
|
def __init__(self,in_channel,channel_dim):
|
|
super(down_block,self).__init__()
|
|
|
|
self.channel_dim = channel_dim
|
|
|
|
self.in_channel = in_channel
|
|
|
|
self.block1 = conv_block(in_channel=self.in_channel,num_heads=4,
|
|
channel_dim=self.channel_dim,use="down")
|
|
self.block2 = conv_block(in_channel=self.channel_dim, num_heads=4,
|
|
channel_dim=self.channel_dim, use="down")
|
|
|
|
self.return_conv = nn.Conv2d(in_channels=self.channel_dim*2,out_channels=self.channel_dim,kernel_size=1,
|
|
stride=1,padding=0,bias=False)
|
|
|
|
self.attention = nn.MultiheadAttention(embed_dim=self.channel_dim, num_heads=4)
|
|
|
|
self.down_pool = nn.Conv2d(in_channels=self.channel_dim, out_channels=self.channel_dim, kernel_size=2,
|
|
stride=2, padding=0, bias=False)
|
|
|
|
def caculate_attention(self,X_q,Y_kv):
|
|
|
|
B,C,H,W = X_q.size()
|
|
|
|
X_q = X_q.view(B, self.channel_dim, H * W).transpose(1, 2)
|
|
Y_kv = Y_kv.view(B, self.channel_dim, H * W).transpose(1, 2)
|
|
|
|
out, weights = self.attention(X_q, Y_kv, Y_kv)
|
|
out = out.transpose(1, 2).view(B, self.channel_dim, H, W)
|
|
|
|
return out
|
|
|
|
def forward(self,X,attention_out,pos_encoding):
|
|
|
|
out = self.block1(X)
|
|
for_skip_connection = self.block2(out)
|
|
|
|
out = torch.cat((for_skip_connection,pos_encoding),dim=1)
|
|
out = self.return_conv(out)
|
|
|
|
out = self.caculate_attention(X_q=attention_out,Y_kv=out)
|
|
|
|
out = self.down_pool(out)
|
|
|
|
return out,for_skip_connection
|
|
|
|
'''
|
|
X = torch.randn(1,4,128,128)
|
|
attention_out = torch.randn(1,64,124,124)
|
|
pos_encoding = torch.randn(1,64,124,124)
|
|
model = down_block(4,64,4)
|
|
out = model(X,attention_out,pos_encoding)
|
|
print(out.shape)
|
|
'''
|
|
|
|
class up_block(nn.Module):
|
|
def __init__(self,in_channel):
|
|
super(up_block,self).__init__()
|
|
self.in_channel = in_channel
|
|
|
|
|
|
|
|
self.block1 = conv_block(in_channel=in_channel*2, num_heads=4,
|
|
channel_dim=in_channel,use="up")
|
|
self.block2 = conv_block(in_channel=in_channel, num_heads=4,
|
|
channel_dim=in_channel,use="up")
|
|
self.up_pool = nn.ConvTranspose2d(self.in_channel*2, self.in_channel,
|
|
kernel_size=2, stride=2)
|
|
|
|
self.return_conv = nn.Conv2d(in_channels=self.in_channel * 2, out_channels=self.in_channel, kernel_size=1,
|
|
stride=1, padding=0, bias=False)
|
|
|
|
self.attention = nn.MultiheadAttention(embed_dim=self.in_channel, num_heads=4)
|
|
|
|
def caculate_attention(self,X_q,Y_kv):
|
|
|
|
B,C,H,W = X_q.size()
|
|
|
|
X_q = X_q.view(B, self.in_channel, H * W).transpose(1, 2)
|
|
Y_kv = Y_kv.view(B, self.in_channel, H * W).transpose(1, 2)
|
|
|
|
out, weights = self.attention(X_q, Y_kv, Y_kv)
|
|
out = out.transpose(1, 2).view(B, self.in_channel, H, W)
|
|
|
|
return out
|
|
|
|
def forward(self,input,input_skip,attention_out,pos_encoding):
|
|
|
|
|
|
after_transposed = self.up_pool(input)
|
|
|
|
after_cat = torch.cat((after_transposed, input_skip), dim=1)
|
|
after_cat = self.return_conv(after_cat)
|
|
after_cat = torch.cat((after_cat, pos_encoding), dim=1)
|
|
after_cat = self.return_conv(after_cat)
|
|
|
|
out = self.caculate_attention(X_q=attention_out, Y_kv=after_cat)
|
|
|
|
out = self.block2(out)
|
|
out = self.block2(out)
|
|
|
|
return out
|
|
|
|
'''
|
|
X = torch.randn(1,128,62,62)
|
|
input_skip = torch.randn(1,64,124,124)
|
|
attention_out = torch.randn(1,64,124,124)
|
|
pos_encoding = torch.randn(1,64,124,124)
|
|
model = up_block(in_channel=64,num_head=4)
|
|
out = model(X,input_skip,attention_out,pos_encoding)
|
|
print(out.shape) # torch.Size([1, 64, 128, 128])
|
|
'''
|
|
|
|
class up_model(nn.Module):
|
|
def __init__(self):
|
|
super(up_model,self).__init__()
|
|
|
|
self.down_model = down_model()
|
|
|
|
self.start_conv = nn.Conv2d(in_channels=3, out_channels=4, kernel_size=1, stride=1)
|
|
|
|
self.down_block1 = down_block(4,64)
|
|
self.down_block2 = down_block(64,128)
|
|
self.down_block3 = down_block(128,256)
|
|
self.down_block4 = down_block(256,512)
|
|
|
|
self.bottle_conv = nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=1, stride=1)
|
|
|
|
self.up_block4 = up_block(512)
|
|
self.up_block3 = up_block(256)
|
|
self.up_block2 = up_block(128)
|
|
self.up_block1 = up_block(64)
|
|
|
|
self.final_conv = nn.Conv2d(in_channels=64, out_channels=3, kernel_size=1, stride=1)
|
|
|
|
def forward(self,input):
|
|
|
|
X, attention_out1, attention_out2, attention_out3, attention_out4, attention_out5, attention_out6, attention_out7, attention_out8, pos_encoding1, pos_encoding2, pos_encoding3, pos_encoding4, pos_encoding5, pos_encoding6, pos_encoding7, pos_encoding8 =self.down_model(input)
|
|
|
|
input = self.start_conv(input)
|
|
|
|
out,for_skip1= self.down_block1(input,attention_out8,pos_encoding8)
|
|
|
|
out,for_skip2 = self.down_block1(out, attention_out7, pos_encoding7)
|
|
|
|
out,for_skip3 = self.down_block1(out, attention_out6, pos_encoding6)
|
|
|
|
out,for_skip4 = self.down_block1(out, attention_out5, pos_encoding5)
|
|
|
|
out = self.bottle_conv(out)
|
|
|
|
|
|
out = self.up_block4(out, for_skip4, attention_out4,pos_encoding4)
|
|
|
|
out = self.up_block4(out, for_skip3, attention_out3, pos_encoding3)
|
|
|
|
out = self.up_block4(out, for_skip2, attention_out2, pos_encoding2)
|
|
|
|
out = self.up_block4(out, for_skip1, attention_out1, pos_encoding1)
|
|
|
|
out = self.final_conv(out)
|
|
|
|
return out
|
|
|
|
|