xinyuanc91 commited on
Commit
95e7ff7
1 Parent(s): db5f6ff

Update models/unet.py

Browse files
Files changed (1) hide show
  1. models/unet.py +3 -105
models/unet.py CHANGED
@@ -610,112 +610,10 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin):
610
  # config["num_class_embeds"] = 100
611
 
612
  from diffusers.utils import WEIGHTS_NAME # diffusion_pytorch_model.bin
613
-
614
- # {'_class_name': 'UNet3DConditionModel',
615
- # '_diffusers_version': '0.2.2',
616
- # 'act_fn': 'silu',
617
- # 'attention_head_dim': 8,
618
- # 'block_out_channels': [320, 640, 1280, 1280],
619
- # 'center_input_sample': False,
620
- # 'cross_attention_dim': 768,
621
- # 'down_block_types':
622
- # ['CrossAttnDownBlock3D',
623
- # 'CrossAttnDownBlock3D',
624
- # 'CrossAttnDownBlock3D',
625
- # 'DownBlock3D'],
626
- # 'downsample_padding': 1,
627
- # 'flip_sin_to_cos': True,
628
- # 'freq_shift': 0,
629
- # 'in_channels': 4,
630
- # 'layers_per_block': 2,
631
- # 'mid_block_scale_factor': 1,
632
- # 'norm_eps': 1e-05,
633
- # 'norm_num_groups': 32,
634
- # 'out_channels': 4,
635
- # 'sample_size': 64,
636
- # 'up_block_types':
637
- # ['UpBlock3D',
638
- # 'CrossAttnUpBlock3D',
639
- # 'CrossAttnUpBlock3D',
640
- # 'CrossAttnUpBlock3D']}
641
 
642
  model = cls.from_config(config)
643
- # model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
644
- # if not os.path.isfile(model_file):
645
- # raise RuntimeError(f"{model_file} does not exist")
646
- # state_dict = torch.load(model_file, map_location="cpu")
647
-
648
- # if use_concat:
649
- # new_state_dict = {}
650
- # conv_in_weight = state_dict["conv_in.weight"]
651
- # new_conv_weight = torch.zeros((conv_in_weight.shape[0], 9, *conv_in_weight.shape[2:]), dtype=conv_in_weight.dtype)
652
-
653
- # for i, j in zip([0, 1, 2, 3], [0, 1, 2, 3, 4, 5, 6, 7, 8]):
654
- # new_conv_weight[:, j] = conv_in_weight[:, i]
655
- # new_state_dict["conv_in.weight"] = new_conv_weight
656
- # new_state_dict["conv_in.bias"] = state_dict["conv_in.bias"]
657
- # for k, v in model.state_dict().items():
658
- # # print(k)
659
- # if '_temp.' in k:
660
- # new_state_dict.update({k: v})
661
- # if 'attn_fcross' in k: # conpy parms of attn1 to attn_fcross
662
- # k = k.replace('attn_fcross', 'attn1')
663
- # state_dict.update({k: state_dict[k]})
664
- # if 'norm_fcross' in k:
665
- # k = k.replace('norm_fcross', 'norm1')
666
- # state_dict.update({k: state_dict[k]})
667
-
668
- # if 'conv_in' in k:
669
- # continue
670
- # else:
671
- # new_state_dict[k] = v
672
- # # # tmp
673
- # # if 'class_embedding' in k:
674
- # # state_dict.update({k: v})
675
- # # breakpoint()
676
- # model.load_state_dict(new_state_dict)
677
- # else:
678
- # for k, v in model.state_dict().items():
679
- # # print(k)
680
- # if '_temp' in k:
681
- # state_dict.update({k: v})
682
- # if 'attn_fcross' in k: # conpy parms of attn1 to attn_fcross
683
- # k = k.replace('attn_fcross', 'attn1')
684
- # state_dict.update({k: state_dict[k]})
685
- # if 'norm_fcross' in k:
686
- # k = k.replace('norm_fcross', 'norm1')
687
- # state_dict.update({k: state_dict[k]})
688
-
689
- # model.load_state_dict(state_dict)
690
 
691
- return model
692
-
693
- if __name__ == '__main__':
694
- import torch
695
- # from xformers.ops import MemoryEfficientAttentionFlashAttentionOp
696
-
697
- device = "cuda" if torch.cuda.is_available() else "cpu"
698
-
699
- # pretrained_model_path = "/mnt/petrelfs/maxin/work/pretrained/stable-diffusion-2-1-base/" # p cluster
700
- pretrained_model_path = "/mnt/petrelfs/share_data/zhanglingjun/stable-diffusion-v1-4/" # p cluster
701
- unet = UNet3DConditionModel.from_pretrained_2d(pretrained_model_path, subfolder="unet").to(device)
702
- # unet.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp)
703
- unet.enable_xformers_memory_efficient_attention()
704
- unet.enable_gradient_checkpointing()
705
-
706
- unet.train()
707
-
708
- use_image_num = 5
709
- noisy_latents = torch.randn((2, 4, 16 + use_image_num, 32, 32)).to(device)
710
- bsz = noisy_latents.shape[0]
711
- timesteps = torch.randint(0, 1000, (bsz,)).to(device)
712
- timesteps = timesteps.long()
713
- encoder_hidden_states = torch.randn((bsz, 1 + use_image_num, 77, 768)).to(device)
714
- # class_labels = torch.randn((bsz, )).to(device)
715
-
716
 
717
- model_pred = unet(sample=noisy_latents, timestep=timesteps,
718
- encoder_hidden_states=encoder_hidden_states,
719
- class_labels=None,
720
- use_image_num=use_image_num).sample
721
- print(model_pred.shape)
 
610
  # config["num_class_embeds"] = 100
611
 
612
  from diffusers.utils import WEIGHTS_NAME # diffusion_pytorch_model.bin
613
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
614
 
615
  model = cls.from_config(config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
616
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
617
 
618
+ return model
619
+