Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
fc411f7
1
Parent(s):
0879767
Initial commit
Browse files
mvdiffusion/models/unet_mv2d_condition.py
CHANGED
|
@@ -35,7 +35,7 @@ from diffusers.models.embeddings import (
|
|
| 35 |
TimestepEmbedding,
|
| 36 |
Timesteps,
|
| 37 |
)
|
| 38 |
-
from diffusers.models.modeling_utils import ModelMixin, load_state_dict
|
| 39 |
from diffusers.models.unets.unet_2d_blocks import (
|
| 40 |
CrossAttnDownBlock2D,
|
| 41 |
CrossAttnUpBlock2D,
|
|
@@ -1506,6 +1506,7 @@ class UNetMV2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixi
|
|
| 1506 |
del state_dict[checkpoint_key]
|
| 1507 |
return mismatched_keys
|
| 1508 |
|
|
|
|
| 1509 |
if state_dict is not None:
|
| 1510 |
# Whole checkpoint
|
| 1511 |
mismatched_keys = _find_mismatched_keys(
|
|
@@ -1514,7 +1515,11 @@ class UNetMV2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixi
|
|
| 1514 |
original_loaded_keys,
|
| 1515 |
ignore_mismatched_sizes,
|
| 1516 |
)
|
| 1517 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1518 |
|
| 1519 |
if len(error_msgs) > 0:
|
| 1520 |
error_msg = "\n\t".join(error_msgs)
|
|
|
|
| 35 |
TimestepEmbedding,
|
| 36 |
Timesteps,
|
| 37 |
)
|
| 38 |
+
from diffusers.models.modeling_utils import ModelMixin, load_state_dict
|
| 39 |
from diffusers.models.unets.unet_2d_blocks import (
|
| 40 |
CrossAttnDownBlock2D,
|
| 41 |
CrossAttnUpBlock2D,
|
|
|
|
| 1506 |
del state_dict[checkpoint_key]
|
| 1507 |
return mismatched_keys
|
| 1508 |
|
| 1509 |
+
error_msgs = []
|
| 1510 |
if state_dict is not None:
|
| 1511 |
# Whole checkpoint
|
| 1512 |
mismatched_keys = _find_mismatched_keys(
|
|
|
|
| 1515 |
original_loaded_keys,
|
| 1516 |
ignore_mismatched_sizes,
|
| 1517 |
)
|
| 1518 |
+
# Use PyTorch's load_state_dict with strict=False to handle mismatched keys
|
| 1519 |
+
incompatible_keys = model_to_load.load_state_dict(state_dict, strict=False)
|
| 1520 |
+
if incompatible_keys.missing_keys:
|
| 1521 |
+
error_msgs.append(f"Missing keys: {incompatible_keys.missing_keys}")
|
| 1522 |
+
# Note: unexpected_keys are already tracked separately, so we don't include them here
|
| 1523 |
|
| 1524 |
if len(error_msgs) > 0:
|
| 1525 |
error_msg = "\n\t".join(error_msgs)
|