Jinl commited on
Commit
1015380
1 Parent(s): 172aac1

[update] LCM support

Browse files
.gitignore CHANGED
@@ -5,12 +5,5 @@
5
  ./row_results
6
  ./new_res
7
  ./cop
8
- examper
9
- results
10
- data
11
- results_ablation
12
- row_results
13
- new_res
14
- cop
15
  ./samples
16
  samples
 
5
  ./row_results
6
  ./new_res
7
  ./cop
 
 
 
 
 
 
 
8
  ./samples
9
  samples
app.py CHANGED
@@ -1,41 +1,25 @@
1
  import os
2
- os.system("pip uninstall -y gradio")
3
- os.system("pip install gradio==3.47")
4
-
5
- import json
6
- import re
7
- from turtle import width
8
  import torch
9
  import random
10
  import numpy as np
11
  import gradio as gr
12
  from glob import glob
13
- from omegaconf import OmegaConf
14
  from datetime import datetime
15
- from safetensors import safe_open
16
 
17
- from diffusers import AutoencoderKL,UNet2DConditionModel,StableDiffusionPipeline
18
- from diffusers import DDIMScheduler, EulerDiscreteScheduler, PNDMScheduler
19
- from diffusers.utils.import_utils import is_xformers_available
20
- from transformers import CLIPTextModel, CLIPTokenizer
21
- from utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint
22
- from utils.convert_lora_safetensor_to_diffusers import convert_lora
23
 
24
  import torch.nn.functional as F
25
- from PIL import Image
26
-
27
- from utils.diffuser_utils import MasaCtrlPipeline
28
  from utils.masactrl_utils import (AttentionBase,
29
  regiter_attention_editor_diffusers)
30
  from utils.free_lunch_utils import register_upblock2d,register_crossattn_upblock2d,register_free_upblock2d, register_free_crossattn_upblock2d
31
-
32
  from utils.style_attn_control import MaskPromptedStyleAttentionControl
 
33
  from torchvision.utils import save_image
34
- from diffusers.models.attention_processor import AttnProcessor2_0
35
 
36
 
37
- # 在 python 中使用 pip 安装 3.41 版本的 gradio
38
-
39
 
40
  css = """
41
  .toolbutton {
@@ -60,7 +44,7 @@ class GlobalText:
60
  self.savedir_mask = os.path.join(self.savedir, "mask")
61
 
62
  self.stable_diffusion_list = ["runwayml/stable-diffusion-v1-5",
63
- "stabilityai/stable-diffusion-2-1"]
64
  self.personalized_model_list = []
65
  self.lora_model_list = []
66
 
@@ -71,14 +55,21 @@ class GlobalText:
71
  self.unet = None
72
  self.pipeline = None
73
  self.lora_loaded = None
 
74
  self.personal_model_loaded = None
 
 
75
  self.lora_model_state_dict = {}
76
  self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
77
  # self.refresh_stable_diffusion()
78
  self.refresh_personalized_model()
79
 
 
 
80
  self.reset_start_code()
81
  def load_base_pipeline(self, model_path):
 
 
82
  print(f'loading {model_path} model')
83
  scheduler = DDIMScheduler.from_pretrained(model_path,subfolder="scheduler")
84
  self.pipeline = MasaCtrlPipeline.from_pretrained(model_path,
@@ -89,6 +80,7 @@ class GlobalText:
89
  self.load_base_pipeline(self.stable_diffusion_list[0])
90
  self.lora_loaded = None
91
  self.personal_model_loaded = None
 
92
  return self.stable_diffusion_list[0]
93
 
94
  def refresh_personalized_model(self):
@@ -99,11 +91,14 @@ class GlobalText:
99
  self.lora_model_list = {os.path.basename(file): file for file in lora_model_list}
100
 
101
  def update_stable_diffusion(self, stable_diffusion_dropdown):
102
-
103
- self.load_base_pipeline(stable_diffusion_dropdown)
 
 
 
104
  self.lora_loaded = None
105
  self.personal_model_loaded = None
106
- return gr.Dropdown.update()
107
 
108
  def update_base_model(self, base_model_dropdown):
109
  if self.pipeline is None:
@@ -132,26 +127,26 @@ class GlobalText:
132
  self.pipeline.unfuse_lora()
133
  self.pipeline.unload_lora_weights()
134
  self.lora_loaded = None
135
- # self.personal_model_loaded = None
136
  print("Restore lora.")
137
  else:
138
 
139
- lora_model_path = self.lora_model_list[lora_model_dropdown]#os.path.join(self.lora_model_dir, lora_model_dropdown)
140
- # self.lora_model_state_dict = {}
141
- # if lora_model_dropdown == "none": pass
142
- # else:
143
- # with safe_open(lora_model_dropdown, framework="pt", device="cpu") as f:
144
- # for key in f.keys():
145
- # self.lora_model_state_dict[key] = f.get_tensor(key)
146
- # convert_lora(self.pipeline, self.lora_model_state_dict, alpha=lora_alpha_slider)
147
- self.pipeline.unfuse_lora()
148
- self.pipeline.unload_lora_weights()
149
  self.pipeline.load_lora_weights(lora_model_path)
150
  self.pipeline.fuse_lora(lora_alpha_slider)
151
  self.lora_loaded = lora_model_dropdown.split('.')[0]
152
- print(f'load {lora_model_dropdown} model success!')
153
  return gr.Dropdown()
154
-
 
 
 
 
 
 
 
 
 
 
155
  def generate(self, source, style, source_mask, style_mask,
156
  start_step, start_layer, Style_attn_step,
157
  Method, Style_Guidance, ddim_steps, scale, seed, de_bug,
@@ -224,6 +219,7 @@ class GlobalText:
224
  de_bug=de_bug,
225
  )
226
  if freeu:
 
227
  print(f'++++++++++++++++++ Run with FreeU {b1}_{b2}_{s1}_{s2} ++++++++++++++++')
228
  if Method != "Without mask":
229
  register_free_upblock2d(model, b1=b1, b2=b2, s1=s1, s2=s1,source_mask=source_mask)
@@ -234,12 +230,14 @@ class GlobalText:
234
 
235
  else:
236
  print(f'++++++++++++++++++ Run without FreeU ++++++++++++++++')
 
237
  register_upblock2d(model)
238
  register_crossattn_upblock2d(model)
239
- regiter_attention_editor_diffusers(model, controller)
240
 
241
  regiter_attention_editor_diffusers(model, controller)
242
 
 
 
243
  # inference the synthesized image
244
  generate_image= model(prompts,
245
  width=width_slider,
@@ -249,7 +247,9 @@ class GlobalText:
249
  num_inference_steps=ddim_steps,
250
  ref_intermediate_latents=latents_list if inter_latents else None,
251
  neg_prompt=negative_prompt_textbox,
252
- return_intermediates=False,)
 
 
253
 
254
  # os.makedirs(os.path.join(output_dir, f"results_{sample_count}"))
255
  save_file_name = f"results_{sample_count}_step{start_step}_layer{start_layer}SG{Style_Guidance}_style_attn_step{Style_attn_step}.jpg"
@@ -283,9 +283,64 @@ class GlobalText:
283
  self.start_code = None
284
  self.latents_list = None
285
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
286
  global_text = GlobalText()
287
 
288
- gr.ImageMask()
289
  def load_mask_images(source,style,source_mask,style_mask,device,width,height,out_dir=None):
290
  # invert the image into noise map
291
  if isinstance(source['image'], np.ndarray):
@@ -306,8 +361,7 @@ def load_mask_images(source,style,source_mask,style_mask,device,width,height,out
306
  style['mask'].save(os.path.join(out_dir,'style_mask.jpg'))
307
  else:
308
  Image.fromarray(style_mask).save(os.path.join(out_dir,'style_mask.jpg'))
309
- # save source['mask']
310
- # import pdb;pdb.set_trace()
311
  source_mask = torch.from_numpy(np.array(source['mask']) if source_mask is None else source_mask).to(device) / 255.
312
  source_mask = source_mask.unsqueeze(0).permute(0, 3, 1, 2)[:,:1]
313
  source_mask = F.interpolate(source_mask, (height//8,width//8))
@@ -327,14 +381,13 @@ def load_mask_images(source,style,source_mask,style_mask,device,width,height,out
327
  return source_image,style_image,source_mask,style_mask
328
 
329
 
330
-
331
  def ui():
332
  with gr.Blocks(css=css) as demo:
333
  gr.Markdown(
334
  """
335
  # [Portrait Diffusion: Training-free Face Stylization with Chain-of-Painting](https://arxiv.org/abs/00000)
336
  Jin Liu, Huaibo Huang, Chao Jin, Ran He* (*Corresponding Author)<br>
337
- [Arxiv Report](https://arxiv.org/abs/0000) | [Project Page](https://www.github.io/) | [Github](https://github.com/)
338
  """
339
  )
340
  with gr.Column(variant="panel"):
@@ -416,7 +469,7 @@ def ui():
416
  with gr.Tab('Base Configs'):
417
  with gr.Row():
418
  # sampler_dropdown = gr.Dropdown(label="Sampling method", choices=list(scheduler_dict.keys()), value=list(scheduler_dict.keys())[0])
419
- ddim_steps = gr.Slider(label="DDIM Steps", value=50, minimum=10, maximum=100, step=1)
420
 
421
  Style_attn_step = gr.Slider(label="Step of Style Attention Control",
422
  minimum=0,
@@ -484,13 +537,17 @@ def ui():
484
 
485
  with gr.Tab("SAM"):
486
  with gr.Column():
487
- add_or_remove = gr.Radio(["Add Mask", "Remove Area"], value="Add Mask", label="Point_label (foreground/background)")
488
  with gr.Row():
489
- sam_source_btn = gr.Button(value="SAM Source")
490
- send_source_btn = gr.Button(value="Send Source")
 
 
 
 
 
491
 
492
- sam_style_btn = gr.Button(value="SAM Style")
493
- send_style_btn = gr.Button(value="Send Style")
494
  with gr.Row():
495
  source_image_sam = gr.Image(label="Source Image SAM", elem_id="SourceimgSAM", source="upload", interactive=True, type="pil", image_mode="RGB", height=512)
496
  style_image_sam = gr.Image(label="Style Image SAM", elem_id="StyleimgSAM", source="upload", interactive=True, type="pil", image_mode="RGB", height=512)
@@ -501,7 +558,19 @@ def ui():
501
 
502
  style_image_with_points = gr.Image(label="Style Image with points", elem_id="style_image_with_points", type="pil", image_mode="RGB", height=256)
503
  style_mask = gr.Image(label="Style Mask", elem_id="img2maskimg", source="upload", interactive=True, type="numpy", image_mode="RGB", height=256)
504
-
 
 
 
 
 
 
 
 
 
 
 
 
505
  gr.Examples(
506
  [[os.path.join(os.path.dirname(__file__), "images/content/1.jpg"),
507
  os.path.join(os.path.dirname(__file__), "images/style/1.jpg")],
@@ -515,7 +584,7 @@ def ui():
515
  Method, Style_Guidance,ddim_steps, cfg_scale_slider, seed_textbox, de_bug,
516
  prompt_textbox, negative_prompt_textbox, inter_latents,
517
  freeu, b1, b2, s1, s2,
518
- width_slider,height_slider,
519
  ]
520
 
521
  generate_button.click(
@@ -530,4 +599,4 @@ def ui():
530
 
531
  if __name__ == "__main__":
532
  demo = ui()
533
- demo.launch()
 
1
  import os
 
 
 
 
 
 
2
  import torch
3
  import random
4
  import numpy as np
5
  import gradio as gr
6
  from glob import glob
 
7
  from datetime import datetime
 
8
 
9
+ from diffusers import StableDiffusionPipeline
10
+ from diffusers import DDIMScheduler, LCMScheduler
 
 
 
 
11
 
12
  import torch.nn.functional as F
13
+ from PIL import Image,ImageDraw
 
 
14
  from utils.masactrl_utils import (AttentionBase,
15
  regiter_attention_editor_diffusers)
16
  from utils.free_lunch_utils import register_upblock2d,register_crossattn_upblock2d,register_free_upblock2d, register_free_crossattn_upblock2d
 
17
  from utils.style_attn_control import MaskPromptedStyleAttentionControl
18
+ from utils.pipeline import MasaCtrlPipeline
19
  from torchvision.utils import save_image
20
+ from segment_anything import sam_model_registry, SamPredictor
21
 
22
 
 
 
23
 
24
  css = """
25
  .toolbutton {
 
44
  self.savedir_mask = os.path.join(self.savedir, "mask")
45
 
46
  self.stable_diffusion_list = ["runwayml/stable-diffusion-v1-5",
47
+ "latent-consistency/lcm-lora-sdv1-5"]
48
  self.personalized_model_list = []
49
  self.lora_model_list = []
50
 
 
55
  self.unet = None
56
  self.pipeline = None
57
  self.lora_loaded = None
58
+ self.lcm_lora_loaded = False
59
  self.personal_model_loaded = None
60
+ self.sam_predictor = None
61
+
62
  self.lora_model_state_dict = {}
63
  self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
64
  # self.refresh_stable_diffusion()
65
  self.refresh_personalized_model()
66
 
67
+
68
+
69
  self.reset_start_code()
70
  def load_base_pipeline(self, model_path):
71
+
72
+
73
  print(f'loading {model_path} model')
74
  scheduler = DDIMScheduler.from_pretrained(model_path,subfolder="scheduler")
75
  self.pipeline = MasaCtrlPipeline.from_pretrained(model_path,
 
80
  self.load_base_pipeline(self.stable_diffusion_list[0])
81
  self.lora_loaded = None
82
  self.personal_model_loaded = None
83
+ self.lcm_lora_loaded = False
84
  return self.stable_diffusion_list[0]
85
 
86
  def refresh_personalized_model(self):
 
91
  self.lora_model_list = {os.path.basename(file): file for file in lora_model_list}
92
 
93
  def update_stable_diffusion(self, stable_diffusion_dropdown):
94
+
95
+ if stable_diffusion_dropdown == 'latent-consistency/lcm-lora-sdv1-5':
96
+ self.load_lcm_lora()
97
+ else:
98
+ self.load_base_pipeline(stable_diffusion_dropdown)
99
  self.lora_loaded = None
100
  self.personal_model_loaded = None
101
+ return gr.Dropdown()
102
 
103
  def update_base_model(self, base_model_dropdown):
104
  if self.pipeline is None:
 
127
  self.pipeline.unfuse_lora()
128
  self.pipeline.unload_lora_weights()
129
  self.lora_loaded = None
 
130
  print("Restore lora.")
131
  else:
132
 
133
+ lora_model_path = self.lora_model_list[lora_model_dropdown]
 
 
 
 
 
 
 
 
 
134
  self.pipeline.load_lora_weights(lora_model_path)
135
  self.pipeline.fuse_lora(lora_alpha_slider)
136
  self.lora_loaded = lora_model_dropdown.split('.')[0]
137
+ print(f'load {lora_model_dropdown} LoRA Model Success!')
138
  return gr.Dropdown()
139
+
140
+ def load_lcm_lora(self, lora_alpha_slider=1.0):
141
+ # set scheduler
142
+ self.pipeline = MasaCtrlPipeline.from_pretrained(self.stable_diffusion_list[0]).to(self.device)
143
+ self.pipeline.scheduler = LCMScheduler.from_config(self.pipeline.scheduler.config)
144
+ # load LCM-LoRA
145
+ self.pipeline.load_lora_weights("latent-consistency/lcm-lora-sdv1-5")
146
+ self.pipeline.fuse_lora(lora_alpha_slider)
147
+ self.lcm_lora_loaded = True
148
+ print(f'load LCM-LoRA model success!')
149
+
150
  def generate(self, source, style, source_mask, style_mask,
151
  start_step, start_layer, Style_attn_step,
152
  Method, Style_Guidance, ddim_steps, scale, seed, de_bug,
 
219
  de_bug=de_bug,
220
  )
221
  if freeu:
222
+ # model.enable_freeu(s1=0.9, s2=0.2, b1=1.2, b2=1.4)
223
  print(f'++++++++++++++++++ Run with FreeU {b1}_{b2}_{s1}_{s2} ++++++++++++++++')
224
  if Method != "Without mask":
225
  register_free_upblock2d(model, b1=b1, b2=b2, s1=s1, s2=s1,source_mask=source_mask)
 
230
 
231
  else:
232
  print(f'++++++++++++++++++ Run without FreeU ++++++++++++++++')
233
+ # model.disable_freeu()
234
  register_upblock2d(model)
235
  register_crossattn_upblock2d(model)
 
236
 
237
  regiter_attention_editor_diffusers(model, controller)
238
 
239
+
240
+
241
  # inference the synthesized image
242
  generate_image= model(prompts,
243
  width=width_slider,
 
247
  num_inference_steps=ddim_steps,
248
  ref_intermediate_latents=latents_list if inter_latents else None,
249
  neg_prompt=negative_prompt_textbox,
250
+ return_intermediates=False,
251
+ lcm_lora=self.lcm_lora_loaded,
252
+ de_bug=de_bug,)
253
 
254
  # os.makedirs(os.path.join(output_dir, f"results_{sample_count}"))
255
  save_file_name = f"results_{sample_count}_step{start_step}_layer{start_layer}SG{Style_Guidance}_style_attn_step{Style_attn_step}.jpg"
 
283
  self.start_code = None
284
  self.latents_list = None
285
 
286
+ def lora_sam_predictor(self, sam_path):
287
+ sam_checkpoint = sam_path
288
+ model_type = "vit_h"
289
+ sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
290
+ sam.to(device=self.device)
291
+ self.sam_predictor = SamPredictor(sam)
292
+ self.sam_point = []
293
+ self.sam_point_label = []
294
+
295
+ def get_points_with_draw(self, image, image_with_points, label, evt: gr.SelectData):
296
+
297
+ x, y = evt.index[0], evt.index[1]
298
+ point_radius, point_color = 15, (255, 255, 0) if label == 'Add Mask' else (255, 0, 255)
299
+ self.sam_point.append([x, y])
300
+ self.sam_point_label.append(1 if label == 'Add Mask' else 0)
301
+
302
+ print(x, y, label == 'Add Mask')
303
+
304
+ if image_with_points is None:
305
+ draw = ImageDraw.Draw(image)
306
+ draw.ellipse([(x - point_radius, y - point_radius), (x + point_radius, y + point_radius)], fill=point_color)
307
+ return image
308
+ else:
309
+
310
+ draw = ImageDraw.Draw(image_with_points)
311
+ draw.ellipse([(x - point_radius, y - point_radius), (x + point_radius, y + point_radius)], fill=point_color)
312
+ return image_with_points
313
+ def reset_sam_points(self,):
314
+ self.sam_point = []
315
+ self.sam_point_label = []
316
+ print('reset all points')
317
+ return None
318
+ def obtain_mask(self, image,sam_path):
319
+ if self.sam_predictor is None:
320
+ self.lora_sam_predictor(sam_path)
321
+
322
+ print("+++++++++++++++++++ Obtain Mask by SAM ++++++++++++++++++++++")
323
+ input_point = np.array(self.sam_point)
324
+ input_label = np.array(self.sam_point_label)
325
+ predictor = self.sam_predictor
326
+ image = np.array(image)
327
+ predictor.set_image(image)
328
+
329
+ # input_point = np.array([[500, 375]])
330
+ # input_label = np.array([1])
331
+
332
+ masks, scores, logits = predictor.predict(point_coords=input_point,point_labels=input_label,multimask_output=False)
333
+
334
+ # import pdb; pdb.set_trace()
335
+ masks = masks.astype(np.uint8)
336
+ masks = masks * 255
337
+ masks = masks.transpose(1,2,0)
338
+ masks = masks.repeat(3, axis=2)
339
+ return masks
340
+
341
  global_text = GlobalText()
342
 
343
+
344
  def load_mask_images(source,style,source_mask,style_mask,device,width,height,out_dir=None):
345
  # invert the image into noise map
346
  if isinstance(source['image'], np.ndarray):
 
361
  style['mask'].save(os.path.join(out_dir,'style_mask.jpg'))
362
  else:
363
  Image.fromarray(style_mask).save(os.path.join(out_dir,'style_mask.jpg'))
364
+
 
365
  source_mask = torch.from_numpy(np.array(source['mask']) if source_mask is None else source_mask).to(device) / 255.
366
  source_mask = source_mask.unsqueeze(0).permute(0, 3, 1, 2)[:,:1]
367
  source_mask = F.interpolate(source_mask, (height//8,width//8))
 
381
  return source_image,style_image,source_mask,style_mask
382
 
383
 
 
384
  def ui():
385
  with gr.Blocks(css=css) as demo:
386
  gr.Markdown(
387
  """
388
  # [Portrait Diffusion: Training-free Face Stylization with Chain-of-Painting](https://arxiv.org/abs/00000)
389
  Jin Liu, Huaibo Huang, Chao Jin, Ran He* (*Corresponding Author)<br>
390
+ [Arxiv Report](https://arxiv.org/abs/2312.02212) | [Github](https://github.com/liujin112/PortraitDiffusion)
391
  """
392
  )
393
  with gr.Column(variant="panel"):
 
469
  with gr.Tab('Base Configs'):
470
  with gr.Row():
471
  # sampler_dropdown = gr.Dropdown(label="Sampling method", choices=list(scheduler_dict.keys()), value=list(scheduler_dict.keys())[0])
472
+ ddim_steps = gr.Slider(label="DDIM Steps", value=50, minimum=0, maximum=100, step=1)
473
 
474
  Style_attn_step = gr.Slider(label="Step of Style Attention Control",
475
  minimum=0,
 
537
 
538
  with gr.Tab("SAM"):
539
  with gr.Column():
 
540
  with gr.Row():
541
+ add_or_remove = gr.Radio(["Add Mask", "Remove Area"], value="Add Mask", label="Point_label (foreground/background)")
542
+ sam_path = gr.Textbox(label="Sam Model path", value='')
543
+ load_sam_btn = gr.Button(value="Lora SAM form path")
544
+ with gr.Row():
545
+
546
+ send_source_btn = gr.Button(value="Send Source Image from PD Tab")
547
+ sam_source_btn = gr.Button(value="Segment Source")
548
 
549
+ send_style_btn = gr.Button(value="Send Style Image from PD Tab")
550
+ sam_style_btn = gr.Button(value="Segment Style")
551
  with gr.Row():
552
  source_image_sam = gr.Image(label="Source Image SAM", elem_id="SourceimgSAM", source="upload", interactive=True, type="pil", image_mode="RGB", height=512)
553
  style_image_sam = gr.Image(label="Style Image SAM", elem_id="StyleimgSAM", source="upload", interactive=True, type="pil", image_mode="RGB", height=512)
 
558
 
559
  style_image_with_points = gr.Image(label="Style Image with points", elem_id="style_image_with_points", type="pil", image_mode="RGB", height=256)
560
  style_mask = gr.Image(label="Style Mask", elem_id="img2maskimg", source="upload", interactive=True, type="numpy", image_mode="RGB", height=256)
561
+ load_sam_btn.click(global_text.lora_sam_predictor,inputs=[sam_path],outputs=[])
562
+ source_image_sam.select(global_text.get_points_with_draw, [source_image_sam, source_image_with_points, add_or_remove], source_image_with_points)
563
+ style_image_sam.select(global_text.get_points_with_draw, [style_image_sam, style_image_with_points, add_or_remove], style_image_with_points)
564
+ send_source_btn.click(lambda x: (x['image'], None), inputs=[source_image], outputs=[source_image_sam, source_image_with_points])
565
+ send_style_btn.click(lambda x: (x['image'], None), inputs=[style_image], outputs=[style_image_sam, style_image_with_points])
566
+
567
+ style_image_sam.change(global_text.reset_sam_points, inputs=[], outputs=[style_image_with_points])
568
+ source_image_sam.change(global_text.reset_sam_points, inputs=[], outputs=[source_image_with_points])
569
+
570
+
571
+ sam_source_btn.click(global_text.obtain_mask,[source_image_sam, sam_path],[source_mask])
572
+ sam_style_btn.click(global_text.obtain_mask,[style_image_sam, sam_path],[style_mask])
573
+
574
  gr.Examples(
575
  [[os.path.join(os.path.dirname(__file__), "images/content/1.jpg"),
576
  os.path.join(os.path.dirname(__file__), "images/style/1.jpg")],
 
584
  Method, Style_Guidance,ddim_steps, cfg_scale_slider, seed_textbox, de_bug,
585
  prompt_textbox, negative_prompt_textbox, inter_latents,
586
  freeu, b1, b2, s1, s2,
587
+ width_slider,height_slider
588
  ]
589
 
590
  generate_button.click(
 
599
 
600
  if __name__ == "__main__":
601
  demo = ui()
602
+ demo.launch(server_name="172.18.32.44")
app.sh DELETED
@@ -1,7 +0,0 @@
1
- #!/bin/bash
2
-
3
- export CUDA_VISIBLE_DEVICES=$1
4
-
5
- echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
6
- # export CUDA_VISIBLE_DEVICES=5
7
- python app.py
 
 
 
 
 
 
 
 
gr4_test.py DELETED
@@ -1,15 +0,0 @@
1
- import gradio as gr
2
-
3
- cnt = 0
4
-
5
- def test():
6
- cnt += 1
7
- return f'triggered!{cnt}'
8
-
9
-
10
- with gr.Blocks() as demo:
11
- sketch_pad = gr.ImageEditor(type="pil")
12
- output_text = gr.Textbox(label='Output Text')
13
- sketch_pad.change(test, outputs=[output_text])
14
-
15
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/convert_from_ckpt.py DELETED
@@ -1,959 +0,0 @@
1
- # coding=utf-8
2
- # Copyright 2023 The HuggingFace Inc. team.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- """ Conversion script for the Stable Diffusion checkpoints."""
16
-
17
- import re
18
- from io import BytesIO
19
- from typing import Optional
20
-
21
- import requests
22
- import torch
23
- from transformers import (
24
- AutoFeatureExtractor,
25
- BertTokenizerFast,
26
- CLIPImageProcessor,
27
- CLIPTextModel,
28
- CLIPTextModelWithProjection,
29
- CLIPTokenizer,
30
- CLIPVisionConfig,
31
- CLIPVisionModelWithProjection,
32
- )
33
-
34
- from diffusers.models import (
35
- AutoencoderKL,
36
- PriorTransformer,
37
- UNet2DConditionModel,
38
- )
39
- from diffusers.schedulers import (
40
- DDIMScheduler,
41
- DDPMScheduler,
42
- DPMSolverMultistepScheduler,
43
- EulerAncestralDiscreteScheduler,
44
- EulerDiscreteScheduler,
45
- HeunDiscreteScheduler,
46
- LMSDiscreteScheduler,
47
- PNDMScheduler,
48
- UnCLIPScheduler,
49
- )
50
- from diffusers.utils.import_utils import BACKENDS_MAPPING
51
-
52
-
53
- def shave_segments(path, n_shave_prefix_segments=1):
54
- """
55
- Removes segments. Positive values shave the first segments, negative shave the last segments.
56
- """
57
- if n_shave_prefix_segments >= 0:
58
- return ".".join(path.split(".")[n_shave_prefix_segments:])
59
- else:
60
- return ".".join(path.split(".")[:n_shave_prefix_segments])
61
-
62
-
63
- def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
64
- """
65
- Updates paths inside resnets to the new naming scheme (local renaming)
66
- """
67
- mapping = []
68
- for old_item in old_list:
69
- new_item = old_item.replace("in_layers.0", "norm1")
70
- new_item = new_item.replace("in_layers.2", "conv1")
71
-
72
- new_item = new_item.replace("out_layers.0", "norm2")
73
- new_item = new_item.replace("out_layers.3", "conv2")
74
-
75
- new_item = new_item.replace("emb_layers.1", "time_emb_proj")
76
- new_item = new_item.replace("skip_connection", "conv_shortcut")
77
-
78
- new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
79
-
80
- mapping.append({"old": old_item, "new": new_item})
81
-
82
- return mapping
83
-
84
-
85
- def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
86
- """
87
- Updates paths inside resnets to the new naming scheme (local renaming)
88
- """
89
- mapping = []
90
- for old_item in old_list:
91
- new_item = old_item
92
-
93
- new_item = new_item.replace("nin_shortcut", "conv_shortcut")
94
- new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
95
-
96
- mapping.append({"old": old_item, "new": new_item})
97
-
98
- return mapping
99
-
100
-
101
- def renew_attention_paths(old_list, n_shave_prefix_segments=0):
102
- """
103
- Updates paths inside attentions to the new naming scheme (local renaming)
104
- """
105
- mapping = []
106
- for old_item in old_list:
107
- new_item = old_item
108
-
109
- # new_item = new_item.replace('norm.weight', 'group_norm.weight')
110
- # new_item = new_item.replace('norm.bias', 'group_norm.bias')
111
-
112
- # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
113
- # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
114
-
115
- # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
116
-
117
- mapping.append({"old": old_item, "new": new_item})
118
-
119
- return mapping
120
-
121
-
122
- def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
123
- """
124
- Updates paths inside attentions to the new naming scheme (local renaming)
125
- """
126
- mapping = []
127
- for old_item in old_list:
128
- new_item = old_item
129
-
130
- new_item = new_item.replace("norm.weight", "group_norm.weight")
131
- new_item = new_item.replace("norm.bias", "group_norm.bias")
132
-
133
- new_item = new_item.replace("q.weight", "query.weight")
134
- new_item = new_item.replace("q.bias", "query.bias")
135
-
136
- new_item = new_item.replace("k.weight", "key.weight")
137
- new_item = new_item.replace("k.bias", "key.bias")
138
-
139
- new_item = new_item.replace("v.weight", "value.weight")
140
- new_item = new_item.replace("v.bias", "value.bias")
141
-
142
- new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
143
- new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
144
-
145
- new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
146
-
147
- mapping.append({"old": old_item, "new": new_item})
148
-
149
- return mapping
150
-
151
-
152
- def assign_to_checkpoint(
153
- paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
154
- ):
155
- """
156
- This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits
157
- attention layers, and takes into account additional replacements that may arise.
158
-
159
- Assigns the weights to the new checkpoint.
160
- """
161
- assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
162
-
163
- # Splits the attention layers into three variables.
164
- if attention_paths_to_split is not None:
165
- for path, path_map in attention_paths_to_split.items():
166
- old_tensor = old_checkpoint[path]
167
- channels = old_tensor.shape[0] // 3
168
-
169
- target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
170
-
171
- num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
172
-
173
- old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
174
- query, key, value = old_tensor.split(channels // num_heads, dim=1)
175
-
176
- checkpoint[path_map["query"]] = query.reshape(target_shape)
177
- checkpoint[path_map["key"]] = key.reshape(target_shape)
178
- checkpoint[path_map["value"]] = value.reshape(target_shape)
179
-
180
- for path in paths:
181
- new_path = path["new"]
182
-
183
- # These have already been assigned
184
- if attention_paths_to_split is not None and new_path in attention_paths_to_split:
185
- continue
186
-
187
- # Global renaming happens here
188
- new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
189
- new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
190
- new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
191
-
192
- if additional_replacements is not None:
193
- for replacement in additional_replacements:
194
- new_path = new_path.replace(replacement["old"], replacement["new"])
195
-
196
- # proj_attn.weight has to be converted from conv 1D to linear
197
- if "proj_attn.weight" in new_path:
198
- checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
199
- else:
200
- checkpoint[new_path] = old_checkpoint[path["old"]]
201
-
202
-
203
- def conv_attn_to_linear(checkpoint):
204
- keys = list(checkpoint.keys())
205
- attn_keys = ["query.weight", "key.weight", "value.weight"]
206
- for key in keys:
207
- if ".".join(key.split(".")[-2:]) in attn_keys:
208
- if checkpoint[key].ndim > 2:
209
- checkpoint[key] = checkpoint[key][:, :, 0, 0]
210
- elif "proj_attn.weight" in key:
211
- if checkpoint[key].ndim > 2:
212
- checkpoint[key] = checkpoint[key][:, :, 0]
213
-
214
-
215
- def create_unet_diffusers_config(original_config, image_size: int, controlnet=False):
216
- """
217
- Creates a config for the diffusers based on the config of the LDM model.
218
- """
219
- if controlnet:
220
- unet_params = original_config.model.params.control_stage_config.params
221
- else:
222
- unet_params = original_config.model.params.unet_config.params
223
-
224
- vae_params = original_config.model.params.first_stage_config.params.ddconfig
225
-
226
- block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult]
227
-
228
- down_block_types = []
229
- resolution = 1
230
- for i in range(len(block_out_channels)):
231
- block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D"
232
- down_block_types.append(block_type)
233
- if i != len(block_out_channels) - 1:
234
- resolution *= 2
235
-
236
- up_block_types = []
237
- for i in range(len(block_out_channels)):
238
- block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D"
239
- up_block_types.append(block_type)
240
- resolution //= 2
241
-
242
- vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1)
243
-
244
- head_dim = unet_params.num_heads if "num_heads" in unet_params else None
245
- use_linear_projection = (
246
- unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False
247
- )
248
- if use_linear_projection:
249
- # stable diffusion 2-base-512 and 2-768
250
- if head_dim is None:
251
- head_dim = [5, 10, 20, 20]
252
-
253
- class_embed_type = None
254
- projection_class_embeddings_input_dim = None
255
-
256
- if "num_classes" in unet_params:
257
- if unet_params.num_classes == "sequential":
258
- class_embed_type = "projection"
259
- assert "adm_in_channels" in unet_params
260
- projection_class_embeddings_input_dim = unet_params.adm_in_channels
261
- else:
262
- raise NotImplementedError(f"Unknown conditional unet num_classes config: {unet_params.num_classes}")
263
-
264
- config = {
265
- "sample_size": image_size // vae_scale_factor,
266
- "in_channels": unet_params.in_channels,
267
- "down_block_types": tuple(down_block_types),
268
- "block_out_channels": tuple(block_out_channels),
269
- "layers_per_block": unet_params.num_res_blocks,
270
- "cross_attention_dim": unet_params.context_dim,
271
- "attention_head_dim": head_dim,
272
- "use_linear_projection": use_linear_projection,
273
- "class_embed_type": class_embed_type,
274
- "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim,
275
- }
276
-
277
- if not controlnet:
278
- config["out_channels"] = unet_params.out_channels
279
- config["up_block_types"] = tuple(up_block_types)
280
-
281
- return config
282
-
283
-
284
- def create_vae_diffusers_config(original_config, image_size: int):
285
- """
286
- Creates a config for the diffusers based on the config of the LDM model.
287
- """
288
- vae_params = original_config.model.params.first_stage_config.params.ddconfig
289
- _ = original_config.model.params.first_stage_config.params.embed_dim
290
-
291
- block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult]
292
- down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
293
- up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
294
-
295
- config = {
296
- "sample_size": image_size,
297
- "in_channels": vae_params.in_channels,
298
- "out_channels": vae_params.out_ch,
299
- "down_block_types": tuple(down_block_types),
300
- "up_block_types": tuple(up_block_types),
301
- "block_out_channels": tuple(block_out_channels),
302
- "latent_channels": vae_params.z_channels,
303
- "layers_per_block": vae_params.num_res_blocks,
304
- }
305
- return config
306
-
307
-
308
- def create_diffusers_schedular(original_config):
309
- schedular = DDIMScheduler(
310
- num_train_timesteps=original_config.model.params.timesteps,
311
- beta_start=original_config.model.params.linear_start,
312
- beta_end=original_config.model.params.linear_end,
313
- beta_schedule="scaled_linear",
314
- )
315
- return schedular
316
-
317
-
318
- def create_ldm_bert_config(original_config):
319
- bert_params = original_config.model.parms.cond_stage_config.params
320
- config = LDMBertConfig(
321
- d_model=bert_params.n_embed,
322
- encoder_layers=bert_params.n_layer,
323
- encoder_ffn_dim=bert_params.n_embed * 4,
324
- )
325
- return config
326
-
327
-
328
- def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False, controlnet=False):
329
- """
330
- Takes a state dict and a config, and returns a converted checkpoint.
331
- """
332
-
333
- # extract state_dict for UNet
334
- unet_state_dict = {}
335
- keys = list(checkpoint.keys())
336
-
337
- if controlnet:
338
- unet_key = "control_model."
339
- else:
340
- unet_key = "model.diffusion_model."
341
-
342
- # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
343
- if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema:
344
- print(f"Checkpoint {path} has both EMA and non-EMA weights.")
345
- print(
346
- "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
347
- " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
348
- )
349
- for key in keys:
350
- if key.startswith("model.diffusion_model"):
351
- flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
352
- unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
353
- else:
354
- if sum(k.startswith("model_ema") for k in keys) > 100:
355
- print(
356
- "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
357
- " weights (usually better for inference), please make sure to add the `--extract_ema` flag."
358
- )
359
-
360
- for key in keys:
361
- if key.startswith(unet_key):
362
- unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
363
-
364
- new_checkpoint = {}
365
-
366
- new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
367
- new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
368
- new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
369
- new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
370
-
371
- if config["class_embed_type"] is None:
372
- # No parameters to port
373
- ...
374
- elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection":
375
- new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"]
376
- new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"]
377
- new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"]
378
- new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"]
379
- else:
380
- raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}")
381
-
382
- new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
383
- new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
384
-
385
- if not controlnet:
386
- new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
387
- new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
388
- new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
389
- new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
390
-
391
- # Retrieves the keys for the input blocks only
392
- num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
393
- input_blocks = {
394
- layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
395
- for layer_id in range(num_input_blocks)
396
- }
397
-
398
- # Retrieves the keys for the middle blocks only
399
- num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
400
- middle_blocks = {
401
- layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
402
- for layer_id in range(num_middle_blocks)
403
- }
404
-
405
- # Retrieves the keys for the output blocks only
406
- num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
407
- output_blocks = {
408
- layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
409
- for layer_id in range(num_output_blocks)
410
- }
411
-
412
- for i in range(1, num_input_blocks):
413
- block_id = (i - 1) // (config["layers_per_block"] + 1)
414
- layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
415
-
416
- resnets = [
417
- key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
418
- ]
419
- attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
420
-
421
- if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
422
- new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
423
- f"input_blocks.{i}.0.op.weight"
424
- )
425
- new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
426
- f"input_blocks.{i}.0.op.bias"
427
- )
428
-
429
- paths = renew_resnet_paths(resnets)
430
- meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
431
- assign_to_checkpoint(
432
- paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
433
- )
434
-
435
- if len(attentions):
436
- paths = renew_attention_paths(attentions)
437
- meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
438
- assign_to_checkpoint(
439
- paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
440
- )
441
-
442
- resnet_0 = middle_blocks[0]
443
- attentions = middle_blocks[1]
444
- resnet_1 = middle_blocks[2]
445
-
446
- resnet_0_paths = renew_resnet_paths(resnet_0)
447
- assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
448
-
449
- resnet_1_paths = renew_resnet_paths(resnet_1)
450
- assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
451
-
452
- attentions_paths = renew_attention_paths(attentions)
453
- meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
454
- assign_to_checkpoint(
455
- attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
456
- )
457
-
458
- for i in range(num_output_blocks):
459
- block_id = i // (config["layers_per_block"] + 1)
460
- layer_in_block_id = i % (config["layers_per_block"] + 1)
461
- output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
462
- output_block_list = {}
463
-
464
- for layer in output_block_layers:
465
- layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
466
- if layer_id in output_block_list:
467
- output_block_list[layer_id].append(layer_name)
468
- else:
469
- output_block_list[layer_id] = [layer_name]
470
-
471
- if len(output_block_list) > 1:
472
- resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
473
- attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
474
-
475
- resnet_0_paths = renew_resnet_paths(resnets)
476
- paths = renew_resnet_paths(resnets)
477
-
478
- meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
479
- assign_to_checkpoint(
480
- paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
481
- )
482
-
483
- output_block_list = {k: sorted(v) for k, v in output_block_list.items()}
484
- if ["conv.bias", "conv.weight"] in output_block_list.values():
485
- index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
486
- new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
487
- f"output_blocks.{i}.{index}.conv.weight"
488
- ]
489
- new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
490
- f"output_blocks.{i}.{index}.conv.bias"
491
- ]
492
-
493
- # Clear attentions as they have been attributed above.
494
- if len(attentions) == 2:
495
- attentions = []
496
-
497
- if len(attentions):
498
- paths = renew_attention_paths(attentions)
499
- meta_path = {
500
- "old": f"output_blocks.{i}.1",
501
- "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
502
- }
503
- assign_to_checkpoint(
504
- paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
505
- )
506
- else:
507
- resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
508
- for path in resnet_0_paths:
509
- old_path = ".".join(["output_blocks", str(i), path["old"]])
510
- new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
511
-
512
- new_checkpoint[new_path] = unet_state_dict[old_path]
513
-
514
- if controlnet:
515
- # conditioning embedding
516
-
517
- orig_index = 0
518
-
519
- new_checkpoint["controlnet_cond_embedding.conv_in.weight"] = unet_state_dict.pop(
520
- f"input_hint_block.{orig_index}.weight"
521
- )
522
- new_checkpoint["controlnet_cond_embedding.conv_in.bias"] = unet_state_dict.pop(
523
- f"input_hint_block.{orig_index}.bias"
524
- )
525
-
526
- orig_index += 2
527
-
528
- diffusers_index = 0
529
-
530
- while diffusers_index < 6:
531
- new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.weight"] = unet_state_dict.pop(
532
- f"input_hint_block.{orig_index}.weight"
533
- )
534
- new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.bias"] = unet_state_dict.pop(
535
- f"input_hint_block.{orig_index}.bias"
536
- )
537
- diffusers_index += 1
538
- orig_index += 2
539
-
540
- new_checkpoint["controlnet_cond_embedding.conv_out.weight"] = unet_state_dict.pop(
541
- f"input_hint_block.{orig_index}.weight"
542
- )
543
- new_checkpoint["controlnet_cond_embedding.conv_out.bias"] = unet_state_dict.pop(
544
- f"input_hint_block.{orig_index}.bias"
545
- )
546
-
547
- # down blocks
548
- for i in range(num_input_blocks):
549
- new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = unet_state_dict.pop(f"zero_convs.{i}.0.weight")
550
- new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = unet_state_dict.pop(f"zero_convs.{i}.0.bias")
551
-
552
- # mid block
553
- new_checkpoint["controlnet_mid_block.weight"] = unet_state_dict.pop("middle_block_out.0.weight")
554
- new_checkpoint["controlnet_mid_block.bias"] = unet_state_dict.pop("middle_block_out.0.bias")
555
-
556
- return new_checkpoint
557
-
558
-
559
- def convert_ldm_vae_checkpoint(checkpoint, config):
560
- # extract state dict for VAE
561
- vae_state_dict = {}
562
- vae_key = "first_stage_model."
563
- keys = list(checkpoint.keys())
564
- for key in keys:
565
- if key.startswith(vae_key):
566
- vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
567
-
568
- new_checkpoint = {}
569
-
570
- new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
571
- new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
572
- new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
573
- new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
574
- new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
575
- new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
576
-
577
- new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
578
- new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
579
- new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
580
- new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
581
- new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
582
- new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
583
-
584
- new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
585
- new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
586
- new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
587
- new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
588
-
589
- # Retrieves the keys for the encoder down blocks only
590
- num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
591
- down_blocks = {
592
- layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
593
- }
594
-
595
- # Retrieves the keys for the decoder up blocks only
596
- num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
597
- up_blocks = {
598
- layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
599
- }
600
-
601
- for i in range(num_down_blocks):
602
- resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
603
-
604
- if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
605
- new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
606
- f"encoder.down.{i}.downsample.conv.weight"
607
- )
608
- new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
609
- f"encoder.down.{i}.downsample.conv.bias"
610
- )
611
-
612
- paths = renew_vae_resnet_paths(resnets)
613
- meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
614
- assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
615
-
616
- mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
617
- num_mid_res_blocks = 2
618
- for i in range(1, num_mid_res_blocks + 1):
619
- resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
620
-
621
- paths = renew_vae_resnet_paths(resnets)
622
- meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
623
- assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
624
-
625
- mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
626
- paths = renew_vae_attention_paths(mid_attentions)
627
- meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
628
- assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
629
- conv_attn_to_linear(new_checkpoint)
630
-
631
- for i in range(num_up_blocks):
632
- block_id = num_up_blocks - 1 - i
633
- resnets = [
634
- key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
635
- ]
636
-
637
- if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
638
- new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
639
- f"decoder.up.{block_id}.upsample.conv.weight"
640
- ]
641
- new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
642
- f"decoder.up.{block_id}.upsample.conv.bias"
643
- ]
644
-
645
- paths = renew_vae_resnet_paths(resnets)
646
- meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
647
- assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
648
-
649
- mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
650
- num_mid_res_blocks = 2
651
- for i in range(1, num_mid_res_blocks + 1):
652
- resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
653
-
654
- paths = renew_vae_resnet_paths(resnets)
655
- meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
656
- assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
657
-
658
- mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
659
- paths = renew_vae_attention_paths(mid_attentions)
660
- meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
661
- assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
662
- conv_attn_to_linear(new_checkpoint)
663
- return new_checkpoint
664
-
665
-
666
- def convert_ldm_bert_checkpoint(checkpoint, config):
667
- def _copy_attn_layer(hf_attn_layer, pt_attn_layer):
668
- hf_attn_layer.q_proj.weight.data = pt_attn_layer.to_q.weight
669
- hf_attn_layer.k_proj.weight.data = pt_attn_layer.to_k.weight
670
- hf_attn_layer.v_proj.weight.data = pt_attn_layer.to_v.weight
671
-
672
- hf_attn_layer.out_proj.weight = pt_attn_layer.to_out.weight
673
- hf_attn_layer.out_proj.bias = pt_attn_layer.to_out.bias
674
-
675
- def _copy_linear(hf_linear, pt_linear):
676
- hf_linear.weight = pt_linear.weight
677
- hf_linear.bias = pt_linear.bias
678
-
679
- def _copy_layer(hf_layer, pt_layer):
680
- # copy layer norms
681
- _copy_linear(hf_layer.self_attn_layer_norm, pt_layer[0][0])
682
- _copy_linear(hf_layer.final_layer_norm, pt_layer[1][0])
683
-
684
- # copy attn
685
- _copy_attn_layer(hf_layer.self_attn, pt_layer[0][1])
686
-
687
- # copy MLP
688
- pt_mlp = pt_layer[1][1]
689
- _copy_linear(hf_layer.fc1, pt_mlp.net[0][0])
690
- _copy_linear(hf_layer.fc2, pt_mlp.net[2])
691
-
692
- def _copy_layers(hf_layers, pt_layers):
693
- for i, hf_layer in enumerate(hf_layers):
694
- if i != 0:
695
- i += i
696
- pt_layer = pt_layers[i : i + 2]
697
- _copy_layer(hf_layer, pt_layer)
698
-
699
- hf_model = LDMBertModel(config).eval()
700
-
701
- # copy embeds
702
- hf_model.model.embed_tokens.weight = checkpoint.transformer.token_emb.weight
703
- hf_model.model.embed_positions.weight.data = checkpoint.transformer.pos_emb.emb.weight
704
-
705
- # copy layer norm
706
- _copy_linear(hf_model.model.layer_norm, checkpoint.transformer.norm)
707
-
708
- # copy hidden layers
709
- _copy_layers(hf_model.model.layers, checkpoint.transformer.attn_layers.layers)
710
-
711
- _copy_linear(hf_model.to_logits, checkpoint.transformer.to_logits)
712
-
713
- return hf_model
714
-
715
-
716
- def convert_ldm_clip_checkpoint(checkpoint):
717
- text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
718
- keys = list(checkpoint.keys())
719
-
720
- text_model_dict = {}
721
-
722
- for key in keys:
723
- if key.startswith("cond_stage_model.transformer"):
724
- text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
725
-
726
- text_model.load_state_dict(text_model_dict)
727
-
728
- return text_model
729
-
730
-
731
- textenc_conversion_lst = [
732
- ("cond_stage_model.model.positional_embedding", "text_model.embeddings.position_embedding.weight"),
733
- ("cond_stage_model.model.token_embedding.weight", "text_model.embeddings.token_embedding.weight"),
734
- ("cond_stage_model.model.ln_final.weight", "text_model.final_layer_norm.weight"),
735
- ("cond_stage_model.model.ln_final.bias", "text_model.final_layer_norm.bias"),
736
- ]
737
- textenc_conversion_map = {x[0]: x[1] for x in textenc_conversion_lst}
738
-
739
- textenc_transformer_conversion_lst = [
740
- # (stable-diffusion, HF Diffusers)
741
- ("resblocks.", "text_model.encoder.layers."),
742
- ("ln_1", "layer_norm1"),
743
- ("ln_2", "layer_norm2"),
744
- (".c_fc.", ".fc1."),
745
- (".c_proj.", ".fc2."),
746
- (".attn", ".self_attn"),
747
- ("ln_final.", "transformer.text_model.final_layer_norm."),
748
- ("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"),
749
- ("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"),
750
- ]
751
- protected = {re.escape(x[0]): x[1] for x in textenc_transformer_conversion_lst}
752
- textenc_pattern = re.compile("|".join(protected.keys()))
753
-
754
-
755
- def convert_paint_by_example_checkpoint(checkpoint):
756
- config = CLIPVisionConfig.from_pretrained("openai/clip-vit-large-patch14")
757
- model = PaintByExampleImageEncoder(config)
758
-
759
- keys = list(checkpoint.keys())
760
-
761
- text_model_dict = {}
762
-
763
- for key in keys:
764
- if key.startswith("cond_stage_model.transformer"):
765
- text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
766
-
767
- # load clip vision
768
- model.model.load_state_dict(text_model_dict)
769
-
770
- # load mapper
771
- keys_mapper = {
772
- k[len("cond_stage_model.mapper.res") :]: v
773
- for k, v in checkpoint.items()
774
- if k.startswith("cond_stage_model.mapper")
775
- }
776
-
777
- MAPPING = {
778
- "attn.c_qkv": ["attn1.to_q", "attn1.to_k", "attn1.to_v"],
779
- "attn.c_proj": ["attn1.to_out.0"],
780
- "ln_1": ["norm1"],
781
- "ln_2": ["norm3"],
782
- "mlp.c_fc": ["ff.net.0.proj"],
783
- "mlp.c_proj": ["ff.net.2"],
784
- }
785
-
786
- mapped_weights = {}
787
- for key, value in keys_mapper.items():
788
- prefix = key[: len("blocks.i")]
789
- suffix = key.split(prefix)[-1].split(".")[-1]
790
- name = key.split(prefix)[-1].split(suffix)[0][1:-1]
791
- mapped_names = MAPPING[name]
792
-
793
- num_splits = len(mapped_names)
794
- for i, mapped_name in enumerate(mapped_names):
795
- new_name = ".".join([prefix, mapped_name, suffix])
796
- shape = value.shape[0] // num_splits
797
- mapped_weights[new_name] = value[i * shape : (i + 1) * shape]
798
-
799
- model.mapper.load_state_dict(mapped_weights)
800
-
801
- # load final layer norm
802
- model.final_layer_norm.load_state_dict(
803
- {
804
- "bias": checkpoint["cond_stage_model.final_ln.bias"],
805
- "weight": checkpoint["cond_stage_model.final_ln.weight"],
806
- }
807
- )
808
-
809
- # load final proj
810
- model.proj_out.load_state_dict(
811
- {
812
- "bias": checkpoint["proj_out.bias"],
813
- "weight": checkpoint["proj_out.weight"],
814
- }
815
- )
816
-
817
- # load uncond vector
818
- model.uncond_vector.data = torch.nn.Parameter(checkpoint["learnable_vector"])
819
- return model
820
-
821
-
822
- def convert_open_clip_checkpoint(checkpoint):
823
- text_model = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder")
824
-
825
- keys = list(checkpoint.keys())
826
-
827
- text_model_dict = {}
828
-
829
- if "cond_stage_model.model.text_projection" in checkpoint:
830
- d_model = int(checkpoint["cond_stage_model.model.text_projection"].shape[0])
831
- else:
832
- d_model = 1024
833
-
834
- text_model_dict["text_model.embeddings.position_ids"] = text_model.text_model.embeddings.get_buffer("position_ids")
835
-
836
- for key in keys:
837
- if "resblocks.23" in key: # Diffusers drops the final layer and only uses the penultimate layer
838
- continue
839
- if key in textenc_conversion_map:
840
- text_model_dict[textenc_conversion_map[key]] = checkpoint[key]
841
- if key.startswith("cond_stage_model.model.transformer."):
842
- new_key = key[len("cond_stage_model.model.transformer.") :]
843
- if new_key.endswith(".in_proj_weight"):
844
- new_key = new_key[: -len(".in_proj_weight")]
845
- new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
846
- text_model_dict[new_key + ".q_proj.weight"] = checkpoint[key][:d_model, :]
847
- text_model_dict[new_key + ".k_proj.weight"] = checkpoint[key][d_model : d_model * 2, :]
848
- text_model_dict[new_key + ".v_proj.weight"] = checkpoint[key][d_model * 2 :, :]
849
- elif new_key.endswith(".in_proj_bias"):
850
- new_key = new_key[: -len(".in_proj_bias")]
851
- new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
852
- text_model_dict[new_key + ".q_proj.bias"] = checkpoint[key][:d_model]
853
- text_model_dict[new_key + ".k_proj.bias"] = checkpoint[key][d_model : d_model * 2]
854
- text_model_dict[new_key + ".v_proj.bias"] = checkpoint[key][d_model * 2 :]
855
- else:
856
- new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
857
-
858
- text_model_dict[new_key] = checkpoint[key]
859
-
860
- text_model.load_state_dict(text_model_dict)
861
-
862
- return text_model
863
-
864
-
865
- def stable_unclip_image_encoder(original_config):
866
- """
867
- Returns the image processor and clip image encoder for the img2img unclip pipeline.
868
-
869
- We currently know of two types of stable unclip models which separately use the clip and the openclip image
870
- encoders.
871
- """
872
-
873
- image_embedder_config = original_config.model.params.embedder_config
874
-
875
- sd_clip_image_embedder_class = image_embedder_config.target
876
- sd_clip_image_embedder_class = sd_clip_image_embedder_class.split(".")[-1]
877
-
878
- if sd_clip_image_embedder_class == "ClipImageEmbedder":
879
- clip_model_name = image_embedder_config.params.model
880
-
881
- if clip_model_name == "ViT-L/14":
882
- feature_extractor = CLIPImageProcessor()
883
- image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14")
884
- else:
885
- raise NotImplementedError(f"Unknown CLIP checkpoint name in stable diffusion checkpoint {clip_model_name}")
886
-
887
- elif sd_clip_image_embedder_class == "FrozenOpenCLIPImageEmbedder":
888
- feature_extractor = CLIPImageProcessor()
889
- image_encoder = CLIPVisionModelWithProjection.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
890
- else:
891
- raise NotImplementedError(
892
- f"Unknown CLIP image embedder class in stable diffusion checkpoint {sd_clip_image_embedder_class}"
893
- )
894
-
895
- return feature_extractor, image_encoder
896
-
897
-
898
- def stable_unclip_image_noising_components(
899
- original_config, clip_stats_path: Optional[str] = None, device: Optional[str] = None
900
- ):
901
- """
902
- Returns the noising components for the img2img and txt2img unclip pipelines.
903
-
904
- Converts the stability noise augmentor into
905
- 1. a `StableUnCLIPImageNormalizer` for holding the CLIP stats
906
- 2. a `DDPMScheduler` for holding the noise schedule
907
-
908
- If the noise augmentor config specifies a clip stats path, the `clip_stats_path` must be provided.
909
- """
910
- noise_aug_config = original_config.model.params.noise_aug_config
911
- noise_aug_class = noise_aug_config.target
912
- noise_aug_class = noise_aug_class.split(".")[-1]
913
-
914
- if noise_aug_class == "CLIPEmbeddingNoiseAugmentation":
915
- noise_aug_config = noise_aug_config.params
916
- embedding_dim = noise_aug_config.timestep_dim
917
- max_noise_level = noise_aug_config.noise_schedule_config.timesteps
918
- beta_schedule = noise_aug_config.noise_schedule_config.beta_schedule
919
-
920
- image_normalizer = StableUnCLIPImageNormalizer(embedding_dim=embedding_dim)
921
- image_noising_scheduler = DDPMScheduler(num_train_timesteps=max_noise_level, beta_schedule=beta_schedule)
922
-
923
- if "clip_stats_path" in noise_aug_config:
924
- if clip_stats_path is None:
925
- raise ValueError("This stable unclip config requires a `clip_stats_path`")
926
-
927
- clip_mean, clip_std = torch.load(clip_stats_path, map_location=device)
928
- clip_mean = clip_mean[None, :]
929
- clip_std = clip_std[None, :]
930
-
931
- clip_stats_state_dict = {
932
- "mean": clip_mean,
933
- "std": clip_std,
934
- }
935
-
936
- image_normalizer.load_state_dict(clip_stats_state_dict)
937
- else:
938
- raise NotImplementedError(f"Unknown noise augmentor class: {noise_aug_class}")
939
-
940
- return image_normalizer, image_noising_scheduler
941
-
942
-
943
- def convert_controlnet_checkpoint(
944
- checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema
945
- ):
946
- ctrlnet_config = create_unet_diffusers_config(original_config, image_size=image_size, controlnet=True)
947
- ctrlnet_config["upcast_attention"] = upcast_attention
948
-
949
- ctrlnet_config.pop("sample_size")
950
-
951
- controlnet_model = ControlNetModel(**ctrlnet_config)
952
-
953
- converted_ctrl_checkpoint = convert_ldm_unet_checkpoint(
954
- checkpoint, ctrlnet_config, path=checkpoint_path, extract_ema=extract_ema, controlnet=True
955
- )
956
-
957
- controlnet_model.load_state_dict(converted_ctrl_checkpoint)
958
-
959
- return controlnet_model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/convert_lora_safetensor_to_diffusers.py DELETED
@@ -1,154 +0,0 @@
1
- # coding=utf-8
2
- # Copyright 2023, Haofan Wang, Qixun Wang, All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- """ Conversion script for the LoRA's safetensors checkpoints. """
17
-
18
- import argparse
19
-
20
- import torch
21
- from safetensors.torch import load_file
22
-
23
- from diffusers import StableDiffusionPipeline
24
- import pdb
25
-
26
-
27
-
28
- def convert_motion_lora_ckpt_to_diffusers(pipeline, state_dict, alpha=1.0):
29
- # directly update weight in diffusers model
30
- for key in state_dict:
31
- # only process lora down key
32
- if "up." in key: continue
33
-
34
- up_key = key.replace(".down.", ".up.")
35
- model_key = key.replace("processor.", "").replace("_lora", "").replace("down.", "").replace("up.", "")
36
- model_key = model_key.replace("to_out.", "to_out.0.")
37
- layer_infos = model_key.split(".")[:-1]
38
-
39
- curr_layer = pipeline.unet
40
- while len(layer_infos) > 0:
41
- temp_name = layer_infos.pop(0)
42
- curr_layer = curr_layer.__getattr__(temp_name)
43
-
44
- weight_down = state_dict[key]
45
- weight_up = state_dict[up_key]
46
- curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device)
47
-
48
- return pipeline
49
-
50
-
51
-
52
- def convert_lora(pipeline, state_dict, LORA_PREFIX_UNET="lora_unet", LORA_PREFIX_TEXT_ENCODER="lora_te", alpha=0.6):
53
- # load base model
54
- # pipeline = StableDiffusionPipeline.from_pretrained(base_model_path, torch_dtype=torch.float32)
55
-
56
- # load LoRA weight from .safetensors
57
- # state_dict = load_file(checkpoint_path)
58
-
59
- visited = []
60
-
61
- # directly update weight in diffusers model
62
- for key in state_dict:
63
- # it is suggested to print out the key, it usually will be something like below
64
- # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"
65
-
66
- # as we have set the alpha beforehand, so just skip
67
- if ".alpha" in key or key in visited:
68
- continue
69
-
70
- if "text" in key:
71
- layer_infos = key.split(".")[0].split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
72
- curr_layer = pipeline.text_encoder
73
- else:
74
- layer_infos = key.split(".")[0].split(LORA_PREFIX_UNET + "_")[-1].split("_")
75
- curr_layer = pipeline.unet
76
-
77
- # find the target layer
78
- temp_name = layer_infos.pop(0)
79
- while len(layer_infos) > -1:
80
- try:
81
- curr_layer = curr_layer.__getattr__(temp_name)
82
- if len(layer_infos) > 0:
83
- temp_name = layer_infos.pop(0)
84
- elif len(layer_infos) == 0:
85
- break
86
- except Exception:
87
- if len(temp_name) > 0:
88
- temp_name += "_" + layer_infos.pop(0)
89
- else:
90
- temp_name = layer_infos.pop(0)
91
-
92
- pair_keys = []
93
- if "lora_down" in key:
94
- pair_keys.append(key.replace("lora_down", "lora_up"))
95
- pair_keys.append(key)
96
- else:
97
- pair_keys.append(key)
98
- pair_keys.append(key.replace("lora_up", "lora_down"))
99
-
100
- # update weight
101
- if len(state_dict[pair_keys[0]].shape) == 4:
102
- weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32)
103
- weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32)
104
- curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3).to(curr_layer.weight.data.device)
105
- else:
106
- weight_up = state_dict[pair_keys[0]].to(torch.float32)
107
- weight_down = state_dict[pair_keys[1]].to(torch.float32)
108
- curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device)
109
-
110
- # update visited list
111
- for item in pair_keys:
112
- visited.append(item)
113
-
114
- return pipeline
115
-
116
-
117
- if __name__ == "__main__":
118
- parser = argparse.ArgumentParser()
119
-
120
- parser.add_argument(
121
- "--base_model_path", default=None, type=str, required=True, help="Path to the base model in diffusers format."
122
- )
123
- parser.add_argument(
124
- "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
125
- )
126
- parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
127
- parser.add_argument(
128
- "--lora_prefix_unet", default="lora_unet", type=str, help="The prefix of UNet weight in safetensors"
129
- )
130
- parser.add_argument(
131
- "--lora_prefix_text_encoder",
132
- default="lora_te",
133
- type=str,
134
- help="The prefix of text encoder weight in safetensors",
135
- )
136
- parser.add_argument("--alpha", default=0.75, type=float, help="The merging ratio in W = W0 + alpha * deltaW")
137
- parser.add_argument(
138
- "--to_safetensors", action="store_true", help="Whether to store pipeline in safetensors format or not."
139
- )
140
- parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)")
141
-
142
- args = parser.parse_args()
143
-
144
- base_model_path = args.base_model_path
145
- checkpoint_path = args.checkpoint_path
146
- dump_path = args.dump_path
147
- lora_prefix_unet = args.lora_prefix_unet
148
- lora_prefix_text_encoder = args.lora_prefix_text_encoder
149
- alpha = args.alpha
150
-
151
- pipe = convert(base_model_path, checkpoint_path, lora_prefix_unet, lora_prefix_text_encoder, alpha)
152
-
153
- pipe = pipe.to(args.device)
154
- pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors)