# Copyright 2019-present NAVER Corp.
# CC BY-NC-SA 3.0
# Available only for non-commercial use

import os, pdb
import numpy as np
from PIL import Image

from .dataset import Dataset, CatDataset
from tools.transforms import instanciate_transformation
from tools.transforms_tools import persp_apply


class PairDataset(Dataset):
    """A dataset that serves image pairs with ground-truth pixel correspondences."""

    def __init__(self):
        Dataset.__init__(self)
        self.npairs = 0

    def get_filename(self, img_idx, root=None):
        if is_pair(
            img_idx
        ):  # if img_idx is a pair of indices, we return a pair of filenames
            return tuple(Dataset.get_filename(self, i, root) for i in img_idx)
        return Dataset.get_filename(self, img_idx, root)

    def get_image(self, img_idx):
        if is_pair(
            img_idx
        ):  # if img_idx is a pair of indices, we return a pair of images
            return tuple(Dataset.get_image(self, i) for i in img_idx)
        return Dataset.get_image(self, img_idx)

    def get_corres_filename(self, pair_idx):
        raise NotImplementedError()

    def get_homography_filename(self, pair_idx):
        raise NotImplementedError()

    def get_flow_filename(self, pair_idx):
        raise NotImplementedError()

    def get_mask_filename(self, pair_idx):
        raise NotImplementedError()

    def get_pair(self, idx, output=()):
        """returns (img1, img2, `metadata`)

        `metadata` is a dict() that can contain:
            flow: optical flow
            aflow: absolute flow
            corres: list of 2d-2d correspondences
            mask: boolean image of flow validity (in the first image)
            ...
        """
        raise NotImplementedError()

    def get_paired_images(self):
        fns = set()
        for i in range(self.npairs):
            a, b = self.image_pairs[i]
            fns.add(self.get_filename(a))
            fns.add(self.get_filename(b))
        return fns

    def __len__(self):
        return self.npairs  # size should correspond to the number of pairs, not images

    def __repr__(self):
        res = "Dataset: %s\n" % self.__class__.__name__
        res += "  %d images," % self.nimg
        res += " %d image pairs" % self.npairs
        res += "\n  root: %s...\n" % self.root
        return res

    @staticmethod
    def _flow2png(flow, path):
        flow = np.clip(np.around(16 * flow), -(2**15), 2**15 - 1)
        bytes = np.int16(flow).view(np.uint8)
        Image.fromarray(bytes).save(path)
        return flow / 16

    @staticmethod
    def _png2flow(path):
        try:
            flow = np.asarray(Image.open(path)).view(np.int16)
            return np.float32(flow) / 16
        except:
            raise IOError("Error loading flow for %s" % path)


class StillPairDataset(PairDataset):
    """A dataset of 'still' image pairs.
    By overloading a normal image dataset, it appends the get_pair(i) function
    that serves trivial image pairs (img1, img2) where img1 == img2 == get_image(i).
    """

    def get_pair(self, pair_idx, output=()):
        if isinstance(output, str):
            output = output.split()
        img1, img2 = map(self.get_image, self.image_pairs[pair_idx])

        W, H = img1.size
        sx = img2.size[0] / float(W)
        sy = img2.size[1] / float(H)

        meta = {}
        if "aflow" in output or "flow" in output:
            mgrid = np.mgrid[0:H, 0:W][::-1].transpose(1, 2, 0).astype(np.float32)
            meta["aflow"] = mgrid * (sx, sy)
            meta["flow"] = meta["aflow"] - mgrid

        if "mask" in output:
            meta["mask"] = np.ones((H, W), np.uint8)

        if "homography" in output:
            meta["homography"] = np.diag(np.float32([sx, sy, 1]))

        return img1, img2, meta


class SyntheticPairDataset(PairDataset):
    """A synthetic generator of image pairs.
    Given a normal image dataset, it constructs pairs using random homographies & noise.
    """

    def __init__(self, dataset, scale="", distort=""):
        self.attach_dataset(dataset)
        self.distort = instanciate_transformation(distort)
        self.scale = instanciate_transformation(scale)

    def attach_dataset(self, dataset):
        assert isinstance(dataset, Dataset) and not isinstance(dataset, PairDataset)
        self.dataset = dataset
        self.npairs = dataset.nimg
        self.get_image = dataset.get_image
        self.get_key = dataset.get_key
        self.get_filename = dataset.get_filename
        self.root = None

    def make_pair(self, img):
        return img, img

    def get_pair(self, i, output=("aflow")):
        """Procedure:
        This function applies a series of random transformations to one original image
        to form a synthetic image pairs with perfect ground-truth.
        """
        if isinstance(output, str):
            output = output.split()

        original_img = self.dataset.get_image(i)

        scaled_image = self.scale(original_img)
        scaled_image, scaled_image2 = self.make_pair(scaled_image)
        scaled_and_distorted_image = self.distort(
            dict(img=scaled_image2, persp=(1, 0, 0, 0, 1, 0, 0, 0))
        )
        W, H = scaled_image.size
        trf = scaled_and_distorted_image["persp"]

        meta = dict()
        if "aflow" in output or "flow" in output:
            # compute optical flow
            xy = np.mgrid[0:H, 0:W][::-1].reshape(2, H * W).T
            aflow = np.float32(persp_apply(trf, xy).reshape(H, W, 2))
            meta["flow"] = aflow - xy.reshape(H, W, 2)
            meta["aflow"] = aflow

        if "homography" in output:
            meta["homography"] = np.float32(trf + (1,)).reshape(3, 3)

        return scaled_image, scaled_and_distorted_image["img"], meta

    def __repr__(self):
        res = "Dataset: %s\n" % self.__class__.__name__
        res += "  %d images and pairs" % self.npairs
        res += "\n  root: %s..." % self.dataset.root
        res += "\n  Scale: %s" % (repr(self.scale).replace("\n", ""))
        res += "\n  Distort: %s" % (repr(self.distort).replace("\n", ""))
        return res + "\n"


class TransformedPairs(PairDataset):
    """Automatic data augmentation for pre-existing image pairs.
    Given an image pair dataset, it generates synthetically jittered pairs
    using random transformations (e.g. homographies & noise).
    """

    def __init__(self, dataset, trf=""):
        self.attach_dataset(dataset)
        self.trf = instanciate_transformation(trf)

    def attach_dataset(self, dataset):
        assert isinstance(dataset, PairDataset)
        self.dataset = dataset
        self.nimg = dataset.nimg
        self.npairs = dataset.npairs
        self.get_image = dataset.get_image
        self.get_key = dataset.get_key
        self.get_filename = dataset.get_filename
        self.root = None

    def get_pair(self, i, output=""):
        """Procedure:
        This function applies a series of random transformations to one original image
        to form a synthetic image pairs with perfect ground-truth.
        """
        img_a, img_b_, metadata = self.dataset.get_pair(i, output)

        img_b = self.trf({"img": img_b_, "persp": (1, 0, 0, 0, 1, 0, 0, 0)})
        trf = img_b["persp"]

        if "aflow" in metadata or "flow" in metadata:
            aflow = metadata["aflow"]
            aflow[:] = persp_apply(trf, aflow.reshape(-1, 2)).reshape(aflow.shape)
            W, H = img_a.size
            flow = metadata["flow"]
            mgrid = np.mgrid[0:H, 0:W][::-1].transpose(1, 2, 0).astype(np.float32)
            flow[:] = aflow - mgrid

        if "corres" in metadata:
            corres = metadata["corres"]
            corres[:, 1] = persp_apply(trf, corres[:, 1])

        if "homography" in metadata:
            # p_b = homography * p_a
            trf_ = np.float32(trf + (1,)).reshape(3, 3)
            metadata["homography"] = np.float32(trf_ @ metadata["homography"])

        return img_a, img_b["img"], metadata

    def __repr__(self):
        res = "Transformed Pairs from %s\n" % type(self.dataset).__name__
        res += "  %d images and pairs" % self.npairs
        res += "\n  root: %s..." % self.dataset.root
        res += "\n  transform: %s" % (repr(self.trf).replace("\n", ""))
        return res + "\n"


class CatPairDataset(CatDataset):
    """Concatenation of several pair datasets."""

    def __init__(self, *datasets):
        CatDataset.__init__(self, *datasets)
        pair_offsets = [0]
        for db in datasets:
            pair_offsets.append(db.npairs)
        self.pair_offsets = np.cumsum(pair_offsets)
        self.npairs = self.pair_offsets[-1]

    def __len__(self):
        return self.npairs

    def __repr__(self):
        fmt_str = "CatPairDataset("
        for db in self.datasets:
            fmt_str += str(db).replace("\n", " ") + ", "
        return fmt_str[:-2] + ")"

    def pair_which(self, i):
        pos = np.searchsorted(self.pair_offsets, i, side="right") - 1
        assert pos < self.npairs, "Bad pair index %d >= %d" % (i, self.npairs)
        return pos, i - self.pair_offsets[pos]

    def pair_call(self, func, i, *args, **kwargs):
        b, j = self.pair_which(i)
        return getattr(self.datasets[b], func)(j, *args, **kwargs)

    def get_pair(self, i, output=()):
        b, i = self.pair_which(i)
        return self.datasets[b].get_pair(i, output)

    def get_flow_filename(self, pair_idx, *args, **kwargs):
        return self.pair_call("get_flow_filename", pair_idx, *args, **kwargs)

    def get_mask_filename(self, pair_idx, *args, **kwargs):
        return self.pair_call("get_mask_filename", pair_idx, *args, **kwargs)

    def get_corres_filename(self, pair_idx, *args, **kwargs):
        return self.pair_call("get_corres_filename", pair_idx, *args, **kwargs)


def is_pair(x):
    if isinstance(x, (tuple, list)) and len(x) == 2:
        return True
    if isinstance(x, np.ndarray) and x.ndim == 1 and x.shape[0] == 2:
        return True
    return False