harsh99 commited on
Commit
24ffbfb
Β·
1 Parent(s): b9e9532

bug fixes & using finetune self attentions weights

Browse files
Files changed (7) hide show
  1. agnostic_mask.png +0 -0
  2. dog.jpg +0 -0
  3. garment.jpg +0 -0
  4. model.py +2 -2
  5. model_converter.py +83 -1
  6. person.jpg +0 -0
  7. test.ipynb +25 -22
agnostic_mask.png DELETED
Binary file (6.26 kB)
 
dog.jpg DELETED
Binary file (71.1 kB)
 
garment.jpg DELETED
Binary file (56.5 kB)
 
model.py CHANGED
@@ -5,12 +5,12 @@ from diffusion import Diffusion
5
 
6
  import model_converter
7
 
8
- def preload_models_from_standard_weights(ckpt_path, device):
9
  # CatVTON parameters
10
  in_channels = 9
11
  out_channels = 4
12
 
13
- state_dict=model_converter.load_from_standard_weights(ckpt_path, device)
14
 
15
  encoder=VAE_Encoder().to(device)
16
  encoder.load_state_dict(state_dict['encoder'], strict=True)
 
5
 
6
  import model_converter
7
 
8
+ def preload_models_from_standard_weights(ckpt_path, device, finetune_weight_path=None):
9
  # CatVTON parameters
10
  in_channels = 9
11
  out_channels = 4
12
 
13
+ state_dict=model_converter.load_from_standard_weights(ckpt_path, device, finetune_weight_path)
14
 
15
  encoder=VAE_Encoder().to(device)
16
  encoder.load_state_dict(state_dict['encoder'], strict=True)
model_converter.py CHANGED
@@ -1,6 +1,7 @@
1
  import torch
 
2
 
3
- def load_from_standard_weights(input_file: str, device: str) -> dict[str, torch.Tensor]:
4
  # Taken from: https://github.com/kjsman/stable-diffusion-pytorch/issues/7#issuecomment-1426839447
5
  # original_model = torch.load(input_file, map_location=device, weights_only = False)["state_dict"]
6
  original_model=torch.load(input_file, weights_only = False)["state_dict"]
@@ -1054,4 +1055,85 @@ def load_from_standard_weights(input_file: str, device: str) -> dict[str, torch.
1054
  converted['clip']['layers.11.attention.in_proj.weight'] = torch.cat((original_model['cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.q_proj.weight'], original_model['cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.k_proj.weight'], original_model['cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.v_proj.weight']), 0)
1055
  converted['clip']['layers.11.attention.in_proj.bias'] = torch.cat((original_model['cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.q_proj.bias'], original_model['cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.k_proj.bias'], original_model['cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.v_proj.bias']), 0)
1056
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1057
  return converted
 
1
  import torch
2
+ import safetensors.torch
3
 
4
+ def load_from_standard_weights(input_file: str, device: str, finetuned_weights_path: str=None) -> dict[str, torch.Tensor]:
5
  # Taken from: https://github.com/kjsman/stable-diffusion-pytorch/issues/7#issuecomment-1426839447
6
  # original_model = torch.load(input_file, map_location=device, weights_only = False)["state_dict"]
7
  original_model=torch.load(input_file, weights_only = False)["state_dict"]
 
1055
  converted['clip']['layers.11.attention.in_proj.weight'] = torch.cat((original_model['cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.q_proj.weight'], original_model['cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.k_proj.weight'], original_model['cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.v_proj.weight']), 0)
1056
  converted['clip']['layers.11.attention.in_proj.bias'] = torch.cat((original_model['cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.q_proj.bias'], original_model['cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.k_proj.bias'], original_model['cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.v_proj.bias']), 0)
1057
 
1058
+ if finetuned_weights_path is not None:
1059
+ converted=convert_safetensors_to_combined_weights(finetuned_weights_path, converted)
1060
+
1061
+ return converted
1062
+
1063
+
1064
+ def convert_safetensors_to_combined_weights(safetensors_path, converted):
1065
+ """
1066
+ Convert safetensors with separate q,k,v weights to combined in_proj weights
1067
+ """
1068
+ # Load the original safetensors
1069
+ state_dict = safetensors.torch.load_file(safetensors_path)
1070
+
1071
+ # Create mapping from your safetensors indices to UNet attention paths
1072
+ # Based on dimension analysis:
1073
+ # 320-dim layers: 0, 8, 96, 104, 112 -> encoders.1,2 and decoders.9,10,11
1074
+ # 640-dim layers: 16, 24, 72, 80, 88 -> encoders.4,5 and decoders.6,7,8
1075
+ # 1280-dim layers: 32, 40, 48, 56, 64, 120 -> encoders.7,8, bottleneck, decoders.3,4,5
1076
+
1077
+ layer_mappings = {
1078
+ # 320-dim layers (encoders)
1079
+ 0: "encoders.1.1.attention_1", # [320,320] -> [960,320]
1080
+ 8: "encoders.2.1.attention_1", # [320,320] -> [960,320]
1081
+
1082
+ # 640-dim layers (encoders)
1083
+ 16: "encoders.4.1.attention_1", # [640,640] -> [1920,640]
1084
+ 24: "encoders.5.1.attention_1", # [640,640] -> [1920,640]
1085
+
1086
+ # 1280-dim layers (encoders)
1087
+ 32: "encoders.7.1.attention_1", # [1280,1280] -> [3840,1280]
1088
+ 40: "encoders.8.1.attention_1", # [1280,1280] -> [3840,1280]
1089
+
1090
+ # 1280-dim layers (bottleneck)
1091
+ 48: "bottleneck.1.attention_1", # [1280,1280] -> [3840,1280]
1092
+
1093
+ # 1280-dim layers (decoders)
1094
+ 56: "decoders.3.1.attention_1", # [1280,1280] -> [3840,1280]
1095
+ 64: "decoders.4.1.attention_1", # [1280,1280] -> [3840,1280]
1096
+ 120: "decoders.5.1.attention_1", # [1280,1280] -> [3840,1280]
1097
+
1098
+ # 640-dim layers (decoders)
1099
+ 72: "decoders.6.1.attention_1", # [640,640] -> [1920,640]
1100
+ 80: "decoders.7.1.attention_1", # [640,640] -> [1920,640]
1101
+ 88: "decoders.8.1.attention_1", # [640,640] -> [1920,640]
1102
+
1103
+ # 320-dim layers (decoders)
1104
+ 96: "decoders.9.1.attention_1", # [320,320] -> [960,320]
1105
+ 104: "decoders.10.1.attention_1", # [320,320] -> [960,320]
1106
+ 112: "decoders.11.1.attention_1" # [320,320] -> [960,320]
1107
+ }
1108
+
1109
+
1110
+ for layer_idx, unet_path in layer_mappings.items():
1111
+ # Get the q, k, v weights for this layer
1112
+ q_key = f"{layer_idx}.to_q.weight"
1113
+ k_key = f"{layer_idx}.to_k.weight"
1114
+ v_key = f"{layer_idx}.to_v.weight"
1115
+ out_weight_key = f"{layer_idx}.to_out.0.weight"
1116
+ out_bias_key = f"{layer_idx}.to_out.0.bias"
1117
+
1118
+ if all(key in state_dict for key in [q_key, k_key, v_key]):
1119
+ # Concatenate q, k, v weights along dimension 0 to create in_proj weight
1120
+ q_weight = state_dict[q_key]
1121
+ k_weight = state_dict[k_key]
1122
+ v_weight = state_dict[v_key]
1123
+
1124
+ # Combine into single in_proj matrix
1125
+ in_proj_weight = torch.cat([q_weight, k_weight, v_weight], dim=0)
1126
+
1127
+ # Store in converted format
1128
+ converted['diffusion'][f"unet.{unet_path}.in_proj.weight"] = in_proj_weight
1129
+
1130
+ # Also handle output projection weights
1131
+ if out_weight_key in state_dict:
1132
+ converted['diffusion'][f"unet.{unet_path}.out_proj.weight"] = state_dict[out_weight_key]
1133
+
1134
+ if out_bias_key in state_dict:
1135
+ converted['diffusion'][f"unet.{unet_path}.out_proj.bias"] = state_dict[out_bias_key]
1136
+
1137
+ print(f"Converted layer {layer_idx}: {q_weight.shape} + {k_weight.shape} + {v_weight.shape} -> {in_proj_weight.shape}")
1138
+
1139
  return converted
person.jpg DELETED
Binary file (30.9 kB)
 
test.ipynb CHANGED
@@ -169,28 +169,19 @@
169
  },
170
  {
171
  "cell_type": "code",
172
- "execution_count": 18,
173
  "id": "13c59a6c",
174
  "metadata": {},
175
  "outputs": [
176
  {
177
- "name": "stdout",
178
- "output_type": "stream",
179
- "text": [
180
- "Prepared image shape: torch.Size([1, 3, 512, 512]), condition image shape: torch.Size([1, 3, 512, 512]), mask shape: torch.Size([1, 1, 512, 512])\n",
181
- "Masked image shape: torch.Size([1, 3, 512, 512])\n",
182
- "Masked latent shape: torch.Size([1, 4, 64, 64]), condition latent shape: torch.Size([1, 4, 64, 64])\n",
183
- "Masked Person latent + garment latent: torch.Size([1, 4, 64, 128])\n",
184
- "Mask latent concat shape: torch.Size([1, 1, 64, 128])\n",
185
- "Latents shape: torch.Size([1, 4, 64, 128])\n",
186
- "Masked latent concat for classifier-free guidance: torch.Size([2, 4, 64, 128]), mask latent concat: torch.Size([2, 1, 64, 128])\n"
187
- ]
188
- },
189
- {
190
- "name": "stderr",
191
- "output_type": "stream",
192
- "text": [
193
- "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 50/50 [01:20<00:00, 1.62s/it]\n"
194
  ]
195
  }
196
  ],
@@ -586,12 +577,12 @@
586
  "\n",
587
  "if __name__ == \"__main__\":\n",
588
  " # Example usage\n",
589
- " image = Image.open(\"person.jpg\").convert(\"RGB\")\n",
590
- " condition_image = Image.open(\"image.png\").convert(\"RGB\")\n",
591
- " mask = Image.open(\"agnostic_mask.png\").convert(\"L\")\n",
592
  "\n",
593
  " # Load models\n",
594
- " models=model.preload_models_from_standard_weights(\"sd-v1-5-inpainting.ckpt\", device=\"cuda\")\n",
595
  "\n",
596
  " # Generate image\n",
597
  " generated_image = generate(\n",
@@ -733,6 +724,18 @@
733
  "display_name": "Python 3 (ipykernel)",
734
  "language": "python",
735
  "name": "python3"
 
 
 
 
 
 
 
 
 
 
 
 
736
  }
737
  },
738
  "nbformat": 4,
 
169
  },
170
  {
171
  "cell_type": "code",
172
+ "execution_count": null,
173
  "id": "13c59a6c",
174
  "metadata": {},
175
  "outputs": [
176
  {
177
+ "ename": "ModuleNotFoundError",
178
+ "evalue": "No module named 'ddpm'",
179
+ "output_type": "error",
180
+ "traceback": [
181
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
182
+ "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
183
+ "\u001b[0;32m/tmp/ipykernel_391/3664407558.py\u001b[0m in \u001b[0;36m<cell line: 0>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mnumpy\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtqdm\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtqdm\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 7\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mddpm\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mDDPMSampler\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 8\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mPIL\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mImage\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
184
+ "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'ddpm'"
 
 
 
 
 
 
 
 
 
185
  ]
186
  }
187
  ],
 
577
  "\n",
578
  "if __name__ == \"__main__\":\n",
579
  " # Example usage\n",
580
+ " image = Image.open(\"sample_dataset/image.png\").convert(\"RGB\")\n",
581
+ " condition_image = Image.open(\"sample_dataset/cloth.png\").convert(\"RGB\")\n",
582
+ " mask = Image.open(\"sample_dataset/agnostic_mask.png\").convert(\"L\")\n",
583
  "\n",
584
  " # Load models\n",
585
+ " models=model.preload_models_from_standard_weights(ckpt_path=\"sd-v1-5-inpainting.ckpt\", device=\"cuda\", finetune_weight_path=\"model.safetensors\")\n",
586
  "\n",
587
  " # Generate image\n",
588
  " generated_image = generate(\n",
 
724
  "display_name": "Python 3 (ipykernel)",
725
  "language": "python",
726
  "name": "python3"
727
+ },
728
+ "language_info": {
729
+ "codemirror_mode": {
730
+ "name": "ipython",
731
+ "version": 3
732
+ },
733
+ "file_extension": ".py",
734
+ "mimetype": "text/x-python",
735
+ "name": "python",
736
+ "nbconvert_exporter": "python",
737
+ "pygments_lexer": "ipython3",
738
+ "version": "3.11.11"
739
  }
740
  },
741
  "nbformat": 4,