import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import sympy as sp
import wandb
from PIL import Image
from datasets import load_dataset
from torchvision import transforms

from down_unet import down_model
'''上面的网络需要接受三个信息,上下采样模块需要重写,两次宽高减2后接受三个信息,renet块加入时间信息,'''

class conv_block(nn.Module):    #一个下采样模块包含两个卷积层,深度channel从1-64-128-256这样[B,C,H,W]-->[B,C_DIM,H-2,W-2]
    def __init__(self,in_channel,num_heads,channel_dim,use ="down"):
        super(conv_block,self).__init__()      #in_channel输入通道数,channle_dim输出通道数,一个块减少2


        self.in_channel = in_channel
        self.num_heads = num_heads
        self.channel_dim = channel_dim
        self.use = use

        self.GN = nn.GroupNorm(num_groups=4, num_channels=in_channel)  #这个channel指的是输入通道数
        # num_groups 是组数(2,4,8)输入特征的通道分成多少组进行归一化,num_channels 是输入的通道数
        self.conv = nn.Conv2d(in_channels=in_channel, out_channels=in_channel, kernel_size=3,
                          stride=1, padding=1, bias=False)
        self.silu = nn.SiLU()
        self.attention = nn.MultiheadAttention(embed_dim=self.in_channel, num_heads=self.num_heads)

        if self.use == "down":
            self.conv1 = nn.Conv2d(in_channels=self.in_channel, out_channels=self.channel_dim, kernel_size=3,
                             stride=1, padding=0, bias=False)
        elif self.use =="up":
            self.conv1 = nn.Conv2d(in_channels=self.in_channel, out_channels=self.channel_dim, kernel_size=3,
                         stride=1, padding=2, bias=False)

    def resnet_block(self,X):    #隐藏层使用和输入一样的大小

        out = self.GN(X)
        out = self.conv(out)
        out = self.silu(out)    #这里要加入时间信息

        out = self.GN(out)
        out = self.conv(out)
        out = self.silu(out)

        return out + X

    def attention_block(self,X):

        B,C,H,W = X.size()

        out = self.GN(X)
        out = self.conv(out)

        out = out.view(B, self.in_channel, H * W).transpose(1, 2)  # 将输入重构为 [B, L, C],其中 L = H * W
        out, weights = self.attention(out, out, out)
        out = out.transpose(1, 2).view(B, self.in_channel, H, W)

        out = self.conv(out)

        return out+X

    def forward(self,X):

        out = self.resnet_block(X)
        out = self.attention_block(out)
        out = self.conv1(out)

        return out



class down_block(nn.Module):  #宽高减4,加入两个信息,然后然后除以2
    def __init__(self,in_channel,channel_dim):   #in_channel4-->channel_dim64
        super(down_block,self).__init__()

        self.channel_dim = channel_dim

        self.in_channel = in_channel

        self.block1 = conv_block(in_channel=self.in_channel,num_heads=4,
                           channel_dim=self.channel_dim,use="down")
        self.block2 = conv_block(in_channel=self.channel_dim, num_heads=4,
                                channel_dim=self.channel_dim, use="down")

        self.return_conv = nn.Conv2d(in_channels=self.channel_dim*2,out_channels=self.channel_dim,kernel_size=1,
                                     stride=1,padding=0,bias=False)

        self.attention = nn.MultiheadAttention(embed_dim=self.channel_dim, num_heads=4)

        self.down_pool = nn.Conv2d(in_channels=self.channel_dim, out_channels=self.channel_dim, kernel_size=2,
                              stride=2, padding=0, bias=False)

    def caculate_attention(self,X_q,Y_kv):

        B,C,H,W = X_q.size()

        X_q = X_q.view(B, self.channel_dim, H * W).transpose(1, 2)  # 将输入重构为 [B, L, C],其中 L = H * W
        Y_kv = Y_kv.view(B, self.channel_dim, H * W).transpose(1, 2)

        out, weights = self.attention(X_q, Y_kv, Y_kv)
        out = out.transpose(1, 2).view(B, self.channel_dim, H, W)

        return out

    def forward(self,X,attention_out,pos_encoding):  #输入[1,4,128,128],输出[1.64,124,124]-->[1,64,62,62]

        out = self.block1(X)
        for_skip_connection = self.block2(out)

        out = torch.cat((for_skip_connection,pos_encoding),dim=1)
        out = self.return_conv(out)

        out = self.caculate_attention(X_q=attention_out,Y_kv=out)

        out = self.down_pool(out)

        return out,for_skip_connection

'''
X = torch.randn(1,4,128,128)
attention_out = torch.randn(1,64,124,124)
pos_encoding = torch.randn(1,64,124,124)
model = down_block(4,64,4)
out = model(X,attention_out,pos_encoding)
print(out.shape)
'''

class up_block(nn.Module):
    def __init__(self,in_channel):  #这里的in_channel指的是cat之后的通道数
        super(up_block,self).__init__()
        self.in_channel = in_channel



        self.block1 = conv_block(in_channel=in_channel*2, num_heads=4,
                           channel_dim=in_channel,use="up")
        self.block2 = conv_block(in_channel=in_channel, num_heads=4,
                           channel_dim=in_channel,use="up")
        self.up_pool = nn.ConvTranspose2d(self.in_channel*2, self.in_channel,
                                          kernel_size=2, stride=2)

        self.return_conv = nn.Conv2d(in_channels=self.in_channel * 2, out_channels=self.in_channel, kernel_size=1,
                                     stride=1, padding=0, bias=False)

        self.attention = nn.MultiheadAttention(embed_dim=self.in_channel, num_heads=4)

    def caculate_attention(self,X_q,Y_kv):

        B,C,H,W = X_q.size()

        X_q = X_q.view(B, self.in_channel, H * W).transpose(1, 2)  # 将输入重构为 [B, L, C],其中 L = H * W
        Y_kv = Y_kv.view(B, self.in_channel, H * W).transpose(1, 2)

        out, weights = self.attention(X_q, Y_kv, Y_kv)
        out = out.transpose(1, 2).view(B, self.in_channel, H, W)

        return out

    def forward(self,input,input_skip,attention_out,pos_encoding):  #先对输入进行上采样,然后和跳跃的拼接,之后经过两个block


        after_transposed = self.up_pool(input)   #上采样得到的大小

        after_cat = torch.cat((after_transposed, input_skip), dim=1)  # 拼接张量
        after_cat = self.return_conv(after_cat)
        after_cat = torch.cat((after_cat, pos_encoding), dim=1)
        after_cat = self.return_conv(after_cat)

        out = self.caculate_attention(X_q=attention_out, Y_kv=after_cat)

        out = self.block2(out)      #通道数不用再降低了
        out = self.block2(out)

        return out

'''
X = torch.randn(1,128,62,62)
input_skip = torch.randn(1,64,124,124)
attention_out = torch.randn(1,64,124,124)
pos_encoding = torch.randn(1,64,124,124)
model = up_block(in_channel=64,num_head=4)
out = model(X,input_skip,attention_out,pos_encoding)
print(out.shape)  # torch.Size([1, 64, 128, 128])
'''

class up_model(nn.Module):
    def __init__(self):
        super(up_model,self).__init__()

        self.down_model = down_model()

        self.start_conv = nn.Conv2d(in_channels=3, out_channels=4, kernel_size=1, stride=1)

        self.down_block1 = down_block(4,64)
        self.down_block2 = down_block(64,128)
        self.down_block3 = down_block(128,256)
        self.down_block4 = down_block(256,512)

        self.bottle_conv = nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=1, stride=1)

        self.up_block4 = up_block(512)
        self.up_block3 = up_block(256)
        self.up_block2 = up_block(128)
        self.up_block1 = up_block(64)

        self.final_conv = nn.Conv2d(in_channels=64, out_channels=3, kernel_size=1, stride=1)

    def forward(self,input):   #这个地方的输入一定要除的尽

        X, attention_out1, attention_out2, attention_out3, attention_out4, attention_out5, attention_out6, attention_out7, attention_out8, pos_encoding1, pos_encoding2, pos_encoding3, pos_encoding4, pos_encoding5, pos_encoding6, pos_encoding7, pos_encoding8 =self.down_model(input)

        input = self.start_conv(input)

        out,for_skip1= self.down_block1(input,attention_out8,pos_encoding8)

        out,for_skip2 = self.down_block1(out, attention_out7, pos_encoding7)

        out,for_skip3 = self.down_block1(out, attention_out6, pos_encoding6)

        out,for_skip4 = self.down_block1(out, attention_out5, pos_encoding5)

        out = self.bottle_conv(out)
        # print("bottle",out.shape)

        out = self.up_block4(out, for_skip4, attention_out4,pos_encoding4)

        out = self.up_block4(out, for_skip3, attention_out3, pos_encoding3)

        out = self.up_block4(out, for_skip2, attention_out2, pos_encoding2)

        out = self.up_block4(out, for_skip1, attention_out1, pos_encoding1)

        out = self.final_conv(out)

        return out