|
import os |
|
import gc |
|
import torch |
|
from diffusers import UniPCMultistepScheduler, AutoencoderKL, ControlNetModel |
|
from safetensors.torch import load_file |
|
from pipeline.pipeline_controlnext import StableDiffusionXLControlNeXtPipeline |
|
from models.unet import UNet2DConditionModel |
|
from models.controlnet import ControlNetModel |
|
from . import utils |
|
|
|
UNET_CONFIG = { |
|
"act_fn": "silu", |
|
"addition_embed_type": "text_time", |
|
"addition_embed_type_num_heads": 64, |
|
"addition_time_embed_dim": 256, |
|
"attention_head_dim": [ |
|
5, |
|
10, |
|
20 |
|
], |
|
"block_out_channels": [ |
|
320, |
|
640, |
|
1280 |
|
], |
|
"center_input_sample": False, |
|
"class_embed_type": None, |
|
"class_embeddings_concat": False, |
|
"conv_in_kernel": 3, |
|
"conv_out_kernel": 3, |
|
"cross_attention_dim": 2048, |
|
"cross_attention_norm": None, |
|
"down_block_types": [ |
|
"DownBlock2D", |
|
"CrossAttnDownBlock2D", |
|
"CrossAttnDownBlock2D" |
|
], |
|
"downsample_padding": 1, |
|
"dual_cross_attention": False, |
|
"encoder_hid_dim": None, |
|
"encoder_hid_dim_type": None, |
|
"flip_sin_to_cos": True, |
|
"freq_shift": 0, |
|
"in_channels": 4, |
|
"layers_per_block": 2, |
|
"mid_block_only_cross_attention": None, |
|
"mid_block_scale_factor": 1, |
|
"mid_block_type": "UNetMidBlock2DCrossAttn", |
|
"norm_eps": 1e-05, |
|
"norm_num_groups": 32, |
|
"num_attention_heads": None, |
|
"num_class_embeds": None, |
|
"only_cross_attention": False, |
|
"out_channels": 4, |
|
"projection_class_embeddings_input_dim": 2816, |
|
"resnet_out_scale_factor": 1.0, |
|
"resnet_skip_time_act": False, |
|
"resnet_time_scale_shift": "default", |
|
"sample_size": 128, |
|
"time_cond_proj_dim": None, |
|
"time_embedding_act_fn": None, |
|
"time_embedding_dim": None, |
|
"time_embedding_type": "positional", |
|
"timestep_post_act": None, |
|
"transformer_layers_per_block": [ |
|
1, |
|
2, |
|
10 |
|
], |
|
"up_block_types": [ |
|
"CrossAttnUpBlock2D", |
|
"CrossAttnUpBlock2D", |
|
"UpBlock2D" |
|
], |
|
"upcast_attention": None, |
|
"use_linear_projection": True |
|
} |
|
|
|
CONTROLNET_CONFIG = { |
|
'in_channels': [128, 128], |
|
'out_channels': [128, 256], |
|
'groups': [4, 8], |
|
'time_embed_dim': 256, |
|
'final_out_channels': 320, |
|
'_use_default_values': ['time_embed_dim', 'groups', 'in_channels', 'final_out_channels', 'out_channels'] |
|
} |
|
|
|
|
|
def get_pipeline( |
|
pretrained_model_name_or_path, |
|
unet_model_name_or_path, |
|
controlnet_model_name_or_path, |
|
vae_model_name_or_path=None, |
|
lora_path=None, |
|
load_weight_increasement=False, |
|
enable_xformers_memory_efficient_attention=False, |
|
revision=None, |
|
variant=None, |
|
hf_cache_dir=None, |
|
use_safetensors=True, |
|
device=None, |
|
): |
|
pipeline_init_kwargs = {} |
|
|
|
print(f"loading unet from {pretrained_model_name_or_path}") |
|
if os.path.isfile(pretrained_model_name_or_path): |
|
|
|
unet_sd = load_file(pretrained_model_name_or_path) if pretrained_model_name_or_path.endswith(".safetensors") else torch.load(pretrained_model_name_or_path) |
|
unet_sd = utils.extract_unet_state_dict(unet_sd) |
|
unet_sd = utils.convert_sdxl_unet_state_dict_to_diffusers(unet_sd) |
|
unet = UNet2DConditionModel.from_config(UNET_CONFIG) |
|
unet.load_state_dict(unet_sd, strict=True) |
|
else: |
|
unet = UNet2DConditionModel.from_pretrained( |
|
pretrained_model_name_or_path, |
|
cache_dir=hf_cache_dir, |
|
variant=variant, |
|
torch_dtype=torch.float16, |
|
use_safetensors=use_safetensors, |
|
subfolder="unet", |
|
) |
|
unet = unet.to(dtype=torch.float16) |
|
pipeline_init_kwargs["unet"] = unet |
|
|
|
if vae_model_name_or_path is not None: |
|
print(f"loading vae from {vae_model_name_or_path}") |
|
vae = AutoencoderKL.from_pretrained(vae_model_name_or_path, cache_dir=hf_cache_dir, torch_dtype=torch.float16).to(device) |
|
pipeline_init_kwargs["vae"] = vae |
|
|
|
if controlnet_model_name_or_path is not None: |
|
pipeline_init_kwargs["controlnet"] = ControlNetModel.from_config(CONTROLNET_CONFIG).to(device, dtype=torch.float32) |
|
|
|
print(f"loading pipeline from {pretrained_model_name_or_path}") |
|
if os.path.isfile(pretrained_model_name_or_path): |
|
pipeline: StableDiffusionXLControlNeXtPipeline = StableDiffusionXLControlNeXtPipeline.from_single_file( |
|
pretrained_model_name_or_path, |
|
use_safetensors=pretrained_model_name_or_path.endswith(".safetensors"), |
|
local_files_only=True, |
|
cache_dir=hf_cache_dir, |
|
**pipeline_init_kwargs, |
|
) |
|
else: |
|
pipeline: StableDiffusionXLControlNeXtPipeline = StableDiffusionXLControlNeXtPipeline.from_pretrained( |
|
pretrained_model_name_or_path, |
|
revision=revision, |
|
variant=variant, |
|
use_safetensors=use_safetensors, |
|
cache_dir=hf_cache_dir, |
|
**pipeline_init_kwargs, |
|
) |
|
|
|
pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config) |
|
if unet_model_name_or_path is not None: |
|
print(f"loading controlnext unet from {unet_model_name_or_path}") |
|
pipeline.load_controlnext_unet_weights( |
|
unet_model_name_or_path, |
|
load_weight_increasement=load_weight_increasement, |
|
use_safetensors=True, |
|
torch_dtype=torch.float16, |
|
cache_dir=hf_cache_dir, |
|
) |
|
if controlnet_model_name_or_path is not None: |
|
print(f"loading controlnext controlnet from {controlnet_model_name_or_path}") |
|
pipeline.load_controlnext_controlnet_weights( |
|
controlnet_model_name_or_path, |
|
use_safetensors=True, |
|
torch_dtype=torch.float32, |
|
cache_dir=hf_cache_dir, |
|
) |
|
pipeline.set_progress_bar_config() |
|
pipeline = pipeline.to(device, dtype=torch.float16) |
|
|
|
if lora_path is not None: |
|
pipeline.load_lora_weights(lora_path) |
|
if enable_xformers_memory_efficient_attention: |
|
pipeline.enable_xformers_memory_efficient_attention() |
|
|
|
gc.collect() |
|
if str(device) == 'cuda' and torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
|
|
return pipeline |
|
|
|
|
|
def get_scheduler( |
|
scheduler_name, |
|
scheduler_config, |
|
): |
|
if scheduler_name == 'Euler A': |
|
from diffusers.schedulers import EulerAncestralDiscreteScheduler |
|
scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler_config) |
|
elif scheduler_name == 'UniPC': |
|
from diffusers.schedulers import UniPCMultistepScheduler |
|
scheduler = UniPCMultistepScheduler.from_config(scheduler_config) |
|
elif scheduler_name == 'Euler': |
|
from diffusers.schedulers import EulerDiscreteScheduler |
|
scheduler = EulerDiscreteScheduler.from_config(scheduler_config) |
|
elif scheduler_name == 'DDIM': |
|
from diffusers.schedulers import DDIMScheduler |
|
scheduler = DDIMScheduler.from_config(scheduler_config) |
|
elif scheduler_name == 'DDPM': |
|
from diffusers.schedulers import DDPMScheduler |
|
scheduler = DDPMScheduler.from_config(scheduler_config) |
|
else: |
|
raise ValueError(f"Unknown scheduler: {scheduler_name}") |
|
return scheduler |
|
|