ysharma HF staff commited on
Commit
918e6a9
1 Parent(s): 2310d22

upload updated files

Browse files
lora_diffusion/cli_lora_add.py CHANGED
@@ -1,35 +1,73 @@
1
  from typing import Literal, Union, Dict
2
-
 
3
  import fire
4
  from diffusers import StableDiffusionPipeline
5
 
6
  import torch
7
  from .lora import tune_lora_scale, weight_apply_lora
 
 
 
 
 
 
8
 
9
 
10
  def add(
11
  path_1: str,
12
  path_2: str,
13
- output_path: str = "./merged_lora.pt",
14
  alpha: float = 0.5,
15
- mode: Literal["lpl", "upl"] = "lpl",
 
 
 
 
 
16
  ):
 
17
  if mode == "lpl":
18
- out_list = []
19
- l1 = torch.load(path_1)
20
- l2 = torch.load(path_2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
- l1pairs = zip(l1[::2], l1[1::2])
23
- l2pairs = zip(l2[::2], l2[1::2])
24
 
25
- for (x1, y1), (x2, y2) in zip(l1pairs, l2pairs):
26
- x1.data = alpha * x1.data + (1 - alpha) * x2.data
27
- y1.data = alpha * y1.data + (1 - alpha) * y2.data
 
28
 
29
- out_list.append(x1)
30
- out_list.append(y1)
31
 
32
- torch.save(out_list, output_path)
 
 
 
 
 
 
 
 
 
 
33
 
34
  elif mode == "upl":
35
 
@@ -38,12 +76,43 @@ def add(
38
  ).to("cpu")
39
 
40
  weight_apply_lora(loaded_pipeline.unet, torch.load(path_2), alpha=alpha)
 
41
 
42
- if output_path.endswith(".pt"):
43
- output_path = output_path[:-3]
 
 
 
 
44
 
45
  loaded_pipeline.save_pretrained(output_path)
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
  def main():
49
  fire.Fire(add)
1
  from typing import Literal, Union, Dict
2
+ import os
3
+ import shutil
4
  import fire
5
  from diffusers import StableDiffusionPipeline
6
 
7
  import torch
8
  from .lora import tune_lora_scale, weight_apply_lora
9
+ from .to_ckpt_v2 import convert_to_ckpt
10
+
11
+
12
+ def _text_lora_path(path: str) -> str:
13
+ assert path.endswith(".pt"), "Only .pt files are supported"
14
+ return ".".join(path.split(".")[:-1] + ["text_encoder", "pt"])
15
 
16
 
17
  def add(
18
  path_1: str,
19
  path_2: str,
20
+ output_path: str,
21
  alpha: float = 0.5,
22
+ mode: Literal[
23
+ "lpl",
24
+ "upl",
25
+ "upl-ckpt-v2",
26
+ ] = "lpl",
27
+ with_text_lora: bool = False,
28
  ):
29
+ print("Lora Add, mode " + mode)
30
  if mode == "lpl":
31
+ for _path_1, _path_2, opt in [(path_1, path_2, "unet")] + (
32
+ [(_text_lora_path(path_1), _text_lora_path(path_2), "text_encoder")]
33
+ if with_text_lora
34
+ else []
35
+ ):
36
+ print("Loading", _path_1, _path_2)
37
+ out_list = []
38
+ if opt == "text_encoder":
39
+ if not os.path.exists(_path_1):
40
+ print(f"No text encoder found in {_path_1}, skipping...")
41
+ continue
42
+ if not os.path.exists(_path_2):
43
+ print(f"No text encoder found in {_path_1}, skipping...")
44
+ continue
45
+
46
+ l1 = torch.load(_path_1)
47
+ l2 = torch.load(_path_2)
48
 
49
+ l1pairs = zip(l1[::2], l1[1::2])
50
+ l2pairs = zip(l2[::2], l2[1::2])
51
 
52
+ for (x1, y1), (x2, y2) in zip(l1pairs, l2pairs):
53
+ # print("Merging", x1.shape, y1.shape, x2.shape, y2.shape)
54
+ x1.data = alpha * x1.data + (1 - alpha) * x2.data
55
+ y1.data = alpha * y1.data + (1 - alpha) * y2.data
56
 
57
+ out_list.append(x1)
58
+ out_list.append(y1)
59
 
60
+ if opt == "unet":
61
+
62
+ print("Saving merged UNET to", output_path)
63
+ torch.save(out_list, output_path)
64
+
65
+ elif opt == "text_encoder":
66
+ print("Saving merged text encoder to", _text_lora_path(output_path))
67
+ torch.save(
68
+ out_list,
69
+ _text_lora_path(output_path),
70
+ )
71
 
72
  elif mode == "upl":
73
 
76
  ).to("cpu")
77
 
78
  weight_apply_lora(loaded_pipeline.unet, torch.load(path_2), alpha=alpha)
79
+ if with_text_lora:
80
 
81
+ weight_apply_lora(
82
+ loaded_pipeline.text_encoder,
83
+ torch.load(_text_lora_path(path_2)),
84
+ alpha=alpha,
85
+ target_replace_module=["CLIPAttention"],
86
+ )
87
 
88
  loaded_pipeline.save_pretrained(output_path)
89
 
90
+ elif mode == "upl-ckpt-v2":
91
+
92
+ loaded_pipeline = StableDiffusionPipeline.from_pretrained(
93
+ path_1,
94
+ ).to("cpu")
95
+
96
+ weight_apply_lora(loaded_pipeline.unet, torch.load(path_2), alpha=alpha)
97
+ if with_text_lora:
98
+ weight_apply_lora(
99
+ loaded_pipeline.text_encoder,
100
+ torch.load(_text_lora_path(path_2)),
101
+ alpha=alpha,
102
+ target_replace_module=["CLIPAttention"],
103
+ )
104
+
105
+ _tmp_output = output_path + ".tmp"
106
+
107
+ loaded_pipeline.save_pretrained(_tmp_output)
108
+ convert_to_ckpt(_tmp_output, output_path, as_half=True)
109
+ # remove the tmp_output folder
110
+ shutil.rmtree(_tmp_output)
111
+
112
+ else:
113
+ print("Unknown mode", mode)
114
+ raise ValueError(f"Unknown mode {mode}")
115
+
116
 
117
  def main():
118
  fire.Fire(add)
lora_diffusion/lora.py CHANGED
@@ -10,14 +10,20 @@ import torch.nn as nn
10
 
11
 
12
  class LoraInjectedLinear(nn.Module):
13
- def __init__(self, in_features, out_features, bias=False):
14
  super().__init__()
 
 
 
 
 
 
15
  self.linear = nn.Linear(in_features, out_features, bias)
16
- self.lora_down = nn.Linear(in_features, 4, bias=False)
17
- self.lora_up = nn.Linear(4, out_features, bias=False)
18
  self.scale = 1.0
19
 
20
- nn.init.normal_(self.lora_down.weight, std=1 / 16)
21
  nn.init.zeros_(self.lora_up.weight)
22
 
23
  def forward(self, input):
@@ -25,7 +31,10 @@ class LoraInjectedLinear(nn.Module):
25
 
26
 
27
  def inject_trainable_lora(
28
- model: nn.Module, target_replace_module: List[str] = ["CrossAttention", "Attention"]
 
 
 
29
  ):
30
  """
31
  inject lora into model, and returns lora parameter groups.
@@ -34,6 +43,9 @@ def inject_trainable_lora(
34
  require_grad_params = []
35
  names = []
36
 
 
 
 
37
  for _module in model.modules():
38
  if _module.__class__.__name__ in target_replace_module:
39
 
@@ -46,6 +58,7 @@ def inject_trainable_lora(
46
  _child_module.in_features,
47
  _child_module.out_features,
48
  _child_module.bias is not None,
 
49
  )
50
  _tmp.linear.weight = weight
51
  if bias is not None:
@@ -61,10 +74,13 @@ def inject_trainable_lora(
61
  _module._modules[name].lora_down.parameters()
62
  )
63
 
 
 
 
 
64
  _module._modules[name].lora_up.weight.requires_grad = True
65
  _module._modules[name].lora_down.weight.requires_grad = True
66
  names.append(name)
67
-
68
  return require_grad_params, names
69
 
70
 
@@ -82,9 +98,13 @@ def extract_lora_ups_down(model, target_replace_module=["CrossAttention", "Atten
82
  return loras
83
 
84
 
85
- def save_lora_weight(model, path="./lora.pt"):
 
 
86
  weights = []
87
- for _up, _down in extract_lora_ups_down(model):
 
 
88
  weights.append(_up.weight)
89
  weights.append(_down.weight)
90
 
@@ -125,7 +145,7 @@ def weight_apply_lora(
125
 
126
 
127
  def monkeypatch_lora(
128
- model, loras, target_replace_module=["CrossAttention", "Attention"]
129
  ):
130
  for _module in model.modules():
131
  if _module.__class__.__name__ in target_replace_module:
@@ -138,6 +158,44 @@ def monkeypatch_lora(
138
  _child_module.in_features,
139
  _child_module.out_features,
140
  _child_module.bias is not None,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  )
142
  _tmp.linear.weight = weight
143
 
@@ -160,7 +218,138 @@ def monkeypatch_lora(
160
  _module._modules[name].to(weight.device)
161
 
162
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  def tune_lora_scale(model, alpha: float = 1.0):
164
  for _module in model.modules():
165
  if _module.__class__.__name__ == "LoraInjectedLinear":
166
  _module.scale = alpha
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
 
12
  class LoraInjectedLinear(nn.Module):
13
+ def __init__(self, in_features, out_features, bias=False, r=4):
14
  super().__init__()
15
+
16
+ if r > min(in_features, out_features):
17
+ raise ValueError(
18
+ f"LoRA rank {r} must be less or equal than {min(in_features, out_features)}"
19
+ )
20
+
21
  self.linear = nn.Linear(in_features, out_features, bias)
22
+ self.lora_down = nn.Linear(in_features, r, bias=False)
23
+ self.lora_up = nn.Linear(r, out_features, bias=False)
24
  self.scale = 1.0
25
 
26
+ nn.init.normal_(self.lora_down.weight, std=1 / r**2)
27
  nn.init.zeros_(self.lora_up.weight)
28
 
29
  def forward(self, input):
31
 
32
 
33
  def inject_trainable_lora(
34
+ model: nn.Module,
35
+ target_replace_module: List[str] = ["CrossAttention", "Attention"],
36
+ r: int = 4,
37
+ loras=None, # path to lora .pt
38
  ):
39
  """
40
  inject lora into model, and returns lora parameter groups.
43
  require_grad_params = []
44
  names = []
45
 
46
+ if loras != None:
47
+ loras = torch.load(loras)
48
+
49
  for _module in model.modules():
50
  if _module.__class__.__name__ in target_replace_module:
51
 
58
  _child_module.in_features,
59
  _child_module.out_features,
60
  _child_module.bias is not None,
61
+ r,
62
  )
63
  _tmp.linear.weight = weight
64
  if bias is not None:
74
  _module._modules[name].lora_down.parameters()
75
  )
76
 
77
+ if loras != None:
78
+ _module._modules[name].lora_up.weight = loras.pop(0)
79
+ _module._modules[name].lora_down.weight = loras.pop(0)
80
+
81
  _module._modules[name].lora_up.weight.requires_grad = True
82
  _module._modules[name].lora_down.weight.requires_grad = True
83
  names.append(name)
 
84
  return require_grad_params, names
85
 
86
 
98
  return loras
99
 
100
 
101
+ def save_lora_weight(
102
+ model, path="./lora.pt", target_replace_module=["CrossAttention", "Attention"]
103
+ ):
104
  weights = []
105
+ for _up, _down in extract_lora_ups_down(
106
+ model, target_replace_module=target_replace_module
107
+ ):
108
  weights.append(_up.weight)
109
  weights.append(_down.weight)
110
 
145
 
146
 
147
  def monkeypatch_lora(
148
+ model, loras, target_replace_module=["CrossAttention", "Attention"], r: int = 4
149
  ):
150
  for _module in model.modules():
151
  if _module.__class__.__name__ in target_replace_module:
158
  _child_module.in_features,
159
  _child_module.out_features,
160
  _child_module.bias is not None,
161
+ r=r,
162
+ )
163
+ _tmp.linear.weight = weight
164
+
165
+ if bias is not None:
166
+ _tmp.linear.bias = bias
167
+
168
+ # switch the module
169
+ _module._modules[name] = _tmp
170
+
171
+ up_weight = loras.pop(0)
172
+ down_weight = loras.pop(0)
173
+
174
+ _module._modules[name].lora_up.weight = nn.Parameter(
175
+ up_weight.type(weight.dtype)
176
+ )
177
+ _module._modules[name].lora_down.weight = nn.Parameter(
178
+ down_weight.type(weight.dtype)
179
+ )
180
+
181
+ _module._modules[name].to(weight.device)
182
+
183
+
184
+ def monkeypatch_replace_lora(
185
+ model, loras, target_replace_module=["CrossAttention", "Attention"], r: int = 4
186
+ ):
187
+ for _module in model.modules():
188
+ if _module.__class__.__name__ in target_replace_module:
189
+ for name, _child_module in _module.named_modules():
190
+ if _child_module.__class__.__name__ == "LoraInjectedLinear":
191
+
192
+ weight = _child_module.linear.weight
193
+ bias = _child_module.linear.bias
194
+ _tmp = LoraInjectedLinear(
195
+ _child_module.linear.in_features,
196
+ _child_module.linear.out_features,
197
+ _child_module.linear.bias is not None,
198
+ r=r,
199
  )
200
  _tmp.linear.weight = weight
201
 
218
  _module._modules[name].to(weight.device)
219
 
220
 
221
+ def monkeypatch_add_lora(
222
+ model,
223
+ loras,
224
+ target_replace_module=["CrossAttention", "Attention"],
225
+ alpha: float = 1.0,
226
+ beta: float = 1.0,
227
+ ):
228
+ for _module in model.modules():
229
+ if _module.__class__.__name__ in target_replace_module:
230
+ for name, _child_module in _module.named_modules():
231
+ if _child_module.__class__.__name__ == "LoraInjectedLinear":
232
+
233
+ weight = _child_module.linear.weight
234
+
235
+ up_weight = loras.pop(0)
236
+ down_weight = loras.pop(0)
237
+
238
+ _module._modules[name].lora_up.weight = nn.Parameter(
239
+ up_weight.type(weight.dtype).to(weight.device) * alpha
240
+ + _module._modules[name].lora_up.weight.to(weight.device) * beta
241
+ )
242
+ _module._modules[name].lora_down.weight = nn.Parameter(
243
+ down_weight.type(weight.dtype).to(weight.device) * alpha
244
+ + _module._modules[name].lora_down.weight.to(weight.device)
245
+ * beta
246
+ )
247
+
248
+ _module._modules[name].to(weight.device)
249
+
250
+
251
  def tune_lora_scale(model, alpha: float = 1.0):
252
  for _module in model.modules():
253
  if _module.__class__.__name__ == "LoraInjectedLinear":
254
  _module.scale = alpha
255
+
256
+
257
+ def _text_lora_path(path: str) -> str:
258
+ assert path.endswith(".pt"), "Only .pt files are supported"
259
+ return ".".join(path.split(".")[:-1] + ["text_encoder", "pt"])
260
+
261
+
262
+ def _ti_lora_path(path: str) -> str:
263
+ assert path.endswith(".pt"), "Only .pt files are supported"
264
+ return ".".join(path.split(".")[:-1] + ["ti", "pt"])
265
+
266
+
267
+ def load_learned_embed_in_clip(
268
+ learned_embeds_path, text_encoder, tokenizer, token=None, idempotent=False
269
+ ):
270
+ loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu")
271
+
272
+ # separate token and the embeds
273
+ trained_token = list(loaded_learned_embeds.keys())[0]
274
+ embeds = loaded_learned_embeds[trained_token]
275
+
276
+ # cast to dtype of text_encoder
277
+ dtype = text_encoder.get_input_embeddings().weight.dtype
278
+
279
+ # add the token in tokenizer
280
+ token = token if token is not None else trained_token
281
+ num_added_tokens = tokenizer.add_tokens(token)
282
+ i = 1
283
+ if num_added_tokens == 0 and idempotent:
284
+ return token
285
+
286
+ while num_added_tokens == 0:
287
+ print(f"The tokenizer already contains the token {token}.")
288
+ token = f"{token[:-1]}-{i}>"
289
+ print(f"Attempting to add the token {token}.")
290
+ num_added_tokens = tokenizer.add_tokens(token)
291
+ i += 1
292
+
293
+ # resize the token embeddings
294
+ text_encoder.resize_token_embeddings(len(tokenizer))
295
+
296
+ # get the id for the token and assign the embeds
297
+ token_id = tokenizer.convert_tokens_to_ids(token)
298
+ text_encoder.get_input_embeddings().weight.data[token_id] = embeds
299
+ return token
300
+
301
+
302
+ def patch_pipe(
303
+ pipe,
304
+ unet_path,
305
+ token,
306
+ alpha: float = 1.0,
307
+ r: int = 4,
308
+ patch_text=False,
309
+ patch_ti=False,
310
+ idempotent_token=True,
311
+ ):
312
+
313
+ ti_path = _ti_lora_path(unet_path)
314
+ text_path = _text_lora_path(unet_path)
315
+
316
+ unet_has_lora = False
317
+ text_encoder_has_lora = False
318
+
319
+ for _module in pipe.unet.modules():
320
+ if _module.__class__.__name__ == "LoraInjectedLinear":
321
+ unet_has_lora = True
322
+
323
+ for _module in pipe.text_encoder.modules():
324
+ if _module.__class__.__name__ == "LoraInjectedLinear":
325
+ text_encoder_has_lora = True
326
+
327
+ if not unet_has_lora:
328
+ monkeypatch_lora(pipe.unet, torch.load(unet_path), r=r)
329
+ else:
330
+ monkeypatch_replace_lora(pipe.unet, torch.load(unet_path), r=r)
331
+
332
+ if patch_text:
333
+ if not text_encoder_has_lora:
334
+ monkeypatch_lora(
335
+ pipe.text_encoder,
336
+ torch.load(text_path),
337
+ target_replace_module=["CLIPAttention"],
338
+ r=r,
339
+ )
340
+ else:
341
+
342
+ monkeypatch_replace_lora(
343
+ pipe.text_encoder,
344
+ torch.load(text_path),
345
+ target_replace_module=["CLIPAttention"],
346
+ r=r,
347
+ )
348
+ if patch_ti:
349
+ token = load_learned_embed_in_clip(
350
+ ti_path,
351
+ pipe.text_encoder,
352
+ pipe.tokenizer,
353
+ token,
354
+ idempotent=idempotent_token,
355
+ )
lora_diffusion/to_ckpt_v2.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from https://gist.github.com/jachiam/8a5c0b607e38fcc585168b90c686eb05
2
+ # Script for converting a HF Diffusers saved pipeline to a Stable Diffusion checkpoint.
3
+ # *Only* converts the UNet, VAE, and Text Encoder.
4
+ # Does not convert optimizer state or any other thing.
5
+ # Written by jachiam
6
+ import argparse
7
+ import os.path as osp
8
+
9
+ import torch
10
+
11
+
12
+ # =================#
13
+ # UNet Conversion #
14
+ # =================#
15
+
16
+ unet_conversion_map = [
17
+ # (stable-diffusion, HF Diffusers)
18
+ ("time_embed.0.weight", "time_embedding.linear_1.weight"),
19
+ ("time_embed.0.bias", "time_embedding.linear_1.bias"),
20
+ ("time_embed.2.weight", "time_embedding.linear_2.weight"),
21
+ ("time_embed.2.bias", "time_embedding.linear_2.bias"),
22
+ ("input_blocks.0.0.weight", "conv_in.weight"),
23
+ ("input_blocks.0.0.bias", "conv_in.bias"),
24
+ ("out.0.weight", "conv_norm_out.weight"),
25
+ ("out.0.bias", "conv_norm_out.bias"),
26
+ ("out.2.weight", "conv_out.weight"),
27
+ ("out.2.bias", "conv_out.bias"),
28
+ ]
29
+
30
+ unet_conversion_map_resnet = [
31
+ # (stable-diffusion, HF Diffusers)
32
+ ("in_layers.0", "norm1"),
33
+ ("in_layers.2", "conv1"),
34
+ ("out_layers.0", "norm2"),
35
+ ("out_layers.3", "conv2"),
36
+ ("emb_layers.1", "time_emb_proj"),
37
+ ("skip_connection", "conv_shortcut"),
38
+ ]
39
+
40
+ unet_conversion_map_layer = []
41
+ # hardcoded number of downblocks and resnets/attentions...
42
+ # would need smarter logic for other networks.
43
+ for i in range(4):
44
+ # loop over downblocks/upblocks
45
+
46
+ for j in range(2):
47
+ # loop over resnets/attentions for downblocks
48
+ hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
49
+ sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
50
+ unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
51
+
52
+ if i < 3:
53
+ # no attention layers in down_blocks.3
54
+ hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
55
+ sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
56
+ unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
57
+
58
+ for j in range(3):
59
+ # loop over resnets/attentions for upblocks
60
+ hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
61
+ sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
62
+ unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
63
+
64
+ if i > 0:
65
+ # no attention layers in up_blocks.0
66
+ hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
67
+ sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
68
+ unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
69
+
70
+ if i < 3:
71
+ # no downsample in down_blocks.3
72
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
73
+ sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
74
+ unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
75
+
76
+ # no upsample in up_blocks.3
77
+ hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
78
+ sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
79
+ unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
80
+
81
+ hf_mid_atn_prefix = "mid_block.attentions.0."
82
+ sd_mid_atn_prefix = "middle_block.1."
83
+ unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
84
+
85
+ for j in range(2):
86
+ hf_mid_res_prefix = f"mid_block.resnets.{j}."
87
+ sd_mid_res_prefix = f"middle_block.{2*j}."
88
+ unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
89
+
90
+
91
+ def convert_unet_state_dict(unet_state_dict):
92
+ # buyer beware: this is a *brittle* function,
93
+ # and correct output requires that all of these pieces interact in
94
+ # the exact order in which I have arranged them.
95
+ mapping = {k: k for k in unet_state_dict.keys()}
96
+ for sd_name, hf_name in unet_conversion_map:
97
+ mapping[hf_name] = sd_name
98
+ for k, v in mapping.items():
99
+ if "resnets" in k:
100
+ for sd_part, hf_part in unet_conversion_map_resnet:
101
+ v = v.replace(hf_part, sd_part)
102
+ mapping[k] = v
103
+ for k, v in mapping.items():
104
+ for sd_part, hf_part in unet_conversion_map_layer:
105
+ v = v.replace(hf_part, sd_part)
106
+ mapping[k] = v
107
+ new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
108
+ return new_state_dict
109
+
110
+
111
+ # ================#
112
+ # VAE Conversion #
113
+ # ================#
114
+
115
+ vae_conversion_map = [
116
+ # (stable-diffusion, HF Diffusers)
117
+ ("nin_shortcut", "conv_shortcut"),
118
+ ("norm_out", "conv_norm_out"),
119
+ ("mid.attn_1.", "mid_block.attentions.0."),
120
+ ]
121
+
122
+ for i in range(4):
123
+ # down_blocks have two resnets
124
+ for j in range(2):
125
+ hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
126
+ sd_down_prefix = f"encoder.down.{i}.block.{j}."
127
+ vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
128
+
129
+ if i < 3:
130
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
131
+ sd_downsample_prefix = f"down.{i}.downsample."
132
+ vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
133
+
134
+ hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
135
+ sd_upsample_prefix = f"up.{3-i}.upsample."
136
+ vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
137
+
138
+ # up_blocks have three resnets
139
+ # also, up blocks in hf are numbered in reverse from sd
140
+ for j in range(3):
141
+ hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
142
+ sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
143
+ vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
144
+
145
+ # this part accounts for mid blocks in both the encoder and the decoder
146
+ for i in range(2):
147
+ hf_mid_res_prefix = f"mid_block.resnets.{i}."
148
+ sd_mid_res_prefix = f"mid.block_{i+1}."
149
+ vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
150
+
151
+
152
+ vae_conversion_map_attn = [
153
+ # (stable-diffusion, HF Diffusers)
154
+ ("norm.", "group_norm."),
155
+ ("q.", "query."),
156
+ ("k.", "key."),
157
+ ("v.", "value."),
158
+ ("proj_out.", "proj_attn."),
159
+ ]
160
+
161
+
162
+ def reshape_weight_for_sd(w):
163
+ # convert HF linear weights to SD conv2d weights
164
+ return w.reshape(*w.shape, 1, 1)
165
+
166
+
167
+ def convert_vae_state_dict(vae_state_dict):
168
+ mapping = {k: k for k in vae_state_dict.keys()}
169
+ for k, v in mapping.items():
170
+ for sd_part, hf_part in vae_conversion_map:
171
+ v = v.replace(hf_part, sd_part)
172
+ mapping[k] = v
173
+ for k, v in mapping.items():
174
+ if "attentions" in k:
175
+ for sd_part, hf_part in vae_conversion_map_attn:
176
+ v = v.replace(hf_part, sd_part)
177
+ mapping[k] = v
178
+ new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
179
+ weights_to_convert = ["q", "k", "v", "proj_out"]
180
+ for k, v in new_state_dict.items():
181
+ for weight_name in weights_to_convert:
182
+ if f"mid.attn_1.{weight_name}.weight" in k:
183
+ print(f"Reshaping {k} for SD format")
184
+ new_state_dict[k] = reshape_weight_for_sd(v)
185
+ return new_state_dict
186
+
187
+
188
+ # =========================#
189
+ # Text Encoder Conversion #
190
+ # =========================#
191
+ # pretty much a no-op
192
+
193
+
194
+ def convert_text_enc_state_dict(text_enc_dict):
195
+ return text_enc_dict
196
+
197
+
198
+ def convert_to_ckpt(model_path, checkpoint_path, as_half):
199
+
200
+ assert model_path is not None, "Must provide a model path!"
201
+
202
+ assert checkpoint_path is not None, "Must provide a checkpoint path!"
203
+
204
+ unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.bin")
205
+ vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.bin")
206
+ text_enc_path = osp.join(model_path, "text_encoder", "pytorch_model.bin")
207
+
208
+ # Convert the UNet model
209
+ unet_state_dict = torch.load(unet_path, map_location="cpu")
210
+ unet_state_dict = convert_unet_state_dict(unet_state_dict)
211
+ unet_state_dict = {
212
+ "model.diffusion_model." + k: v for k, v in unet_state_dict.items()
213
+ }
214
+
215
+ # Convert the VAE model
216
+ vae_state_dict = torch.load(vae_path, map_location="cpu")
217
+ vae_state_dict = convert_vae_state_dict(vae_state_dict)
218
+ vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}
219
+
220
+ # Convert the text encoder model
221
+ text_enc_dict = torch.load(text_enc_path, map_location="cpu")
222
+ text_enc_dict = convert_text_enc_state_dict(text_enc_dict)
223
+ text_enc_dict = {
224
+ "cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()
225
+ }
226
+
227
+ # Put together new checkpoint
228
+ state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
229
+ if as_half:
230
+ state_dict = {k: v.half() for k, v in state_dict.items()}
231
+ state_dict = {"state_dict": state_dict}
232
+ torch.save(state_dict, checkpoint_path)