from typing import List

import cv2
import numpy as np
import torch
from controlnet_aux import OpenposeDetector
from diffusers import (
    ControlNetModel,
    DiffusionPipeline,
    StableDiffusionControlNetPipeline,
    UniPCMultistepScheduler,
)
from PIL import Image
from tqdm import gui

from internals.data.result import Result
from internals.pipelines.commons import AbstractPipeline
from internals.util.cache import clear_cuda_and_gc
from internals.util.commons import download_image


class ControlNet(AbstractPipeline):
    __current_task_name = ""

    def load(self, model_dir: str):
        # we will load canny by default
        self.load_canny()

        # controlnet pipeline for canny and pose
        pipe = DiffusionPipeline.from_pretrained(
            model_dir,
            controlnet=self.controlnet,
            torch_dtype=torch.float16,
            custom_pipeline="stable_diffusion_controlnet_img2img",
        ).to("cuda")
        pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
        pipe.enable_model_cpu_offload()
        pipe.enable_xformers_memory_efficient_attention()
        self.pipe = pipe

        # controlnet pipeline for tile upscaler
        pipe2 = StableDiffusionControlNetPipeline(**pipe.components).to("cuda")
        pipe2.scheduler = UniPCMultistepScheduler.from_config(pipe2.scheduler.config)
        pipe2.enable_xformers_memory_efficient_attention()
        self.pipe2 = pipe2

    def load_canny(self):
        if self.__current_task_name == "canny":
            return
        canny = ControlNetModel.from_pretrained(
            "lllyasviel/control_v11p_sd15_canny", torch_dtype=torch.float16
        ).to("cuda")
        self.__current_task_name = "canny"
        self.controlnet = canny
        if hasattr(self, "pipe"):
            self.pipe.controlnet = canny
        if hasattr(self, "pipe2"):
            self.pipe2.controlnet = canny
        clear_cuda_and_gc()

    def load_pose(self):
        if self.__current_task_name == "pose":
            return
        pose = ControlNetModel.from_pretrained(
            "lllyasviel/control_v11p_sd15_openpose", torch_dtype=torch.float16
        ).to("cuda")
        self.__current_task_name = "pose"
        self.controlnet = pose
        if hasattr(self, "pipe"):
            self.pipe.controlnet = pose
        if hasattr(self, "pipe2"):
            self.pipe2.controlnet = pose
        clear_cuda_and_gc()

    def load_tile_upscaler(self):
        if self.__current_task_name == "tile_upscaler":
            return
        tile_upscaler = ControlNetModel.from_pretrained(
            "lllyasviel/control_v11f1e_sd15_tile", torch_dtype=torch.float16
        ).to("cuda")
        self.__current_task_name = "tile_upscaler"
        self.controlnet = tile_upscaler
        if hasattr(self, "pipe"):
            self.pipe.controlnet = tile_upscaler
        if hasattr(self, "pipe2"):
            self.pipe2.controlnet = tile_upscaler
        clear_cuda_and_gc()

    def cleanup(self):
        self.pipe.controlnet = None
        self.pipe2.controlnet = None
        self.controlnet = None
        self.__current_task_name = ""

        clear_cuda_and_gc()

    @torch.inference_mode()
    def process_canny(
        self,
        prompt: List[str],
        imageUrl: str,
        seed: int,
        steps: int,
        negative_prompt: List[str],
        guidance_scale: float,
        height: int,
        width: int,
    ):
        if self.__current_task_name != "canny":
            raise Exception("ControlNet is not loaded with canny model")

        torch.manual_seed(seed)

        init_image = download_image(imageUrl).resize((width, height))
        init_image = self.__canny_detect_edge(init_image)

        result = self.pipe2.__call__(
            prompt=prompt,
            image=init_image,
            guidance_scale=guidance_scale,
            num_images_per_prompt=1,
            negative_prompt=negative_prompt,
            num_inference_steps=steps,
            height=height,
            width=width,
        )
        return Result.from_result(result)

    @torch.inference_mode()
    def process_pose(
        self,
        prompt: List[str],
        image: List[Image.Image],
        seed: int,
        steps: int,
        guidance_scale: float,
        negative_prompt: List[str],
        height: int,
        width: int,
    ):
        if self.__current_task_name != "pose":
            raise Exception("ControlNet is not loaded with pose model")

        torch.manual_seed(seed)

        result = self.pipe2.__call__(
            prompt=prompt,
            image=image,
            num_images_per_prompt=1,
            num_inference_steps=steps,
            negative_prompt=negative_prompt,
            guidance_scale=guidance_scale,
            height=height,
            width=width,
        )
        return Result.from_result(result)

    @torch.inference_mode()
    def process_tile_upscaler(
        self,
        imageUrl: str,
        prompt: str,
        negative_prompt: str,
        steps: int,
        seed: int,
        height: int,
        width: int,
        resize_dimension: int,
        guidance_scale: float,
    ):
        if self.__current_task_name != "tile_upscaler":
            raise Exception("ControlNet is not loaded with tile_upscaler model")

        torch.manual_seed(seed)

        init_image = download_image(imageUrl).resize((width, height))
        condition_image = self.__resize_for_condition_image(
            init_image, resize_dimension
        )

        result = self.pipe.__call__(
            image=condition_image,
            prompt=prompt,
            controlnet_conditioning_image=condition_image,
            num_inference_steps=steps,
            negative_prompt=negative_prompt,
            height=condition_image.size[1],
            width=condition_image.size[0],
            strength=1.0,
            guidance_scale=guidance_scale,
        )
        return Result.from_result(result)

    def detect_pose(self, imageUrl: str) -> Image.Image:
        detector = OpenposeDetector.from_pretrained("lllyasviel/ControlNet")
        image = download_image(imageUrl)
        image = detector.__call__(image, hand_and_face=True)
        return image

    def __canny_detect_edge(self, image: Image.Image) -> Image.Image:
        image_array = np.array(image)

        low_threshold = 100
        high_threshold = 200

        image_array = cv2.Canny(image_array, low_threshold, high_threshold)
        image_array = image_array[:, :, None]
        image_array = np.concatenate([image_array, image_array, image_array], axis=2)
        canny_image = Image.fromarray(image_array)
        return canny_image

    def __resize_for_condition_image(self, image: Image.Image, resolution: int):
        input_image = image.convert("RGB")
        W, H = input_image.size
        k = float(resolution) / min(W, H)
        H *= k
        W *= k
        H = int(round(H / 64.0)) * 64
        W = int(round(W / 64.0)) * 64
        img = input_image.resize((W, H), resample=Image.LANCZOS)
        return img