import argparse

import torch.backends.cudnn as cudnn
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

from model.build_model import build_model

import torch
import cv2
import numpy as np
import torchvision
import os
import tqdm
import time

from utils.misc import prepare_cooridinate_input, customRandomCrop

from datasets.build_INR_dataset import Implicit2DGenerator
import albumentations
from albumentations import Resize
from torch.utils.data import DataLoader
from utils.misc import normalize

import math

global_state = [1]  # For Gradio Stop Button.

class single_image_dataset(torch.utils.data.Dataset):
    def __init__(self, opt, composite_image=None, mask=None):
        super().__init__()

        self.opt = opt

        if composite_image is None:
            composite_image = cv2.imread(opt.composite_image)
            composite_image = cv2.cvtColor(composite_image, cv2.COLOR_BGR2RGB)
        self.composite_image = composite_image

        assert composite_image.shape[0] == composite_image.shape[1], "This faster script only supports square images."
        assert composite_image.shape[
                   0] % 256 == 0, "This faster script only supports images with resolution multiples of 256."
        assert opt.split_resolution % (composite_image.shape[
                                           0] // 16) == 0, f"The image resolution is {composite_image.shape[0]}, " \
                                                           f"you should set {opt.split_resolution} to multiplies of {composite_image.shape[0] // 16}"

        if mask is None:
            mask = cv2.imread(opt.mask)
        mask = mask[:, :, 0].astype(np.float32) / 255.
        self.mask = mask

        self.torch_transforms = transforms.Compose([transforms.ToTensor(),
                                                    transforms.Normalize([.5, .5, .5], [.5, .5, .5])])
        self.INR_dataset = Implicit2DGenerator(opt, 'Val')

        self.split_width_resolution = self.split_height_resolution = opt.split_resolution

        self.num_w = math.ceil(composite_image.shape[1] / self.split_width_resolution)
        self.num_h = math.ceil(composite_image.shape[0] / self.split_height_resolution)

        self.split_start_point = []

        "Split the image into several parts."
        for i in range(self.num_h):
            for j in range(self.num_w):
                if i == composite_image.shape[0] // self.split_height_resolution:
                    if j == composite_image.shape[1] // self.split_width_resolution:
                        self.split_start_point.append((composite_image.shape[0] - self.split_height_resolution,
                                                       composite_image.shape[1] - self.split_width_resolution))
                    else:
                        self.split_start_point.append(
                            (composite_image.shape[0] - self.split_height_resolution, j * self.split_width_resolution))
                else:
                    if j == composite_image.shape[1] // self.split_width_resolution:
                        self.split_start_point.append(
                            (i * self.split_height_resolution, composite_image.shape[1] - self.split_width_resolution))
                    else:
                        self.split_start_point.append(
                            (i * self.split_height_resolution, j * self.split_width_resolution))

        assert len(self.split_start_point) == self.num_w * self.num_h

        print(
            f"The image will be split into {self.num_h} pieces in height, and {self.num_w} pieces in width. Totally {self.num_h * self.num_w} patches.")
        print(f"The final resolution of each patch is {self.split_height_resolution} x {self.split_width_resolution}")

    def __len__(self):
        return self.num_w * self.num_h

    def __getitem__(self, idx):
        composite_image = self.composite_image

        mask = self.mask

        full_coord = prepare_cooridinate_input(mask).transpose(1, 2, 0)

        tmp_transform = albumentations.Compose([Resize(self.opt.base_size, self.opt.base_size)],
                                               additional_targets={'object_mask': 'image'})
        transform_out = tmp_transform(image=self.composite_image, object_mask=self.mask)
        compos_list = [self.torch_transforms(transform_out['image'])]
        mask_list = [
            torchvision.transforms.ToTensor()(transform_out['object_mask'][..., np.newaxis].astype(np.float32))]
        coord_map_list = []

        if composite_image.shape[0] != self.split_height_resolution:
            c_h = self.split_start_point[idx][0] / (composite_image.shape[0] - self.split_height_resolution)
        else:
            c_h = 0
        if composite_image.shape[1] != self.split_width_resolution:
            c_w = self.split_start_point[idx][1] / (composite_image.shape[1] - self.split_width_resolution)
        else:
            c_w = 0
        transform_out, c_h, c_w = customRandomCrop([composite_image, mask, full_coord],
                                                   self.split_height_resolution, self.split_width_resolution, c_h, c_w)

        compos_list.append(self.torch_transforms(transform_out[0]))
        mask_list.append(
            torchvision.transforms.ToTensor()(transform_out[1][..., np.newaxis].astype(np.float32)))
        coord_map_list.append(torchvision.transforms.ToTensor()(transform_out[2]))
        coord_map_list.append(torchvision.transforms.ToTensor()(transform_out[2]))
        for n in range(2):
            tmp_comp = cv2.resize(composite_image, (
                composite_image.shape[1] // 2 ** (n + 1), composite_image.shape[0] // 2 ** (n + 1)))
            tmp_mask = cv2.resize(mask, (mask.shape[1] // 2 ** (n + 1), mask.shape[0] // 2 ** (n + 1)))
            tmp_coord = prepare_cooridinate_input(tmp_mask).transpose(1, 2, 0)

            transform_out, c_h, c_w = customRandomCrop([tmp_comp, tmp_mask, tmp_coord],
                                                       self.split_height_resolution // 2 ** (n + 1),
                                                       self.split_width_resolution // 2 ** (n + 1), c_h, c_w)
            compos_list.append(self.torch_transforms(transform_out[0]))
            mask_list.append(
                torchvision.transforms.ToTensor()(transform_out[1][..., np.newaxis].astype(np.float32)))
            coord_map_list.append(torchvision.transforms.ToTensor()(transform_out[2]))
        out_comp = compos_list
        out_mask = mask_list
        out_coord = coord_map_list

        fg_INR_coordinates, bg_INR_coordinates, fg_INR_RGB, fg_transfer_INR_RGB, bg_INR_RGB = self.INR_dataset.generator(
            self.torch_transforms, transform_out[0], transform_out[0], mask)

        return {
            'composite_image': out_comp,
            'mask': out_mask,
            'coordinate_map': out_coord,
            'composite_image0': out_comp[0],
            'mask0': out_mask[0],
            'coordinate_map0': out_coord[0],
            'composite_image1': out_comp[1],
            'mask1': out_mask[1],
            'coordinate_map1': out_coord[1],
            'composite_image2': out_comp[2],
            'mask2': out_mask[2],
            'coordinate_map2': out_coord[2],
            'composite_image3': out_comp[3],
            'mask3': out_mask[3],
            'coordinate_map3': out_coord[3],
            'fg_INR_coordinates': fg_INR_coordinates,
            'bg_INR_coordinates': bg_INR_coordinates,
            'fg_INR_RGB': fg_INR_RGB,
            'fg_transfer_INR_RGB': fg_transfer_INR_RGB,
            'bg_INR_RGB': bg_INR_RGB,
            'start_point': self.split_start_point[idx],
            'start_proportion': [self.split_start_point[idx][0] / (composite_image.shape[0]),
                                 self.split_start_point[idx][1] / (composite_image.shape[1]),
                                 (self.split_start_point[idx][0] + self.split_height_resolution) / (
                                 composite_image.shape[0]),
                                 (self.split_start_point[idx][1] + self.split_width_resolution) / (
                                 composite_image.shape[1])],
        }


def parse_args():
    parser = argparse.ArgumentParser()

    parser.add_argument('--split_resolution', type=int, default=2048,
                        help='The resolution of the patch split.')

    parser.add_argument('--composite_image', type=str, default=r'./demo/demo_2k_composite.jpg',
                        help='composite image path')

    parser.add_argument('--mask', type=str, default=r'./demo/demo_2k_mask.jpg',
                        help='mask path')

    parser.add_argument('--save_path', type=str, default=r'./demo/',
                        help='save path')

    parser.add_argument('--workers', type=int, default=8,
                        metavar='N', help='Dataloader threads.')

    parser.add_argument('--batch_size', type=int, default=1,
                        help='You can override model batch size by specify positive number.')

    parser.add_argument('--device', type=str, default='cuda',
                        help="Whether use cuda, 'cuda' or 'cpu'.")

    parser.add_argument('--base_size', type=int, default=256,
                        help='Base size. Resolution of the image input into the Encoder')

    parser.add_argument('--input_size', type=int, default=256,
                        help='Input size. Resolution of the image that want to be generated by the Decoder')

    parser.add_argument('--INR_input_size', type=int, default=256,
                        help='INR input size. Resolution of the image that want to be generated by the Decoder. '
                             'Should be the same as `input_size`')

    parser.add_argument('--INR_MLP_dim', type=int, default=32,
                        help='Number of channels for INR linear layer.')

    parser.add_argument('--LUT_dim', type=int, default=7,
                        help='Dim of the output LUT. Refer to https://ieeexplore.ieee.org/abstract/document/9206076')

    parser.add_argument('--activation', type=str, default='leakyrelu_pe',
                        help='INR activation layer type: leakyrelu_pe, sine')

    parser.add_argument('--pretrained', type=str,
                        default=r'.\pretrained_models\Resolution_RAW_iHarmony4.pth',
                        help='Pretrained weight path')

    parser.add_argument('--param_factorize_dim', type=int,
                        default=10,
                        help='The intermediate dimensions of the factorization of the predicted MLP parameters. '
                             'Refer to https://arxiv.org/abs/2011.12026')

    parser.add_argument('--embedding_type', type=str,
                        default="CIPS_embed",
                        help='Which embedding_type to use.')

    parser.add_argument('--INRDecode', action="store_false",
                        help='Whether INR decoder. Set it to False if you want to test the baseline '
                             '(https://github.com/SamsungLabs/image_harmonization)')

    parser.add_argument('--isMoreINRInput', action="store_false",
                        help='Whether to cat RGB and mask. See Section 3.4 in the paper.')

    parser.add_argument('--hr_train', action="store_false",
                        help='Whether use hr_train. See section 3.4 in the paper.')

    parser.add_argument('--isFullRes', action="store_true",
                        help='Whether for original resolution. See section 3.4 in the paper.')

    opt = parser.parse_args()

    assert opt.batch_size == 1, 'This faster script only supports batch size 1 for inference.'

    return opt


@torch.no_grad()
def inference(model, opt, composite_image=None, mask=None):
    model.eval()

    "dataset here is actually consisted of several patches of a single image."
    singledataset = single_image_dataset(opt, composite_image, mask)

    single_data_loader = DataLoader(singledataset, opt.batch_size, shuffle=False, drop_last=False, pin_memory=True,
                                    num_workers=opt.workers, persistent_workers=False if composite_image is not None else True)

    "Init a pure black image with the same size as the input image."
    init_img = np.zeros_like(singledataset.composite_image)

    time_all = 0

    for step, batch in tqdm.tqdm(enumerate(single_data_loader)):
        composite_image = [batch[f'composite_image{name}'].to(opt.device) for name in range(4)]
        mask = [batch[f'mask{name}'].to(opt.device) for name in range(4)]
        coordinate_map = [batch[f'coordinate_map{name}'].to(opt.device) for name in range(4)]
        start_points = batch['start_point']
        start_proportion = batch['start_proportion']

        if opt.batch_size == 1:
            start_points = [torch.cat(start_points)]
            start_proportion = [torch.cat(start_proportion)]

        fg_INR_coordinates = coordinate_map[1:]

        try:
            if global_state[0] == 0:
                print("Stop Harmonizing...!")
                break

            if step == 0:  # This is for CUDA Kernel Warm-up, or the first inference step will be quite slow.
                fg_content_bg_appearance_construct, _, lut_transform_image = model(
                    composite_image,
                    mask,
                    fg_INR_coordinates, start_proportion[0]
                )
                print("Ready for harmonization...")
            if opt.device == "cuda":
                torch.cuda.reset_max_memory_allocated()
                torch.cuda.reset_max_memory_cached()
                start_time = time.time()
                torch.cuda.synchronize()
            fg_content_bg_appearance_construct, _, lut_transform_image = model(
                composite_image,
                mask,
                fg_INR_coordinates, start_proportion[0]
            )
            if opt.device == "cuda":
                torch.cuda.synchronize()
                end_time = time.time()

                end_max_memory = torch.cuda.max_memory_allocated() // 1024 ** 2
                end_memory = torch.cuda.memory_allocated() // 1024 ** 2

                print(f'GPU max memory usage: {end_max_memory} MB')
                print(f'GPU memory usage: {end_memory} MB')
                time_all += (end_time - start_time)
            print(f'progress: {step} / {len(single_data_loader)}')
        except:
            raise Exception(
                f'The image resolution is large. Please reduce the `split_resolution` value. Your current set is {opt.split_resolution}')

        "Assemble the every patch's harmonized result into the final whole image."
        for id in range(len(fg_INR_coordinates[0])):
            pred_fg_image = fg_content_bg_appearance_construct[-1][id]
            pred_harmonized_image = pred_fg_image * (mask[1][id] > 100 / 255.) + composite_image[1][id] * (
                ~(mask[1][id] > 100 / 255.))

            pred_harmonized_tmp = cv2.cvtColor(
                normalize(pred_harmonized_image.unsqueeze(0), opt, 'inv')[0].permute(1, 2, 0).cpu().mul_(255.).clamp_(
                    0., 255.).numpy().astype(np.uint8), cv2.COLOR_RGB2BGR)

            init_img[start_points[id][0]:start_points[id][0] + singledataset.split_height_resolution,
            start_points[id][1]:start_points[id][1] + singledataset.split_width_resolution] = pred_harmonized_tmp

    if opt.device == "cuda":
        print(f'Inference time: {time_all}')
    if opt.save_path is not None:
        os.makedirs(opt.save_path, exist_ok=True)
        cv2.imwrite(os.path.join(opt.save_path, "pred_harmonized_image.jpg"), init_img)
    return init_img


def main_process(opt, composite_image=None, mask=None):
    cudnn.benchmark = True

    print("Preparing model...")
    model = build_model(opt).to(opt.device)

    load_dict = torch.load(opt.pretrained, map_location='cpu')['model']

    model.load_state_dict(load_dict, strict=False)

    return inference(model, opt, composite_image, mask)


if __name__ == '__main__':
    opt = parse_args()
    opt.transform_mean = [.5, .5, .5]
    opt.transform_var = [.5, .5, .5]
    main_process(opt)