import gc import numpy as np import PIL.Image import torch from controlnet_aux import NormalBaeDetector#, CannyDetector from controlnet_aux.util import HWC3 from cv_utils import resize_image class Preprocessor: MODEL_ID = "lllyasviel/Annotators" def __init__(self): self.model = None self.name = "" def load(self, name: str) -> None: if name == self.name: return elif name == "NormalBae": self.model = NormalBaeDetector.from_pretrained(self.MODEL_ID).to("cuda") # elif name == "Canny": # self.model = CannyDetector() else: raise ValueError torch.cuda.empty_cache() gc.collect() self.name = name def __call__(self, image: PIL.Image.Image, **kwargs) -> PIL.Image.Image: if self.name == "Canny": if "detect_resolution" in kwargs: detect_resolution = kwargs.pop("detect_resolution") image = np.array(image) image = HWC3(image) image = resize_image(image, resolution=detect_resolution) image = self.model(image, **kwargs) return PIL.Image.fromarray(image) elif self.name == "Midas": detect_resolution = kwargs.pop("detect_resolution", 512) image_resolution = kwargs.pop("image_resolution", 512) image = np.array(image) image = HWC3(image) image = resize_image(image, resolution=detect_resolution) image = self.model(image, **kwargs) image = HWC3(image) image = resize_image(image, resolution=image_resolution) return PIL.Image.fromarray(image) else: return self.model(image, **kwargs) def manage_memory(self): torch.cuda.empty_cache() gc.collect() # Additional helper function to manage memory less frequently def conditionally_manage_memory(memory_threshold=0.8): """ Frees up GPU memory if usage exceeds the threshold. :param memory_threshold: Fraction of memory usage to trigger cleanup. """ if torch.cuda.is_available(): total_memory = torch.cuda.get_device_properties(0).total_memory reserved_memory = torch.cuda.memory_reserved(0) allocated_memory = torch.cuda.memory_allocated(0) if reserved_memory / total_memory > memory_threshold: torch.cuda.empty_cache() gc.collect()