nat997 commited on
Commit
a27592a
โ€ข
1 Parent(s): a81a471

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes. ย  See raw diff
Files changed (50) hide show
  1. README.md +7 -6
  2. checkpoints/humanparsing/parsing_atr.onnx +3 -0
  3. checkpoints/humanparsing/parsing_lip.onnx +3 -0
  4. checkpoints/ootd/feature_extractor/preprocessor_config.json +20 -0
  5. checkpoints/ootd/model_index.json +38 -0
  6. checkpoints/ootd/ootd_dc/checkpoint-36000/unet_garm/config.json +68 -0
  7. checkpoints/ootd/ootd_dc/checkpoint-36000/unet_garm/diffusion_pytorch_model.safetensors +3 -0
  8. checkpoints/ootd/ootd_dc/checkpoint-36000/unet_vton/config.json +68 -0
  9. checkpoints/ootd/ootd_dc/checkpoint-36000/unet_vton/diffusion_pytorch_model.safetensors +3 -0
  10. checkpoints/ootd/ootd_hd/checkpoint-36000/unet_garm/config.json +68 -0
  11. checkpoints/ootd/ootd_hd/checkpoint-36000/unet_garm/diffusion_pytorch_model.safetensors +3 -0
  12. checkpoints/ootd/ootd_hd/checkpoint-36000/unet_vton/config.json +68 -0
  13. checkpoints/ootd/ootd_hd/checkpoint-36000/unet_vton/diffusion_pytorch_model.safetensors +3 -0
  14. checkpoints/ootd/scheduler/scheduler_config.json +25 -0
  15. checkpoints/ootd/text_encoder/config.json +25 -0
  16. checkpoints/ootd/text_encoder/pytorch_model.bin +3 -0
  17. checkpoints/ootd/tokenizer/merges.txt +0 -0
  18. checkpoints/ootd/tokenizer/special_tokens_map.json +24 -0
  19. checkpoints/ootd/tokenizer/tokenizer_config.json +34 -0
  20. checkpoints/ootd/tokenizer/vocab.json +0 -0
  21. checkpoints/ootd/vae/config.json +31 -0
  22. checkpoints/ootd/vae/diffusion_pytorch_model.bin +3 -0
  23. checkpoints/openpose/ckpts/body_pose_model.pth +3 -0
  24. ootd/__pycache__/inference_ootd_dc.cpython-310.pyc +0 -0
  25. ootd/__pycache__/inference_ootd_hd.cpython-310.pyc +0 -0
  26. ootd/inference_ootd.py +133 -0
  27. ootd/inference_ootd_dc.py +132 -0
  28. ootd/inference_ootd_hd.py +132 -0
  29. ootd/pipelines_ootd/__pycache__/attention_garm.cpython-310.pyc +0 -0
  30. ootd/pipelines_ootd/__pycache__/attention_vton.cpython-310.pyc +0 -0
  31. ootd/pipelines_ootd/__pycache__/pipeline_ootd.cpython-310.pyc +0 -0
  32. ootd/pipelines_ootd/__pycache__/transformer_garm_2d.cpython-310.pyc +0 -0
  33. ootd/pipelines_ootd/__pycache__/transformer_vton_2d.cpython-310.pyc +0 -0
  34. ootd/pipelines_ootd/__pycache__/unet_garm_2d_blocks.cpython-310.pyc +0 -0
  35. ootd/pipelines_ootd/__pycache__/unet_garm_2d_condition.cpython-310.pyc +0 -0
  36. ootd/pipelines_ootd/__pycache__/unet_vton_2d_blocks.cpython-310.pyc +0 -0
  37. ootd/pipelines_ootd/__pycache__/unet_vton_2d_condition.cpython-310.pyc +0 -0
  38. ootd/pipelines_ootd/attention_garm.py +402 -0
  39. ootd/pipelines_ootd/attention_vton.py +407 -0
  40. ootd/pipelines_ootd/pipeline_ootd.py +846 -0
  41. ootd/pipelines_ootd/transformer_garm_2d.py +449 -0
  42. ootd/pipelines_ootd/transformer_vton_2d.py +452 -0
  43. ootd/pipelines_ootd/unet_garm_2d_blocks.py +0 -0
  44. ootd/pipelines_ootd/unet_garm_2d_condition.py +1183 -0
  45. ootd/pipelines_ootd/unet_vton_2d_blocks.py +0 -0
  46. ootd/pipelines_ootd/unet_vton_2d_condition.py +1183 -0
  47. preprocess/humanparsing/__pycache__/parsing_api.cpython-310.pyc +0 -0
  48. preprocess/humanparsing/__pycache__/run_parsing.cpython-310.pyc +0 -0
  49. preprocess/humanparsing/datasets/__init__.py +0 -0
  50. preprocess/humanparsing/datasets/__pycache__/__init__.cpython-310.pyc +0 -0
README.md CHANGED
@@ -1,12 +1,13 @@
1
  ---
2
  title: OOTDiffusion
3
- emoji: ๐Ÿ 
4
- colorFrom: purple
5
- colorTo: blue
6
  sdk: gradio
7
- sdk_version: 4.22.0
8
- app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
  title: OOTDiffusion
3
+ emoji: ๐Ÿฅผ๐Ÿ‘–๐Ÿ‘—
4
+ colorFrom: yellow
5
+ colorTo: pink
6
  sdk: gradio
7
+ sdk_version: 4.16.0
8
+ app_file: ./run/gradio_ootd.py
9
  pinned: false
10
+ license: cc-by-nc-sa-4.0
11
  ---
12
 
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
checkpoints/humanparsing/parsing_atr.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:04c7d1d070d0e0ae943d86b18cb5aaaea9e278d97462e9cfb270cbbe4cd977f4
3
+ size 266859305
checkpoints/humanparsing/parsing_lip.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8436e1dae96e2601c373d1ace29c8f0978b16357d9038c17a8ba756cca376dbc
3
+ size 266863411
checkpoints/ootd/feature_extractor/preprocessor_config.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "crop_size": 224,
3
+ "do_center_crop": true,
4
+ "do_convert_rgb": true,
5
+ "do_normalize": true,
6
+ "do_resize": true,
7
+ "feature_extractor_type": "CLIPFeatureExtractor",
8
+ "image_mean": [
9
+ 0.48145466,
10
+ 0.4578275,
11
+ 0.40821073
12
+ ],
13
+ "image_std": [
14
+ 0.26862954,
15
+ 0.26130258,
16
+ 0.27577711
17
+ ],
18
+ "resample": 3,
19
+ "size": 224
20
+ }
checkpoints/ootd/model_index.json ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "OotdPipeline",
3
+ "_diffusers_version": "0.24.0.dev0",
4
+ "_name_or_path": "/home/aigc/levi/OOTDemo/Checkpoints/OOTDiffusion",
5
+ "feature_extractor": [
6
+ "transformers",
7
+ "CLIPImageProcessor"
8
+ ],
9
+ "requires_safety_checker": true,
10
+ "safety_checker": [
11
+ "stable_diffusion",
12
+ "StableDiffusionSafetyChecker"
13
+ ],
14
+ "scheduler": [
15
+ "diffusers",
16
+ "PNDMScheduler"
17
+ ],
18
+ "text_encoder": [
19
+ "transformers",
20
+ "CLIPTextModel"
21
+ ],
22
+ "tokenizer": [
23
+ "transformers",
24
+ "CLIPTokenizer"
25
+ ],
26
+ "unet_garm": [
27
+ "pipelines_ootd.unet_garm_2d_condition",
28
+ "UNetGarm2DConditionModel"
29
+ ],
30
+ "unet_vton": [
31
+ "pipelines_ootd.unet_vton_2d_condition",
32
+ "UNetVton2DConditionModel"
33
+ ],
34
+ "vae": [
35
+ "diffusers",
36
+ "AutoencoderKL"
37
+ ]
38
+ }
checkpoints/ootd/ootd_dc/checkpoint-36000/unet_garm/config.json ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "UNetGarm2DConditionModel",
3
+ "_diffusers_version": "0.24.0.dev0",
4
+ "_name_or_path": "/home/aigc/Vton_v4/OOTDiffusion/models/ldm/stablediffusion_v15",
5
+ "act_fn": "silu",
6
+ "addition_embed_type": null,
7
+ "addition_embed_type_num_heads": 64,
8
+ "addition_time_embed_dim": null,
9
+ "attention_head_dim": 8,
10
+ "attention_type": "default",
11
+ "block_out_channels": [
12
+ 320,
13
+ 640,
14
+ 1280,
15
+ 1280
16
+ ],
17
+ "center_input_sample": false,
18
+ "class_embed_type": null,
19
+ "class_embeddings_concat": false,
20
+ "conv_in_kernel": 3,
21
+ "conv_out_kernel": 3,
22
+ "cross_attention_dim": 768,
23
+ "cross_attention_norm": null,
24
+ "down_block_types": [
25
+ "CrossAttnDownBlock2D",
26
+ "CrossAttnDownBlock2D",
27
+ "CrossAttnDownBlock2D",
28
+ "DownBlock2D"
29
+ ],
30
+ "downsample_padding": 1,
31
+ "dropout": 0.0,
32
+ "dual_cross_attention": false,
33
+ "encoder_hid_dim": null,
34
+ "encoder_hid_dim_type": null,
35
+ "flip_sin_to_cos": true,
36
+ "freq_shift": 0,
37
+ "in_channels": 4,
38
+ "layers_per_block": 2,
39
+ "mid_block_only_cross_attention": null,
40
+ "mid_block_scale_factor": 1,
41
+ "mid_block_type": "UNetMidBlock2DCrossAttn",
42
+ "norm_eps": 1e-05,
43
+ "norm_num_groups": 32,
44
+ "num_attention_heads": null,
45
+ "num_class_embeds": null,
46
+ "only_cross_attention": false,
47
+ "out_channels": 4,
48
+ "projection_class_embeddings_input_dim": null,
49
+ "resnet_out_scale_factor": 1.0,
50
+ "resnet_skip_time_act": false,
51
+ "resnet_time_scale_shift": "default",
52
+ "reverse_transformer_layers_per_block": null,
53
+ "sample_size": 64,
54
+ "time_cond_proj_dim": null,
55
+ "time_embedding_act_fn": null,
56
+ "time_embedding_dim": null,
57
+ "time_embedding_type": "positional",
58
+ "timestep_post_act": null,
59
+ "transformer_layers_per_block": 1,
60
+ "up_block_types": [
61
+ "UpBlock2D",
62
+ "CrossAttnUpBlock2D",
63
+ "CrossAttnUpBlock2D",
64
+ "CrossAttnUpBlock2D"
65
+ ],
66
+ "upcast_attention": false,
67
+ "use_linear_projection": false
68
+ }
checkpoints/ootd/ootd_dc/checkpoint-36000/unet_garm/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:78b0771e1c8dba8a02eb5e8b39f20cbab0c2722bc73696fb7e2d6278f70e6f3d
3
+ size 3438167536
checkpoints/ootd/ootd_dc/checkpoint-36000/unet_vton/config.json ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "UNetVton2DConditionModel",
3
+ "_diffusers_version": "0.24.0.dev0",
4
+ "_name_or_path": "/home/aigc/Vton_v4/OOTDiffusion/models/ldm/stablediffusion_v15",
5
+ "act_fn": "silu",
6
+ "addition_embed_type": null,
7
+ "addition_embed_type_num_heads": 64,
8
+ "addition_time_embed_dim": null,
9
+ "attention_head_dim": 8,
10
+ "attention_type": "default",
11
+ "block_out_channels": [
12
+ 320,
13
+ 640,
14
+ 1280,
15
+ 1280
16
+ ],
17
+ "center_input_sample": false,
18
+ "class_embed_type": null,
19
+ "class_embeddings_concat": false,
20
+ "conv_in_kernel": 3,
21
+ "conv_out_kernel": 3,
22
+ "cross_attention_dim": 768,
23
+ "cross_attention_norm": null,
24
+ "down_block_types": [
25
+ "CrossAttnDownBlock2D",
26
+ "CrossAttnDownBlock2D",
27
+ "CrossAttnDownBlock2D",
28
+ "DownBlock2D"
29
+ ],
30
+ "downsample_padding": 1,
31
+ "dropout": 0.0,
32
+ "dual_cross_attention": false,
33
+ "encoder_hid_dim": null,
34
+ "encoder_hid_dim_type": null,
35
+ "flip_sin_to_cos": true,
36
+ "freq_shift": 0,
37
+ "in_channels": 8,
38
+ "layers_per_block": 2,
39
+ "mid_block_only_cross_attention": null,
40
+ "mid_block_scale_factor": 1,
41
+ "mid_block_type": "UNetMidBlock2DCrossAttn",
42
+ "norm_eps": 1e-05,
43
+ "norm_num_groups": 32,
44
+ "num_attention_heads": null,
45
+ "num_class_embeds": null,
46
+ "only_cross_attention": false,
47
+ "out_channels": 4,
48
+ "projection_class_embeddings_input_dim": null,
49
+ "resnet_out_scale_factor": 1.0,
50
+ "resnet_skip_time_act": false,
51
+ "resnet_time_scale_shift": "default",
52
+ "reverse_transformer_layers_per_block": null,
53
+ "sample_size": 64,
54
+ "time_cond_proj_dim": null,
55
+ "time_embedding_act_fn": null,
56
+ "time_embedding_dim": null,
57
+ "time_embedding_type": "positional",
58
+ "timestep_post_act": null,
59
+ "transformer_layers_per_block": 1,
60
+ "up_block_types": [
61
+ "UpBlock2D",
62
+ "CrossAttnUpBlock2D",
63
+ "CrossAttnUpBlock2D",
64
+ "CrossAttnUpBlock2D"
65
+ ],
66
+ "upcast_attention": false,
67
+ "use_linear_projection": false
68
+ }
checkpoints/ootd/ootd_dc/checkpoint-36000/unet_vton/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0b3cb1398172757fe1f49c130d104ec4da8d20d2132958dfff0748a2b6a7506b
3
+ size 3438213624
checkpoints/ootd/ootd_hd/checkpoint-36000/unet_garm/config.json ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "UNetGarm2DConditionModel",
3
+ "_diffusers_version": "0.24.0.dev0",
4
+ "_name_or_path": "/home/aigc/Vton_v4/OOTDiffusion/models/ldm/stablediffusion_v15",
5
+ "act_fn": "silu",
6
+ "addition_embed_type": null,
7
+ "addition_embed_type_num_heads": 64,
8
+ "addition_time_embed_dim": null,
9
+ "attention_head_dim": 8,
10
+ "attention_type": "default",
11
+ "block_out_channels": [
12
+ 320,
13
+ 640,
14
+ 1280,
15
+ 1280
16
+ ],
17
+ "center_input_sample": false,
18
+ "class_embed_type": null,
19
+ "class_embeddings_concat": false,
20
+ "conv_in_kernel": 3,
21
+ "conv_out_kernel": 3,
22
+ "cross_attention_dim": 768,
23
+ "cross_attention_norm": null,
24
+ "down_block_types": [
25
+ "CrossAttnDownBlock2D",
26
+ "CrossAttnDownBlock2D",
27
+ "CrossAttnDownBlock2D",
28
+ "DownBlock2D"
29
+ ],
30
+ "downsample_padding": 1,
31
+ "dropout": 0.0,
32
+ "dual_cross_attention": false,
33
+ "encoder_hid_dim": null,
34
+ "encoder_hid_dim_type": null,
35
+ "flip_sin_to_cos": true,
36
+ "freq_shift": 0,
37
+ "in_channels": 4,
38
+ "layers_per_block": 2,
39
+ "mid_block_only_cross_attention": null,
40
+ "mid_block_scale_factor": 1,
41
+ "mid_block_type": "UNetMidBlock2DCrossAttn",
42
+ "norm_eps": 1e-05,
43
+ "norm_num_groups": 32,
44
+ "num_attention_heads": null,
45
+ "num_class_embeds": null,
46
+ "only_cross_attention": false,
47
+ "out_channels": 4,
48
+ "projection_class_embeddings_input_dim": null,
49
+ "resnet_out_scale_factor": 1.0,
50
+ "resnet_skip_time_act": false,
51
+ "resnet_time_scale_shift": "default",
52
+ "reverse_transformer_layers_per_block": null,
53
+ "sample_size": 64,
54
+ "time_cond_proj_dim": null,
55
+ "time_embedding_act_fn": null,
56
+ "time_embedding_dim": null,
57
+ "time_embedding_type": "positional",
58
+ "timestep_post_act": null,
59
+ "transformer_layers_per_block": 1,
60
+ "up_block_types": [
61
+ "UpBlock2D",
62
+ "CrossAttnUpBlock2D",
63
+ "CrossAttnUpBlock2D",
64
+ "CrossAttnUpBlock2D"
65
+ ],
66
+ "upcast_attention": false,
67
+ "use_linear_projection": false
68
+ }
checkpoints/ootd/ootd_hd/checkpoint-36000/unet_garm/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dea03c6b3339f13e1432711608d5c7ac83fcb9b14a430aee52b0015834ba41da
3
+ size 3438167536
checkpoints/ootd/ootd_hd/checkpoint-36000/unet_vton/config.json ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "UNetVton2DConditionModel",
3
+ "_diffusers_version": "0.24.0.dev0",
4
+ "_name_or_path": "/home/aigc/Vton_v4/OOTDiffusion/models/ldm/stablediffusion_v15",
5
+ "act_fn": "silu",
6
+ "addition_embed_type": null,
7
+ "addition_embed_type_num_heads": 64,
8
+ "addition_time_embed_dim": null,
9
+ "attention_head_dim": 8,
10
+ "attention_type": "default",
11
+ "block_out_channels": [
12
+ 320,
13
+ 640,
14
+ 1280,
15
+ 1280
16
+ ],
17
+ "center_input_sample": false,
18
+ "class_embed_type": null,
19
+ "class_embeddings_concat": false,
20
+ "conv_in_kernel": 3,
21
+ "conv_out_kernel": 3,
22
+ "cross_attention_dim": 768,
23
+ "cross_attention_norm": null,
24
+ "down_block_types": [
25
+ "CrossAttnDownBlock2D",
26
+ "CrossAttnDownBlock2D",
27
+ "CrossAttnDownBlock2D",
28
+ "DownBlock2D"
29
+ ],
30
+ "downsample_padding": 1,
31
+ "dropout": 0.0,
32
+ "dual_cross_attention": false,
33
+ "encoder_hid_dim": null,
34
+ "encoder_hid_dim_type": null,
35
+ "flip_sin_to_cos": true,
36
+ "freq_shift": 0,
37
+ "in_channels": 8,
38
+ "layers_per_block": 2,
39
+ "mid_block_only_cross_attention": null,
40
+ "mid_block_scale_factor": 1,
41
+ "mid_block_type": "UNetMidBlock2DCrossAttn",
42
+ "norm_eps": 1e-05,
43
+ "norm_num_groups": 32,
44
+ "num_attention_heads": null,
45
+ "num_class_embeds": null,
46
+ "only_cross_attention": false,
47
+ "out_channels": 4,
48
+ "projection_class_embeddings_input_dim": null,
49
+ "resnet_out_scale_factor": 1.0,
50
+ "resnet_skip_time_act": false,
51
+ "resnet_time_scale_shift": "default",
52
+ "reverse_transformer_layers_per_block": null,
53
+ "sample_size": 64,
54
+ "time_cond_proj_dim": null,
55
+ "time_embedding_act_fn": null,
56
+ "time_embedding_dim": null,
57
+ "time_embedding_type": "positional",
58
+ "timestep_post_act": null,
59
+ "transformer_layers_per_block": 1,
60
+ "up_block_types": [
61
+ "UpBlock2D",
62
+ "CrossAttnUpBlock2D",
63
+ "CrossAttnUpBlock2D",
64
+ "CrossAttnUpBlock2D"
65
+ ],
66
+ "upcast_attention": false,
67
+ "use_linear_projection": false
68
+ }
checkpoints/ootd/ootd_hd/checkpoint-36000/unet_vton/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a3587b5025565060842eac78c74f87fc06d8b82c2b51d9938a492d42858679fe
3
+ size 3438213624
checkpoints/ootd/scheduler/scheduler_config.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "DPMSolverMultistepScheduler",
3
+ "_diffusers_version": "0.20.0.dev0",
4
+ "algorithm_type": "dpmsolver++",
5
+ "beta_end": 0.012,
6
+ "beta_schedule": "scaled_linear",
7
+ "beta_start": 0.00085,
8
+ "clip_sample": false,
9
+ "dynamic_thresholding_ratio": 0.995,
10
+ "lambda_min_clipped": -Infinity,
11
+ "lower_order_final": true,
12
+ "num_train_timesteps": 1000,
13
+ "prediction_type": "epsilon",
14
+ "sample_max_value": 1.0,
15
+ "set_alpha_to_one": false,
16
+ "skip_prk_steps": true,
17
+ "solver_order": 2,
18
+ "solver_type": "midpoint",
19
+ "steps_offset": 1,
20
+ "thresholding": false,
21
+ "timestep_spacing": "leading",
22
+ "trained_betas": null,
23
+ "use_karras_sigmas": true,
24
+ "variance_type": null
25
+ }
checkpoints/ootd/text_encoder/config.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "openai/clip-vit-large-patch14",
3
+ "architectures": [
4
+ "CLIPTextModel"
5
+ ],
6
+ "attention_dropout": 0.0,
7
+ "bos_token_id": 0,
8
+ "dropout": 0.0,
9
+ "eos_token_id": 2,
10
+ "hidden_act": "quick_gelu",
11
+ "hidden_size": 768,
12
+ "initializer_factor": 1.0,
13
+ "initializer_range": 0.02,
14
+ "intermediate_size": 3072,
15
+ "layer_norm_eps": 1e-05,
16
+ "max_position_embeddings": 77,
17
+ "model_type": "clip_text_model",
18
+ "num_attention_heads": 12,
19
+ "num_hidden_layers": 12,
20
+ "pad_token_id": 1,
21
+ "projection_dim": 768,
22
+ "torch_dtype": "float32",
23
+ "transformers_version": "4.22.0.dev0",
24
+ "vocab_size": 49408
25
+ }
checkpoints/ootd/text_encoder/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:770a47a9ffdcfda0b05506a7888ed714d06131d60267e6cf52765d61cf59fd67
3
+ size 492305335
checkpoints/ootd/tokenizer/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
checkpoints/ootd/tokenizer/special_tokens_map.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<|startoftext|>",
4
+ "lstrip": false,
5
+ "normalized": true,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<|endoftext|>",
11
+ "lstrip": false,
12
+ "normalized": true,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": "<|endoftext|>",
17
+ "unk_token": {
18
+ "content": "<|endoftext|>",
19
+ "lstrip": false,
20
+ "normalized": true,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ }
24
+ }
checkpoints/ootd/tokenizer/tokenizer_config.json ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "bos_token": {
4
+ "__type": "AddedToken",
5
+ "content": "<|startoftext|>",
6
+ "lstrip": false,
7
+ "normalized": true,
8
+ "rstrip": false,
9
+ "single_word": false
10
+ },
11
+ "do_lower_case": true,
12
+ "eos_token": {
13
+ "__type": "AddedToken",
14
+ "content": "<|endoftext|>",
15
+ "lstrip": false,
16
+ "normalized": true,
17
+ "rstrip": false,
18
+ "single_word": false
19
+ },
20
+ "errors": "replace",
21
+ "model_max_length": 77,
22
+ "name_or_path": "openai/clip-vit-large-patch14",
23
+ "pad_token": "<|endoftext|>",
24
+ "special_tokens_map_file": "./special_tokens_map.json",
25
+ "tokenizer_class": "CLIPTokenizer",
26
+ "unk_token": {
27
+ "__type": "AddedToken",
28
+ "content": "<|endoftext|>",
29
+ "lstrip": false,
30
+ "normalized": true,
31
+ "rstrip": false,
32
+ "single_word": false
33
+ }
34
+ }
checkpoints/ootd/tokenizer/vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
checkpoints/ootd/vae/config.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "AutoencoderKL",
3
+ "_diffusers_version": "0.20.0.dev0",
4
+ "act_fn": "silu",
5
+ "block_out_channels": [
6
+ 128,
7
+ 256,
8
+ 512,
9
+ 512
10
+ ],
11
+ "down_block_types": [
12
+ "DownEncoderBlock2D",
13
+ "DownEncoderBlock2D",
14
+ "DownEncoderBlock2D",
15
+ "DownEncoderBlock2D"
16
+ ],
17
+ "force_upcast": true,
18
+ "in_channels": 3,
19
+ "latent_channels": 4,
20
+ "layers_per_block": 2,
21
+ "norm_num_groups": 32,
22
+ "out_channels": 3,
23
+ "sample_size": 512,
24
+ "scaling_factor": 0.18215,
25
+ "up_block_types": [
26
+ "UpDecoderBlock2D",
27
+ "UpDecoderBlock2D",
28
+ "UpDecoderBlock2D",
29
+ "UpDecoderBlock2D"
30
+ ]
31
+ }
checkpoints/ootd/vae/diffusion_pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9814753f897cd32db41ec0bf0c574f6f44b39340103df9f4778b18565946d8b1
3
+ size 334712113
checkpoints/openpose/ckpts/body_pose_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:25a948c16078b0f08e236bda51a385d855ef4c153598947c28c0d47ed94bb746
3
+ size 209267595
ootd/__pycache__/inference_ootd_dc.cpython-310.pyc ADDED
Binary file (3.62 kB). View file
 
ootd/__pycache__/inference_ootd_hd.cpython-310.pyc ADDED
Binary file (3.62 kB). View file
 
ootd/inference_ootd.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pdb
2
+ from pathlib import Path
3
+ import sys
4
+ PROJECT_ROOT = Path(__file__).absolute().parents[0].absolute()
5
+ sys.path.insert(0, str(PROJECT_ROOT))
6
+ import os
7
+
8
+ import torch
9
+ import numpy as np
10
+ from PIL import Image
11
+ import cv2
12
+
13
+ import random
14
+ import time
15
+ import pdb
16
+
17
+ from pipelines_ootd.pipeline_ootd import OotdPipeline
18
+ from pipelines_ootd.unet_garm_2d_condition import UNetGarm2DConditionModel
19
+ from pipelines_ootd.unet_vton_2d_condition import UNetVton2DConditionModel
20
+ from diffusers import UniPCMultistepScheduler
21
+ from diffusers import AutoencoderKL
22
+
23
+ import torch.nn as nn
24
+ import torch.nn.functional as F
25
+ from transformers import AutoProcessor, CLIPVisionModelWithProjection
26
+ from transformers import CLIPTextModel, CLIPTokenizer
27
+
28
+ VIT_PATH = "openai/clip-vit-large-patch14"
29
+ VAE_PATH = "./checkpoints/ootd"
30
+ UNET_PATH = "./checkpoints/ootd/ootd_hd/checkpoint-36000"
31
+ MODEL_PATH = "./checkpoints/ootd"
32
+
33
+ class OOTDiffusion:
34
+
35
+ def __init__(self, gpu_id):
36
+ self.gpu_id = 'cuda:' + str(gpu_id)
37
+
38
+ vae = AutoencoderKL.from_pretrained(
39
+ VAE_PATH,
40
+ subfolder="vae",
41
+ torch_dtype=torch.float16,
42
+ )
43
+
44
+ unet_garm = UNetGarm2DConditionModel.from_pretrained(
45
+ UNET_PATH,
46
+ subfolder="unet_garm",
47
+ torch_dtype=torch.float16,
48
+ use_safetensors=True,
49
+ )
50
+ unet_vton = UNetVton2DConditionModel.from_pretrained(
51
+ UNET_PATH,
52
+ subfolder="unet_vton",
53
+ torch_dtype=torch.float16,
54
+ use_safetensors=True,
55
+ )
56
+
57
+ self.pipe = OotdPipeline.from_pretrained(
58
+ MODEL_PATH,
59
+ unet_garm=unet_garm,
60
+ unet_vton=unet_vton,
61
+ vae=vae,
62
+ torch_dtype=torch.float16,
63
+ variant="fp16",
64
+ use_safetensors=True,
65
+ safety_checker=None,
66
+ requires_safety_checker=False,
67
+ ).to(self.gpu_id)
68
+
69
+ self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
70
+
71
+ self.auto_processor = AutoProcessor.from_pretrained(VIT_PATH)
72
+ self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(VIT_PATH).to(self.gpu_id)
73
+
74
+ self.tokenizer = CLIPTokenizer.from_pretrained(
75
+ MODEL_PATH,
76
+ subfolder="tokenizer",
77
+ )
78
+ self.text_encoder = CLIPTextModel.from_pretrained(
79
+ MODEL_PATH,
80
+ subfolder="text_encoder",
81
+ ).to(self.gpu_id)
82
+
83
+
84
+ def tokenize_captions(self, captions, max_length):
85
+ inputs = self.tokenizer(
86
+ captions, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
87
+ )
88
+ return inputs.input_ids
89
+
90
+
91
+ def __call__(self,
92
+ model_type='hd',
93
+ category='upperbody',
94
+ image_garm=None,
95
+ image_vton=None,
96
+ mask=None,
97
+ image_ori=None,
98
+ num_samples=1,
99
+ num_steps=20,
100
+ image_scale=1.0,
101
+ seed=-1,
102
+ ):
103
+ if seed == -1:
104
+ random.seed(time.time())
105
+ seed = random.randint(0, 2147483647)
106
+ print('Initial seed: ' + str(seed))
107
+ generator = torch.manual_seed(seed)
108
+
109
+ with torch.no_grad():
110
+ prompt_image = self.auto_processor(images=image_garm, return_tensors="pt").to(self.gpu_id)
111
+ prompt_image = self.image_encoder(prompt_image.data['pixel_values']).image_embeds
112
+ prompt_image = prompt_image.unsqueeze(1)
113
+ if model_type == 'hd':
114
+ prompt_embeds = self.text_encoder(self.tokenize_captions([""], 2).to(self.gpu_id))[0]
115
+ prompt_embeds[:, 1:] = prompt_image[:]
116
+ elif model_type == 'dc':
117
+ prompt_embeds = self.text_encoder(self.tokenize_captions([category], 3).to(self.gpu_id))[0]
118
+ prompt_embeds = torch.cat([prompt_embeds, prompt_image], dim=1)
119
+ else:
120
+ raise ValueError("model_type must be \'hd\' or \'dc\'!")
121
+
122
+ images = self.pipe(prompt_embeds=prompt_embeds,
123
+ image_garm=image_garm,
124
+ image_vton=image_vton,
125
+ mask=mask,
126
+ image_ori=image_ori,
127
+ num_inference_steps=num_steps,
128
+ image_guidance_scale=image_scale,
129
+ num_images_per_prompt=num_samples,
130
+ generator=generator,
131
+ ).images
132
+
133
+ return images
ootd/inference_ootd_dc.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pdb
2
+ from pathlib import Path
3
+ import sys
4
+ PROJECT_ROOT = Path(__file__).absolute().parents[0].absolute()
5
+ sys.path.insert(0, str(PROJECT_ROOT))
6
+ import os
7
+ import torch
8
+ import numpy as np
9
+ from PIL import Image
10
+ import cv2
11
+
12
+ import random
13
+ import time
14
+ import pdb
15
+
16
+ from pipelines_ootd.pipeline_ootd import OotdPipeline
17
+ from pipelines_ootd.unet_garm_2d_condition import UNetGarm2DConditionModel
18
+ from pipelines_ootd.unet_vton_2d_condition import UNetVton2DConditionModel
19
+ from diffusers import UniPCMultistepScheduler
20
+ from diffusers import AutoencoderKL
21
+
22
+ import torch.nn as nn
23
+ import torch.nn.functional as F
24
+ from transformers import AutoProcessor, CLIPVisionModelWithProjection
25
+ from transformers import CLIPTextModel, CLIPTokenizer
26
+
27
+ VIT_PATH = "openai/clip-vit-large-patch14"
28
+ VAE_PATH = "./checkpoints/ootd"
29
+ UNET_PATH = "./checkpoints/ootd/ootd_dc/checkpoint-36000"
30
+ MODEL_PATH = "./checkpoints/ootd"
31
+
32
+ class OOTDiffusionDC:
33
+
34
+ def __init__(self, gpu_id):
35
+ self.gpu_id = 'cuda:' + str(gpu_id)
36
+
37
+ vae = AutoencoderKL.from_pretrained(
38
+ VAE_PATH,
39
+ subfolder="vae",
40
+ torch_dtype=torch.float16,
41
+ )
42
+
43
+ unet_garm = UNetGarm2DConditionModel.from_pretrained(
44
+ UNET_PATH,
45
+ subfolder="unet_garm",
46
+ torch_dtype=torch.float16,
47
+ use_safetensors=True,
48
+ )
49
+ unet_vton = UNetVton2DConditionModel.from_pretrained(
50
+ UNET_PATH,
51
+ subfolder="unet_vton",
52
+ torch_dtype=torch.float16,
53
+ use_safetensors=True,
54
+ )
55
+
56
+ self.pipe = OotdPipeline.from_pretrained(
57
+ MODEL_PATH,
58
+ unet_garm=unet_garm,
59
+ unet_vton=unet_vton,
60
+ vae=vae,
61
+ torch_dtype=torch.float16,
62
+ variant="fp16",
63
+ use_safetensors=True,
64
+ safety_checker=None,
65
+ requires_safety_checker=False,
66
+ ).to(self.gpu_id)
67
+
68
+ self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
69
+
70
+ self.auto_processor = AutoProcessor.from_pretrained(VIT_PATH)
71
+ self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(VIT_PATH).to(self.gpu_id)
72
+
73
+ self.tokenizer = CLIPTokenizer.from_pretrained(
74
+ MODEL_PATH,
75
+ subfolder="tokenizer",
76
+ )
77
+ self.text_encoder = CLIPTextModel.from_pretrained(
78
+ MODEL_PATH,
79
+ subfolder="text_encoder",
80
+ ).to(self.gpu_id)
81
+
82
+
83
+ def tokenize_captions(self, captions, max_length):
84
+ inputs = self.tokenizer(
85
+ captions, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
86
+ )
87
+ return inputs.input_ids
88
+
89
+
90
+ def __call__(self,
91
+ model_type='hd',
92
+ category='upperbody',
93
+ image_garm=None,
94
+ image_vton=None,
95
+ mask=None,
96
+ image_ori=None,
97
+ num_samples=1,
98
+ num_steps=20,
99
+ image_scale=1.0,
100
+ seed=-1,
101
+ ):
102
+ if seed == -1:
103
+ random.seed(time.time())
104
+ seed = random.randint(0, 2147483647)
105
+ print('Initial seed: ' + str(seed))
106
+ generator = torch.manual_seed(seed)
107
+
108
+ with torch.no_grad():
109
+ prompt_image = self.auto_processor(images=image_garm, return_tensors="pt").to(self.gpu_id)
110
+ prompt_image = self.image_encoder(prompt_image.data['pixel_values']).image_embeds
111
+ prompt_image = prompt_image.unsqueeze(1)
112
+ if model_type == 'hd':
113
+ prompt_embeds = self.text_encoder(self.tokenize_captions([""], 2).to(self.gpu_id))[0]
114
+ prompt_embeds[:, 1:] = prompt_image[:]
115
+ elif model_type == 'dc':
116
+ prompt_embeds = self.text_encoder(self.tokenize_captions([category], 3).to(self.gpu_id))[0]
117
+ prompt_embeds = torch.cat([prompt_embeds, prompt_image], dim=1)
118
+ else:
119
+ raise ValueError("model_type must be \'hd\' or \'dc\'!")
120
+
121
+ images = self.pipe(prompt_embeds=prompt_embeds,
122
+ image_garm=image_garm,
123
+ image_vton=image_vton,
124
+ mask=mask,
125
+ image_ori=image_ori,
126
+ num_inference_steps=num_steps,
127
+ image_guidance_scale=image_scale,
128
+ num_images_per_prompt=num_samples,
129
+ generator=generator,
130
+ ).images
131
+
132
+ return images
ootd/inference_ootd_hd.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pdb
2
+ from pathlib import Path
3
+ import sys
4
+ PROJECT_ROOT = Path(__file__).absolute().parents[0].absolute()
5
+ sys.path.insert(0, str(PROJECT_ROOT))
6
+ import os
7
+ import torch
8
+ import numpy as np
9
+ from PIL import Image
10
+ import cv2
11
+
12
+ import random
13
+ import time
14
+ import pdb
15
+
16
+ from pipelines_ootd.pipeline_ootd import OotdPipeline
17
+ from pipelines_ootd.unet_garm_2d_condition import UNetGarm2DConditionModel
18
+ from pipelines_ootd.unet_vton_2d_condition import UNetVton2DConditionModel
19
+ from diffusers import UniPCMultistepScheduler
20
+ from diffusers import AutoencoderKL
21
+
22
+ import torch.nn as nn
23
+ import torch.nn.functional as F
24
+ from transformers import AutoProcessor, CLIPVisionModelWithProjection
25
+ from transformers import CLIPTextModel, CLIPTokenizer
26
+
27
+ VIT_PATH = "openai/clip-vit-large-patch14"
28
+ VAE_PATH = "./checkpoints/ootd"
29
+ UNET_PATH = "./checkpoints/ootd/ootd_hd/checkpoint-36000"
30
+ MODEL_PATH = "./checkpoints/ootd"
31
+
32
+ class OOTDiffusionHD:
33
+
34
+ def __init__(self, gpu_id):
35
+ self.gpu_id = 'cuda:' + str(gpu_id)
36
+
37
+ vae = AutoencoderKL.from_pretrained(
38
+ VAE_PATH,
39
+ subfolder="vae",
40
+ torch_dtype=torch.float16,
41
+ )
42
+
43
+ unet_garm = UNetGarm2DConditionModel.from_pretrained(
44
+ UNET_PATH,
45
+ subfolder="unet_garm",
46
+ torch_dtype=torch.float16,
47
+ use_safetensors=True,
48
+ )
49
+ unet_vton = UNetVton2DConditionModel.from_pretrained(
50
+ UNET_PATH,
51
+ subfolder="unet_vton",
52
+ torch_dtype=torch.float16,
53
+ use_safetensors=True,
54
+ )
55
+
56
+ self.pipe = OotdPipeline.from_pretrained(
57
+ MODEL_PATH,
58
+ unet_garm=unet_garm,
59
+ unet_vton=unet_vton,
60
+ vae=vae,
61
+ torch_dtype=torch.float16,
62
+ variant="fp16",
63
+ use_safetensors=True,
64
+ safety_checker=None,
65
+ requires_safety_checker=False,
66
+ ).to(self.gpu_id)
67
+
68
+ self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
69
+
70
+ self.auto_processor = AutoProcessor.from_pretrained(VIT_PATH)
71
+ self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(VIT_PATH).to(self.gpu_id)
72
+
73
+ self.tokenizer = CLIPTokenizer.from_pretrained(
74
+ MODEL_PATH,
75
+ subfolder="tokenizer",
76
+ )
77
+ self.text_encoder = CLIPTextModel.from_pretrained(
78
+ MODEL_PATH,
79
+ subfolder="text_encoder",
80
+ ).to(self.gpu_id)
81
+
82
+
83
+ def tokenize_captions(self, captions, max_length):
84
+ inputs = self.tokenizer(
85
+ captions, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
86
+ )
87
+ return inputs.input_ids
88
+
89
+
90
+ def __call__(self,
91
+ model_type='hd',
92
+ category='upperbody',
93
+ image_garm=None,
94
+ image_vton=None,
95
+ mask=None,
96
+ image_ori=None,
97
+ num_samples=1,
98
+ num_steps=20,
99
+ image_scale=1.0,
100
+ seed=-1,
101
+ ):
102
+ if seed == -1:
103
+ random.seed(time.time())
104
+ seed = random.randint(0, 2147483647)
105
+ print('Initial seed: ' + str(seed))
106
+ generator = torch.manual_seed(seed)
107
+
108
+ with torch.no_grad():
109
+ prompt_image = self.auto_processor(images=image_garm, return_tensors="pt").to(self.gpu_id)
110
+ prompt_image = self.image_encoder(prompt_image.data['pixel_values']).image_embeds
111
+ prompt_image = prompt_image.unsqueeze(1)
112
+ if model_type == 'hd':
113
+ prompt_embeds = self.text_encoder(self.tokenize_captions([""], 2).to(self.gpu_id))[0]
114
+ prompt_embeds[:, 1:] = prompt_image[:]
115
+ elif model_type == 'dc':
116
+ prompt_embeds = self.text_encoder(self.tokenize_captions([category], 3).to(self.gpu_id))[0]
117
+ prompt_embeds = torch.cat([prompt_embeds, prompt_image], dim=1)
118
+ else:
119
+ raise ValueError("model_type must be \'hd\' or \'dc\'!")
120
+
121
+ images = self.pipe(prompt_embeds=prompt_embeds,
122
+ image_garm=image_garm,
123
+ image_vton=image_vton,
124
+ mask=mask,
125
+ image_ori=image_ori,
126
+ num_inference_steps=num_steps,
127
+ image_guidance_scale=image_scale,
128
+ num_images_per_prompt=num_samples,
129
+ generator=generator,
130
+ ).images
131
+
132
+ return images
ootd/pipelines_ootd/__pycache__/attention_garm.cpython-310.pyc ADDED
Binary file (11.5 kB). View file
 
ootd/pipelines_ootd/__pycache__/attention_vton.cpython-310.pyc ADDED
Binary file (11.5 kB). View file
 
ootd/pipelines_ootd/__pycache__/pipeline_ootd.cpython-310.pyc ADDED
Binary file (27 kB). View file
 
ootd/pipelines_ootd/__pycache__/transformer_garm_2d.cpython-310.pyc ADDED
Binary file (13.7 kB). View file
 
ootd/pipelines_ootd/__pycache__/transformer_vton_2d.cpython-310.pyc ADDED
Binary file (13.7 kB). View file
 
ootd/pipelines_ootd/__pycache__/unet_garm_2d_blocks.cpython-310.pyc ADDED
Binary file (63.5 kB). View file
 
ootd/pipelines_ootd/__pycache__/unet_garm_2d_condition.cpython-310.pyc ADDED
Binary file (37.1 kB). View file
 
ootd/pipelines_ootd/__pycache__/unet_vton_2d_blocks.cpython-310.pyc ADDED
Binary file (63.6 kB). View file
 
ootd/pipelines_ootd/__pycache__/unet_vton_2d_condition.cpython-310.pyc ADDED
Binary file (37.2 kB). View file
 
ootd/pipelines_ootd/attention_garm.py ADDED
@@ -0,0 +1,402 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Modified by Yuhao Xu for OOTDiffusion (https://github.com/levihsu/OOTDiffusion)
16
+ from typing import Any, Dict, Optional
17
+
18
+ import torch
19
+ from torch import nn
20
+
21
+ from diffusers.utils import USE_PEFT_BACKEND
22
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
23
+ from diffusers.models.activations import GEGLU, GELU, ApproximateGELU
24
+ from diffusers.models.attention_processor import Attention
25
+ from diffusers.models.embeddings import SinusoidalPositionalEmbedding
26
+ from diffusers.models.lora import LoRACompatibleLinear
27
+ from diffusers.models.normalization import AdaLayerNorm, AdaLayerNormZero
28
+
29
+
30
+ @maybe_allow_in_graph
31
+ class GatedSelfAttentionDense(nn.Module):
32
+ r"""
33
+ A gated self-attention dense layer that combines visual features and object features.
34
+
35
+ Parameters:
36
+ query_dim (`int`): The number of channels in the query.
37
+ context_dim (`int`): The number of channels in the context.
38
+ n_heads (`int`): The number of heads to use for attention.
39
+ d_head (`int`): The number of channels in each head.
40
+ """
41
+
42
+ def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int):
43
+ super().__init__()
44
+
45
+ # we need a linear projection since we need cat visual feature and obj feature
46
+ self.linear = nn.Linear(context_dim, query_dim)
47
+
48
+ self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head)
49
+ self.ff = FeedForward(query_dim, activation_fn="geglu")
50
+
51
+ self.norm1 = nn.LayerNorm(query_dim)
52
+ self.norm2 = nn.LayerNorm(query_dim)
53
+
54
+ self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0)))
55
+ self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0)))
56
+
57
+ self.enabled = True
58
+
59
+ def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor:
60
+ if not self.enabled:
61
+ return x
62
+
63
+ n_visual = x.shape[1]
64
+ objs = self.linear(objs)
65
+
66
+ x = x + self.alpha_attn.tanh() * self.attn(self.norm1(torch.cat([x, objs], dim=1)))[:, :n_visual, :]
67
+ x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x))
68
+
69
+ return x
70
+
71
+
72
+ @maybe_allow_in_graph
73
+ class BasicTransformerBlock(nn.Module):
74
+ r"""
75
+ A basic Transformer block.
76
+
77
+ Parameters:
78
+ dim (`int`): The number of channels in the input and output.
79
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
80
+ attention_head_dim (`int`): The number of channels in each head.
81
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
82
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
83
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
84
+ num_embeds_ada_norm (:
85
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
86
+ attention_bias (:
87
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
88
+ only_cross_attention (`bool`, *optional*):
89
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
90
+ double_self_attention (`bool`, *optional*):
91
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
92
+ upcast_attention (`bool`, *optional*):
93
+ Whether to upcast the attention computation to float32. This is useful for mixed precision training.
94
+ norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
95
+ Whether to use learnable elementwise affine parameters for normalization.
96
+ norm_type (`str`, *optional*, defaults to `"layer_norm"`):
97
+ The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
98
+ final_dropout (`bool` *optional*, defaults to False):
99
+ Whether to apply a final dropout after the last feed-forward layer.
100
+ attention_type (`str`, *optional*, defaults to `"default"`):
101
+ The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
102
+ positional_embeddings (`str`, *optional*, defaults to `None`):
103
+ The type of positional embeddings to apply to.
104
+ num_positional_embeddings (`int`, *optional*, defaults to `None`):
105
+ The maximum number of positional embeddings to apply.
106
+ """
107
+
108
+ def __init__(
109
+ self,
110
+ dim: int,
111
+ num_attention_heads: int,
112
+ attention_head_dim: int,
113
+ dropout=0.0,
114
+ cross_attention_dim: Optional[int] = None,
115
+ activation_fn: str = "geglu",
116
+ num_embeds_ada_norm: Optional[int] = None,
117
+ attention_bias: bool = False,
118
+ only_cross_attention: bool = False,
119
+ double_self_attention: bool = False,
120
+ upcast_attention: bool = False,
121
+ norm_elementwise_affine: bool = True,
122
+ norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single'
123
+ norm_eps: float = 1e-5,
124
+ final_dropout: bool = False,
125
+ attention_type: str = "default",
126
+ positional_embeddings: Optional[str] = None,
127
+ num_positional_embeddings: Optional[int] = None,
128
+ ):
129
+ super().__init__()
130
+ self.only_cross_attention = only_cross_attention
131
+
132
+ self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
133
+ self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
134
+ self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
135
+ self.use_layer_norm = norm_type == "layer_norm"
136
+
137
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
138
+ raise ValueError(
139
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
140
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
141
+ )
142
+
143
+ if positional_embeddings and (num_positional_embeddings is None):
144
+ raise ValueError(
145
+ "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
146
+ )
147
+
148
+ if positional_embeddings == "sinusoidal":
149
+ self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
150
+ else:
151
+ self.pos_embed = None
152
+
153
+ # Define 3 blocks. Each block has its own normalization layer.
154
+ # 1. Self-Attn
155
+ if self.use_ada_layer_norm:
156
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
157
+ elif self.use_ada_layer_norm_zero:
158
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
159
+ else:
160
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
161
+
162
+ self.attn1 = Attention(
163
+ query_dim=dim,
164
+ heads=num_attention_heads,
165
+ dim_head=attention_head_dim,
166
+ dropout=dropout,
167
+ bias=attention_bias,
168
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
169
+ upcast_attention=upcast_attention,
170
+ )
171
+
172
+ # 2. Cross-Attn
173
+ if cross_attention_dim is not None or double_self_attention:
174
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
175
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
176
+ # the second cross attention block.
177
+ self.norm2 = (
178
+ AdaLayerNorm(dim, num_embeds_ada_norm)
179
+ if self.use_ada_layer_norm
180
+ else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
181
+ )
182
+ self.attn2 = Attention(
183
+ query_dim=dim,
184
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
185
+ heads=num_attention_heads,
186
+ dim_head=attention_head_dim,
187
+ dropout=dropout,
188
+ bias=attention_bias,
189
+ upcast_attention=upcast_attention,
190
+ ) # is self-attn if encoder_hidden_states is none
191
+ else:
192
+ self.norm2 = None
193
+ self.attn2 = None
194
+
195
+ # 3. Feed-forward
196
+ if not self.use_ada_layer_norm_single:
197
+ self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
198
+
199
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
200
+
201
+ # 4. Fuser
202
+ if attention_type == "gated" or attention_type == "gated-text-image":
203
+ self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
204
+
205
+ # 5. Scale-shift for PixArt-Alpha.
206
+ if self.use_ada_layer_norm_single:
207
+ self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
208
+
209
+ # let chunk size default to None
210
+ self._chunk_size = None
211
+ self._chunk_dim = 0
212
+
213
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
214
+ # Sets chunk feed-forward
215
+ self._chunk_size = chunk_size
216
+ self._chunk_dim = dim
217
+
218
+ def forward(
219
+ self,
220
+ hidden_states: torch.FloatTensor,
221
+ spatial_attn_inputs = [],
222
+ attention_mask: Optional[torch.FloatTensor] = None,
223
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
224
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
225
+ timestep: Optional[torch.LongTensor] = None,
226
+ cross_attention_kwargs: Dict[str, Any] = None,
227
+ class_labels: Optional[torch.LongTensor] = None,
228
+ ) -> torch.FloatTensor:
229
+ # Notice that normalization is always applied before the real computation in the following blocks.
230
+ # 0. Self-Attention
231
+ batch_size = hidden_states.shape[0]
232
+
233
+ spatial_attn_input = hidden_states
234
+ spatial_attn_inputs.append(spatial_attn_input)
235
+
236
+ if self.use_ada_layer_norm:
237
+ norm_hidden_states = self.norm1(hidden_states, timestep)
238
+ elif self.use_ada_layer_norm_zero:
239
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
240
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
241
+ )
242
+ elif self.use_layer_norm:
243
+ norm_hidden_states = self.norm1(hidden_states)
244
+ elif self.use_ada_layer_norm_single:
245
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
246
+ self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
247
+ ).chunk(6, dim=1)
248
+ norm_hidden_states = self.norm1(hidden_states)
249
+ norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
250
+ norm_hidden_states = norm_hidden_states.squeeze(1)
251
+ else:
252
+ raise ValueError("Incorrect norm used")
253
+
254
+ if self.pos_embed is not None:
255
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
256
+
257
+ # 1. Retrieve lora scale.
258
+ lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
259
+
260
+ # 2. Prepare GLIGEN inputs
261
+ cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
262
+ gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
263
+
264
+ attn_output = self.attn1(
265
+ norm_hidden_states,
266
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
267
+ attention_mask=attention_mask,
268
+ **cross_attention_kwargs,
269
+ )
270
+ if self.use_ada_layer_norm_zero:
271
+ attn_output = gate_msa.unsqueeze(1) * attn_output
272
+ elif self.use_ada_layer_norm_single:
273
+ attn_output = gate_msa * attn_output
274
+
275
+ hidden_states = attn_output + hidden_states
276
+ if hidden_states.ndim == 4:
277
+ hidden_states = hidden_states.squeeze(1)
278
+
279
+ # 2.5 GLIGEN Control
280
+ if gligen_kwargs is not None:
281
+ hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
282
+
283
+ # 3. Cross-Attention
284
+ if self.attn2 is not None:
285
+ if self.use_ada_layer_norm:
286
+ norm_hidden_states = self.norm2(hidden_states, timestep)
287
+ elif self.use_ada_layer_norm_zero or self.use_layer_norm:
288
+ norm_hidden_states = self.norm2(hidden_states)
289
+ elif self.use_ada_layer_norm_single:
290
+ # For PixArt norm2 isn't applied here:
291
+ # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
292
+ norm_hidden_states = hidden_states
293
+ else:
294
+ raise ValueError("Incorrect norm")
295
+
296
+ if self.pos_embed is not None and self.use_ada_layer_norm_single is False:
297
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
298
+
299
+ attn_output = self.attn2(
300
+ norm_hidden_states,
301
+ encoder_hidden_states=encoder_hidden_states,
302
+ attention_mask=encoder_attention_mask,
303
+ **cross_attention_kwargs,
304
+ )
305
+ hidden_states = attn_output + hidden_states
306
+
307
+ # 4. Feed-forward
308
+ if not self.use_ada_layer_norm_single:
309
+ norm_hidden_states = self.norm3(hidden_states)
310
+
311
+ if self.use_ada_layer_norm_zero:
312
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
313
+
314
+ if self.use_ada_layer_norm_single:
315
+ norm_hidden_states = self.norm2(hidden_states)
316
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
317
+
318
+ if self._chunk_size is not None:
319
+ # "feed_forward_chunk_size" can be used to save memory
320
+ if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
321
+ raise ValueError(
322
+ f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
323
+ )
324
+
325
+ num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
326
+ ff_output = torch.cat(
327
+ [
328
+ self.ff(hid_slice, scale=lora_scale)
329
+ for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)
330
+ ],
331
+ dim=self._chunk_dim,
332
+ )
333
+ else:
334
+ ff_output = self.ff(norm_hidden_states, scale=lora_scale)
335
+
336
+ if self.use_ada_layer_norm_zero:
337
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
338
+ elif self.use_ada_layer_norm_single:
339
+ ff_output = gate_mlp * ff_output
340
+
341
+ hidden_states = ff_output + hidden_states
342
+ if hidden_states.ndim == 4:
343
+ hidden_states = hidden_states.squeeze(1)
344
+
345
+ return hidden_states, spatial_attn_inputs
346
+
347
+
348
+ class FeedForward(nn.Module):
349
+ r"""
350
+ A feed-forward layer.
351
+
352
+ Parameters:
353
+ dim (`int`): The number of channels in the input.
354
+ dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
355
+ mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
356
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
357
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
358
+ final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
359
+ """
360
+
361
+ def __init__(
362
+ self,
363
+ dim: int,
364
+ dim_out: Optional[int] = None,
365
+ mult: int = 4,
366
+ dropout: float = 0.0,
367
+ activation_fn: str = "geglu",
368
+ final_dropout: bool = False,
369
+ ):
370
+ super().__init__()
371
+ inner_dim = int(dim * mult)
372
+ dim_out = dim_out if dim_out is not None else dim
373
+ linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear
374
+
375
+ if activation_fn == "gelu":
376
+ act_fn = GELU(dim, inner_dim)
377
+ if activation_fn == "gelu-approximate":
378
+ act_fn = GELU(dim, inner_dim, approximate="tanh")
379
+ elif activation_fn == "geglu":
380
+ act_fn = GEGLU(dim, inner_dim)
381
+ elif activation_fn == "geglu-approximate":
382
+ act_fn = ApproximateGELU(dim, inner_dim)
383
+
384
+ self.net = nn.ModuleList([])
385
+ # project in
386
+ self.net.append(act_fn)
387
+ # project dropout
388
+ self.net.append(nn.Dropout(dropout))
389
+ # project out
390
+ self.net.append(linear_cls(inner_dim, dim_out))
391
+ # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
392
+ if final_dropout:
393
+ self.net.append(nn.Dropout(dropout))
394
+
395
+ def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
396
+ compatible_cls = (GEGLU,) if USE_PEFT_BACKEND else (GEGLU, LoRACompatibleLinear)
397
+ for module in self.net:
398
+ if isinstance(module, compatible_cls):
399
+ hidden_states = module(hidden_states, scale)
400
+ else:
401
+ hidden_states = module(hidden_states)
402
+ return hidden_states
ootd/pipelines_ootd/attention_vton.py ADDED
@@ -0,0 +1,407 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Modified by Yuhao Xu for OOTDiffusion (https://github.com/levihsu/OOTDiffusion)
16
+ from typing import Any, Dict, Optional
17
+
18
+ import torch
19
+ from torch import nn
20
+
21
+ from diffusers.utils import USE_PEFT_BACKEND
22
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
23
+ from diffusers.models.activations import GEGLU, GELU, ApproximateGELU
24
+ from diffusers.models.attention_processor import Attention
25
+ from diffusers.models.embeddings import SinusoidalPositionalEmbedding
26
+ from diffusers.models.lora import LoRACompatibleLinear
27
+ from diffusers.models.normalization import AdaLayerNorm, AdaLayerNormZero
28
+
29
+
30
+ @maybe_allow_in_graph
31
+ class GatedSelfAttentionDense(nn.Module):
32
+ r"""
33
+ A gated self-attention dense layer that combines visual features and object features.
34
+
35
+ Parameters:
36
+ query_dim (`int`): The number of channels in the query.
37
+ context_dim (`int`): The number of channels in the context.
38
+ n_heads (`int`): The number of heads to use for attention.
39
+ d_head (`int`): The number of channels in each head.
40
+ """
41
+
42
+ def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int):
43
+ super().__init__()
44
+
45
+ # we need a linear projection since we need cat visual feature and obj feature
46
+ self.linear = nn.Linear(context_dim, query_dim)
47
+
48
+ self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head)
49
+ self.ff = FeedForward(query_dim, activation_fn="geglu")
50
+
51
+ self.norm1 = nn.LayerNorm(query_dim)
52
+ self.norm2 = nn.LayerNorm(query_dim)
53
+
54
+ self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0)))
55
+ self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0)))
56
+
57
+ self.enabled = True
58
+
59
+ def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor:
60
+ if not self.enabled:
61
+ return x
62
+
63
+ n_visual = x.shape[1]
64
+ objs = self.linear(objs)
65
+
66
+ x = x + self.alpha_attn.tanh() * self.attn(self.norm1(torch.cat([x, objs], dim=1)))[:, :n_visual, :]
67
+ x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x))
68
+
69
+ return x
70
+
71
+
72
+ @maybe_allow_in_graph
73
+ class BasicTransformerBlock(nn.Module):
74
+ r"""
75
+ A basic Transformer block.
76
+
77
+ Parameters:
78
+ dim (`int`): The number of channels in the input and output.
79
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
80
+ attention_head_dim (`int`): The number of channels in each head.
81
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
82
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
83
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
84
+ num_embeds_ada_norm (:
85
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
86
+ attention_bias (:
87
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
88
+ only_cross_attention (`bool`, *optional*):
89
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
90
+ double_self_attention (`bool`, *optional*):
91
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
92
+ upcast_attention (`bool`, *optional*):
93
+ Whether to upcast the attention computation to float32. This is useful for mixed precision training.
94
+ norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
95
+ Whether to use learnable elementwise affine parameters for normalization.
96
+ norm_type (`str`, *optional*, defaults to `"layer_norm"`):
97
+ The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
98
+ final_dropout (`bool` *optional*, defaults to False):
99
+ Whether to apply a final dropout after the last feed-forward layer.
100
+ attention_type (`str`, *optional*, defaults to `"default"`):
101
+ The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
102
+ positional_embeddings (`str`, *optional*, defaults to `None`):
103
+ The type of positional embeddings to apply to.
104
+ num_positional_embeddings (`int`, *optional*, defaults to `None`):
105
+ The maximum number of positional embeddings to apply.
106
+ """
107
+
108
+ def __init__(
109
+ self,
110
+ dim: int,
111
+ num_attention_heads: int,
112
+ attention_head_dim: int,
113
+ dropout=0.0,
114
+ cross_attention_dim: Optional[int] = None,
115
+ activation_fn: str = "geglu",
116
+ num_embeds_ada_norm: Optional[int] = None,
117
+ attention_bias: bool = False,
118
+ only_cross_attention: bool = False,
119
+ double_self_attention: bool = False,
120
+ upcast_attention: bool = False,
121
+ norm_elementwise_affine: bool = True,
122
+ norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single'
123
+ norm_eps: float = 1e-5,
124
+ final_dropout: bool = False,
125
+ attention_type: str = "default",
126
+ positional_embeddings: Optional[str] = None,
127
+ num_positional_embeddings: Optional[int] = None,
128
+ ):
129
+ super().__init__()
130
+ self.only_cross_attention = only_cross_attention
131
+
132
+ self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
133
+ self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
134
+ self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
135
+ self.use_layer_norm = norm_type == "layer_norm"
136
+
137
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
138
+ raise ValueError(
139
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
140
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
141
+ )
142
+
143
+ if positional_embeddings and (num_positional_embeddings is None):
144
+ raise ValueError(
145
+ "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
146
+ )
147
+
148
+ if positional_embeddings == "sinusoidal":
149
+ self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
150
+ else:
151
+ self.pos_embed = None
152
+
153
+ # Define 3 blocks. Each block has its own normalization layer.
154
+ # 1. Self-Attn
155
+ if self.use_ada_layer_norm:
156
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
157
+ elif self.use_ada_layer_norm_zero:
158
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
159
+ else:
160
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
161
+
162
+ self.attn1 = Attention(
163
+ query_dim=dim,
164
+ heads=num_attention_heads,
165
+ dim_head=attention_head_dim,
166
+ dropout=dropout,
167
+ bias=attention_bias,
168
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
169
+ upcast_attention=upcast_attention,
170
+ )
171
+
172
+ # 2. Cross-Attn
173
+ if cross_attention_dim is not None or double_self_attention:
174
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
175
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
176
+ # the second cross attention block.
177
+ self.norm2 = (
178
+ AdaLayerNorm(dim, num_embeds_ada_norm)
179
+ if self.use_ada_layer_norm
180
+ else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
181
+ )
182
+ self.attn2 = Attention(
183
+ query_dim=dim,
184
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
185
+ heads=num_attention_heads,
186
+ dim_head=attention_head_dim,
187
+ dropout=dropout,
188
+ bias=attention_bias,
189
+ upcast_attention=upcast_attention,
190
+ ) # is self-attn if encoder_hidden_states is none
191
+ else:
192
+ self.norm2 = None
193
+ self.attn2 = None
194
+
195
+ # 3. Feed-forward
196
+ if not self.use_ada_layer_norm_single:
197
+ self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
198
+
199
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
200
+
201
+ # 4. Fuser
202
+ if attention_type == "gated" or attention_type == "gated-text-image":
203
+ self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
204
+
205
+ # 5. Scale-shift for PixArt-Alpha.
206
+ if self.use_ada_layer_norm_single:
207
+ self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
208
+
209
+ # let chunk size default to None
210
+ self._chunk_size = None
211
+ self._chunk_dim = 0
212
+
213
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
214
+ # Sets chunk feed-forward
215
+ self._chunk_size = chunk_size
216
+ self._chunk_dim = dim
217
+
218
+ def forward(
219
+ self,
220
+ hidden_states: torch.FloatTensor,
221
+ spatial_attn_inputs = [],
222
+ spatial_attn_idx = 0,
223
+ attention_mask: Optional[torch.FloatTensor] = None,
224
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
225
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
226
+ timestep: Optional[torch.LongTensor] = None,
227
+ cross_attention_kwargs: Dict[str, Any] = None,
228
+ class_labels: Optional[torch.LongTensor] = None,
229
+ ) -> torch.FloatTensor:
230
+ # Notice that normalization is always applied before the real computation in the following blocks.
231
+ # 0. Self-Attention
232
+ batch_size = hidden_states.shape[0]
233
+
234
+ spatial_attn_input = spatial_attn_inputs[spatial_attn_idx]
235
+ spatial_attn_idx += 1
236
+ hidden_states = torch.cat((hidden_states, spatial_attn_input), dim=1)
237
+
238
+ if self.use_ada_layer_norm:
239
+ norm_hidden_states = self.norm1(hidden_states, timestep)
240
+ elif self.use_ada_layer_norm_zero:
241
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
242
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
243
+ )
244
+ elif self.use_layer_norm:
245
+ norm_hidden_states = self.norm1(hidden_states)
246
+ elif self.use_ada_layer_norm_single:
247
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
248
+ self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
249
+ ).chunk(6, dim=1)
250
+ norm_hidden_states = self.norm1(hidden_states)
251
+ norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
252
+ norm_hidden_states = norm_hidden_states.squeeze(1)
253
+ else:
254
+ raise ValueError("Incorrect norm used")
255
+
256
+ if self.pos_embed is not None:
257
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
258
+
259
+ # 1. Retrieve lora scale.
260
+ lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
261
+
262
+ # 2. Prepare GLIGEN inputs
263
+ cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
264
+ gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
265
+
266
+ attn_output = self.attn1(
267
+ norm_hidden_states,
268
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
269
+ attention_mask=attention_mask,
270
+ **cross_attention_kwargs,
271
+ )
272
+ if self.use_ada_layer_norm_zero:
273
+ attn_output = gate_msa.unsqueeze(1) * attn_output
274
+ elif self.use_ada_layer_norm_single:
275
+ attn_output = gate_msa * attn_output
276
+
277
+
278
+ hidden_states = attn_output + hidden_states
279
+ hidden_states, _ = hidden_states.chunk(2, dim=1)
280
+
281
+ if hidden_states.ndim == 4:
282
+ hidden_states = hidden_states.squeeze(1)
283
+
284
+ # 2.5 GLIGEN Control
285
+ if gligen_kwargs is not None:
286
+ hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
287
+
288
+ # 3. Cross-Attention
289
+ if self.attn2 is not None:
290
+ if self.use_ada_layer_norm:
291
+ norm_hidden_states = self.norm2(hidden_states, timestep)
292
+ elif self.use_ada_layer_norm_zero or self.use_layer_norm:
293
+ norm_hidden_states = self.norm2(hidden_states)
294
+ elif self.use_ada_layer_norm_single:
295
+ # For PixArt norm2 isn't applied here:
296
+ # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
297
+ norm_hidden_states = hidden_states
298
+ else:
299
+ raise ValueError("Incorrect norm")
300
+
301
+ if self.pos_embed is not None and self.use_ada_layer_norm_single is False:
302
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
303
+
304
+ attn_output = self.attn2(
305
+ norm_hidden_states,
306
+ encoder_hidden_states=encoder_hidden_states,
307
+ attention_mask=encoder_attention_mask,
308
+ **cross_attention_kwargs,
309
+ )
310
+ hidden_states = attn_output + hidden_states
311
+
312
+ # 4. Feed-forward
313
+ if not self.use_ada_layer_norm_single:
314
+ norm_hidden_states = self.norm3(hidden_states)
315
+
316
+ if self.use_ada_layer_norm_zero:
317
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
318
+
319
+ if self.use_ada_layer_norm_single:
320
+ norm_hidden_states = self.norm2(hidden_states)
321
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
322
+
323
+ if self._chunk_size is not None:
324
+ # "feed_forward_chunk_size" can be used to save memory
325
+ if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
326
+ raise ValueError(
327
+ f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
328
+ )
329
+
330
+ num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
331
+ ff_output = torch.cat(
332
+ [
333
+ self.ff(hid_slice, scale=lora_scale)
334
+ for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)
335
+ ],
336
+ dim=self._chunk_dim,
337
+ )
338
+ else:
339
+ ff_output = self.ff(norm_hidden_states, scale=lora_scale)
340
+
341
+ if self.use_ada_layer_norm_zero:
342
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
343
+ elif self.use_ada_layer_norm_single:
344
+ ff_output = gate_mlp * ff_output
345
+
346
+ hidden_states = ff_output + hidden_states
347
+ if hidden_states.ndim == 4:
348
+ hidden_states = hidden_states.squeeze(1)
349
+
350
+ return hidden_states, spatial_attn_inputs, spatial_attn_idx
351
+
352
+
353
+ class FeedForward(nn.Module):
354
+ r"""
355
+ A feed-forward layer.
356
+
357
+ Parameters:
358
+ dim (`int`): The number of channels in the input.
359
+ dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
360
+ mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
361
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
362
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
363
+ final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
364
+ """
365
+
366
+ def __init__(
367
+ self,
368
+ dim: int,
369
+ dim_out: Optional[int] = None,
370
+ mult: int = 4,
371
+ dropout: float = 0.0,
372
+ activation_fn: str = "geglu",
373
+ final_dropout: bool = False,
374
+ ):
375
+ super().__init__()
376
+ inner_dim = int(dim * mult)
377
+ dim_out = dim_out if dim_out is not None else dim
378
+ linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear
379
+
380
+ if activation_fn == "gelu":
381
+ act_fn = GELU(dim, inner_dim)
382
+ if activation_fn == "gelu-approximate":
383
+ act_fn = GELU(dim, inner_dim, approximate="tanh")
384
+ elif activation_fn == "geglu":
385
+ act_fn = GEGLU(dim, inner_dim)
386
+ elif activation_fn == "geglu-approximate":
387
+ act_fn = ApproximateGELU(dim, inner_dim)
388
+
389
+ self.net = nn.ModuleList([])
390
+ # project in
391
+ self.net.append(act_fn)
392
+ # project dropout
393
+ self.net.append(nn.Dropout(dropout))
394
+ # project out
395
+ self.net.append(linear_cls(inner_dim, dim_out))
396
+ # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
397
+ if final_dropout:
398
+ self.net.append(nn.Dropout(dropout))
399
+
400
+ def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
401
+ compatible_cls = (GEGLU,) if USE_PEFT_BACKEND else (GEGLU, LoRACompatibleLinear)
402
+ for module in self.net:
403
+ if isinstance(module, compatible_cls):
404
+ hidden_states = module(hidden_states, scale)
405
+ else:
406
+ hidden_states = module(hidden_states)
407
+ return hidden_states
ootd/pipelines_ootd/pipeline_ootd.py ADDED
@@ -0,0 +1,846 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Modified by Yuhao Xu for OOTDiffusion (https://github.com/levihsu/OOTDiffusion)
16
+ import inspect
17
+ from typing import Any, Callable, Dict, List, Optional, Union
18
+
19
+ import numpy as np
20
+ import PIL.Image
21
+ import torch
22
+ from packaging import version
23
+ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
24
+
25
+ from transformers import AutoProcessor, CLIPVisionModelWithProjection
26
+
27
+ from .unet_vton_2d_condition import UNetVton2DConditionModel
28
+ from .unet_garm_2d_condition import UNetGarm2DConditionModel
29
+
30
+ from diffusers.configuration_utils import FrozenDict
31
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
32
+ from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
33
+ from diffusers.models import AutoencoderKL, UNet2DConditionModel
34
+ from diffusers.models.lora import adjust_lora_scale_text_encoder
35
+ from diffusers.schedulers import KarrasDiffusionSchedulers
36
+ from diffusers.utils import (
37
+ PIL_INTERPOLATION,
38
+ USE_PEFT_BACKEND,
39
+ deprecate,
40
+ logging,
41
+ replace_example_docstring,
42
+ scale_lora_layers,
43
+ unscale_lora_layers,
44
+ )
45
+ from diffusers.utils.torch_utils import randn_tensor
46
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
47
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
48
+ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
49
+
50
+
51
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
52
+
53
+
54
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess
55
+ def preprocess(image):
56
+ deprecation_message = "The preprocess method is deprecated and will be removed in diffusers 1.0.0. Please use VaeImageProcessor.preprocess(...) instead"
57
+ deprecate("preprocess", "1.0.0", deprecation_message, standard_warn=False)
58
+ if isinstance(image, torch.Tensor):
59
+ return image
60
+ elif isinstance(image, PIL.Image.Image):
61
+ image = [image]
62
+
63
+ if isinstance(image[0], PIL.Image.Image):
64
+ w, h = image[0].size
65
+ w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8
66
+
67
+ image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
68
+ image = np.concatenate(image, axis=0)
69
+ image = np.array(image).astype(np.float32) / 255.0
70
+ image = image.transpose(0, 3, 1, 2)
71
+ image = 2.0 * image - 1.0
72
+ image = torch.from_numpy(image)
73
+ elif isinstance(image[0], torch.Tensor):
74
+ image = torch.cat(image, dim=0)
75
+ return image
76
+
77
+
78
+ class OotdPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
79
+ r"""
80
+ Args:
81
+ vae ([`AutoencoderKL`]):
82
+ Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
83
+ text_encoder ([`~transformers.CLIPTextModel`]):
84
+ Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
85
+ tokenizer ([`~transformers.CLIPTokenizer`]):
86
+ A `CLIPTokenizer` to tokenize text.
87
+ unet ([`UNet2DConditionModel`]):
88
+ A `UNet2DConditionModel` to denoise the encoded image latents.
89
+ scheduler ([`SchedulerMixin`]):
90
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
91
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
92
+ safety_checker ([`StableDiffusionSafetyChecker`]):
93
+ Classification module that estimates whether generated images could be considered offensive or harmful.
94
+ Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
95
+ about a model's potential harms.
96
+ feature_extractor ([`~transformers.CLIPImageProcessor`]):
97
+ A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
98
+ """
99
+ model_cpu_offload_seq = "text_encoder->unet->vae"
100
+ _optional_components = ["safety_checker", "feature_extractor"]
101
+ _exclude_from_cpu_offload = ["safety_checker"]
102
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "vton_latents"]
103
+
104
+ def __init__(
105
+ self,
106
+ vae: AutoencoderKL,
107
+ text_encoder: CLIPTextModel,
108
+ tokenizer: CLIPTokenizer,
109
+ unet_garm: UNetGarm2DConditionModel,
110
+ unet_vton: UNetVton2DConditionModel,
111
+ scheduler: KarrasDiffusionSchedulers,
112
+ safety_checker: StableDiffusionSafetyChecker,
113
+ feature_extractor: CLIPImageProcessor,
114
+ requires_safety_checker: bool = True,
115
+ ):
116
+ super().__init__()
117
+
118
+ if safety_checker is None and requires_safety_checker:
119
+ logger.warning(
120
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
121
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
122
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
123
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
124
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
125
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
126
+ )
127
+
128
+ if safety_checker is not None and feature_extractor is None:
129
+ raise ValueError(
130
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
131
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
132
+ )
133
+
134
+ self.register_modules(
135
+ vae=vae,
136
+ text_encoder=text_encoder,
137
+ tokenizer=tokenizer,
138
+ unet_garm=unet_garm,
139
+ unet_vton=unet_vton,
140
+ scheduler=scheduler,
141
+ safety_checker=safety_checker,
142
+ feature_extractor=feature_extractor,
143
+ )
144
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
145
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
146
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
147
+
148
+ @torch.no_grad()
149
+ def __call__(
150
+ self,
151
+ prompt: Union[str, List[str]] = None,
152
+ image_garm: PipelineImageInput = None,
153
+ image_vton: PipelineImageInput = None,
154
+ mask: PipelineImageInput = None,
155
+ image_ori: PipelineImageInput = None,
156
+ num_inference_steps: int = 100,
157
+ guidance_scale: float = 7.5,
158
+ image_guidance_scale: float = 1.5,
159
+ negative_prompt: Optional[Union[str, List[str]]] = None,
160
+ num_images_per_prompt: Optional[int] = 1,
161
+ eta: float = 0.0,
162
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
163
+ latents: Optional[torch.FloatTensor] = None,
164
+ prompt_embeds: Optional[torch.FloatTensor] = None,
165
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
166
+ output_type: Optional[str] = "pil",
167
+ return_dict: bool = True,
168
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
169
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
170
+ **kwargs,
171
+ ):
172
+ r"""
173
+ The call function to the pipeline for generation.
174
+
175
+ Args:
176
+ prompt (`str` or `List[str]`, *optional*):
177
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
178
+ image (`torch.FloatTensor` `np.ndarray`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
179
+ `Image` or tensor representing an image batch to be repainted according to `prompt`. Can also accept
180
+ image latents as `image`, but if passing latents directly it is not encoded again.
181
+ num_inference_steps (`int`, *optional*, defaults to 100):
182
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
183
+ expense of slower inference.
184
+ guidance_scale (`float`, *optional*, defaults to 7.5):
185
+ A higher guidance scale value encourages the model to generate images closely linked to the text
186
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
187
+ image_guidance_scale (`float`, *optional*, defaults to 1.5):
188
+ Push the generated image towards the initial `image`. Image guidance scale is enabled by setting
189
+ `image_guidance_scale > 1`. Higher image guidance scale encourages generated images that are closely
190
+ linked to the source `image`, usually at the expense of lower image quality. This pipeline requires a
191
+ value of at least `1`.
192
+ negative_prompt (`str` or `List[str]`, *optional*):
193
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
194
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
195
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
196
+ The number of images to generate per prompt.
197
+ eta (`float`, *optional*, defaults to 0.0):
198
+ Corresponds to parameter eta (ฮท) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
199
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
200
+ generator (`torch.Generator`, *optional*):
201
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
202
+ generation deterministic.
203
+ latents (`torch.FloatTensor`, *optional*):
204
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
205
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
206
+ tensor is generated by sampling using the supplied random `generator`.
207
+ prompt_embeds (`torch.FloatTensor`, *optional*):
208
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
209
+ provided, text embeddings are generated from the `prompt` input argument.
210
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
211
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
212
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
213
+ output_type (`str`, *optional*, defaults to `"pil"`):
214
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
215
+ return_dict (`bool`, *optional*, defaults to `True`):
216
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
217
+ plain tuple.
218
+ callback_on_step_end (`Callable`, *optional*):
219
+ A function that calls at the end of each denoising steps during the inference. The function is called
220
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
221
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
222
+ `callback_on_step_end_tensor_inputs`.
223
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
224
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
225
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
226
+ `._callback_tensor_inputs` attribute of your pipeline class.
227
+
228
+ Returns:
229
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
230
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
231
+ otherwise a `tuple` is returned where the first element is a list with the generated images and the
232
+ second element is a list of `bool`s indicating whether the corresponding generated image contains
233
+ "not-safe-for-work" (nsfw) content.
234
+ """
235
+
236
+ callback = kwargs.pop("callback", None)
237
+ callback_steps = kwargs.pop("callback_steps", None)
238
+
239
+ if callback is not None:
240
+ deprecate(
241
+ "callback",
242
+ "1.0.0",
243
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
244
+ )
245
+ if callback_steps is not None:
246
+ deprecate(
247
+ "callback_steps",
248
+ "1.0.0",
249
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
250
+ )
251
+
252
+ # 0. Check inputs
253
+ self.check_inputs(
254
+ prompt,
255
+ callback_steps,
256
+ negative_prompt,
257
+ prompt_embeds,
258
+ negative_prompt_embeds,
259
+ callback_on_step_end_tensor_inputs,
260
+ )
261
+ self._guidance_scale = guidance_scale
262
+ self._image_guidance_scale = image_guidance_scale
263
+
264
+ if (image_vton is None) or (image_garm is None):
265
+ raise ValueError("`image` input cannot be undefined.")
266
+
267
+ # 1. Define call parameters
268
+ if prompt is not None and isinstance(prompt, str):
269
+ batch_size = 1
270
+ elif prompt is not None and isinstance(prompt, list):
271
+ batch_size = len(prompt)
272
+ else:
273
+ batch_size = prompt_embeds.shape[0]
274
+
275
+ device = self._execution_device
276
+ # check if scheduler is in sigmas space
277
+ scheduler_is_in_sigma_space = hasattr(self.scheduler, "sigmas")
278
+
279
+ # 2. Encode input prompt
280
+ prompt_embeds = self._encode_prompt(
281
+ prompt,
282
+ device,
283
+ num_images_per_prompt,
284
+ self.do_classifier_free_guidance,
285
+ negative_prompt,
286
+ prompt_embeds=prompt_embeds,
287
+ negative_prompt_embeds=negative_prompt_embeds,
288
+ )
289
+
290
+ # 3. Preprocess image
291
+ image_garm = self.image_processor.preprocess(image_garm)
292
+ image_vton = self.image_processor.preprocess(image_vton)
293
+ image_ori = self.image_processor.preprocess(image_ori)
294
+ mask = np.array(mask)
295
+ mask[mask < 127] = 0
296
+ mask[mask >= 127] = 255
297
+ mask = torch.tensor(mask)
298
+ mask = mask / 255
299
+ mask = mask.reshape(-1, 1, mask.size(-2), mask.size(-1))
300
+
301
+ # 4. set timesteps
302
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
303
+ timesteps = self.scheduler.timesteps
304
+
305
+ # 5. Prepare Image latents
306
+ garm_latents = self.prepare_garm_latents(
307
+ image_garm,
308
+ batch_size,
309
+ num_images_per_prompt,
310
+ prompt_embeds.dtype,
311
+ device,
312
+ self.do_classifier_free_guidance,
313
+ generator,
314
+ )
315
+
316
+ vton_latents, mask_latents, image_ori_latents = self.prepare_vton_latents(
317
+ image_vton,
318
+ mask,
319
+ image_ori,
320
+ batch_size,
321
+ num_images_per_prompt,
322
+ prompt_embeds.dtype,
323
+ device,
324
+ self.do_classifier_free_guidance,
325
+ generator,
326
+ )
327
+
328
+ height, width = vton_latents.shape[-2:]
329
+ height = height * self.vae_scale_factor
330
+ width = width * self.vae_scale_factor
331
+
332
+ # 6. Prepare latent variables
333
+ num_channels_latents = self.vae.config.latent_channels
334
+ latents = self.prepare_latents(
335
+ batch_size * num_images_per_prompt,
336
+ num_channels_latents,
337
+ height,
338
+ width,
339
+ prompt_embeds.dtype,
340
+ device,
341
+ generator,
342
+ latents,
343
+ )
344
+
345
+ noise = latents.clone()
346
+
347
+ # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
348
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
349
+
350
+ # 9. Denoising loop
351
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
352
+ self._num_timesteps = len(timesteps)
353
+
354
+ _, spatial_attn_outputs = self.unet_garm(
355
+ garm_latents,
356
+ 0,
357
+ encoder_hidden_states=prompt_embeds,
358
+ return_dict=False,
359
+ )
360
+
361
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
362
+ for i, t in enumerate(timesteps):
363
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
364
+
365
+ # concat latents, image_latents in the channel dimension
366
+ scaled_latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
367
+ latent_vton_model_input = torch.cat([scaled_latent_model_input, vton_latents], dim=1)
368
+ # latent_vton_model_input = scaled_latent_model_input + vton_latents
369
+
370
+ spatial_attn_inputs = spatial_attn_outputs.copy()
371
+
372
+ # predict the noise residual
373
+ noise_pred = self.unet_vton(
374
+ latent_vton_model_input,
375
+ spatial_attn_inputs,
376
+ t,
377
+ encoder_hidden_states=prompt_embeds,
378
+ return_dict=False,
379
+ )[0]
380
+
381
+ # Hack:
382
+ # For karras style schedulers the model does classifer free guidance using the
383
+ # predicted_original_sample instead of the noise_pred. So we need to compute the
384
+ # predicted_original_sample here if we are using a karras style scheduler.
385
+ if scheduler_is_in_sigma_space:
386
+ step_index = (self.scheduler.timesteps == t).nonzero()[0].item()
387
+ sigma = self.scheduler.sigmas[step_index]
388
+ noise_pred = latent_model_input - sigma * noise_pred
389
+
390
+ # perform guidance
391
+ if self.do_classifier_free_guidance:
392
+ noise_pred_text_image, noise_pred_text = noise_pred.chunk(2)
393
+ noise_pred = (
394
+ noise_pred_text
395
+ + self.image_guidance_scale * (noise_pred_text_image - noise_pred_text)
396
+ )
397
+
398
+ # Hack:
399
+ # For karras style schedulers the model does classifer free guidance using the
400
+ # predicted_original_sample instead of the noise_pred. But the scheduler.step function
401
+ # expects the noise_pred and computes the predicted_original_sample internally. So we
402
+ # need to overwrite the noise_pred here such that the value of the computed
403
+ # predicted_original_sample is correct.
404
+ if scheduler_is_in_sigma_space:
405
+ noise_pred = (noise_pred - latents) / (-sigma)
406
+
407
+ # compute the previous noisy sample x_t -> x_t-1
408
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
409
+
410
+ init_latents_proper = image_ori_latents * self.vae.config.scaling_factor
411
+
412
+ # repainting
413
+ if i < len(timesteps) - 1:
414
+ noise_timestep = timesteps[i + 1]
415
+ init_latents_proper = self.scheduler.add_noise(
416
+ init_latents_proper, noise, torch.tensor([noise_timestep])
417
+ )
418
+
419
+ latents = (1 - mask_latents) * init_latents_proper + mask_latents * latents
420
+
421
+ if callback_on_step_end is not None:
422
+ callback_kwargs = {}
423
+ for k in callback_on_step_end_tensor_inputs:
424
+ callback_kwargs[k] = locals()[k]
425
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
426
+
427
+ latents = callback_outputs.pop("latents", latents)
428
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
429
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
430
+ vton_latents = callback_outputs.pop("vton_latents", vton_latents)
431
+
432
+ # call the callback, if provided
433
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
434
+ progress_bar.update()
435
+ if callback is not None and i % callback_steps == 0:
436
+ step_idx = i // getattr(self.scheduler, "order", 1)
437
+ callback(step_idx, t, latents)
438
+
439
+ if not output_type == "latent":
440
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
441
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
442
+ else:
443
+ image = latents
444
+ has_nsfw_concept = None
445
+
446
+ if has_nsfw_concept is None:
447
+ do_denormalize = [True] * image.shape[0]
448
+ else:
449
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
450
+
451
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
452
+
453
+ # Offload all models
454
+ self.maybe_free_model_hooks()
455
+
456
+ if not return_dict:
457
+ return (image, has_nsfw_concept)
458
+
459
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
460
+
461
+ def _encode_prompt(
462
+ self,
463
+ prompt,
464
+ device,
465
+ num_images_per_prompt,
466
+ do_classifier_free_guidance,
467
+ negative_prompt=None,
468
+ prompt_embeds: Optional[torch.FloatTensor] = None,
469
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
470
+ ):
471
+ r"""
472
+ Encodes the prompt into text encoder hidden states.
473
+
474
+ Args:
475
+ prompt (`str` or `List[str]`, *optional*):
476
+ prompt to be encoded
477
+ device: (`torch.device`):
478
+ torch device
479
+ num_images_per_prompt (`int`):
480
+ number of images that should be generated per prompt
481
+ do_classifier_free_guidance (`bool`):
482
+ whether to use classifier free guidance or not
483
+ negative_ prompt (`str` or `List[str]`, *optional*):
484
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
485
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
486
+ less than `1`).
487
+ prompt_embeds (`torch.FloatTensor`, *optional*):
488
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
489
+ provided, text embeddings will be generated from `prompt` input argument.
490
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
491
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
492
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
493
+ argument.
494
+ """
495
+ if prompt is not None and isinstance(prompt, str):
496
+ batch_size = 1
497
+ elif prompt is not None and isinstance(prompt, list):
498
+ batch_size = len(prompt)
499
+ else:
500
+ batch_size = prompt_embeds.shape[0]
501
+
502
+ if prompt_embeds is None:
503
+ # textual inversion: procecss multi-vector tokens if necessary
504
+ if isinstance(self, TextualInversionLoaderMixin):
505
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
506
+
507
+ text_inputs = self.tokenizer(
508
+ prompt,
509
+ padding="max_length",
510
+ max_length=self.tokenizer.model_max_length,
511
+ truncation=True,
512
+ return_tensors="pt",
513
+ )
514
+ text_input_ids = text_inputs.input_ids
515
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
516
+
517
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
518
+ text_input_ids, untruncated_ids
519
+ ):
520
+ removed_text = self.tokenizer.batch_decode(
521
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
522
+ )
523
+ logger.warning(
524
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
525
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
526
+ )
527
+
528
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
529
+ attention_mask = text_inputs.attention_mask.to(device)
530
+ else:
531
+ attention_mask = None
532
+
533
+ prompt_embeds = self.text_encoder(
534
+ text_input_ids.to(device),
535
+ attention_mask=attention_mask,
536
+ )
537
+ prompt_embeds = prompt_embeds[0]
538
+
539
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
540
+
541
+ bs_embed, seq_len, _ = prompt_embeds.shape
542
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
543
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
544
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
545
+
546
+ # get unconditional embeddings for classifier free guidance
547
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
548
+ uncond_tokens: List[str]
549
+ if negative_prompt is None:
550
+ uncond_tokens = [""] * batch_size
551
+ elif type(prompt) is not type(negative_prompt):
552
+ raise TypeError(
553
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
554
+ f" {type(prompt)}."
555
+ )
556
+ elif isinstance(negative_prompt, str):
557
+ uncond_tokens = [negative_prompt]
558
+ elif batch_size != len(negative_prompt):
559
+ raise ValueError(
560
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
561
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
562
+ " the batch size of `prompt`."
563
+ )
564
+ else:
565
+ uncond_tokens = negative_prompt
566
+
567
+ # textual inversion: procecss multi-vector tokens if necessary
568
+ if isinstance(self, TextualInversionLoaderMixin):
569
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
570
+
571
+ max_length = prompt_embeds.shape[1]
572
+ uncond_input = self.tokenizer(
573
+ uncond_tokens,
574
+ padding="max_length",
575
+ max_length=max_length,
576
+ truncation=True,
577
+ return_tensors="pt",
578
+ )
579
+
580
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
581
+ attention_mask = uncond_input.attention_mask.to(device)
582
+ else:
583
+ attention_mask = None
584
+
585
+ if do_classifier_free_guidance:
586
+ prompt_embeds = torch.cat([prompt_embeds, prompt_embeds])
587
+
588
+ return prompt_embeds
589
+
590
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
591
+ def run_safety_checker(self, image, device, dtype):
592
+ if self.safety_checker is None:
593
+ has_nsfw_concept = None
594
+ else:
595
+ if torch.is_tensor(image):
596
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
597
+ else:
598
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
599
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
600
+ image, has_nsfw_concept = self.safety_checker(
601
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
602
+ )
603
+ return image, has_nsfw_concept
604
+
605
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
606
+ def prepare_extra_step_kwargs(self, generator, eta):
607
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
608
+ # eta (ฮท) is only used with the DDIMScheduler, it will be ignored for other schedulers.
609
+ # eta corresponds to ฮท in DDIM paper: https://arxiv.org/abs/2010.02502
610
+ # and should be between [0, 1]
611
+
612
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
613
+ extra_step_kwargs = {}
614
+ if accepts_eta:
615
+ extra_step_kwargs["eta"] = eta
616
+
617
+ # check if the scheduler accepts generator
618
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
619
+ if accepts_generator:
620
+ extra_step_kwargs["generator"] = generator
621
+ return extra_step_kwargs
622
+
623
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
624
+ def decode_latents(self, latents):
625
+ deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
626
+ deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
627
+
628
+ latents = 1 / self.vae.config.scaling_factor * latents
629
+ image = self.vae.decode(latents, return_dict=False)[0]
630
+ image = (image / 2 + 0.5).clamp(0, 1)
631
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
632
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
633
+ return image
634
+
635
+ def check_inputs(
636
+ self,
637
+ prompt,
638
+ callback_steps,
639
+ negative_prompt=None,
640
+ prompt_embeds=None,
641
+ negative_prompt_embeds=None,
642
+ callback_on_step_end_tensor_inputs=None,
643
+ ):
644
+ if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
645
+ raise ValueError(
646
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
647
+ f" {type(callback_steps)}."
648
+ )
649
+
650
+ if callback_on_step_end_tensor_inputs is not None and not all(
651
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
652
+ ):
653
+ raise ValueError(
654
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
655
+ )
656
+
657
+ if prompt is not None and prompt_embeds is not None:
658
+ raise ValueError(
659
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
660
+ " only forward one of the two."
661
+ )
662
+ elif prompt is None and prompt_embeds is None:
663
+ raise ValueError(
664
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
665
+ )
666
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
667
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
668
+
669
+ if negative_prompt is not None and negative_prompt_embeds is not None:
670
+ raise ValueError(
671
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
672
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
673
+ )
674
+
675
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
676
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
677
+ raise ValueError(
678
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
679
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
680
+ f" {negative_prompt_embeds.shape}."
681
+ )
682
+
683
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
684
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
685
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
686
+ if isinstance(generator, list) and len(generator) != batch_size:
687
+ raise ValueError(
688
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
689
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
690
+ )
691
+
692
+ if latents is None:
693
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
694
+ else:
695
+ latents = latents.to(device)
696
+
697
+ # scale the initial noise by the standard deviation required by the scheduler
698
+ latents = latents * self.scheduler.init_noise_sigma
699
+ return latents
700
+
701
+ def prepare_garm_latents(
702
+ self, image, batch_size, num_images_per_prompt, dtype, device, do_classifier_free_guidance, generator=None
703
+ ):
704
+ if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
705
+ raise ValueError(
706
+ f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
707
+ )
708
+
709
+ image = image.to(device=device, dtype=dtype)
710
+
711
+ batch_size = batch_size * num_images_per_prompt
712
+
713
+ if image.shape[1] == 4:
714
+ image_latents = image
715
+ else:
716
+ if isinstance(generator, list) and len(generator) != batch_size:
717
+ raise ValueError(
718
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
719
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
720
+ )
721
+
722
+ if isinstance(generator, list):
723
+ image_latents = [self.vae.encode(image[i : i + 1]).latent_dist.mode() for i in range(batch_size)]
724
+ image_latents = torch.cat(image_latents, dim=0)
725
+ else:
726
+ image_latents = self.vae.encode(image).latent_dist.mode()
727
+
728
+ if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
729
+ additional_image_per_prompt = batch_size // image_latents.shape[0]
730
+ image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
731
+ elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
732
+ raise ValueError(
733
+ f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
734
+ )
735
+ else:
736
+ image_latents = torch.cat([image_latents], dim=0)
737
+
738
+ if do_classifier_free_guidance:
739
+ uncond_image_latents = torch.zeros_like(image_latents)
740
+ image_latents = torch.cat([image_latents, uncond_image_latents], dim=0)
741
+
742
+ return image_latents
743
+
744
+ def prepare_vton_latents(
745
+ self, image, mask, image_ori, batch_size, num_images_per_prompt, dtype, device, do_classifier_free_guidance, generator=None
746
+ ):
747
+ if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
748
+ raise ValueError(
749
+ f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
750
+ )
751
+
752
+ image = image.to(device=device, dtype=dtype)
753
+ image_ori = image_ori.to(device=device, dtype=dtype)
754
+
755
+ batch_size = batch_size * num_images_per_prompt
756
+
757
+ if image.shape[1] == 4:
758
+ image_latents = image
759
+ image_ori_latents = image_ori
760
+ else:
761
+ if isinstance(generator, list) and len(generator) != batch_size:
762
+ raise ValueError(
763
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
764
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
765
+ )
766
+
767
+ if isinstance(generator, list):
768
+ image_latents = [self.vae.encode(image[i : i + 1]).latent_dist.mode() for i in range(batch_size)]
769
+ image_latents = torch.cat(image_latents, dim=0)
770
+ image_ori_latents = [self.vae.encode(image_ori[i : i + 1]).latent_dist.mode() for i in range(batch_size)]
771
+ image_ori_latents = torch.cat(image_ori_latents, dim=0)
772
+ else:
773
+ image_latents = self.vae.encode(image).latent_dist.mode()
774
+ image_ori_latents = self.vae.encode(image_ori).latent_dist.mode()
775
+
776
+ mask = torch.nn.functional.interpolate(
777
+ mask, size=(image_latents.size(-2), image_latents.size(-1))
778
+ )
779
+ mask = mask.to(device=device, dtype=dtype)
780
+
781
+ if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
782
+ additional_image_per_prompt = batch_size // image_latents.shape[0]
783
+ image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
784
+ mask = torch.cat([mask] * additional_image_per_prompt, dim=0)
785
+ image_ori_latents = torch.cat([image_ori_latents] * additional_image_per_prompt, dim=0)
786
+ elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
787
+ raise ValueError(
788
+ f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
789
+ )
790
+ else:
791
+ image_latents = torch.cat([image_latents], dim=0)
792
+ mask = torch.cat([mask], dim=0)
793
+ image_ori_latents = torch.cat([image_ori_latents], dim=0)
794
+
795
+ if do_classifier_free_guidance:
796
+ # uncond_image_latents = torch.zeros_like(image_latents)
797
+ image_latents = torch.cat([image_latents] * 2, dim=0)
798
+
799
+ return image_latents, mask, image_ori_latents
800
+
801
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_freeu
802
+ def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
803
+ r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497.
804
+
805
+ The suffixes after the scaling factors represent the stages where they are being applied.
806
+
807
+ Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values
808
+ that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
809
+
810
+ Args:
811
+ s1 (`float`):
812
+ Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
813
+ mitigate "oversmoothing effect" in the enhanced denoising process.
814
+ s2 (`float`):
815
+ Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
816
+ mitigate "oversmoothing effect" in the enhanced denoising process.
817
+ b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
818
+ b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
819
+ """
820
+ if not hasattr(self, "unet"):
821
+ raise ValueError("The pipeline must have `unet` for using FreeU.")
822
+ self.unet_vton.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2)
823
+
824
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_freeu
825
+ def disable_freeu(self):
826
+ """Disables the FreeU mechanism if enabled."""
827
+ self.unet_vton.disable_freeu()
828
+
829
+ @property
830
+ def guidance_scale(self):
831
+ return self._guidance_scale
832
+
833
+ @property
834
+ def image_guidance_scale(self):
835
+ return self._image_guidance_scale
836
+
837
+ @property
838
+ def num_timesteps(self):
839
+ return self._num_timesteps
840
+
841
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
842
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
843
+ # corresponds to doing no classifier free guidance.
844
+ @property
845
+ def do_classifier_free_guidance(self):
846
+ return self.image_guidance_scale >= 1.0
ootd/pipelines_ootd/transformer_garm_2d.py ADDED
@@ -0,0 +1,449 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Modified by Yuhao Xu for OOTDiffusion (https://github.com/levihsu/OOTDiffusion)
16
+ from dataclasses import dataclass
17
+ from typing import Any, Dict, Optional
18
+
19
+ import torch
20
+ import torch.nn.functional as F
21
+ from torch import nn
22
+
23
+ from .attention_garm import BasicTransformerBlock
24
+
25
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
26
+ from diffusers.models.embeddings import ImagePositionalEmbeddings
27
+ from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate
28
+ # from diffusers.models.attention import BasicTransformerBlock
29
+ from diffusers.models.embeddings import CaptionProjection, PatchEmbed
30
+ from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
31
+ from diffusers.models.modeling_utils import ModelMixin
32
+ from diffusers.models.normalization import AdaLayerNormSingle
33
+
34
+
35
+ @dataclass
36
+ class Transformer2DModelOutput(BaseOutput):
37
+ """
38
+ The output of [`Transformer2DModel`].
39
+
40
+ Args:
41
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
42
+ The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
43
+ distributions for the unnoised latent pixels.
44
+ """
45
+
46
+ sample: torch.FloatTensor
47
+
48
+
49
+ class Transformer2DModel(ModelMixin, ConfigMixin):
50
+ """
51
+ A 2D Transformer model for image-like data.
52
+
53
+ Parameters:
54
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
55
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
56
+ in_channels (`int`, *optional*):
57
+ The number of channels in the input and output (specify if the input is **continuous**).
58
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
59
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
60
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
61
+ sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
62
+ This is fixed during training since it is used to learn a number of position embeddings.
63
+ num_vector_embeds (`int`, *optional*):
64
+ The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
65
+ Includes the class for the masked latent pixel.
66
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
67
+ num_embeds_ada_norm ( `int`, *optional*):
68
+ The number of diffusion steps used during training. Pass if at least one of the norm_layers is
69
+ `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
70
+ added to the hidden states.
71
+
72
+ During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
73
+ attention_bias (`bool`, *optional*):
74
+ Configure if the `TransformerBlocks` attention should contain a bias parameter.
75
+ """
76
+
77
+ @register_to_config
78
+ def __init__(
79
+ self,
80
+ num_attention_heads: int = 16,
81
+ attention_head_dim: int = 88,
82
+ in_channels: Optional[int] = None,
83
+ out_channels: Optional[int] = None,
84
+ num_layers: int = 1,
85
+ dropout: float = 0.0,
86
+ norm_num_groups: int = 32,
87
+ cross_attention_dim: Optional[int] = None,
88
+ attention_bias: bool = False,
89
+ sample_size: Optional[int] = None,
90
+ num_vector_embeds: Optional[int] = None,
91
+ patch_size: Optional[int] = None,
92
+ activation_fn: str = "geglu",
93
+ num_embeds_ada_norm: Optional[int] = None,
94
+ use_linear_projection: bool = False,
95
+ only_cross_attention: bool = False,
96
+ double_self_attention: bool = False,
97
+ upcast_attention: bool = False,
98
+ norm_type: str = "layer_norm",
99
+ norm_elementwise_affine: bool = True,
100
+ norm_eps: float = 1e-5,
101
+ attention_type: str = "default",
102
+ caption_channels: int = None,
103
+ ):
104
+ super().__init__()
105
+ self.use_linear_projection = use_linear_projection
106
+ self.num_attention_heads = num_attention_heads
107
+ self.attention_head_dim = attention_head_dim
108
+ inner_dim = num_attention_heads * attention_head_dim
109
+
110
+ conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
111
+ linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
112
+
113
+ # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
114
+ # Define whether input is continuous or discrete depending on configuration
115
+ self.is_input_continuous = (in_channels is not None) and (patch_size is None)
116
+ self.is_input_vectorized = num_vector_embeds is not None
117
+ self.is_input_patches = in_channels is not None and patch_size is not None
118
+
119
+ if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
120
+ deprecation_message = (
121
+ f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
122
+ " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
123
+ " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
124
+ " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
125
+ " would be very nice if you could open a Pull request for the `transformer/config.json` file"
126
+ )
127
+ deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
128
+ norm_type = "ada_norm"
129
+
130
+ if self.is_input_continuous and self.is_input_vectorized:
131
+ raise ValueError(
132
+ f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
133
+ " sure that either `in_channels` or `num_vector_embeds` is None."
134
+ )
135
+ elif self.is_input_vectorized and self.is_input_patches:
136
+ raise ValueError(
137
+ f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
138
+ " sure that either `num_vector_embeds` or `num_patches` is None."
139
+ )
140
+ elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:
141
+ raise ValueError(
142
+ f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
143
+ f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
144
+ )
145
+
146
+ # 2. Define input layers
147
+ if self.is_input_continuous:
148
+ self.in_channels = in_channels
149
+
150
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
151
+ if use_linear_projection:
152
+ self.proj_in = linear_cls(in_channels, inner_dim)
153
+ else:
154
+ self.proj_in = conv_cls(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
155
+ elif self.is_input_vectorized:
156
+ assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
157
+ assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
158
+
159
+ self.height = sample_size
160
+ self.width = sample_size
161
+ self.num_vector_embeds = num_vector_embeds
162
+ self.num_latent_pixels = self.height * self.width
163
+
164
+ self.latent_image_embedding = ImagePositionalEmbeddings(
165
+ num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
166
+ )
167
+ elif self.is_input_patches:
168
+ assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
169
+
170
+ self.height = sample_size
171
+ self.width = sample_size
172
+
173
+ self.patch_size = patch_size
174
+ interpolation_scale = self.config.sample_size // 64 # => 64 (= 512 pixart) has interpolation scale 1
175
+ interpolation_scale = max(interpolation_scale, 1)
176
+ self.pos_embed = PatchEmbed(
177
+ height=sample_size,
178
+ width=sample_size,
179
+ patch_size=patch_size,
180
+ in_channels=in_channels,
181
+ embed_dim=inner_dim,
182
+ interpolation_scale=interpolation_scale,
183
+ )
184
+
185
+ # 3. Define transformers blocks
186
+ self.transformer_blocks = nn.ModuleList(
187
+ [
188
+ BasicTransformerBlock(
189
+ inner_dim,
190
+ num_attention_heads,
191
+ attention_head_dim,
192
+ dropout=dropout,
193
+ cross_attention_dim=cross_attention_dim,
194
+ activation_fn=activation_fn,
195
+ num_embeds_ada_norm=num_embeds_ada_norm,
196
+ attention_bias=attention_bias,
197
+ only_cross_attention=only_cross_attention,
198
+ double_self_attention=double_self_attention,
199
+ upcast_attention=upcast_attention,
200
+ norm_type=norm_type,
201
+ norm_elementwise_affine=norm_elementwise_affine,
202
+ norm_eps=norm_eps,
203
+ attention_type=attention_type,
204
+ )
205
+ for d in range(num_layers)
206
+ ]
207
+ )
208
+
209
+ # 4. Define output layers
210
+ self.out_channels = in_channels if out_channels is None else out_channels
211
+ if self.is_input_continuous:
212
+ # TODO: should use out_channels for continuous projections
213
+ if use_linear_projection:
214
+ self.proj_out = linear_cls(inner_dim, in_channels)
215
+ else:
216
+ self.proj_out = conv_cls(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
217
+ elif self.is_input_vectorized:
218
+ self.norm_out = nn.LayerNorm(inner_dim)
219
+ self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
220
+ elif self.is_input_patches and norm_type != "ada_norm_single":
221
+ self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
222
+ self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
223
+ self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
224
+ elif self.is_input_patches and norm_type == "ada_norm_single":
225
+ self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
226
+ self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
227
+ self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
228
+
229
+ # 5. PixArt-Alpha blocks.
230
+ self.adaln_single = None
231
+ self.use_additional_conditions = False
232
+ if norm_type == "ada_norm_single":
233
+ self.use_additional_conditions = self.config.sample_size == 128
234
+ # TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use
235
+ # additional conditions until we find better name
236
+ self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=self.use_additional_conditions)
237
+
238
+ self.caption_projection = None
239
+ if caption_channels is not None:
240
+ self.caption_projection = CaptionProjection(in_features=caption_channels, hidden_size=inner_dim)
241
+
242
+ self.gradient_checkpointing = False
243
+
244
+ def forward(
245
+ self,
246
+ hidden_states: torch.Tensor,
247
+ spatial_attn_inputs = [],
248
+ encoder_hidden_states: Optional[torch.Tensor] = None,
249
+ timestep: Optional[torch.LongTensor] = None,
250
+ added_cond_kwargs: Dict[str, torch.Tensor] = None,
251
+ class_labels: Optional[torch.LongTensor] = None,
252
+ cross_attention_kwargs: Dict[str, Any] = None,
253
+ attention_mask: Optional[torch.Tensor] = None,
254
+ encoder_attention_mask: Optional[torch.Tensor] = None,
255
+ return_dict: bool = True,
256
+ ):
257
+ """
258
+ The [`Transformer2DModel`] forward method.
259
+
260
+ Args:
261
+ hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
262
+ Input `hidden_states`.
263
+ encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
264
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
265
+ self-attention.
266
+ timestep ( `torch.LongTensor`, *optional*):
267
+ Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
268
+ class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
269
+ Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
270
+ `AdaLayerZeroNorm`.
271
+ cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
272
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
273
+ `self.processor` in
274
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
275
+ attention_mask ( `torch.Tensor`, *optional*):
276
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
277
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
278
+ negative values to the attention scores corresponding to "discard" tokens.
279
+ encoder_attention_mask ( `torch.Tensor`, *optional*):
280
+ Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
281
+
282
+ * Mask `(batch, sequence_length)` True = keep, False = discard.
283
+ * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
284
+
285
+ If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
286
+ above. This bias will be added to the cross-attention scores.
287
+ return_dict (`bool`, *optional*, defaults to `True`):
288
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
289
+ tuple.
290
+
291
+ Returns:
292
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
293
+ `tuple` where the first element is the sample tensor.
294
+ """
295
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
296
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
297
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
298
+ # expects mask of shape:
299
+ # [batch, key_tokens]
300
+ # adds singleton query_tokens dimension:
301
+ # [batch, 1, key_tokens]
302
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
303
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
304
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
305
+ if attention_mask is not None and attention_mask.ndim == 2:
306
+ # assume that mask is expressed as:
307
+ # (1 = keep, 0 = discard)
308
+ # convert mask into a bias that can be added to attention scores:
309
+ # (keep = +0, discard = -10000.0)
310
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
311
+ attention_mask = attention_mask.unsqueeze(1)
312
+
313
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
314
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
315
+ encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
316
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
317
+
318
+ # Retrieve lora scale.
319
+ lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
320
+
321
+ # 1. Input
322
+ if self.is_input_continuous:
323
+ batch, _, height, width = hidden_states.shape
324
+ residual = hidden_states
325
+
326
+ hidden_states = self.norm(hidden_states)
327
+ if not self.use_linear_projection:
328
+ hidden_states = (
329
+ self.proj_in(hidden_states, scale=lora_scale)
330
+ if not USE_PEFT_BACKEND
331
+ else self.proj_in(hidden_states)
332
+ )
333
+ inner_dim = hidden_states.shape[1]
334
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
335
+ else:
336
+ inner_dim = hidden_states.shape[1]
337
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
338
+ hidden_states = (
339
+ self.proj_in(hidden_states, scale=lora_scale)
340
+ if not USE_PEFT_BACKEND
341
+ else self.proj_in(hidden_states)
342
+ )
343
+
344
+ elif self.is_input_vectorized:
345
+ hidden_states = self.latent_image_embedding(hidden_states)
346
+ elif self.is_input_patches:
347
+ height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
348
+ hidden_states = self.pos_embed(hidden_states)
349
+
350
+ if self.adaln_single is not None:
351
+ if self.use_additional_conditions and added_cond_kwargs is None:
352
+ raise ValueError(
353
+ "`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`."
354
+ )
355
+ batch_size = hidden_states.shape[0]
356
+ timestep, embedded_timestep = self.adaln_single(
357
+ timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
358
+ )
359
+
360
+ # 2. Blocks
361
+ if self.caption_projection is not None:
362
+ batch_size = hidden_states.shape[0]
363
+ encoder_hidden_states = self.caption_projection(encoder_hidden_states)
364
+ encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
365
+
366
+ for block in self.transformer_blocks:
367
+ if self.training and self.gradient_checkpointing:
368
+ hidden_states, spatial_attn_inputs = torch.utils.checkpoint.checkpoint(
369
+ block,
370
+ hidden_states,
371
+ spatial_attn_inputs,
372
+ attention_mask,
373
+ encoder_hidden_states,
374
+ encoder_attention_mask,
375
+ timestep,
376
+ cross_attention_kwargs,
377
+ class_labels,
378
+ use_reentrant=False,
379
+ )
380
+ else:
381
+ hidden_states, spatial_attn_inputs = block(
382
+ hidden_states,
383
+ spatial_attn_inputs,
384
+ attention_mask=attention_mask,
385
+ encoder_hidden_states=encoder_hidden_states,
386
+ encoder_attention_mask=encoder_attention_mask,
387
+ timestep=timestep,
388
+ cross_attention_kwargs=cross_attention_kwargs,
389
+ class_labels=class_labels,
390
+ )
391
+
392
+ # 3. Output
393
+ if self.is_input_continuous:
394
+ if not self.use_linear_projection:
395
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
396
+ hidden_states = (
397
+ self.proj_out(hidden_states, scale=lora_scale)
398
+ if not USE_PEFT_BACKEND
399
+ else self.proj_out(hidden_states)
400
+ )
401
+ else:
402
+ hidden_states = (
403
+ self.proj_out(hidden_states, scale=lora_scale)
404
+ if not USE_PEFT_BACKEND
405
+ else self.proj_out(hidden_states)
406
+ )
407
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
408
+
409
+ output = hidden_states + residual
410
+ elif self.is_input_vectorized:
411
+ hidden_states = self.norm_out(hidden_states)
412
+ logits = self.out(hidden_states)
413
+ # (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
414
+ logits = logits.permute(0, 2, 1)
415
+
416
+ # log(p(x_0))
417
+ output = F.log_softmax(logits.double(), dim=1).float()
418
+
419
+ if self.is_input_patches:
420
+ if self.config.norm_type != "ada_norm_single":
421
+ conditioning = self.transformer_blocks[0].norm1.emb(
422
+ timestep, class_labels, hidden_dtype=hidden_states.dtype
423
+ )
424
+ shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
425
+ hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
426
+ hidden_states = self.proj_out_2(hidden_states)
427
+ elif self.config.norm_type == "ada_norm_single":
428
+ shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
429
+ hidden_states = self.norm_out(hidden_states)
430
+ # Modulation
431
+ hidden_states = hidden_states * (1 + scale) + shift
432
+ hidden_states = self.proj_out(hidden_states)
433
+ hidden_states = hidden_states.squeeze(1)
434
+
435
+ # unpatchify
436
+ if self.adaln_single is None:
437
+ height = width = int(hidden_states.shape[1] ** 0.5)
438
+ hidden_states = hidden_states.reshape(
439
+ shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
440
+ )
441
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
442
+ output = hidden_states.reshape(
443
+ shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
444
+ )
445
+
446
+ if not return_dict:
447
+ return (output,), spatial_attn_inputs
448
+
449
+ return Transformer2DModelOutput(sample=output), spatial_attn_inputs
ootd/pipelines_ootd/transformer_vton_2d.py ADDED
@@ -0,0 +1,452 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Modified by Yuhao Xu for OOTDiffusion (https://github.com/levihsu/OOTDiffusion)
16
+ from dataclasses import dataclass
17
+ from typing import Any, Dict, Optional
18
+
19
+ import torch
20
+ import torch.nn.functional as F
21
+ from torch import nn
22
+
23
+ from .attention_vton import BasicTransformerBlock
24
+
25
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
26
+ from diffusers.models.embeddings import ImagePositionalEmbeddings
27
+ from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate
28
+ # from diffusers.models.attention import BasicTransformerBlock
29
+ from diffusers.models.embeddings import CaptionProjection, PatchEmbed
30
+ from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
31
+ from diffusers.models.modeling_utils import ModelMixin
32
+ from diffusers.models.normalization import AdaLayerNormSingle
33
+
34
+
35
+ @dataclass
36
+ class Transformer2DModelOutput(BaseOutput):
37
+ """
38
+ The output of [`Transformer2DModel`].
39
+
40
+ Args:
41
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
42
+ The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
43
+ distributions for the unnoised latent pixels.
44
+ """
45
+
46
+ sample: torch.FloatTensor
47
+
48
+
49
+ class Transformer2DModel(ModelMixin, ConfigMixin):
50
+ """
51
+ A 2D Transformer model for image-like data.
52
+
53
+ Parameters:
54
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
55
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
56
+ in_channels (`int`, *optional*):
57
+ The number of channels in the input and output (specify if the input is **continuous**).
58
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
59
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
60
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
61
+ sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
62
+ This is fixed during training since it is used to learn a number of position embeddings.
63
+ num_vector_embeds (`int`, *optional*):
64
+ The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
65
+ Includes the class for the masked latent pixel.
66
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
67
+ num_embeds_ada_norm ( `int`, *optional*):
68
+ The number of diffusion steps used during training. Pass if at least one of the norm_layers is
69
+ `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
70
+ added to the hidden states.
71
+
72
+ During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
73
+ attention_bias (`bool`, *optional*):
74
+ Configure if the `TransformerBlocks` attention should contain a bias parameter.
75
+ """
76
+
77
+ @register_to_config
78
+ def __init__(
79
+ self,
80
+ num_attention_heads: int = 16,
81
+ attention_head_dim: int = 88,
82
+ in_channels: Optional[int] = None,
83
+ out_channels: Optional[int] = None,
84
+ num_layers: int = 1,
85
+ dropout: float = 0.0,
86
+ norm_num_groups: int = 32,
87
+ cross_attention_dim: Optional[int] = None,
88
+ attention_bias: bool = False,
89
+ sample_size: Optional[int] = None,
90
+ num_vector_embeds: Optional[int] = None,
91
+ patch_size: Optional[int] = None,
92
+ activation_fn: str = "geglu",
93
+ num_embeds_ada_norm: Optional[int] = None,
94
+ use_linear_projection: bool = False,
95
+ only_cross_attention: bool = False,
96
+ double_self_attention: bool = False,
97
+ upcast_attention: bool = False,
98
+ norm_type: str = "layer_norm",
99
+ norm_elementwise_affine: bool = True,
100
+ norm_eps: float = 1e-5,
101
+ attention_type: str = "default",
102
+ caption_channels: int = None,
103
+ ):
104
+ super().__init__()
105
+ self.use_linear_projection = use_linear_projection
106
+ self.num_attention_heads = num_attention_heads
107
+ self.attention_head_dim = attention_head_dim
108
+ inner_dim = num_attention_heads * attention_head_dim
109
+
110
+ conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
111
+ linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
112
+
113
+ # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
114
+ # Define whether input is continuous or discrete depending on configuration
115
+ self.is_input_continuous = (in_channels is not None) and (patch_size is None)
116
+ self.is_input_vectorized = num_vector_embeds is not None
117
+ self.is_input_patches = in_channels is not None and patch_size is not None
118
+
119
+ if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
120
+ deprecation_message = (
121
+ f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
122
+ " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
123
+ " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
124
+ " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
125
+ " would be very nice if you could open a Pull request for the `transformer/config.json` file"
126
+ )
127
+ deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
128
+ norm_type = "ada_norm"
129
+
130
+ if self.is_input_continuous and self.is_input_vectorized:
131
+ raise ValueError(
132
+ f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
133
+ " sure that either `in_channels` or `num_vector_embeds` is None."
134
+ )
135
+ elif self.is_input_vectorized and self.is_input_patches:
136
+ raise ValueError(
137
+ f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
138
+ " sure that either `num_vector_embeds` or `num_patches` is None."
139
+ )
140
+ elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:
141
+ raise ValueError(
142
+ f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
143
+ f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
144
+ )
145
+
146
+ # 2. Define input layers
147
+ if self.is_input_continuous:
148
+ self.in_channels = in_channels
149
+
150
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
151
+ if use_linear_projection:
152
+ self.proj_in = linear_cls(in_channels, inner_dim)
153
+ else:
154
+ self.proj_in = conv_cls(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
155
+ elif self.is_input_vectorized:
156
+ assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
157
+ assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
158
+
159
+ self.height = sample_size
160
+ self.width = sample_size
161
+ self.num_vector_embeds = num_vector_embeds
162
+ self.num_latent_pixels = self.height * self.width
163
+
164
+ self.latent_image_embedding = ImagePositionalEmbeddings(
165
+ num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
166
+ )
167
+ elif self.is_input_patches:
168
+ assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
169
+
170
+ self.height = sample_size
171
+ self.width = sample_size
172
+
173
+ self.patch_size = patch_size
174
+ interpolation_scale = self.config.sample_size // 64 # => 64 (= 512 pixart) has interpolation scale 1
175
+ interpolation_scale = max(interpolation_scale, 1)
176
+ self.pos_embed = PatchEmbed(
177
+ height=sample_size,
178
+ width=sample_size,
179
+ patch_size=patch_size,
180
+ in_channels=in_channels,
181
+ embed_dim=inner_dim,
182
+ interpolation_scale=interpolation_scale,
183
+ )
184
+
185
+ # 3. Define transformers blocks
186
+ self.transformer_blocks = nn.ModuleList(
187
+ [
188
+ BasicTransformerBlock(
189
+ inner_dim,
190
+ num_attention_heads,
191
+ attention_head_dim,
192
+ dropout=dropout,
193
+ cross_attention_dim=cross_attention_dim,
194
+ activation_fn=activation_fn,
195
+ num_embeds_ada_norm=num_embeds_ada_norm,
196
+ attention_bias=attention_bias,
197
+ only_cross_attention=only_cross_attention,
198
+ double_self_attention=double_self_attention,
199
+ upcast_attention=upcast_attention,
200
+ norm_type=norm_type,
201
+ norm_elementwise_affine=norm_elementwise_affine,
202
+ norm_eps=norm_eps,
203
+ attention_type=attention_type,
204
+ )
205
+ for d in range(num_layers)
206
+ ]
207
+ )
208
+
209
+ # 4. Define output layers
210
+ self.out_channels = in_channels if out_channels is None else out_channels
211
+ if self.is_input_continuous:
212
+ # TODO: should use out_channels for continuous projections
213
+ if use_linear_projection:
214
+ self.proj_out = linear_cls(inner_dim, in_channels)
215
+ else:
216
+ self.proj_out = conv_cls(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
217
+ elif self.is_input_vectorized:
218
+ self.norm_out = nn.LayerNorm(inner_dim)
219
+ self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
220
+ elif self.is_input_patches and norm_type != "ada_norm_single":
221
+ self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
222
+ self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
223
+ self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
224
+ elif self.is_input_patches and norm_type == "ada_norm_single":
225
+ self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
226
+ self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
227
+ self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
228
+
229
+ # 5. PixArt-Alpha blocks.
230
+ self.adaln_single = None
231
+ self.use_additional_conditions = False
232
+ if norm_type == "ada_norm_single":
233
+ self.use_additional_conditions = self.config.sample_size == 128
234
+ # TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use
235
+ # additional conditions until we find better name
236
+ self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=self.use_additional_conditions)
237
+
238
+ self.caption_projection = None
239
+ if caption_channels is not None:
240
+ self.caption_projection = CaptionProjection(in_features=caption_channels, hidden_size=inner_dim)
241
+
242
+ self.gradient_checkpointing = False
243
+
244
+ def forward(
245
+ self,
246
+ hidden_states: torch.Tensor,
247
+ spatial_attn_inputs = [],
248
+ spatial_attn_idx = 0,
249
+ encoder_hidden_states: Optional[torch.Tensor] = None,
250
+ timestep: Optional[torch.LongTensor] = None,
251
+ added_cond_kwargs: Dict[str, torch.Tensor] = None,
252
+ class_labels: Optional[torch.LongTensor] = None,
253
+ cross_attention_kwargs: Dict[str, Any] = None,
254
+ attention_mask: Optional[torch.Tensor] = None,
255
+ encoder_attention_mask: Optional[torch.Tensor] = None,
256
+ return_dict: bool = True,
257
+ ):
258
+ """
259
+ The [`Transformer2DModel`] forward method.
260
+
261
+ Args:
262
+ hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
263
+ Input `hidden_states`.
264
+ encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
265
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
266
+ self-attention.
267
+ timestep ( `torch.LongTensor`, *optional*):
268
+ Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
269
+ class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
270
+ Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
271
+ `AdaLayerZeroNorm`.
272
+ cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
273
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
274
+ `self.processor` in
275
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
276
+ attention_mask ( `torch.Tensor`, *optional*):
277
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
278
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
279
+ negative values to the attention scores corresponding to "discard" tokens.
280
+ encoder_attention_mask ( `torch.Tensor`, *optional*):
281
+ Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
282
+
283
+ * Mask `(batch, sequence_length)` True = keep, False = discard.
284
+ * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
285
+
286
+ If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
287
+ above. This bias will be added to the cross-attention scores.
288
+ return_dict (`bool`, *optional*, defaults to `True`):
289
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
290
+ tuple.
291
+
292
+ Returns:
293
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
294
+ `tuple` where the first element is the sample tensor.
295
+ """
296
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
297
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
298
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
299
+ # expects mask of shape:
300
+ # [batch, key_tokens]
301
+ # adds singleton query_tokens dimension:
302
+ # [batch, 1, key_tokens]
303
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
304
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
305
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
306
+ if attention_mask is not None and attention_mask.ndim == 2:
307
+ # assume that mask is expressed as:
308
+ # (1 = keep, 0 = discard)
309
+ # convert mask into a bias that can be added to attention scores:
310
+ # (keep = +0, discard = -10000.0)
311
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
312
+ attention_mask = attention_mask.unsqueeze(1)
313
+
314
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
315
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
316
+ encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
317
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
318
+
319
+ # Retrieve lora scale.
320
+ lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
321
+
322
+ # 1. Input
323
+ if self.is_input_continuous:
324
+ batch, _, height, width = hidden_states.shape
325
+ residual = hidden_states
326
+
327
+ hidden_states = self.norm(hidden_states)
328
+ if not self.use_linear_projection:
329
+ hidden_states = (
330
+ self.proj_in(hidden_states, scale=lora_scale)
331
+ if not USE_PEFT_BACKEND
332
+ else self.proj_in(hidden_states)
333
+ )
334
+ inner_dim = hidden_states.shape[1]
335
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
336
+ else:
337
+ inner_dim = hidden_states.shape[1]
338
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
339
+ hidden_states = (
340
+ self.proj_in(hidden_states, scale=lora_scale)
341
+ if not USE_PEFT_BACKEND
342
+ else self.proj_in(hidden_states)
343
+ )
344
+
345
+ elif self.is_input_vectorized:
346
+ hidden_states = self.latent_image_embedding(hidden_states)
347
+ elif self.is_input_patches:
348
+ height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
349
+ hidden_states = self.pos_embed(hidden_states)
350
+
351
+ if self.adaln_single is not None:
352
+ if self.use_additional_conditions and added_cond_kwargs is None:
353
+ raise ValueError(
354
+ "`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`."
355
+ )
356
+ batch_size = hidden_states.shape[0]
357
+ timestep, embedded_timestep = self.adaln_single(
358
+ timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
359
+ )
360
+
361
+ # 2. Blocks
362
+ if self.caption_projection is not None:
363
+ batch_size = hidden_states.shape[0]
364
+ encoder_hidden_states = self.caption_projection(encoder_hidden_states)
365
+ encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
366
+
367
+ for block in self.transformer_blocks:
368
+ if self.training and self.gradient_checkpointing:
369
+ hidden_states, spatial_attn_inputs, spatial_attn_idx = torch.utils.checkpoint.checkpoint(
370
+ block,
371
+ hidden_states,
372
+ spatial_attn_inputs,
373
+ spatial_attn_idx,
374
+ attention_mask,
375
+ encoder_hidden_states,
376
+ encoder_attention_mask,
377
+ timestep,
378
+ cross_attention_kwargs,
379
+ class_labels,
380
+ use_reentrant=False,
381
+ )
382
+ else:
383
+ hidden_states, spatial_attn_inputs, spatial_attn_idx = block(
384
+ hidden_states,
385
+ spatial_attn_inputs,
386
+ spatial_attn_idx,
387
+ attention_mask=attention_mask,
388
+ encoder_hidden_states=encoder_hidden_states,
389
+ encoder_attention_mask=encoder_attention_mask,
390
+ timestep=timestep,
391
+ cross_attention_kwargs=cross_attention_kwargs,
392
+ class_labels=class_labels,
393
+ )
394
+
395
+ # 3. Output
396
+ if self.is_input_continuous:
397
+ if not self.use_linear_projection:
398
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
399
+ hidden_states = (
400
+ self.proj_out(hidden_states, scale=lora_scale)
401
+ if not USE_PEFT_BACKEND
402
+ else self.proj_out(hidden_states)
403
+ )
404
+ else:
405
+ hidden_states = (
406
+ self.proj_out(hidden_states, scale=lora_scale)
407
+ if not USE_PEFT_BACKEND
408
+ else self.proj_out(hidden_states)
409
+ )
410
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
411
+
412
+ output = hidden_states + residual
413
+ elif self.is_input_vectorized:
414
+ hidden_states = self.norm_out(hidden_states)
415
+ logits = self.out(hidden_states)
416
+ # (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
417
+ logits = logits.permute(0, 2, 1)
418
+
419
+ # log(p(x_0))
420
+ output = F.log_softmax(logits.double(), dim=1).float()
421
+
422
+ if self.is_input_patches:
423
+ if self.config.norm_type != "ada_norm_single":
424
+ conditioning = self.transformer_blocks[0].norm1.emb(
425
+ timestep, class_labels, hidden_dtype=hidden_states.dtype
426
+ )
427
+ shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
428
+ hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
429
+ hidden_states = self.proj_out_2(hidden_states)
430
+ elif self.config.norm_type == "ada_norm_single":
431
+ shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
432
+ hidden_states = self.norm_out(hidden_states)
433
+ # Modulation
434
+ hidden_states = hidden_states * (1 + scale) + shift
435
+ hidden_states = self.proj_out(hidden_states)
436
+ hidden_states = hidden_states.squeeze(1)
437
+
438
+ # unpatchify
439
+ if self.adaln_single is None:
440
+ height = width = int(hidden_states.shape[1] ** 0.5)
441
+ hidden_states = hidden_states.reshape(
442
+ shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
443
+ )
444
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
445
+ output = hidden_states.reshape(
446
+ shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
447
+ )
448
+
449
+ if not return_dict:
450
+ return (output,), spatial_attn_inputs, spatial_attn_idx
451
+
452
+ return Transformer2DModelOutput(sample=output), spatial_attn_inputs, spatial_attn_idx
ootd/pipelines_ootd/unet_garm_2d_blocks.py ADDED
The diff for this file is too large to render. See raw diff
 
ootd/pipelines_ootd/unet_garm_2d_condition.py ADDED
@@ -0,0 +1,1183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Modified by Yuhao Xu for OOTDiffusion (https://github.com/levihsu/OOTDiffusion)
16
+ from dataclasses import dataclass
17
+ from typing import Any, Dict, List, Optional, Tuple, Union
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.utils.checkpoint
22
+
23
+ from .unet_garm_2d_blocks import (
24
+ UNetMidBlock2D,
25
+ UNetMidBlock2DCrossAttn,
26
+ UNetMidBlock2DSimpleCrossAttn,
27
+ get_down_block,
28
+ get_up_block,
29
+ )
30
+
31
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
32
+ from diffusers.loaders import UNet2DConditionLoadersMixin
33
+ from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
34
+ from diffusers.models.activations import get_activation
35
+ from diffusers.models.attention_processor import (
36
+ ADDED_KV_ATTENTION_PROCESSORS,
37
+ CROSS_ATTENTION_PROCESSORS,
38
+ AttentionProcessor,
39
+ AttnAddedKVProcessor,
40
+ AttnProcessor,
41
+ )
42
+ from diffusers.models.embeddings import (
43
+ GaussianFourierProjection,
44
+ ImageHintTimeEmbedding,
45
+ ImageProjection,
46
+ ImageTimeEmbedding,
47
+ PositionNet,
48
+ TextImageProjection,
49
+ TextImageTimeEmbedding,
50
+ TextTimeEmbedding,
51
+ TimestepEmbedding,
52
+ Timesteps,
53
+ )
54
+ from diffusers.models.modeling_utils import ModelMixin
55
+ # from diffusers.models.unet_2d_blocks import (
56
+ # UNetMidBlock2D,
57
+ # UNetMidBlock2DCrossAttn,
58
+ # UNetMidBlock2DSimpleCrossAttn,
59
+ # get_down_block,
60
+ # get_up_block,
61
+ # )
62
+
63
+
64
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
65
+
66
+
67
+ @dataclass
68
+ class UNet2DConditionOutput(BaseOutput):
69
+ """
70
+ The output of [`UNet2DConditionModel`].
71
+
72
+ Args:
73
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
74
+ The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
75
+ """
76
+
77
+ sample: torch.FloatTensor = None
78
+
79
+
80
+ class UNetGarm2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
81
+ r"""
82
+ A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
83
+ shaped output.
84
+
85
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
86
+ for all models (such as downloading or saving).
87
+
88
+ Parameters:
89
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
90
+ Height and width of input/output sample.
91
+ in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
92
+ out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
93
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
94
+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
95
+ Whether to flip the sin to cos in the time embedding.
96
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
97
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
98
+ The tuple of downsample blocks to use.
99
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
100
+ Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or
101
+ `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
102
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
103
+ The tuple of upsample blocks to use.
104
+ only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
105
+ Whether to include self-attention in the basic transformer blocks, see
106
+ [`~models.attention.BasicTransformerBlock`].
107
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
108
+ The tuple of output channels for each block.
109
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
110
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
111
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
112
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
113
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
114
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
115
+ If `None`, normalization and activation layers is skipped in post-processing.
116
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
117
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
118
+ The dimension of the cross attention features.
119
+ transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
120
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
121
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
122
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
123
+ reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None):
124
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling
125
+ blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for
126
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
127
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
128
+ encoder_hid_dim (`int`, *optional*, defaults to None):
129
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
130
+ dimension to `cross_attention_dim`.
131
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
132
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
133
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
134
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
135
+ num_attention_heads (`int`, *optional*):
136
+ The number of attention heads. If not defined, defaults to `attention_head_dim`
137
+ resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
138
+ for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
139
+ class_embed_type (`str`, *optional*, defaults to `None`):
140
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
141
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
142
+ addition_embed_type (`str`, *optional*, defaults to `None`):
143
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
144
+ "text". "text" will use the `TextTimeEmbedding` layer.
145
+ addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
146
+ Dimension for the timestep embeddings.
147
+ num_class_embeds (`int`, *optional*, defaults to `None`):
148
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
149
+ class conditioning with `class_embed_type` equal to `None`.
150
+ time_embedding_type (`str`, *optional*, defaults to `positional`):
151
+ The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
152
+ time_embedding_dim (`int`, *optional*, defaults to `None`):
153
+ An optional override for the dimension of the projected time embedding.
154
+ time_embedding_act_fn (`str`, *optional*, defaults to `None`):
155
+ Optional activation function to use only once on the time embeddings before they are passed to the rest of
156
+ the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
157
+ timestep_post_act (`str`, *optional*, defaults to `None`):
158
+ The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
159
+ time_cond_proj_dim (`int`, *optional*, defaults to `None`):
160
+ The dimension of `cond_proj` layer in the timestep embedding.
161
+ conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. conv_out_kernel (`int`,
162
+ *optional*, default to `3`): The kernel size of `conv_out` layer. projection_class_embeddings_input_dim (`int`,
163
+ *optional*): The dimension of the `class_labels` input when
164
+ `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
165
+ class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
166
+ embeddings with the class embeddings.
167
+ mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
168
+ Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
169
+ `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
170
+ `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
171
+ otherwise.
172
+ """
173
+
174
+ _supports_gradient_checkpointing = True
175
+
176
+ @register_to_config
177
+ def __init__(
178
+ self,
179
+ sample_size: Optional[int] = None,
180
+ in_channels: int = 4,
181
+ out_channels: int = 4,
182
+ center_input_sample: bool = False,
183
+ flip_sin_to_cos: bool = True,
184
+ freq_shift: int = 0,
185
+ down_block_types: Tuple[str] = (
186
+ "CrossAttnDownBlock2D",
187
+ "CrossAttnDownBlock2D",
188
+ "CrossAttnDownBlock2D",
189
+ "DownBlock2D",
190
+ ),
191
+ mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
192
+ up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
193
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
194
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
195
+ layers_per_block: Union[int, Tuple[int]] = 2,
196
+ downsample_padding: int = 1,
197
+ mid_block_scale_factor: float = 1,
198
+ dropout: float = 0.0,
199
+ act_fn: str = "silu",
200
+ norm_num_groups: Optional[int] = 32,
201
+ norm_eps: float = 1e-5,
202
+ cross_attention_dim: Union[int, Tuple[int]] = 1280,
203
+ transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
204
+ reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
205
+ encoder_hid_dim: Optional[int] = None,
206
+ encoder_hid_dim_type: Optional[str] = None,
207
+ attention_head_dim: Union[int, Tuple[int]] = 8,
208
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
209
+ dual_cross_attention: bool = False,
210
+ use_linear_projection: bool = False,
211
+ class_embed_type: Optional[str] = None,
212
+ addition_embed_type: Optional[str] = None,
213
+ addition_time_embed_dim: Optional[int] = None,
214
+ num_class_embeds: Optional[int] = None,
215
+ upcast_attention: bool = False,
216
+ resnet_time_scale_shift: str = "default",
217
+ resnet_skip_time_act: bool = False,
218
+ resnet_out_scale_factor: int = 1.0,
219
+ time_embedding_type: str = "positional",
220
+ time_embedding_dim: Optional[int] = None,
221
+ time_embedding_act_fn: Optional[str] = None,
222
+ timestep_post_act: Optional[str] = None,
223
+ time_cond_proj_dim: Optional[int] = None,
224
+ conv_in_kernel: int = 3,
225
+ conv_out_kernel: int = 3,
226
+ projection_class_embeddings_input_dim: Optional[int] = None,
227
+ attention_type: str = "default",
228
+ class_embeddings_concat: bool = False,
229
+ mid_block_only_cross_attention: Optional[bool] = None,
230
+ cross_attention_norm: Optional[str] = None,
231
+ addition_embed_type_num_heads=64,
232
+ ):
233
+ super().__init__()
234
+
235
+ self.sample_size = sample_size
236
+
237
+ if num_attention_heads is not None:
238
+ raise ValueError(
239
+ "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
240
+ )
241
+
242
+ # If `num_attention_heads` is not defined (which is the case for most models)
243
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
244
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
245
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
246
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
247
+ # which is why we correct for the naming here.
248
+ num_attention_heads = num_attention_heads or attention_head_dim
249
+
250
+ # Check inputs
251
+ if len(down_block_types) != len(up_block_types):
252
+ raise ValueError(
253
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
254
+ )
255
+
256
+ if len(block_out_channels) != len(down_block_types):
257
+ raise ValueError(
258
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
259
+ )
260
+
261
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
262
+ raise ValueError(
263
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
264
+ )
265
+
266
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
267
+ raise ValueError(
268
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
269
+ )
270
+
271
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
272
+ raise ValueError(
273
+ f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
274
+ )
275
+
276
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
277
+ raise ValueError(
278
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
279
+ )
280
+
281
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
282
+ raise ValueError(
283
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
284
+ )
285
+ if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None:
286
+ for layer_number_per_block in transformer_layers_per_block:
287
+ if isinstance(layer_number_per_block, list):
288
+ raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.")
289
+
290
+ # input
291
+ conv_in_padding = (conv_in_kernel - 1) // 2
292
+ self.conv_in = nn.Conv2d(
293
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
294
+ )
295
+
296
+ # time
297
+ if time_embedding_type == "fourier":
298
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
299
+ if time_embed_dim % 2 != 0:
300
+ raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
301
+ self.time_proj = GaussianFourierProjection(
302
+ time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
303
+ )
304
+ timestep_input_dim = time_embed_dim
305
+ elif time_embedding_type == "positional":
306
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
307
+
308
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
309
+ timestep_input_dim = block_out_channels[0]
310
+ else:
311
+ raise ValueError(
312
+ f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
313
+ )
314
+
315
+ self.time_embedding = TimestepEmbedding(
316
+ timestep_input_dim,
317
+ time_embed_dim,
318
+ act_fn=act_fn,
319
+ post_act_fn=timestep_post_act,
320
+ cond_proj_dim=time_cond_proj_dim,
321
+ )
322
+
323
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
324
+ encoder_hid_dim_type = "text_proj"
325
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
326
+ logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
327
+
328
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
329
+ raise ValueError(
330
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
331
+ )
332
+
333
+ if encoder_hid_dim_type == "text_proj":
334
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
335
+ elif encoder_hid_dim_type == "text_image_proj":
336
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
337
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
338
+ # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
339
+ self.encoder_hid_proj = TextImageProjection(
340
+ text_embed_dim=encoder_hid_dim,
341
+ image_embed_dim=cross_attention_dim,
342
+ cross_attention_dim=cross_attention_dim,
343
+ )
344
+ elif encoder_hid_dim_type == "image_proj":
345
+ # Kandinsky 2.2
346
+ self.encoder_hid_proj = ImageProjection(
347
+ image_embed_dim=encoder_hid_dim,
348
+ cross_attention_dim=cross_attention_dim,
349
+ )
350
+ elif encoder_hid_dim_type is not None:
351
+ raise ValueError(
352
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
353
+ )
354
+ else:
355
+ self.encoder_hid_proj = None
356
+
357
+ # class embedding
358
+ if class_embed_type is None and num_class_embeds is not None:
359
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
360
+ elif class_embed_type == "timestep":
361
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
362
+ elif class_embed_type == "identity":
363
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
364
+ elif class_embed_type == "projection":
365
+ if projection_class_embeddings_input_dim is None:
366
+ raise ValueError(
367
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
368
+ )
369
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
370
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
371
+ # 2. it projects from an arbitrary input dimension.
372
+ #
373
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
374
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
375
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
376
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
377
+ elif class_embed_type == "simple_projection":
378
+ if projection_class_embeddings_input_dim is None:
379
+ raise ValueError(
380
+ "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
381
+ )
382
+ self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
383
+ else:
384
+ self.class_embedding = None
385
+
386
+ if addition_embed_type == "text":
387
+ if encoder_hid_dim is not None:
388
+ text_time_embedding_from_dim = encoder_hid_dim
389
+ else:
390
+ text_time_embedding_from_dim = cross_attention_dim
391
+
392
+ self.add_embedding = TextTimeEmbedding(
393
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
394
+ )
395
+ elif addition_embed_type == "text_image":
396
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
397
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
398
+ # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
399
+ self.add_embedding = TextImageTimeEmbedding(
400
+ text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
401
+ )
402
+ elif addition_embed_type == "text_time":
403
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
404
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
405
+ elif addition_embed_type == "image":
406
+ # Kandinsky 2.2
407
+ self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
408
+ elif addition_embed_type == "image_hint":
409
+ # Kandinsky 2.2 ControlNet
410
+ self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
411
+ elif addition_embed_type is not None:
412
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
413
+
414
+ if time_embedding_act_fn is None:
415
+ self.time_embed_act = None
416
+ else:
417
+ self.time_embed_act = get_activation(time_embedding_act_fn)
418
+
419
+ self.down_blocks = nn.ModuleList([])
420
+ self.up_blocks = nn.ModuleList([])
421
+
422
+ if isinstance(only_cross_attention, bool):
423
+ if mid_block_only_cross_attention is None:
424
+ mid_block_only_cross_attention = only_cross_attention
425
+
426
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
427
+
428
+ if mid_block_only_cross_attention is None:
429
+ mid_block_only_cross_attention = False
430
+
431
+ if isinstance(num_attention_heads, int):
432
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
433
+
434
+ if isinstance(attention_head_dim, int):
435
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
436
+
437
+ if isinstance(cross_attention_dim, int):
438
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
439
+
440
+ if isinstance(layers_per_block, int):
441
+ layers_per_block = [layers_per_block] * len(down_block_types)
442
+
443
+ if isinstance(transformer_layers_per_block, int):
444
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
445
+
446
+ if class_embeddings_concat:
447
+ # The time embeddings are concatenated with the class embeddings. The dimension of the
448
+ # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
449
+ # regular time embeddings
450
+ blocks_time_embed_dim = time_embed_dim * 2
451
+ else:
452
+ blocks_time_embed_dim = time_embed_dim
453
+
454
+ # down
455
+ output_channel = block_out_channels[0]
456
+ for i, down_block_type in enumerate(down_block_types):
457
+ input_channel = output_channel
458
+ output_channel = block_out_channels[i]
459
+ is_final_block = i == len(block_out_channels) - 1
460
+
461
+ down_block = get_down_block(
462
+ down_block_type,
463
+ num_layers=layers_per_block[i],
464
+ transformer_layers_per_block=transformer_layers_per_block[i],
465
+ in_channels=input_channel,
466
+ out_channels=output_channel,
467
+ temb_channels=blocks_time_embed_dim,
468
+ add_downsample=not is_final_block,
469
+ resnet_eps=norm_eps,
470
+ resnet_act_fn=act_fn,
471
+ resnet_groups=norm_num_groups,
472
+ cross_attention_dim=cross_attention_dim[i],
473
+ num_attention_heads=num_attention_heads[i],
474
+ downsample_padding=downsample_padding,
475
+ dual_cross_attention=dual_cross_attention,
476
+ use_linear_projection=use_linear_projection,
477
+ only_cross_attention=only_cross_attention[i],
478
+ upcast_attention=upcast_attention,
479
+ resnet_time_scale_shift=resnet_time_scale_shift,
480
+ attention_type=attention_type,
481
+ resnet_skip_time_act=resnet_skip_time_act,
482
+ resnet_out_scale_factor=resnet_out_scale_factor,
483
+ cross_attention_norm=cross_attention_norm,
484
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
485
+ dropout=dropout,
486
+ )
487
+ self.down_blocks.append(down_block)
488
+
489
+ # mid
490
+ if mid_block_type == "UNetMidBlock2DCrossAttn":
491
+ self.mid_block = UNetMidBlock2DCrossAttn(
492
+ transformer_layers_per_block=transformer_layers_per_block[-1],
493
+ in_channels=block_out_channels[-1],
494
+ temb_channels=blocks_time_embed_dim,
495
+ dropout=dropout,
496
+ resnet_eps=norm_eps,
497
+ resnet_act_fn=act_fn,
498
+ output_scale_factor=mid_block_scale_factor,
499
+ resnet_time_scale_shift=resnet_time_scale_shift,
500
+ cross_attention_dim=cross_attention_dim[-1],
501
+ num_attention_heads=num_attention_heads[-1],
502
+ resnet_groups=norm_num_groups,
503
+ dual_cross_attention=dual_cross_attention,
504
+ use_linear_projection=use_linear_projection,
505
+ upcast_attention=upcast_attention,
506
+ attention_type=attention_type,
507
+ )
508
+ elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
509
+ self.mid_block = UNetMidBlock2DSimpleCrossAttn(
510
+ in_channels=block_out_channels[-1],
511
+ temb_channels=blocks_time_embed_dim,
512
+ dropout=dropout,
513
+ resnet_eps=norm_eps,
514
+ resnet_act_fn=act_fn,
515
+ output_scale_factor=mid_block_scale_factor,
516
+ cross_attention_dim=cross_attention_dim[-1],
517
+ attention_head_dim=attention_head_dim[-1],
518
+ resnet_groups=norm_num_groups,
519
+ resnet_time_scale_shift=resnet_time_scale_shift,
520
+ skip_time_act=resnet_skip_time_act,
521
+ only_cross_attention=mid_block_only_cross_attention,
522
+ cross_attention_norm=cross_attention_norm,
523
+ )
524
+ elif mid_block_type == "UNetMidBlock2D":
525
+ self.mid_block = UNetMidBlock2D(
526
+ in_channels=block_out_channels[-1],
527
+ temb_channels=blocks_time_embed_dim,
528
+ dropout=dropout,
529
+ num_layers=0,
530
+ resnet_eps=norm_eps,
531
+ resnet_act_fn=act_fn,
532
+ output_scale_factor=mid_block_scale_factor,
533
+ resnet_groups=norm_num_groups,
534
+ resnet_time_scale_shift=resnet_time_scale_shift,
535
+ add_attention=False,
536
+ )
537
+ elif mid_block_type is None:
538
+ self.mid_block = None
539
+ else:
540
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
541
+
542
+ # count how many layers upsample the images
543
+ self.num_upsamplers = 0
544
+
545
+ # up
546
+ reversed_block_out_channels = list(reversed(block_out_channels))
547
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
548
+ reversed_layers_per_block = list(reversed(layers_per_block))
549
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
550
+ reversed_transformer_layers_per_block = (
551
+ list(reversed(transformer_layers_per_block))
552
+ if reverse_transformer_layers_per_block is None
553
+ else reverse_transformer_layers_per_block
554
+ )
555
+ only_cross_attention = list(reversed(only_cross_attention))
556
+
557
+ output_channel = reversed_block_out_channels[0]
558
+ for i, up_block_type in enumerate(up_block_types):
559
+ is_final_block = i == len(block_out_channels) - 1
560
+
561
+ prev_output_channel = output_channel
562
+ output_channel = reversed_block_out_channels[i]
563
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
564
+
565
+ # add upsample block for all BUT final layer
566
+ if not is_final_block:
567
+ add_upsample = True
568
+ self.num_upsamplers += 1
569
+ else:
570
+ add_upsample = False
571
+
572
+ up_block = get_up_block(
573
+ up_block_type,
574
+ num_layers=reversed_layers_per_block[i] + 1,
575
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
576
+ in_channels=input_channel,
577
+ out_channels=output_channel,
578
+ prev_output_channel=prev_output_channel,
579
+ temb_channels=blocks_time_embed_dim,
580
+ add_upsample=add_upsample,
581
+ resnet_eps=norm_eps,
582
+ resnet_act_fn=act_fn,
583
+ resolution_idx=i,
584
+ resnet_groups=norm_num_groups,
585
+ cross_attention_dim=reversed_cross_attention_dim[i],
586
+ num_attention_heads=reversed_num_attention_heads[i],
587
+ dual_cross_attention=dual_cross_attention,
588
+ use_linear_projection=use_linear_projection,
589
+ only_cross_attention=only_cross_attention[i],
590
+ upcast_attention=upcast_attention,
591
+ resnet_time_scale_shift=resnet_time_scale_shift,
592
+ attention_type=attention_type,
593
+ resnet_skip_time_act=resnet_skip_time_act,
594
+ resnet_out_scale_factor=resnet_out_scale_factor,
595
+ cross_attention_norm=cross_attention_norm,
596
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
597
+ dropout=dropout,
598
+ )
599
+ self.up_blocks.append(up_block)
600
+ prev_output_channel = output_channel
601
+
602
+ # out
603
+ if norm_num_groups is not None:
604
+ self.conv_norm_out = nn.GroupNorm(
605
+ num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
606
+ )
607
+
608
+ self.conv_act = get_activation(act_fn)
609
+
610
+ else:
611
+ self.conv_norm_out = None
612
+ self.conv_act = None
613
+
614
+ conv_out_padding = (conv_out_kernel - 1) // 2
615
+ self.conv_out = nn.Conv2d(
616
+ block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
617
+ )
618
+
619
+ if attention_type in ["gated", "gated-text-image"]:
620
+ positive_len = 768
621
+ if isinstance(cross_attention_dim, int):
622
+ positive_len = cross_attention_dim
623
+ elif isinstance(cross_attention_dim, tuple) or isinstance(cross_attention_dim, list):
624
+ positive_len = cross_attention_dim[0]
625
+
626
+ feature_type = "text-only" if attention_type == "gated" else "text-image"
627
+ self.position_net = PositionNet(
628
+ positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type
629
+ )
630
+
631
+ @property
632
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
633
+ r"""
634
+ Returns:
635
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
636
+ indexed by its weight name.
637
+ """
638
+ # set recursively
639
+ processors = {}
640
+
641
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
642
+ if hasattr(module, "get_processor"):
643
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
644
+
645
+ for sub_name, child in module.named_children():
646
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
647
+
648
+ return processors
649
+
650
+ for name, module in self.named_children():
651
+ fn_recursive_add_processors(name, module, processors)
652
+
653
+ return processors
654
+
655
+ def set_attn_processor(
656
+ self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
657
+ ):
658
+ r"""
659
+ Sets the attention processor to use to compute attention.
660
+
661
+ Parameters:
662
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
663
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
664
+ for **all** `Attention` layers.
665
+
666
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
667
+ processor. This is strongly recommended when setting trainable attention processors.
668
+
669
+ """
670
+ count = len(self.attn_processors.keys())
671
+
672
+ if isinstance(processor, dict) and len(processor) != count:
673
+ raise ValueError(
674
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
675
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
676
+ )
677
+
678
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
679
+ if hasattr(module, "set_processor"):
680
+ if not isinstance(processor, dict):
681
+ module.set_processor(processor, _remove_lora=_remove_lora)
682
+ else:
683
+ module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
684
+
685
+ for sub_name, child in module.named_children():
686
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
687
+
688
+ for name, module in self.named_children():
689
+ fn_recursive_attn_processor(name, module, processor)
690
+
691
+ def set_default_attn_processor(self):
692
+ """
693
+ Disables custom attention processors and sets the default attention implementation.
694
+ """
695
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
696
+ processor = AttnAddedKVProcessor()
697
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
698
+ processor = AttnProcessor()
699
+ else:
700
+ raise ValueError(
701
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
702
+ )
703
+
704
+ self.set_attn_processor(processor, _remove_lora=True)
705
+
706
+ def set_attention_slice(self, slice_size):
707
+ r"""
708
+ Enable sliced attention computation.
709
+
710
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
711
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
712
+
713
+ Args:
714
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
715
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
716
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
717
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
718
+ must be a multiple of `slice_size`.
719
+ """
720
+ sliceable_head_dims = []
721
+
722
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
723
+ if hasattr(module, "set_attention_slice"):
724
+ sliceable_head_dims.append(module.sliceable_head_dim)
725
+
726
+ for child in module.children():
727
+ fn_recursive_retrieve_sliceable_dims(child)
728
+
729
+ # retrieve number of attention layers
730
+ for module in self.children():
731
+ fn_recursive_retrieve_sliceable_dims(module)
732
+
733
+ num_sliceable_layers = len(sliceable_head_dims)
734
+
735
+ if slice_size == "auto":
736
+ # half the attention head size is usually a good trade-off between
737
+ # speed and memory
738
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
739
+ elif slice_size == "max":
740
+ # make smallest slice possible
741
+ slice_size = num_sliceable_layers * [1]
742
+
743
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
744
+
745
+ if len(slice_size) != len(sliceable_head_dims):
746
+ raise ValueError(
747
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
748
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
749
+ )
750
+
751
+ for i in range(len(slice_size)):
752
+ size = slice_size[i]
753
+ dim = sliceable_head_dims[i]
754
+ if size is not None and size > dim:
755
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
756
+
757
+ # Recursively walk through all the children.
758
+ # Any children which exposes the set_attention_slice method
759
+ # gets the message
760
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
761
+ if hasattr(module, "set_attention_slice"):
762
+ module.set_attention_slice(slice_size.pop())
763
+
764
+ for child in module.children():
765
+ fn_recursive_set_attention_slice(child, slice_size)
766
+
767
+ reversed_slice_size = list(reversed(slice_size))
768
+ for module in self.children():
769
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
770
+
771
+ def _set_gradient_checkpointing(self, module, value=False):
772
+ if hasattr(module, "gradient_checkpointing"):
773
+ module.gradient_checkpointing = value
774
+
775
+ def enable_freeu(self, s1, s2, b1, b2):
776
+ r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
777
+
778
+ The suffixes after the scaling factors represent the stage blocks where they are being applied.
779
+
780
+ Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that
781
+ are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
782
+
783
+ Args:
784
+ s1 (`float`):
785
+ Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
786
+ mitigate the "oversmoothing effect" in the enhanced denoising process.
787
+ s2 (`float`):
788
+ Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
789
+ mitigate the "oversmoothing effect" in the enhanced denoising process.
790
+ b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
791
+ b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
792
+ """
793
+ for i, upsample_block in enumerate(self.up_blocks):
794
+ setattr(upsample_block, "s1", s1)
795
+ setattr(upsample_block, "s2", s2)
796
+ setattr(upsample_block, "b1", b1)
797
+ setattr(upsample_block, "b2", b2)
798
+
799
+ def disable_freeu(self):
800
+ """Disables the FreeU mechanism."""
801
+ freeu_keys = {"s1", "s2", "b1", "b2"}
802
+ for i, upsample_block in enumerate(self.up_blocks):
803
+ for k in freeu_keys:
804
+ if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
805
+ setattr(upsample_block, k, None)
806
+
807
+ def forward(
808
+ self,
809
+ sample: torch.FloatTensor,
810
+ timestep: Union[torch.Tensor, float, int],
811
+ encoder_hidden_states: torch.Tensor,
812
+ class_labels: Optional[torch.Tensor] = None,
813
+ timestep_cond: Optional[torch.Tensor] = None,
814
+ attention_mask: Optional[torch.Tensor] = None,
815
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
816
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
817
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
818
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
819
+ down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
820
+ encoder_attention_mask: Optional[torch.Tensor] = None,
821
+ return_dict: bool = True,
822
+ ) -> Union[UNet2DConditionOutput, Tuple]:
823
+ r"""
824
+ The [`UNet2DConditionModel`] forward method.
825
+
826
+ Args:
827
+ sample (`torch.FloatTensor`):
828
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
829
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
830
+ encoder_hidden_states (`torch.FloatTensor`):
831
+ The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
832
+ class_labels (`torch.Tensor`, *optional*, defaults to `None`):
833
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
834
+ timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`):
835
+ Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed
836
+ through the `self.time_embedding` layer to obtain the timestep embeddings.
837
+ attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
838
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
839
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
840
+ negative values to the attention scores corresponding to "discard" tokens.
841
+ cross_attention_kwargs (`dict`, *optional*):
842
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
843
+ `self.processor` in
844
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
845
+ added_cond_kwargs: (`dict`, *optional*):
846
+ A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that
847
+ are passed along to the UNet blocks.
848
+ down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*):
849
+ A tuple of tensors that if specified are added to the residuals of down unet blocks.
850
+ mid_block_additional_residual: (`torch.Tensor`, *optional*):
851
+ A tensor that if specified is added to the residual of the middle unet block.
852
+ encoder_attention_mask (`torch.Tensor`):
853
+ A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
854
+ `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
855
+ which adds large negative values to the attention scores corresponding to "discard" tokens.
856
+ return_dict (`bool`, *optional*, defaults to `True`):
857
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
858
+ tuple.
859
+ cross_attention_kwargs (`dict`, *optional*):
860
+ A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
861
+ added_cond_kwargs: (`dict`, *optional*):
862
+ A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
863
+ are passed along to the UNet blocks.
864
+ down_block_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
865
+ additional residuals to be added to UNet long skip connections from down blocks to up blocks for
866
+ example from ControlNet side model(s)
867
+ mid_block_additional_residual (`torch.Tensor`, *optional*):
868
+ additional residual to be added to UNet mid block output, for example from ControlNet side model
869
+ down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
870
+ additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
871
+
872
+ Returns:
873
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
874
+ If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
875
+ a `tuple` is returned where the first element is the sample tensor.
876
+ """
877
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
878
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
879
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
880
+ # on the fly if necessary.
881
+ default_overall_up_factor = 2**self.num_upsamplers
882
+
883
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
884
+ forward_upsample_size = False
885
+ upsample_size = None
886
+
887
+ for dim in sample.shape[-2:]:
888
+ if dim % default_overall_up_factor != 0:
889
+ # Forward upsample size to force interpolation output size.
890
+ forward_upsample_size = True
891
+ break
892
+
893
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
894
+ # expects mask of shape:
895
+ # [batch, key_tokens]
896
+ # adds singleton query_tokens dimension:
897
+ # [batch, 1, key_tokens]
898
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
899
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
900
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
901
+ if attention_mask is not None:
902
+ # assume that mask is expressed as:
903
+ # (1 = keep, 0 = discard)
904
+ # convert mask into a bias that can be added to attention scores:
905
+ # (keep = +0, discard = -10000.0)
906
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
907
+ attention_mask = attention_mask.unsqueeze(1)
908
+
909
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
910
+ if encoder_attention_mask is not None:
911
+ encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
912
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
913
+
914
+ # 0. center input if necessary
915
+ if self.config.center_input_sample:
916
+ sample = 2 * sample - 1.0
917
+
918
+ # 1. time
919
+ timesteps = timestep
920
+ if not torch.is_tensor(timesteps):
921
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
922
+ # This would be a good case for the `match` statement (Python 3.10+)
923
+ is_mps = sample.device.type == "mps"
924
+ if isinstance(timestep, float):
925
+ dtype = torch.float32 if is_mps else torch.float64
926
+ else:
927
+ dtype = torch.int32 if is_mps else torch.int64
928
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
929
+ elif len(timesteps.shape) == 0:
930
+ timesteps = timesteps[None].to(sample.device)
931
+
932
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
933
+ timesteps = timesteps.expand(sample.shape[0])
934
+
935
+ t_emb = self.time_proj(timesteps)
936
+
937
+ # `Timesteps` does not contain any weights and will always return f32 tensors
938
+ # but time_embedding might actually be running in fp16. so we need to cast here.
939
+ # there might be better ways to encapsulate this.
940
+ t_emb = t_emb.to(dtype=sample.dtype)
941
+
942
+ emb = self.time_embedding(t_emb, timestep_cond)
943
+ aug_emb = None
944
+
945
+ if self.class_embedding is not None:
946
+ if class_labels is None:
947
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
948
+
949
+ if self.config.class_embed_type == "timestep":
950
+ class_labels = self.time_proj(class_labels)
951
+
952
+ # `Timesteps` does not contain any weights and will always return f32 tensors
953
+ # there might be better ways to encapsulate this.
954
+ class_labels = class_labels.to(dtype=sample.dtype)
955
+
956
+ class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
957
+
958
+ if self.config.class_embeddings_concat:
959
+ emb = torch.cat([emb, class_emb], dim=-1)
960
+ else:
961
+ emb = emb + class_emb
962
+
963
+ if self.config.addition_embed_type == "text":
964
+ aug_emb = self.add_embedding(encoder_hidden_states)
965
+ elif self.config.addition_embed_type == "text_image":
966
+ # Kandinsky 2.1 - style
967
+ if "image_embeds" not in added_cond_kwargs:
968
+ raise ValueError(
969
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
970
+ )
971
+
972
+ image_embs = added_cond_kwargs.get("image_embeds")
973
+ text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
974
+ aug_emb = self.add_embedding(text_embs, image_embs)
975
+ elif self.config.addition_embed_type == "text_time":
976
+ # SDXL - style
977
+ if "text_embeds" not in added_cond_kwargs:
978
+ raise ValueError(
979
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
980
+ )
981
+ text_embeds = added_cond_kwargs.get("text_embeds")
982
+ if "time_ids" not in added_cond_kwargs:
983
+ raise ValueError(
984
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
985
+ )
986
+ time_ids = added_cond_kwargs.get("time_ids")
987
+ time_embeds = self.add_time_proj(time_ids.flatten())
988
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
989
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
990
+ add_embeds = add_embeds.to(emb.dtype)
991
+ aug_emb = self.add_embedding(add_embeds)
992
+ elif self.config.addition_embed_type == "image":
993
+ # Kandinsky 2.2 - style
994
+ if "image_embeds" not in added_cond_kwargs:
995
+ raise ValueError(
996
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
997
+ )
998
+ image_embs = added_cond_kwargs.get("image_embeds")
999
+ aug_emb = self.add_embedding(image_embs)
1000
+ elif self.config.addition_embed_type == "image_hint":
1001
+ # Kandinsky 2.2 - style
1002
+ if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
1003
+ raise ValueError(
1004
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
1005
+ )
1006
+ image_embs = added_cond_kwargs.get("image_embeds")
1007
+ hint = added_cond_kwargs.get("hint")
1008
+ aug_emb, hint = self.add_embedding(image_embs, hint)
1009
+ sample = torch.cat([sample, hint], dim=1)
1010
+
1011
+ emb = emb + aug_emb if aug_emb is not None else emb
1012
+
1013
+ if self.time_embed_act is not None:
1014
+ emb = self.time_embed_act(emb)
1015
+
1016
+ if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
1017
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
1018
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
1019
+ # Kadinsky 2.1 - style
1020
+ if "image_embeds" not in added_cond_kwargs:
1021
+ raise ValueError(
1022
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1023
+ )
1024
+
1025
+ image_embeds = added_cond_kwargs.get("image_embeds")
1026
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
1027
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
1028
+ # Kandinsky 2.2 - style
1029
+ if "image_embeds" not in added_cond_kwargs:
1030
+ raise ValueError(
1031
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1032
+ )
1033
+ image_embeds = added_cond_kwargs.get("image_embeds")
1034
+ encoder_hidden_states = self.encoder_hid_proj(image_embeds)
1035
+ # 2. pre-process
1036
+ sample = self.conv_in(sample)
1037
+
1038
+ # 2.5 GLIGEN position net
1039
+ if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
1040
+ cross_attention_kwargs = cross_attention_kwargs.copy()
1041
+ gligen_args = cross_attention_kwargs.pop("gligen")
1042
+ cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
1043
+
1044
+ # For Vton
1045
+ spatial_attn_inputs = []
1046
+
1047
+ # 3. down
1048
+ lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
1049
+ if USE_PEFT_BACKEND:
1050
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
1051
+ scale_lora_layers(self, lora_scale)
1052
+
1053
+ is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
1054
+ # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets
1055
+ is_adapter = down_intrablock_additional_residuals is not None
1056
+ # maintain backward compatibility for legacy usage, where
1057
+ # T2I-Adapter and ControlNet both use down_block_additional_residuals arg
1058
+ # but can only use one or the other
1059
+ if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None:
1060
+ deprecate(
1061
+ "T2I should not use down_block_additional_residuals",
1062
+ "1.3.0",
1063
+ "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \
1064
+ and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \
1065
+ for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ",
1066
+ standard_warn=False,
1067
+ )
1068
+ down_intrablock_additional_residuals = down_block_additional_residuals
1069
+ is_adapter = True
1070
+
1071
+ down_block_res_samples = (sample,)
1072
+ for downsample_block in self.down_blocks:
1073
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
1074
+ # For t2i-adapter CrossAttnDownBlock2D
1075
+ additional_residuals = {}
1076
+ if is_adapter and len(down_intrablock_additional_residuals) > 0:
1077
+ additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0)
1078
+
1079
+ sample, res_samples, spatial_attn_inputs = downsample_block(
1080
+ hidden_states=sample,
1081
+ spatial_attn_inputs=spatial_attn_inputs,
1082
+ temb=emb,
1083
+ encoder_hidden_states=encoder_hidden_states,
1084
+ attention_mask=attention_mask,
1085
+ cross_attention_kwargs=cross_attention_kwargs,
1086
+ encoder_attention_mask=encoder_attention_mask,
1087
+ **additional_residuals,
1088
+ )
1089
+ else:
1090
+ sample, res_samples = downsample_block(
1091
+ hidden_states=sample,
1092
+ temb=emb,
1093
+ scale=lora_scale,
1094
+ )
1095
+ if is_adapter and len(down_intrablock_additional_residuals) > 0:
1096
+ sample += down_intrablock_additional_residuals.pop(0)
1097
+
1098
+ down_block_res_samples += res_samples
1099
+
1100
+ # if is_controlnet:
1101
+ # new_down_block_res_samples = ()
1102
+
1103
+ # for down_block_res_sample, down_block_additional_residual in zip(
1104
+ # down_block_res_samples, down_block_additional_residuals
1105
+ # ):
1106
+ # down_block_res_sample = down_block_res_sample + down_block_additional_residual
1107
+ # new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
1108
+
1109
+ # down_block_res_samples = new_down_block_res_samples
1110
+
1111
+ # 4. mid
1112
+ if self.mid_block is not None:
1113
+ if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
1114
+ sample, spatial_attn_inputs = self.mid_block(
1115
+ sample,
1116
+ spatial_attn_inputs=spatial_attn_inputs,
1117
+ temb=emb,
1118
+ encoder_hidden_states=encoder_hidden_states,
1119
+ attention_mask=attention_mask,
1120
+ cross_attention_kwargs=cross_attention_kwargs,
1121
+ encoder_attention_mask=encoder_attention_mask,
1122
+ )
1123
+ else:
1124
+ sample = self.mid_block(sample, emb)
1125
+
1126
+ # To support T2I-Adapter-XL
1127
+ if (
1128
+ is_adapter
1129
+ and len(down_intrablock_additional_residuals) > 0
1130
+ and sample.shape == down_intrablock_additional_residuals[0].shape
1131
+ ):
1132
+ sample += down_intrablock_additional_residuals.pop(0)
1133
+
1134
+ if is_controlnet:
1135
+ sample = sample + mid_block_additional_residual
1136
+
1137
+ # 5. up
1138
+ for i, upsample_block in enumerate(self.up_blocks):
1139
+ is_final_block = i == len(self.up_blocks) - 1
1140
+
1141
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
1142
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
1143
+
1144
+ # if we have not reached the final block and need to forward the
1145
+ # upsample size, we do it here
1146
+ if not is_final_block and forward_upsample_size:
1147
+ upsample_size = down_block_res_samples[-1].shape[2:]
1148
+
1149
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
1150
+ sample, spatial_attn_inputs = upsample_block(
1151
+ hidden_states=sample,
1152
+ spatial_attn_inputs=spatial_attn_inputs,
1153
+ temb=emb,
1154
+ res_hidden_states_tuple=res_samples,
1155
+ encoder_hidden_states=encoder_hidden_states,
1156
+ cross_attention_kwargs=cross_attention_kwargs,
1157
+ upsample_size=upsample_size,
1158
+ attention_mask=attention_mask,
1159
+ encoder_attention_mask=encoder_attention_mask,
1160
+ )
1161
+ else:
1162
+ sample = upsample_block(
1163
+ hidden_states=sample,
1164
+ temb=emb,
1165
+ res_hidden_states_tuple=res_samples,
1166
+ upsample_size=upsample_size,
1167
+ scale=lora_scale,
1168
+ )
1169
+
1170
+ # 6. post-process
1171
+ if self.conv_norm_out:
1172
+ sample = self.conv_norm_out(sample)
1173
+ sample = self.conv_act(sample)
1174
+ sample = self.conv_out(sample)
1175
+
1176
+ if USE_PEFT_BACKEND:
1177
+ # remove `lora_scale` from each PEFT layer
1178
+ unscale_lora_layers(self, lora_scale)
1179
+
1180
+ if not return_dict:
1181
+ return (sample,), spatial_attn_inputs
1182
+
1183
+ return UNet2DConditionOutput(sample=sample), spatial_attn_inputs
ootd/pipelines_ootd/unet_vton_2d_blocks.py ADDED
The diff for this file is too large to render. See raw diff
 
ootd/pipelines_ootd/unet_vton_2d_condition.py ADDED
@@ -0,0 +1,1183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Modified by Yuhao Xu for OOTDiffusion (https://github.com/levihsu/OOTDiffusion)
16
+ from dataclasses import dataclass
17
+ from typing import Any, Dict, List, Optional, Tuple, Union
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.utils.checkpoint
22
+
23
+ from .unet_vton_2d_blocks import (
24
+ UNetMidBlock2D,
25
+ UNetMidBlock2DCrossAttn,
26
+ UNetMidBlock2DSimpleCrossAttn,
27
+ get_down_block,
28
+ get_up_block,
29
+ )
30
+
31
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
32
+ from diffusers.loaders import UNet2DConditionLoadersMixin
33
+ from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
34
+ from diffusers.models.activations import get_activation
35
+ from diffusers.models.attention_processor import (
36
+ ADDED_KV_ATTENTION_PROCESSORS,
37
+ CROSS_ATTENTION_PROCESSORS,
38
+ AttentionProcessor,
39
+ AttnAddedKVProcessor,
40
+ AttnProcessor,
41
+ )
42
+ from diffusers.models.embeddings import (
43
+ GaussianFourierProjection,
44
+ ImageHintTimeEmbedding,
45
+ ImageProjection,
46
+ ImageTimeEmbedding,
47
+ PositionNet,
48
+ TextImageProjection,
49
+ TextImageTimeEmbedding,
50
+ TextTimeEmbedding,
51
+ TimestepEmbedding,
52
+ Timesteps,
53
+ )
54
+ from diffusers.models.modeling_utils import ModelMixin
55
+ # from ..diffusers.src.diffusers.models.unet_2d_blocks import (
56
+ # UNetMidBlock2D,
57
+ # UNetMidBlock2DCrossAttn,
58
+ # UNetMidBlock2DSimpleCrossAttn,
59
+ # get_down_block,
60
+ # get_up_block,
61
+ # )
62
+
63
+
64
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
65
+
66
+
67
+ @dataclass
68
+ class UNet2DConditionOutput(BaseOutput):
69
+ """
70
+ The output of [`UNet2DConditionModel`].
71
+
72
+ Args:
73
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
74
+ The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
75
+ """
76
+
77
+ sample: torch.FloatTensor = None
78
+
79
+
80
+ class UNetVton2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
81
+ r"""
82
+ A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
83
+ shaped output.
84
+
85
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
86
+ for all models (such as downloading or saving).
87
+
88
+ Parameters:
89
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
90
+ Height and width of input/output sample.
91
+ in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
92
+ out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
93
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
94
+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
95
+ Whether to flip the sin to cos in the time embedding.
96
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
97
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
98
+ The tuple of downsample blocks to use.
99
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
100
+ Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or
101
+ `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
102
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
103
+ The tuple of upsample blocks to use.
104
+ only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
105
+ Whether to include self-attention in the basic transformer blocks, see
106
+ [`~models.attention.BasicTransformerBlock`].
107
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
108
+ The tuple of output channels for each block.
109
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
110
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
111
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
112
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
113
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
114
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
115
+ If `None`, normalization and activation layers is skipped in post-processing.
116
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
117
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
118
+ The dimension of the cross attention features.
119
+ transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
120
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
121
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
122
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
123
+ reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None):
124
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling
125
+ blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for
126
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
127
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
128
+ encoder_hid_dim (`int`, *optional*, defaults to None):
129
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
130
+ dimension to `cross_attention_dim`.
131
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
132
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
133
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
134
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
135
+ num_attention_heads (`int`, *optional*):
136
+ The number of attention heads. If not defined, defaults to `attention_head_dim`
137
+ resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
138
+ for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
139
+ class_embed_type (`str`, *optional*, defaults to `None`):
140
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
141
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
142
+ addition_embed_type (`str`, *optional*, defaults to `None`):
143
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
144
+ "text". "text" will use the `TextTimeEmbedding` layer.
145
+ addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
146
+ Dimension for the timestep embeddings.
147
+ num_class_embeds (`int`, *optional*, defaults to `None`):
148
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
149
+ class conditioning with `class_embed_type` equal to `None`.
150
+ time_embedding_type (`str`, *optional*, defaults to `positional`):
151
+ The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
152
+ time_embedding_dim (`int`, *optional*, defaults to `None`):
153
+ An optional override for the dimension of the projected time embedding.
154
+ time_embedding_act_fn (`str`, *optional*, defaults to `None`):
155
+ Optional activation function to use only once on the time embeddings before they are passed to the rest of
156
+ the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
157
+ timestep_post_act (`str`, *optional*, defaults to `None`):
158
+ The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
159
+ time_cond_proj_dim (`int`, *optional*, defaults to `None`):
160
+ The dimension of `cond_proj` layer in the timestep embedding.
161
+ conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. conv_out_kernel (`int`,
162
+ *optional*, default to `3`): The kernel size of `conv_out` layer. projection_class_embeddings_input_dim (`int`,
163
+ *optional*): The dimension of the `class_labels` input when
164
+ `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
165
+ class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
166
+ embeddings with the class embeddings.
167
+ mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
168
+ Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
169
+ `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
170
+ `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
171
+ otherwise.
172
+ """
173
+
174
+ _supports_gradient_checkpointing = True
175
+
176
+ @register_to_config
177
+ def __init__(
178
+ self,
179
+ sample_size: Optional[int] = None,
180
+ in_channels: int = 4,
181
+ out_channels: int = 4,
182
+ center_input_sample: bool = False,
183
+ flip_sin_to_cos: bool = True,
184
+ freq_shift: int = 0,
185
+ down_block_types: Tuple[str] = (
186
+ "CrossAttnDownBlock2D",
187
+ "CrossAttnDownBlock2D",
188
+ "CrossAttnDownBlock2D",
189
+ "DownBlock2D",
190
+ ),
191
+ mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
192
+ up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
193
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
194
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
195
+ layers_per_block: Union[int, Tuple[int]] = 2,
196
+ downsample_padding: int = 1,
197
+ mid_block_scale_factor: float = 1,
198
+ dropout: float = 0.0,
199
+ act_fn: str = "silu",
200
+ norm_num_groups: Optional[int] = 32,
201
+ norm_eps: float = 1e-5,
202
+ cross_attention_dim: Union[int, Tuple[int]] = 1280,
203
+ transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
204
+ reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
205
+ encoder_hid_dim: Optional[int] = None,
206
+ encoder_hid_dim_type: Optional[str] = None,
207
+ attention_head_dim: Union[int, Tuple[int]] = 8,
208
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
209
+ dual_cross_attention: bool = False,
210
+ use_linear_projection: bool = False,
211
+ class_embed_type: Optional[str] = None,
212
+ addition_embed_type: Optional[str] = None,
213
+ addition_time_embed_dim: Optional[int] = None,
214
+ num_class_embeds: Optional[int] = None,
215
+ upcast_attention: bool = False,
216
+ resnet_time_scale_shift: str = "default",
217
+ resnet_skip_time_act: bool = False,
218
+ resnet_out_scale_factor: int = 1.0,
219
+ time_embedding_type: str = "positional",
220
+ time_embedding_dim: Optional[int] = None,
221
+ time_embedding_act_fn: Optional[str] = None,
222
+ timestep_post_act: Optional[str] = None,
223
+ time_cond_proj_dim: Optional[int] = None,
224
+ conv_in_kernel: int = 3,
225
+ conv_out_kernel: int = 3,
226
+ projection_class_embeddings_input_dim: Optional[int] = None,
227
+ attention_type: str = "default",
228
+ class_embeddings_concat: bool = False,
229
+ mid_block_only_cross_attention: Optional[bool] = None,
230
+ cross_attention_norm: Optional[str] = None,
231
+ addition_embed_type_num_heads=64,
232
+ ):
233
+ super().__init__()
234
+
235
+ self.sample_size = sample_size
236
+
237
+ if num_attention_heads is not None:
238
+ raise ValueError(
239
+ "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
240
+ )
241
+
242
+ # If `num_attention_heads` is not defined (which is the case for most models)
243
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
244
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
245
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
246
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
247
+ # which is why we correct for the naming here.
248
+ num_attention_heads = num_attention_heads or attention_head_dim
249
+
250
+ # Check inputs
251
+ if len(down_block_types) != len(up_block_types):
252
+ raise ValueError(
253
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
254
+ )
255
+
256
+ if len(block_out_channels) != len(down_block_types):
257
+ raise ValueError(
258
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
259
+ )
260
+
261
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
262
+ raise ValueError(
263
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
264
+ )
265
+
266
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
267
+ raise ValueError(
268
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
269
+ )
270
+
271
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
272
+ raise ValueError(
273
+ f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
274
+ )
275
+
276
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
277
+ raise ValueError(
278
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
279
+ )
280
+
281
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
282
+ raise ValueError(
283
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
284
+ )
285
+ if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None:
286
+ for layer_number_per_block in transformer_layers_per_block:
287
+ if isinstance(layer_number_per_block, list):
288
+ raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.")
289
+
290
+ # input
291
+ conv_in_padding = (conv_in_kernel - 1) // 2
292
+ self.conv_in = nn.Conv2d(
293
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
294
+ )
295
+
296
+ # time
297
+ if time_embedding_type == "fourier":
298
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
299
+ if time_embed_dim % 2 != 0:
300
+ raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
301
+ self.time_proj = GaussianFourierProjection(
302
+ time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
303
+ )
304
+ timestep_input_dim = time_embed_dim
305
+ elif time_embedding_type == "positional":
306
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
307
+
308
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
309
+ timestep_input_dim = block_out_channels[0]
310
+ else:
311
+ raise ValueError(
312
+ f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
313
+ )
314
+
315
+ self.time_embedding = TimestepEmbedding(
316
+ timestep_input_dim,
317
+ time_embed_dim,
318
+ act_fn=act_fn,
319
+ post_act_fn=timestep_post_act,
320
+ cond_proj_dim=time_cond_proj_dim,
321
+ )
322
+
323
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
324
+ encoder_hid_dim_type = "text_proj"
325
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
326
+ logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
327
+
328
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
329
+ raise ValueError(
330
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
331
+ )
332
+
333
+ if encoder_hid_dim_type == "text_proj":
334
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
335
+ elif encoder_hid_dim_type == "text_image_proj":
336
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
337
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
338
+ # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
339
+ self.encoder_hid_proj = TextImageProjection(
340
+ text_embed_dim=encoder_hid_dim,
341
+ image_embed_dim=cross_attention_dim,
342
+ cross_attention_dim=cross_attention_dim,
343
+ )
344
+ elif encoder_hid_dim_type == "image_proj":
345
+ # Kandinsky 2.2
346
+ self.encoder_hid_proj = ImageProjection(
347
+ image_embed_dim=encoder_hid_dim,
348
+ cross_attention_dim=cross_attention_dim,
349
+ )
350
+ elif encoder_hid_dim_type is not None:
351
+ raise ValueError(
352
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
353
+ )
354
+ else:
355
+ self.encoder_hid_proj = None
356
+
357
+ # class embedding
358
+ if class_embed_type is None and num_class_embeds is not None:
359
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
360
+ elif class_embed_type == "timestep":
361
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
362
+ elif class_embed_type == "identity":
363
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
364
+ elif class_embed_type == "projection":
365
+ if projection_class_embeddings_input_dim is None:
366
+ raise ValueError(
367
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
368
+ )
369
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
370
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
371
+ # 2. it projects from an arbitrary input dimension.
372
+ #
373
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
374
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
375
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
376
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
377
+ elif class_embed_type == "simple_projection":
378
+ if projection_class_embeddings_input_dim is None:
379
+ raise ValueError(
380
+ "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
381
+ )
382
+ self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
383
+ else:
384
+ self.class_embedding = None
385
+
386
+ if addition_embed_type == "text":
387
+ if encoder_hid_dim is not None:
388
+ text_time_embedding_from_dim = encoder_hid_dim
389
+ else:
390
+ text_time_embedding_from_dim = cross_attention_dim
391
+
392
+ self.add_embedding = TextTimeEmbedding(
393
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
394
+ )
395
+ elif addition_embed_type == "text_image":
396
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
397
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
398
+ # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
399
+ self.add_embedding = TextImageTimeEmbedding(
400
+ text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
401
+ )
402
+ elif addition_embed_type == "text_time":
403
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
404
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
405
+ elif addition_embed_type == "image":
406
+ # Kandinsky 2.2
407
+ self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
408
+ elif addition_embed_type == "image_hint":
409
+ # Kandinsky 2.2 ControlNet
410
+ self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
411
+ elif addition_embed_type is not None:
412
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
413
+
414
+ if time_embedding_act_fn is None:
415
+ self.time_embed_act = None
416
+ else:
417
+ self.time_embed_act = get_activation(time_embedding_act_fn)
418
+
419
+ self.down_blocks = nn.ModuleList([])
420
+ self.up_blocks = nn.ModuleList([])
421
+
422
+ if isinstance(only_cross_attention, bool):
423
+ if mid_block_only_cross_attention is None:
424
+ mid_block_only_cross_attention = only_cross_attention
425
+
426
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
427
+
428
+ if mid_block_only_cross_attention is None:
429
+ mid_block_only_cross_attention = False
430
+
431
+ if isinstance(num_attention_heads, int):
432
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
433
+
434
+ if isinstance(attention_head_dim, int):
435
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
436
+
437
+ if isinstance(cross_attention_dim, int):
438
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
439
+
440
+ if isinstance(layers_per_block, int):
441
+ layers_per_block = [layers_per_block] * len(down_block_types)
442
+
443
+ if isinstance(transformer_layers_per_block, int):
444
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
445
+
446
+ if class_embeddings_concat:
447
+ # The time embeddings are concatenated with the class embeddings. The dimension of the
448
+ # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
449
+ # regular time embeddings
450
+ blocks_time_embed_dim = time_embed_dim * 2
451
+ else:
452
+ blocks_time_embed_dim = time_embed_dim
453
+
454
+ # down
455
+ output_channel = block_out_channels[0]
456
+ for i, down_block_type in enumerate(down_block_types):
457
+ input_channel = output_channel
458
+ output_channel = block_out_channels[i]
459
+ is_final_block = i == len(block_out_channels) - 1
460
+
461
+ down_block = get_down_block(
462
+ down_block_type,
463
+ num_layers=layers_per_block[i],
464
+ transformer_layers_per_block=transformer_layers_per_block[i],
465
+ in_channels=input_channel,
466
+ out_channels=output_channel,
467
+ temb_channels=blocks_time_embed_dim,
468
+ add_downsample=not is_final_block,
469
+ resnet_eps=norm_eps,
470
+ resnet_act_fn=act_fn,
471
+ resnet_groups=norm_num_groups,
472
+ cross_attention_dim=cross_attention_dim[i],
473
+ num_attention_heads=num_attention_heads[i],
474
+ downsample_padding=downsample_padding,
475
+ dual_cross_attention=dual_cross_attention,
476
+ use_linear_projection=use_linear_projection,
477
+ only_cross_attention=only_cross_attention[i],
478
+ upcast_attention=upcast_attention,
479
+ resnet_time_scale_shift=resnet_time_scale_shift,
480
+ attention_type=attention_type,
481
+ resnet_skip_time_act=resnet_skip_time_act,
482
+ resnet_out_scale_factor=resnet_out_scale_factor,
483
+ cross_attention_norm=cross_attention_norm,
484
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
485
+ dropout=dropout,
486
+ )
487
+ self.down_blocks.append(down_block)
488
+
489
+ # mid
490
+ if mid_block_type == "UNetMidBlock2DCrossAttn":
491
+ self.mid_block = UNetMidBlock2DCrossAttn(
492
+ transformer_layers_per_block=transformer_layers_per_block[-1],
493
+ in_channels=block_out_channels[-1],
494
+ temb_channels=blocks_time_embed_dim,
495
+ dropout=dropout,
496
+ resnet_eps=norm_eps,
497
+ resnet_act_fn=act_fn,
498
+ output_scale_factor=mid_block_scale_factor,
499
+ resnet_time_scale_shift=resnet_time_scale_shift,
500
+ cross_attention_dim=cross_attention_dim[-1],
501
+ num_attention_heads=num_attention_heads[-1],
502
+ resnet_groups=norm_num_groups,
503
+ dual_cross_attention=dual_cross_attention,
504
+ use_linear_projection=use_linear_projection,
505
+ upcast_attention=upcast_attention,
506
+ attention_type=attention_type,
507
+ )
508
+ elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
509
+ self.mid_block = UNetMidBlock2DSimpleCrossAttn(
510
+ in_channels=block_out_channels[-1],
511
+ temb_channels=blocks_time_embed_dim,
512
+ dropout=dropout,
513
+ resnet_eps=norm_eps,
514
+ resnet_act_fn=act_fn,
515
+ output_scale_factor=mid_block_scale_factor,
516
+ cross_attention_dim=cross_attention_dim[-1],
517
+ attention_head_dim=attention_head_dim[-1],
518
+ resnet_groups=norm_num_groups,
519
+ resnet_time_scale_shift=resnet_time_scale_shift,
520
+ skip_time_act=resnet_skip_time_act,
521
+ only_cross_attention=mid_block_only_cross_attention,
522
+ cross_attention_norm=cross_attention_norm,
523
+ )
524
+ elif mid_block_type == "UNetMidBlock2D":
525
+ self.mid_block = UNetMidBlock2D(
526
+ in_channels=block_out_channels[-1],
527
+ temb_channels=blocks_time_embed_dim,
528
+ dropout=dropout,
529
+ num_layers=0,
530
+ resnet_eps=norm_eps,
531
+ resnet_act_fn=act_fn,
532
+ output_scale_factor=mid_block_scale_factor,
533
+ resnet_groups=norm_num_groups,
534
+ resnet_time_scale_shift=resnet_time_scale_shift,
535
+ add_attention=False,
536
+ )
537
+ elif mid_block_type is None:
538
+ self.mid_block = None
539
+ else:
540
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
541
+
542
+ # count how many layers upsample the images
543
+ self.num_upsamplers = 0
544
+
545
+ # up
546
+ reversed_block_out_channels = list(reversed(block_out_channels))
547
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
548
+ reversed_layers_per_block = list(reversed(layers_per_block))
549
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
550
+ reversed_transformer_layers_per_block = (
551
+ list(reversed(transformer_layers_per_block))
552
+ if reverse_transformer_layers_per_block is None
553
+ else reverse_transformer_layers_per_block
554
+ )
555
+ only_cross_attention = list(reversed(only_cross_attention))
556
+
557
+ output_channel = reversed_block_out_channels[0]
558
+ for i, up_block_type in enumerate(up_block_types):
559
+ is_final_block = i == len(block_out_channels) - 1
560
+
561
+ prev_output_channel = output_channel
562
+ output_channel = reversed_block_out_channels[i]
563
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
564
+
565
+ # add upsample block for all BUT final layer
566
+ if not is_final_block:
567
+ add_upsample = True
568
+ self.num_upsamplers += 1
569
+ else:
570
+ add_upsample = False
571
+
572
+ up_block = get_up_block(
573
+ up_block_type,
574
+ num_layers=reversed_layers_per_block[i] + 1,
575
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
576
+ in_channels=input_channel,
577
+ out_channels=output_channel,
578
+ prev_output_channel=prev_output_channel,
579
+ temb_channels=blocks_time_embed_dim,
580
+ add_upsample=add_upsample,
581
+ resnet_eps=norm_eps,
582
+ resnet_act_fn=act_fn,
583
+ resolution_idx=i,
584
+ resnet_groups=norm_num_groups,
585
+ cross_attention_dim=reversed_cross_attention_dim[i],
586
+ num_attention_heads=reversed_num_attention_heads[i],
587
+ dual_cross_attention=dual_cross_attention,
588
+ use_linear_projection=use_linear_projection,
589
+ only_cross_attention=only_cross_attention[i],
590
+ upcast_attention=upcast_attention,
591
+ resnet_time_scale_shift=resnet_time_scale_shift,
592
+ attention_type=attention_type,
593
+ resnet_skip_time_act=resnet_skip_time_act,
594
+ resnet_out_scale_factor=resnet_out_scale_factor,
595
+ cross_attention_norm=cross_attention_norm,
596
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
597
+ dropout=dropout,
598
+ )
599
+ self.up_blocks.append(up_block)
600
+ prev_output_channel = output_channel
601
+
602
+ # out
603
+ if norm_num_groups is not None:
604
+ self.conv_norm_out = nn.GroupNorm(
605
+ num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
606
+ )
607
+
608
+ self.conv_act = get_activation(act_fn)
609
+
610
+ else:
611
+ self.conv_norm_out = None
612
+ self.conv_act = None
613
+
614
+ conv_out_padding = (conv_out_kernel - 1) // 2
615
+ self.conv_out = nn.Conv2d(
616
+ block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
617
+ )
618
+
619
+ if attention_type in ["gated", "gated-text-image"]:
620
+ positive_len = 768
621
+ if isinstance(cross_attention_dim, int):
622
+ positive_len = cross_attention_dim
623
+ elif isinstance(cross_attention_dim, tuple) or isinstance(cross_attention_dim, list):
624
+ positive_len = cross_attention_dim[0]
625
+
626
+ feature_type = "text-only" if attention_type == "gated" else "text-image"
627
+ self.position_net = PositionNet(
628
+ positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type
629
+ )
630
+
631
+ @property
632
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
633
+ r"""
634
+ Returns:
635
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
636
+ indexed by its weight name.
637
+ """
638
+ # set recursively
639
+ processors = {}
640
+
641
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
642
+ if hasattr(module, "get_processor"):
643
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
644
+
645
+ for sub_name, child in module.named_children():
646
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
647
+
648
+ return processors
649
+
650
+ for name, module in self.named_children():
651
+ fn_recursive_add_processors(name, module, processors)
652
+
653
+ return processors
654
+
655
+ def set_attn_processor(
656
+ self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
657
+ ):
658
+ r"""
659
+ Sets the attention processor to use to compute attention.
660
+
661
+ Parameters:
662
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
663
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
664
+ for **all** `Attention` layers.
665
+
666
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
667
+ processor. This is strongly recommended when setting trainable attention processors.
668
+
669
+ """
670
+ count = len(self.attn_processors.keys())
671
+
672
+ if isinstance(processor, dict) and len(processor) != count:
673
+ raise ValueError(
674
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
675
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
676
+ )
677
+
678
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
679
+ if hasattr(module, "set_processor"):
680
+ if not isinstance(processor, dict):
681
+ module.set_processor(processor, _remove_lora=_remove_lora)
682
+ else:
683
+ module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
684
+
685
+ for sub_name, child in module.named_children():
686
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
687
+
688
+ for name, module in self.named_children():
689
+ fn_recursive_attn_processor(name, module, processor)
690
+
691
+ def set_default_attn_processor(self):
692
+ """
693
+ Disables custom attention processors and sets the default attention implementation.
694
+ """
695
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
696
+ processor = AttnAddedKVProcessor()
697
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
698
+ processor = AttnProcessor()
699
+ else:
700
+ raise ValueError(
701
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
702
+ )
703
+
704
+ self.set_attn_processor(processor, _remove_lora=True)
705
+
706
+ def set_attention_slice(self, slice_size):
707
+ r"""
708
+ Enable sliced attention computation.
709
+
710
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
711
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
712
+
713
+ Args:
714
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
715
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
716
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
717
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
718
+ must be a multiple of `slice_size`.
719
+ """
720
+ sliceable_head_dims = []
721
+
722
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
723
+ if hasattr(module, "set_attention_slice"):
724
+ sliceable_head_dims.append(module.sliceable_head_dim)
725
+
726
+ for child in module.children():
727
+ fn_recursive_retrieve_sliceable_dims(child)
728
+
729
+ # retrieve number of attention layers
730
+ for module in self.children():
731
+ fn_recursive_retrieve_sliceable_dims(module)
732
+
733
+ num_sliceable_layers = len(sliceable_head_dims)
734
+
735
+ if slice_size == "auto":
736
+ # half the attention head size is usually a good trade-off between
737
+ # speed and memory
738
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
739
+ elif slice_size == "max":
740
+ # make smallest slice possible
741
+ slice_size = num_sliceable_layers * [1]
742
+
743
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
744
+
745
+ if len(slice_size) != len(sliceable_head_dims):
746
+ raise ValueError(
747
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
748
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
749
+ )
750
+
751
+ for i in range(len(slice_size)):
752
+ size = slice_size[i]
753
+ dim = sliceable_head_dims[i]
754
+ if size is not None and size > dim:
755
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
756
+
757
+ # Recursively walk through all the children.
758
+ # Any children which exposes the set_attention_slice method
759
+ # gets the message
760
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
761
+ if hasattr(module, "set_attention_slice"):
762
+ module.set_attention_slice(slice_size.pop())
763
+
764
+ for child in module.children():
765
+ fn_recursive_set_attention_slice(child, slice_size)
766
+
767
+ reversed_slice_size = list(reversed(slice_size))
768
+ for module in self.children():
769
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
770
+
771
+ def _set_gradient_checkpointing(self, module, value=False):
772
+ if hasattr(module, "gradient_checkpointing"):
773
+ module.gradient_checkpointing = value
774
+
775
+ def enable_freeu(self, s1, s2, b1, b2):
776
+ r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
777
+
778
+ The suffixes after the scaling factors represent the stage blocks where they are being applied.
779
+
780
+ Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that
781
+ are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
782
+
783
+ Args:
784
+ s1 (`float`):
785
+ Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
786
+ mitigate the "oversmoothing effect" in the enhanced denoising process.
787
+ s2 (`float`):
788
+ Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
789
+ mitigate the "oversmoothing effect" in the enhanced denoising process.
790
+ b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
791
+ b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
792
+ """
793
+ for i, upsample_block in enumerate(self.up_blocks):
794
+ setattr(upsample_block, "s1", s1)
795
+ setattr(upsample_block, "s2", s2)
796
+ setattr(upsample_block, "b1", b1)
797
+ setattr(upsample_block, "b2", b2)
798
+
799
+ def disable_freeu(self):
800
+ """Disables the FreeU mechanism."""
801
+ freeu_keys = {"s1", "s2", "b1", "b2"}
802
+ for i, upsample_block in enumerate(self.up_blocks):
803
+ for k in freeu_keys:
804
+ if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
805
+ setattr(upsample_block, k, None)
806
+
807
+ def forward(
808
+ self,
809
+ sample: torch.FloatTensor,
810
+ spatial_attn_inputs,
811
+ timestep: Union[torch.Tensor, float, int],
812
+ encoder_hidden_states: torch.Tensor,
813
+ class_labels: Optional[torch.Tensor] = None,
814
+ timestep_cond: Optional[torch.Tensor] = None,
815
+ attention_mask: Optional[torch.Tensor] = None,
816
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
817
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
818
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
819
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
820
+ down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
821
+ encoder_attention_mask: Optional[torch.Tensor] = None,
822
+ return_dict: bool = True,
823
+ ) -> Union[UNet2DConditionOutput, Tuple]:
824
+ r"""
825
+ The [`UNet2DConditionModel`] forward method.
826
+
827
+ Args:
828
+ sample (`torch.FloatTensor`):
829
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
830
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
831
+ encoder_hidden_states (`torch.FloatTensor`):
832
+ The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
833
+ class_labels (`torch.Tensor`, *optional*, defaults to `None`):
834
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
835
+ timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`):
836
+ Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed
837
+ through the `self.time_embedding` layer to obtain the timestep embeddings.
838
+ attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
839
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
840
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
841
+ negative values to the attention scores corresponding to "discard" tokens.
842
+ cross_attention_kwargs (`dict`, *optional*):
843
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
844
+ `self.processor` in
845
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
846
+ added_cond_kwargs: (`dict`, *optional*):
847
+ A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that
848
+ are passed along to the UNet blocks.
849
+ down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*):
850
+ A tuple of tensors that if specified are added to the residuals of down unet blocks.
851
+ mid_block_additional_residual: (`torch.Tensor`, *optional*):
852
+ A tensor that if specified is added to the residual of the middle unet block.
853
+ encoder_attention_mask (`torch.Tensor`):
854
+ A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
855
+ `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
856
+ which adds large negative values to the attention scores corresponding to "discard" tokens.
857
+ return_dict (`bool`, *optional*, defaults to `True`):
858
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
859
+ tuple.
860
+ cross_attention_kwargs (`dict`, *optional*):
861
+ A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
862
+ added_cond_kwargs: (`dict`, *optional*):
863
+ A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
864
+ are passed along to the UNet blocks.
865
+ down_block_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
866
+ additional residuals to be added to UNet long skip connections from down blocks to up blocks for
867
+ example from ControlNet side model(s)
868
+ mid_block_additional_residual (`torch.Tensor`, *optional*):
869
+ additional residual to be added to UNet mid block output, for example from ControlNet side model
870
+ down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
871
+ additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
872
+
873
+ Returns:
874
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
875
+ If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
876
+ a `tuple` is returned where the first element is the sample tensor.
877
+ """
878
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
879
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
880
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
881
+ # on the fly if necessary.
882
+ default_overall_up_factor = 2**self.num_upsamplers
883
+
884
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
885
+ forward_upsample_size = False
886
+ upsample_size = None
887
+
888
+ for dim in sample.shape[-2:]:
889
+ if dim % default_overall_up_factor != 0:
890
+ # Forward upsample size to force interpolation output size.
891
+ forward_upsample_size = True
892
+ break
893
+
894
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
895
+ # expects mask of shape:
896
+ # [batch, key_tokens]
897
+ # adds singleton query_tokens dimension:
898
+ # [batch, 1, key_tokens]
899
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
900
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
901
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
902
+ if attention_mask is not None:
903
+ # assume that mask is expressed as:
904
+ # (1 = keep, 0 = discard)
905
+ # convert mask into a bias that can be added to attention scores:
906
+ # (keep = +0, discard = -10000.0)
907
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
908
+ attention_mask = attention_mask.unsqueeze(1)
909
+
910
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
911
+ if encoder_attention_mask is not None:
912
+ encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
913
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
914
+
915
+ # 0. center input if necessary
916
+ if self.config.center_input_sample:
917
+ sample = 2 * sample - 1.0
918
+
919
+ # 1. time
920
+ timesteps = timestep
921
+ if not torch.is_tensor(timesteps):
922
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
923
+ # This would be a good case for the `match` statement (Python 3.10+)
924
+ is_mps = sample.device.type == "mps"
925
+ if isinstance(timestep, float):
926
+ dtype = torch.float32 if is_mps else torch.float64
927
+ else:
928
+ dtype = torch.int32 if is_mps else torch.int64
929
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
930
+ elif len(timesteps.shape) == 0:
931
+ timesteps = timesteps[None].to(sample.device)
932
+
933
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
934
+ timesteps = timesteps.expand(sample.shape[0])
935
+
936
+ t_emb = self.time_proj(timesteps)
937
+
938
+ # `Timesteps` does not contain any weights and will always return f32 tensors
939
+ # but time_embedding might actually be running in fp16. so we need to cast here.
940
+ # there might be better ways to encapsulate this.
941
+ t_emb = t_emb.to(dtype=sample.dtype)
942
+
943
+ emb = self.time_embedding(t_emb, timestep_cond)
944
+ aug_emb = None
945
+
946
+ if self.class_embedding is not None:
947
+ if class_labels is None:
948
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
949
+
950
+ if self.config.class_embed_type == "timestep":
951
+ class_labels = self.time_proj(class_labels)
952
+
953
+ # `Timesteps` does not contain any weights and will always return f32 tensors
954
+ # there might be better ways to encapsulate this.
955
+ class_labels = class_labels.to(dtype=sample.dtype)
956
+
957
+ class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
958
+
959
+ if self.config.class_embeddings_concat:
960
+ emb = torch.cat([emb, class_emb], dim=-1)
961
+ else:
962
+ emb = emb + class_emb
963
+
964
+ if self.config.addition_embed_type == "text":
965
+ aug_emb = self.add_embedding(encoder_hidden_states)
966
+ elif self.config.addition_embed_type == "text_image":
967
+ # Kandinsky 2.1 - style
968
+ if "image_embeds" not in added_cond_kwargs:
969
+ raise ValueError(
970
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
971
+ )
972
+
973
+ image_embs = added_cond_kwargs.get("image_embeds")
974
+ text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
975
+ aug_emb = self.add_embedding(text_embs, image_embs)
976
+ elif self.config.addition_embed_type == "text_time":
977
+ # SDXL - style
978
+ if "text_embeds" not in added_cond_kwargs:
979
+ raise ValueError(
980
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
981
+ )
982
+ text_embeds = added_cond_kwargs.get("text_embeds")
983
+ if "time_ids" not in added_cond_kwargs:
984
+ raise ValueError(
985
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
986
+ )
987
+ time_ids = added_cond_kwargs.get("time_ids")
988
+ time_embeds = self.add_time_proj(time_ids.flatten())
989
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
990
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
991
+ add_embeds = add_embeds.to(emb.dtype)
992
+ aug_emb = self.add_embedding(add_embeds)
993
+ elif self.config.addition_embed_type == "image":
994
+ # Kandinsky 2.2 - style
995
+ if "image_embeds" not in added_cond_kwargs:
996
+ raise ValueError(
997
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
998
+ )
999
+ image_embs = added_cond_kwargs.get("image_embeds")
1000
+ aug_emb = self.add_embedding(image_embs)
1001
+ elif self.config.addition_embed_type == "image_hint":
1002
+ # Kandinsky 2.2 - style
1003
+ if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
1004
+ raise ValueError(
1005
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
1006
+ )
1007
+ image_embs = added_cond_kwargs.get("image_embeds")
1008
+ hint = added_cond_kwargs.get("hint")
1009
+ aug_emb, hint = self.add_embedding(image_embs, hint)
1010
+ sample = torch.cat([sample, hint], dim=1)
1011
+
1012
+ emb = emb + aug_emb if aug_emb is not None else emb
1013
+
1014
+ if self.time_embed_act is not None:
1015
+ emb = self.time_embed_act(emb)
1016
+
1017
+ if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
1018
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
1019
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
1020
+ # Kadinsky 2.1 - style
1021
+ if "image_embeds" not in added_cond_kwargs:
1022
+ raise ValueError(
1023
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1024
+ )
1025
+
1026
+ image_embeds = added_cond_kwargs.get("image_embeds")
1027
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
1028
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
1029
+ # Kandinsky 2.2 - style
1030
+ if "image_embeds" not in added_cond_kwargs:
1031
+ raise ValueError(
1032
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1033
+ )
1034
+ image_embeds = added_cond_kwargs.get("image_embeds")
1035
+ encoder_hidden_states = self.encoder_hid_proj(image_embeds)
1036
+ # 2. pre-process
1037
+ sample = self.conv_in(sample)
1038
+
1039
+ # 2.5 GLIGEN position net
1040
+ if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
1041
+ cross_attention_kwargs = cross_attention_kwargs.copy()
1042
+ gligen_args = cross_attention_kwargs.pop("gligen")
1043
+ cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
1044
+
1045
+ # for spatial attention
1046
+ spatial_attn_idx = 0
1047
+
1048
+ # 3. down
1049
+ lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
1050
+ if USE_PEFT_BACKEND:
1051
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
1052
+ scale_lora_layers(self, lora_scale)
1053
+
1054
+ is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
1055
+ # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets
1056
+ is_adapter = down_intrablock_additional_residuals is not None
1057
+ # maintain backward compatibility for legacy usage, where
1058
+ # T2I-Adapter and ControlNet both use down_block_additional_residuals arg
1059
+ # but can only use one or the other
1060
+ if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None:
1061
+ deprecate(
1062
+ "T2I should not use down_block_additional_residuals",
1063
+ "1.3.0",
1064
+ "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \
1065
+ and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \
1066
+ for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ",
1067
+ standard_warn=False,
1068
+ )
1069
+ down_intrablock_additional_residuals = down_block_additional_residuals
1070
+ is_adapter = True
1071
+
1072
+ down_block_res_samples = (sample,)
1073
+ for downsample_block in self.down_blocks:
1074
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
1075
+ # For t2i-adapter CrossAttnDownBlock2D
1076
+ additional_residuals = {}
1077
+ if is_adapter and len(down_intrablock_additional_residuals) > 0:
1078
+ additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0)
1079
+
1080
+ sample, res_samples, spatial_attn_inputs, spatial_attn_idx = downsample_block(
1081
+ hidden_states=sample,
1082
+ spatial_attn_inputs=spatial_attn_inputs,
1083
+ spatial_attn_idx=spatial_attn_idx,
1084
+ temb=emb,
1085
+ encoder_hidden_states=encoder_hidden_states,
1086
+ attention_mask=attention_mask,
1087
+ cross_attention_kwargs=cross_attention_kwargs,
1088
+ encoder_attention_mask=encoder_attention_mask,
1089
+ **additional_residuals,
1090
+ )
1091
+ else:
1092
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale)
1093
+ if is_adapter and len(down_intrablock_additional_residuals) > 0:
1094
+ sample += down_intrablock_additional_residuals.pop(0)
1095
+
1096
+ down_block_res_samples += res_samples
1097
+
1098
+ if is_controlnet:
1099
+ new_down_block_res_samples = ()
1100
+
1101
+ for down_block_res_sample, down_block_additional_residual in zip(
1102
+ down_block_res_samples, down_block_additional_residuals
1103
+ ):
1104
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
1105
+ new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
1106
+
1107
+ down_block_res_samples = new_down_block_res_samples
1108
+
1109
+ # 4. mid
1110
+ if self.mid_block is not None:
1111
+ if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
1112
+ sample, spatial_attn_inputs, spatial_attn_idx = self.mid_block(
1113
+ sample,
1114
+ spatial_attn_inputs=spatial_attn_inputs,
1115
+ spatial_attn_idx=spatial_attn_idx,
1116
+ temb=emb,
1117
+ encoder_hidden_states=encoder_hidden_states,
1118
+ attention_mask=attention_mask,
1119
+ cross_attention_kwargs=cross_attention_kwargs,
1120
+ encoder_attention_mask=encoder_attention_mask,
1121
+ )
1122
+ else:
1123
+ sample = self.mid_block(sample, emb)
1124
+
1125
+ # To support T2I-Adapter-XL
1126
+ if (
1127
+ is_adapter
1128
+ and len(down_intrablock_additional_residuals) > 0
1129
+ and sample.shape == down_intrablock_additional_residuals[0].shape
1130
+ ):
1131
+ sample += down_intrablock_additional_residuals.pop(0)
1132
+
1133
+ if is_controlnet:
1134
+ sample = sample + mid_block_additional_residual
1135
+
1136
+ # 5. up
1137
+ for i, upsample_block in enumerate(self.up_blocks):
1138
+ is_final_block = i == len(self.up_blocks) - 1
1139
+
1140
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
1141
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
1142
+
1143
+ # if we have not reached the final block and need to forward the
1144
+ # upsample size, we do it here
1145
+ if not is_final_block and forward_upsample_size:
1146
+ upsample_size = down_block_res_samples[-1].shape[2:]
1147
+
1148
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
1149
+ sample, spatial_attn_inputs, spatial_attn_idx = upsample_block(
1150
+ hidden_states=sample,
1151
+ spatial_attn_inputs=spatial_attn_inputs,
1152
+ spatial_attn_idx=spatial_attn_idx,
1153
+ temb=emb,
1154
+ res_hidden_states_tuple=res_samples,
1155
+ encoder_hidden_states=encoder_hidden_states,
1156
+ cross_attention_kwargs=cross_attention_kwargs,
1157
+ upsample_size=upsample_size,
1158
+ attention_mask=attention_mask,
1159
+ encoder_attention_mask=encoder_attention_mask,
1160
+ )
1161
+ else:
1162
+ sample = upsample_block(
1163
+ hidden_states=sample,
1164
+ temb=emb,
1165
+ res_hidden_states_tuple=res_samples,
1166
+ upsample_size=upsample_size,
1167
+ scale=lora_scale,
1168
+ )
1169
+
1170
+ # 6. post-process
1171
+ if self.conv_norm_out:
1172
+ sample = self.conv_norm_out(sample)
1173
+ sample = self.conv_act(sample)
1174
+ sample = self.conv_out(sample)
1175
+
1176
+ if USE_PEFT_BACKEND:
1177
+ # remove `lora_scale` from each PEFT layer
1178
+ unscale_lora_layers(self, lora_scale)
1179
+
1180
+ if not return_dict:
1181
+ return (sample,)
1182
+
1183
+ return UNet2DConditionOutput(sample=sample)
preprocess/humanparsing/__pycache__/parsing_api.cpython-310.pyc ADDED
Binary file (5.94 kB). View file
 
preprocess/humanparsing/__pycache__/run_parsing.cpython-310.pyc ADDED
Binary file (1.59 kB). View file
 
preprocess/humanparsing/datasets/__init__.py ADDED
File without changes
preprocess/humanparsing/datasets/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (180 Bytes). View file