Spaces:
Running
Running
bug fixes & using finetune self attentions weights
Browse files- agnostic_mask.png +0 -0
- dog.jpg +0 -0
- garment.jpg +0 -0
- model.py +2 -2
- model_converter.py +83 -1
- person.jpg +0 -0
- test.ipynb +25 -22
agnostic_mask.png
DELETED
|
Binary file (6.26 kB)
|
|
|
dog.jpg
DELETED
|
Binary file (71.1 kB)
|
|
|
garment.jpg
DELETED
|
Binary file (56.5 kB)
|
|
|
model.py
CHANGED
|
@@ -5,12 +5,12 @@ from diffusion import Diffusion
|
|
| 5 |
|
| 6 |
import model_converter
|
| 7 |
|
| 8 |
-
def preload_models_from_standard_weights(ckpt_path, device):
|
| 9 |
# CatVTON parameters
|
| 10 |
in_channels = 9
|
| 11 |
out_channels = 4
|
| 12 |
|
| 13 |
-
state_dict=model_converter.load_from_standard_weights(ckpt_path, device)
|
| 14 |
|
| 15 |
encoder=VAE_Encoder().to(device)
|
| 16 |
encoder.load_state_dict(state_dict['encoder'], strict=True)
|
|
|
|
| 5 |
|
| 6 |
import model_converter
|
| 7 |
|
| 8 |
+
def preload_models_from_standard_weights(ckpt_path, device, finetune_weight_path=None):
|
| 9 |
# CatVTON parameters
|
| 10 |
in_channels = 9
|
| 11 |
out_channels = 4
|
| 12 |
|
| 13 |
+
state_dict=model_converter.load_from_standard_weights(ckpt_path, device, finetune_weight_path)
|
| 14 |
|
| 15 |
encoder=VAE_Encoder().to(device)
|
| 16 |
encoder.load_state_dict(state_dict['encoder'], strict=True)
|
model_converter.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
import torch
|
|
|
|
| 2 |
|
| 3 |
-
def load_from_standard_weights(input_file: str, device: str) -> dict[str, torch.Tensor]:
|
| 4 |
# Taken from: https://github.com/kjsman/stable-diffusion-pytorch/issues/7#issuecomment-1426839447
|
| 5 |
# original_model = torch.load(input_file, map_location=device, weights_only = False)["state_dict"]
|
| 6 |
original_model=torch.load(input_file, weights_only = False)["state_dict"]
|
|
@@ -1054,4 +1055,85 @@ def load_from_standard_weights(input_file: str, device: str) -> dict[str, torch.
|
|
| 1054 |
converted['clip']['layers.11.attention.in_proj.weight'] = torch.cat((original_model['cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.q_proj.weight'], original_model['cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.k_proj.weight'], original_model['cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.v_proj.weight']), 0)
|
| 1055 |
converted['clip']['layers.11.attention.in_proj.bias'] = torch.cat((original_model['cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.q_proj.bias'], original_model['cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.k_proj.bias'], original_model['cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.v_proj.bias']), 0)
|
| 1056 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1057 |
return converted
|
|
|
|
| 1 |
import torch
|
| 2 |
+
import safetensors.torch
|
| 3 |
|
| 4 |
+
def load_from_standard_weights(input_file: str, device: str, finetuned_weights_path: str=None) -> dict[str, torch.Tensor]:
|
| 5 |
# Taken from: https://github.com/kjsman/stable-diffusion-pytorch/issues/7#issuecomment-1426839447
|
| 6 |
# original_model = torch.load(input_file, map_location=device, weights_only = False)["state_dict"]
|
| 7 |
original_model=torch.load(input_file, weights_only = False)["state_dict"]
|
|
|
|
| 1055 |
converted['clip']['layers.11.attention.in_proj.weight'] = torch.cat((original_model['cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.q_proj.weight'], original_model['cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.k_proj.weight'], original_model['cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.v_proj.weight']), 0)
|
| 1056 |
converted['clip']['layers.11.attention.in_proj.bias'] = torch.cat((original_model['cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.q_proj.bias'], original_model['cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.k_proj.bias'], original_model['cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.v_proj.bias']), 0)
|
| 1057 |
|
| 1058 |
+
if finetuned_weights_path is not None:
|
| 1059 |
+
converted=convert_safetensors_to_combined_weights(finetuned_weights_path, converted)
|
| 1060 |
+
|
| 1061 |
+
return converted
|
| 1062 |
+
|
| 1063 |
+
|
| 1064 |
+
def convert_safetensors_to_combined_weights(safetensors_path, converted):
|
| 1065 |
+
"""
|
| 1066 |
+
Convert safetensors with separate q,k,v weights to combined in_proj weights
|
| 1067 |
+
"""
|
| 1068 |
+
# Load the original safetensors
|
| 1069 |
+
state_dict = safetensors.torch.load_file(safetensors_path)
|
| 1070 |
+
|
| 1071 |
+
# Create mapping from your safetensors indices to UNet attention paths
|
| 1072 |
+
# Based on dimension analysis:
|
| 1073 |
+
# 320-dim layers: 0, 8, 96, 104, 112 -> encoders.1,2 and decoders.9,10,11
|
| 1074 |
+
# 640-dim layers: 16, 24, 72, 80, 88 -> encoders.4,5 and decoders.6,7,8
|
| 1075 |
+
# 1280-dim layers: 32, 40, 48, 56, 64, 120 -> encoders.7,8, bottleneck, decoders.3,4,5
|
| 1076 |
+
|
| 1077 |
+
layer_mappings = {
|
| 1078 |
+
# 320-dim layers (encoders)
|
| 1079 |
+
0: "encoders.1.1.attention_1", # [320,320] -> [960,320]
|
| 1080 |
+
8: "encoders.2.1.attention_1", # [320,320] -> [960,320]
|
| 1081 |
+
|
| 1082 |
+
# 640-dim layers (encoders)
|
| 1083 |
+
16: "encoders.4.1.attention_1", # [640,640] -> [1920,640]
|
| 1084 |
+
24: "encoders.5.1.attention_1", # [640,640] -> [1920,640]
|
| 1085 |
+
|
| 1086 |
+
# 1280-dim layers (encoders)
|
| 1087 |
+
32: "encoders.7.1.attention_1", # [1280,1280] -> [3840,1280]
|
| 1088 |
+
40: "encoders.8.1.attention_1", # [1280,1280] -> [3840,1280]
|
| 1089 |
+
|
| 1090 |
+
# 1280-dim layers (bottleneck)
|
| 1091 |
+
48: "bottleneck.1.attention_1", # [1280,1280] -> [3840,1280]
|
| 1092 |
+
|
| 1093 |
+
# 1280-dim layers (decoders)
|
| 1094 |
+
56: "decoders.3.1.attention_1", # [1280,1280] -> [3840,1280]
|
| 1095 |
+
64: "decoders.4.1.attention_1", # [1280,1280] -> [3840,1280]
|
| 1096 |
+
120: "decoders.5.1.attention_1", # [1280,1280] -> [3840,1280]
|
| 1097 |
+
|
| 1098 |
+
# 640-dim layers (decoders)
|
| 1099 |
+
72: "decoders.6.1.attention_1", # [640,640] -> [1920,640]
|
| 1100 |
+
80: "decoders.7.1.attention_1", # [640,640] -> [1920,640]
|
| 1101 |
+
88: "decoders.8.1.attention_1", # [640,640] -> [1920,640]
|
| 1102 |
+
|
| 1103 |
+
# 320-dim layers (decoders)
|
| 1104 |
+
96: "decoders.9.1.attention_1", # [320,320] -> [960,320]
|
| 1105 |
+
104: "decoders.10.1.attention_1", # [320,320] -> [960,320]
|
| 1106 |
+
112: "decoders.11.1.attention_1" # [320,320] -> [960,320]
|
| 1107 |
+
}
|
| 1108 |
+
|
| 1109 |
+
|
| 1110 |
+
for layer_idx, unet_path in layer_mappings.items():
|
| 1111 |
+
# Get the q, k, v weights for this layer
|
| 1112 |
+
q_key = f"{layer_idx}.to_q.weight"
|
| 1113 |
+
k_key = f"{layer_idx}.to_k.weight"
|
| 1114 |
+
v_key = f"{layer_idx}.to_v.weight"
|
| 1115 |
+
out_weight_key = f"{layer_idx}.to_out.0.weight"
|
| 1116 |
+
out_bias_key = f"{layer_idx}.to_out.0.bias"
|
| 1117 |
+
|
| 1118 |
+
if all(key in state_dict for key in [q_key, k_key, v_key]):
|
| 1119 |
+
# Concatenate q, k, v weights along dimension 0 to create in_proj weight
|
| 1120 |
+
q_weight = state_dict[q_key]
|
| 1121 |
+
k_weight = state_dict[k_key]
|
| 1122 |
+
v_weight = state_dict[v_key]
|
| 1123 |
+
|
| 1124 |
+
# Combine into single in_proj matrix
|
| 1125 |
+
in_proj_weight = torch.cat([q_weight, k_weight, v_weight], dim=0)
|
| 1126 |
+
|
| 1127 |
+
# Store in converted format
|
| 1128 |
+
converted['diffusion'][f"unet.{unet_path}.in_proj.weight"] = in_proj_weight
|
| 1129 |
+
|
| 1130 |
+
# Also handle output projection weights
|
| 1131 |
+
if out_weight_key in state_dict:
|
| 1132 |
+
converted['diffusion'][f"unet.{unet_path}.out_proj.weight"] = state_dict[out_weight_key]
|
| 1133 |
+
|
| 1134 |
+
if out_bias_key in state_dict:
|
| 1135 |
+
converted['diffusion'][f"unet.{unet_path}.out_proj.bias"] = state_dict[out_bias_key]
|
| 1136 |
+
|
| 1137 |
+
print(f"Converted layer {layer_idx}: {q_weight.shape} + {k_weight.shape} + {v_weight.shape} -> {in_proj_weight.shape}")
|
| 1138 |
+
|
| 1139 |
return converted
|
person.jpg
DELETED
|
Binary file (30.9 kB)
|
|
|
test.ipynb
CHANGED
|
@@ -169,28 +169,19 @@
|
|
| 169 |
},
|
| 170 |
{
|
| 171 |
"cell_type": "code",
|
| 172 |
-
"execution_count":
|
| 173 |
"id": "13c59a6c",
|
| 174 |
"metadata": {},
|
| 175 |
"outputs": [
|
| 176 |
{
|
| 177 |
-
"
|
| 178 |
-
"
|
| 179 |
-
"
|
| 180 |
-
|
| 181 |
-
"
|
| 182 |
-
"
|
| 183 |
-
"
|
| 184 |
-
"
|
| 185 |
-
"Latents shape: torch.Size([1, 4, 64, 128])\n",
|
| 186 |
-
"Masked latent concat for classifier-free guidance: torch.Size([2, 4, 64, 128]), mask latent concat: torch.Size([2, 1, 64, 128])\n"
|
| 187 |
-
]
|
| 188 |
-
},
|
| 189 |
-
{
|
| 190 |
-
"name": "stderr",
|
| 191 |
-
"output_type": "stream",
|
| 192 |
-
"text": [
|
| 193 |
-
"100%|ββββββββββ| 50/50 [01:20<00:00, 1.62s/it]\n"
|
| 194 |
]
|
| 195 |
}
|
| 196 |
],
|
|
@@ -586,12 +577,12 @@
|
|
| 586 |
"\n",
|
| 587 |
"if __name__ == \"__main__\":\n",
|
| 588 |
" # Example usage\n",
|
| 589 |
-
" image = Image.open(\"
|
| 590 |
-
" condition_image = Image.open(\"
|
| 591 |
-
" mask = Image.open(\"agnostic_mask.png\").convert(\"L\")\n",
|
| 592 |
"\n",
|
| 593 |
" # Load models\n",
|
| 594 |
-
" models=model.preload_models_from_standard_weights(
|
| 595 |
"\n",
|
| 596 |
" # Generate image\n",
|
| 597 |
" generated_image = generate(\n",
|
|
@@ -733,6 +724,18 @@
|
|
| 733 |
"display_name": "Python 3 (ipykernel)",
|
| 734 |
"language": "python",
|
| 735 |
"name": "python3"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 736 |
}
|
| 737 |
},
|
| 738 |
"nbformat": 4,
|
|
|
|
| 169 |
},
|
| 170 |
{
|
| 171 |
"cell_type": "code",
|
| 172 |
+
"execution_count": null,
|
| 173 |
"id": "13c59a6c",
|
| 174 |
"metadata": {},
|
| 175 |
"outputs": [
|
| 176 |
{
|
| 177 |
+
"ename": "ModuleNotFoundError",
|
| 178 |
+
"evalue": "No module named 'ddpm'",
|
| 179 |
+
"output_type": "error",
|
| 180 |
+
"traceback": [
|
| 181 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
| 182 |
+
"\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
|
| 183 |
+
"\u001b[0;32m/tmp/ipykernel_391/3664407558.py\u001b[0m in \u001b[0;36m<cell line: 0>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mnumpy\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtqdm\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtqdm\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 7\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mddpm\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mDDPMSampler\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 8\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mPIL\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mImage\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
| 184 |
+
"\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'ddpm'"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
]
|
| 186 |
}
|
| 187 |
],
|
|
|
|
| 577 |
"\n",
|
| 578 |
"if __name__ == \"__main__\":\n",
|
| 579 |
" # Example usage\n",
|
| 580 |
+
" image = Image.open(\"sample_dataset/image.png\").convert(\"RGB\")\n",
|
| 581 |
+
" condition_image = Image.open(\"sample_dataset/cloth.png\").convert(\"RGB\")\n",
|
| 582 |
+
" mask = Image.open(\"sample_dataset/agnostic_mask.png\").convert(\"L\")\n",
|
| 583 |
"\n",
|
| 584 |
" # Load models\n",
|
| 585 |
+
" models=model.preload_models_from_standard_weights(ckpt_path=\"sd-v1-5-inpainting.ckpt\", device=\"cuda\", finetune_weight_path=\"model.safetensors\")\n",
|
| 586 |
"\n",
|
| 587 |
" # Generate image\n",
|
| 588 |
" generated_image = generate(\n",
|
|
|
|
| 724 |
"display_name": "Python 3 (ipykernel)",
|
| 725 |
"language": "python",
|
| 726 |
"name": "python3"
|
| 727 |
+
},
|
| 728 |
+
"language_info": {
|
| 729 |
+
"codemirror_mode": {
|
| 730 |
+
"name": "ipython",
|
| 731 |
+
"version": 3
|
| 732 |
+
},
|
| 733 |
+
"file_extension": ".py",
|
| 734 |
+
"mimetype": "text/x-python",
|
| 735 |
+
"name": "python",
|
| 736 |
+
"nbconvert_exporter": "python",
|
| 737 |
+
"pygments_lexer": "ipython3",
|
| 738 |
+
"version": "3.11.11"
|
| 739 |
}
|
| 740 |
},
|
| 741 |
"nbformat": 4,
|