Eugeoter commited on
Commit
76be739
1 Parent(s): 4074997
Files changed (5) hide show
  1. app.py +4 -24
  2. models/unet.py +0 -70
  3. pipeline/pipeline_controlnext.py +271 -1
  4. utils/tools.py +107 -52
  5. utils/utils.py +0 -68
app.py CHANGED
@@ -2,16 +2,12 @@ import gradio as gr
2
  import torch
3
  import numpy as np
4
  import spaces
5
- from PIL import Image
6
- from huggingface_hub import hf_hub_download
7
  from utils import utils, tools, preprocess
8
 
9
  BASE_MODEL_REPO_ID = "neta-art/neta-xl-2.0"
10
  BASE_MODEL_FILENAME = "neta-xl-v2.fp16.safetensors"
11
  VAE_PATH = "madebyollin/sdxl-vae-fp16-fix"
12
- CONTROLNEXT_REPO_ID = "Pbihao/ControlNeXt"
13
- UNET_FILENAME = "ControlAny-SDXL/anime_canny/unet.safetensors"
14
- CONTROLNET_FILENAME = "ControlAny-SDXL/anime_canny/controlnet.safetensors"
15
  CACHE_DIR = None
16
 
17
  DEFAULT_PROMPT = ""
@@ -20,26 +16,10 @@ DEFAULT_NEGATIVE_PROMPT = "worst quality, abstract, clumsy pose, deformed hand,
20
 
21
  def ui():
22
  device = "cuda" if torch.cuda.is_available() else "cpu"
23
- model_file = hf_hub_download(
24
- repo_id=BASE_MODEL_REPO_ID,
25
- filename=BASE_MODEL_FILENAME,
26
- cache_dir=CACHE_DIR,
27
- )
28
- unet_file = hf_hub_download(
29
- repo_id=CONTROLNEXT_REPO_ID,
30
- filename=UNET_FILENAME,
31
- cache_dir=CACHE_DIR,
32
- )
33
- controlnet_file = hf_hub_download(
34
- repo_id=CONTROLNEXT_REPO_ID,
35
- filename=CONTROLNET_FILENAME,
36
- cache_dir=CACHE_DIR,
37
- )
38
-
39
  pipeline = tools.get_pipeline(
40
- pretrained_model_name_or_path=model_file,
41
- unet_model_name_or_path=unet_file,
42
- controlnet_model_name_or_path=controlnet_file,
43
  vae_model_name_or_path=VAE_PATH,
44
  load_weight_increasement=True,
45
  device=device,
 
2
  import torch
3
  import numpy as np
4
  import spaces
 
 
5
  from utils import utils, tools, preprocess
6
 
7
  BASE_MODEL_REPO_ID = "neta-art/neta-xl-2.0"
8
  BASE_MODEL_FILENAME = "neta-xl-v2.fp16.safetensors"
9
  VAE_PATH = "madebyollin/sdxl-vae-fp16-fix"
10
+ CONTROLNEXT_REPO_ID = "Eugeoter/controlnext-sdxl-anime-canny"
 
 
11
  CACHE_DIR = None
12
 
13
  DEFAULT_PROMPT = ""
 
16
 
17
  def ui():
18
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  pipeline = tools.get_pipeline(
20
+ pretrained_model_name_or_path=BASE_MODEL_REPO_ID,
21
+ unet_model_name_or_path=CONTROLNEXT_REPO_ID,
22
+ controlnet_model_name_or_path=CONTROLNEXT_REPO_ID,
23
  vae_model_name_or_path=VAE_PATH,
24
  load_weight_increasement=True,
25
  device=device,
models/unet.py CHANGED
@@ -53,76 +53,6 @@ from diffusers.models.unets.unet_2d_blocks import (
53
 
54
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
55
 
56
- UNET_CONFIG = {
57
- "_class_name": "UNet2DConditionModel",
58
- "_diffusers_version": "0.19.0.dev0",
59
- "act_fn": "silu",
60
- "addition_embed_type": "text_time",
61
- "addition_embed_type_num_heads": 64,
62
- "addition_time_embed_dim": 256,
63
- "attention_head_dim": [
64
- 5,
65
- 10,
66
- 20
67
- ],
68
- "block_out_channels": [
69
- 320,
70
- 640,
71
- 1280
72
- ],
73
- "center_input_sample": False,
74
- "class_embed_type": None,
75
- "class_embeddings_concat": False,
76
- "conv_in_kernel": 3,
77
- "conv_out_kernel": 3,
78
- "cross_attention_dim": 2048,
79
- "cross_attention_norm": None,
80
- "down_block_types": [
81
- "DownBlock2D",
82
- "CrossAttnDownBlock2D",
83
- "CrossAttnDownBlock2D"
84
- ],
85
- "downsample_padding": 1,
86
- "dual_cross_attention": False,
87
- "encoder_hid_dim": None,
88
- "encoder_hid_dim_type": None,
89
- "flip_sin_to_cos": True,
90
- "freq_shift": 0,
91
- "in_channels": 4,
92
- "layers_per_block": 2,
93
- "mid_block_only_cross_attention": None,
94
- "mid_block_scale_factor": 1,
95
- "mid_block_type": "UNetMidBlock2DCrossAttn",
96
- "norm_eps": 1e-05,
97
- "norm_num_groups": 32,
98
- "num_attention_heads": None,
99
- "num_class_embeds": None,
100
- "only_cross_attention": False,
101
- "out_channels": 4,
102
- "projection_class_embeddings_input_dim": 2816,
103
- "resnet_out_scale_factor": 1.0,
104
- "resnet_skip_time_act": False,
105
- "resnet_time_scale_shift": "default",
106
- "sample_size": 128,
107
- "time_cond_proj_dim": None,
108
- "time_embedding_act_fn": None,
109
- "time_embedding_dim": None,
110
- "time_embedding_type": "positional",
111
- "timestep_post_act": None,
112
- "transformer_layers_per_block": [
113
- 1,
114
- 2,
115
- 10
116
- ],
117
- "up_block_types": [
118
- "CrossAttnUpBlock2D",
119
- "CrossAttnUpBlock2D",
120
- "UpBlock2D"
121
- ],
122
- "upcast_attention": None,
123
- "use_linear_projection": True
124
- }
125
-
126
 
127
  @dataclass
128
  class UNet2DConditionOutput(BaseOutput):
 
53
 
54
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
  @dataclass
58
  class UNet2DConditionOutput(BaseOutput):
pipeline/pipeline_controlnext.py CHANGED
@@ -14,7 +14,6 @@
14
 
15
  import inspect
16
  from typing import Any, Callable, Dict, List, Optional, Tuple, Union
17
- from packaging import version
18
  import torch
19
  from transformers import (
20
  CLIPImageProcessor,
@@ -57,6 +56,7 @@ from diffusers.utils import (
57
  from diffusers.utils.torch_utils import randn_tensor
58
  from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
59
  from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
 
60
 
61
  if is_invisible_watermark_available():
62
  from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
@@ -87,8 +87,128 @@ EXAMPLE_DOC_STRING = """
87
  ```
88
  """
89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
 
 
92
  def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
93
  """
94
  Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
@@ -280,6 +400,156 @@ class StableDiffusionXLControlNeXtPipeline(
280
  else:
281
  self.watermark = None
282
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283
  def prepare_image(
284
  self,
285
  image,
 
14
 
15
  import inspect
16
  from typing import Any, Callable, Dict, List, Optional, Tuple, Union
 
17
  import torch
18
  from transformers import (
19
  CLIPImageProcessor,
 
56
  from diffusers.utils.torch_utils import randn_tensor
57
  from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
58
  from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
59
+ from huggingface_hub.utils import validate_hf_hub_args
60
 
61
  if is_invisible_watermark_available():
62
  from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
 
87
  ```
88
  """
89
 
90
+ CONTROLNEXT_WEIGHT_NAME = "controlnet.bin"
91
+ CONTROLNEXT_WEIGHT_NAME_SAFE = "controlnet.safetensors"
92
+ UNET_WEIGHT_NAME = "unet.bin"
93
+ UNET_WEIGHT_NAME_SAFE = "unet.safetensors"
94
+
95
+
96
+ # Copied from https://github.com/kohya-ss/sd-scripts/blob/main/library/sdxl_model_util.py
97
+
98
+ def is_sdxl_state_dict(state_dict):
99
+ return any(key.startswith('input_blocks') for key in state_dict.keys())
100
+
101
+
102
+ def convert_sdxl_unet_state_dict_to_diffusers(sd):
103
+ unet_conversion_map = make_unet_conversion_map()
104
+
105
+ conversion_dict = {sd: hf for sd, hf in unet_conversion_map}
106
+ return convert_unet_state_dict(sd, conversion_dict)
107
+
108
+
109
+ def convert_unet_state_dict(src_sd, conversion_map):
110
+ converted_sd = {}
111
+ for src_key, value in src_sd.items():
112
+ src_key_fragments = src_key.split(".")[:-1] # remove weight/bias
113
+ while len(src_key_fragments) > 0:
114
+ src_key_prefix = ".".join(src_key_fragments) + "."
115
+ if src_key_prefix in conversion_map:
116
+ converted_prefix = conversion_map[src_key_prefix]
117
+ converted_key = converted_prefix + src_key[len(src_key_prefix):]
118
+ converted_sd[converted_key] = value
119
+ break
120
+ src_key_fragments.pop(-1)
121
+ assert len(src_key_fragments) > 0, f"key {src_key} not found in conversion map"
122
+
123
+ return converted_sd
124
+
125
+
126
+ def make_unet_conversion_map():
127
+ unet_conversion_map_layer = []
128
+
129
+ for i in range(3): # num_blocks is 3 in sdxl
130
+ # loop over downblocks/upblocks
131
+ for j in range(2):
132
+ # loop over resnets/attentions for downblocks
133
+ hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
134
+ sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
135
+ unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
136
+
137
+ if i < 3:
138
+ # no attention layers in down_blocks.3
139
+ hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
140
+ sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
141
+ unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
142
+
143
+ for j in range(3):
144
+ # loop over resnets/attentions for upblocks
145
+ hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
146
+ sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
147
+ unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
148
+
149
+ # if i > 0: commentout for sdxl
150
+ # no attention layers in up_blocks.0
151
+ hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
152
+ sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
153
+ unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
154
+
155
+ if i < 3:
156
+ # no downsample in down_blocks.3
157
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
158
+ sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
159
+ unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
160
+
161
+ # no upsample in up_blocks.3
162
+ hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
163
+ sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}." # change for sdxl
164
+ unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
165
+
166
+ hf_mid_atn_prefix = "mid_block.attentions.0."
167
+ sd_mid_atn_prefix = "middle_block.1."
168
+ unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
169
+
170
+ for j in range(2):
171
+ hf_mid_res_prefix = f"mid_block.resnets.{j}."
172
+ sd_mid_res_prefix = f"middle_block.{2*j}."
173
+ unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
174
+
175
+ unet_conversion_map_resnet = [
176
+ # (stable-diffusion, HF Diffusers)
177
+ ("in_layers.0.", "norm1."),
178
+ ("in_layers.2.", "conv1."),
179
+ ("out_layers.0.", "norm2."),
180
+ ("out_layers.3.", "conv2."),
181
+ ("emb_layers.1.", "time_emb_proj."),
182
+ ("skip_connection.", "conv_shortcut."),
183
+ ]
184
+
185
+ unet_conversion_map = []
186
+ for sd, hf in unet_conversion_map_layer:
187
+ if "resnets" in hf:
188
+ for sd_res, hf_res in unet_conversion_map_resnet:
189
+ unet_conversion_map.append((sd + sd_res, hf + hf_res))
190
+ else:
191
+ unet_conversion_map.append((sd, hf))
192
+
193
+ for j in range(2):
194
+ hf_time_embed_prefix = f"time_embedding.linear_{j+1}."
195
+ sd_time_embed_prefix = f"time_embed.{j*2}."
196
+ unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix))
197
+
198
+ for j in range(2):
199
+ hf_label_embed_prefix = f"add_embedding.linear_{j+1}."
200
+ sd_label_embed_prefix = f"label_emb.0.{j*2}."
201
+ unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix))
202
+
203
+ unet_conversion_map.append(("input_blocks.0.0.", "conv_in."))
204
+ unet_conversion_map.append(("out.0.", "conv_norm_out."))
205
+ unet_conversion_map.append(("out.2.", "conv_out."))
206
+
207
+ return unet_conversion_map
208
 
209
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
210
+
211
+
212
  def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
213
  """
214
  Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
 
400
  else:
401
  self.watermark = None
402
 
403
+ def load_controlnext_weights(
404
+ self,
405
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
406
+ load_weight_increasement: bool = False,
407
+ **kwargs,
408
+ ):
409
+ self.load_controlnext_unet_weights(pretrained_model_name_or_path_or_dict, load_weight_increasement, **kwargs)
410
+ kwargs['torch_dtype'] = torch.float32
411
+ self.load_controlnext_controlnet_weights(pretrained_model_name_or_path_or_dict, **kwargs)
412
+
413
+ def load_controlnext_unet_weights(
414
+ self,
415
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
416
+ load_weight_increasement: bool = False,
417
+ **kwargs,
418
+ ):
419
+ if isinstance(pretrained_model_name_or_path_or_dict, dict):
420
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
421
+
422
+ state_dict = self.controlnext_unet_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
423
+ if is_sdxl_state_dict(state_dict):
424
+ state_dict = convert_sdxl_unet_state_dict_to_diffusers(state_dict)
425
+
426
+ logger.info(f"Loading ControlNeXt UNet" + (f" with weight increasement." if load_weight_increasement else "."))
427
+ if load_weight_increasement:
428
+ unet_sd = self.unet.state_dict()
429
+ for k in state_dict.keys():
430
+ state_dict[k] = state_dict[k] + unet_sd[k]
431
+ self.unet.load_state_dict(state_dict, strict=False)
432
+
433
+ @classmethod
434
+ @validate_hf_hub_args
435
+ def controlnext_unet_state_dict(
436
+ cls,
437
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
438
+ **kwargs,
439
+ ):
440
+ if 'weight_name' not in kwargs:
441
+ kwargs['weight_name'] = UNET_WEIGHT_NAME_SAFE if kwargs.get('use_safetensors', False) else UNET_WEIGHT_NAME
442
+ return cls.controlnext_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
443
+
444
+ def load_controlnext_controlnet_weights(
445
+ self,
446
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
447
+ **kwargs,
448
+ ):
449
+ if self.controlnet is None:
450
+ raise ValueError("No ControlNeXt ControlNet found in the pipeline.")
451
+ if isinstance(pretrained_model_name_or_path_or_dict, dict):
452
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
453
+
454
+ state_dict = self.controlnext_controlnet_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
455
+
456
+ logger.info(f"Loading ControlNeXt ControlNet")
457
+ self.controlnet.load_state_dict(state_dict, strict=True)
458
+
459
+ @classmethod
460
+ @validate_hf_hub_args
461
+ def controlnext_controlnet_state_dict(
462
+ cls,
463
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
464
+ **kwargs,
465
+ ):
466
+ if 'weight_name' not in kwargs:
467
+ kwargs['weight_name'] = CONTROLNEXT_WEIGHT_NAME_SAFE if kwargs.get('use_safetensors', False) else CONTROLNEXT_WEIGHT_NAME
468
+ return cls.controlnext_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
469
+
470
+ @classmethod
471
+ @validate_hf_hub_args
472
+ def controlnext_state_dict(
473
+ cls,
474
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
475
+ **kwargs,
476
+ ):
477
+ r"""
478
+ Return state dict for controlnext weights.
479
+
480
+ Parameters:
481
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
482
+ Can be either:
483
+
484
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
485
+ the Hub.
486
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
487
+ with [`ModelMixin.save_pretrained`].
488
+ - A [torch state
489
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
490
+
491
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
492
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
493
+ is not used.
494
+ force_download (`bool`, *optional*, defaults to `False`):
495
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
496
+ cached versions if they exist.
497
+
498
+ proxies (`Dict[str, str]`, *optional*):
499
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
500
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
501
+ local_files_only (`bool`, *optional*, defaults to `False`):
502
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
503
+ won't be downloaded from the Hub.
504
+ token (`str` or *bool*, *optional*):
505
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
506
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
507
+ revision (`str`, *optional*, defaults to `"main"`):
508
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
509
+ allowed by Git.
510
+ subfolder (`str`, *optional*, defaults to `""`):
511
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
512
+ weight_name (`str`, *optional*, defaults to None):
513
+ Name of the serialized state dict file.
514
+ """
515
+ cache_dir = kwargs.pop("cache_dir", None)
516
+ force_download = kwargs.pop("force_download", False)
517
+ proxies = kwargs.pop("proxies", None)
518
+ local_files_only = kwargs.pop("local_files_only", None)
519
+ token = kwargs.pop("token", None)
520
+ revision = kwargs.pop("revision", None)
521
+ subfolder = kwargs.pop("subfolder", None)
522
+ weight_name = kwargs.pop("weight_name", None)
523
+ unet_config = kwargs.pop("unet_config", None)
524
+ use_safetensors = kwargs.pop("use_safetensors", None)
525
+
526
+ allow_pickle = False
527
+ if use_safetensors is None:
528
+ use_safetensors = True
529
+ allow_pickle = True
530
+
531
+ user_agent = {
532
+ "file_type": "attn_procs_weights",
533
+ "framework": "pytorch",
534
+ }
535
+
536
+ state_dict = cls._fetch_state_dict(
537
+ pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
538
+ weight_name=weight_name,
539
+ use_safetensors=use_safetensors,
540
+ local_files_only=local_files_only,
541
+ cache_dir=cache_dir,
542
+ force_download=force_download,
543
+ proxies=proxies,
544
+ token=token,
545
+ revision=revision,
546
+ subfolder=subfolder,
547
+ user_agent=user_agent,
548
+ allow_pickle=allow_pickle,
549
+ )
550
+
551
+ return state_dict
552
+
553
  def prepare_image(
554
  self,
555
  image,
utils/tools.py CHANGED
@@ -1,14 +1,90 @@
1
  import os
2
- import torch
3
  import gc
4
- from torch import nn
5
- from diffusers import UniPCMultistepScheduler, AutoencoderKL
6
  from safetensors.torch import load_file
7
  from pipeline.pipeline_controlnext import StableDiffusionXLControlNeXtPipeline
8
- from models.unet import UNet2DConditionModel, UNET_CONFIG
9
  from models.controlnet import ControlNetModel
10
  from . import utils
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  def get_pipeline(
14
  pretrained_model_name_or_path,
@@ -26,20 +102,6 @@ def get_pipeline(
26
  ):
27
  pipeline_init_kwargs = {}
28
 
29
- if controlnet_model_name_or_path is not None:
30
- print(f"loading controlnet from {controlnet_model_name_or_path}")
31
- controlnet = ControlNetModel()
32
- if controlnet_model_name_or_path is not None:
33
- utils.load_safetensors(controlnet, controlnet_model_name_or_path)
34
- else:
35
- controlnet.scale = nn.Parameter(torch.tensor(0.), requires_grad=False)
36
- controlnet.to(device, dtype=torch.float32)
37
- pipeline_init_kwargs["controlnet"] = controlnet
38
-
39
- utils.log_model_info(controlnet, "controlnext")
40
- else:
41
- print(f"no controlnet")
42
-
43
  print(f"loading unet from {pretrained_model_name_or_path}")
44
  if os.path.isfile(pretrained_model_name_or_path):
45
  # load unet from local checkpoint
@@ -49,42 +111,15 @@ def get_pipeline(
49
  unet = UNet2DConditionModel.from_config(UNET_CONFIG)
50
  unet.load_state_dict(unet_sd, strict=True)
51
  else:
52
- from huggingface_hub import hf_hub_download
53
- filename = "diffusion_pytorch_model"
54
- if variant == "fp16":
55
- filename += ".fp16"
56
- if use_safetensors:
57
- filename += ".safetensors"
58
- else:
59
- filename += ".pt"
60
- unet_file = hf_hub_download(
61
- repo_id=pretrained_model_name_or_path,
62
- filename="unet" + '/' + filename,
63
  cache_dir=hf_cache_dir,
 
 
 
 
64
  )
65
- unet_sd = load_file(unet_file) if unet_file.endswith(".safetensors") else torch.load(pretrained_model_name_or_path)
66
- unet_sd = utils.extract_unet_state_dict(unet_sd)
67
- unet_sd = utils.convert_sdxl_unet_state_dict_to_diffusers(unet_sd)
68
- unet = UNet2DConditionModel.from_config(UNET_CONFIG)
69
- unet.load_state_dict(unet_sd, strict=True)
70
  unet = unet.to(dtype=torch.float16)
71
- utils.log_model_info(unet, "unet")
72
-
73
- if unet_model_name_or_path is not None:
74
- print(f"loading controlnext unet from {unet_model_name_or_path}")
75
- controlnext_unet_sd = load_file(unet_model_name_or_path)
76
- controlnext_unet_sd = utils.convert_to_controlnext_unet_state_dict(controlnext_unet_sd)
77
- unet_sd = unet.state_dict()
78
- assert all(
79
- k in unet_sd for k in controlnext_unet_sd), \
80
- f"controlnext unet state dict is not compatible with unet state dict, missing keys: {set(controlnext_unet_sd.keys()) - set(unet_sd.keys())}, extra keys: {set(unet_sd.keys()) - set(controlnext_unet_sd.keys())}"
81
- if load_weight_increasement:
82
- print("loading weight increasement")
83
- for k in controlnext_unet_sd.keys():
84
- controlnext_unet_sd[k] = controlnext_unet_sd[k] + unet_sd[k]
85
- unet.load_state_dict(controlnext_unet_sd, strict=False)
86
- utils.log_model_info(controlnext_unet_sd, "controlnext unet")
87
-
88
  pipeline_init_kwargs["unet"] = unet
89
 
90
  if vae_model_name_or_path is not None:
@@ -92,6 +127,9 @@ def get_pipeline(
92
  vae = AutoencoderKL.from_pretrained(vae_model_name_or_path, cache_dir=hf_cache_dir, torch_dtype=torch.float16).to(device)
93
  pipeline_init_kwargs["vae"] = vae
94
 
 
 
 
95
  print(f"loading pipeline from {pretrained_model_name_or_path}")
96
  if os.path.isfile(pretrained_model_name_or_path):
97
  pipeline: StableDiffusionXLControlNeXtPipeline = StableDiffusionXLControlNeXtPipeline.from_single_file(
@@ -112,6 +150,23 @@ def get_pipeline(
112
  )
113
 
114
  pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  pipeline.set_progress_bar_config()
116
  pipeline = pipeline.to(device, dtype=torch.float16)
117
 
@@ -121,7 +176,7 @@ def get_pipeline(
121
  pipeline.enable_xformers_memory_efficient_attention()
122
 
123
  gc.collect()
124
- if torch.cuda.is_available():
125
  torch.cuda.empty_cache()
126
 
127
  return pipeline
 
1
  import os
 
2
  import gc
3
+ import torch
4
+ from diffusers import UniPCMultistepScheduler, AutoencoderKL, ControlNetModel
5
  from safetensors.torch import load_file
6
  from pipeline.pipeline_controlnext import StableDiffusionXLControlNeXtPipeline
7
+ from models.unet import UNet2DConditionModel
8
  from models.controlnet import ControlNetModel
9
  from . import utils
10
 
11
+ UNET_CONFIG = {
12
+ "act_fn": "silu",
13
+ "addition_embed_type": "text_time",
14
+ "addition_embed_type_num_heads": 64,
15
+ "addition_time_embed_dim": 256,
16
+ "attention_head_dim": [
17
+ 5,
18
+ 10,
19
+ 20
20
+ ],
21
+ "block_out_channels": [
22
+ 320,
23
+ 640,
24
+ 1280
25
+ ],
26
+ "center_input_sample": False,
27
+ "class_embed_type": None,
28
+ "class_embeddings_concat": False,
29
+ "conv_in_kernel": 3,
30
+ "conv_out_kernel": 3,
31
+ "cross_attention_dim": 2048,
32
+ "cross_attention_norm": None,
33
+ "down_block_types": [
34
+ "DownBlock2D",
35
+ "CrossAttnDownBlock2D",
36
+ "CrossAttnDownBlock2D"
37
+ ],
38
+ "downsample_padding": 1,
39
+ "dual_cross_attention": False,
40
+ "encoder_hid_dim": None,
41
+ "encoder_hid_dim_type": None,
42
+ "flip_sin_to_cos": True,
43
+ "freq_shift": 0,
44
+ "in_channels": 4,
45
+ "layers_per_block": 2,
46
+ "mid_block_only_cross_attention": None,
47
+ "mid_block_scale_factor": 1,
48
+ "mid_block_type": "UNetMidBlock2DCrossAttn",
49
+ "norm_eps": 1e-05,
50
+ "norm_num_groups": 32,
51
+ "num_attention_heads": None,
52
+ "num_class_embeds": None,
53
+ "only_cross_attention": False,
54
+ "out_channels": 4,
55
+ "projection_class_embeddings_input_dim": 2816,
56
+ "resnet_out_scale_factor": 1.0,
57
+ "resnet_skip_time_act": False,
58
+ "resnet_time_scale_shift": "default",
59
+ "sample_size": 128,
60
+ "time_cond_proj_dim": None,
61
+ "time_embedding_act_fn": None,
62
+ "time_embedding_dim": None,
63
+ "time_embedding_type": "positional",
64
+ "timestep_post_act": None,
65
+ "transformer_layers_per_block": [
66
+ 1,
67
+ 2,
68
+ 10
69
+ ],
70
+ "up_block_types": [
71
+ "CrossAttnUpBlock2D",
72
+ "CrossAttnUpBlock2D",
73
+ "UpBlock2D"
74
+ ],
75
+ "upcast_attention": None,
76
+ "use_linear_projection": True
77
+ }
78
+
79
+ CONTROLNET_CONFIG = {
80
+ 'in_channels': [128, 128],
81
+ 'out_channels': [128, 256],
82
+ 'groups': [4, 8],
83
+ 'time_embed_dim': 256,
84
+ 'final_out_channels': 320,
85
+ '_use_default_values': ['time_embed_dim', 'groups', 'in_channels', 'final_out_channels', 'out_channels']
86
+ }
87
+
88
 
89
  def get_pipeline(
90
  pretrained_model_name_or_path,
 
102
  ):
103
  pipeline_init_kwargs = {}
104
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  print(f"loading unet from {pretrained_model_name_or_path}")
106
  if os.path.isfile(pretrained_model_name_or_path):
107
  # load unet from local checkpoint
 
111
  unet = UNet2DConditionModel.from_config(UNET_CONFIG)
112
  unet.load_state_dict(unet_sd, strict=True)
113
  else:
114
+ unet = UNet2DConditionModel.from_pretrained(
115
+ pretrained_model_name_or_path,
 
 
 
 
 
 
 
 
 
116
  cache_dir=hf_cache_dir,
117
+ variant=variant,
118
+ torch_dtype=torch.float16,
119
+ use_safetensors=use_safetensors,
120
+ subfolder="unet",
121
  )
 
 
 
 
 
122
  unet = unet.to(dtype=torch.float16)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  pipeline_init_kwargs["unet"] = unet
124
 
125
  if vae_model_name_or_path is not None:
 
127
  vae = AutoencoderKL.from_pretrained(vae_model_name_or_path, cache_dir=hf_cache_dir, torch_dtype=torch.float16).to(device)
128
  pipeline_init_kwargs["vae"] = vae
129
 
130
+ if controlnet_model_name_or_path is not None:
131
+ pipeline_init_kwargs["controlnet"] = ControlNetModel.from_config(CONTROLNET_CONFIG).to(device, dtype=torch.float32) # init
132
+
133
  print(f"loading pipeline from {pretrained_model_name_or_path}")
134
  if os.path.isfile(pretrained_model_name_or_path):
135
  pipeline: StableDiffusionXLControlNeXtPipeline = StableDiffusionXLControlNeXtPipeline.from_single_file(
 
150
  )
151
 
152
  pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)
153
+ if unet_model_name_or_path is not None:
154
+ print(f"loading controlnext unet from {unet_model_name_or_path}")
155
+ pipeline.load_controlnext_unet_weights(
156
+ unet_model_name_or_path,
157
+ load_weight_increasement=load_weight_increasement,
158
+ use_safetensors=True,
159
+ torch_dtype=torch.float16,
160
+ cache_dir=hf_cache_dir,
161
+ )
162
+ if controlnet_model_name_or_path is not None:
163
+ print(f"loading controlnext controlnet from {controlnet_model_name_or_path}")
164
+ pipeline.load_controlnext_controlnet_weights(
165
+ controlnet_model_name_or_path,
166
+ use_safetensors=True,
167
+ torch_dtype=torch.float32,
168
+ cache_dir=hf_cache_dir,
169
+ )
170
  pipeline.set_progress_bar_config()
171
  pipeline = pipeline.to(device, dtype=torch.float16)
172
 
 
176
  pipeline.enable_xformers_memory_efficient_attention()
177
 
178
  gc.collect()
179
+ if str(device) == 'cuda' and torch.cuda.is_available():
180
  torch.cuda.empty_cache()
181
 
182
  return pipeline
utils/utils.py CHANGED
@@ -1,52 +1,5 @@
1
  import math
2
  from typing import Tuple, Union, Optional
3
- from safetensors.torch import load_file
4
- from transformers import PretrainedConfig
5
-
6
-
7
- def count_num_parameters_of_safetensors_model(safetensors_path):
8
- state_dict = load_file(safetensors_path)
9
- return sum(p.numel() for p in state_dict.values())
10
-
11
-
12
- def import_model_class_from_model_name_or_path(
13
- pretrained_model_name_or_path: str, revision: str, subfolder: str = None
14
- ):
15
- text_encoder_config = PretrainedConfig.from_pretrained(
16
- pretrained_model_name_or_path, revision=revision, subfolder=subfolder
17
- )
18
- model_class = text_encoder_config.architectures[0]
19
- if model_class == "CLIPTextModel":
20
- from transformers import CLIPTextModel
21
- return CLIPTextModel
22
- elif model_class == "CLIPTextModelWithProjection":
23
- from transformers import CLIPTextModelWithProjection
24
- return CLIPTextModelWithProjection
25
- else:
26
- raise ValueError(f"{model_class} is not supported.")
27
-
28
-
29
- def fix_clip_text_encoder_position_ids(text_encoder):
30
- if hasattr(text_encoder.text_model.embeddings, "position_ids"):
31
- text_encoder.text_model.embeddings.position_ids = text_encoder.text_model.embeddings.position_ids.long()
32
-
33
-
34
- def load_controlnext_unet_state_dict(unet_sd, controlnext_unet_sd):
35
- assert all(
36
- k in unet_sd for k in controlnext_unet_sd), f"controlnext unet state dict is not compatible with unet state dict, missing keys: {set(controlnext_unet_sd.keys()) - set(unet_sd.keys())}, extra keys: {set(unet_sd.keys()) - set(controlnext_unet_sd.keys())}"
37
- for k in controlnext_unet_sd.keys():
38
- unet_sd[k] = controlnext_unet_sd[k]
39
- return unet_sd
40
-
41
-
42
- def convert_to_controlnext_unet_state_dict(state_dict):
43
- import re
44
- pattern = re.compile(r'.*attn2.*to_out.*')
45
- state_dict = {k: v for k, v in state_dict.items() if pattern.match(k)}
46
- # state_dict = extract_unet_state_dict(state_dict)
47
- if is_sdxl_state_dict(state_dict):
48
- state_dict = convert_sdxl_unet_state_dict_to_diffusers(state_dict)
49
- return state_dict
50
 
51
 
52
  def make_unet_conversion_map():
@@ -166,27 +119,6 @@ def extract_unet_state_dict(state_dict):
166
  return unet_sd
167
 
168
 
169
- def is_sdxl_state_dict(state_dict):
170
- return any(key.startswith('input_blocks') for key in state_dict.keys())
171
-
172
-
173
- def contains_unet_keys(state_dict):
174
- UNET_KEY_PREFIX = "model.diffusion_model."
175
- return any(k.startswith(UNET_KEY_PREFIX) for k in state_dict.keys())
176
-
177
-
178
- def load_safetensors(model, safetensors_path, strict=True, load_weight_increasement=False):
179
- if not load_weight_increasement:
180
- state_dict = load_file(safetensors_path)
181
- model.load_state_dict(state_dict, strict=strict)
182
- else:
183
- state_dict = load_file(safetensors_path)
184
- pretrained_state_dict = model.state_dict()
185
- for k in state_dict.keys():
186
- state_dict[k] = state_dict[k] + pretrained_state_dict[k]
187
- model.load_state_dict(state_dict, strict=False)
188
-
189
-
190
  def log_model_info(model, name):
191
  sd = model.state_dict() if hasattr(model, "state_dict") else model
192
  print(
 
1
  import math
2
  from typing import Tuple, Union, Optional
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
 
5
  def make_unet_conversion_map():
 
119
  return unet_sd
120
 
121
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  def log_model_info(model, name):
123
  sd = model.state_dict() if hasattr(model, "state_dict") else model
124
  print(