import gc

import numpy as np
import PIL.Image
import torch
from controlnet_aux import NormalBaeDetector#, CannyDetector

from controlnet_aux.util import HWC3
import cv2
# from cv_utils import resize_image

class Preprocessor:
    MODEL_ID = "lllyasviel/Annotators"
    
    def resize_image(input_image, resolution, interpolation=None):
        H, W, C = input_image.shape
        H = float(H)
        W = float(W)
        k = float(resolution) / max(H, W)
        H *= k
        W *= k
        H = int(np.round(H / 64.0)) * 64
        W = int(np.round(W / 64.0)) * 64
        if interpolation is None:
            interpolation = cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA
        img = cv2.resize(input_image, (W, H), interpolation=interpolation)
        return img


    def __init__(self):
        self.model = None
        self.name = ""

    def load(self, name: str) -> None:
        if name == self.name:
            return
        elif name == "NormalBae":
            self.model = NormalBaeDetector.from_pretrained(self.MODEL_ID).to("cuda")
        # elif name == "Canny":
        #     self.model = CannyDetector()
        else:
            raise ValueError
        torch.cuda.empty_cache()
        gc.collect()
        
        self.name = name

    def __call__(self, image: PIL.Image.Image, **kwargs) -> PIL.Image.Image:
        if self.name == "Canny":
            if "detect_resolution" in kwargs:
                detect_resolution = kwargs.pop("detect_resolution")
                image = np.array(image)
                image = HWC3(image)
                image = resize_image(image, resolution=detect_resolution)
            image = self.model(image, **kwargs)
            return PIL.Image.fromarray(image)
        elif self.name == "Midas":
            detect_resolution = kwargs.pop("detect_resolution", 512)
            image_resolution = kwargs.pop("image_resolution", 512)
            image = np.array(image)
            image = HWC3(image)
            image = resize_image(image, resolution=detect_resolution)
            image = self.model(image, **kwargs)
            image = HWC3(image)
            image = resize_image(image, resolution=image_resolution)
            return PIL.Image.fromarray(image)
        else:
            return self.model(image, **kwargs)