sayakpaul HF staff commited on
Commit
3304f7d
1 Parent(s): f16a80a

apply styling.

Browse files
app.py CHANGED
@@ -1,7 +1,7 @@
1
- import gradio as gr
2
- from convert import run_conversion
3
- from hub_utils import save_model_card, push_to_hub
4
 
 
 
5
 
6
  PRETRAINED_CKPT = "CompVis/stable-diffusion-v1-4"
7
  DESCRIPTION = """
@@ -20,25 +20,37 @@ This Space lets you convert KerasCV Stable Diffusion weights to a format compati
20
  Check [here](https://github.com/huggingface/diffusers/blob/31be42209ddfdb69d9640a777b32e9b5c6259bf0/examples/dreambooth/train_dreambooth_lora.py#L975) for an example on how you can change the scheduler of an already initialized pipeline.
21
  """
22
 
 
23
  def run(hf_token, text_encoder_weights, unet_weights, repo_prefix):
24
  if text_encoder_weights == "":
25
- text_encoder_weights = None
26
  if unet_weights == "":
27
- unet_weights = None
28
  pipeline = run_conversion(text_encoder_weights, unet_weights)
29
  output_path = "kerascv_sd_diffusers_pipeline"
30
  pipeline.save_pretrained(output_path)
31
- save_model_card(base_model=PRETRAINED_CKPT, repo_folder=output_path, weight_paths=[text_encoder_weights, unet_weights], repo_prefix=repo_prefix)
 
 
 
 
 
32
  push_str = push_to_hub(hf_token, output_path, repo_prefix)
33
  return push_str
34
 
35
- demo = gr.Interface(
36
- title="KerasCV Stable Diffusion to Diffusers Stable Diffusion Pipelines 🧨🤗",
37
- description=DESCRIPTION,
38
- allow_flagging="never",
39
- inputs=[gr.Text(max_lines=1, label="your_hf_token"), gr.Text(max_lines=1, label="text_encoder_weights"), gr.Text(max_lines=1, label="unet_weights"), gr.Text(max_lines=1, label="output_repo_prefix")],
40
- outputs=[gr.Markdown(label="output")],
41
- fn=run,
42
- )
43
 
44
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
 
 
2
 
3
+ from convert import run_conversion
4
+ from hub_utils import push_to_hub, save_model_card
5
 
6
  PRETRAINED_CKPT = "CompVis/stable-diffusion-v1-4"
7
  DESCRIPTION = """
 
20
  Check [here](https://github.com/huggingface/diffusers/blob/31be42209ddfdb69d9640a777b32e9b5c6259bf0/examples/dreambooth/train_dreambooth_lora.py#L975) for an example on how you can change the scheduler of an already initialized pipeline.
21
  """
22
 
23
+
24
  def run(hf_token, text_encoder_weights, unet_weights, repo_prefix):
25
  if text_encoder_weights == "":
26
+ text_encoder_weights = None
27
  if unet_weights == "":
28
+ unet_weights = None
29
  pipeline = run_conversion(text_encoder_weights, unet_weights)
30
  output_path = "kerascv_sd_diffusers_pipeline"
31
  pipeline.save_pretrained(output_path)
32
+ save_model_card(
33
+ base_model=PRETRAINED_CKPT,
34
+ repo_folder=output_path,
35
+ weight_paths=[text_encoder_weights, unet_weights],
36
+ repo_prefix=repo_prefix,
37
+ )
38
  push_str = push_to_hub(hf_token, output_path, repo_prefix)
39
  return push_str
40
 
 
 
 
 
 
 
 
 
41
 
42
+ demo = gr.Interface(
43
+ title="KerasCV Stable Diffusion to Diffusers Stable Diffusion Pipelines 🧨🤗",
44
+ description=DESCRIPTION,
45
+ allow_flagging="never",
46
+ inputs=[
47
+ gr.Text(max_lines=1, label="your_hf_token"),
48
+ gr.Text(max_lines=1, label="text_encoder_weights"),
49
+ gr.Text(max_lines=1, label="unet_weights"),
50
+ gr.Text(max_lines=1, label="output_repo_prefix"),
51
+ ],
52
+ outputs=[gr.Markdown(label="output")],
53
+ fn=run,
54
+ )
55
+
56
+ demo.launch()
conversion_utils/__init__.py CHANGED
@@ -1,3 +1,3 @@
1
  from .text_encoder import populate_text_encoder
2
  from .unet import populate_unet
3
- from .utils import run_assertion
 
1
  from .text_encoder import populate_text_encoder
2
  from .unet import populate_unet
3
+ from .utils import run_assertion
conversion_utils/text_encoder.py CHANGED
@@ -1,16 +1,23 @@
1
- from keras_cv.models import stable_diffusion
 
2
  import tensorflow as tf
3
  import torch
4
- from typing import Dict
5
 
6
  MAX_SEQ_LENGTH = 77
7
 
 
8
  def populate_text_encoder(tf_text_encoder: tf.keras.Model) -> Dict[str, torch.Tensor]:
9
  """Populates the state dict from the provided TensorFlow model
10
  (applicable only for the text encoder)."""
11
  text_state_dict = dict()
12
  num_encoder_layers = 0
13
 
 
 
 
 
 
14
  for layer in tf_text_encoder.layers:
15
  # Embeddings.
16
  if isinstance(layer, stable_diffusion.text_encoder.CLIPEmbedding):
@@ -102,9 +109,4 @@ def populate_text_encoder(tf_text_encoder: tf.keras.Model) -> Dict[str, torch.Te
102
  layer.get_weights()[1]
103
  )
104
 
105
- # Position ids.
106
- text_state_dict["text_model.embeddings.position_ids"] = torch.tensor(
107
- list(range(MAX_SEQ_LENGTH))
108
- ).unsqueeze(0)
109
-
110
- return text_state_dict
 
1
+ from typing import Dict
2
+
3
  import tensorflow as tf
4
  import torch
5
+ from keras_cv.models import stable_diffusion
6
 
7
  MAX_SEQ_LENGTH = 77
8
 
9
+
10
  def populate_text_encoder(tf_text_encoder: tf.keras.Model) -> Dict[str, torch.Tensor]:
11
  """Populates the state dict from the provided TensorFlow model
12
  (applicable only for the text encoder)."""
13
  text_state_dict = dict()
14
  num_encoder_layers = 0
15
 
16
+ # Position ids.
17
+ text_state_dict["text_model.embeddings.position_ids"] = torch.tensor(
18
+ list(range(MAX_SEQ_LENGTH))
19
+ ).unsqueeze(0)
20
+
21
  for layer in tf_text_encoder.layers:
22
  # Embeddings.
23
  if isinstance(layer, stable_diffusion.text_encoder.CLIPEmbedding):
 
109
  layer.get_weights()[1]
110
  )
111
 
112
+ return text_state_dict
 
 
 
 
 
conversion_utils/unet.py CHANGED
@@ -1,10 +1,14 @@
1
- import tensorflow as tf
2
- import torch
3
- from typing import Dict
4
  from itertools import product
 
 
 
 
5
  from keras_cv.models import stable_diffusion
6
 
7
- def port_transformer_block(transformer_block: tf.keras.Model, up_down: int, block_id: int, attention_id: int) -> Dict[str, torch.Tensor]:
 
 
 
8
  """Populates a Transformer block."""
9
  transformer_dict = dict()
10
  if block_id is not None:
@@ -15,36 +19,58 @@ def port_transformer_block(transformer_block: tf.keras.Model, up_down: int, bloc
15
  # Norms.
16
  for i in range(1, 4):
17
  if i == 1:
18
- norm = transformer_block.norm1
19
  elif i == 2:
20
  norm = transformer_block.norm2
21
  elif i == 3:
22
  norm = transformer_block.norm3
23
- transformer_dict[f"{prefix}.attentions.{attention_id}.transformer_blocks.0.norm{i}.weight"] = torch.from_numpy(norm.get_weights()[0])
24
- transformer_dict[f"{prefix}.attentions.{attention_id}.transformer_blocks.0.norm{i}.bias"] = torch.from_numpy(norm.get_weights()[1])
25
-
 
 
 
 
26
  # Attentions.
27
  for i in range(1, 3):
28
  if i == 1:
29
  attn = transformer_block.attn1
30
  else:
31
  attn = transformer_block.attn2
32
- transformer_dict[f"{prefix}.attentions.{attention_id}.transformer_blocks.0.attn{i}.to_q.weight"] = torch.from_numpy(attn.to_q.get_weights()[0].transpose())
33
- transformer_dict[f"{prefix}.attentions.{attention_id}.transformer_blocks.0.attn{i}.to_k.weight"] = torch.from_numpy(attn.to_k.get_weights()[0].transpose())
34
- transformer_dict[f"{prefix}.attentions.{attention_id}.transformer_blocks.0.attn{i}.to_v.weight"] = torch.from_numpy(attn.to_v.get_weights()[0].transpose())
35
- transformer_dict[f"{prefix}.attentions.{attention_id}.transformer_blocks.0.attn{i}.to_out.0.weight"] = torch.from_numpy(attn.out_proj.get_weights()[0].transpose())
36
- transformer_dict[f"{prefix}.attentions.{attention_id}.transformer_blocks.0.attn{i}.to_out.0.bias"] = torch.from_numpy(attn.out_proj.get_weights()[1])
37
-
38
- # Dense.
 
 
 
 
 
 
 
 
 
 
39
  for i in range(0, 3, 2):
40
  if i == 0:
41
  layer = transformer_block.geglu.dense
42
- transformer_dict[f"{prefix}.attentions.{attention_id}.transformer_blocks.0.ff.net.{i}.proj.weight"] = torch.from_numpy(layer.get_weights()[0].transpose())
43
- transformer_dict[f"{prefix}.attentions.{attention_id}.transformer_blocks.0.ff.net.{i}.proj.bias"] = torch.from_numpy(layer.get_weights()[1])
 
 
 
 
44
  else:
45
  layer = transformer_block.dense
46
- transformer_dict[f"{prefix}.attentions.{attention_id}.transformer_blocks.0.ff.net.{i}.weight"] = torch.from_numpy(layer.get_weights()[0].transpose())
47
- transformer_dict[f"{prefix}.attentions.{attention_id}.transformer_blocks.0.ff.net.{i}.bias"] = torch.from_numpy(layer.get_weights()[1])
 
 
 
 
48
 
49
  return transformer_dict
50
 
@@ -54,7 +80,7 @@ def populate_unet(tf_unet: tf.keras.Model) -> Dict[str, torch.Tensor]:
54
  (applicable only for the UNet)."""
55
  unet_state_dict = dict()
56
 
57
- timstep_emb = 1
58
  padded_conv = 1
59
  up_block = 0
60
 
@@ -67,37 +93,66 @@ def populate_unet(tf_unet: tf.keras.Model) -> Dict[str, torch.Tensor]:
67
  for layer in tf_unet.layers:
68
  # Timstep embedding.
69
  if isinstance(layer, tf.keras.layers.Dense):
70
- unet_state_dict[f"time_embedding.linear_{timstep_emb}.weight"] = torch.from_numpy(layer.get_weights()[0].transpose())
71
- unet_state_dict[f"time_embedding.linear_{timstep_emb}.bias"] = torch.from_numpy(layer.get_weights()[1])
 
 
 
 
72
  timstep_emb += 1
73
-
74
  # Padded convs (downsamplers).
75
- elif isinstance(layer, stable_diffusion.__internal__.layers.padded_conv2d.PaddedConv2D):
 
 
76
  if padded_conv == 1:
77
  # Transposition axes taken from https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_tf_pytorch_utils.py#L104
78
- unet_state_dict["conv_in.weight"] = torch.from_numpy(layer.get_weights()[0].transpose(3, 2, 0, 1))
79
- unet_state_dict["conv_in.bias"] = torch.from_numpy(layer.get_weights()[1])
 
 
 
 
80
  elif padded_conv in [2, 3, 4]:
81
- unet_state_dict[f"down_blocks.{padded_conv-2}.downsamplers.0.conv.weight"] = torch.from_numpy(layer.get_weights()[0].transpose(3, 2, 0, 1))
82
- unet_state_dict[f"down_blocks.{padded_conv-2}.downsamplers.0.conv.bias"] = torch.from_numpy(layer.get_weights()[1])
 
 
 
 
83
  elif padded_conv == 5:
84
- unet_state_dict["conv_out.weight"] = torch.from_numpy(layer.get_weights()[0].transpose(3, 2, 0, 1))
85
- unet_state_dict["conv_out.bias"] = torch.from_numpy(layer.get_weights()[1])
86
-
 
 
 
 
87
  padded_conv += 1
88
 
89
  # Upsamplers.
90
  elif isinstance(layer, stable_diffusion.diffusion_model.Upsample):
91
  conv = layer.conv
92
- unet_state_dict[f"up_blocks.{up_block}.upsamplers.0.conv.weight"] = torch.from_numpy(conv.get_weights()[0].transpose(3, 2, 0, 1))
93
- unet_state_dict[f"up_blocks.{up_block}.upsamplers.0.conv.bias"] = torch.from_numpy(conv.get_weights()[1])
 
 
 
 
94
  up_block += 1
95
 
96
  # Output norms.
97
- elif isinstance(layer, stable_diffusion.__internal__.layers.group_normalization.GroupNormalization):
98
- unet_state_dict["conv_norm_out.weight"] = torch.from_numpy(layer.get_weights()[0])
99
- unet_state_dict["conv_norm_out.bias"] = torch.from_numpy(layer.get_weights()[1])
100
-
 
 
 
 
 
 
 
101
  # All ResBlocks.
102
  elif isinstance(layer, stable_diffusion.diffusion_model.ResBlock):
103
  layer_name = layer.name
@@ -105,8 +160,8 @@ def populate_unet(tf_unet: tf.keras.Model) -> Dict[str, torch.Tensor]:
105
 
106
  # Down.
107
  if len(parts) == 2 or int(parts[-1]) < 8:
108
- entry_flow = layer.entry_flow
109
- embedding_flow = layer.embedding_flow
110
  exit_flow = layer.exit_flow
111
 
112
  down_block_id = 0 if len(parts) == 2 else int(parts[-1]) // 2
@@ -114,72 +169,138 @@ def populate_unet(tf_unet: tf.keras.Model) -> Dict[str, torch.Tensor]:
114
 
115
  # Conv blocks.
116
  first_conv_layer = entry_flow[-1]
117
- unet_state_dict[f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.conv1.weight"] = torch.from_numpy(first_conv_layer.get_weights()[0].transpose(3, 2, 0, 1))
118
- unet_state_dict[f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.conv1.bias"] = torch.from_numpy(first_conv_layer.get_weights()[1])
 
 
 
 
 
 
119
  second_conv_layer = exit_flow[-1]
120
- unet_state_dict[f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.conv2.weight"] = torch.from_numpy(second_conv_layer.get_weights()[0].transpose(3, 2, 0, 1))
121
- unet_state_dict[f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.conv2.bias"] = torch.from_numpy(second_conv_layer.get_weights()[1])
122
-
123
- # Residual blocks.
 
 
 
 
 
 
124
  if hasattr(layer, "residual_projection"):
125
- if isinstance(layer.residual_projection, stable_diffusion.__internal__.layers.padded_conv2d.PaddedConv2D):
 
 
 
126
  residual = layer.residual_projection
127
- unet_state_dict[f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.conv_shortcut.weight"] = torch.from_numpy(residual.get_weights()[0].transpose(3, 2, 0, 1))
128
- unet_state_dict[f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.conv_shortcut.bias"] = torch.from_numpy(residual.get_weights()[1])
 
 
 
 
 
 
129
 
130
  # Timestep embedding.
131
  embedding_proj = embedding_flow[-1]
132
- unet_state_dict[f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.time_emb_proj.weight"] = torch.from_numpy(embedding_proj.get_weights()[0].transpose())
133
- unet_state_dict[f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.time_emb_proj.bias"] = torch.from_numpy(embedding_proj.get_weights()[1])
134
-
 
 
 
 
135
  # Norms.
136
  first_group_norm = entry_flow[0]
137
- unet_state_dict[f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.norm1.weight"] = torch.from_numpy(first_group_norm.get_weights()[0])
138
- unet_state_dict[f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.norm1.bias"] = torch.from_numpy(first_group_norm.get_weights()[1])
 
 
 
 
139
  second_group_norm = exit_flow[0]
140
- unet_state_dict[f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.norm2.weight"] = torch.from_numpy(second_group_norm.get_weights()[0])
141
- unet_state_dict[f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.norm2.bias"] = torch.from_numpy(second_group_norm.get_weights()[1])
 
 
 
 
142
 
143
  # Middle.
144
  elif int(parts[-1]) == 8 or int(parts[-1]) == 9:
145
- entry_flow = layer.entry_flow
146
- embedding_flow = layer.embedding_flow
147
  exit_flow = layer.exit_flow
148
-
149
  mid_resnet_id = int(parts[-1]) % 2
150
-
151
  # Conv blocks.
152
  first_conv_layer = entry_flow[-1]
153
- unet_state_dict[f"mid_block.resnets.{mid_resnet_id}.conv1.weight"] = torch.from_numpy(first_conv_layer.get_weights()[0].transpose(3, 2, 0, 1))
154
- unet_state_dict[f"mid_block.resnets.{mid_resnet_id}.conv1.bias"] = torch.from_numpy(first_conv_layer.get_weights()[1])
 
 
 
 
 
 
155
  second_conv_layer = exit_flow[-1]
156
- unet_state_dict[f"mid_block.resnets.{mid_resnet_id}.conv2.weight"] = torch.from_numpy(second_conv_layer.get_weights()[0].transpose(3, 2, 0, 1))
157
- unet_state_dict[f"mid_block.resnets.{mid_resnet_id}.conv2.bias"] = torch.from_numpy(second_conv_layer.get_weights()[1])
158
-
159
- # Residual blocks.
 
 
 
 
 
 
160
  if hasattr(layer, "residual_projection"):
161
- if isinstance(layer.residual_projection, stable_diffusion.__internal__.layers.padded_conv2d.PaddedConv2D):
 
 
 
162
  residual = layer.residual_projection
163
- unet_state_dict[f"mid_block.resnets.{mid_resnet_id}.conv_shortcut.weight"] = torch.from_numpy(residual.get_weights()[0].transpose(3, 2, 0, 1))
164
- unet_state_dict[f"mid_block.resnets.{mid_resnet_id}.conv_shortcut.bias"] = torch.from_numpy(residual.get_weights()[1])
 
 
 
 
 
 
165
 
166
  # Timestep embedding.
167
  embedding_proj = embedding_flow[-1]
168
- unet_state_dict[f"mid_block.resnets.{mid_resnet_id}.time_emb_proj.weight"] = torch.from_numpy(embedding_proj.get_weights()[0].transpose())
169
- unet_state_dict[f"mid_block.resnets.{mid_resnet_id}.time_emb_proj.bias"] = torch.from_numpy(embedding_proj.get_weights()[1])
 
 
 
 
170
 
171
  # Norms.
172
  first_group_norm = entry_flow[0]
173
- unet_state_dict[f"mid_block.resnets.{mid_resnet_id}.norm1.weight"] = torch.from_numpy(first_group_norm.get_weights()[0])
174
- unet_state_dict[f"mid_block.resnets.{mid_resnet_id}.norm1.bias"] = torch.from_numpy(first_group_norm.get_weights()[1])
 
 
 
 
175
  second_group_norm = exit_flow[0]
176
- unet_state_dict[f"mid_block.resnets.{mid_resnet_id}.norm2.weight"] = torch.from_numpy(second_group_norm.get_weights()[0])
177
- unet_state_dict[f"mid_block.resnets.{mid_resnet_id}.norm2.bias"] = torch.from_numpy(second_group_norm.get_weights()[1])
 
 
 
 
178
 
179
- # Up.
180
  elif int(parts[-1]) > 9 and up_res_block_flag < len(up_res_blocks):
181
- entry_flow = layer.entry_flow
182
- embedding_flow = layer.embedding_flow
183
  exit_flow = layer.exit_flow
184
 
185
  up_res_block = up_res_blocks[up_res_block_flag]
@@ -188,32 +309,65 @@ def populate_unet(tf_unet: tf.keras.Model) -> Dict[str, torch.Tensor]:
188
 
189
  # Conv blocks.
190
  first_conv_layer = entry_flow[-1]
191
- unet_state_dict[f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.conv1.weight"] = torch.from_numpy(first_conv_layer.get_weights()[0].transpose(3, 2, 0, 1))
192
- unet_state_dict[f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.conv1.bias"] = torch.from_numpy(first_conv_layer.get_weights()[1])
 
 
 
 
 
 
193
  second_conv_layer = exit_flow[-1]
194
- unet_state_dict[f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.conv2.weight"] = torch.from_numpy(second_conv_layer.get_weights()[0].transpose(3, 2, 0, 1))
195
- unet_state_dict[f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.conv2.bias"] = torch.from_numpy(second_conv_layer.get_weights()[1])
196
-
197
- # Residual blocks.
 
 
 
 
 
 
198
  if hasattr(layer, "residual_projection"):
199
- if isinstance(layer.residual_projection, stable_diffusion.__internal__.layers.padded_conv2d.PaddedConv2D):
 
 
 
200
  residual = layer.residual_projection
201
- unet_state_dict[f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.conv_shortcut.weight"] = torch.from_numpy(residual.get_weights()[0].transpose(3, 2, 0, 1))
202
- unet_state_dict[f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.conv_shortcut.bias"] = torch.from_numpy(residual.get_weights()[1])
 
 
 
 
 
 
203
 
204
  # Timestep embedding.
205
  embedding_proj = embedding_flow[-1]
206
- unet_state_dict[f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.time_emb_proj.weight"] = torch.from_numpy(embedding_proj.get_weights()[0].transpose())
207
- unet_state_dict[f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.time_emb_proj.bias"] = torch.from_numpy(embedding_proj.get_weights()[1])
208
-
 
 
 
 
209
  # Norms.
210
  first_group_norm = entry_flow[0]
211
- unet_state_dict[f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.norm1.weight"] = torch.from_numpy(first_group_norm.get_weights()[0])
212
- unet_state_dict[f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.norm1.bias"] = torch.from_numpy(first_group_norm.get_weights()[1])
 
 
 
 
213
  second_group_norm = exit_flow[0]
214
- unet_state_dict[f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.norm2.weight"] = torch.from_numpy(second_group_norm.get_weights()[0])
215
- unet_state_dict[f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.norm2.bias"] = torch.from_numpy(second_group_norm.get_weights()[1])
216
-
 
 
 
 
217
  up_res_block_flag += 1
218
 
219
  # All SpatialTransformer blocks.
@@ -225,67 +379,119 @@ def populate_unet(tf_unet: tf.keras.Model) -> Dict[str, torch.Tensor]:
225
  if len(parts) == 2 or int(parts[-1]) < 6:
226
  down_block_id = 0 if len(parts) == 2 else int(parts[-1]) // 2
227
  down_attention_id = 0 if len(parts) == 2 else int(parts[-1]) % 2
228
-
229
  # Convs.
230
  proj1 = layer.proj1
231
- unet_state_dict[f"down_blocks.{down_block_id}.attentions.{down_attention_id}.proj_in.weight"] = torch.from_numpy(proj1.get_weights()[0].transpose(3, 2, 0, 1))
232
- unet_state_dict[f"down_blocks.{down_block_id}.attentions.{down_attention_id}.proj_in.bias"] = torch.from_numpy(proj1.get_weights()[1])
 
 
 
 
233
  proj2 = layer.proj2
234
- unet_state_dict[f"down_blocks.{down_block_id}.attentions.{down_attention_id}.proj_out.weight"] = torch.from_numpy(proj2.get_weights()[0].transpose(3, 2, 0, 1))
235
- unet_state_dict[f"down_blocks.{down_block_id}.attentions.{down_attention_id}.proj_out.bias"] = torch.from_numpy(proj2.get_weights()[1])
 
 
 
 
236
 
237
  # Transformer blocks.
238
  transformer_block = layer.transformer_block
239
- unet_state_dict.update(port_transformer_block(transformer_block, "down", down_block_id, down_attention_id))
 
 
 
 
240
 
241
  # Norms.
242
  norm = layer.norm
243
- unet_state_dict[f"down_blocks.{down_block_id}.attentions.{down_attention_id}.norm.weight"] = torch.from_numpy(norm.get_weights()[0])
244
- unet_state_dict[f"down_blocks.{down_block_id}.attentions.{down_attention_id}.norm.bias"] = torch.from_numpy(norm.get_weights()[1])
 
 
 
 
245
 
246
  # Middle.
247
  elif int(parts[-1]) == 6:
248
  mid_attention_id = int(parts[-1]) % 2
249
  # Convs.
250
  proj1 = layer.proj1
251
- unet_state_dict[f"mid_block.attentions.{mid_attention_id}.proj_in.weight"] = torch.from_numpy(proj1.get_weights()[0].transpose(3, 2, 0, 1))
252
- unet_state_dict[f"mid_block.attentions.{mid_attention_id}.proj_in.bias"] = torch.from_numpy(proj1.get_weights()[1])
 
 
 
 
253
  proj2 = layer.proj2
254
- unet_state_dict[f"mid_block.attentions.{mid_resnet_id}.proj_out.weight"] = torch.from_numpy(proj2.get_weights()[0].transpose(3, 2, 0, 1))
255
- unet_state_dict[f"mid_block.attentions.{mid_attention_id}.proj_out.bias"] = torch.from_numpy(proj2.get_weights()[1])
 
 
 
 
256
 
257
  # Transformer blocks.
258
  transformer_block = layer.transformer_block
259
- unet_state_dict.update(port_transformer_block(transformer_block, "mid", None, mid_attention_id))
 
 
 
 
260
 
261
  # Norms.
262
  norm = layer.norm
263
- unet_state_dict[f"mid_block.attentions.{mid_attention_id}.norm.weight"] = torch.from_numpy(norm.get_weights()[0])
264
- unet_state_dict[f"mid_block.attentions.{mid_attention_id}.norm.bias"] = torch.from_numpy(norm.get_weights()[1])
 
 
 
 
265
 
266
  # Up.
267
- elif int(parts[-1]) > 6 and up_spatial_transformer_flag < len(up_spatial_transformer_blocks):
268
- up_spatial_transformer_block = up_spatial_transformer_blocks[up_spatial_transformer_flag]
 
 
 
 
269
  up_block_id = up_spatial_transformer_block[0]
270
  up_attention_id = up_spatial_transformer_block[1]
271
 
272
  # Convs.
273
  proj1 = layer.proj1
274
- unet_state_dict[f"up_blocks.{up_block_id}.attentions.{up_attention_id}.proj_in.weight"] = torch.from_numpy(proj1.get_weights()[0].transpose(3, 2, 0, 1))
275
- unet_state_dict[f"up_blocks.{up_block_id}.attentions.{up_attention_id}.proj_in.bias"] = torch.from_numpy(proj1.get_weights()[1])
 
 
 
 
276
  proj2 = layer.proj2
277
- unet_state_dict[f"up_blocks.{up_block_id}.attentions.{up_attention_id}.proj_out.weight"] = torch.from_numpy(proj2.get_weights()[0].transpose(3, 2, 0, 1))
278
- unet_state_dict[f"up_blocks.{up_block_id}.attentions.{up_attention_id}.proj_out.bias"] = torch.from_numpy(proj2.get_weights()[1])
 
 
 
 
279
 
280
  # Transformer blocks.
281
  transformer_block = layer.transformer_block
282
- unet_state_dict.update(port_transformer_block(transformer_block, "up", up_block_id, up_attention_id))
 
 
 
 
283
 
284
  # Norms.
285
  norm = layer.norm
286
- unet_state_dict[f"up_blocks.{up_block_id}.attentions.{up_attention_id}.norm.weight"] = torch.from_numpy(norm.get_weights()[0])
287
- unet_state_dict[f"up_blocks.{up_block_id}.attentions.{up_attention_id}.norm.bias"] = torch.from_numpy(norm.get_weights()[1])
288
-
 
 
 
 
289
  up_spatial_transformer_flag += 1
290
 
291
- return unet_state_dict
 
 
 
 
1
  from itertools import product
2
+ from typing import Dict
3
+
4
+ import tensorflow as tf
5
+ import torch
6
  from keras_cv.models import stable_diffusion
7
 
8
+
9
+ def port_transformer_block(
10
+ transformer_block: tf.keras.Model, up_down: int, block_id: int, attention_id: int
11
+ ) -> Dict[str, torch.Tensor]:
12
  """Populates a Transformer block."""
13
  transformer_dict = dict()
14
  if block_id is not None:
 
19
  # Norms.
20
  for i in range(1, 4):
21
  if i == 1:
22
+ norm = transformer_block.norm1
23
  elif i == 2:
24
  norm = transformer_block.norm2
25
  elif i == 3:
26
  norm = transformer_block.norm3
27
+ transformer_dict[
28
+ f"{prefix}.attentions.{attention_id}.transformer_blocks.0.norm{i}.weight"
29
+ ] = torch.from_numpy(norm.get_weights()[0])
30
+ transformer_dict[
31
+ f"{prefix}.attentions.{attention_id}.transformer_blocks.0.norm{i}.bias"
32
+ ] = torch.from_numpy(norm.get_weights()[1])
33
+
34
  # Attentions.
35
  for i in range(1, 3):
36
  if i == 1:
37
  attn = transformer_block.attn1
38
  else:
39
  attn = transformer_block.attn2
40
+ transformer_dict[
41
+ f"{prefix}.attentions.{attention_id}.transformer_blocks.0.attn{i}.to_q.weight"
42
+ ] = torch.from_numpy(attn.to_q.get_weights()[0].transpose())
43
+ transformer_dict[
44
+ f"{prefix}.attentions.{attention_id}.transformer_blocks.0.attn{i}.to_k.weight"
45
+ ] = torch.from_numpy(attn.to_k.get_weights()[0].transpose())
46
+ transformer_dict[
47
+ f"{prefix}.attentions.{attention_id}.transformer_blocks.0.attn{i}.to_v.weight"
48
+ ] = torch.from_numpy(attn.to_v.get_weights()[0].transpose())
49
+ transformer_dict[
50
+ f"{prefix}.attentions.{attention_id}.transformer_blocks.0.attn{i}.to_out.0.weight"
51
+ ] = torch.from_numpy(attn.out_proj.get_weights()[0].transpose())
52
+ transformer_dict[
53
+ f"{prefix}.attentions.{attention_id}.transformer_blocks.0.attn{i}.to_out.0.bias"
54
+ ] = torch.from_numpy(attn.out_proj.get_weights()[1])
55
+
56
+ # Dense.
57
  for i in range(0, 3, 2):
58
  if i == 0:
59
  layer = transformer_block.geglu.dense
60
+ transformer_dict[
61
+ f"{prefix}.attentions.{attention_id}.transformer_blocks.0.ff.net.{i}.proj.weight"
62
+ ] = torch.from_numpy(layer.get_weights()[0].transpose())
63
+ transformer_dict[
64
+ f"{prefix}.attentions.{attention_id}.transformer_blocks.0.ff.net.{i}.proj.bias"
65
+ ] = torch.from_numpy(layer.get_weights()[1])
66
  else:
67
  layer = transformer_block.dense
68
+ transformer_dict[
69
+ f"{prefix}.attentions.{attention_id}.transformer_blocks.0.ff.net.{i}.weight"
70
+ ] = torch.from_numpy(layer.get_weights()[0].transpose())
71
+ transformer_dict[
72
+ f"{prefix}.attentions.{attention_id}.transformer_blocks.0.ff.net.{i}.bias"
73
+ ] = torch.from_numpy(layer.get_weights()[1])
74
 
75
  return transformer_dict
76
 
 
80
  (applicable only for the UNet)."""
81
  unet_state_dict = dict()
82
 
83
+ timstep_emb = 1
84
  padded_conv = 1
85
  up_block = 0
86
 
 
93
  for layer in tf_unet.layers:
94
  # Timstep embedding.
95
  if isinstance(layer, tf.keras.layers.Dense):
96
+ unet_state_dict[
97
+ f"time_embedding.linear_{timstep_emb}.weight"
98
+ ] = torch.from_numpy(layer.get_weights()[0].transpose())
99
+ unet_state_dict[
100
+ f"time_embedding.linear_{timstep_emb}.bias"
101
+ ] = torch.from_numpy(layer.get_weights()[1])
102
  timstep_emb += 1
103
+
104
  # Padded convs (downsamplers).
105
+ elif isinstance(
106
+ layer, stable_diffusion.__internal__.layers.padded_conv2d.PaddedConv2D
107
+ ):
108
  if padded_conv == 1:
109
  # Transposition axes taken from https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_tf_pytorch_utils.py#L104
110
+ unet_state_dict["conv_in.weight"] = torch.from_numpy(
111
+ layer.get_weights()[0].transpose(3, 2, 0, 1)
112
+ )
113
+ unet_state_dict["conv_in.bias"] = torch.from_numpy(
114
+ layer.get_weights()[1]
115
+ )
116
  elif padded_conv in [2, 3, 4]:
117
+ unet_state_dict[
118
+ f"down_blocks.{padded_conv-2}.downsamplers.0.conv.weight"
119
+ ] = torch.from_numpy(layer.get_weights()[0].transpose(3, 2, 0, 1))
120
+ unet_state_dict[
121
+ f"down_blocks.{padded_conv-2}.downsamplers.0.conv.bias"
122
+ ] = torch.from_numpy(layer.get_weights()[1])
123
  elif padded_conv == 5:
124
+ unet_state_dict["conv_out.weight"] = torch.from_numpy(
125
+ layer.get_weights()[0].transpose(3, 2, 0, 1)
126
+ )
127
+ unet_state_dict["conv_out.bias"] = torch.from_numpy(
128
+ layer.get_weights()[1]
129
+ )
130
+
131
  padded_conv += 1
132
 
133
  # Upsamplers.
134
  elif isinstance(layer, stable_diffusion.diffusion_model.Upsample):
135
  conv = layer.conv
136
+ unet_state_dict[
137
+ f"up_blocks.{up_block}.upsamplers.0.conv.weight"
138
+ ] = torch.from_numpy(conv.get_weights()[0].transpose(3, 2, 0, 1))
139
+ unet_state_dict[
140
+ f"up_blocks.{up_block}.upsamplers.0.conv.bias"
141
+ ] = torch.from_numpy(conv.get_weights()[1])
142
  up_block += 1
143
 
144
  # Output norms.
145
+ elif isinstance(
146
+ layer,
147
+ stable_diffusion.__internal__.layers.group_normalization.GroupNormalization,
148
+ ):
149
+ unet_state_dict["conv_norm_out.weight"] = torch.from_numpy(
150
+ layer.get_weights()[0]
151
+ )
152
+ unet_state_dict["conv_norm_out.bias"] = torch.from_numpy(
153
+ layer.get_weights()[1]
154
+ )
155
+
156
  # All ResBlocks.
157
  elif isinstance(layer, stable_diffusion.diffusion_model.ResBlock):
158
  layer_name = layer.name
 
160
 
161
  # Down.
162
  if len(parts) == 2 or int(parts[-1]) < 8:
163
+ entry_flow = layer.entry_flow
164
+ embedding_flow = layer.embedding_flow
165
  exit_flow = layer.exit_flow
166
 
167
  down_block_id = 0 if len(parts) == 2 else int(parts[-1]) // 2
 
169
 
170
  # Conv blocks.
171
  first_conv_layer = entry_flow[-1]
172
+ unet_state_dict[
173
+ f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.conv1.weight"
174
+ ] = torch.from_numpy(
175
+ first_conv_layer.get_weights()[0].transpose(3, 2, 0, 1)
176
+ )
177
+ unet_state_dict[
178
+ f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.conv1.bias"
179
+ ] = torch.from_numpy(first_conv_layer.get_weights()[1])
180
  second_conv_layer = exit_flow[-1]
181
+ unet_state_dict[
182
+ f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.conv2.weight"
183
+ ] = torch.from_numpy(
184
+ second_conv_layer.get_weights()[0].transpose(3, 2, 0, 1)
185
+ )
186
+ unet_state_dict[
187
+ f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.conv2.bias"
188
+ ] = torch.from_numpy(second_conv_layer.get_weights()[1])
189
+
190
+ # Residual blocks.
191
  if hasattr(layer, "residual_projection"):
192
+ if isinstance(
193
+ layer.residual_projection,
194
+ stable_diffusion.__internal__.layers.padded_conv2d.PaddedConv2D,
195
+ ):
196
  residual = layer.residual_projection
197
+ unet_state_dict[
198
+ f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.conv_shortcut.weight"
199
+ ] = torch.from_numpy(
200
+ residual.get_weights()[0].transpose(3, 2, 0, 1)
201
+ )
202
+ unet_state_dict[
203
+ f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.conv_shortcut.bias"
204
+ ] = torch.from_numpy(residual.get_weights()[1])
205
 
206
  # Timestep embedding.
207
  embedding_proj = embedding_flow[-1]
208
+ unet_state_dict[
209
+ f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.time_emb_proj.weight"
210
+ ] = torch.from_numpy(embedding_proj.get_weights()[0].transpose())
211
+ unet_state_dict[
212
+ f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.time_emb_proj.bias"
213
+ ] = torch.from_numpy(embedding_proj.get_weights()[1])
214
+
215
  # Norms.
216
  first_group_norm = entry_flow[0]
217
+ unet_state_dict[
218
+ f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.norm1.weight"
219
+ ] = torch.from_numpy(first_group_norm.get_weights()[0])
220
+ unet_state_dict[
221
+ f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.norm1.bias"
222
+ ] = torch.from_numpy(first_group_norm.get_weights()[1])
223
  second_group_norm = exit_flow[0]
224
+ unet_state_dict[
225
+ f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.norm2.weight"
226
+ ] = torch.from_numpy(second_group_norm.get_weights()[0])
227
+ unet_state_dict[
228
+ f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.norm2.bias"
229
+ ] = torch.from_numpy(second_group_norm.get_weights()[1])
230
 
231
  # Middle.
232
  elif int(parts[-1]) == 8 or int(parts[-1]) == 9:
233
+ entry_flow = layer.entry_flow
234
+ embedding_flow = layer.embedding_flow
235
  exit_flow = layer.exit_flow
236
+
237
  mid_resnet_id = int(parts[-1]) % 2
238
+
239
  # Conv blocks.
240
  first_conv_layer = entry_flow[-1]
241
+ unet_state_dict[
242
+ f"mid_block.resnets.{mid_resnet_id}.conv1.weight"
243
+ ] = torch.from_numpy(
244
+ first_conv_layer.get_weights()[0].transpose(3, 2, 0, 1)
245
+ )
246
+ unet_state_dict[
247
+ f"mid_block.resnets.{mid_resnet_id}.conv1.bias"
248
+ ] = torch.from_numpy(first_conv_layer.get_weights()[1])
249
  second_conv_layer = exit_flow[-1]
250
+ unet_state_dict[
251
+ f"mid_block.resnets.{mid_resnet_id}.conv2.weight"
252
+ ] = torch.from_numpy(
253
+ second_conv_layer.get_weights()[0].transpose(3, 2, 0, 1)
254
+ )
255
+ unet_state_dict[
256
+ f"mid_block.resnets.{mid_resnet_id}.conv2.bias"
257
+ ] = torch.from_numpy(second_conv_layer.get_weights()[1])
258
+
259
+ # Residual blocks.
260
  if hasattr(layer, "residual_projection"):
261
+ if isinstance(
262
+ layer.residual_projection,
263
+ stable_diffusion.__internal__.layers.padded_conv2d.PaddedConv2D,
264
+ ):
265
  residual = layer.residual_projection
266
+ unet_state_dict[
267
+ f"mid_block.resnets.{mid_resnet_id}.conv_shortcut.weight"
268
+ ] = torch.from_numpy(
269
+ residual.get_weights()[0].transpose(3, 2, 0, 1)
270
+ )
271
+ unet_state_dict[
272
+ f"mid_block.resnets.{mid_resnet_id}.conv_shortcut.bias"
273
+ ] = torch.from_numpy(residual.get_weights()[1])
274
 
275
  # Timestep embedding.
276
  embedding_proj = embedding_flow[-1]
277
+ unet_state_dict[
278
+ f"mid_block.resnets.{mid_resnet_id}.time_emb_proj.weight"
279
+ ] = torch.from_numpy(embedding_proj.get_weights()[0].transpose())
280
+ unet_state_dict[
281
+ f"mid_block.resnets.{mid_resnet_id}.time_emb_proj.bias"
282
+ ] = torch.from_numpy(embedding_proj.get_weights()[1])
283
 
284
  # Norms.
285
  first_group_norm = entry_flow[0]
286
+ unet_state_dict[
287
+ f"mid_block.resnets.{mid_resnet_id}.norm1.weight"
288
+ ] = torch.from_numpy(first_group_norm.get_weights()[0])
289
+ unet_state_dict[
290
+ f"mid_block.resnets.{mid_resnet_id}.norm1.bias"
291
+ ] = torch.from_numpy(first_group_norm.get_weights()[1])
292
  second_group_norm = exit_flow[0]
293
+ unet_state_dict[
294
+ f"mid_block.resnets.{mid_resnet_id}.norm2.weight"
295
+ ] = torch.from_numpy(second_group_norm.get_weights()[0])
296
+ unet_state_dict[
297
+ f"mid_block.resnets.{mid_resnet_id}.norm2.bias"
298
+ ] = torch.from_numpy(second_group_norm.get_weights()[1])
299
 
300
+ # Up.
301
  elif int(parts[-1]) > 9 and up_res_block_flag < len(up_res_blocks):
302
+ entry_flow = layer.entry_flow
303
+ embedding_flow = layer.embedding_flow
304
  exit_flow = layer.exit_flow
305
 
306
  up_res_block = up_res_blocks[up_res_block_flag]
 
309
 
310
  # Conv blocks.
311
  first_conv_layer = entry_flow[-1]
312
+ unet_state_dict[
313
+ f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.conv1.weight"
314
+ ] = torch.from_numpy(
315
+ first_conv_layer.get_weights()[0].transpose(3, 2, 0, 1)
316
+ )
317
+ unet_state_dict[
318
+ f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.conv1.bias"
319
+ ] = torch.from_numpy(first_conv_layer.get_weights()[1])
320
  second_conv_layer = exit_flow[-1]
321
+ unet_state_dict[
322
+ f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.conv2.weight"
323
+ ] = torch.from_numpy(
324
+ second_conv_layer.get_weights()[0].transpose(3, 2, 0, 1)
325
+ )
326
+ unet_state_dict[
327
+ f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.conv2.bias"
328
+ ] = torch.from_numpy(second_conv_layer.get_weights()[1])
329
+
330
+ # Residual blocks.
331
  if hasattr(layer, "residual_projection"):
332
+ if isinstance(
333
+ layer.residual_projection,
334
+ stable_diffusion.__internal__.layers.padded_conv2d.PaddedConv2D,
335
+ ):
336
  residual = layer.residual_projection
337
+ unet_state_dict[
338
+ f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.conv_shortcut.weight"
339
+ ] = torch.from_numpy(
340
+ residual.get_weights()[0].transpose(3, 2, 0, 1)
341
+ )
342
+ unet_state_dict[
343
+ f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.conv_shortcut.bias"
344
+ ] = torch.from_numpy(residual.get_weights()[1])
345
 
346
  # Timestep embedding.
347
  embedding_proj = embedding_flow[-1]
348
+ unet_state_dict[
349
+ f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.time_emb_proj.weight"
350
+ ] = torch.from_numpy(embedding_proj.get_weights()[0].transpose())
351
+ unet_state_dict[
352
+ f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.time_emb_proj.bias"
353
+ ] = torch.from_numpy(embedding_proj.get_weights()[1])
354
+
355
  # Norms.
356
  first_group_norm = entry_flow[0]
357
+ unet_state_dict[
358
+ f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.norm1.weight"
359
+ ] = torch.from_numpy(first_group_norm.get_weights()[0])
360
+ unet_state_dict[
361
+ f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.norm1.bias"
362
+ ] = torch.from_numpy(first_group_norm.get_weights()[1])
363
  second_group_norm = exit_flow[0]
364
+ unet_state_dict[
365
+ f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.norm2.weight"
366
+ ] = torch.from_numpy(second_group_norm.get_weights()[0])
367
+ unet_state_dict[
368
+ f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.norm2.bias"
369
+ ] = torch.from_numpy(second_group_norm.get_weights()[1])
370
+
371
  up_res_block_flag += 1
372
 
373
  # All SpatialTransformer blocks.
 
379
  if len(parts) == 2 or int(parts[-1]) < 6:
380
  down_block_id = 0 if len(parts) == 2 else int(parts[-1]) // 2
381
  down_attention_id = 0 if len(parts) == 2 else int(parts[-1]) % 2
382
+
383
  # Convs.
384
  proj1 = layer.proj1
385
+ unet_state_dict[
386
+ f"down_blocks.{down_block_id}.attentions.{down_attention_id}.proj_in.weight"
387
+ ] = torch.from_numpy(proj1.get_weights()[0].transpose(3, 2, 0, 1))
388
+ unet_state_dict[
389
+ f"down_blocks.{down_block_id}.attentions.{down_attention_id}.proj_in.bias"
390
+ ] = torch.from_numpy(proj1.get_weights()[1])
391
  proj2 = layer.proj2
392
+ unet_state_dict[
393
+ f"down_blocks.{down_block_id}.attentions.{down_attention_id}.proj_out.weight"
394
+ ] = torch.from_numpy(proj2.get_weights()[0].transpose(3, 2, 0, 1))
395
+ unet_state_dict[
396
+ f"down_blocks.{down_block_id}.attentions.{down_attention_id}.proj_out.bias"
397
+ ] = torch.from_numpy(proj2.get_weights()[1])
398
 
399
  # Transformer blocks.
400
  transformer_block = layer.transformer_block
401
+ unet_state_dict.update(
402
+ port_transformer_block(
403
+ transformer_block, "down", down_block_id, down_attention_id
404
+ )
405
+ )
406
 
407
  # Norms.
408
  norm = layer.norm
409
+ unet_state_dict[
410
+ f"down_blocks.{down_block_id}.attentions.{down_attention_id}.norm.weight"
411
+ ] = torch.from_numpy(norm.get_weights()[0])
412
+ unet_state_dict[
413
+ f"down_blocks.{down_block_id}.attentions.{down_attention_id}.norm.bias"
414
+ ] = torch.from_numpy(norm.get_weights()[1])
415
 
416
  # Middle.
417
  elif int(parts[-1]) == 6:
418
  mid_attention_id = int(parts[-1]) % 2
419
  # Convs.
420
  proj1 = layer.proj1
421
+ unet_state_dict[
422
+ f"mid_block.attentions.{mid_attention_id}.proj_in.weight"
423
+ ] = torch.from_numpy(proj1.get_weights()[0].transpose(3, 2, 0, 1))
424
+ unet_state_dict[
425
+ f"mid_block.attentions.{mid_attention_id}.proj_in.bias"
426
+ ] = torch.from_numpy(proj1.get_weights()[1])
427
  proj2 = layer.proj2
428
+ unet_state_dict[
429
+ f"mid_block.attentions.{mid_resnet_id}.proj_out.weight"
430
+ ] = torch.from_numpy(proj2.get_weights()[0].transpose(3, 2, 0, 1))
431
+ unet_state_dict[
432
+ f"mid_block.attentions.{mid_attention_id}.proj_out.bias"
433
+ ] = torch.from_numpy(proj2.get_weights()[1])
434
 
435
  # Transformer blocks.
436
  transformer_block = layer.transformer_block
437
+ unet_state_dict.update(
438
+ port_transformer_block(
439
+ transformer_block, "mid", None, mid_attention_id
440
+ )
441
+ )
442
 
443
  # Norms.
444
  norm = layer.norm
445
+ unet_state_dict[
446
+ f"mid_block.attentions.{mid_attention_id}.norm.weight"
447
+ ] = torch.from_numpy(norm.get_weights()[0])
448
+ unet_state_dict[
449
+ f"mid_block.attentions.{mid_attention_id}.norm.bias"
450
+ ] = torch.from_numpy(norm.get_weights()[1])
451
 
452
  # Up.
453
+ elif int(parts[-1]) > 6 and up_spatial_transformer_flag < len(
454
+ up_spatial_transformer_blocks
455
+ ):
456
+ up_spatial_transformer_block = up_spatial_transformer_blocks[
457
+ up_spatial_transformer_flag
458
+ ]
459
  up_block_id = up_spatial_transformer_block[0]
460
  up_attention_id = up_spatial_transformer_block[1]
461
 
462
  # Convs.
463
  proj1 = layer.proj1
464
+ unet_state_dict[
465
+ f"up_blocks.{up_block_id}.attentions.{up_attention_id}.proj_in.weight"
466
+ ] = torch.from_numpy(proj1.get_weights()[0].transpose(3, 2, 0, 1))
467
+ unet_state_dict[
468
+ f"up_blocks.{up_block_id}.attentions.{up_attention_id}.proj_in.bias"
469
+ ] = torch.from_numpy(proj1.get_weights()[1])
470
  proj2 = layer.proj2
471
+ unet_state_dict[
472
+ f"up_blocks.{up_block_id}.attentions.{up_attention_id}.proj_out.weight"
473
+ ] = torch.from_numpy(proj2.get_weights()[0].transpose(3, 2, 0, 1))
474
+ unet_state_dict[
475
+ f"up_blocks.{up_block_id}.attentions.{up_attention_id}.proj_out.bias"
476
+ ] = torch.from_numpy(proj2.get_weights()[1])
477
 
478
  # Transformer blocks.
479
  transformer_block = layer.transformer_block
480
+ unet_state_dict.update(
481
+ port_transformer_block(
482
+ transformer_block, "up", up_block_id, up_attention_id
483
+ )
484
+ )
485
 
486
  # Norms.
487
  norm = layer.norm
488
+ unet_state_dict[
489
+ f"up_blocks.{up_block_id}.attentions.{up_attention_id}.norm.weight"
490
+ ] = torch.from_numpy(norm.get_weights()[0])
491
+ unet_state_dict[
492
+ f"up_blocks.{up_block_id}.attentions.{up_attention_id}.norm.bias"
493
+ ] = torch.from_numpy(norm.get_weights()[1])
494
+
495
  up_spatial_transformer_flag += 1
496
 
497
+ return unet_state_dict
conversion_utils/utils.py CHANGED
@@ -1,15 +1,19 @@
 
1
 
2
  import numpy as np
3
- import torch
4
- from typing import Dict
5
 
6
 
7
- def run_assertion(orig_pt_state_dict: Dict[str, torch.Tensor], pt_state_dict_from_tf: Dict[str, torch.Tensor]):
 
 
 
8
  for k in orig_pt_state_dict:
9
  try:
10
  np.testing.assert_allclose(
11
- orig_pt_state_dict[k].numpy(),
12
- pt_state_dict_from_tf[k].numpy()
13
  )
14
  except:
15
- raise ValueError("There are problems in the parameter population process. Cannot proceed :(")
 
 
 
1
+ from typing import Dict
2
 
3
  import numpy as np
4
+ import torch
 
5
 
6
 
7
+ def run_assertion(
8
+ orig_pt_state_dict: Dict[str, torch.Tensor],
9
+ pt_state_dict_from_tf: Dict[str, torch.Tensor],
10
+ ):
11
  for k in orig_pt_state_dict:
12
  try:
13
  np.testing.assert_allclose(
14
+ orig_pt_state_dict[k].numpy(), pt_state_dict_from_tf[k].numpy()
 
15
  )
16
  except:
17
+ raise ValueError(
18
+ "There are problems in the parameter population process. Cannot proceed :("
19
+ )
convert.py CHANGED
@@ -1,26 +1,25 @@
1
- from conversion_utils import populate_text_encoder, populate_unet, run_assertion
2
-
3
- from diffusers import (
4
- AutoencoderKL,
5
- StableDiffusionPipeline,
6
- UNet2DConditionModel,
7
- )
8
- from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
9
- from transformers import CLIPTextModel
10
  import keras_cv
11
  import tensorflow as tf
 
 
 
 
 
12
 
 
 
13
 
14
  PRETRAINED_CKPT = "CompVis/stable-diffusion-v1-4"
15
  REVISION = None
16
  NON_EMA_REVISION = None
17
  IMG_HEIGHT = IMG_WIDTH = 512
18
 
 
19
  def initialize_pt_models():
20
  """Initializes the separate models of Stable Diffusion from diffusers and downloads
21
  their pre-trained weights."""
22
  pt_text_encoder = CLIPTextModel.from_pretrained(
23
- PRETRAINED_CKPT, subfolder="text_encoder", revision=REVISION
24
  )
25
  pt_vae = AutoencoderKL.from_pretrained(
26
  PRETRAINED_CKPT, subfolder="vae", revision=REVISION
@@ -34,14 +33,17 @@ def initialize_pt_models():
34
 
35
  return pt_text_encoder, pt_vae, pt_unet, pt_safety_checker
36
 
 
37
  def initialize_tf_models():
38
  """Initializes the separate models of Stable Diffusion from KerasCV and downloads
39
  their pre-trained weights."""
40
- tf_sd_model = keras_cv.models.StableDiffusion(img_height=IMG_HEIGHT, img_width=IMG_WIDTH)
41
- _ = tf_sd_model.text_to_image("Cartoon") # To download the weights.
 
 
42
 
43
- tf_text_encoder = tf_sd_model.text_encoder
44
- tf_vae = tf_sd_model.image_encoder
45
  tf_unet = tf_sd_model.diffusion_model
46
  return tf_sd_model, tf_text_encoder, tf_vae, tf_unet
47
 
@@ -50,7 +52,7 @@ def run_conversion(text_encoder_weights: str = None, unet_weights: str = None):
50
  pt_text_encoder, pt_vae, pt_unet, pt_safety_checker = initialize_pt_models()
51
  tf_sd_model, tf_text_encoder, tf_vae, tf_unet = initialize_tf_models()
52
  print("Pre-trained model weights downloaded.")
53
-
54
  if text_encoder_weights is not None:
55
  print("Loading fine-tuned text encoder weights.")
56
  text_encoder_weights_path = tf.keras.utils.get_file(text_encoder_weights)
@@ -72,7 +74,9 @@ def run_conversion(text_encoder_weights: str = None, unet_weights: str = None):
72
  unet_state_dict_from_pt = pt_text_encoder.state_dict()
73
  run_assertion(unet_state_dict_from_pt, unet_state_dict_from_tf)
74
 
75
- print("Assertions successful, populating the converted parameters into the diffusers models...")
 
 
76
  pt_text_encoder.load_state_dict(text_encoder_state_dict_from_tf)
77
  pt_unet.load_state_dict(unet_state_dict_from_tf)
78
 
@@ -86,5 +90,3 @@ def run_conversion(text_encoder_weights: str = None, unet_weights: str = None):
86
  revision=None,
87
  )
88
  return pipeline
89
-
90
-
 
 
 
 
 
 
 
 
 
 
1
  import keras_cv
2
  import tensorflow as tf
3
+ from diffusers import (AutoencoderKL, StableDiffusionPipeline,
4
+ UNet2DConditionModel)
5
+ from diffusers.pipelines.stable_diffusion.safety_checker import \
6
+ StableDiffusionSafetyChecker
7
+ from transformers import CLIPTextModel
8
 
9
+ from conversion_utils import (populate_text_encoder, populate_unet,
10
+ run_assertion)
11
 
12
  PRETRAINED_CKPT = "CompVis/stable-diffusion-v1-4"
13
  REVISION = None
14
  NON_EMA_REVISION = None
15
  IMG_HEIGHT = IMG_WIDTH = 512
16
 
17
+
18
  def initialize_pt_models():
19
  """Initializes the separate models of Stable Diffusion from diffusers and downloads
20
  their pre-trained weights."""
21
  pt_text_encoder = CLIPTextModel.from_pretrained(
22
+ PRETRAINED_CKPT, subfolder="text_encoder", revision=REVISION
23
  )
24
  pt_vae = AutoencoderKL.from_pretrained(
25
  PRETRAINED_CKPT, subfolder="vae", revision=REVISION
 
33
 
34
  return pt_text_encoder, pt_vae, pt_unet, pt_safety_checker
35
 
36
+
37
  def initialize_tf_models():
38
  """Initializes the separate models of Stable Diffusion from KerasCV and downloads
39
  their pre-trained weights."""
40
+ tf_sd_model = keras_cv.models.StableDiffusion(
41
+ img_height=IMG_HEIGHT, img_width=IMG_WIDTH
42
+ )
43
+ _ = tf_sd_model.text_to_image("Cartoon") # To download the weights.
44
 
45
+ tf_text_encoder = tf_sd_model.text_encoder
46
+ tf_vae = tf_sd_model.image_encoder
47
  tf_unet = tf_sd_model.diffusion_model
48
  return tf_sd_model, tf_text_encoder, tf_vae, tf_unet
49
 
 
52
  pt_text_encoder, pt_vae, pt_unet, pt_safety_checker = initialize_pt_models()
53
  tf_sd_model, tf_text_encoder, tf_vae, tf_unet = initialize_tf_models()
54
  print("Pre-trained model weights downloaded.")
55
+
56
  if text_encoder_weights is not None:
57
  print("Loading fine-tuned text encoder weights.")
58
  text_encoder_weights_path = tf.keras.utils.get_file(text_encoder_weights)
 
74
  unet_state_dict_from_pt = pt_text_encoder.state_dict()
75
  run_assertion(unet_state_dict_from_pt, unet_state_dict_from_tf)
76
 
77
+ print(
78
+ "Assertions successful, populating the converted parameters into the diffusers models..."
79
+ )
80
  pt_text_encoder.load_state_dict(text_encoder_state_dict_from_tf)
81
  pt_unet.load_state_dict(unet_state_dict_from_tf)
82
 
 
90
  revision=None,
91
  )
92
  return pipeline
 
 
hub_utils/__init__.py CHANGED
@@ -1,2 +1,2 @@
1
- from .readme import save_model_card
2
- from .repo import push_to_hub
 
1
+ from .readme import save_model_card
2
+ from .repo import push_to_hub
hub_utils/readme.py CHANGED
@@ -23,7 +23,7 @@ The pipeline contained in this repository was created using [this Space](https:/
23
  """
24
 
25
  if weight_paths is not None:
26
- model_card += "Following weight paths (KerasCV) were used: {weight_paths}"
27
 
28
  with open(os.path.join(repo_folder, "README.md"), "w") as f:
29
- f.write(yaml + model_card)
 
23
  """
24
 
25
  if weight_paths is not None:
26
+ model_card += "Following weight paths (KerasCV) were used: {weight_paths}"
27
 
28
  with open(os.path.join(repo_folder, "README.md"), "w") as f:
29
+ f.write(yaml + model_card)
hub_utils/repo.py CHANGED
@@ -1,5 +1,6 @@
1
  from huggingface_hub import HfApi, create_repo
2
 
 
3
  def push_to_hub(hf_token: str, push_dir: str, repo_prefix: None) -> str:
4
  try:
5
  if hf_token == "":
@@ -7,9 +8,15 @@ def push_to_hub(hf_token: str, push_dir: str, repo_prefix: None) -> str:
7
  else:
8
  hf_api = HfApi(token=hf_token)
9
  user = hf_api.whoami()["name"]
10
- repo_id = f"{user}/{push_dir}" if repo_prefix == "" else f"{user}/{repo_prefix}-{push_dir}"
 
 
 
 
11
  _ = create_repo(repo_id=repo_id, token=hf_token)
12
- url = hf_api.upload_folder(folder_path=push_dir, repo_id=repo_id, exist_ok=True)
 
 
13
  return f"Model successfully pushed: [{url}]({url})"
14
  except Exception as e:
15
- return f"{e}"
 
1
  from huggingface_hub import HfApi, create_repo
2
 
3
+
4
  def push_to_hub(hf_token: str, push_dir: str, repo_prefix: None) -> str:
5
  try:
6
  if hf_token == "":
 
8
  else:
9
  hf_api = HfApi(token=hf_token)
10
  user = hf_api.whoami()["name"]
11
+ repo_id = (
12
+ f"{user}/{push_dir}"
13
+ if repo_prefix == ""
14
+ else f"{user}/{repo_prefix}-{push_dir}"
15
+ )
16
  _ = create_repo(repo_id=repo_id, token=hf_token)
17
+ url = hf_api.upload_folder(
18
+ folder_path=push_dir, repo_id=repo_id, exist_ok=True
19
+ )
20
  return f"Model successfully pushed: [{url}]({url})"
21
  except Exception as e:
22
+ return f"{e}"