HighCWu commited on
Commit
efa09bd
·
1 Parent(s): de3d9dc

init control lora v3.

Browse files
Files changed (9) hide show
  1. README.md +6 -8
  2. app.py +34 -23
  3. app_tile.py +87 -0
  4. model.py +65 -42
  5. pipeline.py +1374 -0
  6. preprocessor.py +2 -1
  7. requirements.txt +3 -13
  8. settings.py +2 -2
  9. unet.py +299 -0
README.md CHANGED
@@ -1,15 +1,13 @@
1
  ---
2
- title: ControlNet V1.1
3
- emoji: 📉
4
- colorFrom: yellow
5
- colorTo: green
6
  sdk: gradio
7
- sdk_version: 3.45.2
8
- python_version: 3.10.11
9
  app_file: app.py
10
- pinned: false
11
  license: mit
12
- suggested_hardware: t4-medium
13
  ---
14
 
15
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Control Lora V3
3
+ emoji: 🖼
4
+ colorFrom: purple
5
+ colorTo: red
6
  sdk: gradio
7
+ sdk_version: 4.26.0
 
8
  app_file: app.py
9
+ pinned: true
10
  license: mit
 
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py CHANGED
@@ -2,9 +2,13 @@
2
 
3
  from __future__ import annotations
4
 
 
5
  import gradio as gr
6
  import torch
7
 
 
 
 
8
  from app_canny import create_demo as create_demo_canny
9
  from app_depth import create_demo as create_demo_depth
10
  from app_ip2p import create_demo as create_demo_ip2p
@@ -17,13 +21,18 @@ from app_scribble_interactive import create_demo as create_demo_scribble_interac
17
  from app_segmentation import create_demo as create_demo_segmentation
18
  from app_shuffle import create_demo as create_demo_shuffle
19
  from app_softedge import create_demo as create_demo_softedge
 
20
  from model import Model
21
  from settings import ALLOW_CHANGING_BASE_MODEL, DEFAULT_MODEL_ID, SHOW_DUPLICATE_BUTTON
22
 
23
- DESCRIPTION = "# ControlNet v1.1"
 
24
 
25
- if not torch.cuda.is_available():
26
- DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
 
 
 
27
 
28
  model = Model(base_model_id=DEFAULT_MODEL_ID, task_name="Canny")
29
 
@@ -37,29 +46,31 @@ with gr.Blocks(css="style.css") as demo:
37
 
38
  with gr.Tabs():
39
  with gr.TabItem("Canny"):
40
- create_demo_canny(model.process_canny)
41
- with gr.TabItem("MLSD"):
42
- create_demo_mlsd(model.process_mlsd)
43
- with gr.TabItem("Scribble"):
44
- create_demo_scribble(model.process_scribble)
45
- with gr.TabItem("Scribble Interactive"):
46
- create_demo_scribble_interactive(model.process_scribble_interactive)
47
- with gr.TabItem("SoftEdge"):
48
- create_demo_softedge(model.process_softedge)
49
  with gr.TabItem("OpenPose"):
50
- create_demo_openpose(model.process_openpose)
51
  with gr.TabItem("Segmentation"):
52
- create_demo_segmentation(model.process_segmentation)
53
  with gr.TabItem("Depth"):
54
- create_demo_depth(model.process_depth)
55
  with gr.TabItem("Normal map"):
56
- create_demo_normal(model.process_normal)
57
- with gr.TabItem("Lineart"):
58
- create_demo_lineart(model.process_lineart)
59
- with gr.TabItem("Content Shuffle"):
60
- create_demo_shuffle(model.process_shuffle)
61
- with gr.TabItem("Instruct Pix2Pix"):
62
- create_demo_ip2p(model.process_ip2p)
 
 
63
 
64
  with gr.Accordion(label="Base model", open=False):
65
  with gr.Row():
@@ -72,7 +83,7 @@ with gr.Blocks(css="style.css") as demo:
72
  new_base_model_id = gr.Text(
73
  label="New base model",
74
  max_lines=1,
75
- placeholder="runwayml/stable-diffusion-v1-5",
76
  info="The base model must be compatible with Stable Diffusion v1.5.",
77
  interactive=ALLOW_CHANGING_BASE_MODEL,
78
  )
 
2
 
3
  from __future__ import annotations
4
 
5
+ import spaces
6
  import gradio as gr
7
  import torch
8
 
9
+ import PIL.Image
10
+ import numpy as np
11
+
12
  from app_canny import create_demo as create_demo_canny
13
  from app_depth import create_demo as create_demo_depth
14
  from app_ip2p import create_demo as create_demo_ip2p
 
21
  from app_segmentation import create_demo as create_demo_segmentation
22
  from app_shuffle import create_demo as create_demo_shuffle
23
  from app_softedge import create_demo as create_demo_softedge
24
+ from app_tile import create_demo as create_demo_tile
25
  from model import Model
26
  from settings import ALLOW_CHANGING_BASE_MODEL, DEFAULT_MODEL_ID, SHOW_DUPLICATE_BUTTON
27
 
28
+ DESCRIPTION = r"""
29
+ # ControlLoRA Version 3: LoRA Is All You Need to Control the Spatial Information of Stable Diffusion
30
 
31
+ <center>
32
+ <a href="https://huggingface.co/HighCWu/control-lora-v3">[Models]</a>
33
+ <a href="https://github.com/HighCWu/control-lora-v3">[Github]</a>
34
+ </center>
35
+ """
36
 
37
  model = Model(base_model_id=DEFAULT_MODEL_ID, task_name="Canny")
38
 
 
46
 
47
  with gr.Tabs():
48
  with gr.TabItem("Canny"):
49
+ create_demo_canny(spaces.GPU(model.process_canny))
50
+ # with gr.TabItem("MLSD"):
51
+ # create_demo_mlsd(spaces.GPU(model.process_mlsd))
52
+ # with gr.TabItem("Scribble"):
53
+ # create_demo_scribble(spaces.GPU(model.process_scribble))
54
+ # with gr.TabItem("Scribble Interactive"):
55
+ # create_demo_scribble_interactive(spaces.GPU(model.process_scribble_interactive))
56
+ # with gr.TabItem("SoftEdge"):
57
+ # create_demo_softedge(spaces.GPU(model.process_softedge))
58
  with gr.TabItem("OpenPose"):
59
+ create_demo_openpose(spaces.GPU(model.process_openpose))
60
  with gr.TabItem("Segmentation"):
61
+ create_demo_segmentation(spaces.GPU(model.process_segmentation))
62
  with gr.TabItem("Depth"):
63
+ create_demo_depth(spaces.GPU(model.process_depth))
64
  with gr.TabItem("Normal map"):
65
+ create_demo_normal(spaces.GPU(model.process_normal))
66
+ # with gr.TabItem("Lineart"):
67
+ # create_demo_lineart(spaces.GPU(model.process_lineart))
68
+ # with gr.TabItem("Content Shuffle"):
69
+ # create_demo_shuffle(spaces.GPU(model.process_shuffle))
70
+ # with gr.TabItem("Instruct Pix2Pix"):
71
+ # create_demo_ip2p(spaces.GPU(model.process_ip2p))
72
+ with gr.TabItem("Tile"):
73
+ create_demo_tile(spaces.GPU(model.process_tile))
74
 
75
  with gr.Accordion(label="Base model", open=False):
76
  with gr.Row():
 
83
  new_base_model_id = gr.Text(
84
  label="New base model",
85
  max_lines=1,
86
+ placeholder="SG161222/Realistic_Vision_V4.0_noVAE",
87
  info="The base model must be compatible with Stable Diffusion v1.5.",
88
  interactive=ALLOW_CHANGING_BASE_MODEL,
89
  )
app_tile.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ import gradio as gr
4
+
5
+ from settings import (
6
+ DEFAULT_IMAGE_RESOLUTION,
7
+ DEFAULT_NUM_IMAGES,
8
+ MAX_IMAGE_RESOLUTION,
9
+ MAX_NUM_IMAGES,
10
+ MAX_SEED,
11
+ )
12
+ from utils import randomize_seed_fn
13
+
14
+
15
+ def create_demo(process):
16
+ with gr.Blocks() as demo:
17
+ with gr.Row():
18
+ with gr.Column():
19
+ image = gr.Image()
20
+ prompt = gr.Textbox(label="Prompt")
21
+ run_button = gr.Button("Run")
22
+ with gr.Accordion("Advanced options", open=False):
23
+ num_samples = gr.Slider(
24
+ label="Number of images", minimum=1, maximum=MAX_NUM_IMAGES, value=DEFAULT_NUM_IMAGES, step=1
25
+ )
26
+ image_resolution = gr.Slider(
27
+ label="Image resolution",
28
+ minimum=256,
29
+ maximum=MAX_IMAGE_RESOLUTION,
30
+ value=DEFAULT_IMAGE_RESOLUTION,
31
+ step=256,
32
+ )
33
+ num_steps = gr.Slider(label="Number of steps", minimum=1, maximum=100, value=20, step=1)
34
+ guidance_scale = gr.Slider(label="Guidance scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
35
+ seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
36
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
37
+ a_prompt = gr.Textbox(label="Additional prompt", value="best quality, extremely detailed")
38
+ n_prompt = gr.Textbox(
39
+ label="Negative prompt",
40
+ value="longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality",
41
+ )
42
+ with gr.Column():
43
+ result = gr.Gallery(label="Output", show_label=False, columns=2, object_fit="scale-down")
44
+ inputs = [
45
+ image,
46
+ prompt,
47
+ a_prompt,
48
+ n_prompt,
49
+ num_samples,
50
+ image_resolution,
51
+ num_steps,
52
+ guidance_scale,
53
+ seed,
54
+ ]
55
+ prompt.submit(
56
+ fn=randomize_seed_fn,
57
+ inputs=[seed, randomize_seed],
58
+ outputs=seed,
59
+ queue=False,
60
+ api_name=False,
61
+ ).then(
62
+ fn=process,
63
+ inputs=inputs,
64
+ outputs=result,
65
+ api_name=False,
66
+ )
67
+ run_button.click(
68
+ fn=randomize_seed_fn,
69
+ inputs=[seed, randomize_seed],
70
+ outputs=seed,
71
+ queue=False,
72
+ api_name=False,
73
+ ).then(
74
+ fn=process,
75
+ inputs=inputs,
76
+ outputs=result,
77
+ api_name="tile",
78
+ )
79
+ return demo
80
+
81
+
82
+ if __name__ == "__main__":
83
+ from model import Model
84
+
85
+ model = Model(task_name="tile")
86
+ demo = create_demo(model.process_tile)
87
+ demo.queue().launch()
model.py CHANGED
@@ -7,59 +7,54 @@ import PIL.Image
7
  import torch
8
  from controlnet_aux.util import HWC3
9
  from diffusers import (
10
- ControlNetModel,
11
- DiffusionPipeline,
12
- StableDiffusionControlNetPipeline,
13
  UniPCMultistepScheduler,
14
  )
 
 
15
 
16
  from cv_utils import resize_image
17
  from preprocessor import Preprocessor
18
  from settings import MAX_IMAGE_RESOLUTION, MAX_NUM_IMAGES
 
19
 
20
- CONTROLNET_MODEL_IDS = {
21
- "Openpose": "lllyasviel/control_v11p_sd15_openpose",
22
- "Canny": "lllyasviel/control_v11p_sd15_canny",
23
- "MLSD": "lllyasviel/control_v11p_sd15_mlsd",
24
- "scribble": "lllyasviel/control_v11p_sd15_scribble",
25
- "softedge": "lllyasviel/control_v11p_sd15_softedge",
26
- "segmentation": "lllyasviel/control_v11p_sd15_seg",
27
- "depth": "lllyasviel/control_v11f1p_sd15_depth",
28
- "NormalBae": "lllyasviel/control_v11p_sd15_normalbae",
29
- "lineart": "lllyasviel/control_v11p_sd15_lineart",
30
- "lineart_anime": "lllyasviel/control_v11p_sd15s2_lineart_anime",
31
- "shuffle": "lllyasviel/control_v11e_sd15_shuffle",
32
- "ip2p": "lllyasviel/control_v11e_sd15_ip2p",
33
- "inpaint": "lllyasviel/control_v11e_sd15_inpaint",
34
- }
35
 
36
-
37
- def download_all_controlnet_weights() -> None:
38
- for model_id in CONTROLNET_MODEL_IDS.values():
39
- ControlNetModel.from_pretrained(model_id)
 
 
 
 
 
40
 
41
 
42
  class Model:
43
- def __init__(self, base_model_id: str = "runwayml/stable-diffusion-v1-5", task_name: str = "Canny"):
44
  self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
45
  self.base_model_id = ""
46
  self.task_name = ""
47
- self.pipe = self.load_pipe(base_model_id, task_name)
48
  self.preprocessor = Preprocessor()
49
 
50
- def load_pipe(self, base_model_id: str, task_name) -> DiffusionPipeline:
51
  if (
52
  base_model_id == self.base_model_id
53
- and task_name == self.task_name
54
  and hasattr(self, "pipe")
55
  and self.pipe is not None
56
  ):
 
 
57
  return self.pipe
58
- model_id = CONTROLNET_MODEL_IDS[task_name]
59
- controlnet = ControlNetModel.from_pretrained(model_id, torch_dtype=torch.float16)
60
- pipe = StableDiffusionControlNetPipeline.from_pretrained(
61
- base_model_id, safety_checker=None, controlnet=controlnet, torch_dtype=torch.float16
 
 
62
  )
 
 
63
  pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
64
  if self.device.type == "cuda":
65
  pipe.enable_xformers_memory_efficient_attention()
@@ -74,7 +69,8 @@ class Model:
74
  if not base_model_id or base_model_id == self.base_model_id:
75
  return self.base_model_id
76
  del self.pipe
77
- torch.cuda.empty_cache()
 
78
  gc.collect()
79
  try:
80
  self.pipe = self.load_pipe(base_model_id, self.task_name)
@@ -85,16 +81,8 @@ class Model:
85
  def load_controlnet_weight(self, task_name: str) -> None:
86
  if task_name == self.task_name:
87
  return
88
- if self.pipe is not None and hasattr(self.pipe, "controlnet"):
89
- del self.pipe.controlnet
90
- torch.cuda.empty_cache()
91
- gc.collect()
92
- model_id = CONTROLNET_MODEL_IDS[task_name]
93
- controlnet = ControlNetModel.from_pretrained(model_id, torch_dtype=torch.float16)
94
- controlnet.to(self.device)
95
- torch.cuda.empty_cache()
96
- gc.collect()
97
- self.pipe.controlnet = controlnet
98
  self.task_name = task_name
99
 
100
  def get_prompt(self, prompt: str, additional_prompt: str) -> str:
@@ -104,7 +92,7 @@ class Model:
104
  prompt = f"{prompt}, {additional_prompt}"
105
  return prompt
106
 
107
- @torch.autocast("cuda")
108
  def run_pipe(
109
  self,
110
  prompt: str,
@@ -672,3 +660,38 @@ class Model:
672
  seed=seed,
673
  )
674
  return [control_image] + results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  import torch
8
  from controlnet_aux.util import HWC3
9
  from diffusers import (
 
 
 
10
  UniPCMultistepScheduler,
11
  )
12
+ from unet import UNet2DConditionModelEx
13
+ from pipeline import StableDiffusionControlLoraV3Pipeline
14
 
15
  from cv_utils import resize_image
16
  from preprocessor import Preprocessor
17
  from settings import MAX_IMAGE_RESOLUTION, MAX_NUM_IMAGES
18
+ from collections import OrderedDict
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
+ CONTROL_LORA_V3_MODEL_IDS = OrderedDict([
22
+ ("Openpose", "sd-control-lora-v3-pose-half-rank128-conv_in-rank128"),
23
+ ("Canny", "sd-control-lora-v3-canny-half_skip_attn-rank16-conv_in-rank64"),
24
+ ("segmentation", "sd-control-lora-v3-segmentation-half_skip_attn-rank128-conv_in-rank128"),
25
+ ("depth", "lllyasviel/control_v11f1p_sd15_depth"),
26
+ ("NormalBae", "sd-control-lora-v3-normal-half-rank32-conv_in-rank128"),
27
+ ("depth", "sd-control-lora-v3-depth-half-rank8-conv_in-rank128"),
28
+ ("Tile", "sd-control-lora-v3-tile-half_skip_attn-rank16-conv_in-rank64"),
29
+ ])
30
 
31
 
32
  class Model:
33
+ def __init__(self, base_model_id: str = "SG161222/Realistic_Vision_V4.0_noVAE", task_name: str = "Canny"):
34
  self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
35
  self.base_model_id = ""
36
  self.task_name = ""
37
+ self.pipe: StableDiffusionControlLoraV3Pipeline = self.load_pipe(base_model_id, task_name)
38
  self.preprocessor = Preprocessor()
39
 
40
+ def load_pipe(self, base_model_id: str, task_name) -> StableDiffusionControlLoraV3Pipeline:
41
  if (
42
  base_model_id == self.base_model_id
 
43
  and hasattr(self, "pipe")
44
  and self.pipe is not None
45
  ):
46
+ unet: UNet2DConditionModelEx = self.pipe.unet
47
+ unet.activate_adapters([task_name])
48
  return self.pipe
49
+ unet: UNet2DConditionModelEx = UNet2DConditionModelEx.from_pretrained(
50
+ base_model_id, subfolder="unet", torch_dtype=torch.float16
51
+ )
52
+ unet.add_extra_conditions(["Placeholder"])
53
+ pipe: StableDiffusionControlLoraV3Pipeline = StableDiffusionControlLoraV3Pipeline.from_pretrained(
54
+ base_model_id, safety_checker=None, unet=unet, torch_dtype=torch.float16
55
  )
56
+ for _task_name, subfolder in CONTROL_LORA_V3_MODEL_IDS.items():
57
+ pipe.load_lora_weights("HighCWu/control-lora-v3", adapter_name=_task_name, subfolder=subfolder)
58
  pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
59
  if self.device.type == "cuda":
60
  pipe.enable_xformers_memory_efficient_attention()
 
69
  if not base_model_id or base_model_id == self.base_model_id:
70
  return self.base_model_id
71
  del self.pipe
72
+ if self.device.type == "cuda":
73
+ torch.cuda.empty_cache()
74
  gc.collect()
75
  try:
76
  self.pipe = self.load_pipe(base_model_id, self.task_name)
 
81
  def load_controlnet_weight(self, task_name: str) -> None:
82
  if task_name == self.task_name:
83
  return
84
+ unet: UNet2DConditionModelEx = self.pipe.unet
85
+ unet.activate_adapters([task_name])
 
 
 
 
 
 
 
 
86
  self.task_name = task_name
87
 
88
  def get_prompt(self, prompt: str, additional_prompt: str) -> str:
 
92
  prompt = f"{prompt}, {additional_prompt}"
93
  return prompt
94
 
95
+ # @torch.autocast("cuda")
96
  def run_pipe(
97
  self,
98
  prompt: str,
 
660
  seed=seed,
661
  )
662
  return [control_image] + results
663
+
664
+ @torch.inference_mode()
665
+ def process_tile(
666
+ self,
667
+ image: np.ndarray,
668
+ prompt: str,
669
+ additional_prompt: str,
670
+ negative_prompt: str,
671
+ num_images: int,
672
+ image_resolution: int,
673
+ num_steps: int,
674
+ guidance_scale: float,
675
+ seed: int,
676
+ ) -> list[PIL.Image.Image]:
677
+ if image is None:
678
+ raise ValueError
679
+ if image_resolution > MAX_IMAGE_RESOLUTION:
680
+ raise ValueError
681
+ if num_images > MAX_NUM_IMAGES:
682
+ raise ValueError
683
+
684
+ image = HWC3(image)
685
+ image = resize_image(image, resolution=image_resolution)
686
+ control_image = PIL.Image.fromarray(image)
687
+ self.load_controlnet_weight("Tile")
688
+ results = self.run_pipe(
689
+ prompt=self.get_prompt(prompt, additional_prompt),
690
+ negative_prompt=negative_prompt,
691
+ control_image=control_image,
692
+ num_images=num_images,
693
+ num_steps=num_steps,
694
+ guidance_scale=guidance_scale,
695
+ seed=seed,
696
+ )
697
+ return [control_image] + results
pipeline.py ADDED
@@ -0,0 +1,1374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
3
+
4
+ import numpy as np
5
+ import PIL.Image
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
9
+
10
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
11
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
12
+ from diffusers.loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
13
+ from diffusers.models import AutoencoderKL, ImageProjection
14
+ from diffusers.models.lora import adjust_lora_scale_text_encoder
15
+ from diffusers.schedulers import KarrasDiffusionSchedulers
16
+ from diffusers.utils import (
17
+ USE_PEFT_BACKEND,
18
+ deprecate,
19
+ logging,
20
+ replace_example_docstring,
21
+ scale_lora_layers,
22
+ unscale_lora_layers,
23
+ )
24
+ from diffusers.utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
25
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
26
+ from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
27
+ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
28
+ from model import UNet2DConditionModelEx
29
+
30
+
31
+ from huggingface_hub.utils import validate_hf_hub_args
32
+
33
+
34
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
35
+
36
+
37
+ EXAMPLE_DOC_STRING = """
38
+ Examples:
39
+ ```py
40
+ >>> # !pip install opencv-python transformers accelerate
41
+ >>> from diffusers import UniPCMultistepScheduler
42
+ >>> from diffusers.utils import load_image
43
+ >>> from model import UNet2DConditionModelEx
44
+ >>> from pipeline import StableDiffusionControlLoraV3Pipeline
45
+ >>> import numpy as np
46
+ >>> import torch
47
+
48
+ >>> import cv2
49
+ >>> from PIL import Image
50
+
51
+ >>> # download an image
52
+ >>> image = load_image(
53
+ ... "https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png"
54
+ ... )
55
+ >>> image = np.array(image)
56
+
57
+ >>> # get canny image
58
+ >>> image = cv2.Canny(image, 100, 200)
59
+ >>> image = image[:, :, None]
60
+ >>> image = np.concatenate([image, image, image], axis=2)
61
+ >>> canny_image = Image.fromarray(image)
62
+
63
+ >>> # load stable diffusion v1-5 and control-lora-v3
64
+ >>> unet: UNet2DConditionModelEx = UNet2DConditionModelEx.from_pretrained(
65
+ ... "runwayml/stable-diffusion-v1-5", subfolder="unet", torch_dtype=torch.float16
66
+ ... )
67
+ >>> unet = unet.add_extra_conditions(["canny"])
68
+ >>> pipe = StableDiffusionControlLoraV3Pipeline.from_pretrained(
69
+ ... "runwayml/stable-diffusion-v1-5", unet=unet, torch_dtype=torch.float16
70
+ ... )
71
+ >>> # load attention processors
72
+ >>> pipe.load_lora_weights("HighCWu/sd-control-lora-v3-canny")
73
+
74
+ >>> # speed up diffusion process with faster scheduler and memory optimization
75
+ >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
76
+ >>> # remove following line if xformers is not installed
77
+ >>> pipe.enable_xformers_memory_efficient_attention()
78
+
79
+ >>> pipe.enable_model_cpu_offload()
80
+
81
+ >>> # generate image
82
+ >>> generator = torch.manual_seed(0)
83
+ >>> image = pipe(
84
+ ... "futuristic-looking woman", num_inference_steps=20, generator=generator, image=canny_image
85
+ ... ).images[0]
86
+ ```
87
+ """
88
+
89
+
90
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
91
+ def retrieve_timesteps(
92
+ scheduler,
93
+ num_inference_steps: Optional[int] = None,
94
+ device: Optional[Union[str, torch.device]] = None,
95
+ timesteps: Optional[List[int]] = None,
96
+ sigmas: Optional[List[float]] = None,
97
+ **kwargs,
98
+ ):
99
+ """
100
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
101
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
102
+
103
+ Args:
104
+ scheduler (`SchedulerMixin`):
105
+ The scheduler to get timesteps from.
106
+ num_inference_steps (`int`):
107
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
108
+ must be `None`.
109
+ device (`str` or `torch.device`, *optional*):
110
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
111
+ timesteps (`List[int]`, *optional*):
112
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
113
+ `num_inference_steps` and `sigmas` must be `None`.
114
+ sigmas (`List[float]`, *optional*):
115
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
116
+ `num_inference_steps` and `timesteps` must be `None`.
117
+
118
+ Returns:
119
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
120
+ second element is the number of inference steps.
121
+ """
122
+ if timesteps is not None and sigmas is not None:
123
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
124
+ if timesteps is not None:
125
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
126
+ if not accepts_timesteps:
127
+ raise ValueError(
128
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
129
+ f" timestep schedules. Please check whether you are using the correct scheduler."
130
+ )
131
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
132
+ timesteps = scheduler.timesteps
133
+ num_inference_steps = len(timesteps)
134
+ elif sigmas is not None:
135
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
136
+ if not accept_sigmas:
137
+ raise ValueError(
138
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
139
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
140
+ )
141
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
142
+ timesteps = scheduler.timesteps
143
+ num_inference_steps = len(timesteps)
144
+ else:
145
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
146
+ timesteps = scheduler.timesteps
147
+ return timesteps, num_inference_steps
148
+
149
+
150
+ class StableDiffusionControlLoraV3Pipeline(
151
+ DiffusionPipeline,
152
+ StableDiffusionMixin,
153
+ TextualInversionLoaderMixin,
154
+ LoraLoaderMixin,
155
+ IPAdapterMixin,
156
+ FromSingleFileMixin,
157
+ ):
158
+ r"""
159
+ Pipeline for text-to-image generation using Stable Diffusion with extra condition guidance.
160
+
161
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
162
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
163
+
164
+ The pipeline also inherits the following loading methods:
165
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
166
+ - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
167
+ - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
168
+ - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
169
+ - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
170
+
171
+ Args:
172
+ vae ([`AutoencoderKL`]):
173
+ Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
174
+ text_encoder ([`~transformers.CLIPTextModel`]):
175
+ Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
176
+ tokenizer ([`~transformers.CLIPTokenizer`]):
177
+ A `CLIPTokenizer` to tokenize text.
178
+ unet ([`UNet2DConditionModelEx`]):
179
+ A `UNet2DConditionModelEx` to denoise the encoded image latents with extra conditions.
180
+ scheduler ([`SchedulerMixin`]):
181
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
182
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
183
+ safety_checker ([`StableDiffusionSafetyChecker`]):
184
+ Classification module that estimates whether generated images could be considered offensive or harmful.
185
+ Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
186
+ about a model's potential harms.
187
+ feature_extractor ([`~transformers.CLIPImageProcessor`]):
188
+ A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
189
+ """
190
+
191
+ model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
192
+ _optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
193
+ _exclude_from_cpu_offload = ["safety_checker"]
194
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
195
+
196
+ def __init__(
197
+ self,
198
+ vae: AutoencoderKL,
199
+ text_encoder: CLIPTextModel,
200
+ tokenizer: CLIPTokenizer,
201
+ unet: UNet2DConditionModelEx,
202
+ scheduler: KarrasDiffusionSchedulers,
203
+ safety_checker: StableDiffusionSafetyChecker,
204
+ feature_extractor: CLIPImageProcessor,
205
+ image_encoder: CLIPVisionModelWithProjection = None,
206
+ requires_safety_checker: bool = True,
207
+ ):
208
+ super().__init__()
209
+
210
+ if safety_checker is None and requires_safety_checker:
211
+ logger.warning(
212
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
213
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
214
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
215
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
216
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
217
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
218
+ )
219
+
220
+ if safety_checker is not None and feature_extractor is None:
221
+ raise ValueError(
222
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
223
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
224
+ )
225
+
226
+ self.register_modules(
227
+ vae=vae,
228
+ text_encoder=text_encoder,
229
+ tokenizer=tokenizer,
230
+ unet=unet,
231
+ scheduler=scheduler,
232
+ safety_checker=safety_checker,
233
+ feature_extractor=feature_extractor,
234
+ image_encoder=image_encoder,
235
+ )
236
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
237
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
238
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
239
+
240
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
241
+ def _encode_prompt(
242
+ self,
243
+ prompt,
244
+ device,
245
+ num_images_per_prompt,
246
+ do_classifier_free_guidance,
247
+ negative_prompt=None,
248
+ prompt_embeds: Optional[torch.Tensor] = None,
249
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
250
+ lora_scale: Optional[float] = None,
251
+ **kwargs,
252
+ ):
253
+ deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
254
+ deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
255
+
256
+ prompt_embeds_tuple = self.encode_prompt(
257
+ prompt=prompt,
258
+ device=device,
259
+ num_images_per_prompt=num_images_per_prompt,
260
+ do_classifier_free_guidance=do_classifier_free_guidance,
261
+ negative_prompt=negative_prompt,
262
+ prompt_embeds=prompt_embeds,
263
+ negative_prompt_embeds=negative_prompt_embeds,
264
+ lora_scale=lora_scale,
265
+ **kwargs,
266
+ )
267
+
268
+ # concatenate for backwards comp
269
+ prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
270
+
271
+ return prompt_embeds
272
+
273
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
274
+ def encode_prompt(
275
+ self,
276
+ prompt,
277
+ device,
278
+ num_images_per_prompt,
279
+ do_classifier_free_guidance,
280
+ negative_prompt=None,
281
+ prompt_embeds: Optional[torch.Tensor] = None,
282
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
283
+ lora_scale: Optional[float] = None,
284
+ clip_skip: Optional[int] = None,
285
+ ):
286
+ r"""
287
+ Encodes the prompt into text encoder hidden states.
288
+
289
+ Args:
290
+ prompt (`str` or `List[str]`, *optional*):
291
+ prompt to be encoded
292
+ device: (`torch.device`):
293
+ torch device
294
+ num_images_per_prompt (`int`):
295
+ number of images that should be generated per prompt
296
+ do_classifier_free_guidance (`bool`):
297
+ whether to use classifier free guidance or not
298
+ negative_prompt (`str` or `List[str]`, *optional*):
299
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
300
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
301
+ less than `1`).
302
+ prompt_embeds (`torch.Tensor`, *optional*):
303
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
304
+ provided, text embeddings will be generated from `prompt` input argument.
305
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
306
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
307
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
308
+ argument.
309
+ lora_scale (`float`, *optional*):
310
+ A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
311
+ clip_skip (`int`, *optional*):
312
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
313
+ the output of the pre-final layer will be used for computing the prompt embeddings.
314
+ """
315
+ # set lora scale so that monkey patched LoRA
316
+ # function of text encoder can correctly access it
317
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
318
+ self._lora_scale = lora_scale
319
+
320
+ # dynamically adjust the LoRA scale
321
+ if not USE_PEFT_BACKEND:
322
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
323
+ else:
324
+ scale_lora_layers(self.text_encoder, lora_scale)
325
+
326
+ if prompt is not None and isinstance(prompt, str):
327
+ batch_size = 1
328
+ elif prompt is not None and isinstance(prompt, list):
329
+ batch_size = len(prompt)
330
+ else:
331
+ batch_size = prompt_embeds.shape[0]
332
+
333
+ if prompt_embeds is None:
334
+ # textual inversion: process multi-vector tokens if necessary
335
+ if isinstance(self, TextualInversionLoaderMixin):
336
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
337
+
338
+ text_inputs = self.tokenizer(
339
+ prompt,
340
+ padding="max_length",
341
+ max_length=self.tokenizer.model_max_length,
342
+ truncation=True,
343
+ return_tensors="pt",
344
+ )
345
+ text_input_ids = text_inputs.input_ids
346
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
347
+
348
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
349
+ text_input_ids, untruncated_ids
350
+ ):
351
+ removed_text = self.tokenizer.batch_decode(
352
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
353
+ )
354
+ logger.warning(
355
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
356
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
357
+ )
358
+
359
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
360
+ attention_mask = text_inputs.attention_mask.to(device)
361
+ else:
362
+ attention_mask = None
363
+
364
+ if clip_skip is None:
365
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
366
+ prompt_embeds = prompt_embeds[0]
367
+ else:
368
+ prompt_embeds = self.text_encoder(
369
+ text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
370
+ )
371
+ # Access the `hidden_states` first, that contains a tuple of
372
+ # all the hidden states from the encoder layers. Then index into
373
+ # the tuple to access the hidden states from the desired layer.
374
+ prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
375
+ # We also need to apply the final LayerNorm here to not mess with the
376
+ # representations. The `last_hidden_states` that we typically use for
377
+ # obtaining the final prompt representations passes through the LayerNorm
378
+ # layer.
379
+ prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
380
+
381
+ if self.text_encoder is not None:
382
+ prompt_embeds_dtype = self.text_encoder.dtype
383
+ elif self.unet is not None:
384
+ prompt_embeds_dtype = self.unet.dtype
385
+ else:
386
+ prompt_embeds_dtype = prompt_embeds.dtype
387
+
388
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
389
+
390
+ bs_embed, seq_len, _ = prompt_embeds.shape
391
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
392
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
393
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
394
+
395
+ # get unconditional embeddings for classifier free guidance
396
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
397
+ uncond_tokens: List[str]
398
+ if negative_prompt is None:
399
+ uncond_tokens = [""] * batch_size
400
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
401
+ raise TypeError(
402
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
403
+ f" {type(prompt)}."
404
+ )
405
+ elif isinstance(negative_prompt, str):
406
+ uncond_tokens = [negative_prompt]
407
+ elif batch_size != len(negative_prompt):
408
+ raise ValueError(
409
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
410
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
411
+ " the batch size of `prompt`."
412
+ )
413
+ else:
414
+ uncond_tokens = negative_prompt
415
+
416
+ # textual inversion: process multi-vector tokens if necessary
417
+ if isinstance(self, TextualInversionLoaderMixin):
418
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
419
+
420
+ max_length = prompt_embeds.shape[1]
421
+ uncond_input = self.tokenizer(
422
+ uncond_tokens,
423
+ padding="max_length",
424
+ max_length=max_length,
425
+ truncation=True,
426
+ return_tensors="pt",
427
+ )
428
+
429
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
430
+ attention_mask = uncond_input.attention_mask.to(device)
431
+ else:
432
+ attention_mask = None
433
+
434
+ negative_prompt_embeds = self.text_encoder(
435
+ uncond_input.input_ids.to(device),
436
+ attention_mask=attention_mask,
437
+ )
438
+ negative_prompt_embeds = negative_prompt_embeds[0]
439
+
440
+ if do_classifier_free_guidance:
441
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
442
+ seq_len = negative_prompt_embeds.shape[1]
443
+
444
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
445
+
446
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
447
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
448
+
449
+ if self.text_encoder is not None:
450
+ if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
451
+ # Retrieve the original scale by scaling back the LoRA layers
452
+ unscale_lora_layers(self.text_encoder, lora_scale)
453
+
454
+ return prompt_embeds, negative_prompt_embeds
455
+
456
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
457
+ def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
458
+ dtype = next(self.image_encoder.parameters()).dtype
459
+
460
+ if not isinstance(image, torch.Tensor):
461
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
462
+
463
+ image = image.to(device=device, dtype=dtype)
464
+ if output_hidden_states:
465
+ image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
466
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
467
+ uncond_image_enc_hidden_states = self.image_encoder(
468
+ torch.zeros_like(image), output_hidden_states=True
469
+ ).hidden_states[-2]
470
+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
471
+ num_images_per_prompt, dim=0
472
+ )
473
+ return image_enc_hidden_states, uncond_image_enc_hidden_states
474
+ else:
475
+ image_embeds = self.image_encoder(image).image_embeds
476
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
477
+ uncond_image_embeds = torch.zeros_like(image_embeds)
478
+
479
+ return image_embeds, uncond_image_embeds
480
+
481
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
482
+ def prepare_ip_adapter_image_embeds(
483
+ self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
484
+ ):
485
+ if ip_adapter_image_embeds is None:
486
+ if not isinstance(ip_adapter_image, list):
487
+ ip_adapter_image = [ip_adapter_image]
488
+
489
+ if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
490
+ raise ValueError(
491
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
492
+ )
493
+
494
+ image_embeds = []
495
+ for single_ip_adapter_image, image_proj_layer in zip(
496
+ ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
497
+ ):
498
+ output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
499
+ single_image_embeds, single_negative_image_embeds = self.encode_image(
500
+ single_ip_adapter_image, device, 1, output_hidden_state
501
+ )
502
+ single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
503
+ single_negative_image_embeds = torch.stack(
504
+ [single_negative_image_embeds] * num_images_per_prompt, dim=0
505
+ )
506
+
507
+ if do_classifier_free_guidance:
508
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
509
+ single_image_embeds = single_image_embeds.to(device)
510
+
511
+ image_embeds.append(single_image_embeds)
512
+ else:
513
+ repeat_dims = [1]
514
+ image_embeds = []
515
+ for single_image_embeds in ip_adapter_image_embeds:
516
+ if do_classifier_free_guidance:
517
+ single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
518
+ single_image_embeds = single_image_embeds.repeat(
519
+ num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
520
+ )
521
+ single_negative_image_embeds = single_negative_image_embeds.repeat(
522
+ num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
523
+ )
524
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
525
+ else:
526
+ single_image_embeds = single_image_embeds.repeat(
527
+ num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
528
+ )
529
+ image_embeds.append(single_image_embeds)
530
+
531
+ return image_embeds
532
+
533
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
534
+ def run_safety_checker(self, image, device, dtype):
535
+ if self.safety_checker is None:
536
+ has_nsfw_concept = None
537
+ else:
538
+ if torch.is_tensor(image):
539
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
540
+ else:
541
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
542
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
543
+ image, has_nsfw_concept = self.safety_checker(
544
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
545
+ )
546
+ return image, has_nsfw_concept
547
+
548
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
549
+ def decode_latents(self, latents):
550
+ deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
551
+ deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
552
+
553
+ latents = 1 / self.vae.config.scaling_factor * latents
554
+ image = self.vae.decode(latents, return_dict=False)[0]
555
+ image = (image / 2 + 0.5).clamp(0, 1)
556
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
557
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
558
+ return image
559
+
560
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
561
+ def prepare_extra_step_kwargs(self, generator, eta):
562
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
563
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
564
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
565
+ # and should be between [0, 1]
566
+
567
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
568
+ extra_step_kwargs = {}
569
+ if accepts_eta:
570
+ extra_step_kwargs["eta"] = eta
571
+
572
+ # check if the scheduler accepts generator
573
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
574
+ if accepts_generator:
575
+ extra_step_kwargs["generator"] = generator
576
+ return extra_step_kwargs
577
+
578
+ def check_inputs(
579
+ self,
580
+ prompt,
581
+ image,
582
+ callback_steps,
583
+ negative_prompt=None,
584
+ prompt_embeds=None,
585
+ negative_prompt_embeds=None,
586
+ ip_adapter_image=None,
587
+ ip_adapter_image_embeds=None,
588
+ extra_condition_scale=1.0,
589
+ control_guidance_start=0.0,
590
+ control_guidance_end=1.0,
591
+ callback_on_step_end_tensor_inputs=None,
592
+ ):
593
+ if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
594
+ raise ValueError(
595
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
596
+ f" {type(callback_steps)}."
597
+ )
598
+
599
+ if callback_on_step_end_tensor_inputs is not None and not all(
600
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
601
+ ):
602
+ raise ValueError(
603
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
604
+ )
605
+
606
+ if prompt is not None and prompt_embeds is not None:
607
+ raise ValueError(
608
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
609
+ " only forward one of the two."
610
+ )
611
+ elif prompt is None and prompt_embeds is None:
612
+ raise ValueError(
613
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
614
+ )
615
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
616
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
617
+
618
+ if negative_prompt is not None and negative_prompt_embeds is not None:
619
+ raise ValueError(
620
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
621
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
622
+ )
623
+
624
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
625
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
626
+ raise ValueError(
627
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
628
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
629
+ f" {negative_prompt_embeds.shape}."
630
+ )
631
+
632
+ # Check `image`
633
+ unet: UNet2DConditionModelEx = self.unet._orig_mod if is_compiled_module(self.unet) else self.unet
634
+ num_extra_conditions = len(unet.extra_condition_names)
635
+ if num_extra_conditions == 1:
636
+ self.check_image(image, prompt, prompt_embeds)
637
+ elif num_extra_conditions > 1:
638
+ if not isinstance(image, list):
639
+ raise TypeError("For multiple extra conditions: `image` must be type `list`")
640
+
641
+ # When `image` is a nested list:
642
+ # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]])
643
+ elif any(isinstance(i, list) for i in image):
644
+ transposed_image = [list(t) for t in zip(*image)]
645
+ if len(transposed_image) != num_extra_conditions:
646
+ raise ValueError(
647
+ f"For multiple extra conditions: if you pass`image` as a list of list, each sublist must have the same length as the number of extra conditions, but the sublists in `image` got {len(transposed_image)} images and {num_extra_conditions} extra conditions."
648
+ )
649
+ for image_ in transposed_image:
650
+ self.check_image(image_, prompt, prompt_embeds)
651
+ elif len(image) != num_extra_conditions:
652
+ raise ValueError(
653
+ f"For multiple extra conditions: `image` must have the same length as the number of extra conditions, but got {len(image)} images and {num_extra_conditions} extra conditions."
654
+ )
655
+ else:
656
+ for image_ in image:
657
+ self.check_image(image_, prompt, prompt_embeds)
658
+ else:
659
+ assert False
660
+
661
+ # Check `extra_condition_scale`
662
+ if num_extra_conditions == 1:
663
+ if not isinstance(extra_condition_scale, float):
664
+ raise TypeError("For single extra condition: `extra_condition_scale` must be type `float`.")
665
+ elif num_extra_conditions >= 1:
666
+ if isinstance(extra_condition_scale, list):
667
+ if any(isinstance(i, list) for i in extra_condition_scale):
668
+ raise ValueError(
669
+ "A single batch of varying conditioning scale settings (e.g. [[1.0, 0.5], [0.2, 0.8]]) is not supported at the moment. "
670
+ "The conditioning scale must be fixed across the batch."
671
+ )
672
+ elif isinstance(extra_condition_scale, list) and len(extra_condition_scale) != num_extra_conditions:
673
+ raise ValueError(
674
+ "For multiple extra conditions: When `extra_condition_scale` is specified as `list`, it must have"
675
+ " the same length as the number of extra conditions"
676
+ )
677
+ else:
678
+ assert False
679
+
680
+ if not isinstance(control_guidance_start, (tuple, list)):
681
+ control_guidance_start = [control_guidance_start]
682
+
683
+ if not isinstance(control_guidance_end, (tuple, list)):
684
+ control_guidance_end = [control_guidance_end]
685
+
686
+ if len(control_guidance_start) != len(control_guidance_end):
687
+ raise ValueError(
688
+ f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list."
689
+ )
690
+
691
+ if num_extra_conditions > 1:
692
+ if len(control_guidance_start) != num_extra_conditions:
693
+ raise ValueError(
694
+ f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {num_extra_conditions} extra conditions available. Make sure to provide {num_extra_conditions}."
695
+ )
696
+
697
+ for start, end in zip(control_guidance_start, control_guidance_end):
698
+ if start >= end:
699
+ raise ValueError(
700
+ f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}."
701
+ )
702
+ if start < 0.0:
703
+ raise ValueError(f"control guidance start: {start} can't be smaller than 0.")
704
+ if end > 1.0:
705
+ raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
706
+
707
+ if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
708
+ raise ValueError(
709
+ "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
710
+ )
711
+
712
+ if ip_adapter_image_embeds is not None:
713
+ if not isinstance(ip_adapter_image_embeds, list):
714
+ raise ValueError(
715
+ f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
716
+ )
717
+ elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
718
+ raise ValueError(
719
+ f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
720
+ )
721
+
722
+ def check_image(self, image, prompt, prompt_embeds):
723
+ image_is_pil = isinstance(image, PIL.Image.Image)
724
+ image_is_tensor = isinstance(image, torch.Tensor)
725
+ image_is_np = isinstance(image, np.ndarray)
726
+ image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
727
+ image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
728
+ image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray)
729
+
730
+ if (
731
+ not image_is_pil
732
+ and not image_is_tensor
733
+ and not image_is_np
734
+ and not image_is_pil_list
735
+ and not image_is_tensor_list
736
+ and not image_is_np_list
737
+ ):
738
+ raise TypeError(
739
+ f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}"
740
+ )
741
+
742
+ if image_is_pil:
743
+ image_batch_size = 1
744
+ else:
745
+ image_batch_size = len(image)
746
+
747
+ if prompt is not None and isinstance(prompt, str):
748
+ prompt_batch_size = 1
749
+ elif prompt is not None and isinstance(prompt, list):
750
+ prompt_batch_size = len(prompt)
751
+ elif prompt_embeds is not None:
752
+ prompt_batch_size = prompt_embeds.shape[0]
753
+
754
+ if image_batch_size != 1 and image_batch_size != prompt_batch_size:
755
+ raise ValueError(
756
+ f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
757
+ )
758
+
759
+ def prepare_image(
760
+ self,
761
+ image,
762
+ width,
763
+ height,
764
+ batch_size,
765
+ num_images_per_prompt,
766
+ device,
767
+ dtype,
768
+ do_classifier_free_guidance=False,
769
+ ):
770
+ image = self.image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
771
+ image_batch_size = image.shape[0]
772
+
773
+ if image_batch_size == 1:
774
+ repeat_by = batch_size
775
+ else:
776
+ # image batch size is the same as prompt batch size
777
+ repeat_by = num_images_per_prompt
778
+
779
+ image = image.repeat_interleave(repeat_by, dim=0)
780
+
781
+ image = image.to(device=device, dtype=dtype)
782
+
783
+ if do_classifier_free_guidance:
784
+ image = torch.cat([image] * 2)
785
+
786
+ return image
787
+
788
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
789
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
790
+ shape = (
791
+ batch_size,
792
+ num_channels_latents,
793
+ int(height) // self.vae_scale_factor,
794
+ int(width) // self.vae_scale_factor,
795
+ )
796
+ if isinstance(generator, list) and len(generator) != batch_size:
797
+ raise ValueError(
798
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
799
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
800
+ )
801
+
802
+ if latents is None:
803
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
804
+ else:
805
+ latents = latents.to(device)
806
+
807
+ # scale the initial noise by the standard deviation required by the scheduler
808
+ latents = latents * self.scheduler.init_noise_sigma
809
+ return latents
810
+
811
+ # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
812
+ def get_guidance_scale_embedding(
813
+ self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32
814
+ ) -> torch.Tensor:
815
+ """
816
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
817
+
818
+ Args:
819
+ w (`torch.Tensor`):
820
+ Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings.
821
+ embedding_dim (`int`, *optional*, defaults to 512):
822
+ Dimension of the embeddings to generate.
823
+ dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
824
+ Data type of the generated embeddings.
825
+
826
+ Returns:
827
+ `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`.
828
+ """
829
+ assert len(w.shape) == 1
830
+ w = w * 1000.0
831
+
832
+ half_dim = embedding_dim // 2
833
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
834
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
835
+ emb = w.to(dtype)[:, None] * emb[None, :]
836
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
837
+ if embedding_dim % 2 == 1: # zero pad
838
+ emb = torch.nn.functional.pad(emb, (0, 1))
839
+ assert emb.shape == (w.shape[0], embedding_dim)
840
+ return emb
841
+
842
+ @property
843
+ def guidance_scale(self):
844
+ return self._guidance_scale
845
+
846
+ @property
847
+ def clip_skip(self):
848
+ return self._clip_skip
849
+
850
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
851
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
852
+ # corresponds to doing no classifier free guidance.
853
+ @property
854
+ def do_classifier_free_guidance(self):
855
+ return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
856
+
857
+ @property
858
+ def cross_attention_kwargs(self):
859
+ return self._cross_attention_kwargs
860
+
861
+ @property
862
+ def num_timesteps(self):
863
+ return self._num_timesteps
864
+
865
+ @classmethod
866
+ @validate_hf_hub_args
867
+ def lora_state_dict(
868
+ cls,
869
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
870
+ **kwargs,
871
+ ):
872
+ # Override to add support for different LoRA alphas
873
+ state_dict, network_alphas = super(StableDiffusionControlLoraV3Pipeline, cls).lora_state_dict(
874
+ pretrained_model_name_or_path_or_dict, **kwargs
875
+ )
876
+ if network_alphas is None:
877
+ network_alphas = {}
878
+ for k, v in state_dict.items():
879
+ if ".lora_A." in k:
880
+ network_alphas[".".join(k.split(".lora_A.")[0].split(".") + ["alpha"])] = v.shape[0]
881
+ return state_dict, network_alphas
882
+
883
+ def load_lora_weights(
884
+ self,
885
+ pretrained_model_name_or_path_or_dict: Union[
886
+ Union[str, Dict[str, torch.Tensor]],
887
+ List[Union[str, Dict[str, torch.Tensor]]]
888
+ ],
889
+ adapter_name=None,
890
+ **kwargs
891
+ ):
892
+ unet: UNet2DConditionModelEx = self.unet._orig_mod if is_compiled_module(self.unet) else self.unet
893
+ num_condition_names = len(unet.extra_condition_names)
894
+ in_channels = unet.config.in_channels
895
+
896
+ if adapter_name is not None and adapter_name not in unet.extra_condition_names:
897
+ return super().load_lora_weights(pretrained_model_name_or_path_or_dict, adapter_name, **kwargs)
898
+
899
+ if not isinstance(pretrained_model_name_or_path_or_dict, list):
900
+ pretrained_model_name_or_path_or_dict = [pretrained_model_name_or_path_or_dict] * num_condition_names
901
+ pretrained_model_name_or_path_or_dict_list = pretrained_model_name_or_path_or_dict
902
+
903
+ assert len(pretrained_model_name_or_path_or_dict) == len(unet.extra_condition_names)
904
+
905
+ adapter_name_ori = adapter_name
906
+ for i, (pretrained_model_name_or_path_or_dict, adapter_name) in enumerate(zip(
907
+ pretrained_model_name_or_path_or_dict_list,
908
+ unet.extra_condition_names
909
+ )):
910
+ _kwargs = {**kwargs}
911
+ subfolder = _kwargs.pop("subfolder", None)
912
+ if isinstance(subfolder, list):
913
+ subfolder = subfolder[i]
914
+ weight_name = _kwargs.pop("weight_name", "pytorch_lora_weights.safetensors")
915
+
916
+ if not isinstance(pretrained_model_name_or_path_or_dict, dict):
917
+ pretrained_model_name_or_path_or_dict, _ = self.lora_state_dict(
918
+ pretrained_model_name_or_path_or_dict,
919
+ subfolder=subfolder,
920
+ weight_name=weight_name,
921
+ **_kwargs
922
+ )
923
+
924
+ if adapter_name_ori is not None:
925
+ # only load lora of the input adapter name, then break the loop
926
+ i = unet.extra_condition_names.index(adapter_name_ori)
927
+ adapter_name = adapter_name_ori
928
+
929
+ unet_conv_in_lora_A_name, old_weight = ([
930
+ (k, v)
931
+ for k, v in pretrained_model_name_or_path_or_dict.items()
932
+ if "unet." in k and ".conv_in." in k and ".lora_A." in k
933
+ ] + [(None, None)])[0]
934
+ if unet_conv_in_lora_A_name is not None:
935
+ in_weight = old_weight[:,:in_channels]
936
+ cond_weight = old_weight[:,in_channels:]
937
+ zero_weight = torch.zeros_like(in_weight)
938
+ new_weight = torch.cat(
939
+ [in_weight] +
940
+ [zero_weight] * i +
941
+ [cond_weight] +
942
+ [zero_weight] * (num_condition_names - i - 1),
943
+ dim=1
944
+ )
945
+ pretrained_model_name_or_path_or_dict[unet_conv_in_lora_A_name] = new_weight
946
+
947
+ super().load_lora_weights(pretrained_model_name_or_path_or_dict, adapter_name, **_kwargs)
948
+
949
+ if adapter_name_ori is not None:
950
+ break
951
+
952
+ unet.activate_adapters()
953
+
954
+ @torch.no_grad()
955
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
956
+ def __call__(
957
+ self,
958
+ prompt: Union[str, List[str]] = None,
959
+ image: PipelineImageInput = None,
960
+ height: Optional[int] = None,
961
+ width: Optional[int] = None,
962
+ num_inference_steps: int = 50,
963
+ timesteps: List[int] = None,
964
+ sigmas: List[float] = None,
965
+ guidance_scale: float = 7.5,
966
+ negative_prompt: Optional[Union[str, List[str]]] = None,
967
+ num_images_per_prompt: Optional[int] = 1,
968
+ eta: float = 0.0,
969
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
970
+ latents: Optional[torch.Tensor] = None,
971
+ prompt_embeds: Optional[torch.Tensor] = None,
972
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
973
+ ip_adapter_image: Optional[PipelineImageInput] = None,
974
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
975
+ output_type: Optional[str] = "pil",
976
+ return_dict: bool = True,
977
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
978
+ extra_condition_scale: Union[float, List[float]] = 1.0,
979
+ control_guidance_start: Union[float, List[float]] = 0.0,
980
+ control_guidance_end: Union[float, List[float]] = 1.0,
981
+ clip_skip: Optional[int] = None,
982
+ callback_on_step_end: Optional[
983
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
984
+ ] = None,
985
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
986
+ **kwargs,
987
+ ):
988
+ r"""
989
+ The call function to the pipeline for generation.
990
+
991
+ Args:
992
+ prompt (`str` or `List[str]`, *optional*):
993
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
994
+ image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
995
+ `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
996
+ The extra input condition to provide guidance to the `unet` for generation after encoded by `vae`. If the type is
997
+ specified as `torch.Tensor`, its `vae` latent representation is passed to UNet. `PIL.Image.Image` can also be accepted
998
+ as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or
999
+ width are passed, `image` is resized accordingly. If multiple extra conditions are specified in `unet`,
1000
+ images must be passed as a list such that each element of the list can be correctly batched for input
1001
+ to `unet`. When `prompt` is a list, and if a list of images is passed for `unet`, each will be paired with each prompt
1002
+ in the `prompt` list. This also applies to multiple extra conditions, where a list of image lists can be
1003
+ passed to batch for each prompt and each extra condition.
1004
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
1005
+ The height in pixels of the generated image.
1006
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
1007
+ The width in pixels of the generated image.
1008
+ num_inference_steps (`int`, *optional*, defaults to 50):
1009
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
1010
+ expense of slower inference.
1011
+ timesteps (`List[int]`, *optional*):
1012
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
1013
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
1014
+ passed will be used. Must be in descending order.
1015
+ sigmas (`List[float]`, *optional*):
1016
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
1017
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
1018
+ will be used.
1019
+ guidance_scale (`float`, *optional*, defaults to 7.5):
1020
+ A higher guidance scale value encourages the model to generate images closely linked to the text
1021
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
1022
+ negative_prompt (`str` or `List[str]`, *optional*):
1023
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
1024
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
1025
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
1026
+ The number of images to generate per prompt.
1027
+ eta (`float`, *optional*, defaults to 0.0):
1028
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
1029
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
1030
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
1031
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
1032
+ generation deterministic.
1033
+ latents (`torch.Tensor`, *optional*):
1034
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
1035
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
1036
+ tensor is generated by sampling using the supplied random `generator`.
1037
+ prompt_embeds (`torch.Tensor`, *optional*):
1038
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
1039
+ provided, text embeddings are generated from the `prompt` input argument.
1040
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
1041
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
1042
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
1043
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
1044
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
1045
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
1046
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
1047
+ contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
1048
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
1049
+ output_type (`str`, *optional*, defaults to `"pil"`):
1050
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
1051
+ return_dict (`bool`, *optional*, defaults to `True`):
1052
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
1053
+ plain tuple.
1054
+ callback (`Callable`, *optional*):
1055
+ A function that calls every `callback_steps` steps during inference. The function is called with the
1056
+ following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
1057
+ callback_steps (`int`, *optional*, defaults to 1):
1058
+ The frequency at which the `callback` function is called. If not specified, the callback is called at
1059
+ every step.
1060
+ cross_attention_kwargs (`dict`, *optional*):
1061
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
1062
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
1063
+ extra_condition_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
1064
+ The control lora scale of `unet`. If multiple extra conditions are specified in `unet`, you can set
1065
+ the corresponding scale as a list.
1066
+ control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
1067
+ The percentage of total steps at which the extra condtion starts applying.
1068
+ control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
1069
+ The percentage of total steps at which the extra condtion stops applying.
1070
+ clip_skip (`int`, *optional*):
1071
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
1072
+ the output of the pre-final layer will be used for computing the prompt embeddings.
1073
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
1074
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
1075
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
1076
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
1077
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
1078
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
1079
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
1080
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
1081
+ `._callback_tensor_inputs` attribute of your pipeline class.
1082
+
1083
+ Examples:
1084
+
1085
+ Returns:
1086
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
1087
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
1088
+ otherwise a `tuple` is returned where the first element is a list with the generated images and the
1089
+ second element is a list of `bool`s indicating whether the corresponding generated image contains
1090
+ "not-safe-for-work" (nsfw) content.
1091
+ """
1092
+
1093
+ callback = kwargs.pop("callback", None)
1094
+ callback_steps = kwargs.pop("callback_steps", None)
1095
+
1096
+ if callback is not None:
1097
+ deprecate(
1098
+ "callback",
1099
+ "1.0.0",
1100
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
1101
+ )
1102
+ if callback_steps is not None:
1103
+ deprecate(
1104
+ "callback_steps",
1105
+ "1.0.0",
1106
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
1107
+ )
1108
+
1109
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
1110
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
1111
+
1112
+ unet: UNet2DConditionModelEx = self.unet._orig_mod if is_compiled_module(self.unet) else self.unet
1113
+ num_extra_conditions = len(unet.extra_condition_names)
1114
+
1115
+ # align format for control guidance
1116
+ if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
1117
+ control_guidance_start = len(control_guidance_end) * [control_guidance_start]
1118
+ elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
1119
+ control_guidance_end = len(control_guidance_start) * [control_guidance_end]
1120
+ elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
1121
+ mult = num_extra_conditions
1122
+ control_guidance_start, control_guidance_end = (
1123
+ mult * [control_guidance_start],
1124
+ mult * [control_guidance_end],
1125
+ )
1126
+
1127
+ # 1. Check inputs. Raise error if not correct
1128
+ self.check_inputs(
1129
+ prompt,
1130
+ image,
1131
+ callback_steps,
1132
+ negative_prompt,
1133
+ prompt_embeds,
1134
+ negative_prompt_embeds,
1135
+ ip_adapter_image,
1136
+ ip_adapter_image_embeds,
1137
+ extra_condition_scale,
1138
+ control_guidance_start,
1139
+ control_guidance_end,
1140
+ callback_on_step_end_tensor_inputs,
1141
+ )
1142
+
1143
+ self._guidance_scale = guidance_scale
1144
+ self._clip_skip = clip_skip
1145
+ self._cross_attention_kwargs = cross_attention_kwargs
1146
+
1147
+ # 2. Define call parameters
1148
+ if prompt is not None and isinstance(prompt, str):
1149
+ batch_size = 1
1150
+ elif prompt is not None and isinstance(prompt, list):
1151
+ batch_size = len(prompt)
1152
+ else:
1153
+ batch_size = prompt_embeds.shape[0]
1154
+
1155
+ device = self._execution_device
1156
+
1157
+ if num_extra_conditions > 1 and isinstance(extra_condition_scale, float):
1158
+ extra_condition_scale = [extra_condition_scale] * num_extra_conditions
1159
+
1160
+ # 3. Encode input prompt
1161
+ text_encoder_lora_scale = (
1162
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
1163
+ )
1164
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
1165
+ prompt,
1166
+ device,
1167
+ num_images_per_prompt,
1168
+ self.do_classifier_free_guidance,
1169
+ negative_prompt,
1170
+ prompt_embeds=prompt_embeds,
1171
+ negative_prompt_embeds=negative_prompt_embeds,
1172
+ lora_scale=text_encoder_lora_scale,
1173
+ clip_skip=self.clip_skip,
1174
+ )
1175
+ # For classifier free guidance, we need to do two forward passes.
1176
+ # Here we concatenate the unconditional and text embeddings into a single batch
1177
+ # to avoid doing two forward passes
1178
+ if self.do_classifier_free_guidance:
1179
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
1180
+
1181
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
1182
+ image_embeds = self.prepare_ip_adapter_image_embeds(
1183
+ ip_adapter_image,
1184
+ ip_adapter_image_embeds,
1185
+ device,
1186
+ batch_size * num_images_per_prompt,
1187
+ self.do_classifier_free_guidance,
1188
+ )
1189
+
1190
+ # 4. Prepare image
1191
+ if num_extra_conditions == 1:
1192
+ image = self.prepare_image(
1193
+ image=image,
1194
+ width=width,
1195
+ height=height,
1196
+ batch_size=batch_size * num_images_per_prompt,
1197
+ num_images_per_prompt=num_images_per_prompt,
1198
+ device=device,
1199
+ dtype=unet.dtype,
1200
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
1201
+ )
1202
+ height, width = image.shape[-2:]
1203
+ image = (
1204
+ self.vae.encode(image.to(dtype=unet.dtype)).latent_dist.mode() * self.vae.config.scaling_factor
1205
+ )
1206
+ elif num_extra_conditions >= 1:
1207
+ images = []
1208
+
1209
+ # Nested lists as extra condition
1210
+ if isinstance(image[0], list):
1211
+ # Transpose the nested image list
1212
+ image = [list(t) for t in zip(*image)]
1213
+
1214
+ for image_ in image:
1215
+ image_ = self.prepare_image(
1216
+ image=image_,
1217
+ width=width,
1218
+ height=height,
1219
+ batch_size=batch_size * num_images_per_prompt,
1220
+ num_images_per_prompt=num_images_per_prompt,
1221
+ device=device,
1222
+ dtype=unet.dtype,
1223
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
1224
+ )
1225
+
1226
+ images.append(image_)
1227
+
1228
+ image = images
1229
+ height, width = image[0].shape[-2:]
1230
+ image = [
1231
+ self.vae.encode(image.to(dtype=unet.dtype)).latent_dist.mode() * self.vae.config.scaling_factor
1232
+ for image in images
1233
+ ]
1234
+ else:
1235
+ assert False
1236
+
1237
+ # 5. Prepare timesteps
1238
+ timesteps, num_inference_steps = retrieve_timesteps(
1239
+ self.scheduler, num_inference_steps, device, timesteps, sigmas
1240
+ )
1241
+ self._num_timesteps = len(timesteps)
1242
+
1243
+ # 6. Prepare latent variables
1244
+ num_channels_latents = self.unet.config.in_channels
1245
+ latents = self.prepare_latents(
1246
+ batch_size * num_images_per_prompt,
1247
+ num_channels_latents,
1248
+ height,
1249
+ width,
1250
+ prompt_embeds.dtype,
1251
+ device,
1252
+ generator,
1253
+ latents,
1254
+ )
1255
+
1256
+ # 6.5 Optionally get Guidance Scale Embedding
1257
+ timestep_cond = None
1258
+ if self.unet.config.time_cond_proj_dim is not None:
1259
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
1260
+ timestep_cond = self.get_guidance_scale_embedding(
1261
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
1262
+ ).to(device=device, dtype=latents.dtype)
1263
+
1264
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1265
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1266
+
1267
+ # 7.1 Add image embeds for IP-Adapter
1268
+ added_cond_kwargs = (
1269
+ {"image_embeds": image_embeds}
1270
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None
1271
+ else None
1272
+ )
1273
+
1274
+ # 7.2 Create tensor stating which extra_conditions to keep
1275
+ extra_condition_keep = []
1276
+ for i in range(len(timesteps)):
1277
+ keeps = [
1278
+ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
1279
+ for s, e in zip(control_guidance_start, control_guidance_end)
1280
+ ]
1281
+ extra_condition_keep.append(keeps[0] if num_extra_conditions == 1 else keeps)
1282
+
1283
+ # 8. Denoising loop
1284
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
1285
+ is_unet_compiled = is_compiled_module(self.unet)
1286
+ is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
1287
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1288
+ for i, t in enumerate(timesteps):
1289
+ # Relevant thread:
1290
+ # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
1291
+ if is_unet_compiled and is_torch_higher_equal_2_1:
1292
+ torch._inductor.cudagraph_mark_step_begin()
1293
+ # expand the latents if we are doing classifier free guidance
1294
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
1295
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1296
+
1297
+ if isinstance(extra_condition_keep[i], list):
1298
+ cond_scale = [c * s for c, s in zip(extra_condition_scale, extra_condition_keep[i])]
1299
+ else:
1300
+ extra_cond_scale = extra_condition_scale
1301
+ if isinstance(extra_cond_scale, list):
1302
+ extra_cond_scale = extra_cond_scale[0]
1303
+ cond_scale = extra_cond_scale * extra_condition_keep[i]
1304
+
1305
+ self.unet.set_extra_condition_scale(cond_scale)
1306
+
1307
+ # predict the noise residual
1308
+ noise_pred = self.unet(
1309
+ latent_model_input,
1310
+ t,
1311
+ encoder_hidden_states=prompt_embeds,
1312
+ timestep_cond=timestep_cond,
1313
+ cross_attention_kwargs=self.cross_attention_kwargs,
1314
+ added_cond_kwargs=added_cond_kwargs,
1315
+ extra_conditions=image,
1316
+ return_dict=False,
1317
+ )[0]
1318
+
1319
+ # perform guidance
1320
+ if self.do_classifier_free_guidance:
1321
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1322
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
1323
+
1324
+ # compute the previous noisy sample x_t -> x_t-1
1325
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
1326
+
1327
+ if callback_on_step_end is not None:
1328
+ callback_kwargs = {}
1329
+ for k in callback_on_step_end_tensor_inputs:
1330
+ callback_kwargs[k] = locals()[k]
1331
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1332
+
1333
+ latents = callback_outputs.pop("latents", latents)
1334
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1335
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
1336
+
1337
+ # call the callback, if provided
1338
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1339
+ progress_bar.update()
1340
+ if callback is not None and i % callback_steps == 0:
1341
+ step_idx = i // getattr(self.scheduler, "order", 1)
1342
+ callback(step_idx, t, latents)
1343
+
1344
+ self.unet.set_extra_condition_scale(1.0)
1345
+
1346
+ # If we do sequential model offloading, let's offload unet
1347
+ # manually for max memory savings
1348
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
1349
+ self.unet.to("cpu")
1350
+ torch.cuda.empty_cache()
1351
+
1352
+ if not output_type == "latent":
1353
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
1354
+ 0
1355
+ ]
1356
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
1357
+ else:
1358
+ image = latents
1359
+ has_nsfw_concept = None
1360
+
1361
+ if has_nsfw_concept is None:
1362
+ do_denormalize = [True] * image.shape[0]
1363
+ else:
1364
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
1365
+
1366
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
1367
+
1368
+ # Offload all models
1369
+ self.maybe_free_model_hooks()
1370
+
1371
+ if not return_dict:
1372
+ return (image, has_nsfw_concept)
1373
+
1374
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
preprocessor.py CHANGED
@@ -58,7 +58,8 @@ class Preprocessor:
58
  self.model = ImageSegmentor()
59
  else:
60
  raise ValueError
61
- torch.cuda.empty_cache()
 
62
  gc.collect()
63
  self.name = name
64
 
 
58
  self.model = ImageSegmentor()
59
  else:
60
  raise ValueError
61
+ if torch.cuda.is_available():
62
+ torch.cuda.empty_cache()
63
  gc.collect()
64
  self.name = name
65
 
requirements.txt CHANGED
@@ -1,13 +1,3 @@
1
- accelerate==0.21.0
2
- controlnet_aux==0.0.6
3
- diffusers==0.18.2
4
- einops==0.6.1
5
- gradio==3.45.2
6
- huggingface-hub==0.16.4
7
- mediapipe==0.10.1
8
- opencv-python-headless==4.8.0.74
9
- safetensors==0.3.1
10
- torch==2.0.1
11
- torchvision==0.15.2
12
- transformers==4.30.2
13
- xformers==0.0.20
 
1
+ controlnet_aux>=0.0.6
2
+ mediapipe>=0.10.1
3
+ opencv-python-headless>=4.8.0.74
 
 
 
 
 
 
 
 
 
 
settings.py CHANGED
@@ -2,14 +2,14 @@ import os
2
 
3
  import numpy as np
4
 
5
- DEFAULT_MODEL_ID = os.getenv("DEFAULT_MODEL_ID", "runwayml/stable-diffusion-v1-5")
6
 
7
  MAX_NUM_IMAGES = int(os.getenv("MAX_NUM_IMAGES", "3"))
8
  DEFAULT_NUM_IMAGES = min(MAX_NUM_IMAGES, int(os.getenv("DEFAULT_NUM_IMAGES", "3")))
9
  MAX_IMAGE_RESOLUTION = int(os.getenv("MAX_IMAGE_RESOLUTION", "768"))
10
  DEFAULT_IMAGE_RESOLUTION = min(MAX_IMAGE_RESOLUTION, int(os.getenv("DEFAULT_IMAGE_RESOLUTION", "768")))
11
 
12
- ALLOW_CHANGING_BASE_MODEL = os.getenv("SPACE_ID") != "hysts/ControlNet-v1-1"
13
  SHOW_DUPLICATE_BUTTON = os.getenv("SHOW_DUPLICATE_BUTTON") == "1"
14
 
15
  MAX_SEED = np.iinfo(np.int32).max
 
2
 
3
  import numpy as np
4
 
5
+ DEFAULT_MODEL_ID = os.getenv("DEFAULT_MODEL_ID", "SG161222/Realistic_Vision_V4.0_noVAE")
6
 
7
  MAX_NUM_IMAGES = int(os.getenv("MAX_NUM_IMAGES", "3"))
8
  DEFAULT_NUM_IMAGES = min(MAX_NUM_IMAGES, int(os.getenv("DEFAULT_NUM_IMAGES", "3")))
9
  MAX_IMAGE_RESOLUTION = int(os.getenv("MAX_IMAGE_RESOLUTION", "768"))
10
  DEFAULT_IMAGE_RESOLUTION = min(MAX_IMAGE_RESOLUTION, int(os.getenv("DEFAULT_IMAGE_RESOLUTION", "768")))
11
 
12
+ ALLOW_CHANGING_BASE_MODEL = os.getenv("SPACE_ID") != "HighCWu/control-lora-v3"
13
  SHOW_DUPLICATE_BUTTON = os.getenv("SHOW_DUPLICATE_BUTTON") == "1"
14
 
15
  MAX_SEED = np.iinfo(np.int32).max
unet.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Optional, Tuple, Union
2
+
3
+ import copy
4
+ import torch
5
+ from torch import nn, svd_lowrank
6
+
7
+ from peft.tuners.lora import LoraLayer, Conv2d as PeftConv2d
8
+ from diffusers.configuration_utils import register_to_config
9
+ from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput, UNet2DConditionModel as UNet2DConditionModel
10
+
11
+
12
+ class UNet2DConditionModelEx(UNet2DConditionModel):
13
+ @register_to_config
14
+ def __init__(
15
+ self,
16
+ sample_size: Optional[int] = None,
17
+ in_channels: int = 4,
18
+ out_channels: int = 4,
19
+ center_input_sample: bool = False,
20
+ flip_sin_to_cos: bool = True,
21
+ freq_shift: int = 0,
22
+ down_block_types: Tuple[str] = (
23
+ "CrossAttnDownBlock2D",
24
+ "CrossAttnDownBlock2D",
25
+ "CrossAttnDownBlock2D",
26
+ "DownBlock2D",
27
+ ),
28
+ mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
29
+ up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
30
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
31
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
32
+ layers_per_block: Union[int, Tuple[int]] = 2,
33
+ downsample_padding: int = 1,
34
+ mid_block_scale_factor: float = 1,
35
+ dropout: float = 0.0,
36
+ act_fn: str = "silu",
37
+ norm_num_groups: Optional[int] = 32,
38
+ norm_eps: float = 1e-5,
39
+ cross_attention_dim: Union[int, Tuple[int]] = 1280,
40
+ transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
41
+ reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
42
+ encoder_hid_dim: Optional[int] = None,
43
+ encoder_hid_dim_type: Optional[str] = None,
44
+ attention_head_dim: Union[int, Tuple[int]] = 8,
45
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
46
+ dual_cross_attention: bool = False,
47
+ use_linear_projection: bool = False,
48
+ class_embed_type: Optional[str] = None,
49
+ addition_embed_type: Optional[str] = None,
50
+ addition_time_embed_dim: Optional[int] = None,
51
+ num_class_embeds: Optional[int] = None,
52
+ upcast_attention: bool = False,
53
+ resnet_time_scale_shift: str = "default",
54
+ resnet_skip_time_act: bool = False,
55
+ resnet_out_scale_factor: float = 1.0,
56
+ time_embedding_type: str = "positional",
57
+ time_embedding_dim: Optional[int] = None,
58
+ time_embedding_act_fn: Optional[str] = None,
59
+ timestep_post_act: Optional[str] = None,
60
+ time_cond_proj_dim: Optional[int] = None,
61
+ conv_in_kernel: int = 3,
62
+ conv_out_kernel: int = 3,
63
+ projection_class_embeddings_input_dim: Optional[int] = None,
64
+ attention_type: str = "default",
65
+ class_embeddings_concat: bool = False,
66
+ mid_block_only_cross_attention: Optional[bool] = None,
67
+ cross_attention_norm: Optional[str] = None,
68
+ addition_embed_type_num_heads: int = 64,
69
+ extra_condition_names: List[str] = [],
70
+ ):
71
+ num_extra_conditions = len(extra_condition_names)
72
+ super().__init__(
73
+ sample_size=sample_size,
74
+ in_channels=in_channels * (1 + num_extra_conditions),
75
+ out_channels=out_channels,
76
+ center_input_sample=center_input_sample,
77
+ flip_sin_to_cos=flip_sin_to_cos,
78
+ freq_shift=freq_shift,
79
+ down_block_types=down_block_types,
80
+ mid_block_type=mid_block_type,
81
+ up_block_types=up_block_types,
82
+ only_cross_attention=only_cross_attention,
83
+ block_out_channels=block_out_channels,
84
+ layers_per_block=layers_per_block,
85
+ downsample_padding=downsample_padding,
86
+ mid_block_scale_factor=mid_block_scale_factor,
87
+ dropout=dropout,
88
+ act_fn=act_fn,
89
+ norm_num_groups=norm_num_groups,
90
+ norm_eps=norm_eps,
91
+ cross_attention_dim=cross_attention_dim,
92
+ transformer_layers_per_block=transformer_layers_per_block,
93
+ reverse_transformer_layers_per_block=reverse_transformer_layers_per_block,
94
+ encoder_hid_dim=encoder_hid_dim,
95
+ encoder_hid_dim_type=encoder_hid_dim_type,
96
+ attention_head_dim=attention_head_dim,
97
+ num_attention_heads=num_attention_heads,
98
+ dual_cross_attention=dual_cross_attention,
99
+ use_linear_projection=use_linear_projection,
100
+ class_embed_type=class_embed_type,
101
+ addition_embed_type=addition_embed_type,
102
+ addition_time_embed_dim=addition_time_embed_dim,
103
+ num_class_embeds=num_class_embeds,
104
+ upcast_attention=upcast_attention,
105
+ resnet_time_scale_shift=resnet_time_scale_shift,
106
+ resnet_skip_time_act=resnet_skip_time_act,
107
+ resnet_out_scale_factor=resnet_out_scale_factor,
108
+ time_embedding_type=time_embedding_type,
109
+ time_embedding_dim=time_embedding_dim,
110
+ time_embedding_act_fn=time_embedding_act_fn,
111
+ timestep_post_act=timestep_post_act,
112
+ time_cond_proj_dim=time_cond_proj_dim,
113
+ conv_in_kernel=conv_in_kernel,
114
+ conv_out_kernel=conv_out_kernel,
115
+ projection_class_embeddings_input_dim=projection_class_embeddings_input_dim,
116
+ attention_type=attention_type,
117
+ class_embeddings_concat=class_embeddings_concat,
118
+ mid_block_only_cross_attention=mid_block_only_cross_attention,
119
+ cross_attention_norm=cross_attention_norm,
120
+ addition_embed_type_num_heads=addition_embed_type_num_heads,)
121
+ self._internal_dict = copy.deepcopy(self._internal_dict)
122
+ self.config.in_channels = in_channels
123
+ self.config.extra_condition_names = extra_condition_names
124
+
125
+ @property
126
+ def extra_condition_names(self) -> List[str]:
127
+ return self.config.extra_condition_names
128
+
129
+ def add_extra_conditions(self, extra_condition_names: Union[str, List[str]]):
130
+ if isinstance(extra_condition_names, str):
131
+ extra_condition_names = [extra_condition_names]
132
+ conv_in_kernel = self.config.conv_in_kernel
133
+ conv_in_weight = self.conv_in.weight
134
+ self.config.extra_condition_names += extra_condition_names
135
+ full_in_channels = self.config.in_channels * (1 + len(self.config.extra_condition_names))
136
+ new_conv_in_weight = torch.zeros(
137
+ conv_in_weight.shape[0], full_in_channels, conv_in_kernel, conv_in_kernel,
138
+ dtype=conv_in_weight.dtype,
139
+ device=conv_in_weight.device,)
140
+ new_conv_in_weight[:,:conv_in_weight.shape[1]] = conv_in_weight
141
+ self.conv_in.weight = nn.Parameter(
142
+ new_conv_in_weight.data,
143
+ requires_grad=conv_in_weight.requires_grad,)
144
+ self.conv_in.in_channels = full_in_channels
145
+
146
+ return self
147
+
148
+ def activate_adapters(self, adapter_names: Union[List[str], None] = None):
149
+ lora_layers = [layer for layer in self.modules() if isinstance(layer, LoraLayer)]
150
+ for lora_layer in lora_layers:
151
+ _adapter_names = adapter_names or list(lora_layer.scaling.keys())
152
+ lora_layer.set_adapter(_adapter_names)
153
+
154
+ def set_extra_condition_scale(self, scale: Union[float, List[float]] = 1.0):
155
+ if isinstance(scale, float):
156
+ scale = [scale] * len(self.config.extra_condition_names)
157
+
158
+ lora_layers = [layer for layer in self.modules() if isinstance(layer, LoraLayer)]
159
+ for s, n in zip(scale, self.config.extra_condition_names):
160
+ for lora_layer in lora_layers:
161
+ lora_layer.set_scale(n, s)
162
+
163
+ @property
164
+ def default_half_lora_target_modules(self) -> List[str]:
165
+ module_names = []
166
+ for name, module in self.named_modules():
167
+ if "conv_out" in name or "up_blocks" in name:
168
+ continue
169
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
170
+ module_names.append(name)
171
+ return list(set(module_names))
172
+
173
+ @property
174
+ def default_full_lora_target_modules(self) -> List[str]:
175
+ module_names = []
176
+ for name, module in self.named_modules():
177
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
178
+ module_names.append(name)
179
+ return list(set(module_names))
180
+
181
+ @property
182
+ def default_half_skip_attn_lora_target_modules(self) -> List[str]:
183
+ return [
184
+ module_name
185
+ for module_name in self.default_half_lora_target_modules
186
+ if all(
187
+ not module_name.endswith(attn_name)
188
+ for attn_name in
189
+ ["to_k", "to_q", "to_v", "to_out.0"]
190
+ )
191
+ ]
192
+
193
+ @property
194
+ def default_full_skip_attn_lora_target_modules(self) -> List[str]:
195
+ return [
196
+ module_name
197
+ for module_name in self.default_full_lora_target_modules
198
+ if all(
199
+ not module_name.endswith(attn_name)
200
+ for attn_name in
201
+ ["to_k", "to_q", "to_v", "to_out.0"]
202
+ )
203
+ ]
204
+
205
+ def forward(
206
+ self,
207
+ sample: torch.Tensor,
208
+ timestep: Union[torch.Tensor, float, int],
209
+ encoder_hidden_states: torch.Tensor,
210
+ class_labels: Optional[torch.Tensor] = None,
211
+ timestep_cond: Optional[torch.Tensor] = None,
212
+ attention_mask: Optional[torch.Tensor] = None,
213
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
214
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
215
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
216
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
217
+ down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
218
+ encoder_attention_mask: Optional[torch.Tensor] = None,
219
+ extra_conditions: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
220
+ return_dict: bool = True,
221
+ ) -> Union[UNet2DConditionOutput, Tuple]:
222
+ if extra_conditions is not None:
223
+ if isinstance(extra_conditions, list):
224
+ extra_conditions = torch.cat(extra_conditions, dim=1)
225
+ sample = torch.cat([sample, extra_conditions], dim=1)
226
+ return super().forward(
227
+ sample=sample,
228
+ timestep=timestep,
229
+ encoder_hidden_states=encoder_hidden_states,
230
+ class_labels=class_labels,
231
+ timestep_cond=timestep_cond,
232
+ attention_mask=attention_mask,
233
+ cross_attention_kwargs=cross_attention_kwargs,
234
+ added_cond_kwargs=added_cond_kwargs,
235
+ down_block_additional_residuals=down_block_additional_residuals,
236
+ mid_block_additional_residual=mid_block_additional_residual,
237
+ down_intrablock_additional_residuals=down_intrablock_additional_residuals,
238
+ encoder_attention_mask=encoder_attention_mask,
239
+ return_dict=return_dict,)
240
+
241
+
242
+ class PeftConv2dEx(PeftConv2d):
243
+ def reset_lora_parameters(self, adapter_name, init_lora_weights):
244
+ if init_lora_weights is False:
245
+ return
246
+
247
+ if isinstance(init_lora_weights, str) and "pissa" in init_lora_weights.lower():
248
+ if self.conv2d_pissa_init(adapter_name, init_lora_weights):
249
+ return
250
+ # Failed
251
+ init_lora_weights = "gaussian"
252
+
253
+ super(PeftConv2d, self).reset_lora_parameters(adapter_name, init_lora_weights)
254
+
255
+ def conv2d_pissa_init(self, adapter_name, init_lora_weights):
256
+ weight = weight_ori = self.get_base_layer().weight
257
+ weight = weight.flatten(start_dim=1)
258
+ if self.r[adapter_name] > weight.shape[0]:
259
+ return False
260
+ dtype = weight.dtype
261
+ if dtype not in [torch.float32, torch.float16, torch.bfloat16]:
262
+ raise TypeError(
263
+ "Please initialize PiSSA under float32, float16, or bfloat16. "
264
+ "Subsequently, re-quantize the residual model to help minimize quantization errors."
265
+ )
266
+ weight = weight.to(torch.float32)
267
+
268
+ if init_lora_weights == "pissa":
269
+ # USV^T = W <-> VSU^T = W^T, where W^T = weight.data in R^{out_channel, in_channel},
270
+ V, S, Uh = torch.linalg.svd(weight.data, full_matrices=False)
271
+ Vr = V[:, : self.r[adapter_name]]
272
+ Sr = S[: self.r[adapter_name]]
273
+ Sr /= self.scaling[adapter_name]
274
+ Uhr = Uh[: self.r[adapter_name]]
275
+ elif len(init_lora_weights.split("_niter_")) == 2:
276
+ Vr, Sr, Ur = svd_lowrank(
277
+ weight.data, self.r[adapter_name], niter=int(init_lora_weights.split("_niter_")[-1])
278
+ )
279
+ Sr /= self.scaling[adapter_name]
280
+ Uhr = Ur.t()
281
+ else:
282
+ raise ValueError(
283
+ f"init_lora_weights should be 'pissa' or 'pissa_niter_[number of iters]', got {init_lora_weights} instead."
284
+ )
285
+
286
+ lora_A = torch.diag(torch.sqrt(Sr)) @ Uhr
287
+ lora_B = Vr @ torch.diag(torch.sqrt(Sr))
288
+ self.lora_A[adapter_name].weight.data = lora_A.view([-1] + list(weight_ori.shape[1:]))
289
+ self.lora_B[adapter_name].weight.data = lora_B.view([-1, self.r[adapter_name]] + [1] * (weight_ori.ndim - 2))
290
+ weight = weight.data - self.scaling[adapter_name] * lora_B @ lora_A
291
+ weight = weight.to(dtype)
292
+ self.get_base_layer().weight.data = weight.view_as(weight_ori)
293
+
294
+ return True
295
+
296
+
297
+ # Patch peft conv2d
298
+ PeftConv2d.reset_lora_parameters = PeftConv2dEx.reset_lora_parameters
299
+ PeftConv2d.conv2d_pissa_init = PeftConv2dEx.conv2d_pissa_init