bluestarburst
commited on
Commit
·
83f7703
1
Parent(s):
96eabf2
Upload folder using huggingface_hub
Browse files
animatediff/utils/__pycache__/convert_from_ckpt.cpython-310.pyc
CHANGED
Binary files a/animatediff/utils/__pycache__/convert_from_ckpt.cpython-310.pyc and b/animatediff/utils/__pycache__/convert_from_ckpt.cpython-310.pyc differ
|
|
animatediff/utils/convert_from_ckpt.py
CHANGED
@@ -198,20 +198,21 @@ def assign_to_checkpoint(
|
|
198 |
new_path = new_path.replace(replacement["old"], replacement["new"])
|
199 |
|
200 |
# proj_attn.weight has to be converted from conv 1D to linear
|
201 |
-
if "
|
202 |
-
|
|
|
203 |
else:
|
204 |
checkpoint[new_path] = old_checkpoint[path["old"]]
|
205 |
|
206 |
|
207 |
def conv_attn_to_linear(checkpoint):
|
208 |
keys = list(checkpoint.keys())
|
209 |
-
attn_keys = ["
|
210 |
for key in keys:
|
211 |
if ".".join(key.split(".")[-2:]) in attn_keys:
|
212 |
if checkpoint[key].ndim > 2:
|
213 |
checkpoint[key] = checkpoint[key][:, :, 0, 0]
|
214 |
-
elif "
|
215 |
if checkpoint[key].ndim > 2:
|
216 |
checkpoint[key] = checkpoint[key][:, :, 0]
|
217 |
|
@@ -632,7 +633,7 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
|
|
632 |
oldKey = {"old": "key", "new": "to_k"}
|
633 |
oldQuery = {"old": "query", "new": "to_q"}
|
634 |
oldValue = {"old": "value", "new": "to_v"}
|
635 |
-
oldOut = {"old": "proj_attn", "new": "to_out"}
|
636 |
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path, oldKey, oldQuery, oldValue, oldOut], config=config)
|
637 |
conv_attn_to_linear(new_checkpoint)
|
638 |
|
@@ -669,7 +670,7 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
|
|
669 |
oldKey = {"old": "key", "new": "to_k"}
|
670 |
oldQuery = {"old": "query", "new": "to_q"}
|
671 |
oldValue = {"old": "value", "new": "to_v"}
|
672 |
-
oldOut = {"old": "proj_attn", "new": "to_out"}
|
673 |
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path, oldKey, oldQuery, oldValue, oldOut], config=config)
|
674 |
conv_attn_to_linear(new_checkpoint)
|
675 |
return new_checkpoint
|
|
|
198 |
new_path = new_path.replace(replacement["old"], replacement["new"])
|
199 |
|
200 |
# proj_attn.weight has to be converted from conv 1D to linear
|
201 |
+
if "to_out.0.weight" in new_path and "decoder" in new_path:
|
202 |
+
# turn [512, 512, 1] into [512, 512]
|
203 |
+
checkpoint[new_path] = old_checkpoint[path["old"]].squeeze(-1)
|
204 |
else:
|
205 |
checkpoint[new_path] = old_checkpoint[path["old"]]
|
206 |
|
207 |
|
208 |
def conv_attn_to_linear(checkpoint):
|
209 |
keys = list(checkpoint.keys())
|
210 |
+
attn_keys = ["to_q.weight", "to_k.weight", "to_v.weight"]
|
211 |
for key in keys:
|
212 |
if ".".join(key.split(".")[-2:]) in attn_keys:
|
213 |
if checkpoint[key].ndim > 2:
|
214 |
checkpoint[key] = checkpoint[key][:, :, 0, 0]
|
215 |
+
elif "to_out.0.weight" in key:
|
216 |
if checkpoint[key].ndim > 2:
|
217 |
checkpoint[key] = checkpoint[key][:, :, 0]
|
218 |
|
|
|
633 |
oldKey = {"old": "key", "new": "to_k"}
|
634 |
oldQuery = {"old": "query", "new": "to_q"}
|
635 |
oldValue = {"old": "value", "new": "to_v"}
|
636 |
+
oldOut = {"old": "proj_attn", "new": "to_out.0"}
|
637 |
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path, oldKey, oldQuery, oldValue, oldOut], config=config)
|
638 |
conv_attn_to_linear(new_checkpoint)
|
639 |
|
|
|
670 |
oldKey = {"old": "key", "new": "to_k"}
|
671 |
oldQuery = {"old": "query", "new": "to_q"}
|
672 |
oldValue = {"old": "value", "new": "to_v"}
|
673 |
+
oldOut = {"old": "proj_attn", "new": "to_out.0"}
|
674 |
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path, oldKey, oldQuery, oldValue, oldOut], config=config)
|
675 |
conv_attn_to_linear(new_checkpoint)
|
676 |
return new_checkpoint
|