|
from skimage.filters import threshold_otsu |
|
from PIL import Image |
|
from diffusers import AutoencoderKL, StableDiffusionPipeline, UNet2DConditionModel,UniPCMultistepScheduler |
|
from accelerate import Accelerator |
|
from accelerate.utils import ProjectConfiguration |
|
from transformers import CLIPTextModel,CLIPTokenizer |
|
import torch |
|
import numpy as np |
|
from skimage.filters import threshold_otsu |
|
from model.controlnet_SPADE import ControlNetModel |
|
from model.ControlnetPipeline import StableDiffusionControlNetPipeline_SPADE |
|
|
|
class Pipeline_demo: |
|
def __init__(self, |
|
unet_ckpt="pretrain_weights/unet", |
|
controlnet_ckpt="pretrain_weights/controlnet", |
|
pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5", |
|
is_cpu=False, |
|
mixed_precision = "fp16", |
|
enable_xformers_memory_efficient_attention=True, |
|
output_dir = "output" |
|
) -> None: |
|
accelerator_project_config = ProjectConfiguration(project_dir=output_dir, logging_dir=output_dir) |
|
|
|
self.is_cpu = is_cpu |
|
if is_cpu: |
|
mixed_precision = "no" |
|
|
|
self.accelerator = Accelerator( |
|
gradient_accumulation_steps=1, |
|
mixed_precision=mixed_precision, |
|
log_with="tensorboard", |
|
cpu= True if is_cpu else False, |
|
project_config=accelerator_project_config, |
|
) |
|
|
|
weight_dtype = torch.float32 |
|
if not is_cpu and self.accelerator.mixed_precision == "fp16": |
|
weight_dtype = torch.float16 |
|
|
|
self.text_encoder = CLIPTextModel.from_pretrained( |
|
pretrained_model_name_or_path, subfolder="text_encoder", revision=None |
|
) |
|
self.tokenizer = CLIPTokenizer.from_pretrained( |
|
pretrained_model_name_or_path, subfolder="tokenizer", revision=None |
|
) |
|
self.unet = UNet2DConditionModel.from_pretrained(unet_ckpt) |
|
self.vae = AutoencoderKL.from_pretrained( |
|
pretrained_model_name_or_path, subfolder="vae", revision=None |
|
) |
|
self.controlnet = ControlNetModel.from_pretrained(controlnet_ckpt) |
|
|
|
self.pipeline1 = StableDiffusionPipeline.from_pretrained( |
|
pretrained_model_name_or_path, |
|
vae=self.accelerator.unwrap_model(self.vae), |
|
text_encoder=self.accelerator.unwrap_model(self.text_encoder), |
|
tokenizer=self.tokenizer, |
|
unet=self.accelerator.unwrap_model(self.unet), |
|
safety_checker=None, |
|
revision=None, |
|
torch_dtype=weight_dtype, |
|
) |
|
|
|
self.pipeline2 = StableDiffusionControlNetPipeline_SPADE.from_pretrained( |
|
pretrained_model_name_or_path, |
|
vae=self.accelerator.unwrap_model(self.vae), |
|
text_encoder=self.accelerator.unwrap_model(self.text_encoder), |
|
tokenizer=self.tokenizer, |
|
unet=self.accelerator.unwrap_model(self.unet), |
|
controlnet=self.accelerator.unwrap_model(self.controlnet), |
|
safety_checker=None, |
|
revision=None, |
|
torch_dtype=weight_dtype, |
|
) |
|
|
|
|
|
self.pipeline1 = self.pipeline1.to(self.accelerator.device) |
|
self.pipeline1.set_progress_bar_config(disable=False) |
|
self.pipeline2 = self.pipeline2.to(self.accelerator.device) |
|
self.pipeline2.set_progress_bar_config(disable=False) |
|
|
|
self.pipeline2.scheduler = UniPCMultistepScheduler.from_config(self.pipeline2.scheduler.config) |
|
|
|
if not is_cpu and enable_xformers_memory_efficient_attention: |
|
self.pipeline1.enable_xformers_memory_efficient_attention() |
|
self.pipeline2.enable_xformers_memory_efficient_attention() |
|
|
|
def binarize(self,image_array): |
|
|
|
otsu_thresh = threshold_otsu(image_array) |
|
|
|
binarized_otsu_array = np.where(image_array > otsu_thresh, 255, 0) |
|
|
|
binarized_otsu_image = Image.fromarray(binarized_otsu_array.astype(np.uint8)).convert("RGB") |
|
return binarized_otsu_image |
|
|
|
def generate_semantic_map(self, semantic_map_prompt,seed=None): |
|
if seed is None: |
|
generator = None |
|
else: |
|
generator = torch.Generator(device=self.accelerator.device).manual_seed(seed) |
|
|
|
device_type = "cpu" if self.is_cpu else "cuda" |
|
with torch.autocast(device_type): |
|
image_anno = self.pipeline1(semantic_map_prompt, num_inference_steps=20, generator=generator).images[0] |
|
bin_img = self.binarize(np.array(image_anno)) |
|
|
|
return bin_img |
|
|
|
def generate_image(self, image_prompt, bin_img,seed=None): |
|
|
|
if seed is None: |
|
generator = None |
|
else: |
|
generator = torch.Generator(device=self.accelerator.device).manual_seed(seed) |
|
|
|
device_type = "cpu" if self.is_cpu else "cuda" |
|
|
|
bin_img = Image.fromarray(bin_img) |
|
with torch.autocast(device_type): |
|
image_img = self.pipeline2(image_prompt, bin_img, num_inference_steps=20, generator=generator).images[0] |
|
|
|
return image_img |
|
|
|
|
|
|
|
|
|
|