JadenFK commited on
Commit
640a27b
1 Parent(s): 69190f3

Init for demo

Browse files
Files changed (6) hide show
  1. .gitignore +1 -0
  2. app.py +72 -0
  3. convertModels.py +907 -0
  4. requirements.txt +7 -0
  5. test.py +19 -0
  6. train_esd.py +324 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
1
+ __pycache__
app.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from train_esd import train_esd
3
+
4
+
5
+ ckpt_path = "stable-diffusion/models/ldm/sd-v1-4-full-ema.ckpt"
6
+ config_path = "stable-diffusion/configs/stable-diffusion/v1-inference.yaml"
7
+ diffusers_config_path = "stable-diffusion/config.json"
8
+
9
+ def train(prompt, train_method, neg_guidance, iterations, lr):
10
+
11
+ train_esd(prompt,
12
+ train_method,
13
+ 3,
14
+ neg_guidance,
15
+ iterations,
16
+ lr,
17
+ config_path,
18
+ ckpt_path,
19
+ diffusers_config_path,
20
+ ['cuda']
21
+ )
22
+
23
+
24
+ with gr.Blocks() as demo:
25
+
26
+ prompt_input = gr.Text(
27
+ placeholder="Enter prompt...",
28
+ label="Prompt",
29
+ info="Prompt corresponding to concept to erase"
30
+ )
31
+ train_method_input = gr.Dropdown(
32
+ choices=['noxattn', 'selfattn', 'xattn', 'full'],
33
+ value='xattn',
34
+ label='Train Method',
35
+ info='Method of training'
36
+ )
37
+
38
+ neg_guidance_input = gr.Number(
39
+ value=1,
40
+ label="Negative Guidance",
41
+ info='Guidance of negative training used to train'
42
+ )
43
+
44
+ iterations_input = gr.Number(
45
+ value=1000,
46
+ precision=0,
47
+ label="Iterations",
48
+ info='iterations used to train'
49
+ )
50
+
51
+ lr_input = gr.Number(
52
+ value=1e-5,
53
+ label="Iterations",
54
+ info='Learning rate used to train'
55
+ )
56
+
57
+ train_button = gr.Button(
58
+ value="Train",
59
+ )
60
+ train_button.click(train, inputs = [
61
+ prompt_input,
62
+ train_method_input,
63
+ neg_guidance_input,
64
+ iterations_input,
65
+ lr_input
66
+ ]
67
+ )
68
+
69
+
70
+
71
+
72
+ demo.launch()
convertModels.py ADDED
@@ -0,0 +1,907 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ Conversion script for the LDM checkpoints. """
16
+
17
+ import argparse
18
+ import os
19
+ import re
20
+
21
+ import torch
22
+
23
+
24
+
25
+ try:
26
+ from omegaconf import OmegaConf
27
+ except ImportError:
28
+ raise ImportError(
29
+ "OmegaConf is required to convert the LDM checkpoints. Please install it with `pip install OmegaConf`."
30
+ )
31
+
32
+ from diffusers import (
33
+ AutoencoderKL,
34
+ DDIMScheduler,
35
+ DPMSolverMultistepScheduler,
36
+ EulerAncestralDiscreteScheduler,
37
+ EulerDiscreteScheduler,
38
+ HeunDiscreteScheduler,
39
+ LDMTextToImagePipeline,
40
+ LMSDiscreteScheduler,
41
+ PNDMScheduler,
42
+ StableDiffusionPipeline,
43
+ UNet2DConditionModel,
44
+ )
45
+ from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
46
+ from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder, PaintByExamplePipeline
47
+ from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
48
+ from transformers import AutoFeatureExtractor, BertTokenizerFast, CLIPTextModel, CLIPTokenizer, CLIPVisionConfig
49
+
50
+
51
+ def shave_segments(path, n_shave_prefix_segments=1):
52
+ """
53
+ Removes segments. Positive values shave the first segments, negative shave the last segments.
54
+ """
55
+ if n_shave_prefix_segments >= 0:
56
+ return ".".join(path.split(".")[n_shave_prefix_segments:])
57
+ else:
58
+ return ".".join(path.split(".")[:n_shave_prefix_segments])
59
+
60
+
61
+ def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
62
+ """
63
+ Updates paths inside resnets to the new naming scheme (local renaming)
64
+ """
65
+ mapping = []
66
+ for old_item in old_list:
67
+ new_item = old_item.replace("in_layers.0", "norm1")
68
+ new_item = new_item.replace("in_layers.2", "conv1")
69
+
70
+ new_item = new_item.replace("out_layers.0", "norm2")
71
+ new_item = new_item.replace("out_layers.3", "conv2")
72
+
73
+ new_item = new_item.replace("emb_layers.1", "time_emb_proj")
74
+ new_item = new_item.replace("skip_connection", "conv_shortcut")
75
+
76
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
77
+
78
+ mapping.append({"old": old_item, "new": new_item})
79
+
80
+ return mapping
81
+
82
+
83
+ def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
84
+ """
85
+ Updates paths inside resnets to the new naming scheme (local renaming)
86
+ """
87
+ mapping = []
88
+ for old_item in old_list:
89
+ new_item = old_item
90
+
91
+ new_item = new_item.replace("nin_shortcut", "conv_shortcut")
92
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
93
+
94
+ mapping.append({"old": old_item, "new": new_item})
95
+
96
+ return mapping
97
+
98
+
99
+ def renew_attention_paths(old_list, n_shave_prefix_segments=0):
100
+ """
101
+ Updates paths inside attentions to the new naming scheme (local renaming)
102
+ """
103
+ mapping = []
104
+ for old_item in old_list:
105
+ new_item = old_item
106
+
107
+ # new_item = new_item.replace('norm.weight', 'group_norm.weight')
108
+ # new_item = new_item.replace('norm.bias', 'group_norm.bias')
109
+
110
+ # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
111
+ # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
112
+
113
+ # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
114
+
115
+ mapping.append({"old": old_item, "new": new_item})
116
+
117
+ return mapping
118
+
119
+
120
+ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
121
+ """
122
+ Updates paths inside attentions to the new naming scheme (local renaming)
123
+ """
124
+ mapping = []
125
+ for old_item in old_list:
126
+ new_item = old_item
127
+
128
+ new_item = new_item.replace("norm.weight", "group_norm.weight")
129
+ new_item = new_item.replace("norm.bias", "group_norm.bias")
130
+
131
+ new_item = new_item.replace("q.weight", "query.weight")
132
+ new_item = new_item.replace("q.bias", "query.bias")
133
+
134
+ new_item = new_item.replace("k.weight", "key.weight")
135
+ new_item = new_item.replace("k.bias", "key.bias")
136
+
137
+ new_item = new_item.replace("v.weight", "value.weight")
138
+ new_item = new_item.replace("v.bias", "value.bias")
139
+
140
+ new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
141
+ new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
142
+
143
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
144
+
145
+ mapping.append({"old": old_item, "new": new_item})
146
+
147
+ return mapping
148
+
149
+
150
+ def assign_to_checkpoint(
151
+ paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
152
+ ):
153
+ """
154
+ This does the final conversion step: take locally converted weights and apply a global renaming
155
+ to them. It splits attention layers, and takes into account additional replacements
156
+ that may arise.
157
+
158
+ Assigns the weights to the new checkpoint.
159
+ """
160
+ assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
161
+
162
+ # Splits the attention layers into three variables.
163
+ if attention_paths_to_split is not None:
164
+ for path, path_map in attention_paths_to_split.items():
165
+ old_tensor = old_checkpoint[path]
166
+ channels = old_tensor.shape[0] // 3
167
+
168
+ target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
169
+
170
+ num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
171
+
172
+ old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
173
+ query, key, value = old_tensor.split(channels // num_heads, dim=1)
174
+
175
+ checkpoint[path_map["query"]] = query.reshape(target_shape)
176
+ checkpoint[path_map["key"]] = key.reshape(target_shape)
177
+ checkpoint[path_map["value"]] = value.reshape(target_shape)
178
+
179
+ for path in paths:
180
+ new_path = path["new"]
181
+
182
+ # These have already been assigned
183
+ if attention_paths_to_split is not None and new_path in attention_paths_to_split:
184
+ continue
185
+
186
+ # Global renaming happens here
187
+ new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
188
+ new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
189
+ new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
190
+
191
+ if additional_replacements is not None:
192
+ for replacement in additional_replacements:
193
+ new_path = new_path.replace(replacement["old"], replacement["new"])
194
+
195
+ # proj_attn.weight has to be converted from conv 1D to linear
196
+ if "proj_attn.weight" in new_path:
197
+ checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
198
+ else:
199
+ checkpoint[new_path] = old_checkpoint[path["old"]]
200
+
201
+
202
+ def conv_attn_to_linear(checkpoint):
203
+ keys = list(checkpoint.keys())
204
+ attn_keys = ["query.weight", "key.weight", "value.weight"]
205
+ for key in keys:
206
+ if ".".join(key.split(".")[-2:]) in attn_keys:
207
+ if checkpoint[key].ndim > 2:
208
+ checkpoint[key] = checkpoint[key][:, :, 0, 0]
209
+ elif "proj_attn.weight" in key:
210
+ if checkpoint[key].ndim > 2:
211
+ checkpoint[key] = checkpoint[key][:, :, 0]
212
+
213
+
214
+ def create_unet_diffusers_config(original_config, image_size: int):
215
+ """
216
+ Creates a config for the diffusers based on the config of the LDM model.
217
+ """
218
+ unet_params = original_config.model.params.unet_config.params
219
+ vae_params = original_config.model.params.first_stage_config.params.ddconfig
220
+
221
+ block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult]
222
+
223
+ down_block_types = []
224
+ resolution = 1
225
+ for i in range(len(block_out_channels)):
226
+ block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D"
227
+ down_block_types.append(block_type)
228
+ if i != len(block_out_channels) - 1:
229
+ resolution *= 2
230
+
231
+ up_block_types = []
232
+ for i in range(len(block_out_channels)):
233
+ block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D"
234
+ up_block_types.append(block_type)
235
+ resolution //= 2
236
+
237
+ vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1)
238
+
239
+ head_dim = unet_params.num_heads if "num_heads" in unet_params else None
240
+ use_linear_projection = (
241
+ unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False
242
+ )
243
+ if use_linear_projection:
244
+ # stable diffusion 2-base-512 and 2-768
245
+ if head_dim is None:
246
+ head_dim = [5, 10, 20, 20]
247
+
248
+ config = dict(
249
+ sample_size=image_size // vae_scale_factor,
250
+ in_channels=unet_params.in_channels,
251
+ out_channels=unet_params.out_channels,
252
+ down_block_types=tuple(down_block_types),
253
+ up_block_types=tuple(up_block_types),
254
+ block_out_channels=tuple(block_out_channels),
255
+ layers_per_block=unet_params.num_res_blocks,
256
+ cross_attention_dim=unet_params.context_dim,
257
+ attention_head_dim=head_dim,
258
+ use_linear_projection=use_linear_projection,
259
+ )
260
+
261
+ return config
262
+
263
+
264
+ def create_vae_diffusers_config(original_config, image_size: int):
265
+ """
266
+ Creates a config for the diffusers based on the config of the LDM model.
267
+ """
268
+ vae_params = original_config.model.params.first_stage_config.params.ddconfig
269
+ _ = original_config.model.params.first_stage_config.params.embed_dim
270
+
271
+ block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult]
272
+ down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
273
+ up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
274
+
275
+ config = dict(
276
+ sample_size=image_size,
277
+ in_channels=vae_params.in_channels,
278
+ out_channels=vae_params.out_ch,
279
+ down_block_types=tuple(down_block_types),
280
+ up_block_types=tuple(up_block_types),
281
+ block_out_channels=tuple(block_out_channels),
282
+ latent_channels=vae_params.z_channels,
283
+ layers_per_block=vae_params.num_res_blocks,
284
+ )
285
+ return config
286
+
287
+
288
+ def create_diffusers_schedular(original_config):
289
+ schedular = DDIMScheduler(
290
+ num_train_timesteps=original_config.model.params.timesteps,
291
+ beta_start=original_config.model.params.linear_start,
292
+ beta_end=original_config.model.params.linear_end,
293
+ beta_schedule="scaled_linear",
294
+ )
295
+ return schedular
296
+
297
+
298
+ def create_ldm_bert_config(original_config):
299
+ bert_params = original_config.model.parms.cond_stage_config.params
300
+ config = LDMBertConfig(
301
+ d_model=bert_params.n_embed,
302
+ encoder_layers=bert_params.n_layer,
303
+ encoder_ffn_dim=bert_params.n_embed * 4,
304
+ )
305
+ return config
306
+
307
+
308
+ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False):
309
+ """
310
+ Takes a state dict and a config, and returns a converted checkpoint.
311
+ """
312
+
313
+ # extract state_dict for UNet
314
+ unet_state_dict = {}
315
+ keys = list(checkpoint.keys())
316
+
317
+ unet_key = "model.diffusion_model."
318
+ # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
319
+ if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema:
320
+ print(f"Checkpoint {path} has both EMA and non-EMA weights.")
321
+ print(
322
+ "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
323
+ " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
324
+ )
325
+ for key in keys:
326
+ if key.startswith("model.diffusion_model"):
327
+ flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
328
+ unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
329
+ else:
330
+ if sum(k.startswith("model_ema") for k in keys) > 100:
331
+ print(
332
+ "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
333
+ " weights (usually better for inference), please make sure to add the `--extract_ema` flag."
334
+ )
335
+
336
+ for key in keys:
337
+ if key.startswith(unet_key):
338
+ unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
339
+
340
+ new_checkpoint = {}
341
+
342
+ new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
343
+ new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
344
+ new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
345
+ new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
346
+
347
+ new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
348
+ new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
349
+
350
+ new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
351
+ new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
352
+ new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
353
+ new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
354
+
355
+ # Retrieves the keys for the input blocks only
356
+ num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
357
+ input_blocks = {
358
+ layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
359
+ for layer_id in range(num_input_blocks)
360
+ }
361
+
362
+ # Retrieves the keys for the middle blocks only
363
+ num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
364
+ middle_blocks = {
365
+ layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
366
+ for layer_id in range(num_middle_blocks)
367
+ }
368
+
369
+ # Retrieves the keys for the output blocks only
370
+ num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
371
+ output_blocks = {
372
+ layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
373
+ for layer_id in range(num_output_blocks)
374
+ }
375
+
376
+ for i in range(1, num_input_blocks):
377
+ block_id = (i - 1) // (config["layers_per_block"] + 1)
378
+ layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
379
+
380
+ resnets = [
381
+ key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
382
+ ]
383
+ attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
384
+
385
+ if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
386
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
387
+ f"input_blocks.{i}.0.op.weight"
388
+ )
389
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
390
+ f"input_blocks.{i}.0.op.bias"
391
+ )
392
+
393
+ paths = renew_resnet_paths(resnets)
394
+ meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
395
+ assign_to_checkpoint(
396
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
397
+ )
398
+
399
+ if len(attentions):
400
+ paths = renew_attention_paths(attentions)
401
+ meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
402
+ assign_to_checkpoint(
403
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
404
+ )
405
+
406
+ resnet_0 = middle_blocks[0]
407
+ attentions = middle_blocks[1]
408
+ resnet_1 = middle_blocks[2]
409
+
410
+ resnet_0_paths = renew_resnet_paths(resnet_0)
411
+ assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
412
+
413
+ resnet_1_paths = renew_resnet_paths(resnet_1)
414
+ assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
415
+
416
+ attentions_paths = renew_attention_paths(attentions)
417
+ meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
418
+ assign_to_checkpoint(
419
+ attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
420
+ )
421
+
422
+ for i in range(num_output_blocks):
423
+ block_id = i // (config["layers_per_block"] + 1)
424
+ layer_in_block_id = i % (config["layers_per_block"] + 1)
425
+ output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
426
+ output_block_list = {}
427
+
428
+ for layer in output_block_layers:
429
+ layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
430
+ if layer_id in output_block_list:
431
+ output_block_list[layer_id].append(layer_name)
432
+ else:
433
+ output_block_list[layer_id] = [layer_name]
434
+
435
+ if len(output_block_list) > 1:
436
+ resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
437
+ attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
438
+
439
+ resnet_0_paths = renew_resnet_paths(resnets)
440
+ paths = renew_resnet_paths(resnets)
441
+
442
+ meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
443
+ assign_to_checkpoint(
444
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
445
+ )
446
+
447
+ output_block_list = {k: sorted(v) for k, v in output_block_list.items()}
448
+ if ["conv.bias", "conv.weight"] in output_block_list.values():
449
+ index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
450
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
451
+ f"output_blocks.{i}.{index}.conv.weight"
452
+ ]
453
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
454
+ f"output_blocks.{i}.{index}.conv.bias"
455
+ ]
456
+
457
+ # Clear attentions as they have been attributed above.
458
+ if len(attentions) == 2:
459
+ attentions = []
460
+
461
+ if len(attentions):
462
+ paths = renew_attention_paths(attentions)
463
+ meta_path = {
464
+ "old": f"output_blocks.{i}.1",
465
+ "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
466
+ }
467
+ assign_to_checkpoint(
468
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
469
+ )
470
+ else:
471
+ resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
472
+ for path in resnet_0_paths:
473
+ old_path = ".".join(["output_blocks", str(i), path["old"]])
474
+ new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
475
+
476
+ new_checkpoint[new_path] = unet_state_dict[old_path]
477
+
478
+ return new_checkpoint
479
+
480
+
481
+ def convert_ldm_vae_checkpoint(checkpoint, config):
482
+ # extract state dict for VAE
483
+ vae_state_dict = {}
484
+ vae_key = "first_stage_model."
485
+ keys = list(checkpoint.keys())
486
+ for key in keys:
487
+ if key.startswith(vae_key):
488
+ vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
489
+
490
+ new_checkpoint = {}
491
+
492
+ new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
493
+ new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
494
+ new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
495
+ new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
496
+ new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
497
+ new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
498
+
499
+ new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
500
+ new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
501
+ new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
502
+ new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
503
+ new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
504
+ new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
505
+
506
+ new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
507
+ new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
508
+ new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
509
+ new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
510
+
511
+ # Retrieves the keys for the encoder down blocks only
512
+ num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
513
+ down_blocks = {
514
+ layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
515
+ }
516
+
517
+ # Retrieves the keys for the decoder up blocks only
518
+ num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
519
+ up_blocks = {
520
+ layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
521
+ }
522
+
523
+ for i in range(num_down_blocks):
524
+ resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
525
+
526
+ if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
527
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
528
+ f"encoder.down.{i}.downsample.conv.weight"
529
+ )
530
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
531
+ f"encoder.down.{i}.downsample.conv.bias"
532
+ )
533
+
534
+ paths = renew_vae_resnet_paths(resnets)
535
+ meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
536
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
537
+
538
+ mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
539
+ num_mid_res_blocks = 2
540
+ for i in range(1, num_mid_res_blocks + 1):
541
+ resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
542
+
543
+ paths = renew_vae_resnet_paths(resnets)
544
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
545
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
546
+
547
+ mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
548
+ paths = renew_vae_attention_paths(mid_attentions)
549
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
550
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
551
+ conv_attn_to_linear(new_checkpoint)
552
+
553
+ for i in range(num_up_blocks):
554
+ block_id = num_up_blocks - 1 - i
555
+ resnets = [
556
+ key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
557
+ ]
558
+
559
+ if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
560
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
561
+ f"decoder.up.{block_id}.upsample.conv.weight"
562
+ ]
563
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
564
+ f"decoder.up.{block_id}.upsample.conv.bias"
565
+ ]
566
+
567
+ paths = renew_vae_resnet_paths(resnets)
568
+ meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
569
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
570
+
571
+ mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
572
+ num_mid_res_blocks = 2
573
+ for i in range(1, num_mid_res_blocks + 1):
574
+ resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
575
+
576
+ paths = renew_vae_resnet_paths(resnets)
577
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
578
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
579
+
580
+ mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
581
+ paths = renew_vae_attention_paths(mid_attentions)
582
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
583
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
584
+ conv_attn_to_linear(new_checkpoint)
585
+ return new_checkpoint
586
+
587
+
588
+ def convert_ldm_bert_checkpoint(checkpoint, config):
589
+ def _copy_attn_layer(hf_attn_layer, pt_attn_layer):
590
+ hf_attn_layer.q_proj.weight.data = pt_attn_layer.to_q.weight
591
+ hf_attn_layer.k_proj.weight.data = pt_attn_layer.to_k.weight
592
+ hf_attn_layer.v_proj.weight.data = pt_attn_layer.to_v.weight
593
+
594
+ hf_attn_layer.out_proj.weight = pt_attn_layer.to_out.weight
595
+ hf_attn_layer.out_proj.bias = pt_attn_layer.to_out.bias
596
+
597
+ def _copy_linear(hf_linear, pt_linear):
598
+ hf_linear.weight = pt_linear.weight
599
+ hf_linear.bias = pt_linear.bias
600
+
601
+ def _copy_layer(hf_layer, pt_layer):
602
+ # copy layer norms
603
+ _copy_linear(hf_layer.self_attn_layer_norm, pt_layer[0][0])
604
+ _copy_linear(hf_layer.final_layer_norm, pt_layer[1][0])
605
+
606
+ # copy attn
607
+ _copy_attn_layer(hf_layer.self_attn, pt_layer[0][1])
608
+
609
+ # copy MLP
610
+ pt_mlp = pt_layer[1][1]
611
+ _copy_linear(hf_layer.fc1, pt_mlp.net[0][0])
612
+ _copy_linear(hf_layer.fc2, pt_mlp.net[2])
613
+
614
+ def _copy_layers(hf_layers, pt_layers):
615
+ for i, hf_layer in enumerate(hf_layers):
616
+ if i != 0:
617
+ i += i
618
+ pt_layer = pt_layers[i : i + 2]
619
+ _copy_layer(hf_layer, pt_layer)
620
+
621
+ hf_model = LDMBertModel(config).eval()
622
+
623
+ # copy embeds
624
+ hf_model.model.embed_tokens.weight = checkpoint.transformer.token_emb.weight
625
+ hf_model.model.embed_positions.weight.data = checkpoint.transformer.pos_emb.emb.weight
626
+
627
+ # copy layer norm
628
+ _copy_linear(hf_model.model.layer_norm, checkpoint.transformer.norm)
629
+
630
+ # copy hidden layers
631
+ _copy_layers(hf_model.model.layers, checkpoint.transformer.attn_layers.layers)
632
+
633
+ _copy_linear(hf_model.to_logits, checkpoint.transformer.to_logits)
634
+
635
+ return hf_model
636
+
637
+
638
+ def convert_ldm_clip_checkpoint(checkpoint):
639
+ text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
640
+
641
+ keys = list(checkpoint.keys())
642
+
643
+ text_model_dict = {}
644
+
645
+ for key in keys:
646
+ if key.startswith("cond_stage_model.transformer"):
647
+ text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
648
+
649
+ text_model.load_state_dict(text_model_dict)
650
+
651
+ return text_model
652
+
653
+
654
+ textenc_conversion_lst = [
655
+ ("cond_stage_model.model.positional_embedding", "text_model.embeddings.position_embedding.weight"),
656
+ ("cond_stage_model.model.token_embedding.weight", "text_model.embeddings.token_embedding.weight"),
657
+ ("cond_stage_model.model.ln_final.weight", "text_model.final_layer_norm.weight"),
658
+ ("cond_stage_model.model.ln_final.bias", "text_model.final_layer_norm.bias"),
659
+ ]
660
+ textenc_conversion_map = {x[0]: x[1] for x in textenc_conversion_lst}
661
+
662
+ textenc_transformer_conversion_lst = [
663
+ # (stable-diffusion, HF Diffusers)
664
+ ("resblocks.", "text_model.encoder.layers."),
665
+ ("ln_1", "layer_norm1"),
666
+ ("ln_2", "layer_norm2"),
667
+ (".c_fc.", ".fc1."),
668
+ (".c_proj.", ".fc2."),
669
+ (".attn", ".self_attn"),
670
+ ("ln_final.", "transformer.text_model.final_layer_norm."),
671
+ ("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"),
672
+ ("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"),
673
+ ]
674
+ protected = {re.escape(x[0]): x[1] for x in textenc_transformer_conversion_lst}
675
+ textenc_pattern = re.compile("|".join(protected.keys()))
676
+
677
+
678
+ def convert_paint_by_example_checkpoint(checkpoint):
679
+ config = CLIPVisionConfig.from_pretrained("openai/clip-vit-large-patch14")
680
+ model = PaintByExampleImageEncoder(config)
681
+
682
+ keys = list(checkpoint.keys())
683
+
684
+ text_model_dict = {}
685
+
686
+ for key in keys:
687
+ if key.startswith("cond_stage_model.transformer"):
688
+ text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
689
+
690
+ # load clip vision
691
+ model.model.load_state_dict(text_model_dict)
692
+
693
+ # load mapper
694
+ keys_mapper = {
695
+ k[len("cond_stage_model.mapper.res") :]: v
696
+ for k, v in checkpoint.items()
697
+ if k.startswith("cond_stage_model.mapper")
698
+ }
699
+
700
+ MAPPING = {
701
+ "attn.c_qkv": ["attn1.to_q", "attn1.to_k", "attn1.to_v"],
702
+ "attn.c_proj": ["attn1.to_out.0"],
703
+ "ln_1": ["norm1"],
704
+ "ln_2": ["norm3"],
705
+ "mlp.c_fc": ["ff.net.0.proj"],
706
+ "mlp.c_proj": ["ff.net.2"],
707
+ }
708
+
709
+ mapped_weights = {}
710
+ for key, value in keys_mapper.items():
711
+ prefix = key[: len("blocks.i")]
712
+ suffix = key.split(prefix)[-1].split(".")[-1]
713
+ name = key.split(prefix)[-1].split(suffix)[0][1:-1]
714
+ mapped_names = MAPPING[name]
715
+
716
+ num_splits = len(mapped_names)
717
+ for i, mapped_name in enumerate(mapped_names):
718
+ new_name = ".".join([prefix, mapped_name, suffix])
719
+ shape = value.shape[0] // num_splits
720
+ mapped_weights[new_name] = value[i * shape : (i + 1) * shape]
721
+
722
+ model.mapper.load_state_dict(mapped_weights)
723
+
724
+ # load final layer norm
725
+ model.final_layer_norm.load_state_dict(
726
+ {
727
+ "bias": checkpoint["cond_stage_model.final_ln.bias"],
728
+ "weight": checkpoint["cond_stage_model.final_ln.weight"],
729
+ }
730
+ )
731
+
732
+ # load final proj
733
+ model.proj_out.load_state_dict(
734
+ {
735
+ "bias": checkpoint["proj_out.bias"],
736
+ "weight": checkpoint["proj_out.weight"],
737
+ }
738
+ )
739
+
740
+ # load uncond vector
741
+ model.uncond_vector.data = torch.nn.Parameter(checkpoint["learnable_vector"])
742
+ return model
743
+
744
+
745
+ def convert_open_clip_checkpoint(checkpoint):
746
+ text_model = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder")
747
+
748
+ keys = list(checkpoint.keys())
749
+
750
+ text_model_dict = {}
751
+
752
+ d_model = int(checkpoint["cond_stage_model.model.text_projection"].shape[0])
753
+
754
+ text_model_dict["text_model.embeddings.position_ids"] = text_model.text_model.embeddings.get_buffer("position_ids")
755
+
756
+ for key in keys:
757
+ if "resblocks.23" in key: # Diffusers drops the final layer and only uses the penultimate layer
758
+ continue
759
+ if key in textenc_conversion_map:
760
+ text_model_dict[textenc_conversion_map[key]] = checkpoint[key]
761
+ if key.startswith("cond_stage_model.model.transformer."):
762
+ new_key = key[len("cond_stage_model.model.transformer.") :]
763
+ if new_key.endswith(".in_proj_weight"):
764
+ new_key = new_key[: -len(".in_proj_weight")]
765
+ new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
766
+ text_model_dict[new_key + ".q_proj.weight"] = checkpoint[key][:d_model, :]
767
+ text_model_dict[new_key + ".k_proj.weight"] = checkpoint[key][d_model : d_model * 2, :]
768
+ text_model_dict[new_key + ".v_proj.weight"] = checkpoint[key][d_model * 2 :, :]
769
+ elif new_key.endswith(".in_proj_bias"):
770
+ new_key = new_key[: -len(".in_proj_bias")]
771
+ new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
772
+ text_model_dict[new_key + ".q_proj.bias"] = checkpoint[key][:d_model]
773
+ text_model_dict[new_key + ".k_proj.bias"] = checkpoint[key][d_model : d_model * 2]
774
+ text_model_dict[new_key + ".v_proj.bias"] = checkpoint[key][d_model * 2 :]
775
+ else:
776
+ new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
777
+
778
+ text_model_dict[new_key] = checkpoint[key]
779
+
780
+ text_model.load_state_dict(text_model_dict)
781
+
782
+ return text_model
783
+
784
+
785
+ def savemodelDiffusers(name, compvis_config_file, diffusers_config_file, device='cpu'):
786
+ checkpoint_path = f'models/{name}/{name}.pt'
787
+
788
+ original_config_file = compvis_config_file
789
+ config_file = diffusers_config_file
790
+ num_in_channels = 4
791
+ scheduler_type = 'ddim'
792
+ pipeline_type = None
793
+ image_size = 512
794
+ prediction_type = 'epsilon'
795
+ extract_ema = False
796
+ dump_path = f"models/{name}/{name.replace('compvis','diffusers')}.pt"
797
+ upcast_attention = False
798
+
799
+
800
+ if device is None:
801
+ device = "cuda" if torch.cuda.is_available() else "cpu"
802
+ checkpoint = torch.load(checkpoint_path, map_location=device)
803
+ else:
804
+ checkpoint = torch.load(checkpoint_path, map_location=device)
805
+
806
+ # Sometimes models don't have the global_step item
807
+ if "global_step" in checkpoint:
808
+ global_step = checkpoint["global_step"]
809
+ else:
810
+ print("global_step key not found in model")
811
+ global_step = None
812
+
813
+ if "state_dict" in checkpoint:
814
+ checkpoint = checkpoint["state_dict"]
815
+ upcast_attention = upcast_attention
816
+ if original_config_file is None:
817
+ key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
818
+
819
+ if key_name in checkpoint and checkpoint[key_name].shape[-1] == 1024:
820
+ if not os.path.isfile("v2-inference-v.yaml"):
821
+ # model_type = "v2"
822
+ os.system(
823
+ "wget https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml"
824
+ " -O v2-inference-v.yaml"
825
+ )
826
+ original_config_file = "./v2-inference-v.yaml"
827
+
828
+ if global_step == 110000:
829
+ # v2.1 needs to upcast attention
830
+ upcast_attention = True
831
+ else:
832
+ if not os.path.isfile("v1-inference.yaml"):
833
+ # model_type = "v1"
834
+ os.system(
835
+ "wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
836
+ " -O v1-inference.yaml"
837
+ )
838
+ original_config_file = "./v1-inference.yaml"
839
+
840
+ original_config = OmegaConf.load(original_config_file)
841
+
842
+ if num_in_channels is not None:
843
+ original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels
844
+
845
+ if (
846
+ "parameterization" in original_config["model"]["params"]
847
+ and original_config["model"]["params"]["parameterization"] == "v"
848
+ ):
849
+ if prediction_type is None:
850
+ # NOTE: For stable diffusion 2 base it is recommended to pass `prediction_type=="epsilon"`
851
+ # as it relies on a brittle global step parameter here
852
+ prediction_type = "epsilon" if global_step == 875000 else "v_prediction"
853
+ if image_size is None:
854
+ # NOTE: For stable diffusion 2 base one has to pass `image_size==512`
855
+ # as it relies on a brittle global step parameter here
856
+ image_size = 512 if global_step == 875000 else 768
857
+ else:
858
+ if prediction_type is None:
859
+ prediction_type = "epsilon"
860
+ if image_size is None:
861
+ image_size = 512
862
+
863
+ num_train_timesteps = original_config.model.params.timesteps
864
+ beta_start = original_config.model.params.linear_start
865
+ beta_end = original_config.model.params.linear_end
866
+ scheduler = DDIMScheduler(
867
+ beta_end=beta_end,
868
+ beta_schedule="scaled_linear",
869
+ beta_start=beta_start,
870
+ num_train_timesteps=num_train_timesteps,
871
+ steps_offset=1,
872
+ clip_sample=False,
873
+ set_alpha_to_one=False,
874
+ prediction_type=prediction_type,
875
+ )
876
+ # make sure scheduler works correctly with DDIM
877
+ scheduler.register_to_config(clip_sample=False)
878
+
879
+ if scheduler_type == "pndm":
880
+ config = dict(scheduler.config)
881
+ config["skip_prk_steps"] = True
882
+ scheduler = PNDMScheduler.from_config(config)
883
+ elif scheduler_type == "lms":
884
+ scheduler = LMSDiscreteScheduler.from_config(scheduler.config)
885
+ elif scheduler_type == "heun":
886
+ scheduler = HeunDiscreteScheduler.from_config(scheduler.config)
887
+ elif scheduler_type == "euler":
888
+ scheduler = EulerDiscreteScheduler.from_config(scheduler.config)
889
+ elif scheduler_type == "euler-ancestral":
890
+ scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler.config)
891
+ elif scheduler_type == "dpm":
892
+ scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config)
893
+ elif scheduler_type == "ddim":
894
+ scheduler = scheduler
895
+ else:
896
+ raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!")
897
+
898
+ # Convert the UNet2DConditionModel model.
899
+ unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
900
+ unet_config["upcast_attention"] = False
901
+ unet = UNet2DConditionModel(**unet_config)
902
+
903
+ converted_unet_checkpoint = convert_ldm_unet_checkpoint(
904
+ checkpoint, unet_config, path=checkpoint_path, extract_ema=extract_ema
905
+ )
906
+ torch.save(converted_unet_checkpoint, dump_path)
907
+
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
1
+ omegaconf
2
+ torch
3
+ torchvision
4
+ einops
5
+ diffusers
6
+ transformers
7
+ pytorch_lightning
test.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.insert(0,'stable_diffusion')
3
+ from train_esd import train_esd
4
+
5
+ ckpt_path = "stable_diffusion/models/ldm/sd-v1-4-full-ema.ckpt"
6
+ config_path = "stable_diffusion/configs/stable-diffusion/v1-inference.yaml"
7
+ diffusers_config_path = "stable_diffusion/config.json"
8
+
9
+ train_esd("England",
10
+ 'xattn',
11
+ 3,
12
+ 1,
13
+ 1000,
14
+ .003,
15
+ config_path,
16
+ ckpt_path,
17
+ diffusers_config_path,
18
+ ['cuda', 'cuda']
19
+ )
train_esd.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from omegaconf import OmegaConf
2
+ import torch
3
+ from PIL import Image
4
+ from torchvision import transforms
5
+ import os
6
+ from tqdm import tqdm
7
+ from einops import rearrange
8
+ import numpy as np
9
+ from pathlib import Path
10
+ import matplotlib.pyplot as plt
11
+
12
+ from ldm.models.diffusion.ddim import DDIMSampler
13
+ from ldm.util import instantiate_from_config
14
+ import random
15
+ import glob
16
+ import re
17
+ import shutil
18
+ import pdb
19
+ import argparse
20
+ from convertModels import savemodelDiffusers
21
+ # Util Functions
22
+ def load_model_from_config(config, ckpt, device="cpu", verbose=False):
23
+ """Loads a model from config and a ckpt
24
+ if config is a path will use omegaconf to load
25
+ """
26
+ if isinstance(config, (str, Path)):
27
+ config = OmegaConf.load(config)
28
+
29
+ pl_sd = torch.load(ckpt, map_location="cpu")
30
+ global_step = pl_sd["global_step"]
31
+ sd = pl_sd["state_dict"]
32
+ model = instantiate_from_config(config.model)
33
+ m, u = model.load_state_dict(sd, strict=False)
34
+ model.to(device)
35
+ model.eval()
36
+ model.cond_stage_model.device = device
37
+ return model
38
+
39
+ @torch.no_grad()
40
+ def sample_model(model, sampler, c, h, w, ddim_steps, scale, ddim_eta, start_code=None, n_samples=1,t_start=-1,log_every_t=None,till_T=None,verbose=True):
41
+ """Sample the model"""
42
+ uc = None
43
+ if scale != 1.0:
44
+ uc = model.get_learned_conditioning(n_samples * [""])
45
+ log_t = 100
46
+ if log_every_t is not None:
47
+ log_t = log_every_t
48
+ shape = [4, h // 8, w // 8]
49
+ samples_ddim, inters = sampler.sample(S=ddim_steps,
50
+ conditioning=c,
51
+ batch_size=n_samples,
52
+ shape=shape,
53
+ verbose=False,
54
+ x_T=start_code,
55
+ unconditional_guidance_scale=scale,
56
+ unconditional_conditioning=uc,
57
+ eta=ddim_eta,
58
+ verbose_iter = verbose,
59
+ t_start=t_start,
60
+ log_every_t = log_t,
61
+ till_T = till_T
62
+ )
63
+ if log_every_t is not None:
64
+ return samples_ddim, inters
65
+ return samples_ddim
66
+
67
+ def load_img(path, target_size=512):
68
+ """Load an image, resize and output -1..1"""
69
+ image = Image.open(path).convert("RGB")
70
+
71
+
72
+ tform = transforms.Compose([
73
+ transforms.Resize(target_size),
74
+ transforms.CenterCrop(target_size),
75
+ transforms.ToTensor(),
76
+ ])
77
+ image = tform(image)
78
+ return 2.*image - 1.
79
+
80
+
81
+ def moving_average(a, n=3) :
82
+ ret = np.cumsum(a, dtype=float)
83
+ ret[n:] = ret[n:] - ret[:-n]
84
+ return ret[n - 1:] / n
85
+
86
+ def plot_loss(losses, path,word, n=100):
87
+ v = moving_average(losses, n)
88
+ plt.plot(v, label=f'{word}_loss')
89
+ plt.legend(loc="upper left")
90
+ plt.title('Average loss in trainings', fontsize=20)
91
+ plt.xlabel('Data point', fontsize=16)
92
+ plt.ylabel('Loss value', fontsize=16)
93
+ plt.savefig(path)
94
+
95
+ ##################### ESD Functions
96
+ def get_models(config_path, ckpt_path, devices):
97
+ model_orig = load_model_from_config(config_path, ckpt_path, devices[1])
98
+ sampler_orig = DDIMSampler(model_orig)
99
+
100
+ model = load_model_from_config(config_path, ckpt_path, devices[0])
101
+ sampler = DDIMSampler(model)
102
+
103
+ return model_orig, sampler_orig, model, sampler
104
+
105
+ def train_esd(prompt, train_method, start_guidance, negative_guidance, iterations, lr, config_path, ckpt_path, diffusers_config_path, devices, seperator=None, image_size=512, ddim_steps=50):
106
+ '''
107
+ Function to train diffusion models to erase concepts from model weights
108
+
109
+ Parameters
110
+ ----------
111
+ prompt : str
112
+ The concept to erase from diffusion model (Eg: "Van Gogh").
113
+ train_method : str
114
+ The parameters to train for erasure (ESD-x, ESD-u, full, selfattn).
115
+ start_guidance : float
116
+ Guidance to generate images for training.
117
+ negative_guidance : float
118
+ Guidance to erase the concepts from diffusion model.
119
+ iterations : int
120
+ Number of iterations to train.
121
+ lr : float
122
+ learning rate for fine tuning.
123
+ config_path : str
124
+ config path for compvis diffusion format.
125
+ ckpt_path : str
126
+ checkpoint path for pre-trained compvis diffusion weights.
127
+ diffusers_config_path : str
128
+ Config path for diffusers unet in json format.
129
+ devices : str
130
+ 2 devices used to load the models (Eg: '0,1' will load in cuda:0 and cuda:1).
131
+ seperator : str, optional
132
+ If the prompt has commas can use this to seperate the prompt for individual simulataneous erasures. The default is None.
133
+ image_size : int, optional
134
+ Image size for generated images. The default is 512.
135
+ ddim_steps : int, optional
136
+ Number of diffusion time steps. The default is 50.
137
+
138
+ Returns
139
+ -------
140
+ None
141
+
142
+ '''
143
+ # PROMPT CLEANING
144
+ word_print = prompt.replace(' ','')
145
+ if prompt == 'allartist':
146
+ prompt = "Kelly Mckernan, Thomas Kinkade, Ajin Demi Human, Alena Aenami, Tyler Edlin, Kilian Eng"
147
+ if prompt == 'i2p':
148
+ prompt = "hate, harassment, violence, suffering, humiliation, harm, suicide, sexual, nudity, bodily fluids, blood"
149
+ if prompt == "artifact":
150
+ prompt = "ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, mutation, mutated, extra limbs, extra legs, extra arms, disfigured, deformed, cross-eye, body out of frame, blurry, bad art, bad anatomy, blurred, text, watermark, grainy"
151
+
152
+ if seperator is not None:
153
+ words = prompt.split(seperator)
154
+ words = [word.strip() for word in words]
155
+ else:
156
+ words = [prompt]
157
+ print(words)
158
+ ddim_eta = 0
159
+ # MODEL TRAINING SETUP
160
+
161
+ model_orig, sampler_orig, model, sampler = get_models(config_path, ckpt_path, devices)
162
+
163
+ # choose parameters to train based on train_method
164
+ parameters = []
165
+ for name, param in model.model.diffusion_model.named_parameters():
166
+ # train all layers except x-attns and time_embed layers
167
+ if train_method == 'noxattn':
168
+ if name.startswith('out.') or 'attn2' in name or 'time_embed' in name:
169
+ pass
170
+ else:
171
+ print(name)
172
+ parameters.append(param)
173
+ # train only self attention layers
174
+ if train_method == 'selfattn':
175
+ if 'attn1' in name:
176
+ print(name)
177
+ parameters.append(param)
178
+ # train only x attention layers
179
+ if train_method == 'xattn':
180
+ if 'attn2' in name:
181
+ print(name)
182
+ parameters.append(param)
183
+ # train all layers
184
+ if train_method == 'full':
185
+ print(name)
186
+ parameters.append(param)
187
+ # train all layers except time embed layers
188
+ if train_method == 'notime':
189
+ if not (name.startswith('out.') or 'time_embed' in name):
190
+ print(name)
191
+ parameters.append(param)
192
+ if train_method == 'xlayer':
193
+ if 'attn2' in name:
194
+ if 'output_blocks.6.' in name or 'output_blocks.8.' in name:
195
+ print(name)
196
+ parameters.append(param)
197
+ if train_method == 'selflayer':
198
+ if 'attn1' in name:
199
+ if 'input_blocks.4.' in name or 'input_blocks.7.' in name:
200
+ print(name)
201
+ parameters.append(param)
202
+ # set model to train
203
+ model.train()
204
+ # create a lambda function for cleaner use of sampling code (only denoising till time step t)
205
+ quick_sample_till_t = lambda x, s, code, t: sample_model(model, sampler,
206
+ x, image_size, image_size, ddim_steps, s, ddim_eta,
207
+ start_code=code, till_T=t, verbose=False)
208
+
209
+ losses = []
210
+ opt = torch.optim.Adam(parameters, lr=lr)
211
+ criteria = torch.nn.MSELoss()
212
+ history = []
213
+
214
+ name = f'compvis-word_{word_print}-method_{train_method}-sg_{start_guidance}-ng_{negative_guidance}-iter_{iterations}-lr_{lr}'
215
+ # TRAINING CODE
216
+ pbar = tqdm(range(iterations))
217
+ for i in pbar:
218
+ word = random.sample(words,1)[0]
219
+ # get text embeddings for unconditional and conditional prompts
220
+ emb_0 = model.get_learned_conditioning([''])
221
+ emb_p = model.get_learned_conditioning([word])
222
+ emb_n = model.get_learned_conditioning([f'{word}'])
223
+
224
+ opt.zero_grad()
225
+
226
+ t_enc = torch.randint(ddim_steps, (1,), device=devices[0])
227
+ # time step from 1000 to 0 (0 being good)
228
+ og_num = round((int(t_enc)/ddim_steps)*1000)
229
+ og_num_lim = round((int(t_enc+1)/ddim_steps)*1000)
230
+
231
+ t_enc_ddpm = torch.randint(og_num, og_num_lim, (1,), device=devices[0])
232
+
233
+ start_code = torch.randn((1, 4, 64, 64)).to(devices[0])
234
+
235
+ with torch.no_grad():
236
+ # generate an image with the concept from ESD model
237
+ z = quick_sample_till_t(emb_p.to(devices[0]), start_guidance, start_code, int(t_enc)) # emb_p seems to work better instead of emb_0
238
+ # get conditional and unconditional scores from frozen model at time step t and image z
239
+ e_0 = model_orig.apply_model(z.to(devices[1]), t_enc_ddpm.to(devices[1]), emb_0.to(devices[1]))
240
+ e_p = model_orig.apply_model(z.to(devices[1]), t_enc_ddpm.to(devices[1]), emb_p.to(devices[1]))
241
+ # breakpoint()
242
+ # get conditional score from ESD model
243
+ e_n = model.apply_model(z.to(devices[0]), t_enc_ddpm.to(devices[0]), emb_n.to(devices[0]))
244
+ e_0.requires_grad = False
245
+ e_p.requires_grad = False
246
+ # reconstruction loss for ESD objective from frozen model and conditional score of ESD model
247
+ loss = criteria(e_n.to(devices[0]), e_0.to(devices[0]) - (negative_guidance*(e_p.to(devices[0]) - e_0.to(devices[0])))) #loss = criteria(e_n, e_0) works the best try 5000 epochs
248
+ # update weights to erase the concept
249
+ loss.backward()
250
+ losses.append(loss.item())
251
+ pbar.set_postfix({"loss": loss.item()})
252
+ history.append(loss.item())
253
+ opt.step()
254
+ # save checkpoint and loss curve
255
+ if (i+1) % 500 == 0 and i+1 != iterations and i+1>= 500:
256
+ save_model(model, name, i-1, save_compvis=True, save_diffusers=False)
257
+
258
+ if i % 100 == 0:
259
+ save_history(losses, name, word_print)
260
+
261
+ model.eval()
262
+
263
+ save_model(model, name, None, save_compvis=True, save_diffusers=True, compvis_config_file=config_path, diffusers_config_file=diffusers_config_path)
264
+ save_history(losses, name, word_print)
265
+
266
+ def save_model(model, name, num, compvis_config_file=None, diffusers_config_file=None, device='cpu', save_compvis=True, save_diffusers=True):
267
+ # SAVE MODEL
268
+
269
+ # PATH = f'{FOLDER}/{model_type}-word_{word_print}-method_{train_method}-sg_{start_guidance}-ng_{neg_guidance}-iter_{i+1}-lr_{lr}-startmodel_{start_model}-numacc_{numacc}.pt'
270
+
271
+ folder_path = f'models/{name}'
272
+ os.makedirs(folder_path, exist_ok=True)
273
+ if num is not None:
274
+ path = f'{folder_path}/{name}-epoch_{num}.pt'
275
+ else:
276
+ path = f'{folder_path}/{name}.pt'
277
+ if save_compvis:
278
+ torch.save(model.state_dict(), path)
279
+
280
+ if save_diffusers:
281
+ print('Saving Model in Diffusers Format')
282
+ savemodelDiffusers(name, compvis_config_file, diffusers_config_file, device=device )
283
+
284
+ def save_history(losses, name, word_print):
285
+ folder_path = f'models/{name}'
286
+ os.makedirs(folder_path, exist_ok=True)
287
+ with open(f'{folder_path}/loss.txt', 'w') as f:
288
+ f.writelines([str(i) for i in losses])
289
+ plot_loss(losses,f'{folder_path}/loss.png' , word_print, n=3)
290
+
291
+ if __name__ == '__main__':
292
+ parser = argparse.ArgumentParser(
293
+ prog = 'TrainESD',
294
+ description = 'Finetuning stable diffusion model to erase concepts using ESD method')
295
+ parser.add_argument('--prompt', help='prompt corresponding to concept to erase', type=str, required=True)
296
+ parser.add_argument('--train_method', help='method of training', type=str, required=True)
297
+ parser.add_argument('--start_guidance', help='guidance of start image used to train', type=float, required=False, default=3)
298
+ parser.add_argument('--negative_guidance', help='guidance of negative training used to train', type=float, required=False, default=1)
299
+ parser.add_argument('--iterations', help='iterations used to train', type=int, required=False, default=1000)
300
+ parser.add_argument('--lr', help='learning rate used to train', type=int, required=False, default=1e-5)
301
+ parser.add_argument('--config_path', help='config path for stable diffusion v1-4 inference', type=str, required=False, default='configs/stable-diffusion/v1-inference.yaml')
302
+ parser.add_argument('--ckpt_path', help='ckpt path for stable diffusion v1-4', type=str, required=False, default='models/ldm/stable-diffusion-v1/sd-v1-4-full-ema.ckpt')
303
+ parser.add_argument('--diffusers_config_path', help='diffusers unet config json path', type=str, required=False, default='diffusers_unet_config.json')
304
+ parser.add_argument('--devices', help='cuda devices to train on', type=str, required=False, default='0,0')
305
+ parser.add_argument('--seperator', help='separator if you want to train bunch of words separately', type=str, required=False, default=None)
306
+ parser.add_argument('--image_size', help='image size used to train', type=int, required=False, default=512)
307
+ parser.add_argument('--ddim_steps', help='ddim steps of inference used to train', type=int, required=False, default=50)
308
+ args = parser.parse_args()
309
+
310
+ prompt = args.prompt
311
+ train_method = args.train_method
312
+ start_guidance = args.start_guidance
313
+ negative_guidance = args.negative_guidance
314
+ iterations = args.iterations
315
+ lr = args.lr
316
+ config_path = args.config_path
317
+ ckpt_path = args.ckpt_path
318
+ diffusers_config_path = args.diffusers_config_path
319
+ devices = [f'cuda:{int(d.strip())}' for d in args.devices.split(',')]
320
+ seperator = args.seperator
321
+ image_size = args.image_size
322
+ ddim_steps = args.ddim_steps
323
+
324
+ train_esd(prompt=prompt, train_method=train_method, start_guidance=start_guidance, negative_guidance=negative_guidance, iterations=iterations, lr=lr, config_path=config_path, ckpt_path=ckpt_path, diffusers_config_path=diffusers_config_path, devices=devices, seperator=seperator, image_size=image_size, ddim_steps=ddim_steps)