maxin-cn commited on
Commit
db5f6ff
1 Parent(s): ceb3f0c

Update models/unet.py

Browse files
Files changed (1) hide show
  1. models/unet.py +45 -45
models/unet.py CHANGED
@@ -640,53 +640,53 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin):
640
  # 'CrossAttnUpBlock3D']}
641
 
642
  model = cls.from_config(config)
643
- model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
644
- if not os.path.isfile(model_file):
645
- raise RuntimeError(f"{model_file} does not exist")
646
- state_dict = torch.load(model_file, map_location="cpu")
647
-
648
- if use_concat:
649
- new_state_dict = {}
650
- conv_in_weight = state_dict["conv_in.weight"]
651
- new_conv_weight = torch.zeros((conv_in_weight.shape[0], 9, *conv_in_weight.shape[2:]), dtype=conv_in_weight.dtype)
652
 
653
- for i, j in zip([0, 1, 2, 3], [0, 1, 2, 3, 4, 5, 6, 7, 8]):
654
- new_conv_weight[:, j] = conv_in_weight[:, i]
655
- new_state_dict["conv_in.weight"] = new_conv_weight
656
- new_state_dict["conv_in.bias"] = state_dict["conv_in.bias"]
657
- for k, v in model.state_dict().items():
658
- # print(k)
659
- if '_temp.' in k:
660
- new_state_dict.update({k: v})
661
- if 'attn_fcross' in k: # conpy parms of attn1 to attn_fcross
662
- k = k.replace('attn_fcross', 'attn1')
663
- state_dict.update({k: state_dict[k]})
664
- if 'norm_fcross' in k:
665
- k = k.replace('norm_fcross', 'norm1')
666
- state_dict.update({k: state_dict[k]})
667
 
668
- if 'conv_in' in k:
669
- continue
670
- else:
671
- new_state_dict[k] = v
672
- # # tmp
673
- # if 'class_embedding' in k:
674
- # state_dict.update({k: v})
675
- # breakpoint()
676
- model.load_state_dict(new_state_dict)
677
- else:
678
- for k, v in model.state_dict().items():
679
- # print(k)
680
- if '_temp' in k:
681
- state_dict.update({k: v})
682
- if 'attn_fcross' in k: # conpy parms of attn1 to attn_fcross
683
- k = k.replace('attn_fcross', 'attn1')
684
- state_dict.update({k: state_dict[k]})
685
- if 'norm_fcross' in k:
686
- k = k.replace('norm_fcross', 'norm1')
687
- state_dict.update({k: state_dict[k]})
688
-
689
- model.load_state_dict(state_dict)
690
 
691
  return model
692
 
 
640
  # 'CrossAttnUpBlock3D']}
641
 
642
  model = cls.from_config(config)
643
+ # model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
644
+ # if not os.path.isfile(model_file):
645
+ # raise RuntimeError(f"{model_file} does not exist")
646
+ # state_dict = torch.load(model_file, map_location="cpu")
647
+
648
+ # if use_concat:
649
+ # new_state_dict = {}
650
+ # conv_in_weight = state_dict["conv_in.weight"]
651
+ # new_conv_weight = torch.zeros((conv_in_weight.shape[0], 9, *conv_in_weight.shape[2:]), dtype=conv_in_weight.dtype)
652
 
653
+ # for i, j in zip([0, 1, 2, 3], [0, 1, 2, 3, 4, 5, 6, 7, 8]):
654
+ # new_conv_weight[:, j] = conv_in_weight[:, i]
655
+ # new_state_dict["conv_in.weight"] = new_conv_weight
656
+ # new_state_dict["conv_in.bias"] = state_dict["conv_in.bias"]
657
+ # for k, v in model.state_dict().items():
658
+ # # print(k)
659
+ # if '_temp.' in k:
660
+ # new_state_dict.update({k: v})
661
+ # if 'attn_fcross' in k: # conpy parms of attn1 to attn_fcross
662
+ # k = k.replace('attn_fcross', 'attn1')
663
+ # state_dict.update({k: state_dict[k]})
664
+ # if 'norm_fcross' in k:
665
+ # k = k.replace('norm_fcross', 'norm1')
666
+ # state_dict.update({k: state_dict[k]})
667
 
668
+ # if 'conv_in' in k:
669
+ # continue
670
+ # else:
671
+ # new_state_dict[k] = v
672
+ # # # tmp
673
+ # # if 'class_embedding' in k:
674
+ # # state_dict.update({k: v})
675
+ # # breakpoint()
676
+ # model.load_state_dict(new_state_dict)
677
+ # else:
678
+ # for k, v in model.state_dict().items():
679
+ # # print(k)
680
+ # if '_temp' in k:
681
+ # state_dict.update({k: v})
682
+ # if 'attn_fcross' in k: # conpy parms of attn1 to attn_fcross
683
+ # k = k.replace('attn_fcross', 'attn1')
684
+ # state_dict.update({k: state_dict[k]})
685
+ # if 'norm_fcross' in k:
686
+ # k = k.replace('norm_fcross', 'norm1')
687
+ # state_dict.update({k: state_dict[k]})
688
+
689
+ # model.load_state_dict(state_dict)
690
 
691
  return model
692