from skimage.filters import threshold_otsu from PIL import Image from diffusers import AutoencoderKL, StableDiffusionPipeline, UNet2DConditionModel,UniPCMultistepScheduler from accelerate import Accelerator from accelerate.utils import ProjectConfiguration from transformers import CLIPTextModel,CLIPTokenizer import torch import numpy as np from skimage.filters import threshold_otsu from model.controlnet_SPADE import ControlNetModel from model.ControlnetPipeline import StableDiffusionControlNetPipeline_SPADE class Pipeline_demo: def __init__(self, unet_ckpt="pretrain_weights/unet", controlnet_ckpt="pretrain_weights/controlnet", pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5", is_cpu=False, mixed_precision = "fp16", enable_xformers_memory_efficient_attention=True, output_dir = "output" ) -> None: accelerator_project_config = ProjectConfiguration(project_dir=output_dir, logging_dir=output_dir) self.is_cpu = is_cpu if is_cpu: mixed_precision = "no" self.accelerator = Accelerator( gradient_accumulation_steps=1, mixed_precision=mixed_precision, log_with="tensorboard", cpu= True if is_cpu else False, project_config=accelerator_project_config, ) weight_dtype = torch.float32 if not is_cpu and self.accelerator.mixed_precision == "fp16": weight_dtype = torch.float16 self.text_encoder = CLIPTextModel.from_pretrained( pretrained_model_name_or_path, subfolder="text_encoder", revision=None ) self.tokenizer = CLIPTokenizer.from_pretrained( pretrained_model_name_or_path, subfolder="tokenizer", revision=None ) self.unet = UNet2DConditionModel.from_pretrained(unet_ckpt) self.vae = AutoencoderKL.from_pretrained( pretrained_model_name_or_path, subfolder="vae", revision=None ) self.controlnet = ControlNetModel.from_pretrained(controlnet_ckpt) self.pipeline1 = StableDiffusionPipeline.from_pretrained( pretrained_model_name_or_path, vae=self.accelerator.unwrap_model(self.vae), text_encoder=self.accelerator.unwrap_model(self.text_encoder), tokenizer=self.tokenizer, unet=self.accelerator.unwrap_model(self.unet), safety_checker=None, revision=None, torch_dtype=weight_dtype, ) self.pipeline2 = StableDiffusionControlNetPipeline_SPADE.from_pretrained( pretrained_model_name_or_path, vae=self.accelerator.unwrap_model(self.vae), text_encoder=self.accelerator.unwrap_model(self.text_encoder), tokenizer=self.tokenizer, unet=self.accelerator.unwrap_model(self.unet), controlnet=self.accelerator.unwrap_model(self.controlnet), safety_checker=None, revision=None, torch_dtype=weight_dtype, ) self.pipeline1 = self.pipeline1.to(self.accelerator.device) self.pipeline1.set_progress_bar_config(disable=False) self.pipeline2 = self.pipeline2.to(self.accelerator.device) self.pipeline2.set_progress_bar_config(disable=False) self.pipeline2.scheduler = UniPCMultistepScheduler.from_config(self.pipeline2.scheduler.config) if not is_cpu and enable_xformers_memory_efficient_attention: self.pipeline1.enable_xformers_memory_efficient_attention() self.pipeline2.enable_xformers_memory_efficient_attention() def binarize(self,image_array): otsu_thresh = threshold_otsu(image_array) # Apply Otsu's threshold to binarize the image binarized_otsu_array = np.where(image_array > otsu_thresh, 255, 0) # Convert the binarized array back into an image binarized_otsu_image = Image.fromarray(binarized_otsu_array.astype(np.uint8)).convert("RGB") return binarized_otsu_image def generate_semantic_map(self, semantic_map_prompt,seed=None): if seed is None: generator = None else: generator = torch.Generator(device=self.accelerator.device).manual_seed(seed) device_type = "cpu" if self.is_cpu else "cuda" with torch.autocast(device_type): image_anno = self.pipeline1(semantic_map_prompt, num_inference_steps=20, generator=generator).images[0] bin_img = self.binarize(np.array(image_anno)) return bin_img def generate_image(self, image_prompt, bin_img,seed=None): if seed is None: generator = None else: generator = torch.Generator(device=self.accelerator.device).manual_seed(seed) device_type = "cpu" if self.is_cpu else "cuda" bin_img = Image.fromarray(bin_img) with torch.autocast(device_type): image_img = self.pipeline2(image_prompt, bin_img, num_inference_steps=20, generator=generator).images[0] return image_img