from typing import Optional

import torch
from diffusers import ControlNetModel, StableDiffusionControlNetImg2ImgPipeline
from PIL import Image

import internals.util.image as ImageUtil
from internals.pipelines.commons import AbstractPipeline
from internals.pipelines.controlnets import ControlNet
from internals.pipelines.high_res import HighRes
from internals.pipelines.sdxl_llite_pipeline import SDXLLLiteImg2ImgPipeline
from internals.util.config import get_base_dimension, get_hf_cache_dir, get_is_sdxl


class RealtimeDraw(AbstractPipeline):
    def load(self, pipeline: AbstractPipeline):
        if hasattr(self, "pipe"):
            return

        if get_is_sdxl():
            lite_pipe = SDXLLLiteImg2ImgPipeline()
            lite_pipe.load(
                pipeline,
                [
                    "https://s3.ap-south-1.amazonaws.com/autodraft.model.assets/models/replicate-xl-llite.safetensors"
                ],
            )
            self.pipe = lite_pipe
        else:
            self.__controlnet_scribble = ControlNetModel.from_pretrained(
                "lllyasviel/control_v11p_sd15_scribble",
                torch_dtype=torch.float16,
                cache_dir=get_hf_cache_dir(),
            )

            self.__controlnet_seg = ControlNetModel.from_pretrained(
                "lllyasviel/control_v11p_sd15_seg",
                torch_dtype=torch.float16,
                cache_dir=get_hf_cache_dir(),
            )

            kwargs = {**pipeline.pipe.components}  # pyright: ignore
            kwargs.pop("image_encoder", None)
            self.pipe = StableDiffusionControlNetImg2ImgPipeline(
                **kwargs, controlnet=self.__controlnet_seg
            ).to("cuda")
            self.pipe.safety_checker = None
            self.pipe2 = StableDiffusionControlNetImg2ImgPipeline(
                **kwargs, controlnet=[self.__controlnet_scribble, self.__controlnet_seg]
            ).to("cuda")
            self.pipe2.safety_checker = None

    def process_seg(
        self,
        image: Image.Image,
        prompt: str,
        negative_prompt: str,
        seed: int,
    ):
        if get_is_sdxl():
            raise Exception("SDXL is not supported for this method")

        torch.manual_seed(seed)

        image = ImageUtil.resize_image(image, 512)

        img = self.pipe.__call__(
            image=image,
            control_image=image,
            prompt=prompt,
            num_inference_steps=15,
            negative_prompt=negative_prompt,
            guidance_scale=10,
            strength=0.8,
        ).images[0]

        return img

    def process_img(
        self,
        prompt: str,
        negative_prompt: str,
        seed: int,
        image: Optional[Image.Image] = None,
        image2: Optional[Image.Image] = None,
    ):
        torch.manual_seed(seed)

        b_dimen = get_base_dimension()

        if not image:
            size = (b_dimen, b_dimen)
            if image2:
                size = image2.size
            image = Image.new("RGB", size, color=0)

        if not image2:
            size = (b_dimen, b_dimen)
            if image:
                size = image.size
            image2 = Image.new("RGB", size, color=0)

        if get_is_sdxl():
            size = HighRes.find_closest_sdxl_aspect_ratio(image.size[0], image.size[1])
            image = image.resize(size)

            images = self.pipe.__call__(
                image=image,
                condition_image=image,
                negative_prompt=negative_prompt,
                prompt=prompt,
                seed=seed,
                num_inference_steps=10,
                width=image.size[0],
                height=image.size[1],
            )
            img = images[0]
        else:
            image = ImageUtil.resize_image(image, b_dimen)

            scribble = ControlNet.scribble_image(image)

            image2 = ImageUtil.resize_image(image2, b_dimen)

            img = self.pipe2.__call__(
                image=image,
                control_image=[scribble, image2],
                prompt=prompt,
                num_inference_steps=15,
                negative_prompt=negative_prompt,
                guidance_scale=10,
                strength=0.9,
                width=image.size[0],
                height=image.size[1],
                controlnet_conditioning_scale=[1.0, 0.8],
            ).images[0]

        img = ImageUtil.resize_image(img, 512)

        return img