import argparse


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-b",
        "--batch_size",
        type=int,
        default=4,
        help="Batch size to process input images to events. Defaults to 4",
    )
    parser.add_argument(
        "-i",
        "--images_paths",
        type=str,
        required=True,
        help="Path to a directory with image files",
    )
    parser.add_argument(
        "-o",
        "--output_path",
        type=str,
        default=None,
        help="Path to a directory were events should be written. "
        + "Will NOT write anything to disk if this flag is not used.",
    )
    parser.add_argument(
        "-s",
        "--save_input",
        action="store_true",
        default=False,
        help="Binary flag to include the input image to the model (after crop and"
        + " resize) in the images written or uploaded (depending on saving options.)",
    )
    parser.add_argument(
        "-r",
        "--resume_path",
        type=str,
        default=None,
        help="Path to a directory containing the trainer to resume."
        + " In particular it must contain `opts.yam` and `checkpoints/`."
        + " Typically this points to a Masker, which holds the path to a"
        + " Painter in its opts",
    )
    parser.add_argument(
        "--no_time",
        action="store_true",
        default=False,
        help="Binary flag to prevent the timing of operations.",
    )
    parser.add_argument(
        "-f",
        "--flood_mask_binarization",
        type=float,
        default=0.5,
        help="Value to use to binarize masks (mask > value). "
        + "Set to -1 to use soft masks (not binarized). Defaults to 0.5.",
    )
    parser.add_argument(
        "-t",
        "--target_size",
        type=int,
        default=640,
        help="Output image size (when not using `keep_ratio_128`): images are resized"
        + " such that their smallest side is `target_size` then cropped in the middle"
        + " of the largest side such that the resulting input image (and output images)"
        + " has height and width `target_size x target_size`. **Must** be a multiple of"
        + " 2^7=128 (up/downscaling inside the models). Defaults to 640.",
    )
    parser.add_argument(
        "--half",
        action="store_true",
        default=False,
        help="Binary flag to use half precision (float16). Defaults to False.",
    )
    parser.add_argument(
        "-n",
        "--n_images",
        default=-1,
        type=int,
        help="Limit the number of images processed (if you have 100 images in "
        + "a directory but n is 10 then only the first 10 images will be loaded"
        + " for processing)",
    )
    parser.add_argument(
        "--no_conf",
        action="store_true",
        default=False,
        help="disable writing the apply_events hash and command in the output folder",
    )
    parser.add_argument(
        "--overwrite",
        action="store_true",
        default=False,
        help="Do not check for existing outdir, i.e. force overwrite"
        + " potentially existing files in the output path",
    )
    parser.add_argument(
        "--no_cloudy",
        action="store_true",
        default=False,
        help="Prevent the use of the cloudy intermediate"
        + " image to create the flood image. Rendering will"
        + " be more colorful but may seem less realistic",
    )
    parser.add_argument(
        "--keep_ratio_128",
        action="store_true",
        default=False,
        help="When loading the input images, resize and crop them in order for their "
        + "dimensions to match the closest multiples"
        + " of 128. Will force a batch size of 1 since images"
        + " now have different dimensions. "
        + "Use --max_im_width to cap the resulting dimensions.",
    )
    parser.add_argument(
        "--fuse",
        action="store_true",
        default=False,
        help="Use batch norm fusion to speed up inference",
    )
    parser.add_argument(
        "--save_masks",
        action="store_true",
        default=False,
        help="Save output masks along events",
    )
    parser.add_argument(
        "-m",
        "--max_im_width",
        type=int,
        default=-1,
        help="When using --keep_ratio_128, some images may still be too large. Use "
        + "--max_im_width to cap the resized image's width. Defaults to -1 (no cap).",
    )
    parser.add_argument(
        "--upload",
        action="store_true",
        help="Upload to comet.ml in a project called `climategan-apply`",
    )
    parser.add_argument(
        "--zip_outdir",
        "-z",
        action="store_true",
        help="Zip the output directory as '{outdir.parent}/{outdir.name}.zip'",
    )
    return parser.parse_args()


args = parse_args()


print("\n• Imports\n")
import time

import_time = time.time()
import sys
import shutil
from collections import OrderedDict
from pathlib import Path

import comet_ml  # noqa: F401
import torch
import numpy as np
import skimage.io as io
from skimage.color import rgba2rgb
from skimage.transform import resize
from tqdm import tqdm

from climategan.trainer import Trainer
from climategan.bn_fusion import bn_fuse
from climategan.tutils import print_num_parameters
from climategan.utils import Timer, find_images, get_git_revision_hash, to_128, resolve

import_time = time.time() - import_time


def to_m1_p1(img, i):
    """
    rescales a [0, 1] image to [-1, +1]

    Args:
        img (np.array): float32 numpy array of an image in [0, 1]
        i (int): Index of the image being rescaled

    Raises:
        ValueError: If the image is not in [0, 1]

    Returns:
        np.array(np.float32): array in [-1, +1]
    """
    if img.min() >= 0 and img.max() <= 1:
        return (img.astype(np.float32) - 0.5) * 2
    raise ValueError(f"Data range mismatch for image {i} : ({img.min()}, {img.max()})")


def uint8(array):
    """
    convert an array to np.uint8 (does not rescale or anything else than changing dtype)

    Args:
        array (np.array): array to modify

    Returns:
        np.array(np.uint8): converted array
    """
    return array.astype(np.uint8)


def resize_and_crop(img, to=640):
    """
    Resizes an image so that it keeps the aspect ratio and the smallest dimensions
    is `to`, then crops this resized image in its center so that the output is `to x to`
    without aspect ratio distortion

    Args:
        img (np.array): np.uint8 255 image

    Returns:
        np.array: [0, 1] np.float32 image
    """
    # resize keeping aspect ratio: smallest dim is 640
    h, w = img.shape[:2]
    if h < w:
        size = (to, int(to * w / h))
    else:
        size = (int(to * h / w), to)

    r_img = resize(img, size, preserve_range=True, anti_aliasing=True)
    r_img = uint8(r_img)

    # crop in the center
    H, W = r_img.shape[:2]

    top = (H - to) // 2
    left = (W - to) // 2

    rc_img = r_img[top : top + to, left : left + to, :]

    return rc_img / 255.0


def print_time(text, time_series, purge=-1):
    """
    Print a timeseries's mean and std with a label

    Args:
        text (str): label of the time series
        time_series (list): list of timings
        purge (int, optional): ignore first n values of time series. Defaults to -1.
    """
    if not time_series:
        return

    if purge > 0 and len(time_series) > purge:
        time_series = time_series[purge:]

    m = np.mean(time_series)
    s = np.std(time_series)

    print(
        f"{text.capitalize() + ' ':.<26}  {m:.5f}"
        + (f" +/- {s:.5f}" if len(time_series) > 1 else "")
    )


def print_store(store, purge=-1):
    """
    Pretty-print time series store

    Args:
        store (dict): maps string keys to lists of times
        purge (int, optional): ignore first n values of time series. Defaults to -1.
    """
    singles = OrderedDict({k: v for k, v in store.items() if len(v) == 1})
    multiples = OrderedDict({k: v for k, v in store.items() if len(v) > 1})
    empties = {k: v for k, v in store.items() if len(v) == 0}

    if empties:
        print("Ignoring empty stores ", ", ".join(empties.keys()))
        print()

    for k in singles:
        print_time(k, singles[k], purge)

    print()
    print("Unit: s/batch")
    for k in multiples:
        print_time(k, multiples[k], purge)
    print()


def write_apply_config(out):
    """
    Saves the args to `apply_events.py` in a text file for future reference
    """
    cwd = Path.cwd().expanduser().resolve()
    command = f"cd {str(cwd)}\n"
    command += " ".join(sys.argv)
    git_hash = get_git_revision_hash()
    with (out / "command.txt").open("w") as f:
        f.write(command)
    with (out / "hash.txt").open("w") as f:
        f.write(git_hash)


def get_outdir_name(half, keep_ratio, max_im_width, target_size, bin_value, cloudy):
    """
    Create the output directory's name based on uer-provided arguments
    """
    name_items = []
    if half:
        name_items.append("half")
    if keep_ratio:
        name_items.append("AR")
    if max_im_width and keep_ratio:
        name_items.append(f"{max_im_width}")
    if target_size and not keep_ratio:
        name_items.append("S")
        name_items.append(f"{target_size}")
    if bin_value != 0.5:
        name_items.append(f"bin{bin_value}")
    if not cloudy:
        name_items.append("no_cloudy")

    return "-".join(name_items)


def make_outdir(
    outdir, overwrite, half, keep_ratio, max_im_width, target_size, bin_value, cloudy
):
    """
    Creates the output directory if it does not exist. If it does exist,
    prompts the user for confirmation (except if `overwrite` is True).
    If the output directory's name is "_auto_" then it is created as:
        outdir.parent / get_outdir_name(...)
    """
    if outdir.name == "_auto_":
        outdir = outdir.parent / get_outdir_name(
            half, keep_ratio, max_im_width, target_size, bin_value, cloudy
        )
    if outdir.exists() and not overwrite:
        print(
            f"\nWARNING: outdir ({str(outdir)}) already exists."
            + " Files with existing names will be overwritten"
        )
        if "n" in input(">>> Continue anyway? [y / n] (default: y) : "):
            print("Interrupting execution from user input.")
            sys.exit()
        print()
    outdir.mkdir(exist_ok=True, parents=True)
    return outdir


def get_time_stores(import_time):
    return OrderedDict(
        {
            "imports": [import_time],
            "setup": [],
            "data pre-processing": [],
            "encode": [],
            "mask": [],
            "flood": [],
            "depth": [],
            "segmentation": [],
            "smog": [],
            "wildfire": [],
            "all events": [],
            "numpy": [],
            "inference on all images": [],
            "write": [],
        }
    )


if __name__ == "__main__":

    # -----------------------------------------
    # -----  Initialize script variables  -----
    # -----------------------------------------
    print(
        "• Using args\n\n"
        + "\n".join(["{:25}: {}".format(k, v) for k, v in vars(args).items()]),
    )

    batch_size = args.batch_size
    bin_value = args.flood_mask_binarization
    cloudy = not args.no_cloudy
    fuse = args.fuse
    half = args.half
    save_masks = args.save_masks
    images_paths = resolve(args.images_paths)
    keep_ratio = args.keep_ratio_128
    max_im_width = args.max_im_width
    n_images = args.n_images
    outdir = resolve(args.output_path) if args.output_path is not None else None
    resume_path = args.resume_path
    target_size = args.target_size
    time_inference = not args.no_time
    upload = args.upload
    zip_outdir = args.zip_outdir

    # -------------------------------------
    # -----  Validate size arguments  -----
    # -------------------------------------
    if keep_ratio:
        if target_size != 640:
            print(
                "\nWARNING: using --keep_ratio_128 overwrites target_size"
                + " which is ignored."
            )
        if batch_size != 1:
            print("\nWARNING: batch_size overwritten to 1 when using keep_ratio_128")
            batch_size = 1
        if max_im_width > 0 and max_im_width % 128 != 0:
            new_im_width = int(max_im_width / 128) * 128
            print("\nWARNING: max_im_width should be <0 or a multiple of 128.")
            print(
                "            Was {} but is now overwritten to {}".format(
                    max_im_width, new_im_width
                )
            )
            max_im_width = new_im_width
    else:
        if target_size % 128 != 0:
            print(f"\nWarning: target size {target_size} is not a multiple of 128.")
            target_size = target_size - (target_size % 128)
            print(f"Setting target_size to {target_size}.")

    # -------------------------------------
    # -----  Create output directory  -----
    # -------------------------------------
    if outdir is not None:
        outdir = make_outdir(
            outdir,
            args.overwrite,
            half,
            keep_ratio,
            max_im_width,
            target_size,
            bin_value,
            cloudy,
        )

    # -------------------------------
    # -----  Create time store  -----
    # -------------------------------
    stores = get_time_stores(import_time)

    # -----------------------------------
    # -----  Load Trainer instance  -----
    # -----------------------------------
    with Timer(store=stores.get("setup", []), ignore=time_inference):
        print("\n• Initializing trainer\n")
        torch.set_grad_enabled(False)
        trainer = Trainer.resume_from_path(
            resume_path,
            setup=True,
            inference=True,
            new_exp=None,
        )
        print()
        print_num_parameters(trainer, True)
        if fuse:
            trainer.G = bn_fuse(trainer.G)
        if half:
            trainer.G.half()

    # --------------------------------------------
    # -----  Read data from input directory  -----
    # --------------------------------------------
    print("\n• Reading & Pre-processing Data\n")

    # find all images
    data_paths = find_images(images_paths)
    base_data_paths = data_paths
    # filter images
    if 0 < n_images < len(data_paths):
        data_paths = data_paths[:n_images]
    # repeat data
    elif n_images > len(data_paths):
        repeats = n_images // len(data_paths) + 1
        data_paths = base_data_paths * repeats
        data_paths = data_paths[:n_images]

    with Timer(store=stores.get("data pre-processing", []), ignore=time_inference):
        # read images to numpy arrays
        data = [io.imread(str(d)) for d in data_paths]
        # rgba to rgb
        data = [im if im.shape[-1] == 3 else uint8(rgba2rgb(im) * 255) for im in data]
        # resize images to target_size or
        if keep_ratio:
            # to closest multiples of 128 <= max_im_width, keeping aspect ratio
            new_sizes = [to_128(d, max_im_width) for d in data]
            data = [resize(d, ns, anti_aliasing=True) for d, ns in zip(data, new_sizes)]
        else:
            # to args.target_size
            data = [resize_and_crop(d, target_size) for d in data]
            new_sizes = [(target_size, target_size) for _ in data]
        # resize() produces [0, 1] images, rescale to [-1, 1]
        data = [to_m1_p1(d, i) for i, d in enumerate(data)]

    n_batchs = len(data) // batch_size
    if len(data) % batch_size != 0:
        n_batchs += 1

    print("Found", len(base_data_paths), "images. Inferring on", len(data), "images.")

    # --------------------------------------------
    # -----  Batch-process images to events  -----
    # --------------------------------------------
    print(f"\n• Using device {str(trainer.device)}\n")

    all_events = []

    with Timer(store=stores.get("inference on all images", []), ignore=time_inference):
        for b in tqdm(range(n_batchs), desc="Infering events", unit="batch"):

            images = data[b * batch_size : (b + 1) * batch_size]
            if not images:
                continue

            # concatenate images in a batch batch_size x height x width x 3
            images = np.stack(images)
            # Retreive numpy events as a dict {event: array[BxHxWxC]}
            events = trainer.infer_all(
                images,
                numpy=True,
                stores=stores,
                bin_value=bin_value,
                half=half,
                cloudy=cloudy,
                return_masks=save_masks,
            )

            # save resized and cropped image
            if args.save_input:
                events["input"] = uint8((images + 1) / 2 * 255)

            # store events to write after inference loop
            all_events.append(events)

    # --------------------------------------------
    # -----  Save (write/upload) inferences  -----
    # --------------------------------------------
    if outdir is not None or upload:

        if upload:
            print("\n• Creating comet Experiment")
            exp = comet_ml.Experiment(project_name="climategan-apply")
            exp.log_parameters(vars(args))

        # --------------------------------------------------------------
        # -----  Change inferred data structure to a list of dicts  -----
        # --------------------------------------------------------------
        to_write = []
        events_names = list(all_events[0].keys())
        for events_data in all_events:
            n_ims = len(events_data[events_names[0]])
            for i in range(n_ims):
                item = {event: events_data[event][i] for event in events_names}
                to_write.append(item)

        progress_bar_desc = ""
        if outdir is not None:
            print("\n• Output directory:\n")
            print(str(outdir), "\n")
            if upload:
                progress_bar_desc = "Writing & Uploading events"
            else:
                progress_bar_desc = "Writing events"
        else:
            if upload:
                progress_bar_desc = "Uploading events"

        # ------------------------------------
        # -----  Save individual images  -----
        # ------------------------------------
        with Timer(store=stores.get("write", []), ignore=time_inference):

            # for each image
            for t, event_dict in tqdm(
                enumerate(to_write),
                desc=progress_bar_desc,
                unit="input image",
                total=len(to_write),
            ):

                idx = t % len(base_data_paths)
                stem = Path(data_paths[idx]).stem
                width = new_sizes[idx][1]

                if keep_ratio:
                    ar = "_AR"
                else:
                    ar = ""

                # for each event type
                event_bar = tqdm(
                    enumerate(event_dict.items()),
                    leave=False,
                    total=len(events_names),
                    unit="event",
                )
                for e, (event, im_data) in event_bar:
                    event_bar.set_description(
                        f"  {event.capitalize():<{len(progress_bar_desc) - 2}}"
                    )

                    if args.no_cloudy:
                        suffix = ar + "_no_cloudy"
                    else:
                        suffix = ar

                    im_path = Path(f"{stem}_{event}_{width}{suffix}.png")

                    if outdir is not None:
                        im_path = outdir / im_path
                        io.imsave(im_path, im_data)

                    if upload:
                        exp.log_image(im_data, name=im_path.name)
    if zip_outdir:
        print("\n• Zipping output directory... ", end="", flush=True)
        archive_path = Path(shutil.make_archive(outdir.name, "zip", root_dir=outdir))
        archive_path = archive_path.rename(outdir.parent / archive_path.name)
        print("Done:\n")
        print(str(archive_path))

    # ---------------------------
    # -----  Print timings  -----
    # ---------------------------
    if time_inference:
        print("\n• Timings\n")
        print_store(stores)

    # ---------------------------------------------
    # -----  Save apply_events.py run config  -----
    # ---------------------------------------------
    if not args.no_conf and outdir is not None:
        write_apply_config(outdir)