|
import gc |
|
|
|
import numpy as np |
|
import PIL.Image |
|
import torch |
|
from controlnet_aux import ( |
|
CannyDetector, |
|
ContentShuffleDetector, |
|
HEDdetector, |
|
LineartAnimeDetector, |
|
LineartDetector, |
|
MidasDetector, |
|
MLSDdetector, |
|
NormalBaeDetector, |
|
OpenposeDetector, |
|
PidiNetDetector, |
|
) |
|
from controlnet_aux.util import HWC3 |
|
|
|
from cv_utils import resize_image |
|
from depth_estimator import DepthEstimator |
|
from image_segmentor import ImageSegmentor |
|
|
|
|
|
class Preprocessor: |
|
MODEL_ID = "lllyasviel/Annotators" |
|
|
|
def __init__(self): |
|
self.model = None |
|
self.name = "" |
|
|
|
def load(self, name: str) -> None: |
|
if name == self.name: |
|
return |
|
if name == "HED": |
|
self.model = HEDdetector.from_pretrained(self.MODEL_ID) |
|
elif name == "Midas": |
|
self.model = MidasDetector.from_pretrained(self.MODEL_ID) |
|
elif name == "MLSD": |
|
self.model = MLSDdetector.from_pretrained(self.MODEL_ID) |
|
elif name == "Openpose": |
|
self.model = OpenposeDetector.from_pretrained(self.MODEL_ID) |
|
elif name == "PidiNet": |
|
self.model = PidiNetDetector.from_pretrained(self.MODEL_ID) |
|
elif name == "NormalBae": |
|
self.model = NormalBaeDetector.from_pretrained(self.MODEL_ID) |
|
elif name == "Lineart": |
|
self.model = LineartDetector.from_pretrained(self.MODEL_ID) |
|
elif name == "LineartAnime": |
|
self.model = LineartAnimeDetector.from_pretrained(self.MODEL_ID) |
|
elif name == "Canny": |
|
self.model = CannyDetector() |
|
elif name == "ContentShuffle": |
|
self.model = ContentShuffleDetector() |
|
elif name == "DPT": |
|
self.model = DepthEstimator() |
|
elif name == "UPerNet": |
|
self.model = ImageSegmentor() |
|
else: |
|
raise ValueError |
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
self.name = name |
|
|
|
def __call__(self, image: PIL.Image.Image, **kwargs) -> PIL.Image.Image: |
|
if self.name == "Canny": |
|
if "detect_resolution" in kwargs: |
|
detect_resolution = kwargs.pop("detect_resolution") |
|
image = np.array(image) |
|
image = HWC3(image) |
|
image = resize_image(image, resolution=detect_resolution) |
|
image = self.model(image, **kwargs) |
|
return PIL.Image.fromarray(image) |
|
elif self.name == "Midas": |
|
detect_resolution = kwargs.pop("detect_resolution", 512) |
|
image_resolution = kwargs.pop("image_resolution", 512) |
|
image = np.array(image) |
|
image = HWC3(image) |
|
image = resize_image(image, resolution=detect_resolution) |
|
image = self.model(image, **kwargs) |
|
image = HWC3(image) |
|
image = resize_image(image, resolution=image_resolution) |
|
return PIL.Image.fromarray(image) |
|
else: |
|
return self.model(image, **kwargs) |
|
|