Spaces:
Runtime error
Runtime error
xinyuanc91
commited on
Commit
Β·
95e7ff7
1
Parent(s):
db5f6ff
Update models/unet.py
Browse files- models/unet.py +3 -105
models/unet.py
CHANGED
@@ -610,112 +610,10 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin):
|
|
610 |
# config["num_class_embeds"] = 100
|
611 |
|
612 |
from diffusers.utils import WEIGHTS_NAME # diffusion_pytorch_model.bin
|
613 |
-
|
614 |
-
# {'_class_name': 'UNet3DConditionModel',
|
615 |
-
# '_diffusers_version': '0.2.2',
|
616 |
-
# 'act_fn': 'silu',
|
617 |
-
# 'attention_head_dim': 8,
|
618 |
-
# 'block_out_channels': [320, 640, 1280, 1280],
|
619 |
-
# 'center_input_sample': False,
|
620 |
-
# 'cross_attention_dim': 768,
|
621 |
-
# 'down_block_types':
|
622 |
-
# ['CrossAttnDownBlock3D',
|
623 |
-
# 'CrossAttnDownBlock3D',
|
624 |
-
# 'CrossAttnDownBlock3D',
|
625 |
-
# 'DownBlock3D'],
|
626 |
-
# 'downsample_padding': 1,
|
627 |
-
# 'flip_sin_to_cos': True,
|
628 |
-
# 'freq_shift': 0,
|
629 |
-
# 'in_channels': 4,
|
630 |
-
# 'layers_per_block': 2,
|
631 |
-
# 'mid_block_scale_factor': 1,
|
632 |
-
# 'norm_eps': 1e-05,
|
633 |
-
# 'norm_num_groups': 32,
|
634 |
-
# 'out_channels': 4,
|
635 |
-
# 'sample_size': 64,
|
636 |
-
# 'up_block_types':
|
637 |
-
# ['UpBlock3D',
|
638 |
-
# 'CrossAttnUpBlock3D',
|
639 |
-
# 'CrossAttnUpBlock3D',
|
640 |
-
# 'CrossAttnUpBlock3D']}
|
641 |
|
642 |
model = cls.from_config(config)
|
643 |
-
# model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
|
644 |
-
# if not os.path.isfile(model_file):
|
645 |
-
# raise RuntimeError(f"{model_file} does not exist")
|
646 |
-
# state_dict = torch.load(model_file, map_location="cpu")
|
647 |
-
|
648 |
-
# if use_concat:
|
649 |
-
# new_state_dict = {}
|
650 |
-
# conv_in_weight = state_dict["conv_in.weight"]
|
651 |
-
# new_conv_weight = torch.zeros((conv_in_weight.shape[0], 9, *conv_in_weight.shape[2:]), dtype=conv_in_weight.dtype)
|
652 |
-
|
653 |
-
# for i, j in zip([0, 1, 2, 3], [0, 1, 2, 3, 4, 5, 6, 7, 8]):
|
654 |
-
# new_conv_weight[:, j] = conv_in_weight[:, i]
|
655 |
-
# new_state_dict["conv_in.weight"] = new_conv_weight
|
656 |
-
# new_state_dict["conv_in.bias"] = state_dict["conv_in.bias"]
|
657 |
-
# for k, v in model.state_dict().items():
|
658 |
-
# # print(k)
|
659 |
-
# if '_temp.' in k:
|
660 |
-
# new_state_dict.update({k: v})
|
661 |
-
# if 'attn_fcross' in k: # conpy parms of attn1 to attn_fcross
|
662 |
-
# k = k.replace('attn_fcross', 'attn1')
|
663 |
-
# state_dict.update({k: state_dict[k]})
|
664 |
-
# if 'norm_fcross' in k:
|
665 |
-
# k = k.replace('norm_fcross', 'norm1')
|
666 |
-
# state_dict.update({k: state_dict[k]})
|
667 |
-
|
668 |
-
# if 'conv_in' in k:
|
669 |
-
# continue
|
670 |
-
# else:
|
671 |
-
# new_state_dict[k] = v
|
672 |
-
# # # tmp
|
673 |
-
# # if 'class_embedding' in k:
|
674 |
-
# # state_dict.update({k: v})
|
675 |
-
# # breakpoint()
|
676 |
-
# model.load_state_dict(new_state_dict)
|
677 |
-
# else:
|
678 |
-
# for k, v in model.state_dict().items():
|
679 |
-
# # print(k)
|
680 |
-
# if '_temp' in k:
|
681 |
-
# state_dict.update({k: v})
|
682 |
-
# if 'attn_fcross' in k: # conpy parms of attn1 to attn_fcross
|
683 |
-
# k = k.replace('attn_fcross', 'attn1')
|
684 |
-
# state_dict.update({k: state_dict[k]})
|
685 |
-
# if 'norm_fcross' in k:
|
686 |
-
# k = k.replace('norm_fcross', 'norm1')
|
687 |
-
# state_dict.update({k: state_dict[k]})
|
688 |
-
|
689 |
-
# model.load_state_dict(state_dict)
|
690 |
|
691 |
-
return model
|
692 |
-
|
693 |
-
if __name__ == '__main__':
|
694 |
-
import torch
|
695 |
-
# from xformers.ops import MemoryEfficientAttentionFlashAttentionOp
|
696 |
-
|
697 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
698 |
-
|
699 |
-
# pretrained_model_path = "/mnt/petrelfs/maxin/work/pretrained/stable-diffusion-2-1-base/" # p cluster
|
700 |
-
pretrained_model_path = "/mnt/petrelfs/share_data/zhanglingjun/stable-diffusion-v1-4/" # p cluster
|
701 |
-
unet = UNet3DConditionModel.from_pretrained_2d(pretrained_model_path, subfolder="unet").to(device)
|
702 |
-
# unet.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp)
|
703 |
-
unet.enable_xformers_memory_efficient_attention()
|
704 |
-
unet.enable_gradient_checkpointing()
|
705 |
-
|
706 |
-
unet.train()
|
707 |
-
|
708 |
-
use_image_num = 5
|
709 |
-
noisy_latents = torch.randn((2, 4, 16 + use_image_num, 32, 32)).to(device)
|
710 |
-
bsz = noisy_latents.shape[0]
|
711 |
-
timesteps = torch.randint(0, 1000, (bsz,)).to(device)
|
712 |
-
timesteps = timesteps.long()
|
713 |
-
encoder_hidden_states = torch.randn((bsz, 1 + use_image_num, 77, 768)).to(device)
|
714 |
-
# class_labels = torch.randn((bsz, )).to(device)
|
715 |
-
|
716 |
|
717 |
-
|
718 |
-
|
719 |
-
class_labels=None,
|
720 |
-
use_image_num=use_image_num).sample
|
721 |
-
print(model_pred.shape)
|
|
|
610 |
# config["num_class_embeds"] = 100
|
611 |
|
612 |
from diffusers.utils import WEIGHTS_NAME # diffusion_pytorch_model.bin
|
613 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
614 |
|
615 |
model = cls.from_config(config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
616 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
617 |
|
618 |
+
return model
|
619 |
+
|
|
|
|
|
|