import torch import torch.nn as nn import torch.nn.functional as F import math import sympy as sp import wandb from PIL import Image from datasets import load_dataset from torchvision import transforms device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(device) # 初始化项目 wandb.init( # set the wandb project where this run will be logged project="unet-try", ) ''' conv_block = resnetblock--attentionblock--convblock. input:[B,C,H,W],output:[B,channel_dim,H(+/-)2,W(+/-)2] down block = 2blocks|-->for_skip_connection | down_sample-->result_after_pool. input:[B,C,H,W],output:[B,channel_dim,(H-4)//2,(W-4)//2] up block = -->concat-->2blocks input:[B,C,H,W],input_skip:[B,C/2,2H,2W],output:[B,C/2,2H+4,2W+4] | --up_sample LR-----------------------------MSE LOSS--------------------------LR |--down block -------------skip connection-----------up block--| |--down block up block--| |---------------| ''' # ---------------------------------------------------------------------------------------------------- class conv_block(nn.Module): #一个下采样模块包含两个卷积层,深度channel从1-64-128-256这样[B,C,H,W]-->[B,C_DIM,H-2,W-2] def __init__(self,in_channel,num_heads,channel_dim,use ="down"): super(conv_block,self).__init__() #in_channel输入通道数,channle_dim输出通道数,一个块减少2 self.in_channel = in_channel self.num_heads = num_heads self.channel_dim = channel_dim self.use = use self.GN = nn.GroupNorm(num_groups=4, num_channels=in_channel) #这个channel指的是输入通道数 # num_groups 是组数(2,4,8)输入特征的通道分成多少组进行归一化,num_channels 是输入的通道数 self.conv = nn.Conv2d(in_channels=in_channel, out_channels=in_channel, kernel_size=3, stride=1, padding=1, bias=False) self.silu = nn.SiLU() self.attention = nn.MultiheadAttention(embed_dim=self.in_channel, num_heads=self.num_heads) if self.use == "down": self.conv1 = nn.Conv2d(in_channels=self.in_channel, out_channels=self.channel_dim, kernel_size=3, stride=1, padding=0, bias=False) elif self.use =="up": self.conv1 = nn.Conv2d(in_channels=self.in_channel, out_channels=self.channel_dim, kernel_size=3, stride=1, padding=2, bias=False) def resnet_block(self,X): #隐藏层使用和输入一样的大小 out = self.GN(X) out = self.conv(out) out = self.silu(out) out = self.GN(out) out = self.conv(out) out = self.silu(out) return out + X def attention_block(self,X): B,C,H,W = X.size() out = self.GN(X) out = self.conv(out) out = out.view(B, self.in_channel, H * W).transpose(1, 2) # 将输入重构为 [B, L, C],其中 L = H * W out, weights = self.attention(out, out, out) out = out.transpose(1, 2).view(B, self.in_channel, H, W) out = self.conv(out) return out+X def forward(self,X): out = self.resnet_block(X) out = self.attention_block(out) out = self.conv1(out) return out ''' model = conv_block(in_channel=4,num_heads=4,channel_dim=64,use="down") in_put = torch.randn(1,4,256,256) #注意,在SR3代码中隐藏层是不变的和输入一致 ouput = model(in_put) print(ouput.shape) ''' # ------------------------------------------------------------------------------------------------- class SpatialAttention(nn.Module): def __init__(self, in_channels): super(SpatialAttention, self).__init__() self.conv = nn.Conv2d(in_channels, 1, kernel_size=1) def forward(self, x): # Apply convolution to generate attention map attention_map = self.conv(x) # Generate attention scores attention_scores = torch.softmax(attention_map, dim=1) # Apply attention scores out = x * attention_scores return out class ChannelAttention(nn.Module): def __init__(self, in_channels, reduction_ratio=16): super(ChannelAttention, self).__init__() self.avg_pool = nn.AdaptiveAvgPool2d(1) self.fc = nn.Sequential( nn.Linear(in_channels, in_channels // reduction_ratio, bias=False), nn.ReLU(), nn.Linear(in_channels // reduction_ratio, in_channels, bias=False), nn.ReLU() ) def forward(self, x): # Average pooling to generate a channel descriptor avg_out = self.avg_pool(x).view(x.size(0), -1) # Apply fully connected layers to generate channel attention attn = self.fc(avg_out) # Reshape attention to match the input attn = attn.view(x.size(0), -1, 1, 1) return x * attn def calculate_attention(X, num_heads, use): X = X.to(device) B, C, H, W = X.size() if use == "down": # Apply channel attention channel_attention = ChannelAttention(C).to(device) out = channel_attention(X) elif use == "up": # Reshape and transpose for multi-head attention up = X.view(B, C, H * W).transpose(1, 2) spatial_attention = nn.MultiheadAttention(embed_dim=C, num_heads=num_heads).to(device) out, weights = spatial_attention(up, up, up) # Apply spatial attention on upsampled output out = out.transpose(1, 2).view(B, C, H,W) spatial_attention_module = SpatialAttention(in_channels=C).to(device) out = spatial_attention_module(out) # Reshape output to match the original input dimensions return out ''' # Example usage X = torch.randn(1,4,572,572) # Example input tensor num_heads = 4 attention_out = calculate_attention(X, num_heads,use="up") print("attention out",attention_out.shape) ''' ''' X = torch.randn(1, 64, 254, 254) output = calculate_attention(X,num_heads=8) print("attention", output.shape) # 应该输出 torch.Size([1, 64, 254, 254]) ''' # ----------------------------------------------------------------------------------- def generate_positional_encoding(X): X = X.to(device) B,C,H,W = X.size() # 初始化位置编码矩阵 pos_encoding = torch.zeros(B, C, H, W) # 计算位置索引 y_positions = torch.arange(0, H, dtype=torch.float32).unsqueeze(1).repeat(1, W) #[H,W] x_positions = torch.arange(0, W, dtype=torch.float32).unsqueeze(0).repeat(H, 1) # 将位置索引除以尺度以进行缩放 y_positions = y_positions / (H ** 0.5) x_positions = x_positions / (W ** 0.5) # 计算位置编码的正弦和余弦值 for i in range(0, C, 2): pos_encoding[:, i, :, :] = torch.sin(x_positions) pos_encoding[:, i + 1, :, :] = torch.cos(y_positions) return pos_encoding ''' X = torch.randn(1,128, 512, 512) # 计算位置编码 pos_encoding = generate_positional_encoding(X) print("Positional Encoding shape:", pos_encoding.shape) # 应该输出 torch.Size([1, 64, 254, 254]) ''' class down_block(nn.Module): #宽高减4,然后除以2 def __init__(self,in_channel,channel_dim): super(down_block,self).__init__() self.channel_dim = channel_dim self.block1 = conv_block(in_channel=in_channel,num_heads=4, channel_dim=self.channel_dim,use="down") self.block2 = conv_block(in_channel=self.channel_dim, num_heads=4, channel_dim=self.channel_dim, use="down") self.down_pool = nn.Conv2d(in_channels=self.channel_dim, out_channels=self.channel_dim, kernel_size=2, stride=2, padding=0, bias=False) def forward(self,X): #输入[1,4,128,128],输出[1.64,124,124]-->[1,64,62,62] out = self.block1(X) for_skip_connection = self.block2(out) #这个out用于跳跃连接的 result_after_pool = self.down_pool(for_skip_connection) return result_after_pool,for_skip_connection ''' model1 = down_block(in_channel=64,channel_dim=128) input = torch.randn(1,64,284,284) res,out = model1(input) print(res.shape,out.shape) ''' # -------------------------------------------------------------------------------------------------- class up_block(nn.Module): def __init__(self,in_channel): super(up_block,self).__init__() self.in_channel = in_channel self.block1 = conv_block(in_channel=in_channel*2, num_heads=4, channel_dim=in_channel,use="up") self.block2 = conv_block(in_channel=in_channel, num_heads=4, channel_dim=in_channel,use="up") self.up_pool = nn.ConvTranspose2d(self.in_channel*2, self.in_channel, kernel_size=2, stride=2) def forward(self,input,input_skip): #先对输入进行上采样,然后和跳跃的拼接,之后经过两个block after_transposed = self.up_pool(input) #上采样得到的大小 after_cat = torch.cat((after_transposed, input_skip), dim=1) # 拼接张量 out = self.block1(after_cat) out = self.block2(out) return out,after_transposed ''' model2 = up_block(in_channel=128) input = torch.randn(1,256,140,140) input_skip = torch.randn(1,128,280,280) out,after = model2(input,input_skip) print("up block",out.shape) #torch.Size([1, 128, 284, 284]) ''' class down_model(nn.Module): def __init__(self): super(down_model,self).__init__() self.start_conv = nn.Conv2d(in_channels=3, out_channels=4, kernel_size=1, stride=1) self.down_block1 = down_block(4,64) self.down_block2 = down_block(64,128) self.down_block3 = down_block(128,256) self.down_block4 = down_block(256,512) self.bottle_conv = nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=1, stride=1) self.up_block4 = up_block(512) self.up_block3 = up_block(256) self.up_block2 = up_block(128) self.up_block1 = up_block(64) self.final_conv = nn.Conv2d(in_channels=64, out_channels=3, kernel_size=1, stride=1) def forward(self,input): #这个地方的输入一定要除的尽 input = self.start_conv(input) result_after_pool1, for_skip_connection1 = self.down_block1(input) attention_out1 = calculate_attention(for_skip_connection1, num_heads=4, use="down") pos_encoding1 = generate_positional_encoding(for_skip_connection1) # print("1",result_after_pool1.shape,for_skip_connection1.shape) result_after_pool2, for_skip_connection2 = self.down_block2(result_after_pool1) attention_out2 = calculate_attention(for_skip_connection2, num_heads=4, use="down") pos_encoding2 = generate_positional_encoding(for_skip_connection2) # print("2",result_after_pool2.shape, for_skip_connection2.shape) result_after_pool3, for_skip_connection3 = self.down_block3(result_after_pool2) attention_out3 = calculate_attention(for_skip_connection3, num_heads=4, use="down") pos_encoding3 = generate_positional_encoding(for_skip_connection3) # print("3",result_after_pool3.shape, for_skip_connection3.shape) result_after_pool4, for_skip_connection4 = self.down_block4(result_after_pool3) attention_out4 = calculate_attention(for_skip_connection4, num_heads=4, use="down") pos_encoding4 = generate_positional_encoding(for_skip_connection4) # print("4",result_after_pool4.shape, for_skip_connection4.shape) result_after_pool4 = self.bottle_conv(result_after_pool4) # print("bottle",result_after_pool4.shape) out, after_transposed1 = self.up_block4(result_after_pool4, for_skip_connection4) attention_out5 = calculate_attention(after_transposed1, num_heads=4, use="up") pos_encoding5 = generate_positional_encoding(after_transposed1) # print("5",out.shape,after_transposed1.shape) out, after_transposed2 = self.up_block3(out, for_skip_connection3) attention_out6 = calculate_attention(after_transposed2, num_heads=4, use="up").to(device) pos_encoding6 = generate_positional_encoding(after_transposed2).to(device) # print("6",out.shape, after_transposed2.shape) out, after_transposed3 = self.up_block2(out, for_skip_connection2) attention_out7 = calculate_attention(after_transposed3, num_heads=4, use="up").to(device) pos_encoding7 = generate_positional_encoding(after_transposed3).to(device) # print("7",out.shape, after_transposed3.shape) out, after_transposed4 = self.up_block1(out, for_skip_connection1) attention_out8 = calculate_attention(after_transposed4, num_heads=4, use="up").to(device) pos_encoding8 = generate_positional_encoding(after_transposed4).to(device) # print("8",out.shape, after_transposed4.shape) out = self.final_conv(out) return out,attention_out1,attention_out2,attention_out3,attention_out4,attention_out5,attention_out6,attention_out7,attention_out8,pos_encoding1,pos_encoding2,pos_encoding3,pos_encoding4,pos_encoding5,pos_encoding6,pos_encoding7,pos_encoding8 ''' all_model = model() input = torch.randn(1,4,1024,1024) output = all_model(input) print(output.shape) ''' all_model = down_model().to(device) loss_function = nn.MSELoss().to(device) #2.定义loss optimizer = torch.optim.Adam(all_model.parameters(),lr=1e-6) #3.定义优化器 epoch = 3 batch_size = 10 image_size = 268 #【10,3,268,268】 ds = load_dataset("bitmind/ffhq-256",split="train") preprocess = transforms.Compose( [ transforms.Resize((image_size, image_size)), # Resize transforms.RandomHorizontalFlip(), # Randomly flip (data augmentation) transforms.ToTensor(), # Convert to tensor (0, 1) transforms.Normalize([0.5], [0.5]), # Map to (-1, 1) ] ) def transform(examples): images = [preprocess(image.convert("RGB")) for image in examples["image"]] return {"images": images} ds.set_transform(transform) dataloader = torch.utils.data.DataLoader(ds,batch_size=batch_size,shuffle=True) for i in range(epoch): for idx, batch_x in enumerate(dataloader): images = batch_x["images"].to(device) # print(images.shape) #(4,3,572,572) output = all_model(images).to(device) loss = loss_function(output, images) optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(all_model.parameters(), 1.) optimizer.step() print("epoch:", i, "loss:", loss.item()) wandb.log({'epoch': i,"batch:": idx,'loss':loss}) #torch.save(model.state_dict(), 'model_weights.pth')