wlyu-adobe commited on
Commit
fc411f7
·
1 Parent(s): 0879767

Initial commit

Browse files
mvdiffusion/models/unet_mv2d_condition.py CHANGED
@@ -35,7 +35,7 @@ from diffusers.models.embeddings import (
35
  TimestepEmbedding,
36
  Timesteps,
37
  )
38
- from diffusers.models.modeling_utils import ModelMixin, load_state_dict, _load_state_dict_into_model
39
  from diffusers.models.unets.unet_2d_blocks import (
40
  CrossAttnDownBlock2D,
41
  CrossAttnUpBlock2D,
@@ -1506,6 +1506,7 @@ class UNetMV2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixi
1506
  del state_dict[checkpoint_key]
1507
  return mismatched_keys
1508
 
 
1509
  if state_dict is not None:
1510
  # Whole checkpoint
1511
  mismatched_keys = _find_mismatched_keys(
@@ -1514,7 +1515,11 @@ class UNetMV2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixi
1514
  original_loaded_keys,
1515
  ignore_mismatched_sizes,
1516
  )
1517
- error_msgs = _load_state_dict_into_model(model_to_load, state_dict)
 
 
 
 
1518
 
1519
  if len(error_msgs) > 0:
1520
  error_msg = "\n\t".join(error_msgs)
 
35
  TimestepEmbedding,
36
  Timesteps,
37
  )
38
+ from diffusers.models.modeling_utils import ModelMixin, load_state_dict
39
  from diffusers.models.unets.unet_2d_blocks import (
40
  CrossAttnDownBlock2D,
41
  CrossAttnUpBlock2D,
 
1506
  del state_dict[checkpoint_key]
1507
  return mismatched_keys
1508
 
1509
+ error_msgs = []
1510
  if state_dict is not None:
1511
  # Whole checkpoint
1512
  mismatched_keys = _find_mismatched_keys(
 
1515
  original_loaded_keys,
1516
  ignore_mismatched_sizes,
1517
  )
1518
+ # Use PyTorch's load_state_dict with strict=False to handle mismatched keys
1519
+ incompatible_keys = model_to_load.load_state_dict(state_dict, strict=False)
1520
+ if incompatible_keys.missing_keys:
1521
+ error_msgs.append(f"Missing keys: {incompatible_keys.missing_keys}")
1522
+ # Note: unexpected_keys are already tracked separately, so we don't include them here
1523
 
1524
  if len(error_msgs) > 0:
1525
  error_msg = "\n\t".join(error_msgs)