|
import gc
|
|
|
|
import numpy as np
|
|
import PIL.Image
|
|
import torch
|
|
import torchvision
|
|
from controlnet_aux import (
|
|
CannyDetector,
|
|
ContentShuffleDetector,
|
|
HEDdetector,
|
|
LineartAnimeDetector,
|
|
LineartDetector,
|
|
MidasDetector,
|
|
MLSDdetector,
|
|
NormalBaeDetector,
|
|
OpenposeDetector,
|
|
PidiNetDetector,
|
|
)
|
|
from controlnet_aux.util import HWC3
|
|
|
|
from cv_utils import resize_image
|
|
from depth_estimator import DepthEstimator
|
|
from image_segmentor import ImageSegmentor
|
|
|
|
from kornia.core import Tensor
|
|
from kornia.filters import canny
|
|
|
|
|
|
class Canny:
|
|
|
|
def __call__(
|
|
self,
|
|
images: np.array,
|
|
low_threshold: float = 0.1,
|
|
high_threshold: float = 0.2,
|
|
kernel_size: tuple[int, int] | int = (5, 5),
|
|
sigma: tuple[float, float] | Tensor = (1, 1),
|
|
hysteresis: bool = True,
|
|
eps: float = 1e-6
|
|
) -> torch.Tensor:
|
|
|
|
assert low_threshold is not None, "low_threshold must be provided"
|
|
assert high_threshold is not None, "high_threshold must be provided"
|
|
|
|
images = torch.from_numpy(images).permute(2, 0, 1).unsqueeze(0) / 255.0
|
|
|
|
images_tensor = canny(images, low_threshold, high_threshold, kernel_size, sigma, hysteresis, eps)[1]
|
|
images_tensor = (images_tensor[0][0].numpy() * 255).astype(np.uint8)
|
|
return images_tensor
|
|
|
|
|
|
class Preprocessor:
|
|
MODEL_ID = "lllyasviel/Annotators"
|
|
|
|
def __init__(self):
|
|
self.model = None
|
|
self.name = ""
|
|
|
|
def load(self, name: str) -> None:
|
|
if name == self.name:
|
|
return
|
|
if name == "Canny":
|
|
self.model = Canny()
|
|
elif name == "DPT":
|
|
self.model = DepthEstimator()
|
|
else:
|
|
raise ValueError
|
|
torch.cuda.empty_cache()
|
|
gc.collect()
|
|
self.name = name
|
|
|
|
def __call__(self, image: PIL.Image.Image, **kwargs) -> PIL.Image.Image:
|
|
if self.name == "Canny":
|
|
if "detect_resolution" in kwargs:
|
|
detect_resolution = kwargs.pop("detect_resolution")
|
|
image = np.array(image)
|
|
image = HWC3(image)
|
|
image = resize_image(image, resolution=detect_resolution)
|
|
image = self.model(image, **kwargs)
|
|
return PIL.Image.fromarray(image).convert('RGB')
|
|
elif self.name == "Midas":
|
|
detect_resolution = kwargs.pop("detect_resolution", 512)
|
|
image_resolution = kwargs.pop("image_resolution", 512)
|
|
image = np.array(image)
|
|
image = HWC3(image)
|
|
image = resize_image(image, resolution=detect_resolution)
|
|
image = self.model(image, **kwargs)
|
|
image = HWC3(image)
|
|
image = resize_image(image, resolution=image_resolution)
|
|
return PIL.Image.fromarray(image)
|
|
else:
|
|
return self.model(image, **kwargs)
|
|
|