tanlei0
"v1"
8222fd4
raw
history blame contribute delete
No virus
5.25 kB
from skimage.filters import threshold_otsu
from PIL import Image
from diffusers import AutoencoderKL, StableDiffusionPipeline, UNet2DConditionModel,UniPCMultistepScheduler
from accelerate import Accelerator
from accelerate.utils import ProjectConfiguration
from transformers import CLIPTextModel,CLIPTokenizer
import torch
import numpy as np
from skimage.filters import threshold_otsu
from model.controlnet_SPADE import ControlNetModel
from model.ControlnetPipeline import StableDiffusionControlNetPipeline_SPADE
class Pipeline_demo:
def __init__(self,
unet_ckpt="pretrain_weights/unet",
controlnet_ckpt="pretrain_weights/controlnet",
pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5",
is_cpu=False,
mixed_precision = "fp16",
enable_xformers_memory_efficient_attention=True,
output_dir = "output"
) -> None:
accelerator_project_config = ProjectConfiguration(project_dir=output_dir, logging_dir=output_dir)
self.is_cpu = is_cpu
if is_cpu:
mixed_precision = "no"
self.accelerator = Accelerator(
gradient_accumulation_steps=1,
mixed_precision=mixed_precision,
log_with="tensorboard",
cpu= True if is_cpu else False,
project_config=accelerator_project_config,
)
weight_dtype = torch.float32
if not is_cpu and self.accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
self.text_encoder = CLIPTextModel.from_pretrained(
pretrained_model_name_or_path, subfolder="text_encoder", revision=None
)
self.tokenizer = CLIPTokenizer.from_pretrained(
pretrained_model_name_or_path, subfolder="tokenizer", revision=None
)
self.unet = UNet2DConditionModel.from_pretrained(unet_ckpt)
self.vae = AutoencoderKL.from_pretrained(
pretrained_model_name_or_path, subfolder="vae", revision=None
)
self.controlnet = ControlNetModel.from_pretrained(controlnet_ckpt)
self.pipeline1 = StableDiffusionPipeline.from_pretrained(
pretrained_model_name_or_path,
vae=self.accelerator.unwrap_model(self.vae),
text_encoder=self.accelerator.unwrap_model(self.text_encoder),
tokenizer=self.tokenizer,
unet=self.accelerator.unwrap_model(self.unet),
safety_checker=None,
revision=None,
torch_dtype=weight_dtype,
)
self.pipeline2 = StableDiffusionControlNetPipeline_SPADE.from_pretrained(
pretrained_model_name_or_path,
vae=self.accelerator.unwrap_model(self.vae),
text_encoder=self.accelerator.unwrap_model(self.text_encoder),
tokenizer=self.tokenizer,
unet=self.accelerator.unwrap_model(self.unet),
controlnet=self.accelerator.unwrap_model(self.controlnet),
safety_checker=None,
revision=None,
torch_dtype=weight_dtype,
)
self.pipeline1 = self.pipeline1.to(self.accelerator.device)
self.pipeline1.set_progress_bar_config(disable=False)
self.pipeline2 = self.pipeline2.to(self.accelerator.device)
self.pipeline2.set_progress_bar_config(disable=False)
self.pipeline2.scheduler = UniPCMultistepScheduler.from_config(self.pipeline2.scheduler.config)
if not is_cpu and enable_xformers_memory_efficient_attention:
self.pipeline1.enable_xformers_memory_efficient_attention()
self.pipeline2.enable_xformers_memory_efficient_attention()
def binarize(self,image_array):
otsu_thresh = threshold_otsu(image_array)
# Apply Otsu's threshold to binarize the image
binarized_otsu_array = np.where(image_array > otsu_thresh, 255, 0)
# Convert the binarized array back into an image
binarized_otsu_image = Image.fromarray(binarized_otsu_array.astype(np.uint8)).convert("RGB")
return binarized_otsu_image
def generate_semantic_map(self, semantic_map_prompt,seed=None):
if seed is None:
generator = None
else:
generator = torch.Generator(device=self.accelerator.device).manual_seed(seed)
device_type = "cpu" if self.is_cpu else "cuda"
with torch.autocast(device_type):
image_anno = self.pipeline1(semantic_map_prompt, num_inference_steps=20, generator=generator).images[0]
bin_img = self.binarize(np.array(image_anno))
return bin_img
def generate_image(self, image_prompt, bin_img,seed=None):
if seed is None:
generator = None
else:
generator = torch.Generator(device=self.accelerator.device).manual_seed(seed)
device_type = "cpu" if self.is_cpu else "cuda"
bin_img = Image.fromarray(bin_img)
with torch.autocast(device_type):
image_img = self.pipeline2(image_prompt, bin_img, num_inference_steps=20, generator=generator).images[0]
return image_img