import torch from PIL import Image from diffusers import ( StableDiffusionControlNetPipeline, UniPCMultistepScheduler, ControlNetModel ) class GeoPainting: DEFAULT_CONTROLNET_MODEL = "lllyasviel/control_v11f1p_sd15_depth" DEFAULT_DIFFUSER_MODEL = "geospatial_diffuser" def __init__(self, controlnet_model_path=DEFAULT_CONTROLNET_MODEL, diffuser_model=DEFAULT_DIFFUSER_MODEL): self.controlnet = ControlNetModel.from_pretrained(controlnet_model_path, torch_dtype=torch.float16) self.generator = torch.Generator(device="cpu").manual_seed(2) self.pipe = StableDiffusionControlNetPipeline.from_pretrained( diffuser_model, low_cpu_mem_usage=False, device_map=None, controlnet=self.controlnet, torch_dtype=torch.float16 ) self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config) if torch.cuda.is_available(): self.pipe.enable_model_cpu_offload() self.pipe.enable_xformers_memory_efficient_attention() def generate_painting(self, input_promp, control_image): image = Image.fromarray(control_image.astype('uint8')) output = self.pipe( input_promp, image, negative_prompt="ugly, disfigured, low quality, blurry, nsfw", generator=self.generator, num_inference_steps=20, ) return output.images[0]