import torch
import torch.nn as nn
import torch.nn.functional as F
from .warplayer import warp_features
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class DecoderBlock(nn.Module):
    def __init__(self, in_planes, c=224, out_msgs=0, out_locals=0, block_nums=1, out_masks=1, out_local_flows=32, out_msgs_flows=32, out_feat_flows=0):

        super(DecoderBlock, self).__init__()
        self.conv0 = nn.Sequential(
            nn.Conv2d(in_planes, c, 3, 2, 1),
            nn.PReLU(c),
            nn.Conv2d(c, c, 3, 2, 1),
            nn.PReLU(c),
        )

        self.convblocks = nn.ModuleList()
        for i in range(block_nums):
            self.convblocks.append(nn.Sequential(
                nn.Conv2d(c, c, 3, 1, 1),
                nn.PReLU(c),
                nn.Conv2d(c, c, 3, 1, 1),
                nn.PReLU(c),
                nn.Conv2d(c, c, 3, 1, 1),
                nn.PReLU(c),
                nn.Conv2d(c, c, 3, 1, 1),
                nn.PReLU(c),
                nn.Conv2d(c, c, 3, 1, 1),
                nn.PReLU(c),
                nn.Conv2d(c, c, 3, 1, 1),
                nn.PReLU(c),
            ))
        self.out_flows = 2
        self.out_msgs = out_msgs
        self.out_msgs_flows = out_msgs_flows if out_msgs > 0 else 0
        self.out_locals = out_locals
        self.out_local_flows = out_local_flows if out_locals > 0 else 0
        self.out_masks = out_masks
        self.out_feat_flows = out_feat_flows

        self.conv_last = nn.Sequential(
            nn.ConvTranspose2d(c, c, 4, 2, 1),
            nn.PReLU(c),
            nn.ConvTranspose2d(c, self.out_flows+self.out_msgs+self.out_msgs_flows +
                               self.out_locals+self.out_local_flows+self.out_masks+self.out_feat_flows, 4, 2, 1),
        )

    def forward(self, accumulated_flow, *other):
        x = [accumulated_flow]
        for each in other:
            if each is not None:
                assert(accumulated_flow.shape[-1] == each.shape[-1]), "decoder want {}, but get {}".format(
                    accumulated_flow.shape, each.shape)
                x.append(each)
        feat = self.conv0(torch.cat(x, dim=1))
        for convblock1 in self.convblocks:
            feat = convblock1(feat) + feat
        feat = self.conv_last(feat)
        prev = 0
        flow = feat[:, prev:prev+self.out_flows, :, :]
        prev += self.out_flows
        message = feat[:, prev:prev+self.out_msgs,
                       :, :] if self.out_msgs > 0 else None
        prev += self.out_msgs
        message_flow = feat[:, prev:prev + self.out_msgs_flows,
                            :, :] if self.out_msgs_flows > 0 else None
        prev += self.out_msgs_flows
        local_message = feat[:, prev:prev + self.out_locals,
                             :, :] if self.out_locals > 0 else None
        prev += self.out_locals
        local_message_flow = feat[:, prev:prev+self.out_local_flows,
                                  :, :] if self.out_local_flows > 0 else None
        prev += self.out_local_flows
        mask = torch.sigmoid(
            feat[:, prev:prev+self.out_masks, :, :]) if self.out_masks > 0 else None
        prev += self.out_masks
        feat_flow = feat[:, prev:prev+self.out_feat_flows,
                         :, :] if self.out_feat_flows > 0 else None
        prev += self.out_feat_flows
        return flow, mask, message, message_flow, local_message, local_message_flow, feat_flow


class CINN(nn.Module):
    def __init__(self, DIM_SHADER_REFERENCE, target_feature_chns=[512, 256, 128, 64, 64], feature_chns=[2048, 1024, 512, 256, 64], out_msgs_chn=[2048, 1024, 512, 256, 64, 64], out_locals_chn=[2048, 1024, 512, 256, 64, 0], block_num=[1, 1, 1, 1, 1, 2], block_chn_num=[224, 224, 224, 224, 224, 224]):
        super(CINN, self).__init__()
       
        self.in_msgs_chn = [0, *out_msgs_chn[:-1]]
        self.in_locals_chn = [0, *out_locals_chn[:-1]]

        self.decoder_blocks = nn.ModuleList()
        self.feed_weighted = True
        if self.feed_weighted:
            in_planes = 2+2+DIM_SHADER_REFERENCE*2
        else:
            in_planes = 2+DIM_SHADER_REFERENCE
        for each_target_feature_chns, each_feature_chns, each_out_msgs_chn, each_out_locals_chn, each_in_msgs_chn, each_in_locals_chn, each_block_num, each_block_chn_num in zip(target_feature_chns, feature_chns, out_msgs_chn, out_locals_chn, self.in_msgs_chn, self.in_locals_chn, block_num, block_chn_num):
            self.decoder_blocks.append(
                DecoderBlock(in_planes+each_target_feature_chns+each_feature_chns+each_in_locals_chn+each_in_msgs_chn, c=each_block_chn_num, block_nums=each_block_num, out_msgs=each_out_msgs_chn, out_locals=each_out_locals_chn, out_masks=2+each_out_locals_chn))
        for i in range(len(feature_chns), len(out_locals_chn)):
            #print("append extra block", i, "msg",
            #      out_msgs_chn[i], "local", out_locals_chn[i], "block", block_num[i])
            self.decoder_blocks.append(
                DecoderBlock(in_planes+self.in_msgs_chn[i]+self.in_locals_chn[i], c=block_chn_num[i], block_nums=block_num[i], out_msgs=out_msgs_chn[i], out_locals=out_locals_chn[i], out_masks=2+out_msgs_chn[i], out_feat_flows=0))

    def apply_flow(self, mask, message, message_flow, local_message, local_message_flow, x_reference, accumulated_flow, each_x_reference_features=None, each_x_reference_features_flow=None):
        if each_x_reference_features is not None:
            size_from = each_x_reference_features
        else:
            size_from = x_reference
        f_size = (size_from.shape[2], size_from.shape[3])
        accumulated_flow = self.flow_rescale(
            accumulated_flow, size_from)
        # mask = warp_features(F.interpolate(
        #    mask, size=f_size, mode="bilinear"), accumulated_flow) if mask is not None else None
        mask = F.interpolate(
            mask, size=f_size, mode="bilinear") if mask is not None else None
        message = F.interpolate(
            message, size=f_size, mode="bilinear") if message is not None else None
        message_flow = self.flow_rescale(
            message_flow, size_from) if message_flow is not None else None
        message = warp_features(
            message, message_flow) if message_flow is not None else message

        local_message = F.interpolate(
            local_message, size=f_size, mode="bilinear") if local_message is not None else None
        local_message_flow = self.flow_rescale(
            local_message_flow, size_from) if local_message_flow is not None else None
        local_message = warp_features(
            local_message, local_message_flow) if local_message_flow is not None else local_message

        warp_x_reference = warp_features(F.interpolate(
            x_reference, size=f_size, mode="bilinear"), accumulated_flow)

        each_x_reference_features_flow = self.flow_rescale(
            each_x_reference_features_flow, size_from) if (each_x_reference_features is not None and each_x_reference_features_flow is not None) else None
        warp_each_x_reference_features = warp_features(
            each_x_reference_features, each_x_reference_features_flow) if each_x_reference_features_flow is not None else each_x_reference_features

        return mask, message, local_message, warp_x_reference, accumulated_flow, warp_each_x_reference_features, each_x_reference_features_flow

    def forward(self, x_target_features=[], x_reference=None, x_reference_features=[]):
        y_flow = []
        y_feat_flow = []

        y_local_message = []
        y_warp_x_reference = []
        y_warp_x_reference_features = []

        y_weighted_flow = []
        y_weighted_mask = []
        y_weighted_message = []
        y_weighted_x_reference = []
        y_weighted_x_reference_features = []

        for pyrlevel, ifblock in enumerate(self.decoder_blocks):
            stacked_wref = []
            stacked_feat = []
            stacked_anci = []
            stacked_flow = []
            stacked_mask = []
            stacked_mesg = []
            stacked_locm = []
            stacked_feat_flow = []
            for view_id in range(x_reference.shape[1]):  # NMCHW

                if pyrlevel == 0:
                    # create from zero flow
                    feat_ev = x_reference_features[pyrlevel][:,
                                                             view_id, :, :, :] if pyrlevel < len(x_reference_features) else None

                    accumulated_flow = torch.zeros_like(
                        feat_ev[:, :2, :, :]).to(device)
                    accumulated_feat_flow = torch.zeros_like(
                        feat_ev[:, :32, :, :]).to(device)
                    # domestic inputs
                    warp_x_reference = F.interpolate(x_reference[:, view_id, :, :, :], size=(
                        feat_ev.shape[-2], feat_ev.shape[-1]), mode="bilinear")
                    warp_x_reference_features = feat_ev

                    local_message = None
                    # federated inputs
                    weighted_flow = accumulated_flow if self.feed_weighted else None
                    weighted_wref = warp_x_reference if self.feed_weighted else None
                    weighted_message = None
                else:
                    # resume from last layer
                    accumulated_flow = y_flow[-1][:, view_id, :, :, :]
                    accumulated_feat_flow = y_feat_flow[-1][:,
                                                            view_id, :, :, :] if y_feat_flow[-1] is not None else None
                    # domestic inputs
                    warp_x_reference = y_warp_x_reference[-1][:,
                                                              view_id, :, :, :]
                    warp_x_reference_features = y_warp_x_reference_features[-1][:,
                                                                                view_id, :, :, :] if y_warp_x_reference_features[-1] is not None else None
                    local_message = y_local_message[-1][:, view_id, :,
                                                        :, :] if len(y_local_message) > 0 else None

                    # federated inputs
                    weighted_flow = y_weighted_flow[-1] if self.feed_weighted else None
                    weighted_wref = y_weighted_x_reference[-1] if self.feed_weighted else None
                    weighted_message = y_weighted_message[-1] if len(
                        y_weighted_message) > 0 else None
                scaled_x_target = x_target_features[pyrlevel][:, :, :, :].detach() if pyrlevel < len(
                    x_target_features) else None
                # compute flow
                residual_flow, mask, message, message_flow, local_message, local_message_flow, residual_feat_flow = ifblock(
                    accumulated_flow, scaled_x_target, warp_x_reference, warp_x_reference_features, weighted_flow, weighted_wref, weighted_message, local_message)
                accumulated_flow = residual_flow + accumulated_flow
                accumulated_feat_flow = accumulated_flow

                feat_ev = x_reference_features[pyrlevel+1][:,
                                                           view_id, :, :, :] if pyrlevel+1 < len(x_reference_features) else None
                mask, message, local_message, warp_x_reference, accumulated_flow,  warp_x_reference_features, accumulated_feat_flow = self.apply_flow(
                    mask, message, message_flow, local_message, local_message_flow, x_reference[:, view_id, :, :, :], accumulated_flow, feat_ev, accumulated_feat_flow)
                stacked_flow.append(accumulated_flow)
                if accumulated_feat_flow is not None:
                    stacked_feat_flow.append(accumulated_feat_flow)
                stacked_mask.append(mask)
                if message is not None:
                    stacked_mesg.append(message)
                if local_message is not None:
                    stacked_locm.append(local_message)
                stacked_wref.append(warp_x_reference)
                if warp_x_reference_features is not None:
                    stacked_feat.append(warp_x_reference_features)

            stacked_flow = torch.stack(stacked_flow, dim=1)  # M*NCHW -> NMCHW
            stacked_feat_flow = torch.stack(stacked_feat_flow, dim=1) if len(
                stacked_feat_flow) > 0 else None
            stacked_mask = torch.stack(
                stacked_mask, dim=1)
            
            stacked_mesg = torch.stack(stacked_mesg, dim=1) if len(
                stacked_mesg) > 0 else None
            stacked_locm = torch.stack(stacked_locm, dim=1) if len(
                stacked_locm) > 0 else None

            stacked_wref = torch.stack(stacked_wref, dim=1)
            stacked_feat = torch.stack(stacked_feat, dim=1) if len(
                stacked_feat) > 0 else None
            stacked_anci = torch.stack(stacked_anci, dim=1) if len(
                stacked_anci) > 0 else None
            y_flow.append(stacked_flow)
            y_feat_flow.append(stacked_feat_flow)

            y_warp_x_reference.append(stacked_wref)
            y_warp_x_reference_features.append(stacked_feat)
            # compute normalized confidence
            stacked_contrib = torch.nn.functional.softmax(stacked_mask, dim=1)

            # torch.sum to remove temp dimension M from NMCHW --> NCHW
            weighted_flow = torch.sum(
                stacked_mask[:, :, 0:1, :, :] * stacked_contrib[:, :, 0:1, :, :] * stacked_flow, dim=1)
            weighted_mask = torch.sum(
                stacked_contrib[:, :, 0:1, :, :] * stacked_mask[:, :, 0:1, :, :], dim=1)
            weighted_wref = torch.sum(
                stacked_mask[:, :, 0:1, :, :] * stacked_contrib[:, :, 0:1, :, :] * stacked_wref, dim=1) if stacked_wref is not None else None
            weighted_feat = torch.sum(
                stacked_mask[:, :, 1:2, :, :] * stacked_contrib[:, :, 1:2, :, :] * stacked_feat, dim=1) if stacked_feat is not None else None
            weighted_mesg = torch.sum(
                stacked_mask[:, :, 2:, :, :] * stacked_contrib[:, :, 2:, :, :] * stacked_mesg, dim=1) if stacked_mesg is not None else None
            y_weighted_flow.append(weighted_flow)
            y_weighted_mask.append(weighted_mask)
            if weighted_mesg is not None:
                y_weighted_message.append(weighted_mesg)
            if stacked_locm is not None:
                y_local_message.append(stacked_locm)
            y_weighted_message.append(weighted_mesg)
            y_weighted_x_reference.append(weighted_wref)
            y_weighted_x_reference_features.append(weighted_feat)

            if weighted_feat is not None:
                y_weighted_x_reference_features.append(weighted_feat)
        return {
            "y_last_remote_features": [weighted_mesg],
        }

    def flow_rescale(self, prev_flow, each_x_reference_features):
        if prev_flow is None:
            prev_flow = torch.zeros_like(
                each_x_reference_features[:, :2]).to(device)
        else:
            up_scale_factor = each_x_reference_features.shape[-1] / \
                prev_flow.shape[-1]
            if up_scale_factor != 1:
                prev_flow = F.interpolate(prev_flow, scale_factor=up_scale_factor, mode="bilinear",
                                          align_corners=False, recompute_scale_factor=False) * up_scale_factor
        return prev_flow