elias3446 commited on
Commit
a087ce1
Β·
1 Parent(s): 30a5522

Upload 37 files

Browse files
Files changed (37) hide show
  1. Text2Human/configs/index_pred_net.yml +84 -0
  2. Text2Human/configs/parsing_gen.yml +40 -0
  3. Text2Human/configs/parsing_token.yml +47 -0
  4. Text2Human/configs/sample_from_parsing.yml +93 -0
  5. Text2Human/configs/sample_from_pose.yml +107 -0
  6. Text2Human/configs/sampler.yml +83 -0
  7. Text2Human/configs/vqvae_bottom.yml +72 -0
  8. Text2Human/configs/vqvae_top.yml +53 -0
  9. models/__init__.py +42 -0
  10. models/archs/__init__.py +0 -0
  11. models/archs/__pycache__/__init__.cpython-38.pyc +0 -0
  12. models/archs/__pycache__/fcn_arch.cpython-38.pyc +0 -0
  13. models/archs/__pycache__/shape_attr_embedding_arch.cpython-38.pyc +0 -0
  14. models/archs/__pycache__/transformer_arch.cpython-38.pyc +0 -0
  15. models/archs/__pycache__/unet_arch.cpython-38.pyc +0 -0
  16. models/archs/__pycache__/vqgan_arch.cpython-38.pyc +0 -0
  17. models/archs/fcn_arch.py +418 -0
  18. models/archs/shape_attr_embedding_arch.py +35 -0
  19. models/archs/transformer_arch.py +273 -0
  20. models/archs/unet_arch.py +693 -0
  21. models/archs/vqgan_arch.py +1203 -0
  22. models/hierarchy_inference_model.py +363 -0
  23. models/hierarchy_vqgan_model.py +374 -0
  24. models/losses/__init__.py +0 -0
  25. models/losses/__pycache__/__init__.cpython-38.pyc +0 -0
  26. models/losses/__pycache__/accuracy.cpython-38.pyc +0 -0
  27. models/losses/__pycache__/cross_entropy_loss.cpython-38.pyc +0 -0
  28. models/losses/__pycache__/segmentation_loss.cpython-38.pyc +0 -0
  29. models/losses/__pycache__/vqgan_loss.cpython-38.pyc +0 -0
  30. models/losses/accuracy.py +46 -0
  31. models/losses/cross_entropy_loss.py +246 -0
  32. models/losses/segmentation_loss.py +25 -0
  33. models/losses/vqgan_loss.py +114 -0
  34. models/parsing_gen_model.py +220 -0
  35. models/sample_model.py +498 -0
  36. models/transformer_model.py +482 -0
  37. models/vqgan_model.py +551 -0
Text2Human/configs/index_pred_net.yml ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: index_prediction_network
2
+ use_tb_logger: true
3
+ set_CUDA_VISIBLE_DEVICES: ~
4
+ gpu_ids: [3]
5
+
6
+ # dataset configs
7
+ batch_size: 4
8
+ num_workers: 4
9
+ train_img_dir: ./datasets/train_images
10
+ test_img_dir: ./datasets/test_images
11
+ segm_dir: ./datasets/segm
12
+ pose_dir: ./datasets/densepose
13
+ train_ann_file: ./datasets/texture_ann/train
14
+ val_ann_file: ./datasets/texture_ann/val
15
+ test_ann_file: ./datasets/texture_ann/test
16
+ downsample_factor: 2
17
+
18
+ model_type: VQGANTextureAwareSpatialHierarchyInferenceModel
19
+ # network configs
20
+ embed_dim: 256
21
+ n_embed: 1024
22
+ codebook_spatial_size: 2
23
+
24
+ # bottom level vqvae
25
+ bot_n_embed: 512
26
+ bot_double_z: false
27
+ bot_z_channels: 256
28
+ bot_resolution: 512
29
+ bot_in_channels: 3
30
+ bot_out_ch: 3
31
+ bot_ch: 128
32
+ bot_ch_mult: [1, 1, 2, 4]
33
+ bot_num_res_blocks: 2
34
+ bot_attn_resolutions: [64]
35
+ bot_dropout: 0.0
36
+ bot_vae_path: ./pretrained_models/vqvae_bottom.pth
37
+
38
+ # top level vqgan
39
+ top_double_z: false
40
+ top_z_channels: 256
41
+ top_resolution: 512
42
+ top_in_channels: 3
43
+ top_out_ch: 3
44
+ top_ch: 128
45
+ top_ch_mult: [1, 1, 2, 2, 4]
46
+ top_num_res_blocks: 2
47
+ top_attn_resolutions: [32]
48
+ top_dropout: 0.0
49
+ top_vae_path: ./pretrained_models/vqvae_top.pth
50
+
51
+ # unet configs
52
+ encoder_in_channels: 256
53
+ fc_in_channels: 64
54
+ fc_in_index: 4
55
+ fc_channels: 64
56
+ fc_num_convs: 1
57
+ fc_concat_input: False
58
+ fc_dropout_ratio: 0.1
59
+ fc_num_classes: 512
60
+ fc_align_corners: False
61
+
62
+ disc_layers: 3
63
+ disc_weight_max: 1
64
+ disc_start_step: 30001
65
+ n_channels: 3
66
+ ndf: 64
67
+ nf: 128
68
+ perceptual_weight: 1.0
69
+
70
+ num_segm_classes: 24
71
+
72
+ # training configs
73
+ val_freq: 5
74
+ print_freq: 100
75
+ weight_decay: 0
76
+ manual_seed: 2021
77
+ num_epochs: 100
78
+ lr: !!float 1.0e-04
79
+ lr_decay: step
80
+ gamma: 1.0
81
+ step: 50
82
+ optimizer: Adam
83
+ loss_function: cross_entropy
84
+
Text2Human/configs/parsing_gen.yml ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: parsing_generation
2
+ use_tb_logger: true
3
+ set_CUDA_VISIBLE_DEVICES: ~
4
+ gpu_ids: [3]
5
+
6
+ # dataset configs
7
+ batch_size: 8
8
+ num_workers: 4
9
+ segm_dir: ./datasets/segm
10
+ pose_dir: ./datasets/densepose
11
+ train_ann_file: ./datasets/shape_ann/train_ann_file.txt
12
+ val_ann_file: ./datasets/shape_ann/val_ann_file.txt
13
+ test_ann_file: ./datasets/shape_ann/test_ann_file.txt
14
+ downsample_factor: 2
15
+
16
+ model_type: ParsingGenModel
17
+ # network configs
18
+ embedder_dim: 8
19
+ embedder_out_dim: 128
20
+ attr_class_num: [2, 4, 6, 5, 4, 3, 5, 5, 3, 2, 2, 2, 2, 2, 2]
21
+ encoder_in_channels: 1
22
+ fc_in_channels: 64
23
+ fc_in_index: 4
24
+ fc_channels: 64
25
+ fc_num_convs: 1
26
+ fc_concat_input: False
27
+ fc_dropout_ratio: 0.1
28
+ fc_num_classes: 24
29
+ fc_align_corners: False
30
+
31
+ # training configs
32
+ val_freq: 5
33
+ print_freq: 100
34
+ weight_decay: 0
35
+ manual_seed: 2021
36
+ num_epochs: 100
37
+ lr: !!float 1e-4
38
+ lr_decay: step
39
+ gamma: 0.1
40
+ step: 50
Text2Human/configs/parsing_token.yml ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: parsing_tokenization
2
+ use_tb_logger: true
3
+ set_CUDA_VISIBLE_DEVICES: ~
4
+ gpu_ids: [3]
5
+
6
+ # dataset configs
7
+ batch_size: 4
8
+ num_workers: 4
9
+ train_img_dir: ./datasets/train_images
10
+ test_img_dir: ./datasets/test_images
11
+ segm_dir: ./datasets/segm
12
+ pose_dir: ./datasets/densepose
13
+ train_ann_file: ./datasets/texture_ann/train
14
+ val_ann_file: ./datasets/texture_ann/val
15
+ test_ann_file: ./datasets/texture_ann/test
16
+ downsample_factor: 2
17
+
18
+ model_type: VQSegmentationModel
19
+ # network configs
20
+ embed_dim: 32
21
+ n_embed: 1024
22
+ image_key: "segmentation"
23
+ n_labels: 24
24
+ double_z: false
25
+ z_channels: 32
26
+ resolution: 512
27
+ in_channels: 24
28
+ out_ch: 24
29
+ ch: 64
30
+ ch_mult: [1, 1, 2, 2, 4]
31
+ num_res_blocks: 1
32
+ attn_resolutions: [16]
33
+ dropout: 0.0
34
+
35
+ num_segm_classes: 24
36
+
37
+
38
+ # training configs
39
+ val_freq: 5
40
+ print_freq: 100
41
+ weight_decay: 0
42
+ manual_seed: 2021
43
+ num_epochs: 100
44
+ lr: !!float 4.5e-05
45
+ lr_decay: step
46
+ gamma: 0.1
47
+ step: 50
Text2Human/configs/sample_from_parsing.yml ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: sample_from_parsing
2
+ use_tb_logger: true
3
+ set_CUDA_VISIBLE_DEVICES: ~
4
+ gpu_ids: [3]
5
+
6
+ # dataset configs
7
+ batch_size: 4
8
+ num_workers: 4
9
+ test_img_dir: ./datasets/test_images
10
+ segm_dir: ./datasets/segm
11
+ pose_dir: ./datasets/densepose
12
+ test_ann_file: ./datasets/texture_ann/test
13
+ downsample_factor: 2
14
+
15
+ model_type: SampleFromParsingModel
16
+ # network configs
17
+ embed_dim: 256
18
+ n_embed: 1024
19
+ codebook_spatial_size: 2
20
+
21
+ # bottom level vqvae
22
+ bot_n_embed: 512
23
+ bot_codebook_spatial_size: 2
24
+ bot_double_z: false
25
+ bot_z_channels: 256
26
+ bot_resolution: 512
27
+ bot_in_channels: 3
28
+ bot_out_ch: 3
29
+ bot_ch: 128
30
+ bot_ch_mult: [1, 1, 2, 4]
31
+ bot_num_res_blocks: 2
32
+ bot_attn_resolutions: [64]
33
+ bot_dropout: 0.0
34
+ bot_vae_path: ./pretrained_models/vqvae_bottom.pth
35
+
36
+ # top level vqgan
37
+ top_double_z: false
38
+ top_z_channels: 256
39
+ top_resolution: 512
40
+ top_in_channels: 3
41
+ top_out_ch: 3
42
+ top_ch: 128
43
+ top_ch_mult: [1, 1, 2, 2, 4]
44
+ top_num_res_blocks: 2
45
+ top_attn_resolutions: [32]
46
+ top_dropout: 0.0
47
+ top_vae_path: ./pretrained_models/vqvae_top.pth
48
+
49
+ # unet configs
50
+ index_pred_encoder_in_channels: 256
51
+ index_pred_fc_in_channels: 64
52
+ index_pred_fc_in_index: 4
53
+ index_pred_fc_channels: 64
54
+ index_pred_fc_num_convs: 1
55
+ index_pred_fc_concat_input: False
56
+ index_pred_fc_dropout_ratio: 0.1
57
+ index_pred_fc_num_classes: 512
58
+ index_pred_fc_align_corners: False
59
+ pretrained_index_network: ./pretrained_models/index_pred_net.pth
60
+
61
+ # segmentation tokenization
62
+ segm_double_z: false
63
+ segm_z_channels: 32
64
+ segm_resolution: 512
65
+ segm_in_channels: 24
66
+ segm_out_ch: 24
67
+ segm_ch: 64
68
+ segm_ch_mult: [1, 1, 2, 2, 4]
69
+ segm_num_res_blocks: 1
70
+ segm_attn_resolutions: [16]
71
+ segm_dropout: 0.0
72
+ segm_num_segm_classes: 24
73
+ segm_n_embed: 1024
74
+ segm_embed_dim: 32
75
+ segm_token_path: ./pretrained_models/parsing_token.pth
76
+
77
+ # sampler configs
78
+ codebook_size: 18432
79
+ segm_codebook_size: 1024
80
+ texture_codebook_size: 18
81
+ bert_n_emb: 512
82
+ bert_n_layers: 24
83
+ bert_n_head: 8
84
+ block_size: 512 # 32 x 16
85
+ latent_shape: [32, 16]
86
+ embd_pdrop: 0.0
87
+ resid_pdrop: 0.0
88
+ attn_pdrop: 0.0
89
+ num_head: 18
90
+ pretrained_sampler: ./pretrained_models/sampler.pth
91
+
92
+ manual_seed: 2021
93
+ sample_steps: 256
Text2Human/configs/sample_from_pose.yml ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: sample_from_pose
2
+ use_tb_logger: true
3
+ set_CUDA_VISIBLE_DEVICES: ~
4
+ gpu_ids: [3]
5
+
6
+ # dataset configs
7
+ batch_size: 4
8
+ num_workers: 4
9
+ pose_dir: ./datasets/densepose
10
+ texture_ann_file: ./datasets/texture_ann/test
11
+ shape_ann_path: ./datasets/shape_ann/test_ann_file.txt
12
+ downsample_factor: 2
13
+
14
+ model_type: SampleFromPoseModel
15
+ # network configs
16
+ embed_dim: 256
17
+ n_embed: 1024
18
+ codebook_spatial_size: 2
19
+
20
+ # bottom level vqgan
21
+ bot_n_embed: 512
22
+ bot_codebook_spatial_size: 2
23
+ bot_double_z: false
24
+ bot_z_channels: 256
25
+ bot_resolution: 512
26
+ bot_in_channels: 3
27
+ bot_out_ch: 3
28
+ bot_ch: 128
29
+ bot_ch_mult: [1, 1, 2, 4]
30
+ bot_num_res_blocks: 2
31
+ bot_attn_resolutions: [64]
32
+ bot_dropout: 0.0
33
+ bot_vae_path: ./pretrained_models/vqvae_bottom.pth
34
+
35
+ # top level vqgan
36
+ top_double_z: false
37
+ top_z_channels: 256
38
+ top_resolution: 512
39
+ top_in_channels: 3
40
+ top_out_ch: 3
41
+ top_ch: 128
42
+ top_ch_mult: [1, 1, 2, 2, 4]
43
+ top_num_res_blocks: 2
44
+ top_attn_resolutions: [32]
45
+ top_dropout: 0.0
46
+ top_vae_path: ./pretrained_models/vqvae_top.pth
47
+
48
+ # unet configs
49
+ index_pred_encoder_in_channels: 256
50
+ index_pred_fc_in_channels: 64
51
+ index_pred_fc_in_index: 4
52
+ index_pred_fc_channels: 64
53
+ index_pred_fc_num_convs: 1
54
+ index_pred_fc_concat_input: False
55
+ index_pred_fc_dropout_ratio: 0.1
56
+ index_pred_fc_num_classes: 512
57
+ index_pred_fc_align_corners: False
58
+ pretrained_index_network: ./pretrained_models/index_pred_net.pth
59
+
60
+ # segmentation tokenization
61
+ segm_double_z: false
62
+ segm_z_channels: 32
63
+ segm_resolution: 512
64
+ segm_in_channels: 24
65
+ segm_out_ch: 24
66
+ segm_ch: 64
67
+ segm_ch_mult: [1, 1, 2, 2, 4]
68
+ segm_num_res_blocks: 1
69
+ segm_attn_resolutions: [16]
70
+ segm_dropout: 0.0
71
+ segm_num_segm_classes: 24
72
+ segm_n_embed: 1024
73
+ segm_embed_dim: 32
74
+ segm_token_path: ./pretrained_models/parsing_token.pth
75
+
76
+ # sampler configs
77
+ codebook_size: 18432
78
+ segm_codebook_size: 1024
79
+ texture_codebook_size: 18
80
+ bert_n_emb: 512
81
+ bert_n_layers: 24
82
+ bert_n_head: 8
83
+ block_size: 512 # 32 x 16
84
+ latent_shape: [32, 16]
85
+ embd_pdrop: 0.0
86
+ resid_pdrop: 0.0
87
+ attn_pdrop: 0.0
88
+ num_head: 18
89
+ pretrained_sampler: ./pretrained_models/sampler.pth
90
+
91
+ # shape network configs
92
+ shape_embedder_dim: 8
93
+ shape_embedder_out_dim: 128
94
+ shape_attr_class_num: [2, 4, 6, 5, 4, 3, 5, 5, 3, 2, 2, 2, 2, 2, 2]
95
+ shape_encoder_in_channels: 1
96
+ shape_fc_in_channels: 64
97
+ shape_fc_in_index: 4
98
+ shape_fc_channels: 64
99
+ shape_fc_num_convs: 1
100
+ shape_fc_concat_input: False
101
+ shape_fc_dropout_ratio: 0.1
102
+ shape_fc_num_classes: 24
103
+ shape_fc_align_corners: False
104
+ pretrained_parsing_gen: ./pretrained_models/parsing_gen.pth
105
+
106
+ manual_seed: 2021
107
+ sample_steps: 256
Text2Human/configs/sampler.yml ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: sampler
2
+ use_tb_logger: true
3
+ set_CUDA_VISIBLE_DEVICES: ~
4
+ gpu_ids: [3]
5
+
6
+ # dataset configs
7
+ batch_size: 4
8
+ num_workers: 1
9
+ train_img_dir: ./datasets/train_images
10
+ test_img_dir: ./datasets/test_images
11
+ segm_dir: ./datasets/segm
12
+ pose_dir: ./datasets/densepose
13
+ train_ann_file: ./datasets/texture_ann/train
14
+ val_ann_file: ./datasets/texture_ann/val
15
+ test_ann_file: ./datasets/texture_ann/test
16
+ downsample_factor: 2
17
+
18
+ # pretrained models
19
+ img_ae_path: ./pretrained_models/vqvae_top.pth
20
+ segm_ae_path: ./pretrained_models/parsing_token.pth
21
+
22
+ model_type: TransformerTextureAwareModel
23
+ # network configs
24
+
25
+ # image autoencoder
26
+ img_embed_dim: 256
27
+ img_n_embed: 1024
28
+ img_double_z: false
29
+ img_z_channels: 256
30
+ img_resolution: 512
31
+ img_in_channels: 3
32
+ img_out_ch: 3
33
+ img_ch: 128
34
+ img_ch_mult: [1, 1, 2, 2, 4]
35
+ img_num_res_blocks: 2
36
+ img_attn_resolutions: [32]
37
+ img_dropout: 0.0
38
+
39
+ # segmentation tokenization
40
+ segm_double_z: false
41
+ segm_z_channels: 32
42
+ segm_resolution: 512
43
+ segm_in_channels: 24
44
+ segm_out_ch: 24
45
+ segm_ch: 64
46
+ segm_ch_mult: [1, 1, 2, 2, 4]
47
+ segm_num_res_blocks: 1
48
+ segm_attn_resolutions: [16]
49
+ segm_dropout: 0.0
50
+ segm_num_segm_classes: 24
51
+ segm_n_embed: 1024
52
+ segm_embed_dim: 32
53
+
54
+ # sampler configs
55
+ codebook_size: 18432
56
+ segm_codebook_size: 1024
57
+ texture_codebook_size: 18
58
+ bert_n_emb: 512
59
+ bert_n_layers: 24
60
+ bert_n_head: 8
61
+ block_size: 512 # 32 x 16
62
+ latent_shape: [32, 16]
63
+ embd_pdrop: 0.0
64
+ resid_pdrop: 0.0
65
+ attn_pdrop: 0.0
66
+ num_head: 18
67
+
68
+ # loss configs
69
+ loss_type: reweighted_elbo
70
+ mask_schedule: random
71
+
72
+ sample_steps: 256
73
+
74
+ # training configs
75
+ val_freq: 5
76
+ print_freq: 100
77
+ weight_decay: 0
78
+ manual_seed: 2021
79
+ num_epochs: 100
80
+ lr: !!float 1e-4
81
+ lr_decay: step
82
+ gamma: 1.0
83
+ step: 50
Text2Human/configs/vqvae_bottom.yml ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: vqvae_bottom
2
+ use_tb_logger: true
3
+ set_CUDA_VISIBLE_DEVICES: ~
4
+ gpu_ids: [3]
5
+
6
+ # dataset configs
7
+ batch_size: 4
8
+ num_workers: 4
9
+ train_img_dir: ./datasets/train_images
10
+ test_img_dir: ./datasets/test_images
11
+ segm_dir: ./datasets/segm
12
+ pose_dir: ./datasets/densepose
13
+ train_ann_file: ./datasets/texture_ann/train
14
+ val_ann_file: ./datasets/texture_ann/val
15
+ test_ann_file: ./datasets/texture_ann/test
16
+ downsample_factor: 2
17
+
18
+ model_type: HierarchyVQSpatialTextureAwareModel
19
+ # network configs
20
+ embed_dim: 256
21
+ n_embed: 1024
22
+ codebook_spatial_size: 2
23
+
24
+ # bottom level vqvae
25
+ bot_n_embed: 512
26
+ bot_double_z: false
27
+ bot_z_channels: 256
28
+ bot_resolution: 512
29
+ bot_in_channels: 3
30
+ bot_out_ch: 3
31
+ bot_ch: 128
32
+ bot_ch_mult: [1, 1, 2, 4]
33
+ bot_num_res_blocks: 2
34
+ bot_attn_resolutions: [64]
35
+ bot_dropout: 0.0
36
+
37
+ # top level vqgan
38
+ top_double_z: false
39
+ top_z_channels: 256
40
+ top_resolution: 512
41
+ top_in_channels: 3
42
+ top_out_ch: 3
43
+ top_ch: 128
44
+ top_ch_mult: [1, 1, 2, 2, 4]
45
+ top_num_res_blocks: 2
46
+ top_attn_resolutions: [32]
47
+ top_dropout: 0.0
48
+ top_vae_path: ./pretrained_models/vqvae_top.pth
49
+
50
+ fix_decoder: false
51
+
52
+ disc_layers: 3
53
+ disc_weight_max: 1
54
+ disc_start_step: 1
55
+ n_channels: 3
56
+ ndf: 64
57
+ nf: 128
58
+ perceptual_weight: 1.0
59
+
60
+ num_segm_classes: 24
61
+
62
+ # training configs
63
+ val_freq: 5
64
+ print_freq: 100
65
+ weight_decay: 0
66
+ manual_seed: 2021
67
+ num_epochs: 1000
68
+ lr: !!float 1.0e-04
69
+ lr_decay: step
70
+ gamma: 1.0
71
+ step: 50
72
+
Text2Human/configs/vqvae_top.yml ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: vqvae_top
2
+ use_tb_logger: true
3
+ set_CUDA_VISIBLE_DEVICES: ~
4
+ gpu_ids: [3]
5
+
6
+ # dataset configs
7
+ batch_size: 4
8
+ num_workers: 4
9
+ train_img_dir: ./datasets/train_images
10
+ test_img_dir: ./datasets/test_images
11
+ segm_dir: ./datasets/segm
12
+ pose_dir: ./datasets/densepose
13
+ train_ann_file: ./datasets/texture_ann/train
14
+ val_ann_file: ./datasets/texture_ann/val
15
+ test_ann_file: ./datasets/texture_ann/test
16
+ downsample_factor: 2
17
+
18
+ model_type: VQImageSegmTextureModel
19
+ # network configs
20
+ embed_dim: 256
21
+ n_embed: 1024
22
+ double_z: false
23
+ z_channels: 256
24
+ resolution: 512
25
+ in_channels: 3
26
+ out_ch: 3
27
+ ch: 128
28
+ ch_mult: [1, 1, 2, 2, 4]
29
+ num_res_blocks: 2
30
+ attn_resolutions: [32]
31
+ dropout: 0.0
32
+
33
+ disc_layers: 3
34
+ disc_weight_max: 0
35
+ disc_start_step: 3000000000000000000000000001
36
+ n_channels: 3
37
+ ndf: 64
38
+ nf: 128
39
+ perceptual_weight: 1.0
40
+
41
+ num_segm_classes: 24
42
+
43
+
44
+ # training configs
45
+ val_freq: 5
46
+ print_freq: 100
47
+ weight_decay: 0
48
+ manual_seed: 2021
49
+ num_epochs: 1000
50
+ lr: !!float 1.0e-04
51
+ lr_decay: step
52
+ gamma: 1.0
53
+ step: 50
models/__init__.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import importlib
3
+ import logging
4
+ import os.path as osp
5
+
6
+ # automatically scan and import model modules
7
+ # scan all the files under the 'models' folder and collect files ending with
8
+ # '_model.py'
9
+ model_folder = osp.dirname(osp.abspath(__file__))
10
+ model_filenames = [
11
+ osp.splitext(osp.basename(v))[0]
12
+ for v in glob.glob(f'{model_folder}/*_model.py')
13
+ ]
14
+ # import all the model modules
15
+ _model_modules = [
16
+ importlib.import_module(f'models.{file_name}')
17
+ for file_name in model_filenames
18
+ ]
19
+
20
+
21
+ def create_model(opt):
22
+ """Create model.
23
+
24
+ Args:
25
+ opt (dict): Configuration. It constains:
26
+ model_type (str): Model type.
27
+ """
28
+ model_type = opt['model_type']
29
+
30
+ # dynamically instantiation
31
+ for module in _model_modules:
32
+ model_cls = getattr(module, model_type, None)
33
+ if model_cls is not None:
34
+ break
35
+ if model_cls is None:
36
+ raise ValueError(f'Model {model_type} is not found.')
37
+
38
+ model = model_cls(opt)
39
+
40
+ logger = logging.getLogger('base')
41
+ logger.info(f'Model [{model.__class__.__name__}] is created.')
42
+ return model
models/archs/__init__.py ADDED
File without changes
models/archs/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (126 Bytes). View file
 
models/archs/__pycache__/fcn_arch.cpython-38.pyc ADDED
Binary file (10.5 kB). View file
 
models/archs/__pycache__/shape_attr_embedding_arch.cpython-38.pyc ADDED
Binary file (1.33 kB). View file
 
models/archs/__pycache__/transformer_arch.cpython-38.pyc ADDED
Binary file (7.61 kB). View file
 
models/archs/__pycache__/unet_arch.cpython-38.pyc ADDED
Binary file (21.9 kB). View file
 
models/archs/__pycache__/vqgan_arch.cpython-38.pyc ADDED
Binary file (24.5 kB). View file
 
models/archs/fcn_arch.py ADDED
@@ -0,0 +1,418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from mmcv.cnn import ConvModule, normal_init
4
+ from mmseg.ops import resize
5
+
6
+
7
+ class BaseDecodeHead(nn.Module):
8
+ """Base class for BaseDecodeHead.
9
+
10
+ Args:
11
+ in_channels (int|Sequence[int]): Input channels.
12
+ channels (int): Channels after modules, before conv_seg.
13
+ num_classes (int): Number of classes.
14
+ dropout_ratio (float): Ratio of dropout layer. Default: 0.1.
15
+ conv_cfg (dict|None): Config of conv layers. Default: None.
16
+ norm_cfg (dict|None): Config of norm layers. Default: None.
17
+ act_cfg (dict): Config of activation layers.
18
+ Default: dict(type='ReLU')
19
+ in_index (int|Sequence[int]): Input feature index. Default: -1
20
+ input_transform (str|None): Transformation type of input features.
21
+ Options: 'resize_concat', 'multiple_select', None.
22
+ 'resize_concat': Multiple feature maps will be resize to the
23
+ same size as first one and than concat together.
24
+ Usually used in FCN head of HRNet.
25
+ 'multiple_select': Multiple feature maps will be bundle into
26
+ a list and passed into decode head.
27
+ None: Only one select feature map is allowed.
28
+ Default: None.
29
+ loss_decode (dict): Config of decode loss.
30
+ Default: dict(type='CrossEntropyLoss').
31
+ ignore_index (int | None): The label index to be ignored. When using
32
+ masked BCE loss, ignore_index should be set to None. Default: 255
33
+ sampler (dict|None): The config of segmentation map sampler.
34
+ Default: None.
35
+ align_corners (bool): align_corners argument of F.interpolate.
36
+ Default: False.
37
+ """
38
+
39
+ def __init__(self,
40
+ in_channels,
41
+ channels,
42
+ *,
43
+ num_classes,
44
+ dropout_ratio=0.1,
45
+ conv_cfg=None,
46
+ norm_cfg=dict(type='BN'),
47
+ act_cfg=dict(type='ReLU'),
48
+ in_index=-1,
49
+ input_transform=None,
50
+ ignore_index=255,
51
+ align_corners=False):
52
+ super(BaseDecodeHead, self).__init__()
53
+ self._init_inputs(in_channels, in_index, input_transform)
54
+ self.channels = channels
55
+ self.num_classes = num_classes
56
+ self.dropout_ratio = dropout_ratio
57
+ self.conv_cfg = conv_cfg
58
+ self.norm_cfg = norm_cfg
59
+ self.act_cfg = act_cfg
60
+ self.in_index = in_index
61
+
62
+ self.ignore_index = ignore_index
63
+ self.align_corners = align_corners
64
+
65
+ self.conv_seg = nn.Conv2d(channels, num_classes, kernel_size=1)
66
+ if dropout_ratio > 0:
67
+ self.dropout = nn.Dropout2d(dropout_ratio)
68
+ else:
69
+ self.dropout = None
70
+
71
+ def extra_repr(self):
72
+ """Extra repr."""
73
+ s = f'input_transform={self.input_transform}, ' \
74
+ f'ignore_index={self.ignore_index}, ' \
75
+ f'align_corners={self.align_corners}'
76
+ return s
77
+
78
+ def _init_inputs(self, in_channels, in_index, input_transform):
79
+ """Check and initialize input transforms.
80
+
81
+ The in_channels, in_index and input_transform must match.
82
+ Specifically, when input_transform is None, only single feature map
83
+ will be selected. So in_channels and in_index must be of type int.
84
+ When input_transform
85
+
86
+ Args:
87
+ in_channels (int|Sequence[int]): Input channels.
88
+ in_index (int|Sequence[int]): Input feature index.
89
+ input_transform (str|None): Transformation type of input features.
90
+ Options: 'resize_concat', 'multiple_select', None.
91
+ 'resize_concat': Multiple feature maps will be resize to the
92
+ same size as first one and than concat together.
93
+ Usually used in FCN head of HRNet.
94
+ 'multiple_select': Multiple feature maps will be bundle into
95
+ a list and passed into decode head.
96
+ None: Only one select feature map is allowed.
97
+ """
98
+
99
+ if input_transform is not None:
100
+ assert input_transform in ['resize_concat', 'multiple_select']
101
+ self.input_transform = input_transform
102
+ self.in_index = in_index
103
+ if input_transform is not None:
104
+ assert isinstance(in_channels, (list, tuple))
105
+ assert isinstance(in_index, (list, tuple))
106
+ assert len(in_channels) == len(in_index)
107
+ if input_transform == 'resize_concat':
108
+ self.in_channels = sum(in_channels)
109
+ else:
110
+ self.in_channels = in_channels
111
+ else:
112
+ assert isinstance(in_channels, int)
113
+ assert isinstance(in_index, int)
114
+ self.in_channels = in_channels
115
+
116
+ def init_weights(self):
117
+ """Initialize weights of classification layer."""
118
+ normal_init(self.conv_seg, mean=0, std=0.01)
119
+
120
+ def _transform_inputs(self, inputs):
121
+ """Transform inputs for decoder.
122
+
123
+ Args:
124
+ inputs (list[Tensor]): List of multi-level img features.
125
+
126
+ Returns:
127
+ Tensor: The transformed inputs
128
+ """
129
+
130
+ if self.input_transform == 'resize_concat':
131
+ inputs = [inputs[i] for i in self.in_index]
132
+ upsampled_inputs = [
133
+ resize(
134
+ input=x,
135
+ size=inputs[0].shape[2:],
136
+ mode='bilinear',
137
+ align_corners=self.align_corners) for x in inputs
138
+ ]
139
+ inputs = torch.cat(upsampled_inputs, dim=1)
140
+ elif self.input_transform == 'multiple_select':
141
+ inputs = [inputs[i] for i in self.in_index]
142
+ else:
143
+ inputs = inputs[self.in_index]
144
+
145
+ return inputs
146
+
147
+ def forward(self, inputs):
148
+ """Placeholder of forward function."""
149
+ pass
150
+
151
+ def cls_seg(self, feat):
152
+ """Classify each pixel."""
153
+ if self.dropout is not None:
154
+ feat = self.dropout(feat)
155
+ output = self.conv_seg(feat)
156
+ return output
157
+
158
+
159
+ class FCNHead(BaseDecodeHead):
160
+ """Fully Convolution Networks for Semantic Segmentation.
161
+
162
+ This head is implemented of `FCNNet <https://arxiv.org/abs/1411.4038>`_.
163
+
164
+ Args:
165
+ num_convs (int): Number of convs in the head. Default: 2.
166
+ kernel_size (int): The kernel size for convs in the head. Default: 3.
167
+ concat_input (bool): Whether concat the input and output of convs
168
+ before classification layer.
169
+ """
170
+
171
+ def __init__(self,
172
+ num_convs=2,
173
+ kernel_size=3,
174
+ concat_input=True,
175
+ **kwargs):
176
+ assert num_convs >= 0
177
+ self.num_convs = num_convs
178
+ self.concat_input = concat_input
179
+ self.kernel_size = kernel_size
180
+ super(FCNHead, self).__init__(**kwargs)
181
+ if num_convs == 0:
182
+ assert self.in_channels == self.channels
183
+
184
+ convs = []
185
+ convs.append(
186
+ ConvModule(
187
+ self.in_channels,
188
+ self.channels,
189
+ kernel_size=kernel_size,
190
+ padding=kernel_size // 2,
191
+ conv_cfg=self.conv_cfg,
192
+ norm_cfg=self.norm_cfg,
193
+ act_cfg=self.act_cfg))
194
+ for i in range(num_convs - 1):
195
+ convs.append(
196
+ ConvModule(
197
+ self.channels,
198
+ self.channels,
199
+ kernel_size=kernel_size,
200
+ padding=kernel_size // 2,
201
+ conv_cfg=self.conv_cfg,
202
+ norm_cfg=self.norm_cfg,
203
+ act_cfg=self.act_cfg))
204
+ if num_convs == 0:
205
+ self.convs = nn.Identity()
206
+ else:
207
+ self.convs = nn.Sequential(*convs)
208
+ if self.concat_input:
209
+ self.conv_cat = ConvModule(
210
+ self.in_channels + self.channels,
211
+ self.channels,
212
+ kernel_size=kernel_size,
213
+ padding=kernel_size // 2,
214
+ conv_cfg=self.conv_cfg,
215
+ norm_cfg=self.norm_cfg,
216
+ act_cfg=self.act_cfg)
217
+
218
+ def forward(self, inputs):
219
+ """Forward function."""
220
+ x = self._transform_inputs(inputs)
221
+ output = self.convs(x)
222
+ if self.concat_input:
223
+ output = self.conv_cat(torch.cat([x, output], dim=1))
224
+ output = self.cls_seg(output)
225
+ return output
226
+
227
+
228
+ class MultiHeadFCNHead(nn.Module):
229
+ """Fully Convolution Networks for Semantic Segmentation.
230
+
231
+ This head is implemented of `FCNNet <https://arxiv.org/abs/1411.4038>`_.
232
+
233
+ Args:
234
+ num_convs (int): Number of convs in the head. Default: 2.
235
+ kernel_size (int): The kernel size for convs in the head. Default: 3.
236
+ concat_input (bool): Whether concat the input and output of convs
237
+ before classification layer.
238
+ """
239
+
240
+ def __init__(self,
241
+ in_channels,
242
+ channels,
243
+ *,
244
+ num_classes,
245
+ dropout_ratio=0.1,
246
+ conv_cfg=None,
247
+ norm_cfg=dict(type='BN'),
248
+ act_cfg=dict(type='ReLU'),
249
+ in_index=-1,
250
+ input_transform=None,
251
+ ignore_index=255,
252
+ align_corners=False,
253
+ num_convs=2,
254
+ kernel_size=3,
255
+ concat_input=True,
256
+ num_head=18,
257
+ **kwargs):
258
+ super(MultiHeadFCNHead, self).__init__()
259
+ assert num_convs >= 0
260
+ self.num_convs = num_convs
261
+ self.concat_input = concat_input
262
+ self.kernel_size = kernel_size
263
+ self._init_inputs(in_channels, in_index, input_transform)
264
+ self.channels = channels
265
+ self.num_classes = num_classes
266
+ self.dropout_ratio = dropout_ratio
267
+ self.conv_cfg = conv_cfg
268
+ self.norm_cfg = norm_cfg
269
+ self.act_cfg = act_cfg
270
+ self.in_index = in_index
271
+ self.num_head = num_head
272
+
273
+ self.ignore_index = ignore_index
274
+ self.align_corners = align_corners
275
+
276
+ if dropout_ratio > 0:
277
+ self.dropout = nn.Dropout2d(dropout_ratio)
278
+
279
+ conv_seg_head_list = []
280
+ for _ in range(self.num_head):
281
+ conv_seg_head_list.append(
282
+ nn.Conv2d(channels, num_classes, kernel_size=1))
283
+
284
+ self.conv_seg_head_list = nn.ModuleList(conv_seg_head_list)
285
+
286
+ self.init_weights()
287
+
288
+ if num_convs == 0:
289
+ assert self.in_channels == self.channels
290
+
291
+ convs_list = []
292
+ conv_cat_list = []
293
+
294
+ for _ in range(self.num_head):
295
+ convs = []
296
+ convs.append(
297
+ ConvModule(
298
+ self.in_channels,
299
+ self.channels,
300
+ kernel_size=kernel_size,
301
+ padding=kernel_size // 2,
302
+ conv_cfg=self.conv_cfg,
303
+ norm_cfg=self.norm_cfg,
304
+ act_cfg=self.act_cfg))
305
+ for _ in range(num_convs - 1):
306
+ convs.append(
307
+ ConvModule(
308
+ self.channels,
309
+ self.channels,
310
+ kernel_size=kernel_size,
311
+ padding=kernel_size // 2,
312
+ conv_cfg=self.conv_cfg,
313
+ norm_cfg=self.norm_cfg,
314
+ act_cfg=self.act_cfg))
315
+ if num_convs == 0:
316
+ convs_list.append(nn.Identity())
317
+ else:
318
+ convs_list.append(nn.Sequential(*convs))
319
+ if self.concat_input:
320
+ conv_cat_list.append(
321
+ ConvModule(
322
+ self.in_channels + self.channels,
323
+ self.channels,
324
+ kernel_size=kernel_size,
325
+ padding=kernel_size // 2,
326
+ conv_cfg=self.conv_cfg,
327
+ norm_cfg=self.norm_cfg,
328
+ act_cfg=self.act_cfg))
329
+
330
+ self.convs_list = nn.ModuleList(convs_list)
331
+ self.conv_cat_list = nn.ModuleList(conv_cat_list)
332
+
333
+ def forward(self, inputs):
334
+ """Forward function."""
335
+ x = self._transform_inputs(inputs)
336
+
337
+ output_list = []
338
+ for head_idx in range(self.num_head):
339
+ output = self.convs_list[head_idx](x)
340
+ if self.concat_input:
341
+ output = self.conv_cat_list[head_idx](
342
+ torch.cat([x, output], dim=1))
343
+ if self.dropout is not None:
344
+ output = self.dropout(output)
345
+ output = self.conv_seg_head_list[head_idx](output)
346
+ output_list.append(output)
347
+
348
+ return output_list
349
+
350
+ def _init_inputs(self, in_channels, in_index, input_transform):
351
+ """Check and initialize input transforms.
352
+
353
+ The in_channels, in_index and input_transform must match.
354
+ Specifically, when input_transform is None, only single feature map
355
+ will be selected. So in_channels and in_index must be of type int.
356
+ When input_transform
357
+
358
+ Args:
359
+ in_channels (int|Sequence[int]): Input channels.
360
+ in_index (int|Sequence[int]): Input feature index.
361
+ input_transform (str|None): Transformation type of input features.
362
+ Options: 'resize_concat', 'multiple_select', None.
363
+ 'resize_concat': Multiple feature maps will be resize to the
364
+ same size as first one and than concat together.
365
+ Usually used in FCN head of HRNet.
366
+ 'multiple_select': Multiple feature maps will be bundle into
367
+ a list and passed into decode head.
368
+ None: Only one select feature map is allowed.
369
+ """
370
+
371
+ if input_transform is not None:
372
+ assert input_transform in ['resize_concat', 'multiple_select']
373
+ self.input_transform = input_transform
374
+ self.in_index = in_index
375
+ if input_transform is not None:
376
+ assert isinstance(in_channels, (list, tuple))
377
+ assert isinstance(in_index, (list, tuple))
378
+ assert len(in_channels) == len(in_index)
379
+ if input_transform == 'resize_concat':
380
+ self.in_channels = sum(in_channels)
381
+ else:
382
+ self.in_channels = in_channels
383
+ else:
384
+ assert isinstance(in_channels, int)
385
+ assert isinstance(in_index, int)
386
+ self.in_channels = in_channels
387
+
388
+ def init_weights(self):
389
+ """Initialize weights of classification layer."""
390
+ for conv_seg_head in self.conv_seg_head_list:
391
+ normal_init(conv_seg_head, mean=0, std=0.01)
392
+
393
+ def _transform_inputs(self, inputs):
394
+ """Transform inputs for decoder.
395
+
396
+ Args:
397
+ inputs (list[Tensor]): List of multi-level img features.
398
+
399
+ Returns:
400
+ Tensor: The transformed inputs
401
+ """
402
+
403
+ if self.input_transform == 'resize_concat':
404
+ inputs = [inputs[i] for i in self.in_index]
405
+ upsampled_inputs = [
406
+ resize(
407
+ input=x,
408
+ size=inputs[0].shape[2:],
409
+ mode='bilinear',
410
+ align_corners=self.align_corners) for x in inputs
411
+ ]
412
+ inputs = torch.cat(upsampled_inputs, dim=1)
413
+ elif self.input_transform == 'multiple_select':
414
+ inputs = [inputs[i] for i in self.in_index]
415
+ else:
416
+ inputs = inputs[self.in_index]
417
+
418
+ return inputs
models/archs/shape_attr_embedding_arch.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch import nn
4
+
5
+
6
+ class ShapeAttrEmbedding(nn.Module):
7
+
8
+ def __init__(self, dim, out_dim, cls_num_list):
9
+ super(ShapeAttrEmbedding, self).__init__()
10
+
11
+ for idx, cls_num in enumerate(cls_num_list):
12
+ setattr(
13
+ self, f'attr_{idx}',
14
+ nn.Sequential(
15
+ nn.Linear(cls_num, dim), nn.LeakyReLU(),
16
+ nn.Linear(dim, dim)))
17
+ self.cls_num_list = cls_num_list
18
+ self.attr_num = len(cls_num_list)
19
+ self.fusion = nn.Sequential(
20
+ nn.Linear(dim * self.attr_num, out_dim), nn.LeakyReLU(),
21
+ nn.Linear(out_dim, out_dim))
22
+
23
+ def forward(self, attr):
24
+ attr_embedding_list = []
25
+ for idx in range(self.attr_num):
26
+ attr_embed_fc = getattr(self, f'attr_{idx}')
27
+ attr_embedding_list.append(
28
+ attr_embed_fc(
29
+ F.one_hot(
30
+ attr[:, idx],
31
+ num_classes=self.cls_num_list[idx]).to(torch.float32)))
32
+ attr_embedding = torch.cat(attr_embedding_list, dim=1)
33
+ attr_embedding = self.fusion(attr_embedding)
34
+
35
+ return attr_embedding
models/archs/transformer_arch.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+
9
+ class CausalSelfAttention(nn.Module):
10
+ """
11
+ A vanilla multi-head masked self-attention layer with a projection at the end.
12
+ It is possible to use torch.nn.MultiheadAttention here but I am including an
13
+ explicit implementation here to show that there is nothing too scary here.
14
+ """
15
+
16
+ def __init__(self, bert_n_emb, bert_n_head, attn_pdrop, resid_pdrop,
17
+ latent_shape, sampler):
18
+ super().__init__()
19
+ assert bert_n_emb % bert_n_head == 0
20
+ # key, query, value projections for all heads
21
+ self.key = nn.Linear(bert_n_emb, bert_n_emb)
22
+ self.query = nn.Linear(bert_n_emb, bert_n_emb)
23
+ self.value = nn.Linear(bert_n_emb, bert_n_emb)
24
+ # regularization
25
+ self.attn_drop = nn.Dropout(attn_pdrop)
26
+ self.resid_drop = nn.Dropout(resid_pdrop)
27
+ # output projection
28
+ self.proj = nn.Linear(bert_n_emb, bert_n_emb)
29
+ self.n_head = bert_n_head
30
+ self.causal = True if sampler == 'autoregressive' else False
31
+ if self.causal:
32
+ block_size = np.prod(latent_shape)
33
+ mask = torch.tril(torch.ones(block_size, block_size))
34
+ self.register_buffer("mask", mask.view(1, 1, block_size,
35
+ block_size))
36
+
37
+ def forward(self, x, layer_past=None):
38
+ B, T, C = x.size()
39
+
40
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
41
+ k = self.key(x).view(B, T, self.n_head,
42
+ C // self.n_head).transpose(1,
43
+ 2) # (B, nh, T, hs)
44
+ q = self.query(x).view(B, T, self.n_head,
45
+ C // self.n_head).transpose(1,
46
+ 2) # (B, nh, T, hs)
47
+ v = self.value(x).view(B, T, self.n_head,
48
+ C // self.n_head).transpose(1,
49
+ 2) # (B, nh, T, hs)
50
+
51
+ present = torch.stack((k, v))
52
+ if self.causal and layer_past is not None:
53
+ past_key, past_value = layer_past
54
+ k = torch.cat((past_key, k), dim=-2)
55
+ v = torch.cat((past_value, v), dim=-2)
56
+
57
+ # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
58
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
59
+
60
+ if self.causal and layer_past is None:
61
+ att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float('-inf'))
62
+
63
+ att = F.softmax(att, dim=-1)
64
+ att = self.attn_drop(att)
65
+ y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
66
+ # re-assemble all head outputs side by side
67
+ y = y.transpose(1, 2).contiguous().view(B, T, C)
68
+
69
+ # output projection
70
+ y = self.resid_drop(self.proj(y))
71
+ return y, present
72
+
73
+
74
+ class Block(nn.Module):
75
+ """ an unassuming Transformer block """
76
+
77
+ def __init__(self, bert_n_emb, resid_pdrop, bert_n_head, attn_pdrop,
78
+ latent_shape, sampler):
79
+ super().__init__()
80
+ self.ln1 = nn.LayerNorm(bert_n_emb)
81
+ self.ln2 = nn.LayerNorm(bert_n_emb)
82
+ self.attn = CausalSelfAttention(bert_n_emb, bert_n_head, attn_pdrop,
83
+ resid_pdrop, latent_shape, sampler)
84
+ self.mlp = nn.Sequential(
85
+ nn.Linear(bert_n_emb, 4 * bert_n_emb),
86
+ nn.GELU(), # nice
87
+ nn.Linear(4 * bert_n_emb, bert_n_emb),
88
+ nn.Dropout(resid_pdrop),
89
+ )
90
+
91
+ def forward(self, x, layer_past=None, return_present=False):
92
+
93
+ attn, present = self.attn(self.ln1(x), layer_past)
94
+ x = x + attn
95
+ x = x + self.mlp(self.ln2(x))
96
+
97
+ if layer_past is not None or return_present:
98
+ return x, present
99
+ return x
100
+
101
+
102
+ class Transformer(nn.Module):
103
+ """ the full GPT language model, with a context size of block_size """
104
+
105
+ def __init__(self,
106
+ codebook_size,
107
+ segm_codebook_size,
108
+ bert_n_emb,
109
+ bert_n_layers,
110
+ bert_n_head,
111
+ block_size,
112
+ latent_shape,
113
+ embd_pdrop,
114
+ resid_pdrop,
115
+ attn_pdrop,
116
+ sampler='absorbing'):
117
+ super().__init__()
118
+
119
+ self.vocab_size = codebook_size + 1
120
+ self.n_embd = bert_n_emb
121
+ self.block_size = block_size
122
+ self.n_layers = bert_n_layers
123
+ self.codebook_size = codebook_size
124
+ self.segm_codebook_size = segm_codebook_size
125
+ self.causal = sampler == 'autoregressive'
126
+ if self.causal:
127
+ self.vocab_size = codebook_size
128
+
129
+ self.tok_emb = nn.Embedding(self.vocab_size, self.n_embd)
130
+ self.pos_emb = nn.Parameter(
131
+ torch.zeros(1, self.block_size, self.n_embd))
132
+ self.segm_emb = nn.Embedding(self.segm_codebook_size, self.n_embd)
133
+ self.start_tok = nn.Parameter(torch.zeros(1, 1, self.n_embd))
134
+ self.drop = nn.Dropout(embd_pdrop)
135
+
136
+ # transformer
137
+ self.blocks = nn.Sequential(*[
138
+ Block(bert_n_emb, resid_pdrop, bert_n_head, attn_pdrop,
139
+ latent_shape, sampler) for _ in range(self.n_layers)
140
+ ])
141
+ # decoder head
142
+ self.ln_f = nn.LayerNorm(self.n_embd)
143
+ self.head = nn.Linear(self.n_embd, self.codebook_size, bias=False)
144
+
145
+ def get_block_size(self):
146
+ return self.block_size
147
+
148
+ def _init_weights(self, module):
149
+ if isinstance(module, (nn.Linear, nn.Embedding)):
150
+ module.weight.data.normal_(mean=0.0, std=0.02)
151
+ if isinstance(module, nn.Linear) and module.bias is not None:
152
+ module.bias.data.zero_()
153
+ elif isinstance(module, nn.LayerNorm):
154
+ module.bias.data.zero_()
155
+ module.weight.data.fill_(1.0)
156
+
157
+ def forward(self, idx, segm_tokens, t=None):
158
+ # each index maps to a (learnable) vector
159
+ token_embeddings = self.tok_emb(idx)
160
+
161
+ segm_embeddings = self.segm_emb(segm_tokens)
162
+
163
+ if self.causal:
164
+ token_embeddings = torch.cat((self.start_tok.repeat(
165
+ token_embeddings.size(0), 1, 1), token_embeddings),
166
+ dim=1)
167
+
168
+ t = token_embeddings.shape[1]
169
+ assert t <= self.block_size, "Cannot forward, model block size is exhausted."
170
+ # each position maps to a (learnable) vector
171
+
172
+ position_embeddings = self.pos_emb[:, :t, :]
173
+
174
+ x = token_embeddings + position_embeddings + segm_embeddings
175
+ x = self.drop(x)
176
+ for block in self.blocks:
177
+ x = block(x)
178
+ x = self.ln_f(x)
179
+ logits = self.head(x)
180
+
181
+ return logits
182
+
183
+
184
+ class TransformerMultiHead(nn.Module):
185
+ """ the full GPT language model, with a context size of block_size """
186
+
187
+ def __init__(self,
188
+ codebook_size,
189
+ segm_codebook_size,
190
+ texture_codebook_size,
191
+ bert_n_emb,
192
+ bert_n_layers,
193
+ bert_n_head,
194
+ block_size,
195
+ latent_shape,
196
+ embd_pdrop,
197
+ resid_pdrop,
198
+ attn_pdrop,
199
+ num_head,
200
+ sampler='absorbing'):
201
+ super().__init__()
202
+
203
+ self.vocab_size = codebook_size + 1
204
+ self.n_embd = bert_n_emb
205
+ self.block_size = block_size
206
+ self.n_layers = bert_n_layers
207
+ self.codebook_size = codebook_size
208
+ self.segm_codebook_size = segm_codebook_size
209
+ self.texture_codebook_size = texture_codebook_size
210
+ self.causal = sampler == 'autoregressive'
211
+ if self.causal:
212
+ self.vocab_size = codebook_size
213
+
214
+ self.tok_emb = nn.Embedding(self.vocab_size, self.n_embd)
215
+ self.pos_emb = nn.Parameter(
216
+ torch.zeros(1, self.block_size, self.n_embd))
217
+ self.segm_emb = nn.Embedding(self.segm_codebook_size, self.n_embd)
218
+ self.texture_emb = nn.Embedding(self.texture_codebook_size,
219
+ self.n_embd)
220
+ self.start_tok = nn.Parameter(torch.zeros(1, 1, self.n_embd))
221
+ self.drop = nn.Dropout(embd_pdrop)
222
+
223
+ # transformer
224
+ self.blocks = nn.Sequential(*[
225
+ Block(bert_n_emb, resid_pdrop, bert_n_head, attn_pdrop,
226
+ latent_shape, sampler) for _ in range(self.n_layers)
227
+ ])
228
+ # decoder head
229
+ self.num_head = num_head
230
+ self.head_class_num = codebook_size // self.num_head
231
+ self.ln_f = nn.LayerNorm(self.n_embd)
232
+ self.head_list = nn.ModuleList([
233
+ nn.Linear(self.n_embd, self.head_class_num, bias=False)
234
+ for _ in range(self.num_head)
235
+ ])
236
+
237
+ def get_block_size(self):
238
+ return self.block_size
239
+
240
+ def _init_weights(self, module):
241
+ if isinstance(module, (nn.Linear, nn.Embedding)):
242
+ module.weight.data.normal_(mean=0.0, std=0.02)
243
+ if isinstance(module, nn.Linear) and module.bias is not None:
244
+ module.bias.data.zero_()
245
+ elif isinstance(module, nn.LayerNorm):
246
+ module.bias.data.zero_()
247
+ module.weight.data.fill_(1.0)
248
+
249
+ def forward(self, idx, segm_tokens, texture_tokens, t=None):
250
+ # each index maps to a (learnable) vector
251
+ token_embeddings = self.tok_emb(idx)
252
+ segm_embeddings = self.segm_emb(segm_tokens)
253
+ texture_embeddings = self.texture_emb(texture_tokens)
254
+
255
+ if self.causal:
256
+ token_embeddings = torch.cat((self.start_tok.repeat(
257
+ token_embeddings.size(0), 1, 1), token_embeddings),
258
+ dim=1)
259
+
260
+ t = token_embeddings.shape[1]
261
+ assert t <= self.block_size, "Cannot forward, model block size is exhausted."
262
+ # each position maps to a (learnable) vector
263
+
264
+ position_embeddings = self.pos_emb[:, :t, :]
265
+
266
+ x = token_embeddings + position_embeddings + segm_embeddings + texture_embeddings
267
+ x = self.drop(x)
268
+ for block in self.blocks:
269
+ x = block(x)
270
+ x = self.ln_f(x)
271
+ logits_list = [self.head_list[i](x) for i in range(self.num_head)]
272
+
273
+ return logits_list
models/archs/unet_arch.py ADDED
@@ -0,0 +1,693 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.utils.checkpoint as cp
4
+ from mmcv.cnn import (UPSAMPLE_LAYERS, ConvModule, build_activation_layer,
5
+ build_norm_layer, build_upsample_layer, constant_init,
6
+ kaiming_init)
7
+ from mmcv.runner import load_checkpoint
8
+ from mmcv.utils.parrots_wrapper import _BatchNorm
9
+ from mmseg.utils import get_root_logger
10
+
11
+
12
+ class UpConvBlock(nn.Module):
13
+ """Upsample convolution block in decoder for UNet.
14
+
15
+ This upsample convolution block consists of one upsample module
16
+ followed by one convolution block. The upsample module expands the
17
+ high-level low-resolution feature map and the convolution block fuses
18
+ the upsampled high-level low-resolution feature map and the low-level
19
+ high-resolution feature map from encoder.
20
+
21
+ Args:
22
+ conv_block (nn.Sequential): Sequential of convolutional layers.
23
+ in_channels (int): Number of input channels of the high-level
24
+ skip_channels (int): Number of input channels of the low-level
25
+ high-resolution feature map from encoder.
26
+ out_channels (int): Number of output channels.
27
+ num_convs (int): Number of convolutional layers in the conv_block.
28
+ Default: 2.
29
+ stride (int): Stride of convolutional layer in conv_block. Default: 1.
30
+ dilation (int): Dilation rate of convolutional layer in conv_block.
31
+ Default: 1.
32
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
33
+ memory while slowing down the training speed. Default: False.
34
+ conv_cfg (dict | None): Config dict for convolution layer.
35
+ Default: None.
36
+ norm_cfg (dict | None): Config dict for normalization layer.
37
+ Default: dict(type='BN').
38
+ act_cfg (dict | None): Config dict for activation layer in ConvModule.
39
+ Default: dict(type='ReLU').
40
+ upsample_cfg (dict): The upsample config of the upsample module in
41
+ decoder. Default: dict(type='InterpConv'). If the size of
42
+ high-level feature map is the same as that of skip feature map
43
+ (low-level feature map from encoder), it does not need upsample the
44
+ high-level feature map and the upsample_cfg is None.
45
+ dcn (bool): Use deformable convoluton in convolutional layer or not.
46
+ Default: None.
47
+ plugins (dict): plugins for convolutional layers. Default: None.
48
+ """
49
+
50
+ def __init__(self,
51
+ conv_block,
52
+ in_channels,
53
+ skip_channels,
54
+ out_channels,
55
+ num_convs=2,
56
+ stride=1,
57
+ dilation=1,
58
+ with_cp=False,
59
+ conv_cfg=None,
60
+ norm_cfg=dict(type='BN'),
61
+ act_cfg=dict(type='ReLU'),
62
+ upsample_cfg=dict(type='InterpConv'),
63
+ dcn=None,
64
+ plugins=None):
65
+ super(UpConvBlock, self).__init__()
66
+ assert dcn is None, 'Not implemented yet.'
67
+ assert plugins is None, 'Not implemented yet.'
68
+
69
+ self.conv_block = conv_block(
70
+ in_channels=2 * skip_channels,
71
+ out_channels=out_channels,
72
+ num_convs=num_convs,
73
+ stride=stride,
74
+ dilation=dilation,
75
+ with_cp=with_cp,
76
+ conv_cfg=conv_cfg,
77
+ norm_cfg=norm_cfg,
78
+ act_cfg=act_cfg,
79
+ dcn=None,
80
+ plugins=None)
81
+ if upsample_cfg is not None:
82
+ self.upsample = build_upsample_layer(
83
+ cfg=upsample_cfg,
84
+ in_channels=in_channels,
85
+ out_channels=skip_channels,
86
+ with_cp=with_cp,
87
+ norm_cfg=norm_cfg,
88
+ act_cfg=act_cfg)
89
+ else:
90
+ self.upsample = ConvModule(
91
+ in_channels,
92
+ skip_channels,
93
+ kernel_size=1,
94
+ stride=1,
95
+ padding=0,
96
+ conv_cfg=conv_cfg,
97
+ norm_cfg=norm_cfg,
98
+ act_cfg=act_cfg)
99
+
100
+ def forward(self, skip, x):
101
+ """Forward function."""
102
+
103
+ x = self.upsample(x)
104
+ out = torch.cat([skip, x], dim=1)
105
+ out = self.conv_block(out)
106
+
107
+ return out
108
+
109
+
110
+ class BasicConvBlock(nn.Module):
111
+ """Basic convolutional block for UNet.
112
+
113
+ This module consists of several plain convolutional layers.
114
+
115
+ Args:
116
+ in_channels (int): Number of input channels.
117
+ out_channels (int): Number of output channels.
118
+ num_convs (int): Number of convolutional layers. Default: 2.
119
+ stride (int): Whether use stride convolution to downsample
120
+ the input feature map. If stride=2, it only uses stride convolution
121
+ in the first convolutional layer to downsample the input feature
122
+ map. Options are 1 or 2. Default: 1.
123
+ dilation (int): Whether use dilated convolution to expand the
124
+ receptive field. Set dilation rate of each convolutional layer and
125
+ the dilation rate of the first convolutional layer is always 1.
126
+ Default: 1.
127
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
128
+ memory while slowing down the training speed. Default: False.
129
+ conv_cfg (dict | None): Config dict for convolution layer.
130
+ Default: None.
131
+ norm_cfg (dict | None): Config dict for normalization layer.
132
+ Default: dict(type='BN').
133
+ act_cfg (dict | None): Config dict for activation layer in ConvModule.
134
+ Default: dict(type='ReLU').
135
+ dcn (bool): Use deformable convoluton in convolutional layer or not.
136
+ Default: None.
137
+ plugins (dict): plugins for convolutional layers. Default: None.
138
+ """
139
+
140
+ def __init__(self,
141
+ in_channels,
142
+ out_channels,
143
+ num_convs=2,
144
+ stride=1,
145
+ dilation=1,
146
+ with_cp=False,
147
+ conv_cfg=None,
148
+ norm_cfg=dict(type='BN'),
149
+ act_cfg=dict(type='ReLU'),
150
+ dcn=None,
151
+ plugins=None):
152
+ super(BasicConvBlock, self).__init__()
153
+ assert dcn is None, 'Not implemented yet.'
154
+ assert plugins is None, 'Not implemented yet.'
155
+
156
+ self.with_cp = with_cp
157
+ convs = []
158
+ for i in range(num_convs):
159
+ convs.append(
160
+ ConvModule(
161
+ in_channels=in_channels if i == 0 else out_channels,
162
+ out_channels=out_channels,
163
+ kernel_size=3,
164
+ stride=stride if i == 0 else 1,
165
+ dilation=1 if i == 0 else dilation,
166
+ padding=1 if i == 0 else dilation,
167
+ conv_cfg=conv_cfg,
168
+ norm_cfg=norm_cfg,
169
+ act_cfg=act_cfg))
170
+
171
+ self.convs = nn.Sequential(*convs)
172
+
173
+ def forward(self, x):
174
+ """Forward function."""
175
+
176
+ if self.with_cp and x.requires_grad:
177
+ out = cp.checkpoint(self.convs, x)
178
+ else:
179
+ out = self.convs(x)
180
+ return out
181
+
182
+
183
+ class DeconvModule(nn.Module):
184
+ """Deconvolution upsample module in decoder for UNet (2X upsample).
185
+
186
+ This module uses deconvolution to upsample feature map in the decoder
187
+ of UNet.
188
+
189
+ Args:
190
+ in_channels (int): Number of input channels.
191
+ out_channels (int): Number of output channels.
192
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
193
+ memory while slowing down the training speed. Default: False.
194
+ norm_cfg (dict | None): Config dict for normalization layer.
195
+ Default: dict(type='BN').
196
+ act_cfg (dict | None): Config dict for activation layer in ConvModule.
197
+ Default: dict(type='ReLU').
198
+ kernel_size (int): Kernel size of the convolutional layer. Default: 4.
199
+ """
200
+
201
+ def __init__(self,
202
+ in_channels,
203
+ out_channels,
204
+ with_cp=False,
205
+ norm_cfg=dict(type='BN'),
206
+ act_cfg=dict(type='ReLU'),
207
+ *,
208
+ kernel_size=4,
209
+ scale_factor=2):
210
+ super(DeconvModule, self).__init__()
211
+
212
+ assert (kernel_size - scale_factor >= 0) and\
213
+ (kernel_size - scale_factor) % 2 == 0,\
214
+ f'kernel_size should be greater than or equal to scale_factor '\
215
+ f'and (kernel_size - scale_factor) should be even numbers, '\
216
+ f'while the kernel size is {kernel_size} and scale_factor is '\
217
+ f'{scale_factor}.'
218
+
219
+ stride = scale_factor
220
+ padding = (kernel_size - scale_factor) // 2
221
+ self.with_cp = with_cp
222
+ deconv = nn.ConvTranspose2d(
223
+ in_channels,
224
+ out_channels,
225
+ kernel_size=kernel_size,
226
+ stride=stride,
227
+ padding=padding)
228
+
229
+ norm_name, norm = build_norm_layer(norm_cfg, out_channels)
230
+ activate = build_activation_layer(act_cfg)
231
+ self.deconv_upsamping = nn.Sequential(deconv, norm, activate)
232
+
233
+ def forward(self, x):
234
+ """Forward function."""
235
+
236
+ if self.with_cp and x.requires_grad:
237
+ out = cp.checkpoint(self.deconv_upsamping, x)
238
+ else:
239
+ out = self.deconv_upsamping(x)
240
+ return out
241
+
242
+
243
+ @UPSAMPLE_LAYERS.register_module()
244
+ class InterpConv(nn.Module):
245
+ """Interpolation upsample module in decoder for UNet.
246
+
247
+ This module uses interpolation to upsample feature map in the decoder
248
+ of UNet. It consists of one interpolation upsample layer and one
249
+ convolutional layer. It can be one interpolation upsample layer followed
250
+ by one convolutional layer (conv_first=False) or one convolutional layer
251
+ followed by one interpolation upsample layer (conv_first=True).
252
+
253
+ Args:
254
+ in_channels (int): Number of input channels.
255
+ out_channels (int): Number of output channels.
256
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
257
+ memory while slowing down the training speed. Default: False.
258
+ norm_cfg (dict | None): Config dict for normalization layer.
259
+ Default: dict(type='BN').
260
+ act_cfg (dict | None): Config dict for activation layer in ConvModule.
261
+ Default: dict(type='ReLU').
262
+ conv_cfg (dict | None): Config dict for convolution layer.
263
+ Default: None.
264
+ conv_first (bool): Whether convolutional layer or interpolation
265
+ upsample layer first. Default: False. It means interpolation
266
+ upsample layer followed by one convolutional layer.
267
+ kernel_size (int): Kernel size of the convolutional layer. Default: 1.
268
+ stride (int): Stride of the convolutional layer. Default: 1.
269
+ padding (int): Padding of the convolutional layer. Default: 1.
270
+ upsampe_cfg (dict): Interpolation config of the upsample layer.
271
+ Default: dict(
272
+ scale_factor=2, mode='bilinear', align_corners=False).
273
+ """
274
+
275
+ def __init__(self,
276
+ in_channels,
277
+ out_channels,
278
+ with_cp=False,
279
+ norm_cfg=dict(type='BN'),
280
+ act_cfg=dict(type='ReLU'),
281
+ *,
282
+ conv_cfg=None,
283
+ conv_first=False,
284
+ kernel_size=1,
285
+ stride=1,
286
+ padding=0,
287
+ upsampe_cfg=dict(
288
+ scale_factor=2, mode='bilinear', align_corners=False)):
289
+ super(InterpConv, self).__init__()
290
+
291
+ self.with_cp = with_cp
292
+ conv = ConvModule(
293
+ in_channels,
294
+ out_channels,
295
+ kernel_size=kernel_size,
296
+ stride=stride,
297
+ padding=padding,
298
+ conv_cfg=conv_cfg,
299
+ norm_cfg=norm_cfg,
300
+ act_cfg=act_cfg)
301
+ upsample = nn.Upsample(**upsampe_cfg)
302
+ if conv_first:
303
+ self.interp_upsample = nn.Sequential(conv, upsample)
304
+ else:
305
+ self.interp_upsample = nn.Sequential(upsample, conv)
306
+
307
+ def forward(self, x):
308
+ """Forward function."""
309
+
310
+ if self.with_cp and x.requires_grad:
311
+ out = cp.checkpoint(self.interp_upsample, x)
312
+ else:
313
+ out = self.interp_upsample(x)
314
+ return out
315
+
316
+
317
+ class UNet(nn.Module):
318
+ """UNet backbone.
319
+ U-Net: Convolutional Networks for Biomedical Image Segmentation.
320
+ https://arxiv.org/pdf/1505.04597.pdf
321
+
322
+ Args:
323
+ in_channels (int): Number of input image channels. Default" 3.
324
+ base_channels (int): Number of base channels of each stage.
325
+ The output channels of the first stage. Default: 64.
326
+ num_stages (int): Number of stages in encoder, normally 5. Default: 5.
327
+ strides (Sequence[int 1 | 2]): Strides of each stage in encoder.
328
+ len(strides) is equal to num_stages. Normally the stride of the
329
+ first stage in encoder is 1. If strides[i]=2, it uses stride
330
+ convolution to downsample in the correspondence encoder stage.
331
+ Default: (1, 1, 1, 1, 1).
332
+ enc_num_convs (Sequence[int]): Number of convolutional layers in the
333
+ convolution block of the correspondence encoder stage.
334
+ Default: (2, 2, 2, 2, 2).
335
+ dec_num_convs (Sequence[int]): Number of convolutional layers in the
336
+ convolution block of the correspondence decoder stage.
337
+ Default: (2, 2, 2, 2).
338
+ downsamples (Sequence[int]): Whether use MaxPool to downsample the
339
+ feature map after the first stage of encoder
340
+ (stages: [1, num_stages)). If the correspondence encoder stage use
341
+ stride convolution (strides[i]=2), it will never use MaxPool to
342
+ downsample, even downsamples[i-1]=True.
343
+ Default: (True, True, True, True).
344
+ enc_dilations (Sequence[int]): Dilation rate of each stage in encoder.
345
+ Default: (1, 1, 1, 1, 1).
346
+ dec_dilations (Sequence[int]): Dilation rate of each stage in decoder.
347
+ Default: (1, 1, 1, 1).
348
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
349
+ memory while slowing down the training speed. Default: False.
350
+ conv_cfg (dict | None): Config dict for convolution layer.
351
+ Default: None.
352
+ norm_cfg (dict | None): Config dict for normalization layer.
353
+ Default: dict(type='BN').
354
+ act_cfg (dict | None): Config dict for activation layer in ConvModule.
355
+ Default: dict(type='ReLU').
356
+ upsample_cfg (dict): The upsample config of the upsample module in
357
+ decoder. Default: dict(type='InterpConv').
358
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
359
+ freeze running stats (mean and var). Note: Effect on Batch Norm
360
+ and its variants only. Default: False.
361
+ dcn (bool): Use deformable convolution in convolutional layer or not.
362
+ Default: None.
363
+ plugins (dict): plugins for convolutional layers. Default: None.
364
+
365
+ Notice:
366
+ The input image size should be devisible by the whole downsample rate
367
+ of the encoder. More detail of the whole downsample rate can be found
368
+ in UNet._check_input_devisible.
369
+
370
+ """
371
+
372
+ def __init__(self,
373
+ in_channels=3,
374
+ base_channels=64,
375
+ num_stages=5,
376
+ strides=(1, 1, 1, 1, 1),
377
+ enc_num_convs=(2, 2, 2, 2, 2),
378
+ dec_num_convs=(2, 2, 2, 2),
379
+ downsamples=(True, True, True, True),
380
+ enc_dilations=(1, 1, 1, 1, 1),
381
+ dec_dilations=(1, 1, 1, 1),
382
+ with_cp=False,
383
+ conv_cfg=None,
384
+ norm_cfg=dict(type='BN'),
385
+ act_cfg=dict(type='ReLU'),
386
+ upsample_cfg=dict(type='InterpConv'),
387
+ norm_eval=False,
388
+ dcn=None,
389
+ plugins=None):
390
+ super(UNet, self).__init__()
391
+ assert dcn is None, 'Not implemented yet.'
392
+ assert plugins is None, 'Not implemented yet.'
393
+ assert len(strides) == num_stages, \
394
+ 'The length of strides should be equal to num_stages, '\
395
+ f'while the strides is {strides}, the length of '\
396
+ f'strides is {len(strides)}, and the num_stages is '\
397
+ f'{num_stages}.'
398
+ assert len(enc_num_convs) == num_stages, \
399
+ 'The length of enc_num_convs should be equal to num_stages, '\
400
+ f'while the enc_num_convs is {enc_num_convs}, the length of '\
401
+ f'enc_num_convs is {len(enc_num_convs)}, and the num_stages is '\
402
+ f'{num_stages}.'
403
+ assert len(dec_num_convs) == (num_stages-1), \
404
+ 'The length of dec_num_convs should be equal to (num_stages-1), '\
405
+ f'while the dec_num_convs is {dec_num_convs}, the length of '\
406
+ f'dec_num_convs is {len(dec_num_convs)}, and the num_stages is '\
407
+ f'{num_stages}.'
408
+ assert len(downsamples) == (num_stages-1), \
409
+ 'The length of downsamples should be equal to (num_stages-1), '\
410
+ f'while the downsamples is {downsamples}, the length of '\
411
+ f'downsamples is {len(downsamples)}, and the num_stages is '\
412
+ f'{num_stages}.'
413
+ assert len(enc_dilations) == num_stages, \
414
+ 'The length of enc_dilations should be equal to num_stages, '\
415
+ f'while the enc_dilations is {enc_dilations}, the length of '\
416
+ f'enc_dilations is {len(enc_dilations)}, and the num_stages is '\
417
+ f'{num_stages}.'
418
+ assert len(dec_dilations) == (num_stages-1), \
419
+ 'The length of dec_dilations should be equal to (num_stages-1), '\
420
+ f'while the dec_dilations is {dec_dilations}, the length of '\
421
+ f'dec_dilations is {len(dec_dilations)}, and the num_stages is '\
422
+ f'{num_stages}.'
423
+ self.num_stages = num_stages
424
+ self.strides = strides
425
+ self.downsamples = downsamples
426
+ self.norm_eval = norm_eval
427
+
428
+ self.encoder = nn.ModuleList()
429
+ self.decoder = nn.ModuleList()
430
+
431
+ for i in range(num_stages):
432
+ enc_conv_block = []
433
+ if i != 0:
434
+ if strides[i] == 1 and downsamples[i - 1]:
435
+ enc_conv_block.append(nn.MaxPool2d(kernel_size=2))
436
+ upsample = (strides[i] != 1 or downsamples[i - 1])
437
+ self.decoder.append(
438
+ UpConvBlock(
439
+ conv_block=BasicConvBlock,
440
+ in_channels=base_channels * 2**i,
441
+ skip_channels=base_channels * 2**(i - 1),
442
+ out_channels=base_channels * 2**(i - 1),
443
+ num_convs=dec_num_convs[i - 1],
444
+ stride=1,
445
+ dilation=dec_dilations[i - 1],
446
+ with_cp=with_cp,
447
+ conv_cfg=conv_cfg,
448
+ norm_cfg=norm_cfg,
449
+ act_cfg=act_cfg,
450
+ upsample_cfg=upsample_cfg if upsample else None,
451
+ dcn=None,
452
+ plugins=None))
453
+
454
+ enc_conv_block.append(
455
+ BasicConvBlock(
456
+ in_channels=in_channels,
457
+ out_channels=base_channels * 2**i,
458
+ num_convs=enc_num_convs[i],
459
+ stride=strides[i],
460
+ dilation=enc_dilations[i],
461
+ with_cp=with_cp,
462
+ conv_cfg=conv_cfg,
463
+ norm_cfg=norm_cfg,
464
+ act_cfg=act_cfg,
465
+ dcn=None,
466
+ plugins=None))
467
+ self.encoder.append((nn.Sequential(*enc_conv_block)))
468
+ in_channels = base_channels * 2**i
469
+
470
+ def forward(self, x):
471
+ enc_outs = []
472
+
473
+ for enc in self.encoder:
474
+ x = enc(x)
475
+ enc_outs.append(x)
476
+ dec_outs = [x]
477
+ for i in reversed(range(len(self.decoder))):
478
+ x = self.decoder[i](enc_outs[i], x)
479
+ dec_outs.append(x)
480
+
481
+ return dec_outs
482
+
483
+ def init_weights(self, pretrained=None):
484
+ """Initialize the weights in backbone.
485
+
486
+ Args:
487
+ pretrained (str, optional): Path to pre-trained weights.
488
+ Defaults to None.
489
+ """
490
+ if isinstance(pretrained, str):
491
+ logger = get_root_logger()
492
+ load_checkpoint(self, pretrained, strict=False, logger=logger)
493
+ elif pretrained is None:
494
+ for m in self.modules():
495
+ if isinstance(m, nn.Conv2d):
496
+ kaiming_init(m)
497
+ elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
498
+ constant_init(m, 1)
499
+ else:
500
+ raise TypeError('pretrained must be a str or None')
501
+
502
+
503
+ class ShapeUNet(nn.Module):
504
+ """ShapeUNet backbone with small modifications.
505
+ U-Net: Convolutional Networks for Biomedical Image Segmentation.
506
+ https://arxiv.org/pdf/1505.04597.pdf
507
+
508
+ Args:
509
+ in_channels (int): Number of input image channels. Default" 3.
510
+ base_channels (int): Number of base channels of each stage.
511
+ The output channels of the first stage. Default: 64.
512
+ num_stages (int): Number of stages in encoder, normally 5. Default: 5.
513
+ strides (Sequence[int 1 | 2]): Strides of each stage in encoder.
514
+ len(strides) is equal to num_stages. Normally the stride of the
515
+ first stage in encoder is 1. If strides[i]=2, it uses stride
516
+ convolution to downsample in the correspondance encoder stage.
517
+ Default: (1, 1, 1, 1, 1).
518
+ enc_num_convs (Sequence[int]): Number of convolutional layers in the
519
+ convolution block of the correspondance encoder stage.
520
+ Default: (2, 2, 2, 2, 2).
521
+ dec_num_convs (Sequence[int]): Number of convolutional layers in the
522
+ convolution block of the correspondance decoder stage.
523
+ Default: (2, 2, 2, 2).
524
+ downsamples (Sequence[int]): Whether use MaxPool to downsample the
525
+ feature map after the first stage of encoder
526
+ (stages: [1, num_stages)). If the correspondance encoder stage use
527
+ stride convolution (strides[i]=2), it will never use MaxPool to
528
+ downsample, even downsamples[i-1]=True.
529
+ Default: (True, True, True, True).
530
+ enc_dilations (Sequence[int]): Dilation rate of each stage in encoder.
531
+ Default: (1, 1, 1, 1, 1).
532
+ dec_dilations (Sequence[int]): Dilation rate of each stage in decoder.
533
+ Default: (1, 1, 1, 1).
534
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
535
+ memory while slowing down the training speed. Default: False.
536
+ conv_cfg (dict | None): Config dict for convolution layer.
537
+ Default: None.
538
+ norm_cfg (dict | None): Config dict for normalization layer.
539
+ Default: dict(type='BN').
540
+ act_cfg (dict | None): Config dict for activation layer in ConvModule.
541
+ Default: dict(type='ReLU').
542
+ upsample_cfg (dict): The upsample config of the upsample module in
543
+ decoder. Default: dict(type='InterpConv').
544
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
545
+ freeze running stats (mean and var). Note: Effect on Batch Norm
546
+ and its variants only. Default: False.
547
+ dcn (bool): Use deformable convoluton in convolutional layer or not.
548
+ Default: None.
549
+ plugins (dict): plugins for convolutional layers. Default: None.
550
+
551
+ Notice:
552
+ The input image size should be devisible by the whole downsample rate
553
+ of the encoder. More detail of the whole downsample rate can be found
554
+ in UNet._check_input_devisible.
555
+
556
+ """
557
+
558
+ def __init__(self,
559
+ in_channels=3,
560
+ base_channels=64,
561
+ num_stages=5,
562
+ attr_embedding=128,
563
+ strides=(1, 1, 1, 1, 1),
564
+ enc_num_convs=(2, 2, 2, 2, 2),
565
+ dec_num_convs=(2, 2, 2, 2),
566
+ downsamples=(True, True, True, True),
567
+ enc_dilations=(1, 1, 1, 1, 1),
568
+ dec_dilations=(1, 1, 1, 1),
569
+ with_cp=False,
570
+ conv_cfg=None,
571
+ norm_cfg=dict(type='BN'),
572
+ act_cfg=dict(type='ReLU'),
573
+ upsample_cfg=dict(type='InterpConv'),
574
+ norm_eval=False,
575
+ dcn=None,
576
+ plugins=None):
577
+ super(ShapeUNet, self).__init__()
578
+ assert dcn is None, 'Not implemented yet.'
579
+ assert plugins is None, 'Not implemented yet.'
580
+ assert len(strides) == num_stages, \
581
+ 'The length of strides should be equal to num_stages, '\
582
+ f'while the strides is {strides}, the length of '\
583
+ f'strides is {len(strides)}, and the num_stages is '\
584
+ f'{num_stages}.'
585
+ assert len(enc_num_convs) == num_stages, \
586
+ 'The length of enc_num_convs should be equal to num_stages, '\
587
+ f'while the enc_num_convs is {enc_num_convs}, the length of '\
588
+ f'enc_num_convs is {len(enc_num_convs)}, and the num_stages is '\
589
+ f'{num_stages}.'
590
+ assert len(dec_num_convs) == (num_stages-1), \
591
+ 'The length of dec_num_convs should be equal to (num_stages-1), '\
592
+ f'while the dec_num_convs is {dec_num_convs}, the length of '\
593
+ f'dec_num_convs is {len(dec_num_convs)}, and the num_stages is '\
594
+ f'{num_stages}.'
595
+ assert len(downsamples) == (num_stages-1), \
596
+ 'The length of downsamples should be equal to (num_stages-1), '\
597
+ f'while the downsamples is {downsamples}, the length of '\
598
+ f'downsamples is {len(downsamples)}, and the num_stages is '\
599
+ f'{num_stages}.'
600
+ assert len(enc_dilations) == num_stages, \
601
+ 'The length of enc_dilations should be equal to num_stages, '\
602
+ f'while the enc_dilations is {enc_dilations}, the length of '\
603
+ f'enc_dilations is {len(enc_dilations)}, and the num_stages is '\
604
+ f'{num_stages}.'
605
+ assert len(dec_dilations) == (num_stages-1), \
606
+ 'The length of dec_dilations should be equal to (num_stages-1), '\
607
+ f'while the dec_dilations is {dec_dilations}, the length of '\
608
+ f'dec_dilations is {len(dec_dilations)}, and the num_stages is '\
609
+ f'{num_stages}.'
610
+ self.num_stages = num_stages
611
+ self.strides = strides
612
+ self.downsamples = downsamples
613
+ self.norm_eval = norm_eval
614
+
615
+ self.encoder = nn.ModuleList()
616
+ self.decoder = nn.ModuleList()
617
+
618
+ for i in range(num_stages):
619
+ enc_conv_block = []
620
+ if i != 0:
621
+ if strides[i] == 1 and downsamples[i - 1]:
622
+ enc_conv_block.append(nn.MaxPool2d(kernel_size=2))
623
+ upsample = (strides[i] != 1 or downsamples[i - 1])
624
+ self.decoder.append(
625
+ UpConvBlock(
626
+ conv_block=BasicConvBlock,
627
+ in_channels=base_channels * 2**i,
628
+ skip_channels=base_channels * 2**(i - 1),
629
+ out_channels=base_channels * 2**(i - 1),
630
+ num_convs=dec_num_convs[i - 1],
631
+ stride=1,
632
+ dilation=dec_dilations[i - 1],
633
+ with_cp=with_cp,
634
+ conv_cfg=conv_cfg,
635
+ norm_cfg=norm_cfg,
636
+ act_cfg=act_cfg,
637
+ upsample_cfg=upsample_cfg if upsample else None,
638
+ dcn=None,
639
+ plugins=None))
640
+
641
+ enc_conv_block.append(
642
+ BasicConvBlock(
643
+ in_channels=in_channels + attr_embedding,
644
+ out_channels=base_channels * 2**i,
645
+ num_convs=enc_num_convs[i],
646
+ stride=strides[i],
647
+ dilation=enc_dilations[i],
648
+ with_cp=with_cp,
649
+ conv_cfg=conv_cfg,
650
+ norm_cfg=norm_cfg,
651
+ act_cfg=act_cfg,
652
+ dcn=None,
653
+ plugins=None))
654
+ self.encoder.append((nn.Sequential(*enc_conv_block)))
655
+ in_channels = base_channels * 2**i
656
+
657
+ def forward(self, x, attr_embedding):
658
+ enc_outs = []
659
+ Be, Ce = attr_embedding.size()
660
+ for enc in self.encoder:
661
+ _, _, H, W = x.size()
662
+ x = enc(
663
+ torch.cat([
664
+ x,
665
+ attr_embedding.view(Be, Ce, 1, 1).expand((Be, Ce, H, W))
666
+ ],
667
+ dim=1))
668
+ enc_outs.append(x)
669
+ dec_outs = [x]
670
+ for i in reversed(range(len(self.decoder))):
671
+ x = self.decoder[i](enc_outs[i], x)
672
+ dec_outs.append(x)
673
+
674
+ return dec_outs
675
+
676
+ def init_weights(self, pretrained=None):
677
+ """Initialize the weights in backbone.
678
+
679
+ Args:
680
+ pretrained (str, optional): Path to pre-trained weights.
681
+ Defaults to None.
682
+ """
683
+ if isinstance(pretrained, str):
684
+ logger = get_root_logger()
685
+ load_checkpoint(self, pretrained, strict=False, logger=logger)
686
+ elif pretrained is None:
687
+ for m in self.modules():
688
+ if isinstance(m, nn.Conv2d):
689
+ kaiming_init(m)
690
+ elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
691
+ constant_init(m, 1)
692
+ else:
693
+ raise TypeError('pretrained must be a str or None')
models/archs/vqgan_arch.py ADDED
@@ -0,0 +1,1203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pytorch_diffusion + derived encoder decoder
2
+ import math
3
+ from urllib.request import proxy_bypass
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from einops import rearrange
10
+
11
+
12
+ class VectorQuantizer(nn.Module):
13
+ """
14
+ Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
15
+ avoids costly matrix multiplications and allows for post-hoc remapping of indices.
16
+ """
17
+
18
+ # NOTE: due to a bug the beta term was applied to the wrong term. for
19
+ # backwards compatibility we use the buggy version by default, but you can
20
+ # specify legacy=False to fix it.
21
+ def __init__(self,
22
+ n_e,
23
+ e_dim,
24
+ beta,
25
+ remap=None,
26
+ unknown_index="random",
27
+ sane_index_shape=False,
28
+ legacy=True):
29
+ super().__init__()
30
+ self.n_e = n_e
31
+ self.e_dim = e_dim
32
+ self.beta = beta
33
+ self.legacy = legacy
34
+
35
+ self.embedding = nn.Embedding(self.n_e, self.e_dim)
36
+ self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
37
+
38
+ self.remap = remap
39
+ if self.remap is not None:
40
+ self.register_buffer("used", torch.tensor(np.load(self.remap)))
41
+ self.re_embed = self.used.shape[0]
42
+ self.unknown_index = unknown_index # "random" or "extra" or integer
43
+ if self.unknown_index == "extra":
44
+ self.unknown_index = self.re_embed
45
+ self.re_embed = self.re_embed + 1
46
+ print(f"Remapping {self.n_e} indices to {self.re_embed} indices. "
47
+ f"Using {self.unknown_index} for unknown indices.")
48
+ else:
49
+ self.re_embed = n_e
50
+
51
+ self.sane_index_shape = sane_index_shape
52
+
53
+ def remap_to_used(self, inds):
54
+ ishape = inds.shape
55
+ assert len(ishape) > 1
56
+ inds = inds.reshape(ishape[0], -1)
57
+ used = self.used.to(inds)
58
+ match = (inds[:, :, None] == used[None, None, ...]).long()
59
+ new = match.argmax(-1)
60
+ unknown = match.sum(2) < 1
61
+ if self.unknown_index == "random":
62
+ new[unknown] = torch.randint(
63
+ 0, self.re_embed,
64
+ size=new[unknown].shape).to(device=new.device)
65
+ else:
66
+ new[unknown] = self.unknown_index
67
+ return new.reshape(ishape)
68
+
69
+ def unmap_to_all(self, inds):
70
+ ishape = inds.shape
71
+ assert len(ishape) > 1
72
+ inds = inds.reshape(ishape[0], -1)
73
+ used = self.used.to(inds)
74
+ if self.re_embed > self.used.shape[0]: # extra token
75
+ inds[inds >= self.used.shape[0]] = 0 # simply set to zero
76
+ back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
77
+ return back.reshape(ishape)
78
+
79
+ def forward(self, z, temp=None, rescale_logits=False, return_logits=False):
80
+ assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel"
81
+ assert rescale_logits == False, "Only for interface compatible with Gumbel"
82
+ assert return_logits == False, "Only for interface compatible with Gumbel"
83
+ # reshape z -> (batch, height, width, channel) and flatten
84
+ z = rearrange(z, 'b c h w -> b h w c').contiguous()
85
+ z_flattened = z.view(-1, self.e_dim)
86
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
87
+
88
+ d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
89
+ torch.sum(self.embedding.weight**2, dim=1) - 2 * \
90
+ torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n'))
91
+
92
+ min_encoding_indices = torch.argmin(d, dim=1)
93
+ z_q = self.embedding(min_encoding_indices).view(z.shape)
94
+ perplexity = None
95
+ min_encodings = None
96
+
97
+ # compute loss for embedding
98
+ if not self.legacy:
99
+ loss = self.beta * torch.mean((z_q.detach()-z)**2) + \
100
+ torch.mean((z_q - z.detach()) ** 2)
101
+ else:
102
+ loss = torch.mean((z_q.detach()-z)**2) + self.beta * \
103
+ torch.mean((z_q - z.detach()) ** 2)
104
+
105
+ # preserve gradients
106
+ z_q = z + (z_q - z).detach()
107
+
108
+ # reshape back to match original input shape
109
+ z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous()
110
+
111
+ if self.remap is not None:
112
+ min_encoding_indices = min_encoding_indices.reshape(
113
+ z.shape[0], -1) # add batch axis
114
+ min_encoding_indices = self.remap_to_used(min_encoding_indices)
115
+ min_encoding_indices = min_encoding_indices.reshape(-1,
116
+ 1) # flatten
117
+
118
+ if self.sane_index_shape:
119
+ min_encoding_indices = min_encoding_indices.reshape(
120
+ z_q.shape[0], z_q.shape[2], z_q.shape[3])
121
+
122
+ return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
123
+
124
+ def get_codebook_entry(self, indices, shape):
125
+ # shape specifying (batch, height, width, channel)
126
+ if self.remap is not None:
127
+ indices = indices.reshape(shape[0], -1) # add batch axis
128
+ indices = self.unmap_to_all(indices)
129
+ indices = indices.reshape(-1) # flatten again
130
+
131
+ # get quantized latent vectors
132
+ z_q = self.embedding(indices)
133
+
134
+ if shape is not None:
135
+ z_q = z_q.view(shape)
136
+ # reshape back to match original input shape
137
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
138
+
139
+ return z_q
140
+
141
+
142
+ class VectorQuantizerTexture(nn.Module):
143
+ """
144
+ Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
145
+ avoids costly matrix multiplications and allows for post-hoc remapping of indices.
146
+ """
147
+
148
+ # NOTE: due to a bug the beta term was applied to the wrong term. for
149
+ # backwards compatibility we use the buggy version by default, but you can
150
+ # specify legacy=False to fix it.
151
+ def __init__(self,
152
+ n_e,
153
+ e_dim,
154
+ beta,
155
+ remap=None,
156
+ unknown_index="random",
157
+ sane_index_shape=False,
158
+ legacy=True):
159
+ super().__init__()
160
+ self.n_e = n_e
161
+ self.e_dim = e_dim
162
+ self.beta = beta
163
+ self.legacy = legacy
164
+
165
+ # TODO: decide number of embeddings
166
+ self.embedding_list = nn.ModuleList(
167
+ [nn.Embedding(self.n_e, self.e_dim) for i in range(18)])
168
+ for embedding in self.embedding_list:
169
+ embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
170
+
171
+ self.remap = remap
172
+ if self.remap is not None:
173
+ self.register_buffer("used", torch.tensor(np.load(self.remap)))
174
+ self.re_embed = self.used.shape[0]
175
+ self.unknown_index = unknown_index # "random" or "extra" or integer
176
+ if self.unknown_index == "extra":
177
+ self.unknown_index = self.re_embed
178
+ self.re_embed = self.re_embed + 1
179
+ print(f"Remapping {self.n_e} indices to {self.re_embed} indices. "
180
+ f"Using {self.unknown_index} for unknown indices.")
181
+ else:
182
+ self.re_embed = n_e
183
+
184
+ self.sane_index_shape = sane_index_shape
185
+
186
+ def remap_to_used(self, inds):
187
+ ishape = inds.shape
188
+ assert len(ishape) > 1
189
+ inds = inds.reshape(ishape[0], -1)
190
+ used = self.used.to(inds)
191
+ match = (inds[:, :, None] == used[None, None, ...]).long()
192
+ new = match.argmax(-1)
193
+ unknown = match.sum(2) < 1
194
+ if self.unknown_index == "random":
195
+ new[unknown] = torch.randint(
196
+ 0, self.re_embed,
197
+ size=new[unknown].shape).to(device=new.device)
198
+ else:
199
+ new[unknown] = self.unknown_index
200
+ return new.reshape(ishape)
201
+
202
+ def unmap_to_all(self, inds):
203
+ ishape = inds.shape
204
+ assert len(ishape) > 1
205
+ inds = inds.reshape(ishape[0], -1)
206
+ used = self.used.to(inds)
207
+ if self.re_embed > self.used.shape[0]: # extra token
208
+ inds[inds >= self.used.shape[0]] = 0 # simply set to zero
209
+ back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
210
+ return back.reshape(ishape)
211
+
212
+ def forward(self,
213
+ z,
214
+ segm_map,
215
+ temp=None,
216
+ rescale_logits=False,
217
+ return_logits=False):
218
+ assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel"
219
+ assert rescale_logits == False, "Only for interface compatible with Gumbel"
220
+ assert return_logits == False, "Only for interface compatible with Gumbel"
221
+
222
+ segm_map = F.interpolate(segm_map, size=z.size()[2:], mode='nearest')
223
+ # reshape z -> (batch, height, width, channel) and flatten
224
+ z = rearrange(z, 'b c h w -> b h w c').contiguous()
225
+ z_flattened = z.view(-1, self.e_dim)
226
+
227
+ # flatten segm_map (b, h, w)
228
+ segm_map_flatten = segm_map.view(-1)
229
+
230
+ z_q = torch.zeros_like(z_flattened)
231
+ min_encoding_indices_list = []
232
+ min_encoding_indices_continual = torch.full(
233
+ segm_map_flatten.size(),
234
+ fill_value=-1,
235
+ dtype=torch.long,
236
+ device=segm_map_flatten.device)
237
+ for codebook_idx in range(18):
238
+ min_encoding_indices = torch.full(
239
+ segm_map_flatten.size(),
240
+ fill_value=-1,
241
+ dtype=torch.long,
242
+ device=segm_map_flatten.device)
243
+ if torch.sum(segm_map_flatten == codebook_idx) > 0:
244
+ z_selected = z_flattened[segm_map_flatten == codebook_idx]
245
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
246
+ d_selected = torch.sum(
247
+ z_selected**2, dim=1, keepdim=True) + torch.sum(
248
+ self.embedding_list[codebook_idx].weight**2,
249
+ dim=1) - 2 * torch.einsum(
250
+ 'bd,dn->bn', z_selected,
251
+ rearrange(self.embedding_list[codebook_idx].weight,
252
+ 'n d -> d n'))
253
+ min_encoding_indices_selected = torch.argmin(d_selected, dim=1)
254
+ z_q_selected = self.embedding_list[codebook_idx](
255
+ min_encoding_indices_selected)
256
+ z_q[segm_map_flatten == codebook_idx] = z_q_selected
257
+ min_encoding_indices[
258
+ segm_map_flatten ==
259
+ codebook_idx] = min_encoding_indices_selected
260
+ min_encoding_indices_continual[
261
+ segm_map_flatten ==
262
+ codebook_idx] = min_encoding_indices_selected + 1024 * codebook_idx
263
+ min_encoding_indices = min_encoding_indices.reshape(
264
+ z.shape[0], z.shape[1], z.shape[2])
265
+ min_encoding_indices_list.append(min_encoding_indices)
266
+
267
+ min_encoding_indices_continual = min_encoding_indices_continual.reshape(
268
+ z.shape[0], z.shape[1], z.shape[2])
269
+ z_q = z_q.view(z.shape)
270
+ perplexity = None
271
+
272
+ # compute loss for embedding
273
+ if not self.legacy:
274
+ loss = self.beta * torch.mean((z_q.detach()-z)**2) + \
275
+ torch.mean((z_q - z.detach()) ** 2)
276
+ else:
277
+ loss = torch.mean((z_q.detach()-z)**2) + self.beta * \
278
+ torch.mean((z_q - z.detach()) ** 2)
279
+
280
+ # preserve gradients
281
+ z_q = z + (z_q - z).detach()
282
+
283
+ # reshape back to match original input shape
284
+ z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous()
285
+
286
+ return z_q, loss, (perplexity, min_encoding_indices_continual,
287
+ min_encoding_indices_list)
288
+
289
+ def get_codebook_entry(self, indices_list, segm_map, shape):
290
+ # flatten segm_map (b, h, w)
291
+ segm_map = F.interpolate(
292
+ segm_map, size=(shape[1], shape[2]), mode='nearest')
293
+ segm_map_flatten = segm_map.view(-1)
294
+
295
+ z_q = torch.zeros((shape[0] * shape[1] * shape[2]),
296
+ self.e_dim).to(segm_map.device)
297
+ for codebook_idx in range(18):
298
+ if torch.sum(segm_map_flatten == codebook_idx) > 0:
299
+ min_encoding_indices_selected = indices_list[
300
+ codebook_idx].view(-1)[segm_map_flatten == codebook_idx]
301
+ z_q_selected = self.embedding_list[codebook_idx](
302
+ min_encoding_indices_selected)
303
+ z_q[segm_map_flatten == codebook_idx] = z_q_selected
304
+
305
+ z_q = z_q.view(shape)
306
+ # reshape back to match original input shape
307
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
308
+
309
+ return z_q
310
+
311
+
312
+ def sample_patches(inputs, patch_size=3, stride=1):
313
+ """Extract sliding local patches from an input feature tensor.
314
+ The sampled pathes are row-major.
315
+ Args:
316
+ inputs (Tensor): the input feature maps, shape: (n, c, h, w).
317
+ patch_size (int): the spatial size of sampled patches. Default: 3.
318
+ stride (int): the stride of sampling. Default: 1.
319
+ Returns:
320
+ patches (Tensor): extracted patches, shape: (n, c * patch_size *
321
+ patch_size, n_patches).
322
+ """
323
+
324
+ patches = F.unfold(inputs, (patch_size, patch_size), stride=stride)
325
+
326
+ return patches
327
+
328
+
329
+ class VectorQuantizerSpatialTextureAware(nn.Module):
330
+ """
331
+ Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
332
+ avoids costly matrix multiplications and allows for post-hoc remapping of indices.
333
+ """
334
+
335
+ # NOTE: due to a bug the beta term was applied to the wrong term. for
336
+ # backwards compatibility we use the buggy version by default, but you can
337
+ # specify legacy=False to fix it.
338
+ def __init__(self,
339
+ n_e,
340
+ e_dim,
341
+ beta,
342
+ spatial_size,
343
+ remap=None,
344
+ unknown_index="random",
345
+ sane_index_shape=False,
346
+ legacy=True):
347
+ super().__init__()
348
+ self.n_e = n_e
349
+ self.e_dim = e_dim * spatial_size * spatial_size
350
+ self.beta = beta
351
+ self.legacy = legacy
352
+ self.spatial_size = spatial_size
353
+
354
+ # TODO: decide number of embeddings
355
+ self.embedding_list = nn.ModuleList(
356
+ [nn.Embedding(self.n_e, self.e_dim) for i in range(18)])
357
+ for embedding in self.embedding_list:
358
+ embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
359
+
360
+ self.remap = remap
361
+ if self.remap is not None:
362
+ self.register_buffer("used", torch.tensor(np.load(self.remap)))
363
+ self.re_embed = self.used.shape[0]
364
+ self.unknown_index = unknown_index # "random" or "extra" or integer
365
+ if self.unknown_index == "extra":
366
+ self.unknown_index = self.re_embed
367
+ self.re_embed = self.re_embed + 1
368
+ print(f"Remapping {self.n_e} indices to {self.re_embed} indices. "
369
+ f"Using {self.unknown_index} for unknown indices.")
370
+ else:
371
+ self.re_embed = n_e
372
+
373
+ self.sane_index_shape = sane_index_shape
374
+
375
+ def forward(self,
376
+ z,
377
+ segm_map,
378
+ temp=None,
379
+ rescale_logits=False,
380
+ return_logits=False):
381
+ assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel"
382
+ assert rescale_logits == False, "Only for interface compatible with Gumbel"
383
+ assert return_logits == False, "Only for interface compatible with Gumbel"
384
+
385
+ segm_map = F.interpolate(
386
+ segm_map,
387
+ size=(z.size(2) // self.spatial_size,
388
+ z.size(3) // self.spatial_size),
389
+ mode='nearest')
390
+
391
+ # reshape z -> (batch, height, width, channel) and flatten
392
+ # z = rearrange(z, 'b c h w -> b h w c').contiguous() ?
393
+ z_patches = sample_patches(
394
+ z, patch_size=self.spatial_size,
395
+ stride=self.spatial_size).permute(0, 2, 1)
396
+ z_patches_flattened = z_patches.reshape(-1, self.e_dim)
397
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
398
+
399
+ # flatten segm_map (b, h, w)
400
+ segm_map_flatten = segm_map.view(-1)
401
+
402
+ z_q = torch.zeros_like(z_patches_flattened)
403
+ min_encoding_indices_list = []
404
+ min_encoding_indices_continual = torch.full(
405
+ segm_map_flatten.size(),
406
+ fill_value=-1,
407
+ dtype=torch.long,
408
+ device=segm_map_flatten.device)
409
+
410
+ for codebook_idx in range(18):
411
+ min_encoding_indices = torch.full(
412
+ segm_map_flatten.size(),
413
+ fill_value=-1,
414
+ dtype=torch.long,
415
+ device=segm_map_flatten.device)
416
+ if torch.sum(segm_map_flatten == codebook_idx) > 0:
417
+ z_selected = z_patches_flattened[segm_map_flatten ==
418
+ codebook_idx]
419
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
420
+ d_selected = torch.sum(
421
+ z_selected**2, dim=1, keepdim=True) + torch.sum(
422
+ self.embedding_list[codebook_idx].weight**2,
423
+ dim=1) - 2 * torch.einsum(
424
+ 'bd,dn->bn', z_selected,
425
+ rearrange(self.embedding_list[codebook_idx].weight,
426
+ 'n d -> d n'))
427
+ min_encoding_indices_selected = torch.argmin(d_selected, dim=1)
428
+ z_q_selected = self.embedding_list[codebook_idx](
429
+ min_encoding_indices_selected)
430
+ z_q[segm_map_flatten == codebook_idx] = z_q_selected
431
+ min_encoding_indices[
432
+ segm_map_flatten ==
433
+ codebook_idx] = min_encoding_indices_selected
434
+ min_encoding_indices_continual[
435
+ segm_map_flatten ==
436
+ codebook_idx] = min_encoding_indices_selected + self.n_e * codebook_idx
437
+ min_encoding_indices = min_encoding_indices.reshape(
438
+ z_patches.shape[0], segm_map.shape[2], segm_map.shape[3])
439
+ min_encoding_indices_list.append(min_encoding_indices)
440
+
441
+ z_q = F.fold(
442
+ z_q.view(z_patches.shape).permute(0, 2, 1),
443
+ z.size()[2:],
444
+ kernel_size=(self.spatial_size, self.spatial_size),
445
+ stride=self.spatial_size)
446
+
447
+ perplexity = None
448
+
449
+ # compute loss for embedding
450
+ if not self.legacy:
451
+ loss = self.beta * torch.mean((z_q.detach()-z)**2) + \
452
+ torch.mean((z_q - z.detach()) ** 2)
453
+ else:
454
+ loss = torch.mean((z_q.detach()-z)**2) + self.beta * \
455
+ torch.mean((z_q - z.detach()) ** 2)
456
+
457
+ # preserve gradients
458
+ z_q = z + (z_q - z).detach()
459
+
460
+ return z_q, loss, (perplexity, min_encoding_indices_continual,
461
+ min_encoding_indices_list)
462
+
463
+ def get_codebook_entry(self, indices_list, segm_map, shape):
464
+ # flatten segm_map (b, h, w)
465
+ segm_map = F.interpolate(
466
+ segm_map, size=(shape[1], shape[2]), mode='nearest')
467
+ segm_map_flatten = segm_map.view(-1)
468
+
469
+ z_q = torch.zeros((shape[0] * shape[1] * shape[2]),
470
+ self.e_dim).to(segm_map.device)
471
+ for codebook_idx in range(18):
472
+ if torch.sum(segm_map_flatten == codebook_idx) > 0:
473
+ min_encoding_indices_selected = indices_list[
474
+ codebook_idx].view(-1)[segm_map_flatten == codebook_idx]
475
+ z_q_selected = self.embedding_list[codebook_idx](
476
+ min_encoding_indices_selected)
477
+ z_q[segm_map_flatten == codebook_idx] = z_q_selected
478
+
479
+ z_q = F.fold(
480
+ z_q.view(((shape[0], shape[1] * shape[2],
481
+ self.e_dim))).permute(0, 2, 1),
482
+ (shape[1] * self.spatial_size, shape[2] * self.spatial_size),
483
+ kernel_size=(self.spatial_size, self.spatial_size),
484
+ stride=self.spatial_size)
485
+
486
+ return z_q
487
+
488
+
489
+ def get_timestep_embedding(timesteps, embedding_dim):
490
+ """
491
+ This matches the implementation in Denoising Diffusion Probabilistic Models:
492
+ From Fairseq.
493
+ Build sinusoidal embeddings.
494
+ This matches the implementation in tensor2tensor, but differs slightly
495
+ from the description in Section 3.5 of "Attention Is All You Need".
496
+ """
497
+ assert len(timesteps.shape) == 1
498
+
499
+ half_dim = embedding_dim // 2
500
+ emb = math.log(10000) / (half_dim - 1)
501
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
502
+ emb = emb.to(device=timesteps.device)
503
+ emb = timesteps.float()[:, None] * emb[None, :]
504
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
505
+ if embedding_dim % 2 == 1: # zero pad
506
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
507
+ return emb
508
+
509
+
510
+ def nonlinearity(x):
511
+ # swish
512
+ return x * torch.sigmoid(x)
513
+
514
+
515
+ def Normalize(in_channels):
516
+ return torch.nn.GroupNorm(
517
+ num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
518
+
519
+
520
+ class Upsample(nn.Module):
521
+
522
+ def __init__(self, in_channels, with_conv):
523
+ super().__init__()
524
+ self.with_conv = with_conv
525
+ if self.with_conv:
526
+ self.conv = torch.nn.Conv2d(
527
+ in_channels, in_channels, kernel_size=3, stride=1, padding=1)
528
+
529
+ def forward(self, x):
530
+ x = torch.nn.functional.interpolate(
531
+ x, scale_factor=2.0, mode="nearest")
532
+ if self.with_conv:
533
+ x = self.conv(x)
534
+ return x
535
+
536
+
537
+ class Downsample(nn.Module):
538
+
539
+ def __init__(self, in_channels, with_conv):
540
+ super().__init__()
541
+ self.with_conv = with_conv
542
+ if self.with_conv:
543
+ # no asymmetric padding in torch conv, must do it ourselves
544
+ self.conv = torch.nn.Conv2d(
545
+ in_channels, in_channels, kernel_size=3, stride=2, padding=0)
546
+
547
+ def forward(self, x):
548
+ if self.with_conv:
549
+ pad = (0, 1, 0, 1)
550
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
551
+ x = self.conv(x)
552
+ else:
553
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
554
+ return x
555
+
556
+
557
+ class ResnetBlock(nn.Module):
558
+
559
+ def __init__(self,
560
+ *,
561
+ in_channels,
562
+ out_channels=None,
563
+ conv_shortcut=False,
564
+ dropout,
565
+ temb_channels=512):
566
+ super().__init__()
567
+ self.in_channels = in_channels
568
+ out_channels = in_channels if out_channels is None else out_channels
569
+ self.out_channels = out_channels
570
+ self.use_conv_shortcut = conv_shortcut
571
+
572
+ self.norm1 = Normalize(in_channels)
573
+ self.conv1 = torch.nn.Conv2d(
574
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1)
575
+ if temb_channels > 0:
576
+ self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
577
+ self.norm2 = Normalize(out_channels)
578
+ self.dropout = torch.nn.Dropout(dropout)
579
+ self.conv2 = torch.nn.Conv2d(
580
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1)
581
+ if self.in_channels != self.out_channels:
582
+ if self.use_conv_shortcut:
583
+ self.conv_shortcut = torch.nn.Conv2d(
584
+ in_channels,
585
+ out_channels,
586
+ kernel_size=3,
587
+ stride=1,
588
+ padding=1)
589
+ else:
590
+ self.nin_shortcut = torch.nn.Conv2d(
591
+ in_channels,
592
+ out_channels,
593
+ kernel_size=1,
594
+ stride=1,
595
+ padding=0)
596
+
597
+ def forward(self, x, temb):
598
+ h = x
599
+ h = self.norm1(h)
600
+ h = nonlinearity(h)
601
+ h = self.conv1(h)
602
+
603
+ if temb is not None:
604
+ h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
605
+
606
+ h = self.norm2(h)
607
+ h = nonlinearity(h)
608
+ h = self.dropout(h)
609
+ h = self.conv2(h)
610
+
611
+ if self.in_channels != self.out_channels:
612
+ if self.use_conv_shortcut:
613
+ x = self.conv_shortcut(x)
614
+ else:
615
+ x = self.nin_shortcut(x)
616
+
617
+ return x + h
618
+
619
+
620
+ class AttnBlock(nn.Module):
621
+
622
+ def __init__(self, in_channels):
623
+ super().__init__()
624
+ self.in_channels = in_channels
625
+
626
+ self.norm = Normalize(in_channels)
627
+ self.q = torch.nn.Conv2d(
628
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0)
629
+ self.k = torch.nn.Conv2d(
630
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0)
631
+ self.v = torch.nn.Conv2d(
632
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0)
633
+ self.proj_out = torch.nn.Conv2d(
634
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0)
635
+
636
+ def forward(self, x):
637
+ h_ = x
638
+ h_ = self.norm(h_)
639
+ q = self.q(h_)
640
+ k = self.k(h_)
641
+ v = self.v(h_)
642
+
643
+ # compute attention
644
+ b, c, h, w = q.shape
645
+ q = q.reshape(b, c, h * w)
646
+ q = q.permute(0, 2, 1) # b,hw,c
647
+ k = k.reshape(b, c, h * w) # b,c,hw
648
+ w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
649
+ w_ = w_ * (int(c)**(-0.5))
650
+ w_ = torch.nn.functional.softmax(w_, dim=2)
651
+
652
+ # attend to values
653
+ v = v.reshape(b, c, h * w)
654
+ w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
655
+ h_ = torch.bmm(
656
+ v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
657
+ h_ = h_.reshape(b, c, h, w)
658
+
659
+ h_ = self.proj_out(h_)
660
+
661
+ return x + h_
662
+
663
+
664
+ class Model(nn.Module):
665
+
666
+ def __init__(self,
667
+ *,
668
+ ch,
669
+ out_ch,
670
+ ch_mult=(1, 2, 4, 8),
671
+ num_res_blocks,
672
+ attn_resolutions,
673
+ dropout=0.0,
674
+ resamp_with_conv=True,
675
+ in_channels,
676
+ resolution,
677
+ use_timestep=True):
678
+ super().__init__()
679
+ self.ch = ch
680
+ self.temb_ch = self.ch * 4
681
+ self.num_resolutions = len(ch_mult)
682
+ self.num_res_blocks = num_res_blocks
683
+ self.resolution = resolution
684
+ self.in_channels = in_channels
685
+
686
+ self.use_timestep = use_timestep
687
+ if self.use_timestep:
688
+ # timestep embedding
689
+ self.temb = nn.Module()
690
+ self.temb.dense = nn.ModuleList([
691
+ torch.nn.Linear(self.ch, self.temb_ch),
692
+ torch.nn.Linear(self.temb_ch, self.temb_ch),
693
+ ])
694
+
695
+ # downsampling
696
+ self.conv_in = torch.nn.Conv2d(
697
+ in_channels, self.ch, kernel_size=3, stride=1, padding=1)
698
+
699
+ curr_res = resolution
700
+ in_ch_mult = (1, ) + tuple(ch_mult)
701
+ self.down = nn.ModuleList()
702
+ for i_level in range(self.num_resolutions):
703
+ block = nn.ModuleList()
704
+ attn = nn.ModuleList()
705
+ block_in = ch * in_ch_mult[i_level]
706
+ block_out = ch * ch_mult[i_level]
707
+ for i_block in range(self.num_res_blocks):
708
+ block.append(
709
+ ResnetBlock(
710
+ in_channels=block_in,
711
+ out_channels=block_out,
712
+ temb_channels=self.temb_ch,
713
+ dropout=dropout))
714
+ block_in = block_out
715
+ if curr_res in attn_resolutions:
716
+ attn.append(AttnBlock(block_in))
717
+ down = nn.Module()
718
+ down.block = block
719
+ down.attn = attn
720
+ if i_level != self.num_resolutions - 1:
721
+ down.downsample = Downsample(block_in, resamp_with_conv)
722
+ curr_res = curr_res // 2
723
+ self.down.append(down)
724
+
725
+ # middle
726
+ self.mid = nn.Module()
727
+ self.mid.block_1 = ResnetBlock(
728
+ in_channels=block_in,
729
+ out_channels=block_in,
730
+ temb_channels=self.temb_ch,
731
+ dropout=dropout)
732
+ self.mid.attn_1 = AttnBlock(block_in)
733
+ self.mid.block_2 = ResnetBlock(
734
+ in_channels=block_in,
735
+ out_channels=block_in,
736
+ temb_channels=self.temb_ch,
737
+ dropout=dropout)
738
+
739
+ # upsampling
740
+ self.up = nn.ModuleList()
741
+ for i_level in reversed(range(self.num_resolutions)):
742
+ block = nn.ModuleList()
743
+ attn = nn.ModuleList()
744
+ block_out = ch * ch_mult[i_level]
745
+ skip_in = ch * ch_mult[i_level]
746
+ for i_block in range(self.num_res_blocks + 1):
747
+ if i_block == self.num_res_blocks:
748
+ skip_in = ch * in_ch_mult[i_level]
749
+ block.append(
750
+ ResnetBlock(
751
+ in_channels=block_in + skip_in,
752
+ out_channels=block_out,
753
+ temb_channels=self.temb_ch,
754
+ dropout=dropout))
755
+ block_in = block_out
756
+ if curr_res in attn_resolutions:
757
+ attn.append(AttnBlock(block_in))
758
+ up = nn.Module()
759
+ up.block = block
760
+ up.attn = attn
761
+ if i_level != 0:
762
+ up.upsample = Upsample(block_in, resamp_with_conv)
763
+ curr_res = curr_res * 2
764
+ self.up.insert(0, up) # prepend to get consistent order
765
+
766
+ # end
767
+ self.norm_out = Normalize(block_in)
768
+ self.conv_out = torch.nn.Conv2d(
769
+ block_in, out_ch, kernel_size=3, stride=1, padding=1)
770
+
771
+ def forward(self, x, t=None):
772
+ #assert x.shape[2] == x.shape[3] == self.resolution
773
+
774
+ if self.use_timestep:
775
+ # timestep embedding
776
+ assert t is not None
777
+ temb = get_timestep_embedding(t, self.ch)
778
+ temb = self.temb.dense[0](temb)
779
+ temb = nonlinearity(temb)
780
+ temb = self.temb.dense[1](temb)
781
+ else:
782
+ temb = None
783
+
784
+ # downsampling
785
+ hs = [self.conv_in(x)]
786
+ for i_level in range(self.num_resolutions):
787
+ for i_block in range(self.num_res_blocks):
788
+ h = self.down[i_level].block[i_block](hs[-1], temb)
789
+ if len(self.down[i_level].attn) > 0:
790
+ h = self.down[i_level].attn[i_block](h)
791
+ hs.append(h)
792
+ if i_level != self.num_resolutions - 1:
793
+ hs.append(self.down[i_level].downsample(hs[-1]))
794
+
795
+ # middle
796
+ h = hs[-1]
797
+ h = self.mid.block_1(h, temb)
798
+ h = self.mid.attn_1(h)
799
+ h = self.mid.block_2(h, temb)
800
+
801
+ # upsampling
802
+ for i_level in reversed(range(self.num_resolutions)):
803
+ for i_block in range(self.num_res_blocks + 1):
804
+ h = self.up[i_level].block[i_block](torch.cat([h, hs.pop()],
805
+ dim=1), temb)
806
+ if len(self.up[i_level].attn) > 0:
807
+ h = self.up[i_level].attn[i_block](h)
808
+ if i_level != 0:
809
+ h = self.up[i_level].upsample(h)
810
+
811
+ # end
812
+ h = self.norm_out(h)
813
+ h = nonlinearity(h)
814
+ h = self.conv_out(h)
815
+ return h
816
+
817
+
818
+ class Encoder(nn.Module):
819
+
820
+ def __init__(self,
821
+ ch,
822
+ num_res_blocks,
823
+ attn_resolutions,
824
+ in_channels,
825
+ resolution,
826
+ z_channels,
827
+ ch_mult=(1, 2, 4, 8),
828
+ dropout=0.0,
829
+ resamp_with_conv=True,
830
+ double_z=True):
831
+ super().__init__()
832
+ self.ch = ch
833
+ self.temb_ch = 0
834
+ self.num_resolutions = len(ch_mult)
835
+ self.num_res_blocks = num_res_blocks
836
+ self.resolution = resolution
837
+ self.in_channels = in_channels
838
+
839
+ # downsampling
840
+ self.conv_in = torch.nn.Conv2d(
841
+ in_channels, self.ch, kernel_size=3, stride=1, padding=1)
842
+
843
+ curr_res = resolution
844
+ in_ch_mult = (1, ) + tuple(ch_mult)
845
+ self.down = nn.ModuleList()
846
+ for i_level in range(self.num_resolutions):
847
+ block = nn.ModuleList()
848
+ attn = nn.ModuleList()
849
+ block_in = ch * in_ch_mult[i_level]
850
+ block_out = ch * ch_mult[i_level]
851
+ for i_block in range(self.num_res_blocks):
852
+ block.append(
853
+ ResnetBlock(
854
+ in_channels=block_in,
855
+ out_channels=block_out,
856
+ temb_channels=self.temb_ch,
857
+ dropout=dropout))
858
+ block_in = block_out
859
+ if curr_res in attn_resolutions:
860
+ attn.append(AttnBlock(block_in))
861
+ down = nn.Module()
862
+ down.block = block
863
+ down.attn = attn
864
+ if i_level != self.num_resolutions - 1:
865
+ down.downsample = Downsample(block_in, resamp_with_conv)
866
+ curr_res = curr_res // 2
867
+ self.down.append(down)
868
+
869
+ # middle
870
+ self.mid = nn.Module()
871
+ self.mid.block_1 = ResnetBlock(
872
+ in_channels=block_in,
873
+ out_channels=block_in,
874
+ temb_channels=self.temb_ch,
875
+ dropout=dropout)
876
+ self.mid.attn_1 = AttnBlock(block_in)
877
+ self.mid.block_2 = ResnetBlock(
878
+ in_channels=block_in,
879
+ out_channels=block_in,
880
+ temb_channels=self.temb_ch,
881
+ dropout=dropout)
882
+
883
+ # end
884
+ self.norm_out = Normalize(block_in)
885
+ self.conv_out = torch.nn.Conv2d(
886
+ block_in,
887
+ 2 * z_channels if double_z else z_channels,
888
+ kernel_size=3,
889
+ stride=1,
890
+ padding=1)
891
+
892
+ def forward(self, x):
893
+ #assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution)
894
+
895
+ # timestep embedding
896
+ temb = None
897
+
898
+ # downsampling
899
+ hs = [self.conv_in(x)]
900
+ for i_level in range(self.num_resolutions):
901
+ for i_block in range(self.num_res_blocks):
902
+ h = self.down[i_level].block[i_block](hs[-1], temb)
903
+ if len(self.down[i_level].attn) > 0:
904
+ h = self.down[i_level].attn[i_block](h)
905
+ hs.append(h)
906
+ if i_level != self.num_resolutions - 1:
907
+ hs.append(self.down[i_level].downsample(hs[-1]))
908
+
909
+ # middle
910
+ h = hs[-1]
911
+ h = self.mid.block_1(h, temb)
912
+ h = self.mid.attn_1(h)
913
+ h = self.mid.block_2(h, temb)
914
+
915
+ # end
916
+ h = self.norm_out(h)
917
+ h = nonlinearity(h)
918
+ h = self.conv_out(h)
919
+ return h
920
+
921
+
922
+ class Decoder(nn.Module):
923
+
924
+ def __init__(self,
925
+ in_channels,
926
+ resolution,
927
+ z_channels,
928
+ ch,
929
+ out_ch,
930
+ num_res_blocks,
931
+ attn_resolutions,
932
+ ch_mult=(1, 2, 4, 8),
933
+ dropout=0.0,
934
+ resamp_with_conv=True,
935
+ give_pre_end=False):
936
+ super().__init__()
937
+ self.ch = ch
938
+ self.temb_ch = 0
939
+ self.num_resolutions = len(ch_mult)
940
+ self.num_res_blocks = num_res_blocks
941
+ self.resolution = resolution
942
+ self.in_channels = in_channels
943
+ self.give_pre_end = give_pre_end
944
+
945
+ # compute in_ch_mult, block_in and curr_res at lowest res
946
+ in_ch_mult = (1, ) + tuple(ch_mult)
947
+ block_in = ch * ch_mult[self.num_resolutions - 1]
948
+ curr_res = resolution // 2**(self.num_resolutions - 1)
949
+ self.z_shape = (1, z_channels, curr_res, curr_res // 2)
950
+ print("Working with z of shape {} = {} dimensions.".format(
951
+ self.z_shape, np.prod(self.z_shape)))
952
+
953
+ # z to block_in
954
+ self.conv_in = torch.nn.Conv2d(
955
+ z_channels, block_in, kernel_size=3, stride=1, padding=1)
956
+
957
+ # middle
958
+ self.mid = nn.Module()
959
+ self.mid.block_1 = ResnetBlock(
960
+ in_channels=block_in,
961
+ out_channels=block_in,
962
+ temb_channels=self.temb_ch,
963
+ dropout=dropout)
964
+ self.mid.attn_1 = AttnBlock(block_in)
965
+ self.mid.block_2 = ResnetBlock(
966
+ in_channels=block_in,
967
+ out_channels=block_in,
968
+ temb_channels=self.temb_ch,
969
+ dropout=dropout)
970
+
971
+ # upsampling
972
+ self.up = nn.ModuleList()
973
+ for i_level in reversed(range(self.num_resolutions)):
974
+ block = nn.ModuleList()
975
+ attn = nn.ModuleList()
976
+ block_out = ch * ch_mult[i_level]
977
+ for i_block in range(self.num_res_blocks + 1):
978
+ block.append(
979
+ ResnetBlock(
980
+ in_channels=block_in,
981
+ out_channels=block_out,
982
+ temb_channels=self.temb_ch,
983
+ dropout=dropout))
984
+ block_in = block_out
985
+ if curr_res in attn_resolutions:
986
+ attn.append(AttnBlock(block_in))
987
+ up = nn.Module()
988
+ up.block = block
989
+ up.attn = attn
990
+ if i_level != 0:
991
+ up.upsample = Upsample(block_in, resamp_with_conv)
992
+ curr_res = curr_res * 2
993
+ self.up.insert(0, up) # prepend to get consistent order
994
+
995
+ # end
996
+ self.norm_out = Normalize(block_in)
997
+ self.conv_out = torch.nn.Conv2d(
998
+ block_in, out_ch, kernel_size=3, stride=1, padding=1)
999
+
1000
+ def forward(self, z, bot_h=None):
1001
+ #assert z.shape[1:] == self.z_shape[1:]
1002
+ self.last_z_shape = z.shape
1003
+
1004
+ # timestep embedding
1005
+ temb = None
1006
+
1007
+ # z to block_in
1008
+ h = self.conv_in(z)
1009
+
1010
+ # middle
1011
+ h = self.mid.block_1(h, temb)
1012
+ h = self.mid.attn_1(h)
1013
+ h = self.mid.block_2(h, temb)
1014
+
1015
+ # upsampling
1016
+ for i_level in reversed(range(self.num_resolutions)):
1017
+ for i_block in range(self.num_res_blocks + 1):
1018
+ h = self.up[i_level].block[i_block](h, temb)
1019
+ if len(self.up[i_level].attn) > 0:
1020
+ h = self.up[i_level].attn[i_block](h)
1021
+ if i_level != 0:
1022
+ h = self.up[i_level].upsample(h)
1023
+ if i_level == 4 and bot_h is not None:
1024
+ h += bot_h
1025
+
1026
+ # end
1027
+ if self.give_pre_end:
1028
+ return h
1029
+
1030
+ h = self.norm_out(h)
1031
+ h = nonlinearity(h)
1032
+ h = self.conv_out(h)
1033
+ return h
1034
+
1035
+ def get_feature_top(self, z):
1036
+ #assert z.shape[1:] == self.z_shape[1:]
1037
+ self.last_z_shape = z.shape
1038
+
1039
+ # timestep embedding
1040
+ temb = None
1041
+
1042
+ # z to block_in
1043
+ h = self.conv_in(z)
1044
+
1045
+ # middle
1046
+ h = self.mid.block_1(h, temb)
1047
+ h = self.mid.attn_1(h)
1048
+ h = self.mid.block_2(h, temb)
1049
+
1050
+ # upsampling
1051
+ for i_level in reversed(range(self.num_resolutions)):
1052
+ for i_block in range(self.num_res_blocks + 1):
1053
+ h = self.up[i_level].block[i_block](h, temb)
1054
+ if len(self.up[i_level].attn) > 0:
1055
+ h = self.up[i_level].attn[i_block](h)
1056
+ if i_level != 0:
1057
+ h = self.up[i_level].upsample(h)
1058
+ if i_level == 4:
1059
+ return h
1060
+
1061
+ def get_feature_middle(self, z, mid_h):
1062
+ #assert z.shape[1:] == self.z_shape[1:]
1063
+ self.last_z_shape = z.shape
1064
+
1065
+ # timestep embedding
1066
+ temb = None
1067
+
1068
+ # z to block_in
1069
+ h = self.conv_in(z)
1070
+
1071
+ # middle
1072
+ h = self.mid.block_1(h, temb)
1073
+ h = self.mid.attn_1(h)
1074
+ h = self.mid.block_2(h, temb)
1075
+
1076
+ # upsampling
1077
+ for i_level in reversed(range(self.num_resolutions)):
1078
+ for i_block in range(self.num_res_blocks + 1):
1079
+ h = self.up[i_level].block[i_block](h, temb)
1080
+ if len(self.up[i_level].attn) > 0:
1081
+ h = self.up[i_level].attn[i_block](h)
1082
+ if i_level != 0:
1083
+ h = self.up[i_level].upsample(h)
1084
+ if i_level == 4:
1085
+ h += mid_h
1086
+ if i_level == 3:
1087
+ return h
1088
+
1089
+
1090
+ class DecoderRes(nn.Module):
1091
+
1092
+ def __init__(self,
1093
+ in_channels,
1094
+ resolution,
1095
+ z_channels,
1096
+ ch,
1097
+ num_res_blocks,
1098
+ ch_mult=(1, 2, 4, 8),
1099
+ dropout=0.0,
1100
+ give_pre_end=False):
1101
+ super().__init__()
1102
+ self.ch = ch
1103
+ self.temb_ch = 0
1104
+ self.num_resolutions = len(ch_mult)
1105
+ self.num_res_blocks = num_res_blocks
1106
+ self.resolution = resolution
1107
+ self.in_channels = in_channels
1108
+ self.give_pre_end = give_pre_end
1109
+
1110
+ # compute in_ch_mult, block_in and curr_res at lowest res
1111
+ in_ch_mult = (1, ) + tuple(ch_mult)
1112
+ block_in = ch * ch_mult[self.num_resolutions - 1]
1113
+ curr_res = resolution // 2**(self.num_resolutions - 1)
1114
+ self.z_shape = (1, z_channels, curr_res, curr_res // 2)
1115
+ print("Working with z of shape {} = {} dimensions.".format(
1116
+ self.z_shape, np.prod(self.z_shape)))
1117
+
1118
+ # z to block_in
1119
+ self.conv_in = torch.nn.Conv2d(
1120
+ z_channels, block_in, kernel_size=3, stride=1, padding=1)
1121
+
1122
+ # middle
1123
+ self.mid = nn.Module()
1124
+ self.mid.block_1 = ResnetBlock(
1125
+ in_channels=block_in,
1126
+ out_channels=block_in,
1127
+ temb_channels=self.temb_ch,
1128
+ dropout=dropout)
1129
+ self.mid.attn_1 = AttnBlock(block_in)
1130
+ self.mid.block_2 = ResnetBlock(
1131
+ in_channels=block_in,
1132
+ out_channels=block_in,
1133
+ temb_channels=self.temb_ch,
1134
+ dropout=dropout)
1135
+
1136
+ def forward(self, z):
1137
+ #assert z.shape[1:] == self.z_shape[1:]
1138
+ self.last_z_shape = z.shape
1139
+
1140
+ # timestep embedding
1141
+ temb = None
1142
+
1143
+ # z to block_in
1144
+ h = self.conv_in(z)
1145
+
1146
+ # middle
1147
+ h = self.mid.block_1(h, temb)
1148
+ h = self.mid.attn_1(h)
1149
+ h = self.mid.block_2(h, temb)
1150
+
1151
+ return h
1152
+
1153
+
1154
+ # patch based discriminator
1155
+ class Discriminator(nn.Module):
1156
+
1157
+ def __init__(self, nc, ndf, n_layers=3):
1158
+ super().__init__()
1159
+
1160
+ layers = [
1161
+ nn.Conv2d(nc, ndf, kernel_size=4, stride=2, padding=1),
1162
+ nn.LeakyReLU(0.2, True)
1163
+ ]
1164
+ ndf_mult = 1
1165
+ ndf_mult_prev = 1
1166
+ for n in range(1,
1167
+ n_layers): # gradually increase the number of filters
1168
+ ndf_mult_prev = ndf_mult
1169
+ ndf_mult = min(2**n, 8)
1170
+ layers += [
1171
+ nn.Conv2d(
1172
+ ndf * ndf_mult_prev,
1173
+ ndf * ndf_mult,
1174
+ kernel_size=4,
1175
+ stride=2,
1176
+ padding=1,
1177
+ bias=False),
1178
+ nn.BatchNorm2d(ndf * ndf_mult),
1179
+ nn.LeakyReLU(0.2, True)
1180
+ ]
1181
+
1182
+ ndf_mult_prev = ndf_mult
1183
+ ndf_mult = min(2**n_layers, 8)
1184
+
1185
+ layers += [
1186
+ nn.Conv2d(
1187
+ ndf * ndf_mult_prev,
1188
+ ndf * ndf_mult,
1189
+ kernel_size=4,
1190
+ stride=1,
1191
+ padding=1,
1192
+ bias=False),
1193
+ nn.BatchNorm2d(ndf * ndf_mult),
1194
+ nn.LeakyReLU(0.2, True)
1195
+ ]
1196
+
1197
+ layers += [
1198
+ nn.Conv2d(ndf * ndf_mult, 1, kernel_size=4, stride=1, padding=1)
1199
+ ] # output 1 channel prediction map
1200
+ self.main = nn.Sequential(*layers)
1201
+
1202
+ def forward(self, x):
1203
+ return self.main(x)
models/hierarchy_inference_model.py ADDED
@@ -0,0 +1,363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+ from collections import OrderedDict
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torchvision.utils import save_image
8
+
9
+ from models.archs.fcn_arch import MultiHeadFCNHead
10
+ from models.archs.unet_arch import UNet
11
+ from models.archs.vqgan_arch import (Decoder, DecoderRes, Encoder,
12
+ VectorQuantizerSpatialTextureAware,
13
+ VectorQuantizerTexture)
14
+ from models.losses.accuracy import accuracy
15
+ from models.losses.cross_entropy_loss import CrossEntropyLoss
16
+
17
+ logger = logging.getLogger('base')
18
+
19
+
20
+ class VQGANTextureAwareSpatialHierarchyInferenceModel():
21
+
22
+ def __init__(self, opt):
23
+ self.opt = opt
24
+ self.device = torch.device('cuda')
25
+ self.is_train = opt['is_train']
26
+
27
+ self.top_encoder = Encoder(
28
+ ch=opt['top_ch'],
29
+ num_res_blocks=opt['top_num_res_blocks'],
30
+ attn_resolutions=opt['top_attn_resolutions'],
31
+ ch_mult=opt['top_ch_mult'],
32
+ in_channels=opt['top_in_channels'],
33
+ resolution=opt['top_resolution'],
34
+ z_channels=opt['top_z_channels'],
35
+ double_z=opt['top_double_z'],
36
+ dropout=opt['top_dropout']).to(self.device)
37
+ self.decoder = Decoder(
38
+ in_channels=opt['top_in_channels'],
39
+ resolution=opt['top_resolution'],
40
+ z_channels=opt['top_z_channels'],
41
+ ch=opt['top_ch'],
42
+ out_ch=opt['top_out_ch'],
43
+ num_res_blocks=opt['top_num_res_blocks'],
44
+ attn_resolutions=opt['top_attn_resolutions'],
45
+ ch_mult=opt['top_ch_mult'],
46
+ dropout=opt['top_dropout'],
47
+ resamp_with_conv=True,
48
+ give_pre_end=False).to(self.device)
49
+ self.top_quantize = VectorQuantizerTexture(
50
+ 1024, opt['embed_dim'], beta=0.25).to(self.device)
51
+ self.top_quant_conv = torch.nn.Conv2d(opt["top_z_channels"],
52
+ opt['embed_dim'],
53
+ 1).to(self.device)
54
+ self.top_post_quant_conv = torch.nn.Conv2d(opt['embed_dim'],
55
+ opt["top_z_channels"],
56
+ 1).to(self.device)
57
+ self.load_top_pretrain_models()
58
+
59
+ self.bot_encoder = Encoder(
60
+ ch=opt['bot_ch'],
61
+ num_res_blocks=opt['bot_num_res_blocks'],
62
+ attn_resolutions=opt['bot_attn_resolutions'],
63
+ ch_mult=opt['bot_ch_mult'],
64
+ in_channels=opt['bot_in_channels'],
65
+ resolution=opt['bot_resolution'],
66
+ z_channels=opt['bot_z_channels'],
67
+ double_z=opt['bot_double_z'],
68
+ dropout=opt['bot_dropout']).to(self.device)
69
+ self.bot_decoder_res = DecoderRes(
70
+ in_channels=opt['bot_in_channels'],
71
+ resolution=opt['bot_resolution'],
72
+ z_channels=opt['bot_z_channels'],
73
+ ch=opt['bot_ch'],
74
+ num_res_blocks=opt['bot_num_res_blocks'],
75
+ ch_mult=opt['bot_ch_mult'],
76
+ dropout=opt['bot_dropout'],
77
+ give_pre_end=False).to(self.device)
78
+ self.bot_quantize = VectorQuantizerSpatialTextureAware(
79
+ opt['bot_n_embed'],
80
+ opt['embed_dim'],
81
+ beta=0.25,
82
+ spatial_size=opt['codebook_spatial_size']).to(self.device)
83
+ self.bot_quant_conv = torch.nn.Conv2d(opt["bot_z_channels"],
84
+ opt['embed_dim'],
85
+ 1).to(self.device)
86
+ self.bot_post_quant_conv = torch.nn.Conv2d(opt['embed_dim'],
87
+ opt["bot_z_channels"],
88
+ 1).to(self.device)
89
+
90
+ self.load_bot_pretrain_network()
91
+
92
+ self.guidance_encoder = UNet(
93
+ in_channels=opt['encoder_in_channels']).to(self.device)
94
+ self.index_decoder = MultiHeadFCNHead(
95
+ in_channels=opt['fc_in_channels'],
96
+ in_index=opt['fc_in_index'],
97
+ channels=opt['fc_channels'],
98
+ num_convs=opt['fc_num_convs'],
99
+ concat_input=opt['fc_concat_input'],
100
+ dropout_ratio=opt['fc_dropout_ratio'],
101
+ num_classes=opt['fc_num_classes'],
102
+ align_corners=opt['fc_align_corners'],
103
+ num_head=18).to(self.device)
104
+
105
+ self.init_training_settings()
106
+
107
+ def init_training_settings(self):
108
+ optim_params = []
109
+ for v in self.guidance_encoder.parameters():
110
+ if v.requires_grad:
111
+ optim_params.append(v)
112
+ for v in self.index_decoder.parameters():
113
+ if v.requires_grad:
114
+ optim_params.append(v)
115
+ # set up optimizers
116
+ if self.opt['optimizer'] == 'Adam':
117
+ self.optimizer = torch.optim.Adam(
118
+ optim_params,
119
+ self.opt['lr'],
120
+ weight_decay=self.opt['weight_decay'])
121
+ elif self.opt['optimizer'] == 'SGD':
122
+ self.optimizer = torch.optim.SGD(
123
+ optim_params,
124
+ self.opt['lr'],
125
+ momentum=self.opt['momentum'],
126
+ weight_decay=self.opt['weight_decay'])
127
+ self.log_dict = OrderedDict()
128
+ if self.opt['loss_function'] == 'cross_entropy':
129
+ self.loss_func = CrossEntropyLoss().to(self.device)
130
+
131
+ def load_top_pretrain_models(self):
132
+ # load pretrained vqgan for segmentation mask
133
+ top_vae_checkpoint = torch.load(self.opt['top_vae_path'])
134
+ self.top_encoder.load_state_dict(
135
+ top_vae_checkpoint['encoder'], strict=True)
136
+ self.decoder.load_state_dict(
137
+ top_vae_checkpoint['decoder'], strict=True)
138
+ self.top_quantize.load_state_dict(
139
+ top_vae_checkpoint['quantize'], strict=True)
140
+ self.top_quant_conv.load_state_dict(
141
+ top_vae_checkpoint['quant_conv'], strict=True)
142
+ self.top_post_quant_conv.load_state_dict(
143
+ top_vae_checkpoint['post_quant_conv'], strict=True)
144
+ self.top_encoder.eval()
145
+ self.top_quantize.eval()
146
+ self.top_quant_conv.eval()
147
+ self.top_post_quant_conv.eval()
148
+
149
+ def load_bot_pretrain_network(self):
150
+ checkpoint = torch.load(self.opt['bot_vae_path'])
151
+ self.bot_encoder.load_state_dict(
152
+ checkpoint['bot_encoder'], strict=True)
153
+ self.bot_decoder_res.load_state_dict(
154
+ checkpoint['bot_decoder_res'], strict=True)
155
+ self.decoder.load_state_dict(checkpoint['decoder'], strict=True)
156
+ self.bot_quantize.load_state_dict(
157
+ checkpoint['bot_quantize'], strict=True)
158
+ self.bot_quant_conv.load_state_dict(
159
+ checkpoint['bot_quant_conv'], strict=True)
160
+ self.bot_post_quant_conv.load_state_dict(
161
+ checkpoint['bot_post_quant_conv'], strict=True)
162
+
163
+ self.bot_encoder.eval()
164
+ self.bot_decoder_res.eval()
165
+ self.decoder.eval()
166
+ self.bot_quantize.eval()
167
+ self.bot_quant_conv.eval()
168
+ self.bot_post_quant_conv.eval()
169
+
170
+ def top_encode(self, x, mask):
171
+ h = self.top_encoder(x)
172
+ h = self.top_quant_conv(h)
173
+ quant, _, _ = self.top_quantize(h, mask)
174
+ quant = self.top_post_quant_conv(quant)
175
+
176
+ return quant, quant
177
+
178
+ def feed_data(self, data):
179
+ self.image = data['image'].to(self.device)
180
+ self.texture_mask = data['texture_mask'].float().to(self.device)
181
+ self.get_gt_indices()
182
+
183
+ self.texture_tokens = F.interpolate(
184
+ self.texture_mask, size=(32, 16),
185
+ mode='nearest').view(self.image.size(0), -1).long()
186
+
187
+ def bot_encode(self, x, mask):
188
+ h = self.bot_encoder(x)
189
+ h = self.bot_quant_conv(h)
190
+ _, _, (_, _, indices_list) = self.bot_quantize(h, mask)
191
+
192
+ return indices_list
193
+
194
+ def get_gt_indices(self):
195
+ self.quant_t, self.feature_t = self.top_encode(self.image,
196
+ self.texture_mask)
197
+ self.gt_indices_list = self.bot_encode(self.image, self.texture_mask)
198
+
199
+ def index_to_image(self, index_bottom_list, texture_mask):
200
+ quant_b = self.bot_quantize.get_codebook_entry(
201
+ index_bottom_list, texture_mask,
202
+ (index_bottom_list[0].size(0), index_bottom_list[0].size(1),
203
+ index_bottom_list[0].size(2),
204
+ self.opt["bot_z_channels"])) #.permute(0, 3, 1, 2)
205
+ quant_b = self.bot_post_quant_conv(quant_b)
206
+ bot_dec_res = self.bot_decoder_res(quant_b)
207
+
208
+ dec = self.decoder(self.quant_t, bot_h=bot_dec_res)
209
+
210
+ return dec
211
+
212
+ def get_vis(self, pred_img_index, rec_img_index, texture_mask, save_path):
213
+ rec_img = self.index_to_image(rec_img_index, texture_mask)
214
+ pred_img = self.index_to_image(pred_img_index, texture_mask)
215
+
216
+ base_img = self.decoder(self.quant_t)
217
+ img_cat = torch.cat([
218
+ self.image,
219
+ rec_img,
220
+ base_img,
221
+ pred_img,
222
+ ], dim=3).detach()
223
+ img_cat = ((img_cat + 1) / 2)
224
+ img_cat = img_cat.clamp_(0, 1)
225
+ save_image(img_cat, save_path, nrow=1, padding=4)
226
+
227
+ def optimize_parameters(self):
228
+ self.guidance_encoder.train()
229
+ self.index_decoder.train()
230
+
231
+ self.feature_enc = self.guidance_encoder(self.feature_t)
232
+ self.memory_logits_list = self.index_decoder(self.feature_enc)
233
+
234
+ loss = 0
235
+ for i in range(18):
236
+ loss += self.loss_func(
237
+ self.memory_logits_list[i],
238
+ self.gt_indices_list[i],
239
+ ignore_index=-1)
240
+
241
+ self.optimizer.zero_grad()
242
+ loss.backward()
243
+ self.optimizer.step()
244
+
245
+ self.log_dict['loss_total'] = loss
246
+
247
+ def inference(self, data_loader, save_dir):
248
+ self.guidance_encoder.eval()
249
+ self.index_decoder.eval()
250
+
251
+ acc = 0
252
+ num = 0
253
+
254
+ for _, data in enumerate(data_loader):
255
+ self.feed_data(data)
256
+ img_name = data['img_name']
257
+
258
+ num += self.image.size(0)
259
+
260
+ texture_mask_flatten = self.texture_tokens.view(-1)
261
+ min_encodings_indices_list = [
262
+ torch.full(
263
+ texture_mask_flatten.size(),
264
+ fill_value=-1,
265
+ dtype=torch.long,
266
+ device=texture_mask_flatten.device) for _ in range(18)
267
+ ]
268
+ with torch.no_grad():
269
+ self.feature_enc = self.guidance_encoder(self.feature_t)
270
+ memory_logits_list = self.index_decoder(self.feature_enc)
271
+ # memory_indices_pred = memory_logits.argmax(dim=1)
272
+ batch_acc = 0
273
+ for codebook_idx, memory_logits in enumerate(memory_logits_list):
274
+ region_of_interest = texture_mask_flatten == codebook_idx
275
+ if torch.sum(region_of_interest) > 0:
276
+ memory_indices_pred = memory_logits.argmax(dim=1).view(-1)
277
+ batch_acc += torch.sum(
278
+ memory_indices_pred[region_of_interest] ==
279
+ self.gt_indices_list[codebook_idx].view(
280
+ -1)[region_of_interest])
281
+ memory_indices_pred = memory_indices_pred
282
+ min_encodings_indices_list[codebook_idx][
283
+ region_of_interest] = memory_indices_pred[
284
+ region_of_interest]
285
+ min_encodings_indices_return_list = [
286
+ min_encodings_indices.view(self.gt_indices_list[0].size())
287
+ for min_encodings_indices in min_encodings_indices_list
288
+ ]
289
+ batch_acc = batch_acc / self.gt_indices_list[codebook_idx].numel(
290
+ ) * self.image.size(0)
291
+ acc += batch_acc
292
+ self.get_vis(min_encodings_indices_return_list,
293
+ self.gt_indices_list, self.texture_mask,
294
+ f'{save_dir}/{img_name[0]}')
295
+
296
+ self.guidance_encoder.train()
297
+ self.index_decoder.train()
298
+ return (acc / num).item()
299
+
300
+ def load_network(self):
301
+ checkpoint = torch.load(self.opt['pretrained_models'])
302
+ self.guidance_encoder.load_state_dict(
303
+ checkpoint['guidance_encoder'], strict=True)
304
+ self.guidance_encoder.eval()
305
+
306
+ self.index_decoder.load_state_dict(
307
+ checkpoint['index_decoder'], strict=True)
308
+ self.index_decoder.eval()
309
+
310
+ def save_network(self, save_path):
311
+ """Save networks.
312
+
313
+ Args:
314
+ net (nn.Module): Network to be saved.
315
+ net_label (str): Network label.
316
+ current_iter (int): Current iter number.
317
+ """
318
+
319
+ save_dict = {}
320
+ save_dict['guidance_encoder'] = self.guidance_encoder.state_dict()
321
+ save_dict['index_decoder'] = self.index_decoder.state_dict()
322
+
323
+ torch.save(save_dict, save_path)
324
+
325
+ def update_learning_rate(self, epoch):
326
+ """Update learning rate.
327
+
328
+ Args:
329
+ current_iter (int): Current iteration.
330
+ warmup_iter (int): Warmup iter numbers. -1 for no warmup.
331
+ Default: -1.
332
+ """
333
+ lr = self.optimizer.param_groups[0]['lr']
334
+
335
+ if self.opt['lr_decay'] == 'step':
336
+ lr = self.opt['lr'] * (
337
+ self.opt['gamma']**(epoch // self.opt['step']))
338
+ elif self.opt['lr_decay'] == 'cos':
339
+ lr = self.opt['lr'] * (
340
+ 1 + math.cos(math.pi * epoch / self.opt['num_epochs'])) / 2
341
+ elif self.opt['lr_decay'] == 'linear':
342
+ lr = self.opt['lr'] * (1 - epoch / self.opt['num_epochs'])
343
+ elif self.opt['lr_decay'] == 'linear2exp':
344
+ if epoch < self.opt['turning_point'] + 1:
345
+ # learning rate decay as 95%
346
+ # at the turning point (1 / 95% = 1.0526)
347
+ lr = self.opt['lr'] * (
348
+ 1 - epoch / int(self.opt['turning_point'] * 1.0526))
349
+ else:
350
+ lr *= self.opt['gamma']
351
+ elif self.opt['lr_decay'] == 'schedule':
352
+ if epoch in self.opt['schedule']:
353
+ lr *= self.opt['gamma']
354
+ else:
355
+ raise ValueError('Unknown lr mode {}'.format(self.opt['lr_decay']))
356
+ # set learning rate
357
+ for param_group in self.optimizer.param_groups:
358
+ param_group['lr'] = lr
359
+
360
+ return lr
361
+
362
+ def get_current_log(self):
363
+ return self.log_dict
models/hierarchy_vqgan_model.py ADDED
@@ -0,0 +1,374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import sys
3
+ from collections import OrderedDict
4
+
5
+ sys.path.append('..')
6
+ import lpips
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from torchvision.utils import save_image
10
+
11
+ from models.archs.vqgan_arch import (Decoder, DecoderRes, Discriminator,
12
+ Encoder,
13
+ VectorQuantizerSpatialTextureAware,
14
+ VectorQuantizerTexture)
15
+ from models.losses.vqgan_loss import (DiffAugment, adopt_weight,
16
+ calculate_adaptive_weight, hinge_d_loss)
17
+
18
+
19
+ class HierarchyVQSpatialTextureAwareModel():
20
+
21
+ def __init__(self, opt):
22
+ self.opt = opt
23
+ self.device = torch.device('cuda')
24
+ self.top_encoder = Encoder(
25
+ ch=opt['top_ch'],
26
+ num_res_blocks=opt['top_num_res_blocks'],
27
+ attn_resolutions=opt['top_attn_resolutions'],
28
+ ch_mult=opt['top_ch_mult'],
29
+ in_channels=opt['top_in_channels'],
30
+ resolution=opt['top_resolution'],
31
+ z_channels=opt['top_z_channels'],
32
+ double_z=opt['top_double_z'],
33
+ dropout=opt['top_dropout']).to(self.device)
34
+ self.decoder = Decoder(
35
+ in_channels=opt['top_in_channels'],
36
+ resolution=opt['top_resolution'],
37
+ z_channels=opt['top_z_channels'],
38
+ ch=opt['top_ch'],
39
+ out_ch=opt['top_out_ch'],
40
+ num_res_blocks=opt['top_num_res_blocks'],
41
+ attn_resolutions=opt['top_attn_resolutions'],
42
+ ch_mult=opt['top_ch_mult'],
43
+ dropout=opt['top_dropout'],
44
+ resamp_with_conv=True,
45
+ give_pre_end=False).to(self.device)
46
+ self.top_quantize = VectorQuantizerTexture(
47
+ 1024, opt['embed_dim'], beta=0.25).to(self.device)
48
+ self.top_quant_conv = torch.nn.Conv2d(opt["top_z_channels"],
49
+ opt['embed_dim'],
50
+ 1).to(self.device)
51
+ self.top_post_quant_conv = torch.nn.Conv2d(opt['embed_dim'],
52
+ opt["top_z_channels"],
53
+ 1).to(self.device)
54
+ self.load_top_pretrain_models()
55
+
56
+ self.bot_encoder = Encoder(
57
+ ch=opt['bot_ch'],
58
+ num_res_blocks=opt['bot_num_res_blocks'],
59
+ attn_resolutions=opt['bot_attn_resolutions'],
60
+ ch_mult=opt['bot_ch_mult'],
61
+ in_channels=opt['bot_in_channels'],
62
+ resolution=opt['bot_resolution'],
63
+ z_channels=opt['bot_z_channels'],
64
+ double_z=opt['bot_double_z'],
65
+ dropout=opt['bot_dropout']).to(self.device)
66
+ self.bot_decoder_res = DecoderRes(
67
+ in_channels=opt['bot_in_channels'],
68
+ resolution=opt['bot_resolution'],
69
+ z_channels=opt['bot_z_channels'],
70
+ ch=opt['bot_ch'],
71
+ num_res_blocks=opt['bot_num_res_blocks'],
72
+ ch_mult=opt['bot_ch_mult'],
73
+ dropout=opt['bot_dropout'],
74
+ give_pre_end=False).to(self.device)
75
+ self.bot_quantize = VectorQuantizerSpatialTextureAware(
76
+ opt['bot_n_embed'],
77
+ opt['embed_dim'],
78
+ beta=0.25,
79
+ spatial_size=opt['codebook_spatial_size']).to(self.device)
80
+ self.bot_quant_conv = torch.nn.Conv2d(opt["bot_z_channels"],
81
+ opt['embed_dim'],
82
+ 1).to(self.device)
83
+ self.bot_post_quant_conv = torch.nn.Conv2d(opt['embed_dim'],
84
+ opt["bot_z_channels"],
85
+ 1).to(self.device)
86
+
87
+ self.disc = Discriminator(
88
+ opt['n_channels'], opt['ndf'],
89
+ n_layers=opt['disc_layers']).to(self.device)
90
+ self.perceptual = lpips.LPIPS(net="vgg").to(self.device)
91
+ self.perceptual_weight = opt['perceptual_weight']
92
+ self.disc_start_step = opt['disc_start_step']
93
+ self.disc_weight_max = opt['disc_weight_max']
94
+ self.diff_aug = opt['diff_aug']
95
+ self.policy = "color,translation"
96
+
97
+ self.load_discriminator_models()
98
+
99
+ self.disc.train()
100
+
101
+ self.fix_decoder = opt['fix_decoder']
102
+
103
+ self.init_training_settings()
104
+
105
+ def load_top_pretrain_models(self):
106
+ # load pretrained vqgan for segmentation mask
107
+ top_vae_checkpoint = torch.load(self.opt['top_vae_path'])
108
+ self.top_encoder.load_state_dict(
109
+ top_vae_checkpoint['encoder'], strict=True)
110
+ self.decoder.load_state_dict(
111
+ top_vae_checkpoint['decoder'], strict=True)
112
+ self.top_quantize.load_state_dict(
113
+ top_vae_checkpoint['quantize'], strict=True)
114
+ self.top_quant_conv.load_state_dict(
115
+ top_vae_checkpoint['quant_conv'], strict=True)
116
+ self.top_post_quant_conv.load_state_dict(
117
+ top_vae_checkpoint['post_quant_conv'], strict=True)
118
+ self.top_encoder.eval()
119
+ self.top_quantize.eval()
120
+ self.top_quant_conv.eval()
121
+ self.top_post_quant_conv.eval()
122
+
123
+ def init_training_settings(self):
124
+ self.log_dict = OrderedDict()
125
+ self.configure_optimizers()
126
+
127
+ def configure_optimizers(self):
128
+ optim_params = []
129
+ for v in self.bot_encoder.parameters():
130
+ if v.requires_grad:
131
+ optim_params.append(v)
132
+ for v in self.bot_decoder_res.parameters():
133
+ if v.requires_grad:
134
+ optim_params.append(v)
135
+ for v in self.bot_quantize.parameters():
136
+ if v.requires_grad:
137
+ optim_params.append(v)
138
+ for v in self.bot_quant_conv.parameters():
139
+ if v.requires_grad:
140
+ optim_params.append(v)
141
+ for v in self.bot_post_quant_conv.parameters():
142
+ if v.requires_grad:
143
+ optim_params.append(v)
144
+ if not self.fix_decoder:
145
+ for name, v in self.decoder.named_parameters():
146
+ if v.requires_grad:
147
+ if 'up.0' in name:
148
+ optim_params.append(v)
149
+ if 'up.1' in name:
150
+ optim_params.append(v)
151
+ if 'up.2' in name:
152
+ optim_params.append(v)
153
+ if 'up.3' in name:
154
+ optim_params.append(v)
155
+
156
+ self.optimizer = torch.optim.Adam(optim_params, lr=self.opt['lr'])
157
+
158
+ self.disc_optimizer = torch.optim.Adam(
159
+ self.disc.parameters(), lr=self.opt['lr'])
160
+
161
+ def load_discriminator_models(self):
162
+ # load pretrained vqgan for segmentation mask
163
+ top_vae_checkpoint = torch.load(self.opt['top_vae_path'])
164
+ self.disc.load_state_dict(
165
+ top_vae_checkpoint['discriminator'], strict=True)
166
+
167
+ def save_network(self, save_path):
168
+ """Save networks.
169
+ """
170
+
171
+ save_dict = {}
172
+ save_dict['bot_encoder'] = self.bot_encoder.state_dict()
173
+ save_dict['bot_decoder_res'] = self.bot_decoder_res.state_dict()
174
+ save_dict['decoder'] = self.decoder.state_dict()
175
+ save_dict['bot_quantize'] = self.bot_quantize.state_dict()
176
+ save_dict['bot_quant_conv'] = self.bot_quant_conv.state_dict()
177
+ save_dict['bot_post_quant_conv'] = self.bot_post_quant_conv.state_dict(
178
+ )
179
+ save_dict['discriminator'] = self.disc.state_dict()
180
+ torch.save(save_dict, save_path)
181
+
182
+ def load_network(self):
183
+ checkpoint = torch.load(self.opt['pretrained_models'])
184
+ self.bot_encoder.load_state_dict(
185
+ checkpoint['bot_encoder'], strict=True)
186
+ self.bot_decoder_res.load_state_dict(
187
+ checkpoint['bot_decoder_res'], strict=True)
188
+ self.decoder.load_state_dict(checkpoint['decoder'], strict=True)
189
+ self.bot_quantize.load_state_dict(
190
+ checkpoint['bot_quantize'], strict=True)
191
+ self.bot_quant_conv.load_state_dict(
192
+ checkpoint['bot_quant_conv'], strict=True)
193
+ self.bot_post_quant_conv.load_state_dict(
194
+ checkpoint['bot_post_quant_conv'], strict=True)
195
+
196
+ def optimize_parameters(self, data, step):
197
+ self.bot_encoder.train()
198
+ self.bot_decoder_res.train()
199
+ if not self.fix_decoder:
200
+ self.decoder.train()
201
+ self.bot_quantize.train()
202
+ self.bot_quant_conv.train()
203
+ self.bot_post_quant_conv.train()
204
+
205
+ loss, d_loss = self.training_step(data, step)
206
+ self.optimizer.zero_grad()
207
+ loss.backward()
208
+ self.optimizer.step()
209
+
210
+ if step > self.disc_start_step:
211
+ self.disc_optimizer.zero_grad()
212
+ d_loss.backward()
213
+ self.disc_optimizer.step()
214
+
215
+ def top_encode(self, x, mask):
216
+ h = self.top_encoder(x)
217
+ h = self.top_quant_conv(h)
218
+ quant, _, _ = self.top_quantize(h, mask)
219
+ quant = self.top_post_quant_conv(quant)
220
+ return quant
221
+
222
+ def bot_encode(self, x, mask):
223
+ h = self.bot_encoder(x)
224
+ h = self.bot_quant_conv(h)
225
+ quant, emb_loss, info = self.bot_quantize(h, mask)
226
+ quant = self.bot_post_quant_conv(quant)
227
+ bot_dec_res = self.bot_decoder_res(quant)
228
+ return bot_dec_res, emb_loss, info
229
+
230
+ def decode(self, quant_top, bot_dec_res):
231
+ dec = self.decoder(quant_top, bot_h=bot_dec_res)
232
+ return dec
233
+
234
+ def forward_step(self, input, mask):
235
+ with torch.no_grad():
236
+ quant_top = self.top_encode(input, mask)
237
+ bot_dec_res, diff, _ = self.bot_encode(input, mask)
238
+ dec = self.decode(quant_top, bot_dec_res)
239
+ return dec, diff
240
+
241
+ def feed_data(self, data):
242
+ x = data['image'].float().to(self.device)
243
+ mask = data['texture_mask'].float().to(self.device)
244
+
245
+ return x, mask
246
+
247
+ def training_step(self, data, step):
248
+ x, mask = self.feed_data(data)
249
+ xrec, codebook_loss = self.forward_step(x, mask)
250
+
251
+ # get recon/perceptual loss
252
+ recon_loss = torch.abs(x.contiguous() - xrec.contiguous())
253
+ p_loss = self.perceptual(x.contiguous(), xrec.contiguous())
254
+ nll_loss = recon_loss + self.perceptual_weight * p_loss
255
+ nll_loss = torch.mean(nll_loss)
256
+
257
+ # augment for input to discriminator
258
+ if self.diff_aug:
259
+ xrec = DiffAugment(xrec, policy=self.policy)
260
+
261
+ # update generator
262
+ logits_fake = self.disc(xrec)
263
+ g_loss = -torch.mean(logits_fake)
264
+ last_layer = self.decoder.conv_out.weight
265
+ d_weight = calculate_adaptive_weight(nll_loss, g_loss, last_layer,
266
+ self.disc_weight_max)
267
+ d_weight *= adopt_weight(1, step, self.disc_start_step)
268
+ loss = nll_loss + d_weight * g_loss + codebook_loss
269
+
270
+ self.log_dict["loss"] = loss
271
+ self.log_dict["l1"] = recon_loss.mean().item()
272
+ self.log_dict["perceptual"] = p_loss.mean().item()
273
+ self.log_dict["nll_loss"] = nll_loss.item()
274
+ self.log_dict["g_loss"] = g_loss.item()
275
+ self.log_dict["d_weight"] = d_weight
276
+ self.log_dict["codebook_loss"] = codebook_loss.item()
277
+
278
+ if step > self.disc_start_step:
279
+ if self.diff_aug:
280
+ logits_real = self.disc(
281
+ DiffAugment(x.contiguous().detach(), policy=self.policy))
282
+ else:
283
+ logits_real = self.disc(x.contiguous().detach())
284
+ logits_fake = self.disc(xrec.contiguous().detach(
285
+ )) # detach so that generator isn"t also updated
286
+ d_loss = hinge_d_loss(logits_real, logits_fake)
287
+ self.log_dict["d_loss"] = d_loss
288
+ else:
289
+ d_loss = None
290
+
291
+ return loss, d_loss
292
+
293
+ @torch.no_grad()
294
+ def inference(self, data_loader, save_dir):
295
+ self.bot_encoder.eval()
296
+ self.bot_decoder_res.eval()
297
+ self.decoder.eval()
298
+ self.bot_quantize.eval()
299
+ self.bot_quant_conv.eval()
300
+ self.bot_post_quant_conv.eval()
301
+
302
+ loss_total = 0
303
+ num = 0
304
+
305
+ for _, data in enumerate(data_loader):
306
+ img_name = data['img_name'][0]
307
+ x, mask = self.feed_data(data)
308
+ xrec, _ = self.forward_step(x, mask)
309
+
310
+ recon_loss = torch.abs(x.contiguous() - xrec.contiguous())
311
+ p_loss = self.perceptual(x.contiguous(), xrec.contiguous())
312
+ nll_loss = recon_loss + self.perceptual_weight * p_loss
313
+ nll_loss = torch.mean(nll_loss)
314
+ loss_total += nll_loss
315
+
316
+ num += x.size(0)
317
+
318
+ if x.shape[1] > 3:
319
+ # colorize with random projection
320
+ assert xrec.shape[1] > 3
321
+ # convert logits to indices
322
+ xrec = torch.argmax(xrec, dim=1, keepdim=True)
323
+ xrec = F.one_hot(xrec, num_classes=x.shape[1])
324
+ xrec = xrec.squeeze(1).permute(0, 3, 1, 2).float()
325
+ x = self.to_rgb(x)
326
+ xrec = self.to_rgb(xrec)
327
+
328
+ img_cat = torch.cat([x, xrec], dim=3).detach()
329
+ img_cat = ((img_cat + 1) / 2)
330
+ img_cat = img_cat.clamp_(0, 1)
331
+ save_image(
332
+ img_cat, f'{save_dir}/{img_name}.png', nrow=1, padding=4)
333
+
334
+ return (loss_total / num).item()
335
+
336
+ def get_current_log(self):
337
+ return self.log_dict
338
+
339
+ def update_learning_rate(self, epoch):
340
+ """Update learning rate.
341
+
342
+ Args:
343
+ current_iter (int): Current iteration.
344
+ warmup_iter (int): Warmup iter numbers. -1 for no warmup.
345
+ Default: -1.
346
+ """
347
+ lr = self.optimizer.param_groups[0]['lr']
348
+
349
+ if self.opt['lr_decay'] == 'step':
350
+ lr = self.opt['lr'] * (
351
+ self.opt['gamma']**(epoch // self.opt['step']))
352
+ elif self.opt['lr_decay'] == 'cos':
353
+ lr = self.opt['lr'] * (
354
+ 1 + math.cos(math.pi * epoch / self.opt['num_epochs'])) / 2
355
+ elif self.opt['lr_decay'] == 'linear':
356
+ lr = self.opt['lr'] * (1 - epoch / self.opt['num_epochs'])
357
+ elif self.opt['lr_decay'] == 'linear2exp':
358
+ if epoch < self.opt['turning_point'] + 1:
359
+ # learning rate decay as 95%
360
+ # at the turning point (1 / 95% = 1.0526)
361
+ lr = self.opt['lr'] * (
362
+ 1 - epoch / int(self.opt['turning_point'] * 1.0526))
363
+ else:
364
+ lr *= self.opt['gamma']
365
+ elif self.opt['lr_decay'] == 'schedule':
366
+ if epoch in self.opt['schedule']:
367
+ lr *= self.opt['gamma']
368
+ else:
369
+ raise ValueError('Unknown lr mode {}'.format(self.opt['lr_decay']))
370
+ # set learning rate
371
+ for param_group in self.optimizer.param_groups:
372
+ param_group['lr'] = lr
373
+
374
+ return lr
models/losses/__init__.py ADDED
File without changes
models/losses/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (127 Bytes). View file
 
models/losses/__pycache__/accuracy.cpython-38.pyc ADDED
Binary file (2.02 kB). View file
 
models/losses/__pycache__/cross_entropy_loss.cpython-38.pyc ADDED
Binary file (6.76 kB). View file
 
models/losses/__pycache__/segmentation_loss.cpython-38.pyc ADDED
Binary file (1.3 kB). View file
 
models/losses/__pycache__/vqgan_loss.cpython-38.pyc ADDED
Binary file (3.65 kB). View file
 
models/losses/accuracy.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def accuracy(pred, target, topk=1, thresh=None):
2
+ """Calculate accuracy according to the prediction and target.
3
+
4
+ Args:
5
+ pred (torch.Tensor): The model prediction, shape (N, num_class, ...)
6
+ target (torch.Tensor): The target of each prediction, shape (N, , ...)
7
+ topk (int | tuple[int], optional): If the predictions in ``topk``
8
+ matches the target, the predictions will be regarded as
9
+ correct ones. Defaults to 1.
10
+ thresh (float, optional): If not None, predictions with scores under
11
+ this threshold are considered incorrect. Default to None.
12
+
13
+ Returns:
14
+ float | tuple[float]: If the input ``topk`` is a single integer,
15
+ the function will return a single float as accuracy. If
16
+ ``topk`` is a tuple containing multiple integers, the
17
+ function will return a tuple containing accuracies of
18
+ each ``topk`` number.
19
+ """
20
+ assert isinstance(topk, (int, tuple))
21
+ if isinstance(topk, int):
22
+ topk = (topk, )
23
+ return_single = True
24
+ else:
25
+ return_single = False
26
+
27
+ maxk = max(topk)
28
+ if pred.size(0) == 0:
29
+ accu = [pred.new_tensor(0.) for i in range(len(topk))]
30
+ return accu[0] if return_single else accu
31
+ assert pred.ndim == target.ndim + 1
32
+ assert pred.size(0) == target.size(0)
33
+ assert maxk <= pred.size(1), \
34
+ f'maxk {maxk} exceeds pred dimension {pred.size(1)}'
35
+ pred_value, pred_label = pred.topk(maxk, dim=1)
36
+ # transpose to shape (maxk, N, ...)
37
+ pred_label = pred_label.transpose(0, 1)
38
+ correct = pred_label.eq(target.unsqueeze(0).expand_as(pred_label))
39
+ if thresh is not None:
40
+ # Only prediction values larger than thresh are counted as correct
41
+ correct = correct & (pred_value > thresh).t()
42
+ res = []
43
+ for k in topk:
44
+ correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
45
+ res.append(correct_k.mul_(100.0 / target.numel()))
46
+ return res[0] if return_single else res
models/losses/cross_entropy_loss.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ def reduce_loss(loss, reduction):
7
+ """Reduce loss as specified.
8
+
9
+ Args:
10
+ loss (Tensor): Elementwise loss tensor.
11
+ reduction (str): Options are "none", "mean" and "sum".
12
+
13
+ Return:
14
+ Tensor: Reduced loss tensor.
15
+ """
16
+ reduction_enum = F._Reduction.get_enum(reduction)
17
+ # none: 0, elementwise_mean:1, sum: 2
18
+ if reduction_enum == 0:
19
+ return loss
20
+ elif reduction_enum == 1:
21
+ return loss.mean()
22
+ elif reduction_enum == 2:
23
+ return loss.sum()
24
+
25
+
26
+ def weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None):
27
+ """Apply element-wise weight and reduce loss.
28
+
29
+ Args:
30
+ loss (Tensor): Element-wise loss.
31
+ weight (Tensor): Element-wise weights.
32
+ reduction (str): Same as built-in losses of PyTorch.
33
+ avg_factor (float): Avarage factor when computing the mean of losses.
34
+
35
+ Returns:
36
+ Tensor: Processed loss values.
37
+ """
38
+ # if weight is specified, apply element-wise weight
39
+ if weight is not None:
40
+ assert weight.dim() == loss.dim()
41
+ if weight.dim() > 1:
42
+ assert weight.size(1) == 1 or weight.size(1) == loss.size(1)
43
+ loss = loss * weight
44
+
45
+ # if avg_factor is not specified, just reduce the loss
46
+ if avg_factor is None:
47
+ loss = reduce_loss(loss, reduction)
48
+ else:
49
+ # if reduction is mean, then average the loss by avg_factor
50
+ if reduction == 'mean':
51
+ loss = loss.sum() / avg_factor
52
+ # if reduction is 'none', then do nothing, otherwise raise an error
53
+ elif reduction != 'none':
54
+ raise ValueError('avg_factor can not be used with reduction="sum"')
55
+ return loss
56
+
57
+
58
+ def cross_entropy(pred,
59
+ label,
60
+ weight=None,
61
+ class_weight=None,
62
+ reduction='mean',
63
+ avg_factor=None,
64
+ ignore_index=-100):
65
+ """The wrapper function for :func:`F.cross_entropy`"""
66
+ # class_weight is a manual rescaling weight given to each class.
67
+ # If given, has to be a Tensor of size C element-wise losses
68
+ loss = F.cross_entropy(
69
+ pred,
70
+ label,
71
+ weight=class_weight,
72
+ reduction='none',
73
+ ignore_index=ignore_index)
74
+
75
+ # apply weights and do the reduction
76
+ if weight is not None:
77
+ weight = weight.float()
78
+ loss = weight_reduce_loss(
79
+ loss, weight=weight, reduction=reduction, avg_factor=avg_factor)
80
+
81
+ return loss
82
+
83
+
84
+ def _expand_onehot_labels(labels, label_weights, target_shape, ignore_index):
85
+ """Expand onehot labels to match the size of prediction."""
86
+ bin_labels = labels.new_zeros(target_shape)
87
+ valid_mask = (labels >= 0) & (labels != ignore_index)
88
+ inds = torch.nonzero(valid_mask, as_tuple=True)
89
+
90
+ if inds[0].numel() > 0:
91
+ if labels.dim() == 3:
92
+ bin_labels[inds[0], labels[valid_mask], inds[1], inds[2]] = 1
93
+ else:
94
+ bin_labels[inds[0], labels[valid_mask]] = 1
95
+
96
+ valid_mask = valid_mask.unsqueeze(1).expand(target_shape).float()
97
+ if label_weights is None:
98
+ bin_label_weights = valid_mask
99
+ else:
100
+ bin_label_weights = label_weights.unsqueeze(1).expand(target_shape)
101
+ bin_label_weights *= valid_mask
102
+
103
+ return bin_labels, bin_label_weights
104
+
105
+
106
+ def binary_cross_entropy(pred,
107
+ label,
108
+ weight=None,
109
+ reduction='mean',
110
+ avg_factor=None,
111
+ class_weight=None,
112
+ ignore_index=255):
113
+ """Calculate the binary CrossEntropy loss.
114
+
115
+ Args:
116
+ pred (torch.Tensor): The prediction with shape (N, 1).
117
+ label (torch.Tensor): The learning label of the prediction.
118
+ weight (torch.Tensor, optional): Sample-wise loss weight.
119
+ reduction (str, optional): The method used to reduce the loss.
120
+ Options are "none", "mean" and "sum".
121
+ avg_factor (int, optional): Average factor that is used to average
122
+ the loss. Defaults to None.
123
+ class_weight (list[float], optional): The weight for each class.
124
+ ignore_index (int | None): The label index to be ignored. Default: 255
125
+
126
+ Returns:
127
+ torch.Tensor: The calculated loss
128
+ """
129
+ if pred.dim() != label.dim():
130
+ assert (pred.dim() == 2 and label.dim() == 1) or (
131
+ pred.dim() == 4 and label.dim() == 3), \
132
+ 'Only pred shape [N, C], label shape [N] or pred shape [N, C, ' \
133
+ 'H, W], label shape [N, H, W] are supported'
134
+ label, weight = _expand_onehot_labels(label, weight, pred.shape,
135
+ ignore_index)
136
+
137
+ # weighted element-wise losses
138
+ if weight is not None:
139
+ weight = weight.float()
140
+ loss = F.binary_cross_entropy_with_logits(
141
+ pred, label.float(), pos_weight=class_weight, reduction='none')
142
+ # do the reduction for the weighted loss
143
+ loss = weight_reduce_loss(
144
+ loss, weight, reduction=reduction, avg_factor=avg_factor)
145
+
146
+ return loss
147
+
148
+
149
+ def mask_cross_entropy(pred,
150
+ target,
151
+ label,
152
+ reduction='mean',
153
+ avg_factor=None,
154
+ class_weight=None,
155
+ ignore_index=None):
156
+ """Calculate the CrossEntropy loss for masks.
157
+
158
+ Args:
159
+ pred (torch.Tensor): The prediction with shape (N, C), C is the number
160
+ of classes.
161
+ target (torch.Tensor): The learning label of the prediction.
162
+ label (torch.Tensor): ``label`` indicates the class label of the mask'
163
+ corresponding object. This will be used to select the mask in the
164
+ of the class which the object belongs to when the mask prediction
165
+ if not class-agnostic.
166
+ reduction (str, optional): The method used to reduce the loss.
167
+ Options are "none", "mean" and "sum".
168
+ avg_factor (int, optional): Average factor that is used to average
169
+ the loss. Defaults to None.
170
+ class_weight (list[float], optional): The weight for each class.
171
+ ignore_index (None): Placeholder, to be consistent with other loss.
172
+ Default: None.
173
+
174
+ Returns:
175
+ torch.Tensor: The calculated loss
176
+ """
177
+ assert ignore_index is None, 'BCE loss does not support ignore_index'
178
+ # TODO: handle these two reserved arguments
179
+ assert reduction == 'mean' and avg_factor is None
180
+ num_rois = pred.size()[0]
181
+ inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device)
182
+ pred_slice = pred[inds, label].squeeze(1)
183
+ return F.binary_cross_entropy_with_logits(
184
+ pred_slice, target, weight=class_weight, reduction='mean')[None]
185
+
186
+
187
+ class CrossEntropyLoss(nn.Module):
188
+ """CrossEntropyLoss.
189
+
190
+ Args:
191
+ use_sigmoid (bool, optional): Whether the prediction uses sigmoid
192
+ of softmax. Defaults to False.
193
+ use_mask (bool, optional): Whether to use mask cross entropy loss.
194
+ Defaults to False.
195
+ reduction (str, optional): . Defaults to 'mean'.
196
+ Options are "none", "mean" and "sum".
197
+ class_weight (list[float], optional): Weight of each class.
198
+ Defaults to None.
199
+ loss_weight (float, optional): Weight of the loss. Defaults to 1.0.
200
+ """
201
+
202
+ def __init__(self,
203
+ use_sigmoid=False,
204
+ use_mask=False,
205
+ reduction='mean',
206
+ class_weight=None,
207
+ loss_weight=1.0):
208
+ super(CrossEntropyLoss, self).__init__()
209
+ assert (use_sigmoid is False) or (use_mask is False)
210
+ self.use_sigmoid = use_sigmoid
211
+ self.use_mask = use_mask
212
+ self.reduction = reduction
213
+ self.loss_weight = loss_weight
214
+ self.class_weight = class_weight
215
+
216
+ if self.use_sigmoid:
217
+ self.cls_criterion = binary_cross_entropy
218
+ elif self.use_mask:
219
+ self.cls_criterion = mask_cross_entropy
220
+ else:
221
+ self.cls_criterion = cross_entropy
222
+
223
+ def forward(self,
224
+ cls_score,
225
+ label,
226
+ weight=None,
227
+ avg_factor=None,
228
+ reduction_override=None,
229
+ **kwargs):
230
+ """Forward function."""
231
+ assert reduction_override in (None, 'none', 'mean', 'sum')
232
+ reduction = (
233
+ reduction_override if reduction_override else self.reduction)
234
+ if self.class_weight is not None:
235
+ class_weight = cls_score.new_tensor(self.class_weight)
236
+ else:
237
+ class_weight = None
238
+ loss_cls = self.loss_weight * self.cls_criterion(
239
+ cls_score,
240
+ label,
241
+ weight,
242
+ class_weight=class_weight,
243
+ reduction=reduction,
244
+ avg_factor=avg_factor,
245
+ **kwargs)
246
+ return loss_cls
models/losses/segmentation_loss.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch.nn.functional as F
3
+
4
+
5
+ class BCELoss(nn.Module):
6
+
7
+ def forward(self, prediction, target):
8
+ loss = F.binary_cross_entropy_with_logits(prediction, target)
9
+ return loss, {}
10
+
11
+
12
+ class BCELossWithQuant(nn.Module):
13
+
14
+ def __init__(self, codebook_weight=1.):
15
+ super().__init__()
16
+ self.codebook_weight = codebook_weight
17
+
18
+ def forward(self, qloss, target, prediction, split):
19
+ bce_loss = F.binary_cross_entropy_with_logits(prediction, target)
20
+ loss = bce_loss + self.codebook_weight * qloss
21
+ return loss, {
22
+ "{}/total_loss".format(split): loss.clone().detach().mean(),
23
+ "{}/bce_loss".format(split): bce_loss.detach().mean(),
24
+ "{}/quant_loss".format(split): qloss.detach().mean()
25
+ }
models/losses/vqgan_loss.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+
5
+ def calculate_adaptive_weight(recon_loss, g_loss, last_layer, disc_weight_max):
6
+ recon_grads = torch.autograd.grad(
7
+ recon_loss, last_layer, retain_graph=True)[0]
8
+ g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
9
+
10
+ d_weight = torch.norm(recon_grads) / (torch.norm(g_grads) + 1e-4)
11
+ d_weight = torch.clamp(d_weight, 0.0, disc_weight_max).detach()
12
+ return d_weight
13
+
14
+
15
+ def adopt_weight(weight, global_step, threshold=0, value=0.):
16
+ if global_step < threshold:
17
+ weight = value
18
+ return weight
19
+
20
+
21
+ @torch.jit.script
22
+ def hinge_d_loss(logits_real, logits_fake):
23
+ loss_real = torch.mean(F.relu(1. - logits_real))
24
+ loss_fake = torch.mean(F.relu(1. + logits_fake))
25
+ d_loss = 0.5 * (loss_real + loss_fake)
26
+ return d_loss
27
+
28
+
29
+ def DiffAugment(x, policy='', channels_first=True):
30
+ if policy:
31
+ if not channels_first:
32
+ x = x.permute(0, 3, 1, 2)
33
+ for p in policy.split(','):
34
+ for f in AUGMENT_FNS[p]:
35
+ x = f(x)
36
+ if not channels_first:
37
+ x = x.permute(0, 2, 3, 1)
38
+ x = x.contiguous()
39
+ return x
40
+
41
+
42
+ def rand_brightness(x):
43
+ x = x + (
44
+ torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5)
45
+ return x
46
+
47
+
48
+ def rand_saturation(x):
49
+ x_mean = x.mean(dim=1, keepdim=True)
50
+ x = (x - x_mean) * (torch.rand(
51
+ x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean
52
+ return x
53
+
54
+
55
+ def rand_contrast(x):
56
+ x_mean = x.mean(dim=[1, 2, 3], keepdim=True)
57
+ x = (x - x_mean) * (torch.rand(
58
+ x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean
59
+ return x
60
+
61
+
62
+ def rand_translation(x, ratio=0.125):
63
+ shift_x, shift_y = int(x.size(2) * ratio +
64
+ 0.5), int(x.size(3) * ratio + 0.5)
65
+ translation_x = torch.randint(
66
+ -shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device)
67
+ translation_y = torch.randint(
68
+ -shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device)
69
+ grid_batch, grid_x, grid_y = torch.meshgrid(
70
+ torch.arange(x.size(0), dtype=torch.long, device=x.device),
71
+ torch.arange(x.size(2), dtype=torch.long, device=x.device),
72
+ torch.arange(x.size(3), dtype=torch.long, device=x.device),
73
+ )
74
+ grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1)
75
+ grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1)
76
+ x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0])
77
+ x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x,
78
+ grid_y].permute(0, 3, 1, 2)
79
+ return x
80
+
81
+
82
+ def rand_cutout(x, ratio=0.5):
83
+ cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
84
+ offset_x = torch.randint(
85
+ 0,
86
+ x.size(2) + (1 - cutout_size[0] % 2),
87
+ size=[x.size(0), 1, 1],
88
+ device=x.device)
89
+ offset_y = torch.randint(
90
+ 0,
91
+ x.size(3) + (1 - cutout_size[1] % 2),
92
+ size=[x.size(0), 1, 1],
93
+ device=x.device)
94
+ grid_batch, grid_x, grid_y = torch.meshgrid(
95
+ torch.arange(x.size(0), dtype=torch.long, device=x.device),
96
+ torch.arange(cutout_size[0], dtype=torch.long, device=x.device),
97
+ torch.arange(cutout_size[1], dtype=torch.long, device=x.device),
98
+ )
99
+ grid_x = torch.clamp(
100
+ grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1)
101
+ grid_y = torch.clamp(
102
+ grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1)
103
+ mask = torch.ones(
104
+ x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device)
105
+ mask[grid_batch, grid_x, grid_y] = 0
106
+ x = x * mask.unsqueeze(1)
107
+ return x
108
+
109
+
110
+ AUGMENT_FNS = {
111
+ 'color': [rand_brightness, rand_saturation, rand_contrast],
112
+ 'translation': [rand_translation],
113
+ 'cutout': [rand_cutout],
114
+ }
models/parsing_gen_model.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+ from collections import OrderedDict
4
+
5
+ import mmcv
6
+ import numpy as np
7
+ import torch
8
+ from torchvision.utils import save_image
9
+
10
+ from models.archs.fcn_arch import FCNHead
11
+ from models.archs.shape_attr_embedding_arch import ShapeAttrEmbedding
12
+ from models.archs.unet_arch import ShapeUNet
13
+ from models.losses.accuracy import accuracy
14
+ from models.losses.cross_entropy_loss import CrossEntropyLoss
15
+
16
+ logger = logging.getLogger('base')
17
+
18
+
19
+ class ParsingGenModel():
20
+ """Paring Generation model.
21
+ """
22
+
23
+ def __init__(self, opt):
24
+ self.opt = opt
25
+ self.device = torch.device('cuda')
26
+ self.is_train = opt['is_train']
27
+
28
+ self.attr_embedder = ShapeAttrEmbedding(
29
+ dim=opt['embedder_dim'],
30
+ out_dim=opt['embedder_out_dim'],
31
+ cls_num_list=opt['attr_class_num']).to(self.device)
32
+ self.parsing_encoder = ShapeUNet(
33
+ in_channels=opt['encoder_in_channels']).to(self.device)
34
+ self.parsing_decoder = FCNHead(
35
+ in_channels=opt['fc_in_channels'],
36
+ in_index=opt['fc_in_index'],
37
+ channels=opt['fc_channels'],
38
+ num_convs=opt['fc_num_convs'],
39
+ concat_input=opt['fc_concat_input'],
40
+ dropout_ratio=opt['fc_dropout_ratio'],
41
+ num_classes=opt['fc_num_classes'],
42
+ align_corners=opt['fc_align_corners'],
43
+ ).to(self.device)
44
+
45
+ self.init_training_settings()
46
+
47
+ self.palette = [[0, 0, 0], [255, 250, 250], [220, 220, 220],
48
+ [250, 235, 215], [255, 250, 205], [211, 211, 211],
49
+ [70, 130, 180], [127, 255, 212], [0, 100, 0],
50
+ [50, 205, 50], [255, 255, 0], [245, 222, 179],
51
+ [255, 140, 0], [255, 0, 0], [16, 78, 139],
52
+ [144, 238, 144], [50, 205, 174], [50, 155, 250],
53
+ [160, 140, 88], [213, 140, 88], [90, 140, 90],
54
+ [185, 210, 205], [130, 165, 180], [225, 141, 151]]
55
+
56
+ def init_training_settings(self):
57
+ optim_params = []
58
+ for v in self.attr_embedder.parameters():
59
+ if v.requires_grad:
60
+ optim_params.append(v)
61
+ for v in self.parsing_encoder.parameters():
62
+ if v.requires_grad:
63
+ optim_params.append(v)
64
+ for v in self.parsing_decoder.parameters():
65
+ if v.requires_grad:
66
+ optim_params.append(v)
67
+ # set up optimizers
68
+ self.optimizer = torch.optim.Adam(
69
+ optim_params,
70
+ self.opt['lr'],
71
+ weight_decay=self.opt['weight_decay'])
72
+ self.log_dict = OrderedDict()
73
+ self.entropy_loss = CrossEntropyLoss().to(self.device)
74
+
75
+ def feed_data(self, data):
76
+ self.pose = data['densepose'].to(self.device)
77
+ self.attr = data['attr'].to(self.device)
78
+ self.segm = data['segm'].to(self.device)
79
+
80
+ def optimize_parameters(self):
81
+ self.attr_embedder.train()
82
+ self.parsing_encoder.train()
83
+ self.parsing_decoder.train()
84
+
85
+ self.attr_embedding = self.attr_embedder(self.attr)
86
+ self.pose_enc = self.parsing_encoder(self.pose, self.attr_embedding)
87
+ self.seg_logits = self.parsing_decoder(self.pose_enc)
88
+
89
+ loss = self.entropy_loss(self.seg_logits, self.segm)
90
+
91
+ self.optimizer.zero_grad()
92
+ loss.backward()
93
+ self.optimizer.step()
94
+
95
+ self.log_dict['loss_total'] = loss
96
+
97
+ def get_vis(self, save_path):
98
+ img_cat = torch.cat([
99
+ self.pose,
100
+ self.segm,
101
+ ], dim=3).detach()
102
+ img_cat = ((img_cat + 1) / 2)
103
+
104
+ img_cat = img_cat.clamp_(0, 1)
105
+
106
+ save_image(img_cat, save_path, nrow=1, padding=4)
107
+
108
+ def inference(self, data_loader, save_dir):
109
+ self.attr_embedder.eval()
110
+ self.parsing_encoder.eval()
111
+ self.parsing_decoder.eval()
112
+
113
+ acc = 0
114
+ num = 0
115
+
116
+ for _, data in enumerate(data_loader):
117
+ pose = data['densepose'].to(self.device)
118
+ attr = data['attr'].to(self.device)
119
+ segm = data['segm'].to(self.device)
120
+ img_name = data['img_name']
121
+
122
+ num += pose.size(0)
123
+ with torch.no_grad():
124
+ attr_embedding = self.attr_embedder(attr)
125
+ pose_enc = self.parsing_encoder(pose, attr_embedding)
126
+ seg_logits = self.parsing_decoder(pose_enc)
127
+ seg_pred = seg_logits.argmax(dim=1)
128
+ acc += accuracy(seg_logits, segm)
129
+ palette_label = self.palette_result(segm.cpu().numpy())
130
+ palette_pred = self.palette_result(seg_pred.cpu().numpy())
131
+ pose_numpy = ((pose[0] + 1) / 2. * 255.).expand(
132
+ 3,
133
+ pose[0].size(1),
134
+ pose[0].size(2),
135
+ ).cpu().numpy().clip(0, 255).astype(np.uint8).transpose(1, 2, 0)
136
+ concat_result = np.concatenate(
137
+ (pose_numpy, palette_pred, palette_label), axis=1)
138
+ mmcv.imwrite(concat_result, f'{save_dir}/{img_name[0]}')
139
+
140
+ self.attr_embedder.train()
141
+ self.parsing_encoder.train()
142
+ self.parsing_decoder.train()
143
+ return (acc / num).item()
144
+
145
+ def get_current_log(self):
146
+ return self.log_dict
147
+
148
+ def update_learning_rate(self, epoch):
149
+ """Update learning rate.
150
+
151
+ Args:
152
+ current_iter (int): Current iteration.
153
+ warmup_iter (int): Warmup iter numbers. -1 for no warmup.
154
+ Default: -1.
155
+ """
156
+ lr = self.optimizer.param_groups[0]['lr']
157
+
158
+ if self.opt['lr_decay'] == 'step':
159
+ lr = self.opt['lr'] * (
160
+ self.opt['gamma']**(epoch // self.opt['step']))
161
+ elif self.opt['lr_decay'] == 'cos':
162
+ lr = self.opt['lr'] * (
163
+ 1 + math.cos(math.pi * epoch / self.opt['num_epochs'])) / 2
164
+ elif self.opt['lr_decay'] == 'linear':
165
+ lr = self.opt['lr'] * (1 - epoch / self.opt['num_epochs'])
166
+ elif self.opt['lr_decay'] == 'linear2exp':
167
+ if epoch < self.opt['turning_point'] + 1:
168
+ # learning rate decay as 95%
169
+ # at the turning point (1 / 95% = 1.0526)
170
+ lr = self.opt['lr'] * (
171
+ 1 - epoch / int(self.opt['turning_point'] * 1.0526))
172
+ else:
173
+ lr *= self.opt['gamma']
174
+ elif self.opt['lr_decay'] == 'schedule':
175
+ if epoch in self.opt['schedule']:
176
+ lr *= self.opt['gamma']
177
+ else:
178
+ raise ValueError('Unknown lr mode {}'.format(self.opt['lr_decay']))
179
+ # set learning rate
180
+ for param_group in self.optimizer.param_groups:
181
+ param_group['lr'] = lr
182
+
183
+ return lr
184
+
185
+ def save_network(self, save_path):
186
+ """Save networks.
187
+ """
188
+
189
+ save_dict = {}
190
+ save_dict['embedder'] = self.attr_embedder.state_dict()
191
+ save_dict['encoder'] = self.parsing_encoder.state_dict()
192
+ save_dict['decoder'] = self.parsing_decoder.state_dict()
193
+
194
+ torch.save(save_dict, save_path)
195
+
196
+ def load_network(self):
197
+ checkpoint = torch.load(self.opt['pretrained_parsing_gen'])
198
+
199
+ self.attr_embedder.load_state_dict(checkpoint['embedder'], strict=True)
200
+ self.attr_embedder.eval()
201
+
202
+ self.parsing_encoder.load_state_dict(
203
+ checkpoint['encoder'], strict=True)
204
+ self.parsing_encoder.eval()
205
+
206
+ self.parsing_decoder.load_state_dict(
207
+ checkpoint['decoder'], strict=True)
208
+ self.parsing_decoder.eval()
209
+
210
+ def palette_result(self, result):
211
+ seg = result[0]
212
+ palette = np.array(self.palette)
213
+ assert palette.shape[1] == 3
214
+ assert len(palette.shape) == 2
215
+ color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
216
+ for label, color in enumerate(palette):
217
+ color_seg[seg == label, :] = color
218
+ # convert to BGR
219
+ color_seg = color_seg[..., ::-1]
220
+ return color_seg
models/sample_model.py ADDED
@@ -0,0 +1,498 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.distributions as dists
6
+ import torch.nn.functional as F
7
+ from torchvision.utils import save_image
8
+
9
+ from models.archs.fcn_arch import FCNHead, MultiHeadFCNHead
10
+ from models.archs.shape_attr_embedding_arch import ShapeAttrEmbedding
11
+ from models.archs.transformer_arch import TransformerMultiHead
12
+ from models.archs.unet_arch import ShapeUNet, UNet
13
+ from models.archs.vqgan_arch import (Decoder, DecoderRes, Encoder,
14
+ VectorQuantizer,
15
+ VectorQuantizerSpatialTextureAware,
16
+ VectorQuantizerTexture)
17
+
18
+ logger = logging.getLogger('base')
19
+
20
+
21
+ class BaseSampleModel():
22
+ """Base Model"""
23
+
24
+ def __init__(self, opt):
25
+ self.opt = opt
26
+ self.device = torch.device('cuda')
27
+
28
+ # hierarchical VQVAE
29
+ self.decoder = Decoder(
30
+ in_channels=opt['top_in_channels'],
31
+ resolution=opt['top_resolution'],
32
+ z_channels=opt['top_z_channels'],
33
+ ch=opt['top_ch'],
34
+ out_ch=opt['top_out_ch'],
35
+ num_res_blocks=opt['top_num_res_blocks'],
36
+ attn_resolutions=opt['top_attn_resolutions'],
37
+ ch_mult=opt['top_ch_mult'],
38
+ dropout=opt['top_dropout'],
39
+ resamp_with_conv=True,
40
+ give_pre_end=False).to(self.device)
41
+ self.top_quantize = VectorQuantizerTexture(
42
+ 1024, opt['embed_dim'], beta=0.25).to(self.device)
43
+ self.top_post_quant_conv = torch.nn.Conv2d(opt['embed_dim'],
44
+ opt["top_z_channels"],
45
+ 1).to(self.device)
46
+ self.load_top_pretrain_models()
47
+
48
+ self.bot_decoder_res = DecoderRes(
49
+ in_channels=opt['bot_in_channels'],
50
+ resolution=opt['bot_resolution'],
51
+ z_channels=opt['bot_z_channels'],
52
+ ch=opt['bot_ch'],
53
+ num_res_blocks=opt['bot_num_res_blocks'],
54
+ ch_mult=opt['bot_ch_mult'],
55
+ dropout=opt['bot_dropout'],
56
+ give_pre_end=False).to(self.device)
57
+ self.bot_quantize = VectorQuantizerSpatialTextureAware(
58
+ opt['bot_n_embed'],
59
+ opt['embed_dim'],
60
+ beta=0.25,
61
+ spatial_size=opt['bot_codebook_spatial_size']).to(self.device)
62
+ self.bot_post_quant_conv = torch.nn.Conv2d(opt['embed_dim'],
63
+ opt["bot_z_channels"],
64
+ 1).to(self.device)
65
+ self.load_bot_pretrain_network()
66
+
67
+ # top -> bot prediction
68
+ self.index_pred_guidance_encoder = UNet(
69
+ in_channels=opt['index_pred_encoder_in_channels']).to(self.device)
70
+ self.index_pred_decoder = MultiHeadFCNHead(
71
+ in_channels=opt['index_pred_fc_in_channels'],
72
+ in_index=opt['index_pred_fc_in_index'],
73
+ channels=opt['index_pred_fc_channels'],
74
+ num_convs=opt['index_pred_fc_num_convs'],
75
+ concat_input=opt['index_pred_fc_concat_input'],
76
+ dropout_ratio=opt['index_pred_fc_dropout_ratio'],
77
+ num_classes=opt['index_pred_fc_num_classes'],
78
+ align_corners=opt['index_pred_fc_align_corners'],
79
+ num_head=18).to(self.device)
80
+ self.load_index_pred_network()
81
+
82
+ # VAE for segmentation mask
83
+ self.segm_encoder = Encoder(
84
+ ch=opt['segm_ch'],
85
+ num_res_blocks=opt['segm_num_res_blocks'],
86
+ attn_resolutions=opt['segm_attn_resolutions'],
87
+ ch_mult=opt['segm_ch_mult'],
88
+ in_channels=opt['segm_in_channels'],
89
+ resolution=opt['segm_resolution'],
90
+ z_channels=opt['segm_z_channels'],
91
+ double_z=opt['segm_double_z'],
92
+ dropout=opt['segm_dropout']).to(self.device)
93
+ self.segm_quantizer = VectorQuantizer(
94
+ opt['segm_n_embed'],
95
+ opt['segm_embed_dim'],
96
+ beta=0.25,
97
+ sane_index_shape=True).to(self.device)
98
+ self.segm_quant_conv = torch.nn.Conv2d(opt["segm_z_channels"],
99
+ opt['segm_embed_dim'],
100
+ 1).to(self.device)
101
+ self.load_pretrained_segm_token()
102
+
103
+ # define sampler
104
+ self.sampler_fn = TransformerMultiHead(
105
+ codebook_size=opt['codebook_size'],
106
+ segm_codebook_size=opt['segm_codebook_size'],
107
+ texture_codebook_size=opt['texture_codebook_size'],
108
+ bert_n_emb=opt['bert_n_emb'],
109
+ bert_n_layers=opt['bert_n_layers'],
110
+ bert_n_head=opt['bert_n_head'],
111
+ block_size=opt['block_size'],
112
+ latent_shape=opt['latent_shape'],
113
+ embd_pdrop=opt['embd_pdrop'],
114
+ resid_pdrop=opt['resid_pdrop'],
115
+ attn_pdrop=opt['attn_pdrop'],
116
+ num_head=opt['num_head']).to(self.device)
117
+ self.load_sampler_pretrained_network()
118
+
119
+ self.shape = tuple(opt['latent_shape'])
120
+
121
+ self.mask_id = opt['codebook_size']
122
+ self.sample_steps = opt['sample_steps']
123
+
124
+ def load_top_pretrain_models(self):
125
+ # load pretrained vqgan
126
+ top_vae_checkpoint = torch.load(self.opt['top_vae_path'])
127
+
128
+ self.decoder.load_state_dict(
129
+ top_vae_checkpoint['decoder'], strict=True)
130
+ self.top_quantize.load_state_dict(
131
+ top_vae_checkpoint['quantize'], strict=True)
132
+ self.top_post_quant_conv.load_state_dict(
133
+ top_vae_checkpoint['post_quant_conv'], strict=True)
134
+
135
+ self.decoder.eval()
136
+ self.top_quantize.eval()
137
+ self.top_post_quant_conv.eval()
138
+
139
+ def load_bot_pretrain_network(self):
140
+ checkpoint = torch.load(self.opt['bot_vae_path'])
141
+ self.bot_decoder_res.load_state_dict(
142
+ checkpoint['bot_decoder_res'], strict=True)
143
+ self.decoder.load_state_dict(checkpoint['decoder'], strict=True)
144
+ self.bot_quantize.load_state_dict(
145
+ checkpoint['bot_quantize'], strict=True)
146
+ self.bot_post_quant_conv.load_state_dict(
147
+ checkpoint['bot_post_quant_conv'], strict=True)
148
+
149
+ self.bot_decoder_res.eval()
150
+ self.decoder.eval()
151
+ self.bot_quantize.eval()
152
+ self.bot_post_quant_conv.eval()
153
+
154
+ def load_pretrained_segm_token(self):
155
+ # load pretrained vqgan for segmentation mask
156
+ segm_token_checkpoint = torch.load(self.opt['segm_token_path'])
157
+ self.segm_encoder.load_state_dict(
158
+ segm_token_checkpoint['encoder'], strict=True)
159
+ self.segm_quantizer.load_state_dict(
160
+ segm_token_checkpoint['quantize'], strict=True)
161
+ self.segm_quant_conv.load_state_dict(
162
+ segm_token_checkpoint['quant_conv'], strict=True)
163
+
164
+ self.segm_encoder.eval()
165
+ self.segm_quantizer.eval()
166
+ self.segm_quant_conv.eval()
167
+
168
+ def load_index_pred_network(self):
169
+ checkpoint = torch.load(self.opt['pretrained_index_network'])
170
+ self.index_pred_guidance_encoder.load_state_dict(
171
+ checkpoint['guidance_encoder'], strict=True)
172
+ self.index_pred_decoder.load_state_dict(
173
+ checkpoint['index_decoder'], strict=True)
174
+
175
+ self.index_pred_guidance_encoder.eval()
176
+ self.index_pred_decoder.eval()
177
+
178
+ def load_sampler_pretrained_network(self):
179
+ checkpoint = torch.load(self.opt['pretrained_sampler'])
180
+ self.sampler_fn.load_state_dict(checkpoint, strict=True)
181
+ self.sampler_fn.eval()
182
+
183
+ def bot_index_prediction(self, feature_top, texture_mask):
184
+ self.index_pred_guidance_encoder.eval()
185
+ self.index_pred_decoder.eval()
186
+
187
+ texture_mask_flatten = F.interpolate(
188
+ texture_mask, (32, 16), mode='nearest').view(-1).long()
189
+
190
+ min_encodings_indices_list = [
191
+ torch.full(
192
+ texture_mask_flatten.size(),
193
+ fill_value=-1,
194
+ dtype=torch.long,
195
+ device=texture_mask_flatten.device) for _ in range(18)
196
+ ]
197
+ with torch.no_grad():
198
+ feature_enc = self.index_pred_guidance_encoder(feature_top)
199
+ memory_logits_list = self.index_pred_decoder(feature_enc)
200
+ for codebook_idx, memory_logits in enumerate(memory_logits_list):
201
+ region_of_interest = texture_mask_flatten == codebook_idx
202
+ if torch.sum(region_of_interest) > 0:
203
+ memory_indices_pred = memory_logits.argmax(dim=1).view(-1)
204
+ memory_indices_pred = memory_indices_pred
205
+ min_encodings_indices_list[codebook_idx][
206
+ region_of_interest] = memory_indices_pred[
207
+ region_of_interest]
208
+ min_encodings_indices_return_list = [
209
+ min_encodings_indices.view((1, 32, 16))
210
+ for min_encodings_indices in min_encodings_indices_list
211
+ ]
212
+
213
+ return min_encodings_indices_return_list
214
+
215
+ def sample_and_refine(self, save_dir=None, img_name=None):
216
+ # sample 32x16 features indices
217
+ sampled_top_indices_list = self.sample_fn(
218
+ temp=1, sample_steps=self.sample_steps)
219
+
220
+ for sample_idx in range(self.batch_size):
221
+ sample_indices = [
222
+ sampled_indices_cur[sample_idx:sample_idx + 1]
223
+ for sampled_indices_cur in sampled_top_indices_list
224
+ ]
225
+ top_quant = self.top_quantize.get_codebook_entry(
226
+ sample_indices, self.texture_mask[sample_idx:sample_idx + 1],
227
+ (sample_indices[0].size(0), self.shape[0], self.shape[1],
228
+ self.opt["top_z_channels"]))
229
+
230
+ top_quant = self.top_post_quant_conv(top_quant)
231
+
232
+ bot_indices_list = self.bot_index_prediction(
233
+ top_quant, self.texture_mask[sample_idx:sample_idx + 1])
234
+
235
+ quant_bot = self.bot_quantize.get_codebook_entry(
236
+ bot_indices_list, self.texture_mask[sample_idx:sample_idx + 1],
237
+ (bot_indices_list[0].size(0), bot_indices_list[0].size(1),
238
+ bot_indices_list[0].size(2),
239
+ self.opt["bot_z_channels"])) #.permute(0, 3, 1, 2)
240
+ quant_bot = self.bot_post_quant_conv(quant_bot)
241
+ bot_dec_res = self.bot_decoder_res(quant_bot)
242
+
243
+ dec = self.decoder(top_quant, bot_h=bot_dec_res)
244
+
245
+ dec = ((dec + 1) / 2)
246
+ dec = dec.clamp_(0, 1)
247
+ if save_dir is None and img_name is None:
248
+ return dec
249
+ else:
250
+ save_image(
251
+ dec,
252
+ f'{save_dir}/{img_name[sample_idx]}',
253
+ nrow=1,
254
+ padding=4)
255
+
256
+ def sample_fn(self, temp=1.0, sample_steps=None):
257
+ self.sampler_fn.eval()
258
+
259
+ x_t = torch.ones((self.batch_size, np.prod(self.shape)),
260
+ device=self.device).long() * self.mask_id
261
+ unmasked = torch.zeros_like(x_t, device=self.device).bool()
262
+ sample_steps = list(range(1, sample_steps + 1))
263
+
264
+ texture_tokens = F.interpolate(
265
+ self.texture_mask, (32, 16),
266
+ mode='nearest').view(self.batch_size, -1).long()
267
+
268
+ texture_mask_flatten = texture_tokens.view(-1)
269
+
270
+ # min_encodings_indices_list would be used to visualize the image
271
+ min_encodings_indices_list = [
272
+ torch.full(
273
+ texture_mask_flatten.size(),
274
+ fill_value=-1,
275
+ dtype=torch.long,
276
+ device=texture_mask_flatten.device) for _ in range(18)
277
+ ]
278
+
279
+ for t in reversed(sample_steps):
280
+ t = torch.full((self.batch_size, ),
281
+ t,
282
+ device=self.device,
283
+ dtype=torch.long)
284
+
285
+ # where to unmask
286
+ changes = torch.rand(
287
+ x_t.shape, device=self.device) < 1 / t.float().unsqueeze(-1)
288
+ # don't unmask somewhere already unmasked
289
+ changes = torch.bitwise_xor(changes,
290
+ torch.bitwise_and(changes, unmasked))
291
+ # update mask with changes
292
+ unmasked = torch.bitwise_or(unmasked, changes)
293
+
294
+ x_0_logits_list = self.sampler_fn(
295
+ x_t, self.segm_tokens, texture_tokens, t=t)
296
+
297
+ changes_flatten = changes.view(-1)
298
+ ori_shape = x_t.shape # [b, h*w]
299
+ x_t = x_t.view(-1) # [b*h*w]
300
+ for codebook_idx, x_0_logits in enumerate(x_0_logits_list):
301
+ if torch.sum(texture_mask_flatten[changes_flatten] ==
302
+ codebook_idx) > 0:
303
+ # scale by temperature
304
+ x_0_logits = x_0_logits / temp
305
+ x_0_dist = dists.Categorical(logits=x_0_logits)
306
+ x_0_hat = x_0_dist.sample().long()
307
+ x_0_hat = x_0_hat.view(-1)
308
+
309
+ # only replace the changed indices with corresponding codebook_idx
310
+ changes_segm = torch.bitwise_and(
311
+ changes_flatten, texture_mask_flatten == codebook_idx)
312
+
313
+ # x_t would be the input to the transformer, so the index range should be continual one
314
+ x_t[changes_segm] = x_0_hat[
315
+ changes_segm] + 1024 * codebook_idx
316
+ min_encodings_indices_list[codebook_idx][
317
+ changes_segm] = x_0_hat[changes_segm]
318
+
319
+ x_t = x_t.view(ori_shape) # [b, h*w]
320
+
321
+ min_encodings_indices_return_list = [
322
+ min_encodings_indices.view(ori_shape)
323
+ for min_encodings_indices in min_encodings_indices_list
324
+ ]
325
+
326
+ self.sampler_fn.train()
327
+
328
+ return min_encodings_indices_return_list
329
+
330
+ @torch.no_grad()
331
+ def get_quantized_segm(self, segm):
332
+ segm_one_hot = F.one_hot(
333
+ segm.squeeze(1).long(),
334
+ num_classes=self.opt['segm_num_segm_classes']).permute(
335
+ 0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
336
+ encoded_segm_mask = self.segm_encoder(segm_one_hot)
337
+ encoded_segm_mask = self.segm_quant_conv(encoded_segm_mask)
338
+ _, _, [_, _, segm_tokens] = self.segm_quantizer(encoded_segm_mask)
339
+
340
+ return segm_tokens
341
+
342
+
343
+ class SampleFromParsingModel(BaseSampleModel):
344
+ """SampleFromParsing model.
345
+ """
346
+
347
+ def feed_data(self, data):
348
+ self.segm = data['segm'].to(self.device)
349
+ self.texture_mask = data['texture_mask'].to(self.device)
350
+ self.batch_size = self.segm.size(0)
351
+
352
+ self.segm_tokens = self.get_quantized_segm(self.segm)
353
+ self.segm_tokens = self.segm_tokens.view(self.batch_size, -1)
354
+
355
+ def inference(self, data_loader, save_dir):
356
+ for _, data in enumerate(data_loader):
357
+ img_name = data['img_name']
358
+ self.feed_data(data)
359
+ with torch.no_grad():
360
+ self.sample_and_refine(save_dir, img_name)
361
+
362
+
363
+ class SampleFromPoseModel(BaseSampleModel):
364
+ """SampleFromPose model.
365
+ """
366
+
367
+ def __init__(self, opt):
368
+ super().__init__(opt)
369
+ # pose-to-parsing
370
+ self.shape_attr_embedder = ShapeAttrEmbedding(
371
+ dim=opt['shape_embedder_dim'],
372
+ out_dim=opt['shape_embedder_out_dim'],
373
+ cls_num_list=opt['shape_attr_class_num']).to(self.device)
374
+ self.shape_parsing_encoder = ShapeUNet(
375
+ in_channels=opt['shape_encoder_in_channels']).to(self.device)
376
+ self.shape_parsing_decoder = FCNHead(
377
+ in_channels=opt['shape_fc_in_channels'],
378
+ in_index=opt['shape_fc_in_index'],
379
+ channels=opt['shape_fc_channels'],
380
+ num_convs=opt['shape_fc_num_convs'],
381
+ concat_input=opt['shape_fc_concat_input'],
382
+ dropout_ratio=opt['shape_fc_dropout_ratio'],
383
+ num_classes=opt['shape_fc_num_classes'],
384
+ align_corners=opt['shape_fc_align_corners'],
385
+ ).to(self.device)
386
+ self.load_shape_generation_models()
387
+
388
+ self.palette = [[0, 0, 0], [255, 250, 250], [220, 220, 220],
389
+ [250, 235, 215], [255, 250, 205], [211, 211, 211],
390
+ [70, 130, 180], [127, 255, 212], [0, 100, 0],
391
+ [50, 205, 50], [255, 255, 0], [245, 222, 179],
392
+ [255, 140, 0], [255, 0, 0], [16, 78, 139],
393
+ [144, 238, 144], [50, 205, 174], [50, 155, 250],
394
+ [160, 140, 88], [213, 140, 88], [90, 140, 90],
395
+ [185, 210, 205], [130, 165, 180], [225, 141, 151]]
396
+
397
+ def load_shape_generation_models(self):
398
+ checkpoint = torch.load(self.opt['pretrained_parsing_gen'])
399
+
400
+ self.shape_attr_embedder.load_state_dict(
401
+ checkpoint['embedder'], strict=True)
402
+ self.shape_attr_embedder.eval()
403
+
404
+ self.shape_parsing_encoder.load_state_dict(
405
+ checkpoint['encoder'], strict=True)
406
+ self.shape_parsing_encoder.eval()
407
+
408
+ self.shape_parsing_decoder.load_state_dict(
409
+ checkpoint['decoder'], strict=True)
410
+ self.shape_parsing_decoder.eval()
411
+
412
+ def feed_data(self, data):
413
+ self.pose = data['densepose'].to(self.device)
414
+ self.batch_size = self.pose.size(0)
415
+
416
+ self.shape_attr = data['shape_attr'].to(self.device)
417
+ self.upper_fused_attr = data['upper_fused_attr'].to(self.device)
418
+ self.lower_fused_attr = data['lower_fused_attr'].to(self.device)
419
+ self.outer_fused_attr = data['outer_fused_attr'].to(self.device)
420
+
421
+ def inference(self, data_loader, save_dir):
422
+ for _, data in enumerate(data_loader):
423
+ img_name = data['img_name']
424
+ self.feed_data(data)
425
+ with torch.no_grad():
426
+ self.generate_parsing_map()
427
+ self.generate_quantized_segm()
428
+ self.generate_texture_map()
429
+ self.sample_and_refine(save_dir, img_name)
430
+
431
+ def generate_parsing_map(self):
432
+ with torch.no_grad():
433
+ attr_embedding = self.shape_attr_embedder(self.shape_attr)
434
+ pose_enc = self.shape_parsing_encoder(self.pose, attr_embedding)
435
+ seg_logits = self.shape_parsing_decoder(pose_enc)
436
+ self.segm = seg_logits.argmax(dim=1)
437
+ self.segm = self.segm.unsqueeze(1)
438
+
439
+ def generate_quantized_segm(self):
440
+ self.segm_tokens = self.get_quantized_segm(self.segm)
441
+ self.segm_tokens = self.segm_tokens.view(self.batch_size, -1)
442
+
443
+ def generate_texture_map(self):
444
+ upper_cls = [1., 4.]
445
+ lower_cls = [3., 5., 21.]
446
+ outer_cls = [2.]
447
+
448
+ mask_batch = []
449
+ for idx in range(self.batch_size):
450
+ mask = torch.zeros_like(self.segm[idx])
451
+ upper_fused_attr = self.upper_fused_attr[idx]
452
+ lower_fused_attr = self.lower_fused_attr[idx]
453
+ outer_fused_attr = self.outer_fused_attr[idx]
454
+ if upper_fused_attr != 17:
455
+ for cls in upper_cls:
456
+ mask[self.segm[idx] == cls] = upper_fused_attr + 1
457
+
458
+ if lower_fused_attr != 17:
459
+ for cls in lower_cls:
460
+ mask[self.segm[idx] == cls] = lower_fused_attr + 1
461
+
462
+ if outer_fused_attr != 17:
463
+ for cls in outer_cls:
464
+ mask[self.segm[idx] == cls] = outer_fused_attr + 1
465
+
466
+ mask_batch.append(mask)
467
+ self.texture_mask = torch.stack(mask_batch, dim=0).to(torch.float32)
468
+
469
+ def feed_pose_data(self, pose_img):
470
+ # for ui demo
471
+
472
+ self.pose = pose_img.to(self.device)
473
+ self.batch_size = self.pose.size(0)
474
+
475
+ def feed_shape_attributes(self, shape_attr):
476
+ # for ui demo
477
+
478
+ self.shape_attr = shape_attr.to(self.device)
479
+
480
+ def feed_texture_attributes(self, texture_attr):
481
+ # for ui demo
482
+
483
+ self.upper_fused_attr = texture_attr[0].unsqueeze(0).to(self.device)
484
+ self.lower_fused_attr = texture_attr[1].unsqueeze(0).to(self.device)
485
+ self.outer_fused_attr = texture_attr[2].unsqueeze(0).to(self.device)
486
+
487
+ def palette_result(self, result):
488
+
489
+ seg = result[0]
490
+ palette = np.array(self.palette)
491
+ assert palette.shape[1] == 3
492
+ assert len(palette.shape) == 2
493
+ color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
494
+ for label, color in enumerate(palette):
495
+ color_seg[seg == label, :] = color
496
+ # convert to BGR
497
+ # color_seg = color_seg[..., ::-1]
498
+ return color_seg
models/transformer_model.py ADDED
@@ -0,0 +1,482 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+ from collections import OrderedDict
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.distributions as dists
8
+ import torch.nn.functional as F
9
+ from torchvision.utils import save_image
10
+
11
+ from models.archs.transformer_arch import TransformerMultiHead
12
+ from models.archs.vqgan_arch import (Decoder, Encoder, VectorQuantizer,
13
+ VectorQuantizerTexture)
14
+
15
+ logger = logging.getLogger('base')
16
+
17
+
18
+ class TransformerTextureAwareModel():
19
+ """Texture-Aware Diffusion based Transformer model.
20
+ """
21
+
22
+ def __init__(self, opt):
23
+ self.opt = opt
24
+ self.device = torch.device('cuda')
25
+ self.is_train = opt['is_train']
26
+
27
+ # VQVAE for image
28
+ self.img_encoder = Encoder(
29
+ ch=opt['img_ch'],
30
+ num_res_blocks=opt['img_num_res_blocks'],
31
+ attn_resolutions=opt['img_attn_resolutions'],
32
+ ch_mult=opt['img_ch_mult'],
33
+ in_channels=opt['img_in_channels'],
34
+ resolution=opt['img_resolution'],
35
+ z_channels=opt['img_z_channels'],
36
+ double_z=opt['img_double_z'],
37
+ dropout=opt['img_dropout']).to(self.device)
38
+ self.img_decoder = Decoder(
39
+ in_channels=opt['img_in_channels'],
40
+ resolution=opt['img_resolution'],
41
+ z_channels=opt['img_z_channels'],
42
+ ch=opt['img_ch'],
43
+ out_ch=opt['img_out_ch'],
44
+ num_res_blocks=opt['img_num_res_blocks'],
45
+ attn_resolutions=opt['img_attn_resolutions'],
46
+ ch_mult=opt['img_ch_mult'],
47
+ dropout=opt['img_dropout'],
48
+ resamp_with_conv=True,
49
+ give_pre_end=False).to(self.device)
50
+ self.img_quantizer = VectorQuantizerTexture(
51
+ opt['img_n_embed'], opt['img_embed_dim'],
52
+ beta=0.25).to(self.device)
53
+ self.img_quant_conv = torch.nn.Conv2d(opt["img_z_channels"],
54
+ opt['img_embed_dim'],
55
+ 1).to(self.device)
56
+ self.img_post_quant_conv = torch.nn.Conv2d(opt['img_embed_dim'],
57
+ opt["img_z_channels"],
58
+ 1).to(self.device)
59
+ self.load_pretrained_image_vae()
60
+
61
+ # VAE for segmentation mask
62
+ self.segm_encoder = Encoder(
63
+ ch=opt['segm_ch'],
64
+ num_res_blocks=opt['segm_num_res_blocks'],
65
+ attn_resolutions=opt['segm_attn_resolutions'],
66
+ ch_mult=opt['segm_ch_mult'],
67
+ in_channels=opt['segm_in_channels'],
68
+ resolution=opt['segm_resolution'],
69
+ z_channels=opt['segm_z_channels'],
70
+ double_z=opt['segm_double_z'],
71
+ dropout=opt['segm_dropout']).to(self.device)
72
+ self.segm_quantizer = VectorQuantizer(
73
+ opt['segm_n_embed'],
74
+ opt['segm_embed_dim'],
75
+ beta=0.25,
76
+ sane_index_shape=True).to(self.device)
77
+ self.segm_quant_conv = torch.nn.Conv2d(opt["segm_z_channels"],
78
+ opt['segm_embed_dim'],
79
+ 1).to(self.device)
80
+ self.load_pretrained_segm_vae()
81
+
82
+ # define sampler
83
+ self._denoise_fn = TransformerMultiHead(
84
+ codebook_size=opt['codebook_size'],
85
+ segm_codebook_size=opt['segm_codebook_size'],
86
+ texture_codebook_size=opt['texture_codebook_size'],
87
+ bert_n_emb=opt['bert_n_emb'],
88
+ bert_n_layers=opt['bert_n_layers'],
89
+ bert_n_head=opt['bert_n_head'],
90
+ block_size=opt['block_size'],
91
+ latent_shape=opt['latent_shape'],
92
+ embd_pdrop=opt['embd_pdrop'],
93
+ resid_pdrop=opt['resid_pdrop'],
94
+ attn_pdrop=opt['attn_pdrop'],
95
+ num_head=opt['num_head']).to(self.device)
96
+
97
+ self.num_classes = opt['codebook_size']
98
+ self.shape = tuple(opt['latent_shape'])
99
+ self.num_timesteps = 1000
100
+
101
+ self.mask_id = opt['codebook_size']
102
+ self.loss_type = opt['loss_type']
103
+ self.mask_schedule = opt['mask_schedule']
104
+
105
+ self.sample_steps = opt['sample_steps']
106
+
107
+ self.init_training_settings()
108
+
109
+ def load_pretrained_image_vae(self):
110
+ # load pretrained vqgan for segmentation mask
111
+ img_ae_checkpoint = torch.load(self.opt['img_ae_path'])
112
+ self.img_encoder.load_state_dict(
113
+ img_ae_checkpoint['encoder'], strict=True)
114
+ self.img_decoder.load_state_dict(
115
+ img_ae_checkpoint['decoder'], strict=True)
116
+ self.img_quantizer.load_state_dict(
117
+ img_ae_checkpoint['quantize'], strict=True)
118
+ self.img_quant_conv.load_state_dict(
119
+ img_ae_checkpoint['quant_conv'], strict=True)
120
+ self.img_post_quant_conv.load_state_dict(
121
+ img_ae_checkpoint['post_quant_conv'], strict=True)
122
+ self.img_encoder.eval()
123
+ self.img_decoder.eval()
124
+ self.img_quantizer.eval()
125
+ self.img_quant_conv.eval()
126
+ self.img_post_quant_conv.eval()
127
+
128
+ def load_pretrained_segm_vae(self):
129
+ # load pretrained vqgan for segmentation mask
130
+ segm_ae_checkpoint = torch.load(self.opt['segm_ae_path'])
131
+ self.segm_encoder.load_state_dict(
132
+ segm_ae_checkpoint['encoder'], strict=True)
133
+ self.segm_quantizer.load_state_dict(
134
+ segm_ae_checkpoint['quantize'], strict=True)
135
+ self.segm_quant_conv.load_state_dict(
136
+ segm_ae_checkpoint['quant_conv'], strict=True)
137
+ self.segm_encoder.eval()
138
+ self.segm_quantizer.eval()
139
+ self.segm_quant_conv.eval()
140
+
141
+ def init_training_settings(self):
142
+ optim_params = []
143
+ for v in self._denoise_fn.parameters():
144
+ if v.requires_grad:
145
+ optim_params.append(v)
146
+ # set up optimizer
147
+ self.optimizer = torch.optim.Adam(
148
+ optim_params,
149
+ self.opt['lr'],
150
+ weight_decay=self.opt['weight_decay'])
151
+ self.log_dict = OrderedDict()
152
+
153
+ @torch.no_grad()
154
+ def get_quantized_img(self, image, texture_mask):
155
+ encoded_img = self.img_encoder(image)
156
+ encoded_img = self.img_quant_conv(encoded_img)
157
+
158
+ # img_tokens_input is the continual index for the input of transformer
159
+ # img_tokens_gt_list is the index for 18 texture-aware codebooks respectively
160
+ _, _, [_, img_tokens_input, img_tokens_gt_list
161
+ ] = self.img_quantizer(encoded_img, texture_mask)
162
+
163
+ # reshape the tokens
164
+ b = image.size(0)
165
+ img_tokens_input = img_tokens_input.view(b, -1)
166
+ img_tokens_gt_return_list = [
167
+ img_tokens_gt.view(b, -1) for img_tokens_gt in img_tokens_gt_list
168
+ ]
169
+
170
+ return img_tokens_input, img_tokens_gt_return_list
171
+
172
+ @torch.no_grad()
173
+ def decode(self, quant):
174
+ quant = self.img_post_quant_conv(quant)
175
+ dec = self.img_decoder(quant)
176
+ return dec
177
+
178
+ @torch.no_grad()
179
+ def decode_image_indices(self, indices_list, texture_mask):
180
+ quant = self.img_quantizer.get_codebook_entry(
181
+ indices_list, texture_mask,
182
+ (indices_list[0].size(0), self.shape[0], self.shape[1],
183
+ self.opt["img_z_channels"]))
184
+ dec = self.decode(quant)
185
+
186
+ return dec
187
+
188
+ def sample_time(self, b, device, method='uniform'):
189
+ if method == 'importance':
190
+ if not (self.Lt_count > 10).all():
191
+ return self.sample_time(b, device, method='uniform')
192
+
193
+ Lt_sqrt = torch.sqrt(self.Lt_history + 1e-10) + 0.0001
194
+ Lt_sqrt[0] = Lt_sqrt[1] # Overwrite decoder term with L1.
195
+ pt_all = Lt_sqrt / Lt_sqrt.sum()
196
+
197
+ t = torch.multinomial(pt_all, num_samples=b, replacement=True)
198
+
199
+ pt = pt_all.gather(dim=0, index=t)
200
+
201
+ return t, pt
202
+
203
+ elif method == 'uniform':
204
+ t = torch.randint(
205
+ 1, self.num_timesteps + 1, (b, ), device=device).long()
206
+ pt = torch.ones_like(t).float() / self.num_timesteps
207
+ return t, pt
208
+
209
+ else:
210
+ raise ValueError
211
+
212
+ def q_sample(self, x_0, x_0_gt_list, t):
213
+ # samples q(x_t | x_0)
214
+ # randomly set token to mask with probability t/T
215
+ # x_t, x_0_ignore = x_0.clone(), x_0.clone()
216
+ x_t = x_0.clone()
217
+
218
+ mask = torch.rand_like(x_t.float()) < (
219
+ t.float().unsqueeze(-1) / self.num_timesteps)
220
+ x_t[mask] = self.mask_id
221
+ # x_0_ignore[torch.bitwise_not(mask)] = -1
222
+
223
+ # for every gt token list, we also need to do the mask
224
+ x_0_gt_ignore_list = []
225
+ for x_0_gt in x_0_gt_list:
226
+ x_0_gt_ignore = x_0_gt.clone()
227
+ x_0_gt_ignore[torch.bitwise_not(mask)] = -1
228
+ x_0_gt_ignore_list.append(x_0_gt_ignore)
229
+
230
+ return x_t, x_0_gt_ignore_list, mask
231
+
232
+ def _train_loss(self, x_0, x_0_gt_list):
233
+ b, device = x_0.size(0), x_0.device
234
+
235
+ # choose what time steps to compute loss at
236
+ t, pt = self.sample_time(b, device, 'uniform')
237
+
238
+ # make x noisy and denoise
239
+ if self.mask_schedule == 'random':
240
+ x_t, x_0_gt_ignore_list, mask = self.q_sample(
241
+ x_0=x_0, x_0_gt_list=x_0_gt_list, t=t)
242
+ else:
243
+ raise NotImplementedError
244
+
245
+ # sample p(x_0 | x_t)
246
+ x_0_hat_logits_list = self._denoise_fn(
247
+ x_t, self.segm_tokens, self.texture_tokens, t=t)
248
+
249
+ # Always compute ELBO for comparison purposes
250
+ cross_entropy_loss = 0
251
+ for x_0_hat_logits, x_0_gt_ignore in zip(x_0_hat_logits_list,
252
+ x_0_gt_ignore_list):
253
+ cross_entropy_loss += F.cross_entropy(
254
+ x_0_hat_logits.permute(0, 2, 1),
255
+ x_0_gt_ignore,
256
+ ignore_index=-1,
257
+ reduction='none').sum(1)
258
+ vb_loss = cross_entropy_loss / t
259
+ vb_loss = vb_loss / pt
260
+ vb_loss = vb_loss / (math.log(2) * x_0.shape[1:].numel())
261
+ if self.loss_type == 'elbo':
262
+ loss = vb_loss
263
+ elif self.loss_type == 'mlm':
264
+ denom = mask.float().sum(1)
265
+ denom[denom == 0] = 1 # prevent divide by 0 errors.
266
+ loss = cross_entropy_loss / denom
267
+ elif self.loss_type == 'reweighted_elbo':
268
+ weight = (1 - (t / self.num_timesteps))
269
+ loss = weight * cross_entropy_loss
270
+ loss = loss / (math.log(2) * x_0.shape[1:].numel())
271
+ else:
272
+ raise ValueError
273
+
274
+ return loss.mean(), vb_loss.mean()
275
+
276
+ def feed_data(self, data):
277
+ self.image = data['image'].to(self.device)
278
+ self.segm = data['segm'].to(self.device)
279
+ self.texture_mask = data['texture_mask'].to(self.device)
280
+ self.input_indices, self.gt_indices_list = self.get_quantized_img(
281
+ self.image, self.texture_mask)
282
+
283
+ self.texture_tokens = F.interpolate(
284
+ self.texture_mask, size=self.shape,
285
+ mode='nearest').view(self.image.size(0), -1).long()
286
+
287
+ self.segm_tokens = self.get_quantized_segm(self.segm)
288
+ self.segm_tokens = self.segm_tokens.view(self.image.size(0), -1)
289
+
290
+ def optimize_parameters(self):
291
+ self._denoise_fn.train()
292
+
293
+ loss, vb_loss = self._train_loss(self.input_indices,
294
+ self.gt_indices_list)
295
+
296
+ self.optimizer.zero_grad()
297
+ loss.backward()
298
+ self.optimizer.step()
299
+
300
+ self.log_dict['loss'] = loss
301
+ self.log_dict['vb_loss'] = vb_loss
302
+
303
+ self._denoise_fn.eval()
304
+
305
+ @torch.no_grad()
306
+ def get_quantized_segm(self, segm):
307
+ segm_one_hot = F.one_hot(
308
+ segm.squeeze(1).long(),
309
+ num_classes=self.opt['segm_num_segm_classes']).permute(
310
+ 0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
311
+ encoded_segm_mask = self.segm_encoder(segm_one_hot)
312
+ encoded_segm_mask = self.segm_quant_conv(encoded_segm_mask)
313
+ _, _, [_, _, segm_tokens] = self.segm_quantizer(encoded_segm_mask)
314
+
315
+ return segm_tokens
316
+
317
+ def sample_fn(self, temp=1.0, sample_steps=None):
318
+ self._denoise_fn.eval()
319
+
320
+ b, device = self.image.size(0), 'cuda'
321
+ x_t = torch.ones(
322
+ (b, np.prod(self.shape)), device=device).long() * self.mask_id
323
+ unmasked = torch.zeros_like(x_t, device=device).bool()
324
+ sample_steps = list(range(1, sample_steps + 1))
325
+
326
+ texture_mask_flatten = self.texture_tokens.view(-1)
327
+
328
+ # min_encodings_indices_list would be used to visualize the image
329
+ min_encodings_indices_list = [
330
+ torch.full(
331
+ texture_mask_flatten.size(),
332
+ fill_value=-1,
333
+ dtype=torch.long,
334
+ device=texture_mask_flatten.device) for _ in range(18)
335
+ ]
336
+
337
+ for t in reversed(sample_steps):
338
+ print(f'Sample timestep {t:4d}', end='\r')
339
+ t = torch.full((b, ), t, device=device, dtype=torch.long)
340
+
341
+ # where to unmask
342
+ changes = torch.rand(
343
+ x_t.shape, device=device) < 1 / t.float().unsqueeze(-1)
344
+ # don't unmask somewhere already unmasked
345
+ changes = torch.bitwise_xor(changes,
346
+ torch.bitwise_and(changes, unmasked))
347
+ # update mask with changes
348
+ unmasked = torch.bitwise_or(unmasked, changes)
349
+
350
+ x_0_logits_list = self._denoise_fn(
351
+ x_t, self.segm_tokens, self.texture_tokens, t=t)
352
+
353
+ changes_flatten = changes.view(-1)
354
+ ori_shape = x_t.shape # [b, h*w]
355
+ x_t = x_t.view(-1) # [b*h*w]
356
+ for codebook_idx, x_0_logits in enumerate(x_0_logits_list):
357
+ if torch.sum(texture_mask_flatten[changes_flatten] ==
358
+ codebook_idx) > 0:
359
+ # scale by temperature
360
+ x_0_logits = x_0_logits / temp
361
+ x_0_dist = dists.Categorical(logits=x_0_logits)
362
+ x_0_hat = x_0_dist.sample().long()
363
+ x_0_hat = x_0_hat.view(-1)
364
+
365
+ # only replace the changed indices with corresponding codebook_idx
366
+ changes_segm = torch.bitwise_and(
367
+ changes_flatten, texture_mask_flatten == codebook_idx)
368
+
369
+ # x_t would be the input to the transformer, so the index range should be continual one
370
+ x_t[changes_segm] = x_0_hat[
371
+ changes_segm] + 1024 * codebook_idx
372
+ min_encodings_indices_list[codebook_idx][
373
+ changes_segm] = x_0_hat[changes_segm]
374
+
375
+ x_t = x_t.view(ori_shape) # [b, h*w]
376
+
377
+ min_encodings_indices_return_list = [
378
+ min_encodings_indices.view(ori_shape)
379
+ for min_encodings_indices in min_encodings_indices_list
380
+ ]
381
+
382
+ self._denoise_fn.train()
383
+
384
+ return min_encodings_indices_return_list
385
+
386
+ def get_vis(self, image, gt_indices, predicted_indices, texture_mask,
387
+ save_path):
388
+ # original image
389
+ ori_img = self.decode_image_indices(gt_indices, texture_mask)
390
+ # pred image
391
+ pred_img = self.decode_image_indices(predicted_indices, texture_mask)
392
+ img_cat = torch.cat([
393
+ image,
394
+ ori_img,
395
+ pred_img,
396
+ ], dim=3).detach()
397
+ img_cat = ((img_cat + 1) / 2)
398
+ img_cat = img_cat.clamp_(0, 1)
399
+ save_image(img_cat, save_path, nrow=1, padding=4)
400
+
401
+ def inference(self, data_loader, save_dir):
402
+ self._denoise_fn.eval()
403
+
404
+ for _, data in enumerate(data_loader):
405
+ img_name = data['img_name']
406
+ self.feed_data(data)
407
+ b = self.image.size(0)
408
+ with torch.no_grad():
409
+ sampled_indices_list = self.sample_fn(
410
+ temp=1, sample_steps=self.sample_steps)
411
+ for idx in range(b):
412
+ self.get_vis(self.image[idx:idx + 1], [
413
+ gt_indices[idx:idx + 1]
414
+ for gt_indices in self.gt_indices_list
415
+ ], [
416
+ sampled_indices[idx:idx + 1]
417
+ for sampled_indices in sampled_indices_list
418
+ ], self.texture_mask[idx:idx + 1],
419
+ f'{save_dir}/{img_name[idx]}')
420
+
421
+ self._denoise_fn.train()
422
+
423
+ def get_current_log(self):
424
+ return self.log_dict
425
+
426
+ def update_learning_rate(self, epoch, iters=None):
427
+ """Update learning rate.
428
+
429
+ Args:
430
+ current_iter (int): Current iteration.
431
+ warmup_iter (int): Warmup iter numbers. -1 for no warmup.
432
+ Default: -1.
433
+ """
434
+ lr = self.optimizer.param_groups[0]['lr']
435
+
436
+ if self.opt['lr_decay'] == 'step':
437
+ lr = self.opt['lr'] * (
438
+ self.opt['gamma']**(epoch // self.opt['step']))
439
+ elif self.opt['lr_decay'] == 'cos':
440
+ lr = self.opt['lr'] * (
441
+ 1 + math.cos(math.pi * epoch / self.opt['num_epochs'])) / 2
442
+ elif self.opt['lr_decay'] == 'linear':
443
+ lr = self.opt['lr'] * (1 - epoch / self.opt['num_epochs'])
444
+ elif self.opt['lr_decay'] == 'linear2exp':
445
+ if epoch < self.opt['turning_point'] + 1:
446
+ # learning rate decay as 95%
447
+ # at the turning point (1 / 95% = 1.0526)
448
+ lr = self.opt['lr'] * (
449
+ 1 - epoch / int(self.opt['turning_point'] * 1.0526))
450
+ else:
451
+ lr *= self.opt['gamma']
452
+ elif self.opt['lr_decay'] == 'schedule':
453
+ if epoch in self.opt['schedule']:
454
+ lr *= self.opt['gamma']
455
+ elif self.opt['lr_decay'] == 'warm_up':
456
+ if iters <= self.opt['warmup_iters']:
457
+ lr = self.opt['lr'] * float(iters) / self.opt['warmup_iters']
458
+ else:
459
+ lr = self.opt['lr']
460
+ else:
461
+ raise ValueError('Unknown lr mode {}'.format(self.opt['lr_decay']))
462
+ # set learning rate
463
+ for param_group in self.optimizer.param_groups:
464
+ param_group['lr'] = lr
465
+
466
+ return lr
467
+
468
+ def save_network(self, net, save_path):
469
+ """Save networks.
470
+
471
+ Args:
472
+ net (nn.Module): Network to be saved.
473
+ net_label (str): Network label.
474
+ current_iter (int): Current iter number.
475
+ """
476
+ state_dict = net.state_dict()
477
+ torch.save(state_dict, save_path)
478
+
479
+ def load_network(self):
480
+ checkpoint = torch.load(self.opt['pretrained_sampler'])
481
+ self._denoise_fn.load_state_dict(checkpoint, strict=True)
482
+ self._denoise_fn.eval()
models/vqgan_model.py ADDED
@@ -0,0 +1,551 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import sys
3
+ from collections import OrderedDict
4
+
5
+ sys.path.append('..')
6
+ import lpips
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from torchvision.utils import save_image
10
+
11
+ from models.archs.vqgan_arch import (Decoder, Discriminator, Encoder,
12
+ VectorQuantizer, VectorQuantizerTexture)
13
+ from models.losses.segmentation_loss import BCELossWithQuant
14
+ from models.losses.vqgan_loss import (DiffAugment, adopt_weight,
15
+ calculate_adaptive_weight, hinge_d_loss)
16
+
17
+
18
+ class VQModel():
19
+
20
+ def __init__(self, opt):
21
+ super().__init__()
22
+ self.opt = opt
23
+ self.device = torch.device('cuda')
24
+ self.encoder = Encoder(
25
+ ch=opt['ch'],
26
+ num_res_blocks=opt['num_res_blocks'],
27
+ attn_resolutions=opt['attn_resolutions'],
28
+ ch_mult=opt['ch_mult'],
29
+ in_channels=opt['in_channels'],
30
+ resolution=opt['resolution'],
31
+ z_channels=opt['z_channels'],
32
+ double_z=opt['double_z'],
33
+ dropout=opt['dropout']).to(self.device)
34
+ self.decoder = Decoder(
35
+ in_channels=opt['in_channels'],
36
+ resolution=opt['resolution'],
37
+ z_channels=opt['z_channels'],
38
+ ch=opt['ch'],
39
+ out_ch=opt['out_ch'],
40
+ num_res_blocks=opt['num_res_blocks'],
41
+ attn_resolutions=opt['attn_resolutions'],
42
+ ch_mult=opt['ch_mult'],
43
+ dropout=opt['dropout'],
44
+ resamp_with_conv=True,
45
+ give_pre_end=False).to(self.device)
46
+ self.quantize = VectorQuantizer(
47
+ opt['n_embed'], opt['embed_dim'], beta=0.25).to(self.device)
48
+ self.quant_conv = torch.nn.Conv2d(opt["z_channels"], opt['embed_dim'],
49
+ 1).to(self.device)
50
+ self.post_quant_conv = torch.nn.Conv2d(opt['embed_dim'],
51
+ opt["z_channels"],
52
+ 1).to(self.device)
53
+
54
+ def init_training_settings(self):
55
+ self.loss = BCELossWithQuant()
56
+ self.log_dict = OrderedDict()
57
+ self.configure_optimizers()
58
+
59
+ def save_network(self, save_path):
60
+ """Save networks.
61
+
62
+ Args:
63
+ net (nn.Module): Network to be saved.
64
+ net_label (str): Network label.
65
+ current_iter (int): Current iter number.
66
+ """
67
+
68
+ save_dict = {}
69
+ save_dict['encoder'] = self.encoder.state_dict()
70
+ save_dict['decoder'] = self.decoder.state_dict()
71
+ save_dict['quantize'] = self.quantize.state_dict()
72
+ save_dict['quant_conv'] = self.quant_conv.state_dict()
73
+ save_dict['post_quant_conv'] = self.post_quant_conv.state_dict()
74
+ save_dict['discriminator'] = self.disc.state_dict()
75
+ torch.save(save_dict, save_path)
76
+
77
+ def load_network(self):
78
+ checkpoint = torch.load(self.opt['pretrained_models'])
79
+ self.encoder.load_state_dict(checkpoint['encoder'], strict=True)
80
+ self.decoder.load_state_dict(checkpoint['decoder'], strict=True)
81
+ self.quantize.load_state_dict(checkpoint['quantize'], strict=True)
82
+ self.quant_conv.load_state_dict(checkpoint['quant_conv'], strict=True)
83
+ self.post_quant_conv.load_state_dict(
84
+ checkpoint['post_quant_conv'], strict=True)
85
+
86
+ def optimize_parameters(self, data, current_iter):
87
+ self.encoder.train()
88
+ self.decoder.train()
89
+ self.quantize.train()
90
+ self.quant_conv.train()
91
+ self.post_quant_conv.train()
92
+
93
+ loss = self.training_step(data)
94
+ self.optimizer.zero_grad()
95
+ loss.backward()
96
+ self.optimizer.step()
97
+
98
+ def encode(self, x):
99
+ h = self.encoder(x)
100
+ h = self.quant_conv(h)
101
+ quant, emb_loss, info = self.quantize(h)
102
+ return quant, emb_loss, info
103
+
104
+ def decode(self, quant):
105
+ quant = self.post_quant_conv(quant)
106
+ dec = self.decoder(quant)
107
+ return dec
108
+
109
+ def decode_code(self, code_b):
110
+ quant_b = self.quantize.embed_code(code_b)
111
+ dec = self.decode(quant_b)
112
+ return dec
113
+
114
+ def forward_step(self, input):
115
+ quant, diff, _ = self.encode(input)
116
+ dec = self.decode(quant)
117
+ return dec, diff
118
+
119
+ def feed_data(self, data):
120
+ x = data['segm']
121
+ x = F.one_hot(x, num_classes=self.opt['num_segm_classes'])
122
+
123
+ if len(x.shape) == 3:
124
+ x = x[..., None]
125
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)
126
+ return x.float().to(self.device)
127
+
128
+ def get_current_log(self):
129
+ return self.log_dict
130
+
131
+ def update_learning_rate(self, epoch):
132
+ """Update learning rate.
133
+
134
+ Args:
135
+ current_iter (int): Current iteration.
136
+ warmup_iter (int): Warmup iter numbers. -1 for no warmup.
137
+ Default: -1.
138
+ """
139
+ lr = self.optimizer.param_groups[0]['lr']
140
+
141
+ if self.opt['lr_decay'] == 'step':
142
+ lr = self.opt['lr'] * (
143
+ self.opt['gamma']**(epoch // self.opt['step']))
144
+ elif self.opt['lr_decay'] == 'cos':
145
+ lr = self.opt['lr'] * (
146
+ 1 + math.cos(math.pi * epoch / self.opt['num_epochs'])) / 2
147
+ elif self.opt['lr_decay'] == 'linear':
148
+ lr = self.opt['lr'] * (1 - epoch / self.opt['num_epochs'])
149
+ elif self.opt['lr_decay'] == 'linear2exp':
150
+ if epoch < self.opt['turning_point'] + 1:
151
+ # learning rate decay as 95%
152
+ # at the turning point (1 / 95% = 1.0526)
153
+ lr = self.opt['lr'] * (
154
+ 1 - epoch / int(self.opt['turning_point'] * 1.0526))
155
+ else:
156
+ lr *= self.opt['gamma']
157
+ elif self.opt['lr_decay'] == 'schedule':
158
+ if epoch in self.opt['schedule']:
159
+ lr *= self.opt['gamma']
160
+ else:
161
+ raise ValueError('Unknown lr mode {}'.format(self.opt['lr_decay']))
162
+ # set learning rate
163
+ for param_group in self.optimizer.param_groups:
164
+ param_group['lr'] = lr
165
+
166
+ return lr
167
+
168
+
169
+ class VQSegmentationModel(VQModel):
170
+
171
+ def __init__(self, opt):
172
+ super().__init__(opt)
173
+ self.colorize = torch.randn(3, opt['num_segm_classes'], 1,
174
+ 1).to(self.device)
175
+
176
+ self.init_training_settings()
177
+
178
+ def configure_optimizers(self):
179
+ self.optimizer = torch.optim.Adam(
180
+ list(self.encoder.parameters()) + list(self.decoder.parameters()) +
181
+ list(self.quantize.parameters()) +
182
+ list(self.quant_conv.parameters()) +
183
+ list(self.post_quant_conv.parameters()),
184
+ lr=self.opt['lr'],
185
+ betas=(0.5, 0.9))
186
+
187
+ def training_step(self, data):
188
+ x = self.feed_data(data)
189
+ xrec, qloss = self.forward_step(x)
190
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, split="train")
191
+ self.log_dict.update(log_dict_ae)
192
+ return aeloss
193
+
194
+ def to_rgb(self, x):
195
+ x = F.conv2d(x, weight=self.colorize)
196
+ x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
197
+ return x
198
+
199
+ @torch.no_grad()
200
+ def inference(self, data_loader, save_dir):
201
+ self.encoder.eval()
202
+ self.decoder.eval()
203
+ self.quantize.eval()
204
+ self.quant_conv.eval()
205
+ self.post_quant_conv.eval()
206
+
207
+ loss_total = 0
208
+ loss_bce = 0
209
+ loss_quant = 0
210
+ num = 0
211
+
212
+ for _, data in enumerate(data_loader):
213
+ img_name = data['img_name'][0]
214
+ x = self.feed_data(data)
215
+ xrec, qloss = self.forward_step(x)
216
+ _, log_dict_ae = self.loss(qloss, x, xrec, split="val")
217
+
218
+ loss_total += log_dict_ae['val/total_loss']
219
+ loss_bce += log_dict_ae['val/bce_loss']
220
+ loss_quant += log_dict_ae['val/quant_loss']
221
+
222
+ num += x.size(0)
223
+
224
+ if x.shape[1] > 3:
225
+ # colorize with random projection
226
+ assert xrec.shape[1] > 3
227
+ # convert logits to indices
228
+ xrec = torch.argmax(xrec, dim=1, keepdim=True)
229
+ xrec = F.one_hot(xrec, num_classes=x.shape[1])
230
+ xrec = xrec.squeeze(1).permute(0, 3, 1, 2).float()
231
+ x = self.to_rgb(x)
232
+ xrec = self.to_rgb(xrec)
233
+
234
+ img_cat = torch.cat([x, xrec], dim=3).detach()
235
+ img_cat = ((img_cat + 1) / 2)
236
+ img_cat = img_cat.clamp_(0, 1)
237
+ save_image(
238
+ img_cat, f'{save_dir}/{img_name}.png', nrow=1, padding=4)
239
+
240
+ return (loss_total / num).item(), (loss_bce /
241
+ num).item(), (loss_quant /
242
+ num).item()
243
+
244
+
245
+ class VQImageModel(VQModel):
246
+
247
+ def __init__(self, opt):
248
+ super().__init__(opt)
249
+ self.disc = Discriminator(
250
+ opt['n_channels'], opt['ndf'],
251
+ n_layers=opt['disc_layers']).to(self.device)
252
+ self.perceptual = lpips.LPIPS(net="vgg").to(self.device)
253
+ self.perceptual_weight = opt['perceptual_weight']
254
+ self.disc_start_step = opt['disc_start_step']
255
+ self.disc_weight_max = opt['disc_weight_max']
256
+ self.diff_aug = opt['diff_aug']
257
+ self.policy = "color,translation"
258
+
259
+ self.disc.train()
260
+
261
+ self.init_training_settings()
262
+
263
+ def feed_data(self, data):
264
+ x = data['image']
265
+
266
+ return x.float().to(self.device)
267
+
268
+ def init_training_settings(self):
269
+ self.log_dict = OrderedDict()
270
+ self.configure_optimizers()
271
+
272
+ def configure_optimizers(self):
273
+ self.optimizer = torch.optim.Adam(
274
+ list(self.encoder.parameters()) + list(self.decoder.parameters()) +
275
+ list(self.quantize.parameters()) +
276
+ list(self.quant_conv.parameters()) +
277
+ list(self.post_quant_conv.parameters()),
278
+ lr=self.opt['lr'])
279
+
280
+ self.disc_optimizer = torch.optim.Adam(
281
+ self.disc.parameters(), lr=self.opt['lr'])
282
+
283
+ def training_step(self, data, step):
284
+ x = self.feed_data(data)
285
+ xrec, codebook_loss = self.forward_step(x)
286
+
287
+ # get recon/perceptual loss
288
+ recon_loss = torch.abs(x.contiguous() - xrec.contiguous())
289
+ p_loss = self.perceptual(x.contiguous(), xrec.contiguous())
290
+ nll_loss = recon_loss + self.perceptual_weight * p_loss
291
+ nll_loss = torch.mean(nll_loss)
292
+
293
+ # augment for input to discriminator
294
+ if self.diff_aug:
295
+ xrec = DiffAugment(xrec, policy=self.policy)
296
+
297
+ # update generator
298
+ logits_fake = self.disc(xrec)
299
+ g_loss = -torch.mean(logits_fake)
300
+ last_layer = self.decoder.conv_out.weight
301
+ d_weight = calculate_adaptive_weight(nll_loss, g_loss, last_layer,
302
+ self.disc_weight_max)
303
+ d_weight *= adopt_weight(1, step, self.disc_start_step)
304
+ loss = nll_loss + d_weight * g_loss + codebook_loss
305
+
306
+ self.log_dict["loss"] = loss
307
+ self.log_dict["l1"] = recon_loss.mean().item()
308
+ self.log_dict["perceptual"] = p_loss.mean().item()
309
+ self.log_dict["nll_loss"] = nll_loss.item()
310
+ self.log_dict["g_loss"] = g_loss.item()
311
+ self.log_dict["d_weight"] = d_weight
312
+ self.log_dict["codebook_loss"] = codebook_loss.item()
313
+
314
+ if step > self.disc_start_step:
315
+ if self.diff_aug:
316
+ logits_real = self.disc(
317
+ DiffAugment(x.contiguous().detach(), policy=self.policy))
318
+ else:
319
+ logits_real = self.disc(x.contiguous().detach())
320
+ logits_fake = self.disc(xrec.contiguous().detach(
321
+ )) # detach so that generator isn"t also updated
322
+ d_loss = hinge_d_loss(logits_real, logits_fake)
323
+ self.log_dict["d_loss"] = d_loss
324
+ else:
325
+ d_loss = None
326
+
327
+ return loss, d_loss
328
+
329
+ def optimize_parameters(self, data, step):
330
+ self.encoder.train()
331
+ self.decoder.train()
332
+ self.quantize.train()
333
+ self.quant_conv.train()
334
+ self.post_quant_conv.train()
335
+
336
+ loss, d_loss = self.training_step(data, step)
337
+ self.optimizer.zero_grad()
338
+ loss.backward()
339
+ self.optimizer.step()
340
+
341
+ if step > self.disc_start_step:
342
+ self.disc_optimizer.zero_grad()
343
+ d_loss.backward()
344
+ self.disc_optimizer.step()
345
+
346
+ @torch.no_grad()
347
+ def inference(self, data_loader, save_dir):
348
+ self.encoder.eval()
349
+ self.decoder.eval()
350
+ self.quantize.eval()
351
+ self.quant_conv.eval()
352
+ self.post_quant_conv.eval()
353
+
354
+ loss_total = 0
355
+ num = 0
356
+
357
+ for _, data in enumerate(data_loader):
358
+ img_name = data['img_name'][0]
359
+ x = self.feed_data(data)
360
+ xrec, _ = self.forward_step(x)
361
+
362
+ recon_loss = torch.abs(x.contiguous() - xrec.contiguous())
363
+ p_loss = self.perceptual(x.contiguous(), xrec.contiguous())
364
+ nll_loss = recon_loss + self.perceptual_weight * p_loss
365
+ nll_loss = torch.mean(nll_loss)
366
+ loss_total += nll_loss
367
+
368
+ num += x.size(0)
369
+
370
+ if x.shape[1] > 3:
371
+ # colorize with random projection
372
+ assert xrec.shape[1] > 3
373
+ # convert logits to indices
374
+ xrec = torch.argmax(xrec, dim=1, keepdim=True)
375
+ xrec = F.one_hot(xrec, num_classes=x.shape[1])
376
+ xrec = xrec.squeeze(1).permute(0, 3, 1, 2).float()
377
+ x = self.to_rgb(x)
378
+ xrec = self.to_rgb(xrec)
379
+
380
+ img_cat = torch.cat([x, xrec], dim=3).detach()
381
+ img_cat = ((img_cat + 1) / 2)
382
+ img_cat = img_cat.clamp_(0, 1)
383
+ save_image(
384
+ img_cat, f'{save_dir}/{img_name}.png', nrow=1, padding=4)
385
+
386
+ return (loss_total / num).item()
387
+
388
+
389
+ class VQImageSegmTextureModel(VQImageModel):
390
+
391
+ def __init__(self, opt):
392
+ self.opt = opt
393
+ self.device = torch.device('cuda')
394
+ self.encoder = Encoder(
395
+ ch=opt['ch'],
396
+ num_res_blocks=opt['num_res_blocks'],
397
+ attn_resolutions=opt['attn_resolutions'],
398
+ ch_mult=opt['ch_mult'],
399
+ in_channels=opt['in_channels'],
400
+ resolution=opt['resolution'],
401
+ z_channels=opt['z_channels'],
402
+ double_z=opt['double_z'],
403
+ dropout=opt['dropout']).to(self.device)
404
+ self.decoder = Decoder(
405
+ in_channels=opt['in_channels'],
406
+ resolution=opt['resolution'],
407
+ z_channels=opt['z_channels'],
408
+ ch=opt['ch'],
409
+ out_ch=opt['out_ch'],
410
+ num_res_blocks=opt['num_res_blocks'],
411
+ attn_resolutions=opt['attn_resolutions'],
412
+ ch_mult=opt['ch_mult'],
413
+ dropout=opt['dropout'],
414
+ resamp_with_conv=True,
415
+ give_pre_end=False).to(self.device)
416
+ self.quantize = VectorQuantizerTexture(
417
+ opt['n_embed'], opt['embed_dim'], beta=0.25).to(self.device)
418
+ self.quant_conv = torch.nn.Conv2d(opt["z_channels"], opt['embed_dim'],
419
+ 1).to(self.device)
420
+ self.post_quant_conv = torch.nn.Conv2d(opt['embed_dim'],
421
+ opt["z_channels"],
422
+ 1).to(self.device)
423
+
424
+ self.disc = Discriminator(
425
+ opt['n_channels'], opt['ndf'],
426
+ n_layers=opt['disc_layers']).to(self.device)
427
+ self.perceptual = lpips.LPIPS(net="vgg").to(self.device)
428
+ self.perceptual_weight = opt['perceptual_weight']
429
+ self.disc_start_step = opt['disc_start_step']
430
+ self.disc_weight_max = opt['disc_weight_max']
431
+ self.diff_aug = opt['diff_aug']
432
+ self.policy = "color,translation"
433
+
434
+ self.disc.train()
435
+
436
+ self.init_training_settings()
437
+
438
+ def feed_data(self, data):
439
+ x = data['image'].float().to(self.device)
440
+ mask = data['texture_mask'].float().to(self.device)
441
+
442
+ return x, mask
443
+
444
+ def training_step(self, data, step):
445
+ x, mask = self.feed_data(data)
446
+ xrec, codebook_loss = self.forward_step(x, mask)
447
+
448
+ # get recon/perceptual loss
449
+ recon_loss = torch.abs(x.contiguous() - xrec.contiguous())
450
+ p_loss = self.perceptual(x.contiguous(), xrec.contiguous())
451
+ nll_loss = recon_loss + self.perceptual_weight * p_loss
452
+ nll_loss = torch.mean(nll_loss)
453
+
454
+ # augment for input to discriminator
455
+ if self.diff_aug:
456
+ xrec = DiffAugment(xrec, policy=self.policy)
457
+
458
+ # update generator
459
+ logits_fake = self.disc(xrec)
460
+ g_loss = -torch.mean(logits_fake)
461
+ last_layer = self.decoder.conv_out.weight
462
+ d_weight = calculate_adaptive_weight(nll_loss, g_loss, last_layer,
463
+ self.disc_weight_max)
464
+ d_weight *= adopt_weight(1, step, self.disc_start_step)
465
+ loss = nll_loss + d_weight * g_loss + codebook_loss
466
+
467
+ self.log_dict["loss"] = loss
468
+ self.log_dict["l1"] = recon_loss.mean().item()
469
+ self.log_dict["perceptual"] = p_loss.mean().item()
470
+ self.log_dict["nll_loss"] = nll_loss.item()
471
+ self.log_dict["g_loss"] = g_loss.item()
472
+ self.log_dict["d_weight"] = d_weight
473
+ self.log_dict["codebook_loss"] = codebook_loss.item()
474
+
475
+ if step > self.disc_start_step:
476
+ if self.diff_aug:
477
+ logits_real = self.disc(
478
+ DiffAugment(x.contiguous().detach(), policy=self.policy))
479
+ else:
480
+ logits_real = self.disc(x.contiguous().detach())
481
+ logits_fake = self.disc(xrec.contiguous().detach(
482
+ )) # detach so that generator isn"t also updated
483
+ d_loss = hinge_d_loss(logits_real, logits_fake)
484
+ self.log_dict["d_loss"] = d_loss
485
+ else:
486
+ d_loss = None
487
+
488
+ return loss, d_loss
489
+
490
+ @torch.no_grad()
491
+ def inference(self, data_loader, save_dir):
492
+ self.encoder.eval()
493
+ self.decoder.eval()
494
+ self.quantize.eval()
495
+ self.quant_conv.eval()
496
+ self.post_quant_conv.eval()
497
+
498
+ loss_total = 0
499
+ num = 0
500
+
501
+ for _, data in enumerate(data_loader):
502
+ img_name = data['img_name'][0]
503
+ x, mask = self.feed_data(data)
504
+ xrec, _ = self.forward_step(x, mask)
505
+
506
+ recon_loss = torch.abs(x.contiguous() - xrec.contiguous())
507
+ p_loss = self.perceptual(x.contiguous(), xrec.contiguous())
508
+ nll_loss = recon_loss + self.perceptual_weight * p_loss
509
+ nll_loss = torch.mean(nll_loss)
510
+ loss_total += nll_loss
511
+
512
+ num += x.size(0)
513
+
514
+ if x.shape[1] > 3:
515
+ # colorize with random projection
516
+ assert xrec.shape[1] > 3
517
+ # convert logits to indices
518
+ xrec = torch.argmax(xrec, dim=1, keepdim=True)
519
+ xrec = F.one_hot(xrec, num_classes=x.shape[1])
520
+ xrec = xrec.squeeze(1).permute(0, 3, 1, 2).float()
521
+ x = self.to_rgb(x)
522
+ xrec = self.to_rgb(xrec)
523
+
524
+ img_cat = torch.cat([x, xrec], dim=3).detach()
525
+ img_cat = ((img_cat + 1) / 2)
526
+ img_cat = img_cat.clamp_(0, 1)
527
+ save_image(
528
+ img_cat, f'{save_dir}/{img_name}.png', nrow=1, padding=4)
529
+
530
+ return (loss_total / num).item()
531
+
532
+ def encode(self, x, mask):
533
+ h = self.encoder(x)
534
+ h = self.quant_conv(h)
535
+ quant, emb_loss, info = self.quantize(h, mask)
536
+ return quant, emb_loss, info
537
+
538
+ def decode(self, quant):
539
+ quant = self.post_quant_conv(quant)
540
+ dec = self.decoder(quant)
541
+ return dec
542
+
543
+ def decode_code(self, code_b):
544
+ quant_b = self.quantize.embed_code(code_b)
545
+ dec = self.decode(quant_b)
546
+ return dec
547
+
548
+ def forward_step(self, input, mask):
549
+ quant, diff, _ = self.encode(input, mask)
550
+ dec = self.decode(quant)
551
+ return dec, diff