Shanghua Gao commited on
Commit
78acc58
1 Parent(s): 5fb4baa
README.md CHANGED
@@ -8,7 +8,6 @@ sdk_version: 3.35.2
8
  app_file: app.py
9
  pinned: false
10
  ---
11
-
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
13
 
14
  # Edit Anything by Segment-Anything
 
8
  app_file: app.py
9
  pinned: false
10
  ---
 
11
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
12
 
13
  # Edit Anything by Segment-Anything
annotator/util.py CHANGED
@@ -1,7 +1,7 @@
1
  import numpy as np
2
  import cv2
3
  import os
4
-
5
 
6
  annotator_ckpts_path = os.path.join(os.path.dirname(__file__), 'ckpts')
7
 
@@ -71,3 +71,25 @@ def get_bounding_box(mask):
71
 
72
  # Return as [xmin, ymin, xmax, ymax]
73
  return [rmin, cmin, rmax, cmax]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import numpy as np
2
  import cv2
3
  import os
4
+ import pickle
5
 
6
  annotator_ckpts_path = os.path.join(os.path.dirname(__file__), 'ckpts')
7
 
 
71
 
72
  # Return as [xmin, ymin, xmax, ymax]
73
  return [rmin, cmin, rmax, cmax]
74
+
75
+
76
+
77
+ def save_input_to_file(func):
78
+ def wrapper(self, *args, **kwargs):
79
+ # 创建不包含 self 的输入副本
80
+ input_data = {
81
+ 'args': args,
82
+ 'kwargs': kwargs
83
+ }
84
+
85
+ # 执行原始函数
86
+ result = func(self, *args, **kwargs)
87
+
88
+ # 将输入数据保存到文件
89
+ with open('input_data.pkl', 'wb') as f:
90
+ pickle.dump(input_data, f)
91
+
92
+ # 返回结果
93
+ return result
94
+
95
+ return wrapper
app.py CHANGED
@@ -68,4 +68,4 @@ with gr.Blocks() as demo:
68
  with gr.Tabs():
69
  gr.Markdown(SHARED_UI_WARNING)
70
 
71
- demo.queue(api_open=False).launch(server_name='0.0.0.0', share=False)
 
68
  with gr.Tabs():
69
  gr.Markdown(SHARED_UI_WARNING)
70
 
71
+ demo.queue(api_open=False).launch(server_name='0.0.0.0', share=False)
editany_demo.py CHANGED
@@ -1,6 +1,10 @@
1
  # Edit Anything trained with Stable Diffusion + ControlNet + SAM + BLIP2
2
  import gradio as gr
3
 
 
 
 
 
4
 
5
  def create_demo_template(
6
  process,
@@ -22,7 +26,7 @@ def create_demo_template(
22
  ref_click_mask = gr.State(None)
23
  with gr.Row():
24
  gr.Markdown(INFO)
25
- with gr.Row().style(equal_height=False):
26
  with gr.Column():
27
  with gr.Tab("Click🖱"):
28
  source_image_click = gr.Image(
@@ -40,12 +44,13 @@ def create_demo_template(
40
  interactive=True,
41
  show_label=False,
42
  )
43
- clear_button_click = gr.Button(
44
- value="Clear Click Points", interactive=True
45
- )
46
- clear_button_image = gr.Button(
47
- value="Clear Image", interactive=True
48
- )
 
49
  with gr.Row():
50
  run_button_click = gr.Button(
51
  label="Run EditAnying", interactive=True
@@ -56,63 +61,75 @@ def create_demo_template(
56
  label="Image: Upload an image and cover the region you want to edit with sketch",
57
  type="numpy",
58
  tool="sketch",
 
59
  )
60
  run_button = gr.Button(
61
  label="Run EditAnying", interactive=True)
62
- with gr.Column():
63
- enable_all_generate = gr.Checkbox(
64
- label="Auto generation on all region.", value=False
 
 
65
  )
 
 
 
 
 
 
66
  control_scale = gr.Slider(
67
- label="Mask Align strength",
68
- info="Large value -> strict alignment with SAM mask",
69
  minimum=0,
70
  maximum=1,
71
  value=0.5,
72
  step=0.1,
73
  )
 
 
 
 
 
 
 
 
 
 
 
74
  with gr.Column():
75
- enable_auto_prompt = gr.Checkbox(
76
- label="Auto generate text prompt from input image with BLIP2",
77
- info="Warning: Enable this may makes your prompt not working.",
78
- value=enable_auto_prompt_default,
79
- )
80
- a_prompt = gr.Textbox(
81
- label="Positive Prompt",
82
- info="Text in the expected things of edited region",
83
- value="best quality, extremely detailed,",
84
- )
85
- n_prompt = gr.Textbox(
86
- label="Negative Prompt",
87
- value="longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, NSFW",
88
- )
89
- with gr.Row():
90
- num_samples = gr.Slider(
91
- label="Images", minimum=1, maximum=12, value=2, step=1
92
- )
93
- seed = gr.Slider(
94
- label="Seed",
95
- minimum=-1,
96
- maximum=2147483647,
97
- step=1,
98
- randomize=True,
99
- )
100
  with gr.Row():
101
  enable_tile = gr.Checkbox(
102
- label="Tile refinement for high resolution generation",
103
  info="Slow inference",
104
  value=True,
105
  )
106
  refine_alignment_ratio = gr.Slider(
107
- label="Alignment Strength",
108
- info="Large value -> strict alignment with input image. Small value -> strong global consistency",
109
  minimum=0.0,
110
  maximum=1.0,
111
  value=0.95,
112
  step=0.05,
113
  )
114
 
115
- with gr.Accordion("Reference options", open=False):
116
  # ref_image = gr.Image(
117
  # source='upload', label="Upload a reference image", type="pil", value=None)
118
  ref_image = gr.Image(
@@ -120,8 +137,9 @@ def create_demo_template(
120
  label="Upload a reference image and cover the region you want to use with sketch",
121
  type="pil",
122
  tool="sketch",
 
123
  )
124
- with gr.Column():
125
  ref_auto_prompt = gr.Checkbox(
126
  label="Ref. Auto Prompt", value=True
127
  )
@@ -148,45 +166,25 @@ def create_demo_template(
148
  with gr.Row():
149
  reference_attn = gr.Checkbox(
150
  label="reference_attn", value=True)
151
- attention_auto_machine_weight = gr.Slider(
152
- label="attention_weight",
153
- minimum=0,
154
- maximum=1.0,
155
- value=0.8,
156
- step=0.01,
157
  )
158
  with gr.Row():
159
- reference_adain = gr.Checkbox(
160
- label="reference_adain", value=False
 
 
 
 
161
  )
162
- gn_auto_machine_weight = gr.Slider(
163
- label="gn_weight",
164
  minimum=0,
165
  maximum=1.0,
166
- value=0.1,
167
- step=0.01,
168
  )
169
- style_fidelity = gr.Slider(
170
- label="Style fidelity",
171
- minimum=0,
172
- maximum=1.0,
173
- value=0.5,
174
- step=0.01,
175
- )
176
- ref_sam_scale = gr.Slider(
177
- label="SAM Control Scale",
178
- minimum=0,
179
- maximum=1.0,
180
- value=0.3,
181
- step=0.1,
182
- )
183
- ref_inpaint_scale = gr.Slider(
184
- label="Inpaint Control Scale",
185
- minimum=0,
186
- maximum=1.0,
187
- value=0.2,
188
- step=0.1,
189
- )
190
  with gr.Row():
191
  ref_textinv = gr.Checkbox(
192
  label="Use textual inversion token", value=False
@@ -196,8 +194,37 @@ def create_demo_template(
196
  info="Text in the inversion token path",
197
  value=None,
198
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
 
200
- with gr.Accordion("Advanced options", open=False):
201
  mask_image = gr.Image(
202
  source="upload",
203
  label="Upload a predefined mask of edit region: Switch to Brush mode when using this!",
@@ -244,19 +271,16 @@ def create_demo_template(
244
  )
245
  with gr.Column():
246
  result_gallery_refine = gr.Gallery(
247
- label="Output High quality", show_label=True, elem_id="gallery"
248
- ).style(grid=2, preview=False)
249
  result_gallery_init = gr.Gallery(
250
- label="Output Low quality", show_label=True, elem_id="gallery"
251
- ).style(grid=2, height="auto")
252
  result_gallery_ref = gr.Gallery(
253
- label="Output Ref", show_label=False, elem_id="gallery"
254
- ).style(grid=2, height="auto")
255
- result_text = gr.Text(label="BLIP2+Human Prompt Text")
256
 
257
  ips = [
258
  source_image_brush,
259
- enable_all_generate,
260
  mask_image,
261
  control_scale,
262
  enable_auto_prompt,
@@ -288,6 +312,7 @@ def create_demo_template(
288
  ref_auto_prompt,
289
  ref_textinv,
290
  ref_textinv_path,
 
291
  ]
292
  run_button.click(
293
  fn=process,
@@ -299,10 +324,56 @@ def create_demo_template(
299
  result_text,
300
  ],
301
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
302
 
303
  ip_click = [
304
  origin_image,
305
- enable_all_generate,
306
  click_mask,
307
  control_scale,
308
  enable_auto_prompt,
@@ -334,6 +405,7 @@ def create_demo_template(
334
  ref_auto_prompt,
335
  ref_textinv,
336
  ref_textinv_path,
 
337
  ]
338
 
339
  run_button_click.click(
 
1
  # Edit Anything trained with Stable Diffusion + ControlNet + SAM + BLIP2
2
  import gradio as gr
3
 
4
+ import numpy as np
5
+ import cv2
6
+ from cv2 import imencode
7
+ import base64
8
 
9
  def create_demo_template(
10
  process,
 
26
  ref_click_mask = gr.State(None)
27
  with gr.Row():
28
  gr.Markdown(INFO)
29
+ with gr.Row(equal_height=False):
30
  with gr.Column():
31
  with gr.Tab("Click🖱"):
32
  source_image_click = gr.Image(
 
44
  interactive=True,
45
  show_label=False,
46
  )
47
+ with gr.Row():
48
+ clear_button_click = gr.Button(
49
+ value="Clear Points", interactive=True
50
+ )
51
+ clear_button_image = gr.Button(
52
+ value="Reset Image", interactive=True
53
+ )
54
  with gr.Row():
55
  run_button_click = gr.Button(
56
  label="Run EditAnying", interactive=True
 
61
  label="Image: Upload an image and cover the region you want to edit with sketch",
62
  type="numpy",
63
  tool="sketch",
64
+ brush_color="#00FFBF"
65
  )
66
  run_button = gr.Button(
67
  label="Run EditAnying", interactive=True)
68
+ with gr.Tab("All region"):
69
+ source_image_clean = gr.Image(
70
+ source="upload",
71
+ label="Image: Upload an image",
72
+ type="numpy",
73
  )
74
+ run_button_allregion = gr.Button(
75
+ label="Run EditAnying", interactive=True)
76
+ with gr.Row():
77
+ # enable_all_generate = gr.Checkbox(
78
+ # label="All Region Generation", value=False
79
+ # )
80
  control_scale = gr.Slider(
81
+ label="SAM Mask Alignment Strength",
82
+ # info="Large value -> strict alignment with SAM mask",
83
  minimum=0,
84
  maximum=1,
85
  value=0.5,
86
  step=0.1,
87
  )
88
+ with gr.Row():
89
+ num_samples = gr.Slider(
90
+ label="Images", minimum=1, maximum=12, value=2, step=1
91
+ )
92
+ seed = gr.Slider(
93
+ label="Seed",
94
+ minimum=-1,
95
+ maximum=2147483647,
96
+ step=1,
97
+ randomize=True,
98
+ )
99
  with gr.Column():
100
+ with gr.Row():
101
+ enable_auto_prompt = gr.Checkbox(
102
+ label="Prompt Auto Generation (Enable this may makes your prompt not working)",
103
+ # info="",
104
+ value=enable_auto_prompt_default,
105
+ )
106
+ with gr.Row():
107
+ a_prompt = gr.Textbox(
108
+ label="Positive Prompt",
109
+ info="Text in the expected things of edited region",
110
+ value="best quality, extremely detailed,",
111
+ )
112
+ n_prompt = gr.Textbox(
113
+ label="Negative Prompt",
114
+ value="longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, NSFW",
115
+ )
116
+
 
 
 
 
 
 
 
 
117
  with gr.Row():
118
  enable_tile = gr.Checkbox(
119
+ label="High-resolution Refinement",
120
  info="Slow inference",
121
  value=True,
122
  )
123
  refine_alignment_ratio = gr.Slider(
124
+ label="Similarity with Initial Results",
125
+ # info="Large value -> strict alignment with input image. Small value -> strong global consistency",
126
  minimum=0.0,
127
  maximum=1.0,
128
  value=0.95,
129
  step=0.05,
130
  )
131
 
132
+ with gr.Accordion("Cross-image Drag Options", open=False):
133
  # ref_image = gr.Image(
134
  # source='upload', label="Upload a reference image", type="pil", value=None)
135
  ref_image = gr.Image(
 
137
  label="Upload a reference image and cover the region you want to use with sketch",
138
  type="pil",
139
  tool="sketch",
140
+ brush_color="#00FFBF",
141
  )
142
+ with gr.Row():
143
  ref_auto_prompt = gr.Checkbox(
144
  label="Ref. Auto Prompt", value=True
145
  )
 
166
  with gr.Row():
167
  reference_attn = gr.Checkbox(
168
  label="reference_attn", value=True)
169
+ reference_adain = gr.Checkbox(
170
+ label="reference_adain", value=True
 
 
 
 
171
  )
172
  with gr.Row():
173
+ ref_sam_scale = gr.Slider(
174
+ label="Pos Control Scale",
175
+ minimum=0,
176
+ maximum=1.0,
177
+ value=0.3,
178
+ step=0.1,
179
  )
180
+ ref_inpaint_scale = gr.Slider(
181
+ label="Content Control Scale",
182
  minimum=0,
183
  maximum=1.0,
184
+ value=0.2,
185
+ step=0.1,
186
  )
187
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
  with gr.Row():
189
  ref_textinv = gr.Checkbox(
190
  label="Use textual inversion token", value=False
 
194
  info="Text in the inversion token path",
195
  value=None,
196
  )
197
+ with gr.Accordion("Advanced options", open=False):
198
+ style_fidelity = gr.Slider(
199
+ label="Style fidelity",
200
+ minimum=0,
201
+ maximum=1.,
202
+ value=0.,
203
+ step=0.1,
204
+ )
205
+ attention_auto_machine_weight = gr.Slider(
206
+ label="Attention Reference Weight",
207
+ minimum=0,
208
+ maximum=1.0,
209
+ value=1.0,
210
+ step=0.01,
211
+ )
212
+ gn_auto_machine_weight = gr.Slider(
213
+ label="GroupNorm Reference Weight",
214
+ minimum=0,
215
+ maximum=1.0,
216
+ value=1.0,
217
+ step=0.01,
218
+ )
219
+ ref_scale = gr.Slider(
220
+ label="Frequency Reference Guidance Scale",
221
+ minimum=0,
222
+ maximum=1.0,
223
+ value=0.0,
224
+ step=0.1,
225
+ )
226
 
227
+ with gr.Accordion("Advanced Options", open=False):
228
  mask_image = gr.Image(
229
  source="upload",
230
  label="Upload a predefined mask of edit region: Switch to Brush mode when using this!",
 
271
  )
272
  with gr.Column():
273
  result_gallery_refine = gr.Gallery(
274
+ label="Output High quality", show_label=True, elem_id="gallery", preview=False)
 
275
  result_gallery_init = gr.Gallery(
276
+ label="Output Low quality", show_label=True, elem_id="gallery", height="auto")
 
277
  result_gallery_ref = gr.Gallery(
278
+ label="Output Ref", show_label=False, elem_id="gallery", height="auto")
279
+ result_text = gr.Text(label="ALL Prompt Text")
 
280
 
281
  ips = [
282
  source_image_brush,
283
+ gr.State(False), # enable_all_generate
284
  mask_image,
285
  control_scale,
286
  enable_auto_prompt,
 
312
  ref_auto_prompt,
313
  ref_textinv,
314
  ref_textinv_path,
315
+ ref_scale,
316
  ]
317
  run_button.click(
318
  fn=process,
 
324
  result_text,
325
  ],
326
  )
327
+ ips_allregion = [
328
+ source_image_clean,
329
+ gr.State(True), # enable_all_generate
330
+ mask_image,
331
+ control_scale,
332
+ enable_auto_prompt,
333
+ a_prompt,
334
+ n_prompt,
335
+ num_samples,
336
+ image_resolution,
337
+ detect_resolution,
338
+ ddim_steps,
339
+ guess_mode,
340
+ scale,
341
+ seed,
342
+ eta,
343
+ enable_tile,
344
+ refine_alignment_ratio,
345
+ refine_image_resolution,
346
+ alpha_weight,
347
+ use_scale_map,
348
+ condition_model,
349
+ ref_image,
350
+ attention_auto_machine_weight,
351
+ gn_auto_machine_weight,
352
+ style_fidelity,
353
+ reference_attn,
354
+ reference_adain,
355
+ ref_prompt,
356
+ ref_sam_scale,
357
+ ref_inpaint_scale,
358
+ ref_auto_prompt,
359
+ ref_textinv,
360
+ ref_textinv_path,
361
+ ref_scale,
362
+ ]
363
+ run_button_allregion.click(
364
+ fn=process,
365
+ inputs=ips_allregion,
366
+ outputs=[
367
+ result_gallery_refine,
368
+ result_gallery_init,
369
+ result_gallery_ref,
370
+ result_text,
371
+ ],
372
+ )
373
 
374
  ip_click = [
375
  origin_image,
376
+ gr.State(False), # enable_all_generate
377
  click_mask,
378
  control_scale,
379
  enable_auto_prompt,
 
405
  ref_auto_prompt,
406
  ref_textinv,
407
  ref_textinv_path,
408
+ ref_scale,
409
  ]
410
 
411
  run_button_click.click(
editany_lora.py CHANGED
@@ -14,7 +14,7 @@ import random
14
  import os
15
  import requests
16
  from io import BytesIO
17
- from annotator.util import resize_image, HWC3, resize_points, get_bounding_box
18
 
19
  import torch
20
  from safetensors.torch import load_file
@@ -28,8 +28,7 @@ from utils.stable_diffusion_controlnet_inpaint import StableDiffusionControlNetI
28
  # need the latest transformers
29
  # pip install git+https://github.com/huggingface/transformers.git
30
  from transformers import AutoProcessor, Blip2ForConditionalGeneration
31
- from diffusers import ControlNetModel, DiffusionPipeline
32
- from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
33
  import PIL.Image
34
 
35
  # Segment-Anything init.
@@ -119,16 +118,55 @@ def get_pipeline_embeds(pipeline, prompt, negative_prompt, device):
119
  """
120
  max_length = pipeline.tokenizer.model_max_length
121
 
122
- # simple way to determine length of tokens
123
- count_prompt = len(re.split(r", ", prompt))
124
- count_negative_prompt = len(re.split(r", ", negative_prompt))
125
-
126
- # create the tensor based on which prompt is longer
127
- if count_prompt >= count_negative_prompt:
128
- input_ids = pipeline.tokenizer(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  prompt, return_tensors="pt", truncation=False
130
  ).input_ids.to(device)
131
- shape_max_length = input_ids.shape[-1]
 
 
 
 
 
 
 
132
  negative_ids = pipeline.tokenizer(
133
  negative_prompt,
134
  truncation=False,
@@ -137,23 +175,21 @@ def get_pipeline_embeds(pipeline, prompt, negative_prompt, device):
137
  return_tensors="pt",
138
  ).input_ids.to(device)
139
  else:
140
- negative_ids = pipeline.tokenizer(
141
- negative_prompt, return_tensors="pt", truncation=False
142
- ).input_ids.to(device)
143
- shape_max_length = negative_ids.shape[-1]
144
  input_ids = pipeline.tokenizer(
145
- prompt,
146
- return_tensors="pt",
147
- truncation=False,
148
- padding="max_length",
149
- max_length=shape_max_length,
150
- ).input_ids.to(device)
151
 
152
  concat_embeds = []
153
  neg_embeds = []
154
  for i in range(0, shape_max_length, max_length):
155
- concat_embeds.append(pipeline.text_encoder(input_ids[:, i : i + max_length])[0])
156
- neg_embeds.append(pipeline.text_encoder(negative_ids[:, i : i + max_length])[0])
 
 
157
 
158
  return torch.cat(concat_embeds, dim=1), torch.cat(neg_embeds, dim=1)
159
 
@@ -178,10 +214,12 @@ def load_lora_weights(pipeline, checkpoint_path, multiplier, device, dtype):
178
  for layer, elems in updates.items():
179
 
180
  if "text" in layer:
181
- layer_infos = layer.split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
 
182
  curr_layer = pipeline.text_encoder
183
  else:
184
- layer_infos = layer.split(LORA_PREFIX_UNET + "_")[-1].split("_")
 
185
  curr_layer = pipeline.unet
186
 
187
  # find the target layer
@@ -244,7 +282,8 @@ def load_lora_weights(pipeline, checkpoint_path, multiplier, device, dtype):
244
  )
245
  curr_layer = pipeline.text_encoder
246
  else:
247
- layer_infos = layer.split(LORA_PREFIX_UNET + "_")[-1].split("_")
 
248
  curr_layer = pipeline.unet
249
 
250
  # find the target layer
@@ -489,7 +528,7 @@ class EditAnythingLoraModel:
489
  self.mask_predictor.set_image(image)
490
  # Separate the points and labels
491
  points, labels = zip(*[(point[:2], point[2])
492
- for point in clicked_points])
493
 
494
  # Convert the points and labels to numpy arrays
495
  input_point = np.array(points)
@@ -534,7 +573,8 @@ class EditAnythingLoraModel:
534
  mask_click_np = np.transpose(mask_click_np, (1, 2, 0)) * 255.0
535
 
536
  mask_image = HWC3(mask_click_np.astype(np.uint8))
537
- mask_image = cv2.resize(mask_image, (W, H), interpolation=cv2.INTER_LINEAR)
 
538
  # mask_image = Image.fromarray(mask_image_tmp)
539
 
540
  # Draw circles for all clicked points
@@ -567,6 +607,7 @@ class EditAnythingLoraModel:
567
  )
568
 
569
  @torch.inference_mode()
 
570
  def process(
571
  self,
572
  source_image,
@@ -602,6 +643,7 @@ class EditAnythingLoraModel:
602
  ref_auto_prompt=False,
603
  ref_textinv=True,
604
  ref_textinv_path=None,
 
605
  ):
606
 
607
  if condition_model is None or condition_model == "EditAnything":
@@ -624,14 +666,9 @@ class EditAnythingLoraModel:
624
  )
625
  self.defalut_enable_all_generate = enable_all_generate
626
  if enable_all_generate:
627
- print(
628
- "source_image",
629
- source_image["mask"].shape,
630
- input_image.shape,
631
- )
632
  mask_image = (
633
  np.ones((input_image.shape[0],
634
- input_image.shape[1], 3)) * 255
635
  )
636
  else:
637
  mask_image = source_image["mask"]
@@ -699,11 +736,13 @@ class EditAnythingLoraModel:
699
  except:
700
  print("No textinvert embeddings found.")
701
  ref_data_path = "./utils/tmp/textinv/img"
702
- if not os.path.exists(ref_data_path):
703
  os.makedirs(ref_data_path)
704
- cropped_ref_image.save(os.path.join(ref_data_path, 'ref.png'))
 
705
  print("Ref image region is save to:", ref_data_path)
706
- print("Plese finetune with run_texutal_inversion.sh in utils folder to get the textinvert embeddings.")
 
707
 
708
  else:
709
  ref_mask = None
@@ -735,7 +774,7 @@ class EditAnythingLoraModel:
735
  )
736
 
737
  control = torch.from_numpy(detected_map.copy()).float().cuda()
738
- control = torch.stack([control for _ in range(num_samples)], dim=0)
739
  control = einops.rearrange(control, "b h w c -> b c h w").clone()
740
 
741
  mask_imag_ori = HWC3(mask_image.astype(np.uint8))
@@ -753,14 +792,8 @@ class EditAnythingLoraModel:
753
  prompt_embeds, negative_prompt_embeds = get_pipeline_embeds(
754
  self.pipe, postive_prompt, negative_prompt, "cuda"
755
  )
756
- prompt_embeds = torch.cat([prompt_embeds] * num_samples, dim=0)
757
- negative_prompt_embeds = torch.cat(
758
- [negative_prompt_embeds] * num_samples, dim=0
759
- )
760
 
761
  if enable_all_generate and self.extra_inpaint:
762
- self.pipe.safety_checker = lambda images, clip_input: (
763
- images, False)
764
  if ref_image is not None:
765
  print("Not support yet.")
766
  return
@@ -845,6 +878,7 @@ class EditAnythingLoraModel:
845
  reference_adain=reference_adain,
846
  ref_controlnet_conditioning_scale=ref_multi_condition_scale,
847
  guess_mode=guess_mode,
 
848
  ).images
849
  results = [x_samples[i] for i in range(num_samples)]
850
 
 
14
  import os
15
  import requests
16
  from io import BytesIO
17
+ from annotator.util import resize_image, HWC3, resize_points, get_bounding_box, save_input_to_file
18
 
19
  import torch
20
  from safetensors.torch import load_file
 
28
  # need the latest transformers
29
  # pip install git+https://github.com/huggingface/transformers.git
30
  from transformers import AutoProcessor, Blip2ForConditionalGeneration
31
+ from diffusers import ControlNetModel
 
32
  import PIL.Image
33
 
34
  # Segment-Anything init.
 
118
  """
119
  max_length = pipeline.tokenizer.model_max_length
120
 
121
+ # # simple way to determine length of tokens
122
+ # count_prompt = len(re.split(r",", prompt))
123
+ # count_negative_prompt = len(re.split(r",", negative_prompt))
124
+
125
+ # # create the tensor based on which prompt is longer
126
+ # if count_prompt >= count_negative_prompt:
127
+ # input_ids = pipeline.tokenizer(
128
+ # prompt, return_tensors="pt", truncation=False
129
+ # ).input_ids.to(device)
130
+ # shape_max_length = input_ids.shape[-1]
131
+ # negative_ids = pipeline.tokenizer(
132
+ # negative_prompt,
133
+ # truncation=False,
134
+ # padding="max_length",
135
+ # max_length=shape_max_length,
136
+ # return_tensors="pt",
137
+ # ).input_ids.to(device)
138
+ # else:
139
+ # negative_ids = pipeline.tokenizer(
140
+ # negative_prompt, return_tensors="pt", truncation=False
141
+ # ).input_ids.to(device)
142
+ # shape_max_length = negative_ids.shape[-1]
143
+ # input_ids = pipeline.tokenizer(
144
+ # prompt,
145
+ # return_tensors="pt",
146
+ # truncation=False,
147
+ # padding="max_length",
148
+ # max_length=shape_max_length,
149
+ # ).input_ids.to(device)
150
+
151
+ # concat_embeds = []
152
+ # neg_embeds = []
153
+ # for i in range(0, shape_max_length, max_length):
154
+ # concat_embeds.append(pipeline.text_encoder(
155
+ # input_ids[:, i: i + max_length])[0])
156
+ # neg_embeds.append(pipeline.text_encoder(
157
+ # negative_ids[:, i: i + max_length])[0])
158
+
159
+ input_ids = pipeline.tokenizer(
160
  prompt, return_tensors="pt", truncation=False
161
  ).input_ids.to(device)
162
+
163
+ negative_ids = pipeline.tokenizer(
164
+ negative_prompt, return_tensors="pt", truncation=False
165
+ ).input_ids.to(device)
166
+
167
+ shape_max_length = max(input_ids.shape[-1],negative_ids.shape[-1])
168
+
169
+ if input_ids.shape[-1]>negative_ids.shape[-1]:
170
  negative_ids = pipeline.tokenizer(
171
  negative_prompt,
172
  truncation=False,
 
175
  return_tensors="pt",
176
  ).input_ids.to(device)
177
  else:
 
 
 
 
178
  input_ids = pipeline.tokenizer(
179
+ prompt,
180
+ return_tensors="pt",
181
+ truncation=False,
182
+ padding="max_length",
183
+ max_length=shape_max_length,
184
+ ).input_ids.to(device)
185
 
186
  concat_embeds = []
187
  neg_embeds = []
188
  for i in range(0, shape_max_length, max_length):
189
+ concat_embeds.append(pipeline.text_encoder(
190
+ input_ids[:, i: i + max_length])[0])
191
+ neg_embeds.append(pipeline.text_encoder(
192
+ negative_ids[:, i: i + max_length])[0])
193
 
194
  return torch.cat(concat_embeds, dim=1), torch.cat(neg_embeds, dim=1)
195
 
 
214
  for layer, elems in updates.items():
215
 
216
  if "text" in layer:
217
+ layer_infos = layer.split(
218
+ LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
219
  curr_layer = pipeline.text_encoder
220
  else:
221
+ layer_infos = layer.split(
222
+ LORA_PREFIX_UNET + "_")[-1].split("_")
223
  curr_layer = pipeline.unet
224
 
225
  # find the target layer
 
282
  )
283
  curr_layer = pipeline.text_encoder
284
  else:
285
+ layer_infos = layer.split(
286
+ LORA_PREFIX_UNET + "_")[-1].split("_")
287
  curr_layer = pipeline.unet
288
 
289
  # find the target layer
 
528
  self.mask_predictor.set_image(image)
529
  # Separate the points and labels
530
  points, labels = zip(*[(point[:2], point[2])
531
+ for point in clicked_points])
532
 
533
  # Convert the points and labels to numpy arrays
534
  input_point = np.array(points)
 
573
  mask_click_np = np.transpose(mask_click_np, (1, 2, 0)) * 255.0
574
 
575
  mask_image = HWC3(mask_click_np.astype(np.uint8))
576
+ mask_image = cv2.resize(
577
+ mask_image, (W, H), interpolation=cv2.INTER_LINEAR)
578
  # mask_image = Image.fromarray(mask_image_tmp)
579
 
580
  # Draw circles for all clicked points
 
607
  )
608
 
609
  @torch.inference_mode()
610
+ @save_input_to_file # for debug use
611
  def process(
612
  self,
613
  source_image,
 
643
  ref_auto_prompt=False,
644
  ref_textinv=True,
645
  ref_textinv_path=None,
646
+ ref_scale=None,
647
  ):
648
 
649
  if condition_model is None or condition_model == "EditAnything":
 
666
  )
667
  self.defalut_enable_all_generate = enable_all_generate
668
  if enable_all_generate:
 
 
 
 
 
669
  mask_image = (
670
  np.ones((input_image.shape[0],
671
+ input_image.shape[1], 3)) * 255
672
  )
673
  else:
674
  mask_image = source_image["mask"]
 
736
  except:
737
  print("No textinvert embeddings found.")
738
  ref_data_path = "./utils/tmp/textinv/img"
739
+ if not os.path.exists(ref_data_path):
740
  os.makedirs(ref_data_path)
741
+ cropped_ref_image.save(
742
+ os.path.join(ref_data_path, 'ref.png'))
743
  print("Ref image region is save to:", ref_data_path)
744
+ print(
745
+ "Plese finetune with run_texutal_inversion.sh in utils folder to get the textinvert embeddings.")
746
 
747
  else:
748
  ref_mask = None
 
774
  )
775
 
776
  control = torch.from_numpy(detected_map.copy()).float().cuda()
777
+ control = control.unsqueeze(dim=0)
778
  control = einops.rearrange(control, "b h w c -> b c h w").clone()
779
 
780
  mask_imag_ori = HWC3(mask_image.astype(np.uint8))
 
792
  prompt_embeds, negative_prompt_embeds = get_pipeline_embeds(
793
  self.pipe, postive_prompt, negative_prompt, "cuda"
794
  )
 
 
 
 
795
 
796
  if enable_all_generate and self.extra_inpaint:
 
 
797
  if ref_image is not None:
798
  print("Not support yet.")
799
  return
 
878
  reference_adain=reference_adain,
879
  ref_controlnet_conditioning_scale=ref_multi_condition_scale,
880
  guess_mode=guess_mode,
881
+ ref_scale=ref_scale,
882
  ).images
883
  results = [x_samples[i] for i in range(num_samples)]
884
 
editany_nogradio.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ from editany_lora import EditAnythingLoraModel
3
+ model = EditAnythingLoraModel(
4
+ base_model_path="runwayml/stable-diffusion-v1-5",
5
+ controlmodel_name='LAION Pretrained(v0-4)-SD15',
6
+ lora_model_path=None, use_blip=False, extra_inpaint=True,
7
+ )
8
+
9
+ with open('input_data.pkl', 'rb') as f:
10
+ input_data = pickle.load(f)
11
+
12
+ print(input_data)
13
+
14
+ refined, output, ref, text = model.process(*input_data['args'], **input_data['kwargs'])
15
+
16
+ output
17
+
18
+ # a woman in a tan suit and white shirt
19
+
20
+ # best quality, extremely detailed,iron man wallpaper
editany_test.py CHANGED
@@ -70,4 +70,4 @@ if __name__ == "__main__":
70
  lora_weight=0.5,
71
  )
72
  demo = create_demo(model.process, model.process_image_click)
73
- demo.queue().launch(server_name="0.0.0.0")
 
70
  lora_weight=0.5,
71
  )
72
  demo = create_demo(model.process, model.process_image_click)
73
+ demo.queue().launch(server_name="0.0.0.0", share=True)
environment.yaml ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: control
2
+ channels:
3
+ - pytorch
4
+ - defaults
5
+ dependencies:
6
+ - python=3.8.5
7
+ - pip=20.3
8
+ - cudatoolkit=11.3
9
+ - pytorch=1.13.1
10
+ - torchvision=0.14.1
11
+ - numpy=1.23.1
12
+ - pip:
13
+ - gradio==3.35.2
14
+ - albumentations==1.3.0
15
+ - opencv-contrib-python==4.3.0.36
16
+ - imageio==2.9.0
17
+ - imageio-ffmpeg==0.4.2
18
+ - pytorch-lightning==1.5.0
19
+ - omegaconf==2.1.1
20
+ - test-tube>=0.7.5
21
+ - streamlit==1.12.1
22
+ - einops==0.3.0
23
+ - webdataset==0.2.5
24
+ - kornia==0.6
25
+ - open_clip_torch==2.0.2
26
+ - invisible-watermark>=0.1.5
27
+ - streamlit-drawable-canvas==0.8.0
28
+ - torchmetrics==0.6.0
29
+ - timm==0.6.12
30
+ - addict==2.4.0
31
+ - yapf==0.32.0
32
+ - prettytable==3.6.0
33
+ - safetensors==0.2.7
34
+ - basicsr==1.4.2
35
+ - diffusers==0.17.1
36
+ - accelerate==0.17.0
37
+ - transformers==4.30.2
38
+ - xformers
requirements.txt CHANGED
@@ -30,4 +30,4 @@ transformers==4.30.2
30
  xformers==0.0.16
31
  triton
32
  gradio==3.35.2
33
- gradio-client==0.2.7
 
30
  xformers==0.0.16
31
  triton
32
  gradio==3.35.2
33
+ gradio-client==0.2.7
utils/stable_diffusion_controlnet_inpaint.py CHANGED
@@ -1179,6 +1179,7 @@ class StableDiffusionControlNetInpaintPipeline(
1179
  style_fidelity: float = 0.5,
1180
  reference_attn: bool = True,
1181
  reference_adain: bool = True,
 
1182
  ):
1183
  r"""
1184
  Function invoked when calling the pipeline for generation.
@@ -1272,6 +1273,8 @@ class StableDiffusionControlNetInpaintPipeline(
1272
  Whether to use reference query for self attention's context.
1273
  reference_adain (`bool`):
1274
  Whether to use reference adain.
 
 
1275
 
1276
  Examples:
1277
 
@@ -1346,8 +1349,9 @@ class StableDiffusionControlNetInpaintPipeline(
1346
  ref_prompt_embeds = self._encode_prompt(
1347
  ref_prompt,
1348
  device,
1349
- num_images_per_prompt * 2,
1350
- do_classifier_free_guidance,
 
1351
  negative_prompt="longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality",
1352
  prompt_embeds=None,
1353
  )
@@ -1414,13 +1418,13 @@ class StableDiffusionControlNetInpaintPipeline(
1414
  num_images_per_prompt=num_images_per_prompt,
1415
  device=device,
1416
  dtype=self.controlnet.dtype,
1417
- do_classifier_free_guidance=do_classifier_free_guidance,
1418
  )
1419
  ref_controlnet_conditioning_image = controlnet_conditioning_image.copy()
 
 
 
1420
  ref_controlnet_conditioning_image[-1] = ref_control_image
1421
- # ref_controlnet_conditioning_scale = controlnet_conditioning_scale.copy()
1422
- # ref_controlnet_conditioning_scale[0] = 1.0 # disable the first sam controlnet
1423
- # ref_controlnet_conditioning_scale[-1] = 0.2
1424
 
1425
  # 5. Prepare timesteps
1426
  self.scheduler.set_timesteps(num_inference_steps, device=device)
@@ -1491,7 +1495,7 @@ class StableDiffusionControlNetInpaintPipeline(
1491
  prompt_embeds.dtype,
1492
  device,
1493
  generator,
1494
- do_classifier_free_guidance,
1495
  )
1496
 
1497
  # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
@@ -1511,6 +1515,7 @@ class StableDiffusionControlNetInpaintPipeline(
1511
  self.gn_auto_machine_weight = gn_auto_machine_weight
1512
  self.do_classifier_free_guidance = do_classifier_free_guidance
1513
  self.style_fidelity = style_fidelity
 
1514
  self.ref_mask = ref_mask
1515
  self.inpaint_mask = mask_image
1516
  attn_modules, gn_modules = self.redefine_ref_model(
@@ -1518,9 +1523,16 @@ class StableDiffusionControlNetInpaintPipeline(
1518
  )
1519
 
1520
  control_attn_modules, control_gn_modules = self.redefine_ref_model(
1521
- self.controlnet, reference_attn, False, model_type="controlnet"
 
 
 
 
 
 
 
 
1522
  )
1523
-
1524
  # 8. Denoising loop
1525
  num_warmup_steps = len(timesteps) - \
1526
  num_inference_steps * self.scheduler.order
@@ -1549,12 +1561,6 @@ class StableDiffusionControlNetInpaintPipeline(
1549
 
1550
  if ref_image is not None: # for ref_only mode
1551
  # ref only part
1552
- noise = randn_tensor(
1553
- ref_image_latents.shape,
1554
- generator=generator,
1555
- device=ref_image_latents.device,
1556
- dtype=ref_image_latents.dtype,
1557
- )
1558
  ref_xt = self.scheduler.add_noise(
1559
  ref_image_latents,
1560
  noise,
@@ -1566,8 +1572,8 @@ class StableDiffusionControlNetInpaintPipeline(
1566
 
1567
  MODE = "write"
1568
  self.change_module_mode(
1569
- MODE, control_attn_modules, control_gn_modules
1570
- )
1571
 
1572
  (
1573
  ref_down_block_res_samples,
@@ -1582,7 +1588,6 @@ class StableDiffusionControlNetInpaintPipeline(
1582
  return_dict=False,
1583
  )
1584
 
1585
- self.change_module_mode(MODE, attn_modules, gn_modules)
1586
  self.unet(
1587
  ref_xt,
1588
  t,
@@ -1595,7 +1600,10 @@ class StableDiffusionControlNetInpaintPipeline(
1595
 
1596
  # predict the noise residual
1597
  MODE = "read" # change to read mode for following noise_pred
 
 
1598
  self.change_module_mode(MODE, attn_modules, gn_modules)
 
1599
  down_block_res_samples, mid_block_res_sample = self.controlnet(
1600
  non_inpainting_latent_model_input,
1601
  t,
 
1179
  style_fidelity: float = 0.5,
1180
  reference_attn: bool = True,
1181
  reference_adain: bool = True,
1182
+ ref_scale: float = 1.0,
1183
  ):
1184
  r"""
1185
  Function invoked when calling the pipeline for generation.
 
1273
  Whether to use reference query for self attention's context.
1274
  reference_adain (`bool`):
1275
  Whether to use reference adain.
1276
+ ref_scale (`float`):
1277
+ reference guidance scale.
1278
 
1279
  Examples:
1280
 
 
1349
  ref_prompt_embeds = self._encode_prompt(
1350
  ref_prompt,
1351
  device,
1352
+ # num_images_per_prompt * 2,
1353
+ num_images_per_prompt * 1,
1354
+ False,
1355
  negative_prompt="longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality",
1356
  prompt_embeds=None,
1357
  )
 
1418
  num_images_per_prompt=num_images_per_prompt,
1419
  device=device,
1420
  dtype=self.controlnet.dtype,
1421
+ do_classifier_free_guidance=False,
1422
  )
1423
  ref_controlnet_conditioning_image = controlnet_conditioning_image.copy()
1424
+ for i in range(len(ref_controlnet_conditioning_image)):
1425
+ ref_controlnet_conditioning_image[i] = ref_controlnet_conditioning_image[i].chunk(
1426
+ 2)[0] # remove the extra guidance for cfg
1427
  ref_controlnet_conditioning_image[-1] = ref_control_image
 
 
 
1428
 
1429
  # 5. Prepare timesteps
1430
  self.scheduler.set_timesteps(num_inference_steps, device=device)
 
1495
  prompt_embeds.dtype,
1496
  device,
1497
  generator,
1498
+ False,
1499
  )
1500
 
1501
  # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
 
1515
  self.gn_auto_machine_weight = gn_auto_machine_weight
1516
  self.do_classifier_free_guidance = do_classifier_free_guidance
1517
  self.style_fidelity = style_fidelity
1518
+ self.ref_scale = ref_scale
1519
  self.ref_mask = ref_mask
1520
  self.inpaint_mask = mask_image
1521
  attn_modules, gn_modules = self.redefine_ref_model(
 
1523
  )
1524
 
1525
  control_attn_modules, control_gn_modules = self.redefine_ref_model(
1526
+ self.controlnet, reference_attn, reference_adain, model_type="controlnet"
1527
+ )
1528
+ if ref_image is not None:
1529
+ noise = randn_tensor(
1530
+ # ref_image_latents.shape,
1531
+ latents.shape,
1532
+ generator=generator,
1533
+ device=ref_image_latents.device,
1534
+ dtype=ref_image_latents.dtype,
1535
  )
 
1536
  # 8. Denoising loop
1537
  num_warmup_steps = len(timesteps) - \
1538
  num_inference_steps * self.scheduler.order
 
1561
 
1562
  if ref_image is not None: # for ref_only mode
1563
  # ref only part
 
 
 
 
 
 
1564
  ref_xt = self.scheduler.add_noise(
1565
  ref_image_latents,
1566
  noise,
 
1572
 
1573
  MODE = "write"
1574
  self.change_module_mode(
1575
+ MODE, control_attn_modules, control_gn_modules)
1576
+ self.change_module_mode(MODE, attn_modules, gn_modules)
1577
 
1578
  (
1579
  ref_down_block_res_samples,
 
1588
  return_dict=False,
1589
  )
1590
 
 
1591
  self.unet(
1592
  ref_xt,
1593
  t,
 
1600
 
1601
  # predict the noise residual
1602
  MODE = "read" # change to read mode for following noise_pred
1603
+ self.change_module_mode(
1604
+ MODE, control_attn_modules, control_gn_modules)
1605
  self.change_module_mode(MODE, attn_modules, gn_modules)
1606
+
1607
  down_block_res_samples, mid_block_res_sample = self.controlnet(
1608
  non_inpainting_latent_model_input,
1609
  t,
utils/stable_diffusion_reference.py CHANGED
@@ -1,12 +1,12 @@
1
  # Based on https://raw.githubusercontent.com/okotaku/diffusers/feature/reference_only_control/examples/community/stable_diffusion_reference.py
2
  # Inspired by: https://github.com/Mikubill/sd-webui-controlnet/discussions/1236
 
3
  from typing import Any, Callable, Dict, List, Optional, Union, Tuple
4
 
5
  import numpy as np
6
  import PIL.Image
7
  import torch
8
 
9
- from diffusers import StableDiffusionPipeline
10
  from diffusers.models.attention import BasicTransformerBlock
11
  from diffusers.models.unet_2d_blocks import (
12
  CrossAttnDownBlock2D,
@@ -14,11 +14,9 @@ from diffusers.models.unet_2d_blocks import (
14
  DownBlock2D,
15
  UpBlock2D,
16
  )
17
- from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
18
  from diffusers.utils import PIL_INTERPOLATION, logging
19
  import torch.nn.functional as F
20
 
21
-
22
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
23
 
24
  EXAMPLE_DOC_STRING = """
@@ -56,6 +54,127 @@ def torch_dfs(model: torch.nn.Module):
56
  return result
57
 
58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  class StableDiffusionReferencePipeline:
60
  def prepare_ref_image(
61
  self,
@@ -237,9 +356,8 @@ class StableDiffusionReferencePipeline:
237
  this_ref_mask = F.interpolate(
238
  this_ref_mask, scale_factor=ref_scale
239
  )
240
- # print("this_ref_mask",this_ref_mask.shape)
241
-
242
- # this_ref_mask = this_ref_mask.view(1,-1,1)
243
  this_ref_mask = this_ref_mask.repeat(
244
  resize_norm_hidden_states.shape[0],
245
  resize_norm_hidden_states.shape[1],
@@ -256,11 +374,14 @@ class StableDiffusionReferencePipeline:
256
  -1,
257
  )
258
  )
 
259
  masked_norm_hidden_states = masked_norm_hidden_states.permute(
260
  0, 2, 1
261
  )
262
  self.bank.append(masked_norm_hidden_states)
263
- # self.bank.append(norm_hidden_states.detach().clone())
 
 
264
  attn_output = self.attn1(
265
  norm_hidden_states,
266
  encoder_hidden_states=encoder_hidden_states
@@ -271,31 +392,27 @@ class StableDiffusionReferencePipeline:
271
  )
272
  if self.MODE == "read":
273
  if self.attention_auto_machine_weight > self.attn_weight:
274
- # scale_ratio = ((self.ref_mask.shape[2] * self.ref_mask.shape[3])/norm_hidden_states.shape[1])**0.5
275
- # print(scale_ratio)
276
- # this_ref_mask = F.interpolate(self.ref_mask.to(norm_hidden_states.device), scale_factor=1/scale_ratio).view(1,1,-1)
277
- # print("resized mask", this_ref_mask.shape, this_ref_mask.max(), this_ref_mask.min(), this_ref_mask.sum())
278
- # ref_hidden_states = torch.cat([norm_hidden_states] + self.bank, dim=1)
279
- # if attention_mask is None:
280
- # attention_mask = torch.ones(
281
- # norm_hidden_states.shape[0], norm_hidden_states.shape[1], ref_hidden_states.shape[1], dtype=norm_hidden_states.dtype, device=norm_hidden_states.device
282
- # )
283
- # this_ref_mask = this_ref_mask.repeat(norm_hidden_states.shape[0], norm_hidden_states.shape[1], 1)
284
- # this_ref_mask = torch.zeros(
285
- # norm_hidden_states.shape[0], norm_hidden_states.shape[1], this_ref_mask.shape[1], dtype=norm_hidden_states.dtype, device=norm_hidden_states.device
286
- # )
287
- # print(attention_mask.shape, this_ref_mask.shape)
288
- # attention_mask = torch.cat((attention_mask, this_ref_mask), dim=-1)
289
- # print("merge", attention_mask.shape)
290
  ref_hidden_states = torch.cat(
291
- [norm_hidden_states] + self.bank, dim=1
292
  )
 
 
 
293
  attn_output_uc = self.attn1(
294
- norm_hidden_states,
295
  encoder_hidden_states=ref_hidden_states,
296
- # attention_mask=attention_mask,
297
  **cross_attention_kwargs,
298
  )
 
299
  attn_output_c = attn_output_uc.clone()
300
  if self.do_classifier_free_guidance and self.style_fidelity > 0:
301
  attn_output_c[self.uc_mask] = self.attn1(
@@ -308,6 +425,9 @@ class StableDiffusionReferencePipeline:
308
  + (1.0 - self.style_fidelity) * attn_output_uc
309
  )
310
  self.bank.clear()
 
 
 
311
  else:
312
  attn_output = self.attn1(
313
  norm_hidden_states,
@@ -317,6 +437,9 @@ class StableDiffusionReferencePipeline:
317
  attention_mask=attention_mask,
318
  **cross_attention_kwargs,
319
  )
 
 
 
320
  if self.use_ada_layer_norm_zero:
321
  attn_output = gate_msa.unsqueeze(1) * attn_output
322
  hidden_states = attn_output + hidden_states
@@ -365,6 +488,10 @@ class StableDiffusionReferencePipeline:
365
  this_ref_mask = F.interpolate(
366
  self.ref_mask.to(x.device), scale_factor=1 / scale_ratio
367
  )
 
 
 
 
368
  this_ref_mask = this_ref_mask.repeat(
369
  x.shape[0], x.shape[1], 1, 1
370
  ).bool()
@@ -378,8 +505,8 @@ class StableDiffusionReferencePipeline:
378
  masked_x, dim=(2, 3), keepdim=True, correction=0
379
  )
380
 
381
- self.mean_bank.append(mean)
382
- self.var_bank.append(var)
383
  if self.MODE == "read":
384
  if (
385
  self.gn_auto_machine_weight >= self.gn_weight
@@ -387,37 +514,12 @@ class StableDiffusionReferencePipeline:
387
  and len(self.var_bank) > 0
388
  ):
389
  # print("hacked_mid_forward")
390
- scale_ratio = self.inpaint_mask.shape[2] / x.shape[2]
391
- this_inpaint_mask = F.interpolate(
392
- self.inpaint_mask.to(x.device), scale_factor=1 / scale_ratio
393
- )
394
- this_inpaint_mask = this_inpaint_mask.repeat(
395
- x.shape[0], x.shape[1], 1, 1
396
- ).bool()
397
- masked_x = (
398
- x[this_inpaint_mask]
399
- .detach()
400
- .clone()
401
- .view(x.shape[0], x.shape[1], -1, 1)
402
- )
403
- var, mean = torch.var_mean(
404
- masked_x, dim=(2, 3), keepdim=True, correction=0
405
- )
406
- std = torch.maximum(
407
- var, torch.zeros_like(var) + eps) ** 0.5
408
- mean_acc = sum(self.mean_bank) / float(len(self.mean_bank))
409
- var_acc = sum(self.var_bank) / float(len(self.var_bank))
410
- std_acc = (
411
- torch.maximum(var_acc, torch.zeros_like(
412
- var_acc) + eps) ** 0.5
413
- )
414
- x_uc = (((masked_x - mean) / std) * std_acc) + mean_acc
415
- x_c = x_uc.clone()
416
- if self.do_classifier_free_guidance and self.style_fidelity > 0:
417
- x_c[self.uc_mask] = masked_x[self.uc_mask]
418
- masked_x = self.style_fidelity * x_c + \
419
- (1.0 - self.style_fidelity) * x_uc
420
- x[this_inpaint_mask] = masked_x.view(-1)
421
  self.mean_bank = []
422
  self.var_bank = []
423
  return x
@@ -448,6 +550,8 @@ class StableDiffusionReferencePipeline:
448
  self.ref_mask.to(hidden_states.device),
449
  scale_factor=1 / scale_ratio,
450
  )
 
 
451
  this_ref_mask = this_ref_mask.repeat(
452
  hidden_states.shape[0], hidden_states.shape[1], 1, 1
453
  ).bool()
@@ -460,8 +564,8 @@ class StableDiffusionReferencePipeline:
460
  var, mean = torch.var_mean(
461
  masked_hidden_states, dim=(2, 3), keepdim=True, correction=0
462
  )
463
- self.mean_bank0.append(mean)
464
- self.var_bank0.append(var)
465
  if self.MODE == "read":
466
  if (
467
  self.gn_auto_machine_weight >= self.gn_weight
@@ -469,54 +573,17 @@ class StableDiffusionReferencePipeline:
469
  and len(self.var_bank0) > 0
470
  ):
471
  # print("hacked_CrossAttnDownBlock2D_forward0")
472
- scale_ratio = self.inpaint_mask.shape[2] / \
473
- hidden_states.shape[2]
474
- this_inpaint_mask = F.interpolate(
475
- self.inpaint_mask.to(hidden_states.device), scale_factor=1 / scale_ratio
476
- )
477
- this_inpaint_mask = this_inpaint_mask.repeat(
478
- hidden_states.shape[0], hidden_states.shape[1], 1, 1
479
- ).bool()
480
- masked_hidden_states = (
481
- hidden_states[this_inpaint_mask]
482
- .detach()
483
- .clone()
484
- .view(hidden_states.shape[0], hidden_states.shape[1], -1, 1)
485
- )
486
- var, mean = torch.var_mean(
487
- masked_hidden_states, dim=(2, 3), keepdim=True, correction=0
488
- )
489
- std = torch.maximum(
490
- var, torch.zeros_like(var) + eps) ** 0.5
491
- mean_acc = sum(self.mean_bank0[i]) / float(
492
- len(self.mean_bank0[i])
493
- )
494
- var_acc = sum(
495
- self.var_bank0[i]) / float(len(self.var_bank0[i]))
496
- std_acc = (
497
- torch.maximum(
498
- var_acc, torch.zeros_like(var_acc) + eps)
499
- ** 0.5
500
- )
501
- hidden_states_uc = (
502
- ((masked_hidden_states - mean) / std) * std_acc
503
- ) + mean_acc
504
- hidden_states_c = hidden_states_uc.clone()
505
- if self.do_classifier_free_guidance and self.style_fidelity > 0:
506
- hidden_states_c[self.uc_mask] = masked_hidden_states[self.uc_mask]
507
- masked_hidden_states = (
508
- self.style_fidelity * hidden_states_c
509
- + (1.0 - self.style_fidelity) * hidden_states_uc
510
- )
511
- hidden_states[this_inpaint_mask] = masked_hidden_states.view(
512
- -1)
513
 
514
  hidden_states = attn(
515
  hidden_states,
516
  encoder_hidden_states=encoder_hidden_states,
517
  cross_attention_kwargs=cross_attention_kwargs,
518
- # attention_mask=attention_mask,
519
- # encoder_attention_mask=encoder_attention_mask,
520
  return_dict=False,
521
  )[0]
522
  if self.MODE == "write":
@@ -528,6 +595,8 @@ class StableDiffusionReferencePipeline:
528
  self.ref_mask.to(hidden_states.device),
529
  scale_factor=1 / scale_ratio,
530
  )
 
 
531
  this_ref_mask = this_ref_mask.repeat(
532
  hidden_states.shape[0], hidden_states.shape[1], 1, 1
533
  ).bool()
@@ -540,8 +609,8 @@ class StableDiffusionReferencePipeline:
540
  var, mean = torch.var_mean(
541
  masked_hidden_states, dim=(2, 3), keepdim=True, correction=0
542
  )
543
- self.mean_bank.append(mean)
544
- self.var_bank.append(var)
545
  if self.MODE == "read":
546
  if (
547
  self.gn_auto_machine_weight >= self.gn_weight
@@ -549,48 +618,12 @@ class StableDiffusionReferencePipeline:
549
  and len(self.var_bank) > 0
550
  ):
551
  # print("hack_CrossAttnDownBlock2D_forward")
552
- scale_ratio = self.inpaint_mask.shape[2] / \
553
- hidden_states.shape[2]
554
- this_inpaint_mask = F.interpolate(
555
- self.inpaint_mask.to(hidden_states.device), scale_factor=1 / scale_ratio
556
- )
557
- this_inpaint_mask = this_inpaint_mask.repeat(
558
- hidden_states.shape[0], hidden_states.shape[1], 1, 1
559
- ).bool()
560
- masked_hidden_states = (
561
- hidden_states[this_inpaint_mask]
562
- .detach()
563
- .clone()
564
- .view(hidden_states.shape[0], hidden_states.shape[1], -1, 1)
565
- )
566
- var, mean = torch.var_mean(
567
- masked_hidden_states, dim=(2, 3), keepdim=True, correction=0
568
- )
569
- std = torch.maximum(
570
- var, torch.zeros_like(var) + eps) ** 0.5
571
- mean_acc = sum(self.mean_bank[i]) / float(
572
- len(self.mean_bank[i])
573
- )
574
- var_acc = sum(
575
- self.var_bank[i]) / float(len(self.var_bank[i]))
576
- std_acc = (
577
- torch.maximum(
578
- var_acc, torch.zeros_like(var_acc) + eps)
579
- ** 0.5
580
- )
581
- hidden_states_uc = (
582
- ((masked_hidden_states - mean) / std) * std_acc
583
- ) + mean_acc
584
- hidden_states_c = hidden_states_uc.clone()
585
- if self.do_classifier_free_guidance and self.style_fidelity > 0:
586
- hidden_states_c[self.uc_mask] = masked_hidden_states[self.uc_mask]
587
- masked_hidden_states = (
588
- self.style_fidelity * hidden_states_c
589
- + (1.0 - self.style_fidelity) * hidden_states_uc
590
- )
591
- hidden_states[this_inpaint_mask] = masked_hidden_states.view(
592
- -1)
593
 
 
 
 
594
  output_states = output_states + (hidden_states,)
595
 
596
  if self.MODE == "read":
@@ -598,6 +631,8 @@ class StableDiffusionReferencePipeline:
598
  self.var_bank0 = []
599
  self.mean_bank = []
600
  self.var_bank = []
 
 
601
 
602
  if self.downsamplers is not None:
603
  for downsampler in self.downsamplers:
@@ -625,6 +660,8 @@ class StableDiffusionReferencePipeline:
625
  self.ref_mask.to(hidden_states.device),
626
  scale_factor=1 / scale_ratio,
627
  )
 
 
628
  this_ref_mask = this_ref_mask.repeat(
629
  hidden_states.shape[0], hidden_states.shape[1], 1, 1
630
  ).bool()
@@ -637,8 +674,8 @@ class StableDiffusionReferencePipeline:
637
  var, mean = torch.var_mean(
638
  masked_hidden_states, dim=(2, 3), keepdim=True, correction=0
639
  )
640
- self.mean_bank.append(mean)
641
- self.var_bank.append(var)
642
  if self.MODE == "read":
643
  if (
644
  self.gn_auto_machine_weight >= self.gn_weight
@@ -646,53 +683,19 @@ class StableDiffusionReferencePipeline:
646
  and len(self.var_bank) > 0
647
  ):
648
  # print("hacked_DownBlock2D_forward")
649
- scale_ratio = self.inpaint_mask.shape[2] / \
650
- hidden_states.shape[2]
651
- this_inpaint_mask = F.interpolate(
652
- self.inpaint_mask.to(hidden_states.device), scale_factor=1 / scale_ratio
653
- )
654
- this_inpaint_mask = this_inpaint_mask.repeat(
655
- hidden_states.shape[0], hidden_states.shape[1], 1, 1
656
- ).bool()
657
- masked_hidden_states = (
658
- hidden_states[this_inpaint_mask]
659
- .detach()
660
- .clone()
661
- .view(hidden_states.shape[0], hidden_states.shape[1], -1, 1)
662
- )
663
- var, mean = torch.var_mean(
664
- masked_hidden_states, dim=(2, 3), keepdim=True, correction=0
665
- )
666
- std = torch.maximum(
667
- var, torch.zeros_like(var) + eps) ** 0.5
668
- mean_acc = sum(self.mean_bank[i]) / float(
669
- len(self.mean_bank[i])
670
- )
671
- var_acc = sum(
672
- self.var_bank[i]) / float(len(self.var_bank[i]))
673
- std_acc = (
674
- torch.maximum(
675
- var_acc, torch.zeros_like(var_acc) + eps)
676
- ** 0.5
677
- )
678
- hidden_states_uc = (
679
- ((masked_hidden_states - mean) / std) * std_acc
680
- ) + mean_acc
681
- hidden_states_c = hidden_states_uc.clone()
682
- if self.do_classifier_free_guidance and self.style_fidelity > 0:
683
- hidden_states_c[self.uc_mask] = masked_hidden_states[self.uc_mask]
684
- masked_hidden_states = (
685
- self.style_fidelity * hidden_states_c
686
- + (1.0 - self.style_fidelity) * hidden_states_uc
687
- )
688
- hidden_states[this_inpaint_mask] = masked_hidden_states.view(
689
- -1)
690
 
691
  output_states = output_states + (hidden_states,)
692
 
693
  if self.MODE == "read":
694
  self.mean_bank = []
695
  self.var_bank = []
 
696
 
697
  if self.downsamplers is not None:
698
  for downsampler in self.downsamplers:
@@ -733,6 +736,8 @@ class StableDiffusionReferencePipeline:
733
  self.ref_mask.to(hidden_states.device),
734
  scale_factor=1 / scale_ratio,
735
  )
 
 
736
  this_ref_mask = this_ref_mask.repeat(
737
  hidden_states.shape[0], hidden_states.shape[1], 1, 1
738
  ).bool()
@@ -745,8 +750,8 @@ class StableDiffusionReferencePipeline:
745
  var, mean = torch.var_mean(
746
  masked_hidden_states, dim=(2, 3), keepdim=True, correction=0
747
  )
748
- self.mean_bank0.append(mean)
749
- self.var_bank0.append(var)
750
  if self.MODE == "read":
751
  if (
752
  self.gn_auto_machine_weight >= self.gn_weight
@@ -754,47 +759,12 @@ class StableDiffusionReferencePipeline:
754
  and len(self.var_bank0) > 0
755
  ):
756
  # print("hacked_CrossAttnUpBlock2D_forward1")
757
- scale_ratio = self.inpaint_mask.shape[2] / \
758
- hidden_states.shape[2]
759
- this_inpaint_mask = F.interpolate(
760
- self.inpaint_mask.to(hidden_states.device), scale_factor=1 / scale_ratio
761
- )
762
- this_inpaint_mask = this_inpaint_mask.repeat(
763
- hidden_states.shape[0], hidden_states.shape[1], 1, 1
764
- ).bool()
765
- masked_hidden_states = (
766
- hidden_states[this_inpaint_mask]
767
- .detach()
768
- .clone()
769
- .view(hidden_states.shape[0], hidden_states.shape[1], -1, 1)
770
- )
771
- var, mean = torch.var_mean(
772
- masked_hidden_states, dim=(2, 3), keepdim=True, correction=0
773
- )
774
- std = torch.maximum(
775
- var, torch.zeros_like(var) + eps) ** 0.5
776
- mean_acc = sum(self.mean_bank0[i]) / float(
777
- len(self.mean_bank0[i])
778
- )
779
- var_acc = sum(
780
- self.var_bank0[i]) / float(len(self.var_bank0[i]))
781
- std_acc = (
782
- torch.maximum(
783
- var_acc, torch.zeros_like(var_acc) + eps)
784
- ** 0.5
785
- )
786
- hidden_states_uc = (
787
- ((masked_hidden_states - mean) / std) * std_acc
788
- ) + mean_acc
789
- hidden_states_c = hidden_states_uc.clone()
790
- if self.do_classifier_free_guidance and self.style_fidelity > 0:
791
- hidden_states_c[self.uc_mask] = masked_hidden_states[self.uc_mask]
792
- masked_hidden_states = (
793
- self.style_fidelity * hidden_states_c
794
- + (1.0 - self.style_fidelity) * hidden_states_uc
795
- )
796
- hidden_states[this_inpaint_mask] = masked_hidden_states.view(
797
- -1)
798
 
799
  hidden_states = attn(
800
  hidden_states,
@@ -815,6 +785,8 @@ class StableDiffusionReferencePipeline:
815
  self.ref_mask.to(hidden_states.device),
816
  scale_factor=1 / scale_ratio,
817
  )
 
 
818
  this_ref_mask = this_ref_mask.repeat(
819
  hidden_states.shape[0], hidden_states.shape[1], 1, 1
820
  ).bool()
@@ -827,8 +799,8 @@ class StableDiffusionReferencePipeline:
827
  var, mean = torch.var_mean(
828
  masked_hidden_states, dim=(2, 3), keepdim=True, correction=0
829
  )
830
- self.mean_bank.append(mean)
831
- self.var_bank.append(var)
832
  if self.MODE == "read":
833
  if (
834
  self.gn_auto_machine_weight >= self.gn_weight
@@ -836,53 +808,20 @@ class StableDiffusionReferencePipeline:
836
  and len(self.var_bank) > 0
837
  ):
838
  # print("hacked_CrossAttnUpBlock2D_forward")
839
- scale_ratio = self.inpaint_mask.shape[2] / \
840
- hidden_states.shape[2]
841
- this_inpaint_mask = F.interpolate(
842
- self.inpaint_mask.to(hidden_states.device), scale_factor=1 / scale_ratio
843
- )
844
- this_inpaint_mask = this_inpaint_mask.repeat(
845
- hidden_states.shape[0], hidden_states.shape[1], 1, 1
846
- ).bool()
847
- masked_hidden_states = (
848
- hidden_states[this_inpaint_mask]
849
- .detach()
850
- .clone()
851
- .view(hidden_states.shape[0], hidden_states.shape[1], -1, 1)
852
- )
853
- var, mean = torch.var_mean(
854
- masked_hidden_states, dim=(2, 3), keepdim=True, correction=0
855
- )
856
- std = torch.maximum(
857
- var, torch.zeros_like(var) + eps) ** 0.5
858
- mean_acc = sum(self.mean_bank[i]) / float(
859
- len(self.mean_bank[i])
860
- )
861
- var_acc = sum(
862
- self.var_bank[i]) / float(len(self.var_bank[i]))
863
- std_acc = (
864
- torch.maximum(
865
- var_acc, torch.zeros_like(var_acc) + eps)
866
- ** 0.5
867
- )
868
- hidden_states_uc = (
869
- ((masked_hidden_states - mean) / std) * std_acc
870
- ) + mean_acc
871
- hidden_states_c = hidden_states_uc.clone()
872
- if self.do_classifier_free_guidance and self.style_fidelity > 0:
873
- hidden_states_c[self.uc_mask] = masked_hidden_states[self.uc_mask]
874
- masked_hidden_states = (
875
- self.style_fidelity * hidden_states_c
876
- + (1.0 - self.style_fidelity) * hidden_states_uc
877
- )
878
- hidden_states[this_inpaint_mask] = masked_hidden_states.view(
879
- -1)
880
 
881
  if self.MODE == "read":
882
  self.mean_bank0 = []
883
  self.var_bank0 = []
884
  self.mean_bank = []
885
  self.var_bank = []
 
 
886
 
887
  if self.upsamplers is not None:
888
  for upsampler in self.upsamplers:
@@ -912,6 +851,8 @@ class StableDiffusionReferencePipeline:
912
  self.ref_mask.to(hidden_states.device),
913
  scale_factor=1 / scale_ratio,
914
  )
 
 
915
  this_ref_mask = this_ref_mask.repeat(
916
  hidden_states.shape[0], hidden_states.shape[1], 1, 1
917
  ).bool()
@@ -924,8 +865,8 @@ class StableDiffusionReferencePipeline:
924
  var, mean = torch.var_mean(
925
  masked_hidden_states, dim=(2, 3), keepdim=True, correction=0
926
  )
927
- self.mean_bank.append(mean)
928
- self.var_bank.append(var)
929
  if self.MODE == "read":
930
  if (
931
  self.gn_auto_machine_weight >= self.gn_weight
@@ -933,51 +874,17 @@ class StableDiffusionReferencePipeline:
933
  and len(self.var_bank) > 0
934
  ):
935
  # print("hacked_UpBlock2D_forward")
936
- scale_ratio = self.inpaint_mask.shape[2] / \
937
- hidden_states.shape[2]
938
- this_inpaint_mask = F.interpolate(
939
- self.inpaint_mask.to(hidden_states.device), scale_factor=1 / scale_ratio
940
- )
941
- this_inpaint_mask = this_inpaint_mask.repeat(
942
- hidden_states.shape[0], hidden_states.shape[1], 1, 1
943
- ).bool()
944
- masked_hidden_states = (
945
- hidden_states[this_inpaint_mask]
946
- .detach()
947
- .clone()
948
- .view(hidden_states.shape[0], hidden_states.shape[1], -1, 1)
949
- )
950
- var, mean = torch.var_mean(
951
- masked_hidden_states, dim=(2, 3), keepdim=True, correction=0
952
- )
953
- std = torch.maximum(
954
- var, torch.zeros_like(var) + eps) ** 0.5
955
- mean_acc = sum(self.mean_bank[i]) / float(
956
- len(self.mean_bank[i])
957
- )
958
- var_acc = sum(
959
- self.var_bank[i]) / float(len(self.var_bank[i]))
960
- std_acc = (
961
- torch.maximum(
962
- var_acc, torch.zeros_like(var_acc) + eps)
963
- ** 0.5
964
- )
965
- hidden_states_uc = (
966
- ((masked_hidden_states - mean) / std) * std_acc
967
- ) + mean_acc
968
- hidden_states_c = hidden_states_uc.clone()
969
- if self.do_classifier_free_guidance and self.style_fidelity > 0:
970
- hidden_states_c[self.uc_mask] = masked_hidden_states[self.uc_mask]
971
- masked_hidden_states = (
972
- self.style_fidelity * hidden_states_c
973
- + (1.0 - self.style_fidelity) * hidden_states_uc
974
- )
975
- hidden_states[this_inpaint_mask] = masked_hidden_states.view(
976
- -1)
977
 
978
  if self.MODE == "read":
979
  self.mean_bank = []
980
  self.var_bank = []
 
981
 
982
  if self.upsamplers is not None:
983
  for upsampler in self.upsamplers:
@@ -1003,6 +910,7 @@ class StableDiffusionReferencePipeline:
1003
  module, BasicTransformerBlock
1004
  )
1005
  module.bank = []
 
1006
  module.attn_weight = float(i) / float(len(attn_modules))
1007
  module.attention_auto_machine_weight = (
1008
  self.attention_auto_machine_weight
@@ -1017,6 +925,7 @@ class StableDiffusionReferencePipeline:
1017
  module.uc_mask = self.uc_mask
1018
  module.style_fidelity = self.style_fidelity
1019
  module.ref_mask = self.ref_mask
 
1020
  else:
1021
  attn_modules = None
1022
  if reference_adain:
@@ -1043,12 +952,14 @@ class StableDiffusionReferencePipeline:
1043
  module.forward = hacked_mid_forward.__get__(
1044
  module, torch.nn.Module
1045
  )
1046
- elif isinstance(module, CrossAttnDownBlock2D):
1047
- module.forward = hack_CrossAttnDownBlock2D_forward.__get__(
1048
- module, CrossAttnDownBlock2D
1049
- )
1050
- module.mean_bank0 = []
1051
- module.var_bank0 = []
 
 
1052
  elif isinstance(module, DownBlock2D):
1053
  module.forward = hacked_DownBlock2D_forward.__get__(
1054
  module, DownBlock2D
@@ -1057,14 +968,17 @@ class StableDiffusionReferencePipeline:
1057
  # module.forward = hacked_CrossAttnUpBlock2D_forward.__get__(module, CrossAttnUpBlock2D)
1058
  # module.mean_bank0 = []
1059
  # module.var_bank0 = []
 
1060
  elif isinstance(module, UpBlock2D):
1061
  module.forward = hacked_UpBlock2D_forward.__get__(
1062
  module, UpBlock2D
1063
  )
1064
  module.mean_bank0 = []
1065
  module.var_bank0 = []
 
1066
  module.mean_bank = []
1067
  module.var_bank = []
 
1068
  module.attention_auto_machine_weight = (
1069
  self.attention_auto_machine_weight
1070
  )
@@ -1079,6 +993,7 @@ class StableDiffusionReferencePipeline:
1079
  module.style_fidelity = self.style_fidelity
1080
  module.ref_mask = self.ref_mask
1081
  module.inpaint_mask = self.inpaint_mask
 
1082
  else:
1083
  gn_modules = None
1084
  elif model_type == "controlnet":
@@ -1098,6 +1013,7 @@ class StableDiffusionReferencePipeline:
1098
  module, BasicTransformerBlock
1099
  )
1100
  module.bank = []
 
1101
  # float(i) / float(len(attn_modules))
1102
  module.attn_weight = 0.0
1103
  module.attention_auto_machine_weight = (
@@ -1113,9 +1029,61 @@ class StableDiffusionReferencePipeline:
1113
  module.uc_mask = self.uc_mask
1114
  module.style_fidelity = self.style_fidelity
1115
  module.ref_mask = self.ref_mask
 
1116
  else:
1117
  attn_modules = None
1118
- gn_modules = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1119
 
1120
  return attn_modules, gn_modules
1121
 
@@ -1123,6 +1091,7 @@ class StableDiffusionReferencePipeline:
1123
  if attn_modules is not None:
1124
  for i, module in enumerate(attn_modules):
1125
  module.MODE = mode
 
1126
  if gn_modules is not None:
1127
  for i, module in enumerate(gn_modules):
1128
  module.MODE = mode
 
1
  # Based on https://raw.githubusercontent.com/okotaku/diffusers/feature/reference_only_control/examples/community/stable_diffusion_reference.py
2
  # Inspired by: https://github.com/Mikubill/sd-webui-controlnet/discussions/1236
3
+ import torch.fft as fft
4
  from typing import Any, Callable, Dict, List, Optional, Union, Tuple
5
 
6
  import numpy as np
7
  import PIL.Image
8
  import torch
9
 
 
10
  from diffusers.models.attention import BasicTransformerBlock
11
  from diffusers.models.unet_2d_blocks import (
12
  CrossAttnDownBlock2D,
 
14
  DownBlock2D,
15
  UpBlock2D,
16
  )
 
17
  from diffusers.utils import PIL_INTERPOLATION, logging
18
  import torch.nn.functional as F
19
 
 
20
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
21
 
22
  EXAMPLE_DOC_STRING = """
 
54
  return result
55
 
56
 
57
+ @torch.no_grad()
58
+ def add_freq_feature(feature1, feature2, ref_ratio):
59
+ """
60
+ feature1: reference feature
61
+ feature2: target feature
62
+ ref_ratio: larger ratio means larger reference frequency
63
+ """
64
+ # Convert features to float32 (if not already) for compatibility with fft operations
65
+ data_type = feature2.dtype
66
+ feature1 = feature1.to(torch.float32)
67
+ feature2 = feature2.to(torch.float32)
68
+
69
+ # Compute the Fourier transforms of both features
70
+ spectrum1 = fft.fftn(feature1, dim=(-2, -1))
71
+ spectrum2 = fft.fftn(feature2, dim=(-2, -1))
72
+
73
+ # Extract high-frequency magnitude and phase from feature1
74
+ magnitude1 = torch.abs(spectrum1)
75
+ # phase1 = torch.angle(spectrum1)
76
+
77
+ # Extract magnitude and phase from feature2
78
+ magnitude2 = torch.abs(spectrum2)
79
+ phase2 = torch.angle(spectrum2)
80
+
81
+ magnitude2.mul_((1-ref_ratio)).add_(magnitude1 * ref_ratio)
82
+ # phase2.mul_(1.0).add_(phase1 * 0.0)
83
+
84
+ # Combine magnitude and phase information
85
+ mixed_spectrum = torch.polar(magnitude2, phase2)
86
+
87
+ # Compute the inverse Fourier transform to get the mixed feature
88
+ mixed_feature = fft.ifftn(mixed_spectrum, dim=(-2, -1))
89
+
90
+ del feature1, feature2, spectrum1, spectrum2, magnitude1, magnitude2, phase2, mixed_spectrum
91
+
92
+ # Convert back to the original data type and return the result
93
+ return mixed_feature.to(data_type)
94
+
95
+
96
+ @torch.no_grad()
97
+ def save_ref_feature(feature, mask):
98
+ """
99
+ feature: n,c,h,w
100
+ mask: n,1,h,w
101
+
102
+ return n,c,h,w
103
+ """
104
+ return feature * mask
105
+
106
+
107
+ @torch.no_grad()
108
+ def mix_ref_feature(feature, ref_fea_bank, cfg=True, ref_scale=0.0, dim3=False):
109
+ """
110
+ feature: n,l,c or n,c,h,w
111
+ ref_fea_bank: [(n,c,h,w)]
112
+ cfg: True/False
113
+
114
+ return n,l,c or n,c,h,w
115
+ """
116
+ if cfg:
117
+ ref_fea = torch.cat(
118
+ (ref_fea_bank+ref_fea_bank), dim=0)
119
+ else:
120
+ ref_fea = ref_fea_bank
121
+
122
+ if dim3:
123
+ feature = feature.permute(0, 2, 1).view(ref_fea.shape)
124
+
125
+ mixed_feature = add_freq_feature(ref_fea, feature, ref_scale)
126
+
127
+ if dim3:
128
+ mixed_feature = mixed_feature.view(
129
+ ref_fea.shape[0], ref_fea.shape[1], -1).permute(0, 2, 1)
130
+
131
+ del ref_fea
132
+ del feature
133
+ return mixed_feature
134
+
135
+
136
+ def mix_norm_feature(x, inpaint_mask, mean_bank, var_bank, do_classifier_free_guidance, style_fidelity, uc_mask, eps=1e-6):
137
+ """
138
+ x: input feature n,c,h,w
139
+ inpaint_mask: mask region to inpain
140
+ """
141
+
142
+ # get the inpainting region and only mix this region.
143
+ scale_ratio = inpaint_mask.shape[2] / x.shape[2]
144
+ this_inpaint_mask = F.interpolate(
145
+ inpaint_mask.to(x.device), scale_factor=1 / scale_ratio
146
+ )
147
+ this_inpaint_mask = this_inpaint_mask.repeat(
148
+ x.shape[0], x.shape[1], 1, 1
149
+ ).bool()
150
+ masked_x = (
151
+ x[this_inpaint_mask]
152
+ .detach()
153
+ .clone()
154
+ .view(x.shape[0], x.shape[1], -1, 1)
155
+ )
156
+ var, mean = torch.var_mean(
157
+ masked_x, dim=(2, 3), keepdim=True, correction=0
158
+ )
159
+ std = torch.maximum(
160
+ var, torch.zeros_like(var) + eps) ** 0.5
161
+ mean_acc = sum(mean_bank) / float(len(mean_bank))
162
+ var_acc = sum(var_bank) / float(len(var_bank))
163
+ std_acc = (
164
+ torch.maximum(var_acc, torch.zeros_like(
165
+ var_acc) + eps) ** 0.5
166
+ )
167
+
168
+ x_uc = (((masked_x - mean) / std) * std_acc) + mean_acc
169
+ x_c = x_uc.clone()
170
+ if do_classifier_free_guidance and style_fidelity > 0:
171
+ x_c[uc_mask] = masked_x[uc_mask]
172
+ masked_x = style_fidelity * x_c + \
173
+ (1.0 - style_fidelity) * x_uc
174
+ x[this_inpaint_mask] = masked_x.view(-1)
175
+ return x
176
+
177
+
178
  class StableDiffusionReferencePipeline:
179
  def prepare_ref_image(
180
  self,
 
356
  this_ref_mask = F.interpolate(
357
  this_ref_mask, scale_factor=ref_scale
358
  )
359
+ self.fea_bank.append(save_ref_feature(
360
+ resize_norm_hidden_states, this_ref_mask))
 
361
  this_ref_mask = this_ref_mask.repeat(
362
  resize_norm_hidden_states.shape[0],
363
  resize_norm_hidden_states.shape[1],
 
374
  -1,
375
  )
376
  )
377
+
378
  masked_norm_hidden_states = masked_norm_hidden_states.permute(
379
  0, 2, 1
380
  )
381
  self.bank.append(masked_norm_hidden_states)
382
+ del masked_norm_hidden_states
383
+ del this_ref_mask
384
+ del resize_norm_hidden_states
385
  attn_output = self.attn1(
386
  norm_hidden_states,
387
  encoder_hidden_states=encoder_hidden_states
 
392
  )
393
  if self.MODE == "read":
394
  if self.attention_auto_machine_weight > self.attn_weight:
395
+ freq_norm_hidden_states = mix_ref_feature(
396
+ norm_hidden_states,
397
+ self.fea_bank,
398
+ cfg=self.do_classifier_free_guidance,
399
+ ref_scale=self.ref_scale,
400
+ dim3=True)
401
+ self.fea_bank.clear()
402
+
403
+ this_bank = torch.cat(self.bank+self.bank, dim=0)
 
 
 
 
 
 
 
404
  ref_hidden_states = torch.cat(
405
+ (freq_norm_hidden_states, this_bank), dim=1
406
  )
407
+ del this_bank
408
+ self.bank.clear()
409
+
410
  attn_output_uc = self.attn1(
411
+ freq_norm_hidden_states,
412
  encoder_hidden_states=ref_hidden_states,
 
413
  **cross_attention_kwargs,
414
  )
415
+ del ref_hidden_states
416
  attn_output_c = attn_output_uc.clone()
417
  if self.do_classifier_free_guidance and self.style_fidelity > 0:
418
  attn_output_c[self.uc_mask] = self.attn1(
 
425
  + (1.0 - self.style_fidelity) * attn_output_uc
426
  )
427
  self.bank.clear()
428
+ self.fea_bank.clear()
429
+ del attn_output_c
430
+ del attn_output_uc
431
  else:
432
  attn_output = self.attn1(
433
  norm_hidden_states,
 
437
  attention_mask=attention_mask,
438
  **cross_attention_kwargs,
439
  )
440
+ self.bank.clear()
441
+ self.fea_bank.clear()
442
+
443
  if self.use_ada_layer_norm_zero:
444
  attn_output = gate_msa.unsqueeze(1) * attn_output
445
  hidden_states = attn_output + hidden_states
 
488
  this_ref_mask = F.interpolate(
489
  self.ref_mask.to(x.device), scale_factor=1 / scale_ratio
490
  )
491
+
492
+ self.fea_bank.append(save_ref_feature(
493
+ x, this_ref_mask))
494
+
495
  this_ref_mask = this_ref_mask.repeat(
496
  x.shape[0], x.shape[1], 1, 1
497
  ).bool()
 
505
  masked_x, dim=(2, 3), keepdim=True, correction=0
506
  )
507
 
508
+ self.mean_bank.append(torch.cat([mean]*2, dim=0))
509
+ self.var_bank.append(torch.cat([var]*2, dim=0))
510
  if self.MODE == "read":
511
  if (
512
  self.gn_auto_machine_weight >= self.gn_weight
 
514
  and len(self.var_bank) > 0
515
  ):
516
  # print("hacked_mid_forward")
517
+ x = mix_ref_feature(
518
+ x, self.fea_bank, cfg=self.do_classifier_free_guidance, ref_scale=self.ref_scale)
519
+ self.fea_bank = []
520
+ x = mix_norm_feature(x, self.inpaint_mask, self.mean_bank, self.var_bank,
521
+ self.do_classifier_free_guidance,
522
+ self.style_fidelity, self.uc_mask)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
523
  self.mean_bank = []
524
  self.var_bank = []
525
  return x
 
550
  self.ref_mask.to(hidden_states.device),
551
  scale_factor=1 / scale_ratio,
552
  )
553
+ self.fea_bank0.append(save_ref_feature(
554
+ hidden_states, this_ref_mask))
555
  this_ref_mask = this_ref_mask.repeat(
556
  hidden_states.shape[0], hidden_states.shape[1], 1, 1
557
  ).bool()
 
564
  var, mean = torch.var_mean(
565
  masked_hidden_states, dim=(2, 3), keepdim=True, correction=0
566
  )
567
+ self.mean_bank0.append(torch.cat([mean]*2, dim=0))
568
+ self.var_bank0.append(torch.cat([var]*2, dim=0))
569
  if self.MODE == "read":
570
  if (
571
  self.gn_auto_machine_weight >= self.gn_weight
 
573
  and len(self.var_bank0) > 0
574
  ):
575
  # print("hacked_CrossAttnDownBlock2D_forward0")
576
+ hidden_states = mix_ref_feature(
577
+ hidden_states, [self.fea_bank0[i]], cfg=self.do_classifier_free_guidance, ref_scale=self.ref_scale)
578
+
579
+ hidden_states = mix_norm_feature(hidden_states, self.inpaint_mask, self.mean_bank0[i], self.var_bank0[i],
580
+ self.do_classifier_free_guidance,
581
+ self.style_fidelity, self.uc_mask)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
582
 
583
  hidden_states = attn(
584
  hidden_states,
585
  encoder_hidden_states=encoder_hidden_states,
586
  cross_attention_kwargs=cross_attention_kwargs,
 
 
587
  return_dict=False,
588
  )[0]
589
  if self.MODE == "write":
 
595
  self.ref_mask.to(hidden_states.device),
596
  scale_factor=1 / scale_ratio,
597
  )
598
+ self.fea_bank.append(save_ref_feature(
599
+ hidden_states, this_ref_mask))
600
  this_ref_mask = this_ref_mask.repeat(
601
  hidden_states.shape[0], hidden_states.shape[1], 1, 1
602
  ).bool()
 
609
  var, mean = torch.var_mean(
610
  masked_hidden_states, dim=(2, 3), keepdim=True, correction=0
611
  )
612
+ self.mean_bank.append(torch.cat([mean]*2, dim=0))
613
+ self.var_bank.append(torch.cat([var]*2, dim=0))
614
  if self.MODE == "read":
615
  if (
616
  self.gn_auto_machine_weight >= self.gn_weight
 
618
  and len(self.var_bank) > 0
619
  ):
620
  # print("hack_CrossAttnDownBlock2D_forward")
621
+ hidden_states = mix_ref_feature(
622
+ hidden_states, [self.fea_bank[i]], cfg=self.do_classifier_free_guidance, ref_scale=self.ref_scale)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
623
 
624
+ hidden_states = mix_norm_feature(hidden_states, self.inpaint_mask, self.mean_bank[i], self.var_bank[i],
625
+ self.do_classifier_free_guidance,
626
+ self.style_fidelity, self.uc_mask)
627
  output_states = output_states + (hidden_states,)
628
 
629
  if self.MODE == "read":
 
631
  self.var_bank0 = []
632
  self.mean_bank = []
633
  self.var_bank = []
634
+ self.fea_bank0 = []
635
+ self.fea_bank = []
636
 
637
  if self.downsamplers is not None:
638
  for downsampler in self.downsamplers:
 
660
  self.ref_mask.to(hidden_states.device),
661
  scale_factor=1 / scale_ratio,
662
  )
663
+ self.fea_bank.append(save_ref_feature(
664
+ hidden_states, this_ref_mask))
665
  this_ref_mask = this_ref_mask.repeat(
666
  hidden_states.shape[0], hidden_states.shape[1], 1, 1
667
  ).bool()
 
674
  var, mean = torch.var_mean(
675
  masked_hidden_states, dim=(2, 3), keepdim=True, correction=0
676
  )
677
+ self.mean_bank.append(torch.cat([mean]*2, dim=0))
678
+ self.var_bank.append(torch.cat([var]*2, dim=0))
679
  if self.MODE == "read":
680
  if (
681
  self.gn_auto_machine_weight >= self.gn_weight
 
683
  and len(self.var_bank) > 0
684
  ):
685
  # print("hacked_DownBlock2D_forward")
686
+ hidden_states = mix_ref_feature(
687
+ hidden_states, [self.fea_bank[i]], cfg=self.do_classifier_free_guidance, ref_scale=self.ref_scale)
688
+
689
+ hidden_states = mix_norm_feature(hidden_states, self.inpaint_mask, self.mean_bank[i], self.var_bank[i],
690
+ self.do_classifier_free_guidance,
691
+ self.style_fidelity, self.uc_mask)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
692
 
693
  output_states = output_states + (hidden_states,)
694
 
695
  if self.MODE == "read":
696
  self.mean_bank = []
697
  self.var_bank = []
698
+ self.fea_bank = []
699
 
700
  if self.downsamplers is not None:
701
  for downsampler in self.downsamplers:
 
736
  self.ref_mask.to(hidden_states.device),
737
  scale_factor=1 / scale_ratio,
738
  )
739
+ self.fea_bank0.append(save_ref_feature(
740
+ hidden_states, this_ref_mask))
741
  this_ref_mask = this_ref_mask.repeat(
742
  hidden_states.shape[0], hidden_states.shape[1], 1, 1
743
  ).bool()
 
750
  var, mean = torch.var_mean(
751
  masked_hidden_states, dim=(2, 3), keepdim=True, correction=0
752
  )
753
+ self.mean_bank0.append(torch.cat([mean]*2, dim=0))
754
+ self.var_bank0.append(torch.cat([var]*2, dim=0))
755
  if self.MODE == "read":
756
  if (
757
  self.gn_auto_machine_weight >= self.gn_weight
 
759
  and len(self.var_bank0) > 0
760
  ):
761
  # print("hacked_CrossAttnUpBlock2D_forward1")
762
+ hidden_states = mix_ref_feature(
763
+ hidden_states, [self.fea_bank0[i]], cfg=self.do_classifier_free_guidance, ref_scale=self.ref_scale)
764
+
765
+ hidden_states = mix_norm_feature(hidden_states, self.inpaint_mask, self.mean_bank0[i], self.var_bank0[i],
766
+ self.do_classifier_free_guidance,
767
+ self.style_fidelity, self.uc_mask)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
768
 
769
  hidden_states = attn(
770
  hidden_states,
 
785
  self.ref_mask.to(hidden_states.device),
786
  scale_factor=1 / scale_ratio,
787
  )
788
+ self.fea_bank.append(save_ref_feature(
789
+ hidden_states, this_ref_mask))
790
  this_ref_mask = this_ref_mask.repeat(
791
  hidden_states.shape[0], hidden_states.shape[1], 1, 1
792
  ).bool()
 
799
  var, mean = torch.var_mean(
800
  masked_hidden_states, dim=(2, 3), keepdim=True, correction=0
801
  )
802
+ self.mean_bank.append(torch.cat([mean]*2, dim=0))
803
+ self.var_bank.append(torch.cat([var]*2, dim=0))
804
  if self.MODE == "read":
805
  if (
806
  self.gn_auto_machine_weight >= self.gn_weight
 
808
  and len(self.var_bank) > 0
809
  ):
810
  # print("hacked_CrossAttnUpBlock2D_forward")
811
+ hidden_states = mix_ref_feature(
812
+ hidden_states, [self.fea_bank[i]], cfg=self.do_classifier_free_guidance, ref_scale=self.ref_scale)
813
+
814
+ hidden_states = mix_norm_feature(hidden_states, self.inpaint_mask, self.mean_bank[i], self.var_bank[i],
815
+ self.do_classifier_free_guidance,
816
+ self.style_fidelity, self.uc_mask)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
817
 
818
  if self.MODE == "read":
819
  self.mean_bank0 = []
820
  self.var_bank0 = []
821
  self.mean_bank = []
822
  self.var_bank = []
823
+ self.fea_bank = []
824
+ self.fea_bank0 = []
825
 
826
  if self.upsamplers is not None:
827
  for upsampler in self.upsamplers:
 
851
  self.ref_mask.to(hidden_states.device),
852
  scale_factor=1 / scale_ratio,
853
  )
854
+ self.fea_bank.append(save_ref_feature(
855
+ hidden_states, this_ref_mask))
856
  this_ref_mask = this_ref_mask.repeat(
857
  hidden_states.shape[0], hidden_states.shape[1], 1, 1
858
  ).bool()
 
865
  var, mean = torch.var_mean(
866
  masked_hidden_states, dim=(2, 3), keepdim=True, correction=0
867
  )
868
+ self.mean_bank.append(torch.cat([mean]*2, dim=0))
869
+ self.var_bank.append(torch.cat([var]*2, dim=0))
870
  if self.MODE == "read":
871
  if (
872
  self.gn_auto_machine_weight >= self.gn_weight
 
874
  and len(self.var_bank) > 0
875
  ):
876
  # print("hacked_UpBlock2D_forward")
877
+ hidden_states = mix_ref_feature(
878
+ hidden_states, [self.fea_bank[i]], cfg=self.do_classifier_free_guidance, ref_scale=self.ref_scale)
879
+
880
+ hidden_states = mix_norm_feature(hidden_states, self.inpaint_mask, self.mean_bank[i], self.var_bank[i],
881
+ self.do_classifier_free_guidance,
882
+ self.style_fidelity, self.uc_mask)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
883
 
884
  if self.MODE == "read":
885
  self.mean_bank = []
886
  self.var_bank = []
887
+ self.fea_bank = []
888
 
889
  if self.upsamplers is not None:
890
  for upsampler in self.upsamplers:
 
910
  module, BasicTransformerBlock
911
  )
912
  module.bank = []
913
+ module.fea_bank = []
914
  module.attn_weight = float(i) / float(len(attn_modules))
915
  module.attention_auto_machine_weight = (
916
  self.attention_auto_machine_weight
 
925
  module.uc_mask = self.uc_mask
926
  module.style_fidelity = self.style_fidelity
927
  module.ref_mask = self.ref_mask
928
+ module.ref_scale = self.ref_scale
929
  else:
930
  attn_modules = None
931
  if reference_adain:
 
952
  module.forward = hacked_mid_forward.__get__(
953
  module, torch.nn.Module
954
  )
955
+ # elif isinstance(module, CrossAttnDownBlock2D):
956
+ # module.forward = hack_CrossAttnDownBlock2D_forward.__get__(
957
+ # module, CrossAttnDownBlock2D
958
+ # )
959
+ # module.mean_bank0 = []
960
+ # module.var_bank0 = []
961
+ # module.fea_bank0 = []
962
+
963
  elif isinstance(module, DownBlock2D):
964
  module.forward = hacked_DownBlock2D_forward.__get__(
965
  module, DownBlock2D
 
968
  # module.forward = hacked_CrossAttnUpBlock2D_forward.__get__(module, CrossAttnUpBlock2D)
969
  # module.mean_bank0 = []
970
  # module.var_bank0 = []
971
+ # module.fea_bank0 = []
972
  elif isinstance(module, UpBlock2D):
973
  module.forward = hacked_UpBlock2D_forward.__get__(
974
  module, UpBlock2D
975
  )
976
  module.mean_bank0 = []
977
  module.var_bank0 = []
978
+ module.fea_bank0 = []
979
  module.mean_bank = []
980
  module.var_bank = []
981
+ module.fea_bank = []
982
  module.attention_auto_machine_weight = (
983
  self.attention_auto_machine_weight
984
  )
 
993
  module.style_fidelity = self.style_fidelity
994
  module.ref_mask = self.ref_mask
995
  module.inpaint_mask = self.inpaint_mask
996
+ module.ref_scale = self.ref_scale
997
  else:
998
  gn_modules = None
999
  elif model_type == "controlnet":
 
1013
  module, BasicTransformerBlock
1014
  )
1015
  module.bank = []
1016
+ module.fea_bank = []
1017
  # float(i) / float(len(attn_modules))
1018
  module.attn_weight = 0.0
1019
  module.attention_auto_machine_weight = (
 
1029
  module.uc_mask = self.uc_mask
1030
  module.style_fidelity = self.style_fidelity
1031
  module.ref_mask = self.ref_mask
1032
+ module.ref_scale = self.ref_scale
1033
  else:
1034
  attn_modules = None
1035
+ # gn_modules = None
1036
+ if reference_adain:
1037
+ gn_modules = [model.mid_block]
1038
+ model.mid_block.gn_weight = 0
1039
+
1040
+ down_blocks = model.down_blocks
1041
+ for w, module in enumerate(down_blocks):
1042
+ module.gn_weight = 1.0 - float(w) / float(len(down_blocks))
1043
+ gn_modules.append(module)
1044
+ # print(module.__class__.__name__,module.gn_weight)
1045
+
1046
+
1047
+ for i, module in enumerate(gn_modules):
1048
+ if getattr(module, "original_forward", None) is None:
1049
+ module.original_forward = module.forward
1050
+ if i == 0:
1051
+ # mid_block
1052
+ module.forward = hacked_mid_forward.__get__(
1053
+ module, torch.nn.Module
1054
+ )
1055
+ # elif isinstance(module, CrossAttnDownBlock2D):
1056
+ # module.forward = hack_CrossAttnDownBlock2D_forward.__get__(
1057
+ # module, CrossAttnDownBlock2D
1058
+ # )
1059
+ # module.mean_bank0 = []
1060
+ # module.var_bank0 = []
1061
+ # module.fea_bank0 = []
1062
+
1063
+ elif isinstance(module, DownBlock2D):
1064
+ module.forward = hacked_DownBlock2D_forward.__get__(
1065
+ module, DownBlock2D
1066
+ )
1067
+ module.mean_bank = []
1068
+ module.var_bank = []
1069
+ module.fea_bank = []
1070
+ module.attention_auto_machine_weight = (
1071
+ self.attention_auto_machine_weight
1072
+ )
1073
+ module.gn_auto_machine_weight = self.gn_auto_machine_weight
1074
+ module.do_classifier_free_guidance = (
1075
+ self.do_classifier_free_guidance
1076
+ )
1077
+ module.do_classifier_free_guidance = (
1078
+ self.do_classifier_free_guidance
1079
+ )
1080
+ module.uc_mask = self.uc_mask
1081
+ module.style_fidelity = self.style_fidelity
1082
+ module.ref_mask = self.ref_mask
1083
+ module.inpaint_mask = self.inpaint_mask
1084
+ module.ref_scale = self.ref_scale
1085
+ else:
1086
+ gn_modules = None
1087
 
1088
  return attn_modules, gn_modules
1089
 
 
1091
  if attn_modules is not None:
1092
  for i, module in enumerate(attn_modules):
1093
  module.MODE = mode
1094
+
1095
  if gn_modules is not None:
1096
  for i, module in enumerate(gn_modules):
1097
  module.MODE = mode