interior-ai-designer / preprocess_anime.py
Bobby
profiler part 2
838a1f4
raw
history blame
2.55 kB
import gc
import numpy as np
import PIL.Image
import torch
from controlnet_aux import NormalBaeDetector#, CannyDetector
from controlnet_aux.util import HWC3
from cv_utils import resize_image
class Preprocessor:
MODEL_ID = "lllyasviel/Annotators"
def __init__(self):
self.model = None
self.name = ""
def load(self, name: str) -> None:
if name == self.name:
return
elif name == "NormalBae":
self.model = NormalBaeDetector.from_pretrained(self.MODEL_ID).to("cuda")
# elif name == "Canny":
# self.model = CannyDetector()
else:
raise ValueError
torch.cuda.empty_cache()
gc.collect()
self.name = name
def __call__(self, image: PIL.Image.Image, **kwargs) -> PIL.Image.Image:
if self.name == "Canny":
if "detect_resolution" in kwargs:
detect_resolution = kwargs.pop("detect_resolution")
image = np.array(image)
image = HWC3(image)
image = resize_image(image, resolution=detect_resolution)
image = self.model(image, **kwargs)
return PIL.Image.fromarray(image)
elif self.name == "Midas":
detect_resolution = kwargs.pop("detect_resolution", 512)
image_resolution = kwargs.pop("image_resolution", 512)
image = np.array(image)
image = HWC3(image)
image = resize_image(image, resolution=detect_resolution)
image = self.model(image, **kwargs)
image = HWC3(image)
image = resize_image(image, resolution=image_resolution)
return PIL.Image.fromarray(image)
else:
return self.model(image, **kwargs)
def manage_memory(self):
torch.cuda.empty_cache()
gc.collect()
# Additional helper function to manage memory less frequently
def conditionally_manage_memory(memory_threshold=0.8):
"""
Frees up GPU memory if usage exceeds the threshold.
:param memory_threshold: Fraction of memory usage to trigger cleanup.
"""
if torch.cuda.is_available():
total_memory = torch.cuda.get_device_properties(0).total_memory
reserved_memory = torch.cuda.memory_reserved(0)
allocated_memory = torch.cuda.memory_allocated(0)
if reserved_memory / total_memory > memory_threshold:
torch.cuda.empty_cache()
gc.collect()