import logging
import math
from collections import OrderedDict

import torch
import torch.nn.functional as F
from torchvision.utils import save_image

from models.archs.fcn_arch import MultiHeadFCNHead
from models.archs.unet_arch import UNet
from models.archs.vqgan_arch import (Decoder, DecoderRes, Encoder,
                                     VectorQuantizerSpatialTextureAware,
                                     VectorQuantizerTexture)
from models.losses.accuracy import accuracy
from models.losses.cross_entropy_loss import CrossEntropyLoss

logger = logging.getLogger('base')


class VQGANTextureAwareSpatialHierarchyInferenceModel():

    def __init__(self, opt):
        self.opt = opt
        self.device = torch.device('cuda')
        self.is_train = opt['is_train']

        self.top_encoder = Encoder(
            ch=opt['top_ch'],
            num_res_blocks=opt['top_num_res_blocks'],
            attn_resolutions=opt['top_attn_resolutions'],
            ch_mult=opt['top_ch_mult'],
            in_channels=opt['top_in_channels'],
            resolution=opt['top_resolution'],
            z_channels=opt['top_z_channels'],
            double_z=opt['top_double_z'],
            dropout=opt['top_dropout']).to(self.device)
        self.decoder = Decoder(
            in_channels=opt['top_in_channels'],
            resolution=opt['top_resolution'],
            z_channels=opt['top_z_channels'],
            ch=opt['top_ch'],
            out_ch=opt['top_out_ch'],
            num_res_blocks=opt['top_num_res_blocks'],
            attn_resolutions=opt['top_attn_resolutions'],
            ch_mult=opt['top_ch_mult'],
            dropout=opt['top_dropout'],
            resamp_with_conv=True,
            give_pre_end=False).to(self.device)
        self.top_quantize = VectorQuantizerTexture(
            1024, opt['embed_dim'], beta=0.25).to(self.device)
        self.top_quant_conv = torch.nn.Conv2d(opt["top_z_channels"],
                                              opt['embed_dim'],
                                              1).to(self.device)
        self.top_post_quant_conv = torch.nn.Conv2d(opt['embed_dim'],
                                                   opt["top_z_channels"],
                                                   1).to(self.device)
        self.load_top_pretrain_models()

        self.bot_encoder = Encoder(
            ch=opt['bot_ch'],
            num_res_blocks=opt['bot_num_res_blocks'],
            attn_resolutions=opt['bot_attn_resolutions'],
            ch_mult=opt['bot_ch_mult'],
            in_channels=opt['bot_in_channels'],
            resolution=opt['bot_resolution'],
            z_channels=opt['bot_z_channels'],
            double_z=opt['bot_double_z'],
            dropout=opt['bot_dropout']).to(self.device)
        self.bot_decoder_res = DecoderRes(
            in_channels=opt['bot_in_channels'],
            resolution=opt['bot_resolution'],
            z_channels=opt['bot_z_channels'],
            ch=opt['bot_ch'],
            num_res_blocks=opt['bot_num_res_blocks'],
            ch_mult=opt['bot_ch_mult'],
            dropout=opt['bot_dropout'],
            give_pre_end=False).to(self.device)
        self.bot_quantize = VectorQuantizerSpatialTextureAware(
            opt['bot_n_embed'],
            opt['embed_dim'],
            beta=0.25,
            spatial_size=opt['codebook_spatial_size']).to(self.device)
        self.bot_quant_conv = torch.nn.Conv2d(opt["bot_z_channels"],
                                              opt['embed_dim'],
                                              1).to(self.device)
        self.bot_post_quant_conv = torch.nn.Conv2d(opt['embed_dim'],
                                                   opt["bot_z_channels"],
                                                   1).to(self.device)

        self.load_bot_pretrain_network()

        self.guidance_encoder = UNet(
            in_channels=opt['encoder_in_channels']).to(self.device)
        self.index_decoder = MultiHeadFCNHead(
            in_channels=opt['fc_in_channels'],
            in_index=opt['fc_in_index'],
            channels=opt['fc_channels'],
            num_convs=opt['fc_num_convs'],
            concat_input=opt['fc_concat_input'],
            dropout_ratio=opt['fc_dropout_ratio'],
            num_classes=opt['fc_num_classes'],
            align_corners=opt['fc_align_corners'],
            num_head=18).to(self.device)

        self.init_training_settings()

    def init_training_settings(self):
        optim_params = []
        for v in self.guidance_encoder.parameters():
            if v.requires_grad:
                optim_params.append(v)
        for v in self.index_decoder.parameters():
            if v.requires_grad:
                optim_params.append(v)
        # set up optimizers
        if self.opt['optimizer'] == 'Adam':
            self.optimizer = torch.optim.Adam(
                optim_params,
                self.opt['lr'],
                weight_decay=self.opt['weight_decay'])
        elif self.opt['optimizer'] == 'SGD':
            self.optimizer = torch.optim.SGD(
                optim_params,
                self.opt['lr'],
                momentum=self.opt['momentum'],
                weight_decay=self.opt['weight_decay'])
        self.log_dict = OrderedDict()
        if self.opt['loss_function'] == 'cross_entropy':
            self.loss_func = CrossEntropyLoss().to(self.device)

    def load_top_pretrain_models(self):
        # load pretrained vqgan for segmentation mask
        top_vae_checkpoint = torch.load(self.opt['top_vae_path'])
        self.top_encoder.load_state_dict(
            top_vae_checkpoint['encoder'], strict=True)
        self.decoder.load_state_dict(
            top_vae_checkpoint['decoder'], strict=True)
        self.top_quantize.load_state_dict(
            top_vae_checkpoint['quantize'], strict=True)
        self.top_quant_conv.load_state_dict(
            top_vae_checkpoint['quant_conv'], strict=True)
        self.top_post_quant_conv.load_state_dict(
            top_vae_checkpoint['post_quant_conv'], strict=True)
        self.top_encoder.eval()
        self.top_quantize.eval()
        self.top_quant_conv.eval()
        self.top_post_quant_conv.eval()

    def load_bot_pretrain_network(self):
        checkpoint = torch.load(self.opt['bot_vae_path'])
        self.bot_encoder.load_state_dict(
            checkpoint['bot_encoder'], strict=True)
        self.bot_decoder_res.load_state_dict(
            checkpoint['bot_decoder_res'], strict=True)
        self.decoder.load_state_dict(checkpoint['decoder'], strict=True)
        self.bot_quantize.load_state_dict(
            checkpoint['bot_quantize'], strict=True)
        self.bot_quant_conv.load_state_dict(
            checkpoint['bot_quant_conv'], strict=True)
        self.bot_post_quant_conv.load_state_dict(
            checkpoint['bot_post_quant_conv'], strict=True)

        self.bot_encoder.eval()
        self.bot_decoder_res.eval()
        self.decoder.eval()
        self.bot_quantize.eval()
        self.bot_quant_conv.eval()
        self.bot_post_quant_conv.eval()

    def top_encode(self, x, mask):
        h = self.top_encoder(x)
        h = self.top_quant_conv(h)
        quant, _, _ = self.top_quantize(h, mask)
        quant = self.top_post_quant_conv(quant)

        return quant, quant

    def feed_data(self, data):
        self.image = data['image'].to(self.device)
        self.texture_mask = data['texture_mask'].float().to(self.device)
        self.get_gt_indices()

        self.texture_tokens = F.interpolate(
            self.texture_mask, size=(32, 16),
            mode='nearest').view(self.image.size(0), -1).long()

    def bot_encode(self, x, mask):
        h = self.bot_encoder(x)
        h = self.bot_quant_conv(h)
        _, _, (_, _, indices_list) = self.bot_quantize(h, mask)

        return indices_list

    def get_gt_indices(self):
        self.quant_t, self.feature_t = self.top_encode(self.image,
                                                       self.texture_mask)
        self.gt_indices_list = self.bot_encode(self.image, self.texture_mask)

    def index_to_image(self, index_bottom_list, texture_mask):
        quant_b = self.bot_quantize.get_codebook_entry(
            index_bottom_list, texture_mask,
            (index_bottom_list[0].size(0), index_bottom_list[0].size(1),
             index_bottom_list[0].size(2),
             self.opt["bot_z_channels"]))  #.permute(0, 3, 1, 2)
        quant_b = self.bot_post_quant_conv(quant_b)
        bot_dec_res = self.bot_decoder_res(quant_b)

        dec = self.decoder(self.quant_t, bot_h=bot_dec_res)

        return dec

    def get_vis(self, pred_img_index, rec_img_index, texture_mask, save_path):
        rec_img = self.index_to_image(rec_img_index, texture_mask)
        pred_img = self.index_to_image(pred_img_index, texture_mask)

        base_img = self.decoder(self.quant_t)
        img_cat = torch.cat([
            self.image,
            rec_img,
            base_img,
            pred_img,
        ], dim=3).detach()
        img_cat = ((img_cat + 1) / 2)
        img_cat = img_cat.clamp_(0, 1)
        save_image(img_cat, save_path, nrow=1, padding=4)

    def optimize_parameters(self):
        self.guidance_encoder.train()
        self.index_decoder.train()

        self.feature_enc = self.guidance_encoder(self.feature_t)
        self.memory_logits_list = self.index_decoder(self.feature_enc)

        loss = 0
        for i in range(18):
            loss += self.loss_func(
                self.memory_logits_list[i],
                self.gt_indices_list[i],
                ignore_index=-1)

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        self.log_dict['loss_total'] = loss

    def inference(self, data_loader, save_dir):
        self.guidance_encoder.eval()
        self.index_decoder.eval()

        acc = 0
        num = 0

        for _, data in enumerate(data_loader):
            self.feed_data(data)
            img_name = data['img_name']

            num += self.image.size(0)

            texture_mask_flatten = self.texture_tokens.view(-1)
            min_encodings_indices_list = [
                torch.full(
                    texture_mask_flatten.size(),
                    fill_value=-1,
                    dtype=torch.long,
                    device=texture_mask_flatten.device) for _ in range(18)
            ]
            with torch.no_grad():
                self.feature_enc = self.guidance_encoder(self.feature_t)
                memory_logits_list = self.index_decoder(self.feature_enc)
            # memory_indices_pred = memory_logits.argmax(dim=1)
            batch_acc = 0
            for codebook_idx, memory_logits in enumerate(memory_logits_list):
                region_of_interest = texture_mask_flatten == codebook_idx
                if torch.sum(region_of_interest) > 0:
                    memory_indices_pred = memory_logits.argmax(dim=1).view(-1)
                    batch_acc += torch.sum(
                        memory_indices_pred[region_of_interest] ==
                        self.gt_indices_list[codebook_idx].view(
                            -1)[region_of_interest])
                    memory_indices_pred = memory_indices_pred
                    min_encodings_indices_list[codebook_idx][
                        region_of_interest] = memory_indices_pred[
                            region_of_interest]
            min_encodings_indices_return_list = [
                min_encodings_indices.view(self.gt_indices_list[0].size())
                for min_encodings_indices in min_encodings_indices_list
            ]
            batch_acc = batch_acc / self.gt_indices_list[codebook_idx].numel(
            ) * self.image.size(0)
            acc += batch_acc
            self.get_vis(min_encodings_indices_return_list,
                         self.gt_indices_list, self.texture_mask,
                         f'{save_dir}/{img_name[0]}')

        self.guidance_encoder.train()
        self.index_decoder.train()
        return (acc / num).item()

    def load_network(self):
        checkpoint = torch.load(self.opt['pretrained_models'])
        self.guidance_encoder.load_state_dict(
            checkpoint['guidance_encoder'], strict=True)
        self.guidance_encoder.eval()

        self.index_decoder.load_state_dict(
            checkpoint['index_decoder'], strict=True)
        self.index_decoder.eval()

    def save_network(self, save_path):
        """Save networks.

        Args:
            net (nn.Module): Network to be saved.
            net_label (str): Network label.
            current_iter (int): Current iter number.
        """

        save_dict = {}
        save_dict['guidance_encoder'] = self.guidance_encoder.state_dict()
        save_dict['index_decoder'] = self.index_decoder.state_dict()

        torch.save(save_dict, save_path)

    def update_learning_rate(self, epoch):
        """Update learning rate.

        Args:
            current_iter (int): Current iteration.
            warmup_iter (int): Warmup iter numbers. -1 for no warmup.
                Default: -1.
        """
        lr = self.optimizer.param_groups[0]['lr']

        if self.opt['lr_decay'] == 'step':
            lr = self.opt['lr'] * (
                self.opt['gamma']**(epoch // self.opt['step']))
        elif self.opt['lr_decay'] == 'cos':
            lr = self.opt['lr'] * (
                1 + math.cos(math.pi * epoch / self.opt['num_epochs'])) / 2
        elif self.opt['lr_decay'] == 'linear':
            lr = self.opt['lr'] * (1 - epoch / self.opt['num_epochs'])
        elif self.opt['lr_decay'] == 'linear2exp':
            if epoch < self.opt['turning_point'] + 1:
                # learning rate decay as 95%
                # at the turning point (1 / 95% = 1.0526)
                lr = self.opt['lr'] * (
                    1 - epoch / int(self.opt['turning_point'] * 1.0526))
            else:
                lr *= self.opt['gamma']
        elif self.opt['lr_decay'] == 'schedule':
            if epoch in self.opt['schedule']:
                lr *= self.opt['gamma']
        else:
            raise ValueError('Unknown lr mode {}'.format(self.opt['lr_decay']))
        # set learning rate
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr

        return lr

    def get_current_log(self):
        return self.log_dict